]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
talk-llama : sync llama.cpp
authorGeorgi Gerganov <redacted>
Sat, 20 Sep 2025 10:47:47 +0000 (13:47 +0300)
committerGeorgi Gerganov <redacted>
Sat, 20 Sep 2025 10:58:28 +0000 (13:58 +0300)
39 files changed:
examples/talk-llama/CMakeLists.txt
examples/talk-llama/llama-adapter.cpp
examples/talk-llama/llama-adapter.h
examples/talk-llama/llama-arch.cpp
examples/talk-llama/llama-arch.h
examples/talk-llama/llama-chat.cpp
examples/talk-llama/llama-chat.h
examples/talk-llama/llama-context.cpp
examples/talk-llama/llama-context.h
examples/talk-llama/llama-cparams.h
examples/talk-llama/llama-graph.cpp
examples/talk-llama/llama-graph.h
examples/talk-llama/llama-hparams.cpp
examples/talk-llama/llama-hparams.h
examples/talk-llama/llama-impl.h
examples/talk-llama/llama-kv-cache-iswa.cpp [new file with mode: 0644]
examples/talk-llama/llama-kv-cache-iswa.h [new file with mode: 0644]
examples/talk-llama/llama-kv-cache-unified-iswa.cpp [deleted file]
examples/talk-llama/llama-kv-cache-unified-iswa.h [deleted file]
examples/talk-llama/llama-kv-cache-unified.cpp [deleted file]
examples/talk-llama/llama-kv-cache-unified.h [deleted file]
examples/talk-llama/llama-kv-cache.cpp [new file with mode: 0644]
examples/talk-llama/llama-kv-cache.h
examples/talk-llama/llama-kv-cells.h
examples/talk-llama/llama-memory-hybrid.cpp
examples/talk-llama/llama-memory-hybrid.h
examples/talk-llama/llama-memory-recurrent.cpp
examples/talk-llama/llama-memory-recurrent.h
examples/talk-llama/llama-memory.h
examples/talk-llama/llama-model-loader.cpp
examples/talk-llama/llama-model.cpp
examples/talk-llama/llama-model.h
examples/talk-llama/llama-quant.cpp
examples/talk-llama/llama-sampling.cpp
examples/talk-llama/llama-vocab.cpp
examples/talk-llama/llama-vocab.h
examples/talk-llama/llama.cpp
examples/talk-llama/llama.h
examples/talk-llama/talk-llama.cpp

index 13ecced828544610db47afe8f146122d217f60b8..182114c2697c3c565bf64b1a3e5beea83ecbb6d2 100644 (file)
@@ -16,8 +16,8 @@ if (WHISPER_SDL2)
         llama-hparams.cpp
         llama-impl.cpp
         llama-io.cpp
-        llama-kv-cache-unified.cpp
-        llama-kv-cache-unified-iswa.cpp
+        llama-kv-cache.cpp
+        llama-kv-cache-iswa.cpp
         llama-memory-recurrent.cpp
         llama-memory-hybrid.cpp
         llama-memory.cpp
index 8d94034aed95debd4b1ea9269996578744ba809b..d8eef75a7ad70afee2bebb03d023c60731a50cf2 100644 (file)
@@ -6,6 +6,7 @@
 
 #include <map>
 #include <cassert>
+#include <sstream>
 #include <stdexcept>
 
 // vec
@@ -163,13 +164,38 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
 
     // check metadata
     {
+        const gguf_context * gguf_ctx = ctx_gguf.get();
+
+        LLAMA_LOG_INFO("%s: Dumping metadata keys/values.\n", __func__);
+
+        // get metadata as string
+        for (int i = 0; i < gguf_get_n_kv(gguf_ctx); i++) {
+            gguf_type type = gguf_get_kv_type(gguf_ctx, i);
+            const std::string type_name =
+                type == GGUF_TYPE_ARRAY
+                ? format("%s[%s,%zu]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(gguf_ctx, i)), gguf_get_arr_n(gguf_ctx, i))
+                : gguf_type_name(type);
+            const char * name = gguf_get_key(gguf_ctx, i);
+            const std::string value = gguf_kv_to_str(gguf_ctx, i);
+
+            if (type != GGUF_TYPE_ARRAY) {
+                adapter.gguf_kv.emplace(name, value);
+            }
+
+            const size_t MAX_VALUE_LEN = 40;
+            std::string print_value = value.size() > MAX_VALUE_LEN ? format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str()) : value;
+            replace_all(print_value, "\n", "\\n");
+
+            LLAMA_LOG_INFO("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), print_value.c_str());
+        }
+
         auto get_kv_str = [&](const std::string & key) -> std::string {
-            int id = gguf_find_key(ctx_gguf.get(), key.c_str());
-            return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf.get(), id));
+            int id = gguf_find_key(gguf_ctx, key.c_str());
+            return id < 0 ? "" : std::string(gguf_get_val_str(gguf_ctx, id));
         };
         auto get_kv_f32 = [&](const std::string & key) -> float {
-            int id = gguf_find_key(ctx_gguf.get(), key.c_str());
-            return id < 0 ? 0.0f : gguf_get_val_f32(ctx_gguf.get(), id);
+            int id = gguf_find_key(gguf_ctx, key.c_str());
+            return id < 0 ? 0.0f : gguf_get_val_f32(gguf_ctx, id);
         };
         LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
 
@@ -190,6 +216,26 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
         }
 
         adapter.alpha = get_kv_f32(llm_kv(LLM_KV_ADAPTER_LORA_ALPHA));
+
+        // parse alora invocation sequence vector
+        const auto & key = llm_kv(LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS);
+        const int kid = gguf_find_key(ctx_gguf.get(), key.c_str());
+        if (kid >= 0) {
+            if (gguf_get_kv_type(ctx_gguf.get(), kid) != GGUF_TYPE_ARRAY) {
+                throw std::runtime_error("invalid gguf type for " + key);
+            }
+            const auto arr_type = gguf_get_arr_type(ctx_gguf.get(), kid);
+            if (arr_type != GGUF_TYPE_UINT32) {
+                throw std::runtime_error("invalid gguf element type for " + key);
+            }
+            const size_t seq_len = gguf_get_arr_n(ctx_gguf.get(), kid);
+            const void * data = gguf_get_arr_data(ctx_gguf.get(), kid);
+            adapter.alora_invocation_tokens.resize(seq_len);
+            std::copy(
+                (const llama_token *)data,
+                (const llama_token *)data + seq_len,
+                adapter.alora_invocation_tokens.begin());
+        }
     }
 
     int n_tensors = gguf_get_n_tensors(ctx_gguf.get());
@@ -383,6 +429,57 @@ llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * p
     return nullptr;
 }
 
+int32_t llama_adapter_meta_val_str(const llama_adapter_lora * adapter, const char * key, char * buf, size_t buf_size) {
+    const auto & it = adapter->gguf_kv.find(key);
+    if (it == adapter->gguf_kv.end()) {
+        if (buf_size > 0) {
+            buf[0] = '\0';
+        }
+        return -1;
+    }
+    return snprintf(buf, buf_size, "%s", it->second.c_str());
+}
+
+int32_t llama_adapter_meta_count(const llama_adapter_lora * adapter) {
+    return (int)adapter->gguf_kv.size();
+}
+
+int32_t llama_adapter_meta_key_by_index(const llama_adapter_lora * adapter, int i, char * buf, size_t buf_size) {
+    if (i < 0 || i >= (int)adapter->gguf_kv.size()) {
+        if (buf_size > 0) {
+            buf[0] = '\0';
+        }
+        return -1;
+    }
+    auto it = adapter->gguf_kv.begin();
+    std::advance(it, i);
+    return snprintf(buf, buf_size, "%s", it->first.c_str());
+}
+
+int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size) {
+    if (i < 0 || i >= (int)adapter->gguf_kv.size()) {
+        if (buf_size > 0) {
+            buf[0] = '\0';
+        }
+        return -1;
+    }
+    auto it = adapter->gguf_kv.begin();
+    std::advance(it, i);
+    return snprintf(buf, buf_size, "%s", it->second.c_str());
+}
+
 void llama_adapter_lora_free(llama_adapter_lora * adapter) {
     delete adapter;
 }
+
+uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter) {
+    if (!adapter) {
+        return 0;
+    }
+    return adapter->alora_invocation_tokens.size();
+}
+
+const llama_token * llama_adapter_get_alora_invocation_tokens(const llama_adapter_lora * adapter) {
+    GGML_ASSERT(adapter);
+    return adapter->alora_invocation_tokens.data();
+}
index 65824e972765bd5b8b845e0c41ec46dc81657a4f..4f65247c0feb1d215e08e6c33f3d0d57bf19c8b8 100644 (file)
@@ -67,6 +67,12 @@ struct llama_adapter_lora {
 
     float alpha;
 
+    // gguf metadata
+    std::unordered_map<std::string, std::string> gguf_kv;
+
+    // activated lora (aLoRA)
+    std::vector<llama_token> alora_invocation_tokens;
+
     llama_adapter_lora() = default;
     ~llama_adapter_lora() = default;
 
index 18dcc6ddfe56714ba4b2223c0681961447d365e5..a4d2973ada5dc9a8289eb236b8a79373312b4f09 100644 (file)
@@ -22,6 +22,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_NOMIC_BERT_MOE,   "nomic-bert-moe"   },
     { LLM_ARCH_NEO_BERT,         "neo-bert"         },
     { LLM_ARCH_JINA_BERT_V2,     "jina-bert-v2"     },
+    { LLM_ARCH_JINA_BERT_V3,     "jina-bert-v3"     },
     { LLM_ARCH_BLOOM,            "bloom"            },
     { LLM_ARCH_STABLELM,         "stablelm"         },
     { LLM_ARCH_QWEN,             "qwen"             },
@@ -44,6 +45,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_GEMMA2,           "gemma2"           },
     { LLM_ARCH_GEMMA3,           "gemma3"           },
     { LLM_ARCH_GEMMA3N,          "gemma3n"          },
+    { LLM_ARCH_GEMMA_EMBEDDING,  "gemma-embedding"  },
     { LLM_ARCH_STARCODER2,       "starcoder2"       },
     { LLM_ARCH_MAMBA,            "mamba"            },
     { LLM_ARCH_MAMBA2,           "mamba2"           },
@@ -68,6 +70,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_T5ENCODER,        "t5encoder"        },
     { LLM_ARCH_JAIS,             "jais"             },
     { LLM_ARCH_NEMOTRON,         "nemotron"         },
+    { LLM_ARCH_NEMOTRON_H,       "nemotron_h"       },
     { LLM_ARCH_EXAONE,           "exaone"           },
     { LLM_ARCH_EXAONE4,          "exaone4"          },
     { LLM_ARCH_RWKV6,            "rwkv6"            },
@@ -93,6 +96,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_DREAM,            "dream"            },
     { LLM_ARCH_SMALLTHINKER,     "smallthinker"     },
     { LLM_ARCH_LLADA,            "llada"            },
+    { LLM_ARCH_LLADA_MOE,        "llada-moe"        },
+    { LLM_ARCH_SEED_OSS,         "seed_oss"         },
     { LLM_ARCH_UNKNOWN,          "(unknown)"        },
 };
 
@@ -133,7 +138,9 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_POOLING_TYPE,                      "%s.pooling_type"                      },
     { LLM_KV_LOGIT_SCALE,                       "%s.logit_scale"                       },
     { LLM_KV_DECODER_START_TOKEN_ID,            "%s.decoder_start_token_id"            },
+    { LLM_KV_DECODER_BLOCK_COUNT,               "%s.decoder_block_count"               },
     { LLM_KV_ATTN_LOGIT_SOFTCAPPING,            "%s.attn_logit_softcapping"            },
+    { LLM_KV_ROUTER_LOGIT_SOFTCAPPING,          "%s.router_logit_softcapping"          },
     { LLM_KV_FINAL_LOGIT_SOFTCAPPING,           "%s.final_logit_softcapping"           },
     { LLM_KV_SWIN_NORM,                         "%s.swin_norm"                         },
     { LLM_KV_RESCALE_EVERY_N_LAYERS,            "%s.rescale_every_n_layers"            },
@@ -164,19 +171,25 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,       "%s.attention.relative_buckets_count"       },
     { LLM_KV_ATTENTION_SLIDING_WINDOW,               "%s.attention.sliding_window"               },
     { LLM_KV_ATTENTION_SCALE,                        "%s.attention.scale"                        },
+    { LLM_KV_ATTENTION_OUTPUT_SCALE,                 "%s.attention.output_scale"                 },
+    { LLM_KV_ATTENTION_TEMPERATURE_LENGTH,           "%s.attention.temperature_length"           },
     { LLM_KV_ATTENTION_KEY_LENGTH_MLA,               "%s.attention.key_length_mla"               },
     { LLM_KV_ATTENTION_VALUE_LENGTH_MLA,             "%s.attention.value_length_mla"             },
 
-    { LLM_KV_ROPE_DIMENSION_COUNT,      "%s.rope.dimension_count"                 },
-    { LLM_KV_ROPE_DIMENSION_SECTIONS,   "%s.rope.dimension_sections"              },
-    { LLM_KV_ROPE_FREQ_BASE,            "%s.rope.freq_base"                       },
-    { LLM_KV_ROPE_SCALE_LINEAR,         "%s.rope.scale_linear"                    },
-    { LLM_KV_ROPE_SCALING_TYPE,         "%s.rope.scaling.type"                    },
-    { LLM_KV_ROPE_SCALING_FACTOR,       "%s.rope.scaling.factor"                  },
-    { LLM_KV_ROPE_SCALING_ATTN_FACTOR,  "%s.rope.scaling.attn_factor"             },
-    { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
-    { LLM_KV_ROPE_SCALING_FINETUNED,    "%s.rope.scaling.finetuned"               },
-    { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier"     },
+    { LLM_KV_ROPE_DIMENSION_COUNT,          "%s.rope.dimension_count"                 },
+    { LLM_KV_ROPE_DIMENSION_SECTIONS,       "%s.rope.dimension_sections"              },
+    { LLM_KV_ROPE_FREQ_BASE,                "%s.rope.freq_base"                       },
+    { LLM_KV_ROPE_SCALE_LINEAR,             "%s.rope.scale_linear"                    },
+    { LLM_KV_ROPE_SCALING_TYPE,             "%s.rope.scaling.type"                    },
+    { LLM_KV_ROPE_SCALING_FACTOR,           "%s.rope.scaling.factor"                  },
+    { LLM_KV_ROPE_SCALING_ATTN_FACTOR,      "%s.rope.scaling.attn_factor"             },
+    { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,     "%s.rope.scaling.original_context_length" },
+    { LLM_KV_ROPE_SCALING_FINETUNED,        "%s.rope.scaling.finetuned"               },
+    { LLM_KV_ROPE_SCALING_YARN_LOG_MUL,     "%s.rope.scaling.yarn_log_multiplier"     },
+    { LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR,  "%s.rope.scaling.yarn_ext_factor"         },
+    { LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, "%s.rope.scaling.yarn_attn_factor"        },
+    { LLM_KV_ROPE_SCALING_YARN_BETA_FAST,   "%s.rope.scaling.yarn_beta_fast"          },
+    { LLM_KV_ROPE_SCALING_YARN_BETA_SLOW,   "%s.rope.scaling.yarn_beta_slow"          },
 
     { LLM_KV_SPLIT_NO,            "split.no"            },
     { LLM_KV_SPLIT_COUNT,         "split.count"         },
@@ -233,8 +246,11 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_TOKENIZER_FIM_REP_ID,           "tokenizer.ggml.fim_rep_token_id"         },
     { LLM_KV_TOKENIZER_FIM_SEP_ID,           "tokenizer.ggml.fim_sep_token_id"         },
 
-    { LLM_KV_ADAPTER_TYPE,       "adapter.type"       },
-    { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
+    { LLM_KV_ADAPTER_TYPE,                    "adapter.type"               },
+    { LLM_KV_ADAPTER_LORA_ALPHA,              "adapter.lora.alpha"         },
+    { LLM_KV_ADAPTER_LORA_TASK_NAME,          "adapter.lora.task_name"     },
+    { LLM_KV_ADAPTER_LORA_PROMPT_PREFIX,      "adapter.lora.prompt_prefix" },
+    { LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS, "adapter.alora.invocation_tokens" },
 
     // deprecated
     { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" },
@@ -390,12 +406,16 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
             { LLM_TENSOR_FFN_GATE_INP,    "blk.%d.ffn_gate_inp" },
             { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
             { LLM_TENSOR_FFN_GATE_EXP,    "blk.%d.ffn_gate.%d" },
             { LLM_TENSOR_FFN_DOWN_EXP,    "blk.%d.ffn_down.%d" },
             { LLM_TENSOR_FFN_UP_EXP,      "blk.%d.ffn_up.%d" },
             { LLM_TENSOR_FFN_GATE_EXPS,   "blk.%d.ffn_gate_exps" },
             { LLM_TENSOR_FFN_DOWN_EXPS,   "blk.%d.ffn_down_exps" },
             { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
+            { LLM_TENSOR_FFN_POST_NORM,   "blk.%d.post_ffw_norm" },
             { LLM_TENSOR_LAYER_OUT_NORM,  "blk.%d.layer_output_norm" },
             { LLM_TENSOR_ATTN_OUT_NORM,   "blk.%d.attn_output_norm" },
         },
@@ -574,6 +594,20 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_CLS,             "cls" },
         },
     },
+    {
+        LLM_ARCH_JINA_BERT_V3,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
+            { LLM_TENSOR_TOKEN_TYPES,     "token_types" },
+            { LLM_TENSOR_ATTN_OUT_NORM,   "blk.%d.attn_output_norm" },
+            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+            { LLM_TENSOR_LAYER_OUT_NORM,  "blk.%d.layer_output_norm" },
+        },
+    },
     {
         LLM_ARCH_BLOOM,
         {
@@ -1019,6 +1053,27 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_LAUREL_POST_NORM,     "blk.%d.laurel_post_norm" },
         },
     },
+    {
+        LLM_ARCH_GEMMA_EMBEDDING,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_Q_NORM,     "blk.%d.attn_q_norm" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_POST_NORM,  "blk.%d.post_attention_norm" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+            { LLM_TENSOR_FFN_POST_NORM,   "blk.%d.post_ffw_norm" },
+        },
+    },
     {
         LLM_ARCH_STARCODER2,
         {
@@ -1532,6 +1587,31 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_NEMOTRON_H,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,     "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,    "output_norm" },
+            { LLM_TENSOR_OUTPUT,         "output" },
+            { LLM_TENSOR_ATTN_NORM,      "blk.%d.attn_norm" },
+            // mamba(2) ssm layers
+            { LLM_TENSOR_SSM_IN,         "blk.%d.ssm_in" },
+            { LLM_TENSOR_SSM_CONV1D,     "blk.%d.ssm_conv1d" },
+            { LLM_TENSOR_SSM_DT,         "blk.%d.ssm_dt" },
+            { LLM_TENSOR_SSM_A,          "blk.%d.ssm_a" },
+            { LLM_TENSOR_SSM_D,          "blk.%d.ssm_d" },
+            { LLM_TENSOR_SSM_NORM,       "blk.%d.ssm_norm" },
+            { LLM_TENSOR_SSM_OUT,        "blk.%d.ssm_out" },
+            // attention layers
+            { LLM_TENSOR_ATTN_Q,         "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,         "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,         "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,       "blk.%d.attn_output" },
+            // dense FFN
+            { LLM_TENSOR_FFN_DOWN,       "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,         "blk.%d.ffn_up" },
+        },
+    },
     {
         LLM_ARCH_EXAONE,
         {
@@ -2010,6 +2090,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
             { LLM_TENSOR_TOKEN_EMBD,        "token_embd" },
             { LLM_TENSOR_TOKEN_EMBD_NORM,   "token_embd_norm" },
+            { LLM_TENSOR_OUTPUT,            "output" },
         }
     },
     {
@@ -2067,6 +2148,43 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_LLADA_MOE,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
+            { LLM_TENSOR_OUTPUT,             "output" },
+            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_Q_NORM,        "blk.%d.attn_q_norm" },
+            { LLM_TENSOR_ATTN_K,             "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_K_NORM,        "blk.%d.attn_k_norm" },
+            { LLM_TENSOR_ATTN_V,             "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE_INP,       "blk.%d.ffn_gate_inp" },
+            { LLM_TENSOR_FFN_GATE_EXPS,      "blk.%d.ffn_gate_exps" },
+            { LLM_TENSOR_FFN_DOWN_EXPS,      "blk.%d.ffn_down_exps" },
+            { LLM_TENSOR_FFN_UP_EXPS,        "blk.%d.ffn_up_exps" },
+        },
+    },
+    {
+        LLM_ARCH_SEED_OSS,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_POST_NORM,  "blk.%d.post_attention_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
     {
         LLM_ARCH_UNKNOWN,
         {
@@ -2319,6 +2437,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
         case LLM_ARCH_PLAMO2:
         case LLM_ARCH_GRANITE_HYBRID:
         case LLM_ARCH_LFM2:
+        case LLM_ARCH_NEMOTRON_H:
             return true;
         default:
             return false;
@@ -2329,6 +2448,7 @@ bool llm_arch_is_diffusion(const llm_arch & arch) {
     switch (arch) {
         case LLM_ARCH_DREAM:
         case LLM_ARCH_LLADA:
+        case LLM_ARCH_LLADA_MOE:
             return true;
         default:
             return false;
index 7af587e7951bcf65197ad9d9a08aa9f6a060e352..d181ce6784ffb9469dcdf1795ec52705320dcf36 100644 (file)
@@ -26,6 +26,7 @@ enum llm_arch {
     LLM_ARCH_NOMIC_BERT_MOE,
     LLM_ARCH_NEO_BERT,
     LLM_ARCH_JINA_BERT_V2,
+    LLM_ARCH_JINA_BERT_V3,
     LLM_ARCH_BLOOM,
     LLM_ARCH_STABLELM,
     LLM_ARCH_QWEN,
@@ -48,6 +49,7 @@ enum llm_arch {
     LLM_ARCH_GEMMA2,
     LLM_ARCH_GEMMA3,
     LLM_ARCH_GEMMA3N,
+    LLM_ARCH_GEMMA_EMBEDDING,
     LLM_ARCH_STARCODER2,
     LLM_ARCH_MAMBA,
     LLM_ARCH_MAMBA2,
@@ -72,6 +74,7 @@ enum llm_arch {
     LLM_ARCH_T5ENCODER,
     LLM_ARCH_JAIS,
     LLM_ARCH_NEMOTRON,
+    LLM_ARCH_NEMOTRON_H,
     LLM_ARCH_EXAONE,
     LLM_ARCH_EXAONE4,
     LLM_ARCH_RWKV6,
@@ -97,6 +100,8 @@ enum llm_arch {
     LLM_ARCH_DREAM,
     LLM_ARCH_SMALLTHINKER,
     LLM_ARCH_LLADA,
+    LLM_ARCH_LLADA_MOE,
+    LLM_ARCH_SEED_OSS,
     LLM_ARCH_UNKNOWN,
 };
 
@@ -137,7 +142,9 @@ enum llm_kv {
     LLM_KV_POOLING_TYPE,
     LLM_KV_LOGIT_SCALE,
     LLM_KV_DECODER_START_TOKEN_ID,
+    LLM_KV_DECODER_BLOCK_COUNT,
     LLM_KV_ATTN_LOGIT_SOFTCAPPING,
+    LLM_KV_ROUTER_LOGIT_SOFTCAPPING,
     LLM_KV_FINAL_LOGIT_SOFTCAPPING,
     LLM_KV_SWIN_NORM,
     LLM_KV_RESCALE_EVERY_N_LAYERS,
@@ -168,6 +175,8 @@ enum llm_kv {
     LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
     LLM_KV_ATTENTION_SLIDING_WINDOW,
     LLM_KV_ATTENTION_SCALE,
+    LLM_KV_ATTENTION_OUTPUT_SCALE,
+    LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
     LLM_KV_ATTENTION_KEY_LENGTH_MLA,
     LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
 
@@ -181,6 +190,10 @@ enum llm_kv {
     LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
     LLM_KV_ROPE_SCALING_FINETUNED,
     LLM_KV_ROPE_SCALING_YARN_LOG_MUL,
+    LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR,
+    LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR,
+    LLM_KV_ROPE_SCALING_YARN_BETA_FAST,
+    LLM_KV_ROPE_SCALING_YARN_BETA_SLOW,
 
     LLM_KV_SPLIT_NO,
     LLM_KV_SPLIT_COUNT,
@@ -229,6 +242,9 @@ enum llm_kv {
 
     LLM_KV_ADAPTER_TYPE,
     LLM_KV_ADAPTER_LORA_ALPHA,
+    LLM_KV_ADAPTER_LORA_TASK_NAME,
+    LLM_KV_ADAPTER_LORA_PROMPT_PREFIX,
+    LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS,
 
     LLM_KV_POSNET_EMBEDDING_LENGTH,
     LLM_KV_POSNET_BLOCK_COUNT,
index 0a96a9a579e26773fc5b7290d6d5d41a44b0b6b2..66e6c6a38f1cd4ac29080aa41ea8fdd56dc4e494 100644 (file)
 static std::string trim(const std::string & str) {
     size_t start = 0;
     size_t end = str.size();
-    while (start < end && isspace(str[start])) {
+    while (start < end && isspace(static_cast<unsigned char>(str[start]))) {
         start += 1;
     }
-    while (end > start && isspace(str[end - 1])) {
+    while (end > start && isspace(static_cast<unsigned char>(str[end - 1]))) {
         end -= 1;
     }
     return str.substr(start, end - start);
@@ -69,6 +69,8 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
     { "gpt-oss",           LLM_CHAT_TEMPLATE_OPENAI_MOE        },
     { "hunyuan-dense",     LLM_CHAT_TEMPLATE_HUNYUAN_DENSE     },
     { "kimi-k2",           LLM_CHAT_TEMPLATE_KIMI_K2           },
+    { "seed_oss",          LLM_CHAT_TEMPLATE_SEED_OSS          },
+    { "grok-2",            LLM_CHAT_TEMPLATE_GROK_2            },
 };
 
 llm_chat_template llm_chat_template_from_str(const std::string & name) {
@@ -201,6 +203,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
         return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE;
     } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) {
         return LLM_CHAT_TEMPLATE_KIMI_K2;
+    } else if (tmpl_contains("<seed:bos>")) {
+        return LLM_CHAT_TEMPLATE_SEED_OSS;
+    } else if (tmpl_contains("'Assistant: '  + message['content'] + '<|separator|>")) {
+        return LLM_CHAT_TEMPLATE_GROK_2;
     }
     return LLM_CHAT_TEMPLATE_UNKNOWN;
 }
@@ -752,6 +758,28 @@ int32_t llm_chat_apply_template(
         if (add_ass) {
             ss << "<|im_assistant|>assistant<|im_middle|>";
         }
+    } else if (tmpl == LLM_CHAT_TEMPLATE_SEED_OSS) {
+        for (auto message: chat) {
+            std::string role(message->role);
+            ss << "<seed:bos>" << role << "\n" << (role == "assistant" ? trim(message->content) : message->content) << "<seed:eos>";
+        }
+        if (add_ass) {
+            ss << "<seed:bos>assistant\n";
+        }
+    } else if (tmpl == LLM_CHAT_TEMPLATE_GROK_2) {
+        for (auto message : chat) {
+            std::string role(message->role);
+            if (role == "system") {
+                ss << "System: " << trim(message->content) << "<|separator|>\n\n";
+            } else if (role == "user") {
+                ss << "Human: " << trim(message->content) << "<|separator|>\n\n";
+            } else if (role == "assistant") {
+                ss << "Assistant: " << message->content << "<|separator|>\n\n";
+            }
+        }
+        if (add_ass) {
+            ss << "Assistant:";
+        }
     } else {
         // template not supported
         return -1;
index 35a943856fa528dd81987299ae5d83b9456e5a60..5a87d9ab627bcccd214c41b4c3260449357c1d7b 100644 (file)
@@ -49,6 +49,8 @@ enum llm_chat_template {
     LLM_CHAT_TEMPLATE_OPENAI_MOE,
     LLM_CHAT_TEMPLATE_HUNYUAN_DENSE,
     LLM_CHAT_TEMPLATE_KIMI_K2,
+    LLM_CHAT_TEMPLATE_SEED_OSS,
+    LLM_CHAT_TEMPLATE_GROK_2,
     LLM_CHAT_TEMPLATE_UNKNOWN,
 };
 
index 7d7abad5d4a2dd94ce53038c1ecd5e6b2ed3b2a4..e6f76421cf1319702d476e751db6e45b3ca498ae 100644 (file)
@@ -35,14 +35,12 @@ llama_context::llama_context(
 
     cparams.n_threads        = params.n_threads;
     cparams.n_threads_batch  = params.n_threads_batch;
-    cparams.yarn_ext_factor  = params.yarn_ext_factor;
-    cparams.yarn_attn_factor = params.yarn_attn_factor;
-    cparams.yarn_beta_fast   = params.yarn_beta_fast;
-    cparams.yarn_beta_slow   = params.yarn_beta_slow;
-    cparams.defrag_thold     = params.defrag_thold;
+    cparams.yarn_ext_factor  = params.yarn_ext_factor  >= 0.0f ? params.yarn_ext_factor  : hparams.yarn_ext_factor;
+    cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor;
+    cparams.yarn_beta_fast   = params.yarn_beta_fast   >= 0.0f ? params.yarn_beta_fast   : hparams.yarn_beta_fast;
+    cparams.yarn_beta_slow   = params.yarn_beta_slow   >= 0.0f ? params.yarn_beta_slow   : hparams.yarn_beta_slow;
     cparams.embeddings       = params.embeddings;
     cparams.offload_kqv      = params.offload_kqv;
-    cparams.flash_attn       = params.flash_attn;
     cparams.no_perf          = params.no_perf;
     cparams.pooling_type     = params.pooling_type;
     cparams.warmup           = false;
@@ -87,13 +85,15 @@ llama_context::llama_context(
         cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
     }
 
+    cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED;
+
     // with causal attention, the batch size is limited by the context size
     cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
 
     // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
     // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
     // ref: https://github.com/ggerganov/llama.cpp/pull/5021
-    // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self
+    // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_memory
     if (cparams.n_batch < GGML_KQ_MASK_PAD) {
         LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
         cparams.n_batch = GGML_KQ_MASK_PAD;
@@ -103,16 +103,6 @@ llama_context::llama_context(
     cparams.op_offload = params.op_offload;
     cparams.kv_unified = params.kv_unified;
 
-    {
-        const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
-        supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : supports_set_rows;
-
-        if (!supports_set_rows && !cparams.kv_unified) {
-            LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__);
-            cparams.kv_unified = true;
-        }
-    }
-
     {
         const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE");
         graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable;
@@ -130,7 +120,7 @@ llama_context::llama_context(
     LLAMA_LOG_INFO("%s: n_batch       = %u\n",   __func__, cparams.n_batch);
     LLAMA_LOG_INFO("%s: n_ubatch      = %u\n",   __func__, cparams.n_ubatch);
     LLAMA_LOG_INFO("%s: causal_attn   = %d\n",   __func__, cparams.causal_attn);
-    LLAMA_LOG_INFO("%s: flash_attn    = %d\n",   __func__, cparams.flash_attn);
+    LLAMA_LOG_INFO("%s: flash_attn    = %s\n",   __func__, llama_flash_attn_type_name(params.flash_attn_type));
     LLAMA_LOG_INFO("%s: kv_unified    = %s\n",   __func__, cparams.kv_unified ? "true" : "false");
     LLAMA_LOG_INFO("%s: freq_base     = %.1f\n", __func__, cparams.rope_freq_base);
     LLAMA_LOG_INFO("%s: freq_scale    = %g\n",   __func__, cparams.rope_freq_scale);
@@ -145,11 +135,6 @@ llama_context::llama_context(
                 __func__, n_ctx_per_seq, hparams.n_ctx_train);
     }
 
-    if (!params.swa_full && cparams.n_seq_max > 1 && hparams.is_swa_any()) {
-        LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
-                __func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
-    }
-
     if (!hparams.vocab_only) {
         // GPU backends
         for (auto * dev : model.devices) {
@@ -196,7 +181,7 @@ llama_context::llama_context(
         // graph outputs buffer
         {
             // resized during inference when a batch uses more outputs
-            if ((uint32_t) output_reserve(params.n_seq_max) < params.n_seq_max) {
+            if (output_reserve(params.n_seq_max) < params.n_seq_max) {
                 throw std::runtime_error("failed to reserve initial output buffer");
             }
 
@@ -285,28 +270,75 @@ llama_context::llama_context(
         }
     }
 
-    // reserve worst-case graph
-    if (!hparams.vocab_only && memory) {
+    if (!hparams.vocab_only) {
+        llama_memory_context_ptr mctx;
+        if (memory) {
+            LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
+            mctx = memory->init_full();
+            if (!mctx) {
+                throw std::runtime_error("failed to initialize memory module");
+            }
+        }
+
+        cross.v_embd.clear();
+
         const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
         const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
 
+        // avoid reserving graphs with zero outputs - assume one output per sequence
+        n_outputs = n_seqs;
+
         LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
 
+        // resolve automatic Flash Attention use
+        if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
+            auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
+            if (!gf) {
+                throw std::runtime_error("failed to split graph for Flash Attention check");
+            }
+
+            const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
+            bool fa_device_mismatch = false;
+            for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
+                ggml_tensor * n = ggml_graph_node(gf, i);
+                if (n->op != GGML_OP_FLASH_ATTN_EXT) {
+                    continue;
+                }
+                ggml_backend_dev_t device_fa = ggml_backend_get_device(
+                    ggml_backend_sched_get_tensor_backend(sched.get(), n));
+
+                // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
+                GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
+                const int il = std::stoi(n->name + prefix_len);
+                ggml_backend_dev_t device_kv = model.dev_layer(il);
+                if (device_fa != device_kv) {
+                    LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
+                        "is assigned to device %s (usually due to missing support)\n",
+                        __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
+                    // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
+                    fa_device_mismatch = true;
+                    break;
+                }
+            }
+            if (fa_device_mismatch) {
+                cparams.flash_attn = false;
+                LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
+                if (ggml_is_quantized(params.type_v)) {
+                    throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
+                }
+            } else {
+                cparams.flash_attn = true;
+                LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
+            }
+        }
+
+        // reserve worst-case graph
         int n_splits_pp = -1;
         int n_nodes_pp  = -1;
 
         int n_splits_tg = -1;
         int n_nodes_tg  = -1;
 
-        // simulate full KV cache
-
-        const auto mctx = memory->init_full();
-        if (!mctx) {
-            throw std::runtime_error("failed to initialize KV cache");
-        }
-
-        cross.v_embd.clear();
-
         // reserve pp (prompt processing) graph first so that buffers are only allocated once
         {
             auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
@@ -444,26 +476,12 @@ llama_memory_t llama_context::get_memory() const {
     return memory.get();
 }
 
-// deprecated
-void llama_context::kv_self_defrag_sched() {
-    if (!memory) {
-        return;
-    }
-
-    memory_force_optimize = true;
-}
-
-// deprecated
-bool llama_context::kv_self_update(bool optimize) {
+bool llama_context::memory_update(bool optimize) {
     if (!memory) {
         return false;
     }
 
     {
-        // TODO: remove in the future
-        optimize |= memory_force_optimize;
-        memory_force_optimize = false;
-
         const auto mctx = memory->init_update(this, optimize);
         switch (mctx->get_status()) {
             case LLAMA_MEMORY_STATUS_SUCCESS:
@@ -908,12 +926,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
         }
     }
 
-    if (!supports_set_rows) {
-        // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
-        // overlap with device computation.
-        ggml_backend_sched_reset(sched.get());
-    }
-
     // TODO: hacky solution
     if (model.arch == LLM_ARCH_T5 && t_embd) {
         //cross.t_embd = t_embd;
@@ -997,8 +1009,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
 
     bool did_optimize = false;
 
-    // handle any pending defrags/shifts
-    kv_self_update(false);
+    // handle any pending shifts/copies
+    memory_update(false);
 
     llama_memory_context_ptr mctx;
 
@@ -1023,7 +1035,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
                     if (!did_optimize) {
                         did_optimize = true;
 
-                        if (kv_self_update(true)) {
+                        if (memory_update(true)) {
                             LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens());
 
                             continue;
@@ -1076,7 +1088,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
         const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
 
         if (!res) {
-            // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
+            // the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module
             llama_pos pos_min[LLAMA_MAX_SEQ];
             for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
                 pos_min[s] = std::numeric_limits<llama_pos>::max();
@@ -1093,7 +1105,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
                     continue;
                 }
 
-                LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
+                LLAMA_LOG_WARN("%s: removing memory module entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
 
                 memory->seq_rm(s, pos_min[s], -1);
             }
@@ -1244,12 +1256,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
     // wait for the computation to finish (automatically done when obtaining the model output)
     //synchronize();
 
-    if (!supports_set_rows) {
-        // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
-        // overlap with device computation.
-        ggml_backend_sched_reset(sched.get());
-    }
-
     return 0;
 }
 
@@ -1363,8 +1369,9 @@ llm_graph_result * llama_context::get_gf_res_reserve() const {
     return static_cast<llm_graph_result *>(gf_res_reserve.get());
 }
 
-ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
+ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only) {
     LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
+    GGML_ASSERT(n_outputs >= 1);
 
     if (n_tokens % n_seqs != 0) {
         n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
@@ -1398,7 +1405,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
     this->n_outputs = save_n_outputs;
 
     // initialize scheduler with the specified graph
-    if (!ggml_backend_sched_reserve(sched.get(), gf)) {
+    if (split_only) {
+        ggml_backend_sched_split_graph(sched.get(), gf);
+    } else if (!ggml_backend_sched_reserve(sched.get(), gf)) {
         LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
         return nullptr;
     }
@@ -1438,7 +1447,9 @@ ggml_status llama_context::graph_compute(
     if (backend_cpu != nullptr) {
         auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu));
         auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool");
-        set_threadpool_fn(backend_cpu, tp);
+        if (set_threadpool_fn) {
+            set_threadpool_fn(backend_cpu, tp);
+        }
     }
 
     // set the number of threads for all the backends
@@ -1877,7 +1888,7 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
     }
 
     if (memory != nullptr) {
-        LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
+        LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
         memory->state_write(io);
     }
 
@@ -1963,7 +1974,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
     }
 
     if (memory) {
-        LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
+        LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);
 
         memory->state_read(io);
     }
@@ -2248,12 +2259,13 @@ llama_context_params llama_context_default_params() {
         /*.rope_scaling_type           =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
         /*.pooling_type                =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
         /*.attention_type              =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
+        /*.flash_attn_type             =*/ LLAMA_FLASH_ATTN_TYPE_AUTO,
         /*.rope_freq_base              =*/ 0.0f,
         /*.rope_freq_scale             =*/ 0.0f,
         /*.yarn_ext_factor             =*/ -1.0f,
-        /*.yarn_attn_factor            =*/ 1.0f,
-        /*.yarn_beta_fast              =*/ 32.0f,
-        /*.yarn_beta_slow              =*/ 1.0f,
+        /*.yarn_attn_factor            =*/ -1.0f,
+        /*.yarn_beta_fast              =*/ -1.0f,
+        /*.yarn_beta_slow              =*/ -1.0f,
         /*.yarn_orig_ctx               =*/ 0,
         /*.defrag_thold                =*/ -1.0f,
         /*.cb_eval                     =*/ nullptr,
@@ -2264,7 +2276,6 @@ llama_context_params llama_context_default_params() {
         /*.abort_callback_data         =*/ nullptr,
         /*.embeddings                  =*/ false,
         /*.offload_kqv                 =*/ true,
-        /*.flash_attn                  =*/ false,
         /*.no_perf                     =*/ true,
         /*.op_offload                  =*/ true,
         /*.swa_full                    =*/ true,
@@ -2292,12 +2303,30 @@ llama_context * llama_init_from_model(
         return nullptr;
     }
 
-    if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
+    if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && model->arch == LLM_ARCH_GROK) {
         LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
-        params.flash_attn = false;
+        params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
+    }
+
+    if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) {
+        const uint32_t blck_size = ggml_blck_size(params.type_k);
+        if (model->hparams.n_embd_head_k % blck_size != 0) {
+            LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n",
+                __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k);
+            return nullptr;
+        }
+    }
+
+    if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) {
+        const uint32_t blck_size = ggml_blck_size(params.type_v);
+        if (model->hparams.n_embd_head_v % blck_size != 0) {
+            LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_k=%u\n",
+                __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v);
+            return nullptr;
+        }
     }
 
-    if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
+    if (ggml_is_quantized(params.type_v) && params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_DISABLED) {
         LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
         return nullptr;
     }
@@ -2343,16 +2372,6 @@ const llama_model * llama_get_model(const llama_context * ctx) {
     return &ctx->get_model();
 }
 
-// deprecated
-llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
-    return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
-}
-
-// deprecated
-void llama_kv_self_update(llama_context * ctx) {
-    ctx->kv_self_update(false);
-}
-
 enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
     return ctx->pooling_type();
 }
@@ -2570,168 +2589,6 @@ bool llama_memory_can_shift(llama_memory_t mem) {
     return mem->get_can_shift();
 }
 
-//
-// kv cache
-//
-
-// deprecated
-int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
-    const auto * kv = llama_get_memory(ctx);
-    if (!kv) {
-        return 0;
-    }
-
-    int32_t res = 0;
-
-    for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
-        const llama_pos p0 = kv->seq_pos_min(s);
-        const llama_pos p1 = kv->seq_pos_max(s);
-
-        if (p0 >= 0) {
-            res += (p1 - p0) + 1;
-        }
-    }
-
-    return res;
-}
-
-// deprecated
-// note: this is the same as above - will be removed anyway, so it's ok
-int32_t llama_kv_self_used_cells(const llama_context * ctx) {
-    const auto * kv = llama_get_memory(ctx);
-    if (!kv) {
-        return 0;
-    }
-
-    int32_t res = 0;
-
-    for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
-        const llama_pos p0 = kv->seq_pos_min(s);
-        const llama_pos p1 = kv->seq_pos_max(s);
-
-        if (p0 >= 0) {
-            res += (p1 - p0) + 1;
-        }
-    }
-
-    return res;
-}
-
-// deprecated
-void llama_kv_self_clear(llama_context * ctx) {
-    auto * kv = llama_get_memory(ctx);
-    if (!kv) {
-        return;
-    }
-
-    llama_memory_clear(kv, true);
-}
-
-// deprecated
-bool llama_kv_self_seq_rm(
-        llama_context * ctx,
-         llama_seq_id   seq_id,
-            llama_pos   p0,
-            llama_pos   p1) {
-    auto * kv = llama_get_memory(ctx);
-    if (!kv) {
-        return true;
-    }
-
-    return llama_memory_seq_rm(kv, seq_id, p0, p1);
-}
-
-// deprecated
-void llama_kv_self_seq_cp(
-        llama_context * ctx,
-         llama_seq_id   seq_id_src,
-         llama_seq_id   seq_id_dst,
-            llama_pos   p0,
-            llama_pos   p1) {
-    auto * kv = llama_get_memory(ctx);
-    if (!kv) {
-        return;
-    }
-
-    llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
-}
-
-// deprecated
-void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
-    auto * kv = llama_get_memory(ctx);
-    if (!kv) {
-        return;
-    }
-
-    llama_memory_seq_keep(kv, seq_id);
-}
-
-// deprecated
-void llama_kv_self_seq_add(
-        llama_context * ctx,
-         llama_seq_id   seq_id,
-            llama_pos   p0,
-            llama_pos   p1,
-            llama_pos   delta) {
-    auto * kv = llama_get_memory(ctx);
-    if (!kv) {
-        return;
-    }
-
-    llama_memory_seq_add(kv, seq_id, p0, p1, delta);
-}
-
-// deprecated
-void llama_kv_self_seq_div(
-        llama_context * ctx,
-         llama_seq_id   seq_id,
-            llama_pos   p0,
-            llama_pos   p1,
-                  int   d) {
-    auto * kv = llama_get_memory(ctx);
-    if (!kv) {
-        return;
-    }
-
-    llama_memory_seq_div(kv, seq_id, p0, p1, d);
-}
-
-// deprecated
-llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
-    auto * kv = llama_get_memory(ctx);
-    if (!kv) {
-        return -1;
-    }
-
-    return llama_memory_seq_pos_min(kv, seq_id);
-}
-
-// deprecated
-llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
-    auto * kv = llama_get_memory(ctx);
-    if (!kv) {
-        return -1;
-    }
-
-    return llama_memory_seq_pos_max(kv, seq_id);
-}
-
-// deprecated
-void llama_kv_self_defrag(llama_context * ctx) {
-    // force defrag
-    ctx->kv_self_defrag_sched();
-}
-
-// deprecated
-bool llama_kv_self_can_shift(const llama_context * ctx) {
-    auto * kv = llama_get_memory(ctx);
-    if (!kv) {
-        return false;
-    }
-
-    return llama_memory_can_shift(kv);
-}
-
 // llama state API
 
 // deprecated
index 230ef8962b8fa41113f0c38d94798a1774a05f35..f23aa8ee1368dae4b40218537385b4a4e56851ec 100644 (file)
@@ -46,10 +46,8 @@ struct llama_context {
 
     llama_memory_t get_memory() const;
 
-    // return true of the KV cache was updated
-    // TODO: remove
-    bool kv_self_update(bool optimize);
-    void kv_self_defrag_sched();
+    // return true if the memory was updated
+    bool memory_update(bool optimize);
 
     enum llama_pooling_type pooling_type() const;
 
@@ -198,7 +196,7 @@ public:
     ggml_status graph_compute(ggml_cgraph * gf, bool batched);
 
     // reserve a graph with a dummy ubatch of the specified size
-    ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
+    ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false);
 
 private:
     llm_graph_params graph_params(
@@ -230,9 +228,6 @@ private:
 
     std::unique_ptr<llama_memory_i> memory;
 
-    // TODO: temporary, until the llama_kv_self_defrag() API is removed
-    bool memory_force_optimize = false;
-
     // decode output (2-dimensional array: [n_outputs][n_vocab])
     size_t  logits_size = 0; // capacity (of floats) for logits
     float * logits      = nullptr;
@@ -288,10 +283,6 @@ private:
 
     bool has_evaluated_once = false;
 
-    // env: LLAMA_SET_ROWS (temporary)
-    // ref: https://github.com/ggml-org/llama.cpp/pull/14285
-    bool supports_set_rows = true;
-
     // env: LLAMA_GRAPH_REUSE_DISABLE
     bool graph_reuse_disable = false;
 
index 38750affc500b74504f53bf65320332ddcef0d09..eae7b839f4857da2df56be6cd8d4d9a5fe362b7a 100644 (file)
@@ -4,7 +4,7 @@
 
 #include <cstdint>
 
-#define LLAMA_MAX_SEQ 64
+#define LLAMA_MAX_SEQ 256
 
 struct llama_cparams {
     uint32_t n_ctx;           // context size used during inference
@@ -24,7 +24,6 @@ struct llama_cparams {
     float yarn_attn_factor;
     float yarn_beta_fast;
     float yarn_beta_slow;
-    float defrag_thold;
 
     bool embeddings;
     bool causal_attn;
index 053c72d6dc8d187087865a1b18792b4dade405f9..9f2e417f1ff4b19be80e6371ff2048b7bad29c7c 100644 (file)
@@ -4,8 +4,8 @@
 #include "llama-batch.h"
 #include "llama-cparams.h"
 
-#include "llama-kv-cache-unified.h"
-#include "llama-kv-cache-unified-iswa.h"
+#include "llama-kv-cache.h"
+#include "llama-kv-cache-iswa.h"
 #include "llama-memory-hybrid.h"
 #include "llama-memory-recurrent.h"
 
@@ -258,6 +258,36 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
     }
 }
 
+static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
+    LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
+    const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" :
+                          (swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
+                          (swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" :
+                          (swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown";
+    LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
+    LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
+    LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
+
+    LLAMA_LOG_DEBUG("    ");
+    for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
+        LLAMA_LOG_DEBUG("%2d", j);
+    }
+    LLAMA_LOG_DEBUG("\n");
+
+    for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) {
+        LLAMA_LOG_DEBUG(" %2d ", i);
+        for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
+            float val = data[i * n_kv + j];
+            if (val == -INFINITY) {
+                LLAMA_LOG_DEBUG(" âˆž");
+            } else {
+                LLAMA_LOG_DEBUG(" 0");
+            }
+        }
+        LLAMA_LOG_DEBUG("\n");
+    }
+}
+
 void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
     const int64_t n_kv     = ubatch->n_tokens;
     const int64_t n_tokens = ubatch->n_tokens;
@@ -267,6 +297,9 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
 
     float * data = (float *) kq_mask->data;
 
+    // [TAG_NO_CACHE_ISWA]
+    GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");
+
     for (int h = 0; h < 1; ++h) {
         for (int i1 = 0; i1 < n_tokens; ++i1) {
             const llama_seq_id s1 = ubatch->seq_id[i1][0];
@@ -277,32 +310,44 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
                 for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
                     const llama_seq_id s0 = ubatch->seq_id[i0][0];
 
+                    if (s0 != s1) {
+                        continue; // skip different sequences
+                    }
+
+                    if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) {
+                        continue; // skip future tokens for causal attention
+                    }
+
+                    // TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
+                    //if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
+                    //    continue; // skip masked tokens for SWA
+                    //}
+
                     // TODO: reimplement this like in llama_kv_cache_unified
-                    if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
-                        if (hparams.use_alibi) {
-                            f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
-                        } else {
-                            f = 0.0f;
-                        }
-                        break;
+                    if (hparams.use_alibi) {
+                        f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
+                    } else {
+                        f = 0.0f;
                     }
                 }
-
                 data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
             }
         }
     }
+    if (debug) {
+        print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
+    }
 }
 
-void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
+void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
     mctx->set_input_k_idxs(self_k_idxs, ubatch);
     mctx->set_input_v_idxs(self_v_idxs, ubatch);
 
     mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
 }
 
-bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params) {
-    const auto * mctx = static_cast<const llama_kv_cache_unified_context *>(params.mctx);
+bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
+    const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
 
     this->mctx = mctx;
 
@@ -314,12 +359,10 @@ bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params)
     res &= self_kq_mask->ne[0] == mctx->get_n_kv();
     res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
 
-    res &= mctx->get_supports_set_rows(); // TODO: tmp
-
     return res;
 }
 
-void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
+void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
     mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
     mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
 
@@ -331,8 +374,8 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch
     mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
 }
 
-bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & params) {
-    const auto * mctx = static_cast<const llama_kv_cache_unified_iswa_context *>(params.mctx);
+bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
+    const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.mctx);
 
     this->mctx = mctx;
 
@@ -350,8 +393,6 @@ bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & pa
     res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
     res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
 
-    res &= mctx->get_base()->get_supports_set_rows(); // TODO: tmp
-
     return res;
 }
 
@@ -1186,7 +1227,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
 }
 
 ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
-    const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
+    const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
 
     auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
 
@@ -1223,15 +1264,16 @@ ggml_tensor * llm_graph_context::build_attn_mha(
          ggml_tensor * v,
          ggml_tensor * kq_b,
          ggml_tensor * kq_mask,
-         ggml_tensor * v_mla,
          ggml_tensor * sinks,
-             float     kq_scale) const {
+         ggml_tensor * v_mla,
+               float   kq_scale,
+                 int   il) const {
     const bool v_trans = v->nb[1] > v->nb[2];
 
     // split the batch into streams if needed
     const auto n_stream = k->ne[3];
 
-    q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream);
+    q = ggml_view_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream, q->nb[1], q->nb[2], q->nb[3]/n_stream, 0);
 
     q = ggml_permute(ctx0, q, 0, 2, 1, 3);
     k = ggml_permute(ctx0, k, 0, 2, 1, 3);
@@ -1260,6 +1302,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
 
         cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
                                   hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
+        cb(cur, LLAMA_TENSOR_NAME_FATTN, il);
 
         ggml_flash_attn_ext_add_sinks(cur, sinks);
         ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
@@ -1275,6 +1318,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
             // The permutations are noops and only change how the tensor data is interpreted.
             cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
             cur = ggml_mul_mat(ctx0, v_mla, cur);
+            cb(cur, "fattn_mla", il);
             cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
             cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
 #endif
@@ -1283,6 +1327,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
         cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
     } else {
         ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+        cb(kq, "kq", il);
 
         // note: this op tends to require high floating point range
         //       while for some models F16 is enough, for others it is not, so we default to F32 here
@@ -1290,38 +1335,48 @@ ggml_tensor * llm_graph_context::build_attn_mha(
 
         if (arch == LLM_ARCH_GROK) {
             // need to do the following:
-            // multiply by attn_output_multiplyer of 0.08838834764831845
+            // multiply by attn_output_multiplier
             // and then :
             // kq = 30 * tanh(kq / 30)
             // before the softmax below
 
-            kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f));
-            kq = ggml_scale(ctx0, kq, 30);
+            kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, hparams.f_attn_out_scale / hparams.f_attn_logit_softcapping));
+            cb(kq, "kq_tanh", il);
+            kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
+            cb(kq, "kq_scaled", il);
         }
 
         if (hparams.attn_soft_cap) {
             kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
+            cb(kq, "kq_scaled_1", il);
             kq = ggml_tanh (ctx0, kq);
+            cb(kq, "kq_tanh", il);
             kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
+            cb(kq, "kq_scaled_2", il);
         }
 
         if (kq_b) {
             kq = ggml_add(ctx0, kq, kq_b);
+            cb(kq, "kq_plus_kq_b", il);
         }
 
         kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
         ggml_soft_max_add_sinks(kq, sinks);
+        cb(kq, "kq_soft_max", il);
 
         if (!v_trans) {
             // note: avoid this branch
             v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
+            cb(v, "v_cont", il);
         }
 
         ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
+        cb(kqv, "kqv", il);
 
         // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
         if (v_mla) {
             kqv = ggml_mul_mat(ctx0, v_mla, kqv);
+            cb(kqv, "kqv_mla", il);
         }
 
         cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
@@ -1360,6 +1415,7 @@ ggml_tensor * llm_graph_context::build_attn(
         ggml_tensor * k_cur,
         ggml_tensor * v_cur,
         ggml_tensor * kq_b,
+        ggml_tensor * sinks,
         ggml_tensor * v_mla,
             float     kq_scale,
             int       il) const {
@@ -1375,13 +1431,14 @@ ggml_tensor * llm_graph_context::build_attn(
 
     // [TAG_NO_CACHE_PAD]
     // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
-    assert(!ubatch.equal_seqs());
+    //       but it might not be worth it: https://github.com/ggml-org/llama.cpp/pull/15636
+    //assert(!ubatch.equal_seqs() || (k_cur->ne[3] == 1 && k_cur->ne[3] == ubatch.n_seqs_unq));
 
     ggml_tensor * q = q_cur;
     ggml_tensor * k = k_cur;
     ggml_tensor * v = v_cur;
 
-    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
+    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
     cb(cur, "kqv_out", il);
 
     if (wo) {
@@ -1399,17 +1456,17 @@ ggml_tensor * llm_graph_context::build_attn(
     return cur;
 }
 
-static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unified_impl(
+static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
            ggml_context * ctx0,
      const llama_ubatch & ubatch,
     const llama_hparams & hparams,
     const llama_cparams & cparams,
-    const llama_kv_cache_unified_context * mctx_cur) {
+    const llama_kv_cache_context * mctx_cur) {
 
-    auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
+    auto inp = std::make_unique<llm_graph_input_attn_kv>(hparams, cparams, mctx_cur);
 
     {
-        GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
+        GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
 
         const auto n_kv     = mctx_cur->get_n_kv();
         const auto n_tokens = ubatch.n_tokens;
@@ -1427,22 +1484,23 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
     return inp;
 }
 
-llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
-    const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
+llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const {
+    const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
 
-    auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
+    auto inp = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
 
-    return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
+    return (llm_graph_input_attn_kv *) res->add_input(std::move(inp));
 }
 
 ggml_tensor * llm_graph_context::build_attn(
-        llm_graph_input_attn_kv_unified * inp,
+        llm_graph_input_attn_kv * inp,
         ggml_tensor * wo,
         ggml_tensor * wo_b,
         ggml_tensor * q_cur,
         ggml_tensor * k_cur,
         ggml_tensor * v_cur,
         ggml_tensor * kq_b,
+        ggml_tensor * sinks,
         ggml_tensor * v_mla,
             float     kq_scale,
             int       il) const {
@@ -1469,7 +1527,7 @@ ggml_tensor * llm_graph_context::build_attn(
     ggml_tensor * k = mctx_cur->get_k(ctx0, il);
     ggml_tensor * v = mctx_cur->get_v(ctx0, il);
 
-    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
+    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
     cb(cur, "kqv_out", il);
 
     if (wo) {
@@ -1488,40 +1546,15 @@ ggml_tensor * llm_graph_context::build_attn(
 }
 
 ggml_tensor * llm_graph_context::build_attn(
-        llm_graph_input_attn_kv_unified_iswa * inp,
+        llm_graph_input_attn_kv_iswa * inp,
         ggml_tensor * wo,
         ggml_tensor * wo_b,
         ggml_tensor * q_cur,
         ggml_tensor * k_cur,
         ggml_tensor * v_cur,
         ggml_tensor * kq_b,
-        ggml_tensor * v_mla,
-            float     kq_scale,
-            int       il) const {
-    return build_attn_with_sinks(
-            inp,
-            wo,
-            wo_b,
-            q_cur,
-            k_cur,
-            v_cur,
-            kq_b,
-            v_mla,
-            nullptr,
-            kq_scale,
-            il);
-}
-
-ggml_tensor * llm_graph_context::build_attn_with_sinks(
-        llm_graph_input_attn_kv_unified_iswa * inp,
-        ggml_tensor * wo,
-        ggml_tensor * wo_b,
-        ggml_tensor * q_cur,
-        ggml_tensor * k_cur,
-        ggml_tensor * v_cur,
-        ggml_tensor * kq_b,
-        ggml_tensor * v_mla,
         ggml_tensor * sinks,
+        ggml_tensor * v_mla,
             float     kq_scale,
             int       il) const {
     // these nodes are added to the graph together so that they are not reordered
@@ -1561,7 +1594,7 @@ ggml_tensor * llm_graph_context::build_attn_with_sinks(
     ggml_tensor * k = mctx_cur->get_k(ctx0, il);
     ggml_tensor * v = mctx_cur->get_v(ctx0, il);
 
-    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, sinks, kq_scale);
+    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
     cb(cur, "kqv_out", il);
 
     if (wo) {
@@ -1600,6 +1633,7 @@ ggml_tensor * llm_graph_context::build_attn(
         ggml_tensor * k_cur,
         ggml_tensor * v_cur,
         ggml_tensor * kq_b,
+        ggml_tensor * sinks,
         ggml_tensor * v_mla,
             float     kq_scale,
             int       il) const {
@@ -1615,7 +1649,7 @@ ggml_tensor * llm_graph_context::build_attn(
     ggml_tensor * k = k_cur;
     ggml_tensor * v = v_cur;
 
-    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
+    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
     cb(cur, "kqv_out", il);
 
     if (wo) {
@@ -1636,10 +1670,10 @@ ggml_tensor * llm_graph_context::build_attn(
 // TODO: maybe separate the inner implementation into a separate function
 //       like with the non-sliding window equivalent
 //       once sliding-window hybrid caches are a thing.
-llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
-    const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
+llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const {
+    const auto * mctx_cur = static_cast<const llama_kv_cache_iswa_context *>(mctx);
 
-    auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
+    auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
 
     const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
 
@@ -1656,7 +1690,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
     }
 
     {
-        GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
+        GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
 
         const auto n_kv = mctx_cur->get_swa()->get_n_kv();
 
@@ -1669,7 +1703,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
         inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
     }
 
-    return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
+    return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
 }
 
 ggml_tensor * llm_graph_context::build_rs(
@@ -1792,7 +1826,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
     const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
 
     auto inp_rs   = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
-    auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
+    auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
 
     auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
 
index 6ff49de3a1ce848ac7312bb32856bd0f4a73062f..ca90fdf613f6de0074e465ca414e7b80c070e85a 100644 (file)
@@ -19,8 +19,8 @@ struct llama_cparams;
 
 struct llama_memory_context_i;
 
-class llama_kv_cache_unified_context;
-class llama_kv_cache_unified_iswa_context;
+class llama_kv_cache_context;
+class llama_kv_cache_iswa_context;
 class llama_memory_recurrent_context;
 class llama_memory_hybrid_context;
 
@@ -78,6 +78,11 @@ struct llm_graph_params;
 
 class llm_graph_input_i {
 public:
+    llm_graph_input_i() {
+        const char * LLAMA_GRAPH_INPUT_DEBUG = getenv("LLAMA_GRAPH_INPUT_DEBUG");
+        debug = LLAMA_GRAPH_INPUT_DEBUG ? atoi(LLAMA_GRAPH_INPUT_DEBUG) : 0;
+    }
+
     virtual ~llm_graph_input_i() = default;
 
     virtual void set_input(const llama_ubatch * ubatch) = 0;
@@ -90,6 +95,9 @@ public:
         GGML_UNUSED(params);
         return false;
     }
+protected:
+    // env: LLAMA_GRAPH_INPUT_DEBUG
+    int debug = 0;
 };
 
 using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
@@ -152,7 +160,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
 public:
     llm_graph_input_pos_bucket_kv(
             const llama_hparams & hparams,
-            const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
+            const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {}
     virtual ~llm_graph_input_pos_bucket_kv() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
@@ -161,7 +169,7 @@ public:
 
     const llama_hparams hparams;
 
-    const llama_kv_cache_unified_context * mctx;
+    const llama_kv_cache_context * mctx;
 };
 
 class llm_graph_input_out_ids : public llm_graph_input_i {
@@ -257,17 +265,17 @@ public:
     const llama_cparams cparams;
 };
 
-class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
+class llm_graph_input_attn_kv : public llm_graph_input_i {
 public:
-    llm_graph_input_attn_kv_unified(
+    llm_graph_input_attn_kv(
             const llama_hparams & hparams,
             const llama_cparams & cparams,
-            const llama_kv_cache_unified_context * mctx) :
+            const llama_kv_cache_context * mctx) :
         hparams(hparams),
         cparams(cparams),
         mctx(mctx) {
     }
-    ~llm_graph_input_attn_kv_unified() = default;
+    ~llm_graph_input_attn_kv() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
 
@@ -290,20 +298,20 @@ public:
     const llama_hparams hparams;
     const llama_cparams cparams;
 
-    const llama_kv_cache_unified_context * mctx;
+    const llama_kv_cache_context * mctx;
 };
 
-class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
+class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
 public:
-    llm_graph_input_attn_kv_unified_iswa(
+    llm_graph_input_attn_kv_iswa(
             const llama_hparams & hparams,
             const llama_cparams & cparams,
-            const llama_kv_cache_unified_iswa_context * mctx) :
+            const llama_kv_cache_iswa_context * mctx) :
         hparams(hparams),
         cparams(cparams),
         mctx(mctx) {
     }
-    ~llm_graph_input_attn_kv_unified_iswa() = default;
+    ~llm_graph_input_attn_kv_iswa() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
 
@@ -330,7 +338,7 @@ public:
     const llama_hparams hparams;
     const llama_cparams cparams;
 
-    const llama_kv_cache_unified_iswa_context * mctx;
+    const llama_kv_cache_iswa_context * mctx;
 };
 
 class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -351,7 +359,7 @@ public:
 class llm_graph_input_mem_hybrid : public llm_graph_input_i {
 public:
     llm_graph_input_mem_hybrid(
-            std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn,
+            std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
             std::unique_ptr<llm_graph_input_rs>              inp_rs,
             const llama_memory_hybrid_context *              mctx) :
         inp_attn(std::move(inp_attn)),
@@ -361,11 +369,11 @@ public:
 
     void set_input(const llama_ubatch * ubatch) override;
 
-    std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn;
-    std::unique_ptr<llm_graph_input_rs>              inp_rs;
+    std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
+    std::unique_ptr<llm_graph_input_rs>      inp_rs;
 
-    llm_graph_input_attn_kv_unified * get_attn() const { return inp_attn.get(); }
-    llm_graph_input_rs              * get_recr() const { return inp_rs.get(); }
+    llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
+    llm_graph_input_rs      * get_recr() const { return inp_rs.get(); }
 
     const llama_memory_hybrid_context * mctx;
 };
@@ -680,14 +688,15 @@ struct llm_graph_context {
     //
 
     ggml_tensor * build_attn_mha(
-             ggml_tensor * q,       // [n_embd_head_q, n_head_q, n_tokens]
-             ggml_tensor * k,       // [n_embd_head_k, n_head_k, n_tokens]
-             ggml_tensor * v,       // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
-             ggml_tensor * kq_b,
-             ggml_tensor * kq_mask,
-             ggml_tensor * sinks,
-             ggml_tensor * v_mla,   // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
-                   float   kq_scale) const;
+            ggml_tensor * q,       // [n_embd_head_q, n_head_q, n_tokens]
+            ggml_tensor * k,       // [n_embd_head_k, n_head_k, n_tokens]
+            ggml_tensor * v,       // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
+            ggml_tensor * kq_b,
+            ggml_tensor * kq_mask,
+            ggml_tensor * sinks,   // [n_head_q]
+            ggml_tensor * v_mla,   // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
+                  float   kq_scale,
+                    int   il) const;
 
     llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
 
@@ -699,50 +708,39 @@ struct llm_graph_context {
             ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
             ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
             ggml_tensor * kq_b,
+            ggml_tensor * sinks, // [n_head_q]
             ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
                   float   kq_scale,
                     int   il) const;
 
-    llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
+    llm_graph_input_attn_kv * build_attn_inp_kv() const;
 
     ggml_tensor * build_attn(
-            llm_graph_input_attn_kv_unified * inp,
+            llm_graph_input_attn_kv * inp,
             ggml_tensor * wo,
             ggml_tensor * wo_b,
             ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
             ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
             ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
             ggml_tensor * kq_b,
+            ggml_tensor * sinks, // [n_head_q]
             ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
                   float   kq_scale,
                     int   il) const;
 
-    llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
+    llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
 
     // note: if k_cur or v_cur are not provided, they will not be stored in the memory
     ggml_tensor * build_attn(
-            llm_graph_input_attn_kv_unified_iswa * inp,
+            llm_graph_input_attn_kv_iswa * inp,
             ggml_tensor * wo,
             ggml_tensor * wo_b,
             ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
             ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
             ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
             ggml_tensor * kq_b,
-            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
-                  float   kq_scale,
-                    int   il) const;
-
-    // TODO: temporary to keep the diff small. after the code is public will refactor to simplify this
-    ggml_tensor * build_attn_with_sinks(
-            llm_graph_input_attn_kv_unified_iswa * inp,
-            ggml_tensor * wo,
-            ggml_tensor * wo_b,
-            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
-            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
-            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
-            ggml_tensor * kq_b,
-            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
             ggml_tensor * sinks, // [n_head_q]
+            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
                   float   kq_scale,
                     int   il) const;
 
@@ -756,6 +754,7 @@ struct llm_graph_context {
             ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
             ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
             ggml_tensor * kq_b,
+            ggml_tensor * sinks, // [n_head_q]
             ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
                   float   kq_scale,
                     int   il) const;
@@ -765,7 +764,7 @@ struct llm_graph_context {
     //
 
     // TODO: move this implementation to llama_memory_recurrent.
-    //       this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
+    //       this is analogous to llama_kv_cache::cpy_k / cpy_v
     //       when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
     //         implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
     //         `llama_memory_recurrent`
index 7a06368dcda68e1f133fae49f59008db4fb8af07..c04ac58f1af4ba3746c97ddefcfd723c390ad027 100644 (file)
@@ -1,6 +1,7 @@
 #include "llama-hparams.h"
 
 #include "ggml.h"
+#include <cassert>
 
 void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) {
     if (dense_first) {
@@ -153,3 +154,64 @@ bool llama_hparams::is_swa(uint32_t il) const {
 
     GGML_ABORT("fatal error");
 }
+
+bool llama_hparams::has_kv(uint32_t il) const {
+    if (n_layer_kv_from_start >= 0) {
+        if (il < (uint32_t) n_layer_kv_from_start) {
+            return true;
+        }
+
+        return false;
+    }
+
+    // by default, all layers have kv
+    return true;
+}
+
+uint32_t llama_hparams::n_layer_kv() const {
+    uint32_t res = 0;
+
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        if (has_kv(il)) {
+            res++;
+        }
+    }
+
+    return res;
+}
+
+bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) {
+    assert(p0 >= 0 && p1 >= 0);
+
+    switch (swa_type) {
+        case LLAMA_SWA_TYPE_NONE:
+            {
+            } break;
+        case LLAMA_SWA_TYPE_STANDARD:
+            {
+                if (p1 - p0 >= (int32_t) n_swa) {
+                    return true;
+                }
+            } break;
+        case LLAMA_SWA_TYPE_CHUNKED:
+            {
+                const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
+
+                if (p0 < pos_chunk_start) {
+                    return true;
+                }
+            } break;
+        case LLAMA_SWA_TYPE_SYMMETRIC:
+            {
+                const int32_t half_n_swa = (int32_t) n_swa / 2;
+                const int32_t pos_diff = p1 - p0;
+
+                // Mask if outside the symmetric window
+                if (pos_diff < -half_n_swa || pos_diff > half_n_swa) {
+                    return true;
+                }
+            } break;
+    }
+
+    return false;
+}
index bd23122443271b7fc3d078ad16a8a19e618bba97..202cbbd1b288423d80bc1bfd2399fcc7160e46aa 100644 (file)
@@ -16,9 +16,10 @@ enum llama_expert_gating_func_type {
 };
 
 enum llama_swa_type {
-    LLAMA_SWA_TYPE_NONE     = 0,
-    LLAMA_SWA_TYPE_STANDARD = 1,
-    LLAMA_SWA_TYPE_CHUNKED  = 2,
+    LLAMA_SWA_TYPE_NONE      = 0,
+    LLAMA_SWA_TYPE_STANDARD  = 1,
+    LLAMA_SWA_TYPE_CHUNKED   = 2,
+    LLAMA_SWA_TYPE_SYMMETRIC = 3,
 };
 
 struct llama_hparams_posnet {
@@ -41,6 +42,7 @@ struct llama_hparams {
     uint32_t n_embd;
     uint32_t n_embd_features = 0;
     uint32_t n_layer;
+     int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
     uint32_t n_rot;
     uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
     uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
@@ -80,8 +82,9 @@ struct llama_hparams {
     float f_norm_rms_eps;
     float f_norm_group_eps;
 
-    float f_attn_logit_softcapping  = 50.0f;
-    float f_final_logit_softcapping = 30.0f;
+    float f_attn_logit_softcapping   = 50.0f;
+    float f_router_logit_softcapping = 30.0f;
+    float f_final_logit_softcapping  = 30.0f;
 
     // for RWKV
     uint32_t rescale_every_n_layers = 0;
@@ -102,6 +105,11 @@ struct llama_hparams {
     uint32_t n_ctx_orig_yarn;
     float    rope_yarn_log_mul = 0.0f;
 
+    float    yarn_ext_factor  = -1.0f;
+    float    yarn_attn_factor =  1.0f;
+    float    yarn_beta_fast   = 32.0f;
+    float    yarn_beta_slow   =  1.0f;
+
     std::array<int, 4> rope_sections;
 
     // Sliding Window Attention (SWA)
@@ -134,10 +142,14 @@ struct llama_hparams {
     float f_embedding_scale = 0.0f;
     float f_attention_scale = 0.0f;
 
+    // grok-2
+    float    f_attn_out_scale = 0.0f;
+    uint32_t attn_temp_length = 0;
+
     bool causal_attn   = true;
     bool use_alibi     = false;
     bool attn_soft_cap = false;
-    bool use_kq_norm   = true;
+    bool use_kq_norm   = false;
 
     // for Classifiers
     uint32_t n_cls_out = 1;
@@ -157,6 +169,7 @@ struct llama_hparams {
     // needed by encoder-decoder models (e.g. T5, FLAN-T5)
     // ref: https://github.com/ggerganov/llama.cpp/pull/8141
     llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
+    uint32_t    dec_n_layer        = 0;
 
     enum llama_pooling_type      pooling_type            = LLAMA_POOLING_TYPE_NONE;
     enum llama_rope_type         rope_type               = LLAMA_ROPE_TYPE_NONE;
@@ -221,6 +234,16 @@ struct llama_hparams {
     uint32_t n_pos_per_embd() const;
 
     bool is_swa(uint32_t il) const;
+
+    bool has_kv(uint32_t il) const;
+
+    // number of layers for which has_kv() returns true
+    uint32_t n_layer_kv() const;
+
+    // note that this function uses different SWA parameters from those in the hparams
+    // TODO: think of a better place for this function
+    // TODO: pack the SWA params in a struct?
+    static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
 };
 
 static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
index 02b1d07f8400dc3caa75f1766bf46cdc33685d65..c5163e9225a5e04ee38c426aad139bf15ddc9a90 100644 (file)
@@ -59,3 +59,5 @@ std::string llama_format_tensor_shape(const std::vector<int64_t> & ne);
 std::string llama_format_tensor_shape(const struct ggml_tensor * t);
 
 std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i);
+
+#define LLAMA_TENSOR_NAME_FATTN "__fattn__"
diff --git a/examples/talk-llama/llama-kv-cache-iswa.cpp b/examples/talk-llama/llama-kv-cache-iswa.cpp
new file mode 100644 (file)
index 0000000..d734291
--- /dev/null
@@ -0,0 +1,318 @@
+#include "llama-kv-cache-iswa.h"
+
+#include "llama-impl.h"
+#include "llama-batch.h"
+#include "llama-model.h"
+
+#include <algorithm>
+#include <cassert>
+
+//
+// llama_kv_cache_iswa
+//
+
+llama_kv_cache_iswa::llama_kv_cache_iswa(
+        const llama_model & model,
+                ggml_type   type_k,
+                ggml_type   type_v,
+                     bool   v_trans,
+                     bool   offload,
+                     bool   swa_full,
+                     bool   unified,
+                 uint32_t   kv_size,
+                 uint32_t   n_seq_max,
+                 uint32_t   n_ubatch,
+                 uint32_t   n_pad,
+    const layer_filter_cb & filter,
+    const  layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {
+
+    // chain filters
+    const layer_filter_cb filter_base = [&](int32_t il) {
+        if (filter && !filter(il)) {
+            return false;
+        }
+
+        return !model.hparams.is_swa(il);
+    };
+
+    const layer_filter_cb filter_swa  = [&](int32_t il) {
+        if (filter && !filter(il)) {
+            return false;
+        }
+
+        return  model.hparams.is_swa(il);
+    };
+
+    const uint32_t size_base = kv_size;
+
+    uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
+
+    // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
+    if (swa_full) {
+        LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
+                __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
+
+        size_swa = size_base;
+    }
+
+    LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
+
+    kv_base = std::make_unique<llama_kv_cache>(
+            model, type_k, type_v,
+            v_trans, offload, unified, size_base, n_seq_max, n_pad,
+            0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
+
+    LLAMA_LOG_INFO("%s: creating     SWA KV cache, size = %u cells\n", __func__, size_swa);
+
+    kv_swa = std::make_unique<llama_kv_cache>(
+            model, type_k, type_v,
+            v_trans, offload, unified, size_swa, n_seq_max, n_pad,
+            hparams.n_swa, hparams.swa_type, filter_swa, reuse);
+}
+
+void llama_kv_cache_iswa::clear(bool data) {
+    kv_base->clear(data);
+    kv_swa ->clear(data);
+}
+
+bool llama_kv_cache_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
+    bool res = true;
+
+    res = res & kv_base->seq_rm(seq_id, p0, p1);
+    res = res & kv_swa ->seq_rm(seq_id, p0, p1);
+
+    return res;
+}
+
+void llama_kv_cache_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
+    kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
+    kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
+}
+
+void llama_kv_cache_iswa::seq_keep(llama_seq_id seq_id) {
+    kv_base->seq_keep(seq_id);
+    kv_swa ->seq_keep(seq_id);
+}
+
+void llama_kv_cache_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
+    kv_base->seq_add(seq_id, p0, p1, shift);
+    kv_swa ->seq_add(seq_id, p0, p1, shift);
+}
+
+void llama_kv_cache_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
+    kv_base->seq_div(seq_id, p0, p1, d);
+    kv_swa ->seq_div(seq_id, p0, p1, d);
+}
+
+llama_pos llama_kv_cache_iswa::seq_pos_min(llama_seq_id seq_id) const {
+    // the base cache is a superset of the SWA cache, so we can just check the SWA cache
+    return kv_swa->seq_pos_min(seq_id);
+}
+
+llama_pos llama_kv_cache_iswa::seq_pos_max(llama_seq_id seq_id) const {
+    return kv_swa->seq_pos_max(seq_id);
+}
+
+llama_memory_context_ptr llama_kv_cache_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
+    GGML_UNUSED(embd_all);
+
+    // first try simple split
+    do {
+        if (!unified) {
+            // requires equal splits, so we skip the simple split
+            break;
+        }
+
+        balloc.split_reset();
+
+        std::vector<llama_ubatch> ubatches;
+        while (true) {
+            auto ubatch = balloc.split_simple(n_ubatch);
+
+            if (ubatch.n_tokens == 0) {
+                break;
+            }
+
+            ubatches.push_back(std::move(ubatch)); // NOLINT
+        }
+
+        if (balloc.get_n_used() < balloc.get_n_tokens()) {
+            // failed to find a suitable split
+            break;
+        }
+
+        auto sinfos_base = kv_base->prepare(ubatches);
+        if (sinfos_base.empty()) {
+            break;
+        }
+
+        auto sinfos_swa = kv_swa->prepare(ubatches);
+        if (sinfos_swa.empty()) {
+            break;
+        }
+
+        assert(sinfos_base.size() == sinfos_swa.size());
+
+        return std::make_unique<llama_kv_cache_iswa_context>(
+                this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
+    } while (false);
+
+    // if it fails, try equal split
+    do {
+        balloc.split_reset();
+
+        std::vector<llama_ubatch> ubatches;
+        while (true) {
+            auto ubatch = balloc.split_equal(n_ubatch, !unified);
+
+            if (ubatch.n_tokens == 0) {
+                break;
+            }
+
+            ubatches.push_back(std::move(ubatch)); // NOLINT
+        }
+
+        if (balloc.get_n_used() < balloc.get_n_tokens()) {
+            // failed to find a suitable split
+            break;
+        }
+
+        auto sinfos_base = kv_base->prepare(ubatches);
+        if (sinfos_base.empty()) {
+            break;
+        }
+
+        auto sinfos_swa = kv_swa->prepare(ubatches);
+        if (sinfos_swa.empty()) {
+            break;
+        }
+
+        assert(sinfos_base.size() == sinfos_swa.size());
+
+        return std::make_unique<llama_kv_cache_iswa_context>(
+                this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
+    } while (false);
+
+    // TODO: if we fail again, we should attempt different splitting strategies
+    //       but to do that properly, we first have to refactor the batches to be more flexible
+
+    return std::make_unique<llama_kv_cache_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+}
+
+llama_memory_context_ptr llama_kv_cache_iswa::init_full() {
+    return std::make_unique<llama_kv_cache_iswa_context>(this);
+}
+
+llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, bool optimize) {
+    return std::make_unique<llama_kv_cache_iswa_context>(this, lctx, optimize);
+}
+
+bool llama_kv_cache_iswa::get_can_shift() const {
+    return kv_base->get_size() == kv_swa->get_size();
+}
+
+void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
+    if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
+        kv_base->state_write(io, seq_id, flags);
+    }
+
+    kv_swa->state_write(io, seq_id, flags);
+}
+
+void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
+    if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
+        kv_base->state_read(io, seq_id, flags);
+    }
+
+    kv_swa->state_read(io, seq_id, flags);
+}
+
+llama_kv_cache * llama_kv_cache_iswa::get_base() const {
+    return kv_base.get();
+}
+
+llama_kv_cache * llama_kv_cache_iswa::get_swa() const {
+    return kv_swa.get();
+}
+
+//
+// llama_kv_cache_iswa_context
+//
+
+llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(llama_memory_status status) : status(status) {}
+
+llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
+        llama_kv_cache_iswa * kv) :
+    ctx_base(kv->get_base()->init_full()),
+    ctx_swa (kv->get_swa ()->init_full()),
+    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
+}
+
+llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
+        llama_kv_cache_iswa * kv,
+        llama_context * lctx,
+        bool optimize) :
+    ctx_base(kv->get_base()->init_update(lctx, optimize)),
+    ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
+    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
+}
+
+llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
+        llama_kv_cache_iswa * kv,
+        slot_info_vec_t sinfos_base,
+        slot_info_vec_t sinfos_swa,
+        std::vector<llama_ubatch> ubatches) :
+    ubatches(std::move(ubatches)),
+    // note: here we copy the ubatches. not sure if this is ideal
+    ctx_base(new llama_kv_cache_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
+    ctx_swa (new llama_kv_cache_context(kv->get_swa (), std::move(sinfos_swa),  this->ubatches)),
+    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
+}
+
+llama_kv_cache_iswa_context:: ~llama_kv_cache_iswa_context() = default;
+
+bool llama_kv_cache_iswa_context::next() {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    ctx_base->next();
+    ctx_swa ->next();
+
+    if (++i_next >= ubatches.size()) {
+        return false;
+    }
+
+    return true;
+}
+
+bool llama_kv_cache_iswa_context::apply() {
+    assert(!llama_memory_status_is_fail(status));
+
+    bool res = true;
+
+    res = res & ctx_base->apply();
+    res = res & ctx_swa ->apply();
+
+    return res;
+}
+
+llama_memory_status llama_kv_cache_iswa_context::get_status() const {
+    return status;
+}
+
+const llama_ubatch & llama_kv_cache_iswa_context::get_ubatch() const {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    return ubatches[i_next];
+}
+
+const llama_kv_cache_context * llama_kv_cache_iswa_context::get_base() const {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    return static_cast<const llama_kv_cache_context *>(ctx_base.get());
+}
+
+const llama_kv_cache_context * llama_kv_cache_iswa_context::get_swa()  const {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    return static_cast<const llama_kv_cache_context *>(ctx_swa.get());
+}
diff --git a/examples/talk-llama/llama-kv-cache-iswa.h b/examples/talk-llama/llama-kv-cache-iswa.h
new file mode 100644 (file)
index 0000000..5ed134b
--- /dev/null
@@ -0,0 +1,135 @@
+#pragma once
+
+#include "llama-kv-cache.h"
+
+#include <vector>
+
+//
+// llama_kv_cache_iswa
+//
+
+// utilizes two instances of llama_kv_cache
+//   the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
+
+class llama_kv_cache_iswa : public llama_memory_i {
+public:
+    llama_kv_cache_iswa(
+            const llama_model & model,
+                    ggml_type   type_k,
+                    ggml_type   type_v,
+                         bool   v_trans,
+                         bool   offload,
+                         bool   swa_full,
+                         bool   unified,
+                     uint32_t   kv_size,
+                     uint32_t   n_seq_max,
+                     uint32_t   n_ubatch,
+                     uint32_t   n_pad,
+        const layer_filter_cb & filter,
+        const  layer_reuse_cb & reuse);
+
+    ~llama_kv_cache_iswa() = default;
+
+    //
+    // llama_memory_i
+    //
+
+    llama_memory_context_ptr init_batch(
+            llama_batch_allocr & balloc,
+            uint32_t n_ubatch,
+            bool embd_all) override;
+
+    llama_memory_context_ptr init_full() override;
+
+    llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
+
+    bool get_can_shift() const override;
+
+    void clear(bool data) override;
+
+    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
+    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
+    void seq_keep(llama_seq_id seq_id)                                                          override;
+    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
+    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
+
+    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
+    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
+
+    // state write/load
+
+    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
+    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
+
+    //
+    // llama_kv_cache_iswa specific API
+    //
+
+    llama_kv_cache * get_base() const;
+    llama_kv_cache * get_swa () const;
+
+private:
+    const llama_hparams & hparams;
+
+    const bool unified;
+
+    std::unique_ptr<llama_kv_cache> kv_base;
+    std::unique_ptr<llama_kv_cache> kv_swa;
+};
+
+class llama_kv_cache_iswa_context : public llama_memory_context_i {
+public:
+    using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
+
+    // used for errors
+    llama_kv_cache_iswa_context(llama_memory_status status);
+
+    // used to create a full-cache context
+    llama_kv_cache_iswa_context(
+            llama_kv_cache_iswa * kv);
+
+    // used to create an update context
+    llama_kv_cache_iswa_context(
+            llama_kv_cache_iswa * kv,
+            llama_context * lctx,
+            bool optimize);
+
+    // used to create a batch processing context from a batch
+    llama_kv_cache_iswa_context(
+            llama_kv_cache_iswa * kv,
+            slot_info_vec_t sinfos_base,
+            slot_info_vec_t sinfos_swa,
+            std::vector<llama_ubatch> ubatches);
+
+    virtual ~llama_kv_cache_iswa_context();
+
+    //
+    // llama_memory_context_i
+    //
+
+    bool next()  override;
+    bool apply() override;
+
+    llama_memory_status  get_status() const override;
+    const llama_ubatch & get_ubatch() const override;
+
+    //
+    // llama_kv_cache_iswa_context specific API
+    //
+
+    const llama_kv_cache_context * get_base() const;
+    const llama_kv_cache_context * get_swa()  const;
+
+private:
+    //llama_kv_cache_iswa * kv;
+
+    // the index of the next ubatch to process
+    size_t i_next = 0;
+
+    std::vector<llama_ubatch> ubatches;
+
+    const llama_memory_context_ptr ctx_base;
+    const llama_memory_context_ptr ctx_swa;
+
+    const llama_memory_status status;
+};
diff --git a/examples/talk-llama/llama-kv-cache-unified-iswa.cpp b/examples/talk-llama/llama-kv-cache-unified-iswa.cpp
deleted file mode 100644 (file)
index 1e363ff..0000000
+++ /dev/null
@@ -1,301 +0,0 @@
-#include "llama-kv-cache-unified-iswa.h"
-
-#include "llama-impl.h"
-#include "llama-batch.h"
-#include "llama-model.h"
-
-#include <algorithm>
-#include <cassert>
-
-//
-// llama_kv_cache_unified_iswa
-//
-
-llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
-        const llama_model & model,
-                ggml_type   type_k,
-                ggml_type   type_v,
-                     bool   v_trans,
-                     bool   offload,
-                     bool   swa_full,
-                     bool   unified,
-                 uint32_t   kv_size,
-                 uint32_t   n_seq_max,
-                 uint32_t   n_ubatch,
-                 uint32_t   n_pad) : hparams(model.hparams), unified(unified) {
-    llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
-    llama_kv_cache_unified::layer_filter_cb filter_swa  = [&](int32_t il) { return  model.hparams.is_swa(il); };
-
-    const uint32_t size_base = kv_size;
-
-    uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
-
-    // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
-    if (swa_full) {
-        LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
-                __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
-
-        size_swa = size_base;
-    }
-
-    LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
-
-    kv_base = std::make_unique<llama_kv_cache_unified>(
-            model, std::move(filter_base), type_k, type_v,
-            v_trans, offload, unified, size_base, n_seq_max, n_pad,
-            0, LLAMA_SWA_TYPE_NONE);
-
-    LLAMA_LOG_INFO("%s: creating     SWA KV cache, size = %u cells\n", __func__, size_swa);
-
-    kv_swa = std::make_unique<llama_kv_cache_unified>(
-            model, std::move(filter_swa), type_k, type_v,
-            v_trans, offload, unified, size_swa, n_seq_max, n_pad,
-            hparams.n_swa, hparams.swa_type);
-}
-
-void llama_kv_cache_unified_iswa::clear(bool data) {
-    kv_base->clear(data);
-    kv_swa ->clear(data);
-}
-
-bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
-    bool res = true;
-
-    res = res & kv_base->seq_rm(seq_id, p0, p1);
-    res = res & kv_swa ->seq_rm(seq_id, p0, p1);
-
-    return res;
-}
-
-void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
-    kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
-    kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
-}
-
-void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
-    kv_base->seq_keep(seq_id);
-    kv_swa ->seq_keep(seq_id);
-}
-
-void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
-    kv_base->seq_add(seq_id, p0, p1, shift);
-    kv_swa ->seq_add(seq_id, p0, p1, shift);
-}
-
-void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
-    kv_base->seq_div(seq_id, p0, p1, d);
-    kv_swa ->seq_div(seq_id, p0, p1, d);
-}
-
-llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
-    // the base cache is a superset of the SWA cache, so we can just check the SWA cache
-    return kv_swa->seq_pos_min(seq_id);
-}
-
-llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
-    return kv_swa->seq_pos_max(seq_id);
-}
-
-llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
-    GGML_UNUSED(embd_all);
-
-    // first try simple split
-    do {
-        if (!unified) {
-            // requires equal splits, so we skip the simple split
-            break;
-        }
-
-        balloc.split_reset();
-
-        std::vector<llama_ubatch> ubatches;
-        while (true) {
-            auto ubatch = balloc.split_simple(n_ubatch);
-
-            if (ubatch.n_tokens == 0) {
-                break;
-            }
-
-            ubatches.push_back(std::move(ubatch)); // NOLINT
-        }
-
-        if (balloc.get_n_used() < balloc.get_n_tokens()) {
-            // failed to find a suitable split
-            break;
-        }
-
-        auto sinfos_base = kv_base->prepare(ubatches);
-        if (sinfos_base.empty()) {
-            break;
-        }
-
-        auto sinfos_swa = kv_swa->prepare(ubatches);
-        if (sinfos_swa.empty()) {
-            break;
-        }
-
-        assert(sinfos_base.size() == sinfos_swa.size());
-
-        return std::make_unique<llama_kv_cache_unified_iswa_context>(
-                this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
-    } while (false);
-
-    // if it fails, try equal split
-    do {
-        balloc.split_reset();
-
-        std::vector<llama_ubatch> ubatches;
-        while (true) {
-            auto ubatch = balloc.split_equal(n_ubatch, !unified);
-
-            if (ubatch.n_tokens == 0) {
-                break;
-            }
-
-            ubatches.push_back(std::move(ubatch)); // NOLINT
-        }
-
-        if (balloc.get_n_used() < balloc.get_n_tokens()) {
-            // failed to find a suitable split
-            break;
-        }
-
-        auto sinfos_base = kv_base->prepare(ubatches);
-        if (sinfos_base.empty()) {
-            break;
-        }
-
-        auto sinfos_swa = kv_swa->prepare(ubatches);
-        if (sinfos_swa.empty()) {
-            break;
-        }
-
-        assert(sinfos_base.size() == sinfos_swa.size());
-
-        return std::make_unique<llama_kv_cache_unified_iswa_context>(
-                this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
-    } while (false);
-
-    // TODO: if we fail again, we should attempt different splitting strategies
-    //       but to do that properly, we first have to refactor the batches to be more flexible
-
-    return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
-}
-
-llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() {
-    return std::make_unique<llama_kv_cache_unified_iswa_context>(this);
-}
-
-llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
-    return std::make_unique<llama_kv_cache_unified_iswa_context>(this, lctx, optimize);
-}
-
-bool llama_kv_cache_unified_iswa::get_can_shift() const {
-    return kv_base->get_size() == kv_swa->get_size();
-}
-
-void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
-    if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
-        kv_base->state_write(io, seq_id, flags);
-    }
-
-    kv_swa->state_write(io, seq_id, flags);
-}
-
-void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
-    if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
-        kv_base->state_read(io, seq_id, flags);
-    }
-
-    kv_swa->state_read(io, seq_id, flags);
-}
-
-llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
-    return kv_base.get();
-}
-
-llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
-    return kv_swa.get();
-}
-
-//
-// llama_kv_cache_unified_iswa_context
-//
-
-llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {}
-
-llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
-        llama_kv_cache_unified_iswa * kv) :
-    ctx_base(kv->get_base()->init_full()),
-    ctx_swa (kv->get_swa ()->init_full()),
-    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
-}
-
-llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
-        llama_kv_cache_unified_iswa * kv,
-        llama_context * lctx,
-        bool optimize) :
-    ctx_base(kv->get_base()->init_update(lctx, optimize)),
-    ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
-    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
-}
-
-llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
-        llama_kv_cache_unified_iswa * kv,
-        slot_info_vec_t sinfos_base,
-        slot_info_vec_t sinfos_swa,
-        std::vector<llama_ubatch> ubatches) :
-    ubatches(std::move(ubatches)),
-    // note: here we copy the ubatches. not sure if this is ideal
-    ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
-    ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa),  this->ubatches)),
-    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
-}
-
-llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default;
-
-bool llama_kv_cache_unified_iswa_context::next() {
-    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
-
-    ctx_base->next();
-    ctx_swa ->next();
-
-    if (++i_next >= ubatches.size()) {
-        return false;
-    }
-
-    return true;
-}
-
-bool llama_kv_cache_unified_iswa_context::apply() {
-    assert(!llama_memory_status_is_fail(status));
-
-    bool res = true;
-
-    res = res & ctx_base->apply();
-    res = res & ctx_swa ->apply();
-
-    return res;
-}
-
-llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const {
-    return status;
-}
-
-const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const {
-    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
-
-    return ubatches[i_next];
-}
-
-const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const {
-    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
-
-    return static_cast<const llama_kv_cache_unified_context *>(ctx_base.get());
-}
-
-const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa()  const {
-    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
-
-    return static_cast<const llama_kv_cache_unified_context *>(ctx_swa.get());
-}
diff --git a/examples/talk-llama/llama-kv-cache-unified-iswa.h b/examples/talk-llama/llama-kv-cache-unified-iswa.h
deleted file mode 100644 (file)
index 7bc4df7..0000000
+++ /dev/null
@@ -1,133 +0,0 @@
-#pragma once
-
-#include "llama-kv-cache-unified.h"
-
-#include <vector>
-
-//
-// llama_kv_cache_unified_iswa
-//
-
-// utilizes two instances of llama_kv_cache_unified
-//   the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
-
-class llama_kv_cache_unified_iswa : public llama_memory_i {
-public:
-    llama_kv_cache_unified_iswa(
-            const llama_model & model,
-                    ggml_type   type_k,
-                    ggml_type   type_v,
-                         bool   v_trans,
-                         bool   offload,
-                         bool   swa_full,
-                         bool   unified,
-                     uint32_t   kv_size,
-                     uint32_t   n_seq_max,
-                     uint32_t   n_ubatch,
-                     uint32_t   n_pad);
-
-    ~llama_kv_cache_unified_iswa() = default;
-
-    //
-    // llama_memory_i
-    //
-
-    llama_memory_context_ptr init_batch(
-            llama_batch_allocr & balloc,
-            uint32_t n_ubatch,
-            bool embd_all) override;
-
-    llama_memory_context_ptr init_full() override;
-
-    llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
-
-    bool get_can_shift() const override;
-
-    void clear(bool data) override;
-
-    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
-    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
-    void seq_keep(llama_seq_id seq_id)                                                          override;
-    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
-    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
-
-    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
-    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
-
-    // state write/load
-
-    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
-    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
-
-    //
-    // llama_kv_cache_unified_iswa specific API
-    //
-
-    llama_kv_cache_unified * get_base() const;
-    llama_kv_cache_unified * get_swa () const;
-
-private:
-    const llama_hparams & hparams;
-
-    const bool unified;
-
-    std::unique_ptr<llama_kv_cache_unified> kv_base;
-    std::unique_ptr<llama_kv_cache_unified> kv_swa;
-};
-
-class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
-public:
-    using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
-
-    // used for errors
-    llama_kv_cache_unified_iswa_context(llama_memory_status status);
-
-    // used to create a full-cache context
-    llama_kv_cache_unified_iswa_context(
-            llama_kv_cache_unified_iswa * kv);
-
-    // used to create an update context
-    llama_kv_cache_unified_iswa_context(
-            llama_kv_cache_unified_iswa * kv,
-            llama_context * lctx,
-            bool optimize);
-
-    // used to create a batch processing context from a batch
-    llama_kv_cache_unified_iswa_context(
-            llama_kv_cache_unified_iswa * kv,
-            slot_info_vec_t sinfos_base,
-            slot_info_vec_t sinfos_swa,
-            std::vector<llama_ubatch> ubatches);
-
-    virtual ~llama_kv_cache_unified_iswa_context();
-
-    //
-    // llama_memory_context_i
-    //
-
-    bool next()  override;
-    bool apply() override;
-
-    llama_memory_status  get_status() const override;
-    const llama_ubatch & get_ubatch() const override;
-
-    //
-    // llama_kv_cache_unified_iswa_context specific API
-    //
-
-    const llama_kv_cache_unified_context * get_base() const;
-    const llama_kv_cache_unified_context * get_swa()  const;
-
-private:
-    //llama_kv_cache_unified_iswa * kv;
-
-    // the index of the next ubatch to process
-    size_t i_next = 0;
-
-    std::vector<llama_ubatch> ubatches;
-
-    const llama_memory_context_ptr ctx_base;
-    const llama_memory_context_ptr ctx_swa;
-
-    const llama_memory_status status;
-};
diff --git a/examples/talk-llama/llama-kv-cache-unified.cpp b/examples/talk-llama/llama-kv-cache-unified.cpp
deleted file mode 100644 (file)
index 478ebff..0000000
+++ /dev/null
@@ -1,2410 +0,0 @@
-#include "llama-kv-cache-unified.h"
-
-#include "llama-impl.h"
-#include "llama-io.h"
-#include "llama-model.h"
-#include "llama-context.h"
-
-#include <algorithm>
-#include <cassert>
-#include <cmath>
-#include <limits>
-#include <map>
-#include <stdexcept>
-
-//
-// llama_kv_cache_unified
-//
-
-llama_kv_cache_unified::llama_kv_cache_unified(
-        const llama_model &  model,
-          layer_filter_cb && filter,
-                ggml_type    type_k,
-                ggml_type    type_v,
-                     bool    v_trans,
-                     bool    offload,
-                     bool    unified,
-                 uint32_t    kv_size,
-                 uint32_t    n_seq_max,
-                 uint32_t    n_pad,
-                 uint32_t    n_swa,
-           llama_swa_type    swa_type) :
-    model(model), hparams(model.hparams), v_trans(v_trans),
-    n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
-
-    GGML_ASSERT(kv_size % n_pad == 0);
-
-    // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
-    auto n_layer_cache = hparams.n_layer;
-    if (model.arch == LLM_ARCH_GEMMA3N) {
-        n_layer_cache = 20;
-    }
-    if (model.arch == LLM_ARCH_GLM4_MOE) {
-        // GLM-4.5: Only process up to last layer, skip final NextN layer
-        n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers;
-    }
-
-    // create a context for each buffer type
-    std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
-    auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
-        auto it = ctx_map.find(buft);
-        if (it == ctx_map.end()) {
-            ggml_init_params params = {
-                /*.mem_size   =*/ size_t(2u*(1 + n_stream)*n_layer_cache*ggml_tensor_overhead()),
-                /*.mem_buffer =*/ NULL,
-                /*.no_alloc   =*/ true,
-            };
-
-            ggml_context * ctx = ggml_init(params);
-            if (!ctx) {
-                return nullptr;
-            }
-
-            ctx_map[buft] = ctx;
-            ctxs.emplace_back(ctx);
-
-            return ctx;
-        }
-
-        return it->second;
-    };
-
-    GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
-
-    v_heads.resize(n_stream);
-    for (uint32_t s = 0; s < n_stream; ++s) {
-        v_heads[s] = 0;
-    }
-
-    v_cells.resize(n_stream);
-    for (uint32_t s = 0; s < n_stream; ++s) {
-        v_cells[s].resize(kv_size);
-    }
-
-    // by default, all sequence ids are mapped to the 0th stream
-    seq_to_stream.resize(LLAMA_MAX_SEQ, 0);
-
-    if (n_stream > 1) {
-        seq_to_stream.resize(n_stream, 0);
-        for (uint32_t s = 0; s < n_stream; ++s) {
-            seq_to_stream[s] = s;
-        }
-    }
-
-    // [TAG_V_CACHE_VARIABLE]
-    if (v_trans && hparams.is_n_embd_v_gqa_variable()) {
-        LLAMA_LOG_WARN("%s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n",
-                __func__, hparams.n_embd_v_gqa_max());
-    }
-
-    for (uint32_t il = 0; il < n_layer_cache; il++) {
-        if (filter && !filter(il)) {
-            LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
-            continue;
-        }
-
-        // [TAG_V_CACHE_VARIABLE]
-        const uint32_t n_embd_k_gqa =            hparams.n_embd_k_gqa(il);
-        const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
-
-        const char * dev_name = "CPU";
-
-        ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
-
-        if (offload) {
-            auto * dev = model.dev_layer(il);
-            buft = ggml_backend_dev_buffer_type(dev);
-
-            dev_name = ggml_backend_dev_name(dev);
-        }
-
-        LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name);
-
-        ggml_context * ctx = ctx_for_buft(buft);
-        if (!ctx) {
-            throw std::runtime_error("failed to create ggml context for kv cache");
-        }
-
-        ggml_tensor * k;
-        ggml_tensor * v;
-
-        k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
-        v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
-
-        ggml_format_name(k, "cache_k_l%d", il);
-        ggml_format_name(v, "cache_v_l%d", il);
-
-        std::vector<ggml_tensor *> k_stream;
-        std::vector<ggml_tensor *> v_stream;
-
-        for (uint32_t s = 0; s < n_stream; ++s) {
-            k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]));
-            v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]));
-        }
-
-        map_layer_ids[il] = layers.size();
-
-        layers.push_back({ il, k, v, k_stream, v_stream, });
-    }
-
-    // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
-    if (model.arch == LLM_ARCH_GEMMA3N) {
-        LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1);
-
-        for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) {
-            if (filter && !filter(il)) {
-                LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
-                continue;
-            }
-
-            const bool     is_swa   = hparams.is_swa(il);
-            const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1);
-
-            GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
-            map_layer_ids[il] = map_layer_ids[il_reuse];
-
-            LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa);
-        }
-    }
-
-    // allocate tensors and initialize the buffers to avoid NaNs in the padding
-    for (auto it : ctx_map) {
-        auto * buft = it.first;
-        auto * ctx  = it.second;
-
-        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
-        if (!buf) {
-            throw std::runtime_error("failed to allocate buffer for kv cache");
-        }
-
-        LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
-
-        ggml_backend_buffer_clear(buf, 0);
-        bufs.emplace_back(buf);
-    }
-
-    {
-        const size_t memory_size_k = size_k_bytes();
-        const size_t memory_size_v = size_v_bytes();
-
-        LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
-                (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_stream,
-                ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
-                ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
-    }
-
-    const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
-    debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
-
-    const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
-    supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) != 0 : supports_set_rows;
-
-    if (!supports_set_rows) {
-        // ref: https://github.com/ggml-org/llama.cpp/pull/14363
-        GGML_ASSERT(unified && "cannot use non-unified KV cache without ggml_set_rows() support");
-    }
-
-    if (!supports_set_rows) {
-        LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
-    }
-}
-
-void llama_kv_cache_unified::clear(bool data) {
-    for (uint32_t s = 0; s < n_stream; ++s) {
-        v_cells[s].reset();
-        v_heads[s] = 0;
-    }
-
-    if (data) {
-        for (auto & buf : bufs) {
-            ggml_backend_buffer_clear(buf.get(), 0);
-        }
-    }
-}
-
-bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
-    GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
-
-    if (p0 < 0) {
-        p0 = 0;
-    }
-
-    if (p1 < 0) {
-        p1 = std::numeric_limits<llama_pos>::max();
-    }
-
-    if (seq_id >= 0) {
-        auto & cells = v_cells[seq_to_stream[seq_id]];
-        auto & head  = v_heads[seq_to_stream[seq_id]];
-
-        uint32_t new_head = cells.size();
-
-        for (uint32_t i = 0; i < cells.size(); ++i) {
-            if (!cells.pos_in(i, p0, p1)) {
-                continue;
-            }
-
-            if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
-                if (new_head == cells.size()) {
-                    new_head = i;
-                }
-            }
-        }
-
-        // If we freed up a slot, set head to it so searching can start there.
-        if (new_head != cells.size() && new_head < head) {
-            head = new_head;
-        }
-    } else {
-        // match any sequence
-        for (uint32_t s = 0; s < n_stream; ++s) {
-            auto & cells = v_cells[s];
-            auto & head  = v_heads[s];
-
-            uint32_t new_head = cells.size();
-
-            for (uint32_t i = 0; i < cells.size(); ++i) {
-                if (!cells.pos_in(i, p0, p1)) {
-                    continue;
-                }
-
-                cells.rm(i);
-
-                if (new_head == cells.size()) {
-                    new_head = i;
-                }
-            }
-
-            // If we freed up a slot, set head to it so searching can start there.
-            if (new_head != cells.size() && new_head < head) {
-                head = new_head;
-            }
-        }
-    }
-
-    return true;
-}
-
-void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
-    GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size());
-    GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size());
-
-    const auto s0 = seq_to_stream[seq_id_src];
-    const auto s1 = seq_to_stream[seq_id_dst];
-
-    if (s0 == s1) {
-        // since both sequences are in the same stream, no data copy is necessary
-        // we just have to update the cells meta data
-
-        auto & cells = v_cells[s0];
-
-        if (seq_id_src == seq_id_dst) {
-            return;
-        }
-
-        if (p0 < 0) {
-            p0 = 0;
-        }
-
-        if (p1 < 0) {
-            p1 = std::numeric_limits<llama_pos>::max();
-        }
-
-        for (uint32_t i = 0; i < cells.size(); ++i) {
-            if (!cells.pos_in(i, p0, p1)) {
-                continue;
-            }
-
-            if (cells.seq_has(i, seq_id_src)) {
-                cells.seq_add(i, seq_id_dst);
-            }
-        }
-
-        return;
-    }
-
-    // cross-stream sequence copies require to copy the actual buffer data
-
-    bool is_full = true;
-
-    if (p0 > 0 && p0 + 1 < (int) get_size()) {
-        is_full = false;
-    }
-
-    if (p1 > 0 && p1 + 1 < (int) get_size()) {
-        is_full = false;
-    }
-
-    GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers");
-
-    // enqueue the copy operation - the buffer copy will be performed during the next update
-    sc_info.ssrc.push_back(s0);
-    sc_info.sdst.push_back(s1);
-
-    v_cells[s1].reset();
-    for (uint32_t i = 0; i < v_cells[s0].size(); ++i) {
-        if (v_cells[s0].seq_has(i, seq_id_src)) {
-            llama_pos pos   = v_cells[s0].pos_get(i);
-            llama_pos shift = v_cells[s0].get_shift(i);
-
-            if (shift != 0) {
-                pos -= shift;
-                assert(pos >= 0);
-            }
-
-            v_cells[s1].pos_set(i, pos);
-            v_cells[s1].seq_add(i, seq_id_dst);
-
-            if (shift != 0) {
-                v_cells[s1].pos_add(i, shift);
-            }
-        }
-    }
-
-    v_heads[s1] = v_heads[s0];
-
-    //for (uint32_t s = 0; s < n_stream; ++s) {
-    //    LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s));
-    //}
-}
-
-void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
-    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
-
-    auto & cells = v_cells[seq_to_stream[seq_id]];
-    auto & head  = v_heads[seq_to_stream[seq_id]];
-
-    uint32_t new_head = cells.size();
-
-    for (uint32_t i = 0; i < cells.size(); ++i) {
-        if (cells.seq_keep(i, seq_id)) {
-            if (new_head == cells.size()) {
-                new_head = i;
-            }
-        }
-    }
-
-    // If we freed up a slot, set head to it so searching can start there.
-    if (new_head != cells.size() && new_head < head) {
-        head = new_head;
-    }
-}
-
-void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
-    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
-
-    auto & cells = v_cells[seq_to_stream[seq_id]];
-    auto & head  = v_heads[seq_to_stream[seq_id]];
-
-    if (shift == 0) {
-        return;
-    }
-
-    uint32_t new_head = cells.size();
-
-    if (p0 < 0) {
-        p0 = 0;
-    }
-
-    if (p1 < 0) {
-        p1 = std::numeric_limits<llama_pos>::max();
-    }
-
-    // If there is no range then return early to avoid looping over all cells.
-    if (p0 == p1) {
-        return;
-    }
-
-    for (uint32_t i = 0; i < cells.size(); ++i) {
-        if (!cells.pos_in(i, p0, p1)) {
-            continue;
-        }
-
-        if (cells.seq_has(i, seq_id)) {
-            if (cells.pos_add(i, shift)) {
-                if (new_head == cells.size()) {
-                    new_head = i;
-                }
-            }
-        }
-    }
-
-    // If we freed up a slot, set head to it so searching can start there.
-    // Otherwise we just start the next search from the beginning.
-    head = new_head != cells.size() ? new_head : 0;
-}
-
-void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
-    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
-
-    auto & cells = v_cells[seq_to_stream[seq_id]];
-
-    if (d == 1) {
-        return;
-    }
-
-    if (p0 < 0) {
-        p0 = 0;
-    }
-
-    if (p1 < 0) {
-        p1 = std::numeric_limits<llama_pos>::max();
-    }
-
-    // If there is no range then return early to avoid looping over the cache.
-    if (p0 == p1) {
-        return;
-    }
-
-    for (uint32_t i = 0; i < cells.size(); ++i) {
-        if (!cells.pos_in(i, p0, p1)) {
-            continue;
-        }
-
-        if (cells.seq_has(i, seq_id)) {
-            cells.pos_div(i, d);
-        }
-    }
-}
-
-llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
-    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
-
-    const auto & cells = v_cells[seq_to_stream[seq_id]];
-
-    return cells.seq_pos_min(seq_id);
-}
-
-llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
-    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
-
-    const auto & cells = v_cells[seq_to_stream[seq_id]];
-
-    return cells.seq_pos_max(seq_id);
-}
-
-llama_memory_context_ptr llama_kv_cache_unified::init_batch(
-            llama_batch_allocr & balloc,
-            uint32_t n_ubatch,
-            bool embd_all) {
-    GGML_UNUSED(embd_all);
-
-    do {
-        balloc.split_reset();
-
-        std::vector<llama_ubatch> ubatches;
-        while (true) {
-            auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true);
-
-            if (ubatch.n_tokens == 0) {
-                break;
-            }
-
-            ubatches.push_back(std::move(ubatch)); // NOLINT
-        }
-
-        if (balloc.get_n_used() < balloc.get_n_tokens()) {
-            // failed to find a suitable split
-            break;
-        }
-
-        auto sinfos = prepare(ubatches);
-        if (sinfos.empty()) {
-            break;
-        }
-
-        return std::make_unique<llama_kv_cache_unified_context>(
-                this, std::move(sinfos), std::move(ubatches));
-    } while (false);
-
-    return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
-}
-
-llama_memory_context_ptr llama_kv_cache_unified::init_full() {
-    return std::make_unique<llama_kv_cache_unified_context>(this);
-}
-
-llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
-    bool do_shift = get_has_shift();
-
-    defrag_info dinfo;
-
-    // see if we need to defrag
-    if (n_stream == 1) {
-        // note : for now do not consider defrag for n_stream > 1
-        const auto & cells = v_cells[seq_to_stream[0]];
-
-        bool do_defrag = optimize;
-
-        const auto thold = lctx->get_cparams().defrag_thold;
-
-        if (!do_defrag && thold > 0.0f) {
-            const auto n_kv = cells.used_max_p1();
-
-            // - do not defrag small contexts (i.e. < 2048 tokens)
-            // - count the padding towards the number of used tokens
-            const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
-
-            if (fragmentation > thold) {
-                LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
-
-                do_defrag = true;
-            }
-        }
-
-        if (do_defrag) {
-            dinfo = defrag_prepare(lctx->graph_max_nodes());
-        }
-    }
-
-    return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo), std::move(sc_info));
-}
-
-llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
-    llama_kv_cache_unified::slot_info_vec_t res;
-
-    struct state_t {
-        slot_info sinfo; // slot info for the ubatch
-
-        std::vector<uint32_t> v_heads_old; // old positions of the heads, before placing the ubatch
-
-        std::vector<llama_kv_cells_unified> v_cells; // copy of the old cells, before placing the ubatch
-    };
-
-    // remember the old state of the cells so we can restore it in the end
-    std::vector<state_t> states;
-
-    bool success = true;
-
-    for (const auto & ubatch : ubatches) {
-        // non-continuous slots require support for ggml_set_rows()
-        const bool cont = supports_set_rows ? false : true;
-
-        // only find a suitable slot for the ubatch. don't modify the cells yet
-        const auto sinfo_new = find_slot(ubatch, cont);
-        if (sinfo_new.empty()) {
-            success = false;
-            break;
-        }
-
-        // remeber the position that we found
-        res.push_back(sinfo_new);
-
-        // store the old state of the cells in the recovery stack
-        {
-            state_t state = { sinfo_new, v_heads, {} };
-
-            for (uint32_t s = 0; s < sinfo_new.n_stream(); ++s) {
-                auto & cells = v_cells[sinfo_new.strm[s]];
-
-                state.v_cells.push_back(cells.cp(sinfo_new.idxs[s]));
-            }
-
-            states.push_back(std::move(state));
-        }
-
-        // now emplace the ubatch
-        apply_ubatch(sinfo_new, ubatch);
-    }
-
-    GGML_ASSERT(!states.empty() || !success);
-
-    // iterate backwards and restore the cells to their original state
-    for (auto it = states.rbegin(); it != states.rend(); ++it) {
-        const auto & sinfo = it->sinfo;
-
-        for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
-            auto & cells = v_cells[sinfo.strm[s]];
-            auto & head  = v_heads[sinfo.strm[s]];
-
-            cells.set(sinfo.idxs[s], it->v_cells[s]);
-            head = it->v_heads_old[s];
-        }
-    }
-
-    if (!success) {
-        return {};
-    }
-
-    return res;
-}
-
-bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info) {
-    bool updated = false;
-
-    auto * sched = lctx->get_sched();
-
-    if (!sc_info.empty()) {
-        assert(n_stream > 1 && "stream copy should never happen with a single stream");
-
-        llama_synchronize(lctx);
-
-        const size_t n_copy = sc_info.ssrc.size();
-
-        for (size_t i = 0; i < n_copy; ++i) {
-            const auto ssrc = sc_info.ssrc[i];
-            const auto sdst = sc_info.sdst[i];
-
-            assert(ssrc < n_stream);
-            assert(sdst < n_stream);
-
-            LLAMA_LOG_DEBUG("%s: copying KV buffer: stream %d to stream %d\n", __func__, ssrc, sdst);
-
-            assert(ssrc != sdst);
-
-            for (uint32_t il = 0; il < layers.size(); ++il) {
-                const auto & layer = layers[il];
-
-                ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]);
-                ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
-            }
-        }
-    }
-
-    if (do_shift) {
-        if (!get_can_shift()) {
-            GGML_ABORT("The current KV cache / model configuration does not support K-shift");
-        }
-
-        LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
-
-        // apply K-shift if needed
-        if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
-            ggml_backend_sched_reset(sched);
-
-            auto * res = lctx->get_gf_res_reserve();
-
-            res->reset();
-
-            auto * gf = build_graph_shift(res, lctx);
-            if (!ggml_backend_sched_alloc_graph(sched, gf)) {
-                LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__);
-                return updated;
-            }
-
-            res->set_inputs(nullptr);
-
-            if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
-                LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
-                return updated;
-            }
-
-            updated = true;
-        }
-
-        for (uint32_t s = 0; s < n_stream; ++s) {
-            auto & cells = v_cells[s];
-
-            cells.reset_shift();
-        }
-    }
-
-    if (!dinfo.empty()) {
-        LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
-
-        // note: for now do not consider defrag for n_stream > 1
-        auto & cells = v_cells[seq_to_stream[0]];
-        auto & head  = v_heads[seq_to_stream[0]];
-
-        // apply moves:
-        {
-            const auto n_kv = dinfo.ids.size();
-
-            for (uint32_t i = 0; i < n_kv; ++i) {
-                assert(dinfo.ids[i] <= n_kv);
-
-                if (dinfo.ids[i] == n_kv || dinfo.ids[i] == i) {
-                    continue;
-                }
-
-                cells.mv(i, dinfo.ids[i]);
-            }
-
-            // reset the head so we can find the first free slot during the next ubatch
-            head = 0;
-        }
-
-        ggml_backend_sched_reset(sched);
-
-        auto * res = lctx->get_gf_res_reserve();
-
-        res->reset();
-
-        auto * gf = build_graph_defrag(res, lctx, dinfo);
-        if (!ggml_backend_sched_alloc_graph(sched, gf)) {
-            LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
-            return updated;
-        }
-
-        res->set_inputs(nullptr);
-
-        if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
-            LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
-            return updated;
-        }
-
-        updated = true;
-    }
-
-    return updated;
-}
-
-llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
-
-    if (debug > 0) {
-        for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
-            const auto seq_id = ubatch.seq_id_unq[s];
-            const auto stream_id = seq_to_stream[seq_id];
-            const auto & cells = v_cells[stream_id];
-            const uint32_t head_cur = v_heads[stream_id];
-
-            LLAMA_LOG_DEBUG("%s: stream[%d], n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n",
-                    __func__, stream_id, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa);
-
-            if ((debug == 2 && n_swa > 0) || debug > 2) {
-                std::string ss;
-                for (uint32_t i = 0; i < cells.size(); ++i) {
-                    if (cells.is_empty(i)) {
-                        ss += '.';
-                    } else {
-                        assert(cells.seq_count(i) >= 1);
-
-                        if (cells.seq_count(i) == 1) {
-                            ss += std::to_string(cells.seq_get(i));
-                        } else {
-                            ss += 'M';
-                        }
-                    }
-                    if (i%256 == 255) {
-                        ss += " *";
-                        ss += '\n';
-                    }
-                }
-                LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
-            }
-
-            if ((debug == 2 && n_swa > 0) || debug > 2) {
-                std::string ss;
-                for (uint32_t i = 0; i < cells.size(); ++i) {
-                    std::string cur;
-                    if (cells.is_empty(i)) {
-                        cur = '.';
-                    } else {
-                        cur = std::to_string(cells.pos_get(i));
-                    }
-                    const int n = cur.size();
-                    for (int j = 0; j < 5 - n; ++j) {
-                        cur += ' ';
-                    }
-                    ss += cur;
-                    if (i%256 == 255) {
-                        ss += " *";
-                    }
-                    if (i%64 == 63) {
-                        ss += '\n';
-                    }
-                }
-                LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
-            }
-
-            for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
-                if (cells.seq_pos_min(s) < 0) {
-                    continue;
-                }
-
-                LLAMA_LOG_DEBUG("%s: stream[%d] min[%d] = %5d, max[%d] = %5d\n", __func__, stream_id, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
-            }
-        }
-    }
-
-    uint32_t n_tokens = ubatch.n_tokens;
-    uint32_t n_seqs   = 1;
-
-    if (n_stream > 1) {
-        GGML_ASSERT(n_tokens % ubatch.n_seqs_unq == 0);
-
-        n_seqs   = ubatch.n_seqs_unq;
-        n_tokens = n_tokens / n_seqs;
-    }
-
-    slot_info res = {
-        /*.s0   =*/ LLAMA_MAX_SEQ,
-        /*.s1   =*/ 0,
-        /*.strm =*/ { },
-        /*.idxs =*/ { },
-    };
-
-    res.resize(n_seqs);
-
-    for (uint32_t s = 0; s < n_seqs; ++s) {
-        const auto seq_id = ubatch.seq_id_unq[s];
-
-        if (n_stream > 1) {
-            GGML_ASSERT(ubatch.n_seq_id[s*n_tokens]    == 1);
-            GGML_ASSERT(ubatch.seq_id  [s*n_tokens][0] == seq_id);
-        }
-
-        res.s0 = std::min<llama_seq_id>(res.s0, seq_to_stream[seq_id]);
-        res.s1 = std::max<llama_seq_id>(res.s1, seq_to_stream[seq_id]);
-
-        res.strm[s] = seq_to_stream[seq_id];
-        res.idxs[s].reserve(n_tokens);
-
-        const auto & cells = v_cells[seq_to_stream[seq_id]];
-
-        uint32_t head_cur = v_heads[seq_to_stream[seq_id]];
-
-        // if we have enough unused cells before the current head ->
-        //   better to start searching from the beginning of the cache, hoping to fill it
-        if (head_cur > cells.get_used() + 2*n_tokens) {
-            head_cur = 0;
-        }
-
-        if (n_tokens > cells.size()) {
-            LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
-            return { };
-        }
-
-        uint32_t n_tested = 0;
-
-        // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
-        // for non-continuous slots, we test the tokens one by one
-        const uint32_t n_test = cont ? n_tokens : 1;
-
-        while (true) {
-            if (head_cur + n_test > cells.size()) {
-                n_tested += cells.size() - head_cur;
-                head_cur = 0;
-                continue;
-            }
-
-            for (uint32_t i = 0; i < n_test; i++) {
-                const auto idx = head_cur;
-
-                head_cur++;
-                n_tested++;
-
-                //const llama_pos    pos    = ubatch.pos[i];
-                //const llama_seq_id seq_id = ubatch.seq_id[i][0];
-
-                // can we use this cell? either:
-                //  - the cell is empty
-                //  - the cell is occupied only by one sequence:
-                //    - (disabled) mask causally, if the sequence is the same as the one we are inserting
-                //    - mask SWA, using current max pos for that sequence in the cache
-                //                always insert in the cell with minimum pos
-                bool can_use = cells.is_empty(idx);
-
-                if (!can_use && cells.seq_count(idx) == 1) {
-                    const llama_pos pos_cell = cells.pos_get(idx);
-
-                    // (disabled) causal mask
-                    // note: it's better to purge any "future" tokens beforehand
-                    //if (cells.seq_has(idx, seq_id)) {
-                    //    can_use = pos_cell >= pos;
-                    //}
-
-                    if (!can_use) {
-                        const llama_seq_id seq_id_cell = cells.seq_get(idx);
-
-                        // SWA mask
-                        if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
-                            can_use = true;
-                        }
-                    }
-                }
-
-                if (can_use) {
-                    res.idxs[s].push_back(idx);
-                } else {
-                    if (cont) {
-                        break;
-                    }
-                }
-            }
-
-            if (res.idxs[s].size() == n_tokens) {
-                break;
-            }
-
-            if (cont) {
-                res.idxs[s].clear();
-            }
-
-            if (n_tested >= cells.size()) {
-                //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
-                return { };
-            }
-        }
-
-        // we didn't find a suitable slot - return empty result
-        if (res.idxs[s].size() < n_tokens) {
-            return { };
-        }
-    }
-
-    assert(res.s1 >= res.s0);
-
-    return res;
-}
-
-void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
-    // keep track of the max sequence position that we would overwrite with this ubatch
-    // for non-SWA cache, this would be always empty
-    llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
-    for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
-        seq_pos_max_rm[s] = -1;
-    }
-
-    assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size());
-
-    for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
-        for (uint32_t ii = 0; ii < sinfo.size(); ++ii) {
-            const uint32_t i = s*sinfo.size() + ii;
-
-            auto & cells = v_cells[sinfo.strm[s]];
-
-            const auto idx = sinfo.idxs[s][ii];
-
-            if (!cells.is_empty(idx)) {
-                assert(cells.seq_count(idx) == 1);
-
-                const llama_seq_id seq_id = cells.seq_get(idx);
-                const llama_pos    pos    = cells.pos_get(idx);
-
-                seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
-
-                cells.rm(idx);
-            }
-
-            cells.pos_set(idx, ubatch.pos[i]);
-
-            for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
-                cells.seq_add(idx, ubatch.seq_id[i][s]);
-            }
-        }
-    }
-
-    // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
-    //       will be present in the cache. so we have to purge any position which is less than those we would overwrite
-    //       ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
-    for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
-        if (seq_pos_max_rm[s] == -1) {
-            continue;
-        }
-
-        GGML_ASSERT(s < seq_to_stream.size());
-
-        auto & cells = v_cells[seq_to_stream[s]];
-
-        if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
-            LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
-                    __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
-
-            seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
-        }
-    }
-
-    // move the head at the end of the slot
-    for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
-        auto & head = v_heads[sinfo.strm[s]];
-
-        head = sinfo.idxs[s].back() + 1;
-    }
-}
-
-bool llama_kv_cache_unified::get_can_shift() const {
-    return true;
-}
-
-uint32_t llama_kv_cache_unified::get_size() const {
-    const auto & cells = v_cells[seq_to_stream[0]];
-
-    return cells.size();
-}
-
-uint32_t llama_kv_cache_unified::get_n_stream() const {
-    return n_stream;
-}
-
-bool llama_kv_cache_unified::get_has_shift() const {
-    bool result = false;
-
-    for (uint32_t s = 0; s < n_stream; ++s) {
-        result |= v_cells[s].get_has_shift();
-    }
-
-    return result;
-}
-
-uint32_t llama_kv_cache_unified::get_n_kv() const {
-    uint32_t result = 0;
-
-    for (uint32_t s = 0; s < n_stream; ++s) {
-        const auto & cells = v_cells[s];
-
-        result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
-    }
-
-    return result;
-}
-
-bool llama_kv_cache_unified::get_supports_set_rows() const {
-    return supports_set_rows;
-}
-
-ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
-    const int32_t ikv = map_layer_ids.at(il);
-
-    auto * k = layers[ikv].k;
-
-    const uint64_t kv_size      = get_size();
-    const uint64_t n_embd_k_gqa = k->ne[0];
-
-    assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il));
-
-    const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
-
-    return ggml_view_4d(ctx, k,
-            hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns,
-            ggml_row_size(k->type, hparams.n_embd_head_k),
-            ggml_row_size(k->type, n_embd_k_gqa),
-            ggml_row_size(k->type, n_embd_k_gqa*kv_size),
-            ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0);
-}
-
-ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
-    const int32_t ikv = map_layer_ids.at(il);
-
-    auto * v = layers[ikv].v;
-
-    const uint64_t kv_size      = get_size();
-    const uint64_t n_embd_v_gqa = v->ne[0];
-
-    // [TAG_V_CACHE_VARIABLE]
-    assert(n_embd_v_gqa >= hparams.n_embd_v_gqa(il));
-
-    const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
-
-    if (!v_trans) {
-        // note: v->nb[1] <= v->nb[2]
-        return ggml_view_4d(ctx, v,
-                hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
-                ggml_row_size(v->type, hparams.n_embd_head_v),            // v->nb[1]
-                ggml_row_size(v->type, n_embd_v_gqa),         // v->nb[2]
-                ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
-                ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0);
-    }
-
-    // note: v->nb[1] > v->nb[2]
-    return ggml_view_4d(ctx, v,
-            n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
-            ggml_row_size(v->type, kv_size*hparams.n_embd_head_v),    // v->nb[1]
-            ggml_row_size(v->type, kv_size),                          // v->nb[2]
-            ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
-            ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
-}
-
-ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
-    const int32_t ikv = map_layer_ids.at(il);
-
-    auto * k = layers[ikv].k;
-
-    const int64_t n_embd_k_gqa = k->ne[0];
-    const int64_t n_tokens = k_cur->ne[2];
-
-    k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
-
-    if (k_idxs && supports_set_rows) {
-        if (k->ne[2] > 1) {
-            k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
-        }
-
-        return ggml_set_rows(ctx, k, k_cur, k_idxs);
-    }
-
-    // TODO: fallback to old ggml_cpy() method for backwards compatibility
-    //       will be removed when ggml_set_rows() is adopted by all backends
-
-    GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS");
-
-    ggml_tensor * k_view = ggml_view_1d(ctx, k,
-            n_tokens*n_embd_k_gqa,
-            ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
-
-    return ggml_cpy(ctx, k_cur, k_view);
-}
-
-ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
-    const int32_t ikv = map_layer_ids.at(il);
-
-    auto * v = layers[ikv].v;
-
-    const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1];
-    const int64_t n_tokens     = v_cur->ne[2];
-
-    v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
-
-    if (v_idxs && supports_set_rows) {
-        if (!v_trans) {
-            if (v->ne[2] > 1) {
-                v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
-            }
-
-            return ggml_set_rows(ctx, v, v_cur, v_idxs);
-        }
-
-        // [TAG_V_CACHE_VARIABLE]
-        if (n_embd_v_gqa < v->ne[0]) {
-            v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0);
-        }
-
-        // the row becomes a single element
-        ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
-
-        v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);
-
-        return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
-    }
-
-    // TODO: fallback to old ggml_cpy() method for backwards compatibility
-    //       will be removed when ggml_set_rows() is adopted by all backends
-
-    GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS");
-
-    ggml_tensor * v_view = nullptr;
-
-    if (!v_trans) {
-        v_view = ggml_view_1d(ctx, v,
-                n_tokens*n_embd_v_gqa,
-                ggml_row_size(v->type, n_embd_v_gqa)*sinfo.head());
-    } else {
-        v_cur = ggml_transpose(ctx, v_cur);
-
-        v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa,
-                (v->ne[1]    )*ggml_element_size(v),
-                (sinfo.head())*ggml_element_size(v));
-    }
-
-    return ggml_cpy(ctx, v_cur, v_view);
-}
-
-ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
-    const uint32_t n_tokens = ubatch.n_tokens;
-
-    ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
-
-    ggml_set_input(k_idxs);
-
-    return k_idxs;
-}
-
-ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
-    const uint32_t n_tokens = ubatch.n_tokens;
-
-    ggml_tensor * v_idxs;
-
-    if (!v_trans) {
-        v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
-    } else {
-        v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa_max());
-    }
-
-    ggml_set_input(v_idxs);
-
-    return v_idxs;
-}
-
-void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
-    if (!supports_set_rows) {
-        return;
-    }
-
-    const uint32_t n_tokens = ubatch->n_tokens;
-    GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
-
-    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
-    int64_t * data = (int64_t *) dst->data;
-
-    for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
-        const int64_t offs = sinfo.strm[s]*get_size();
-
-        for (uint32_t i = 0; i < sinfo.size(); ++i) {
-            data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
-        }
-    }
-}
-
-void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
-    if (!supports_set_rows) {
-        return;
-    }
-
-    const uint32_t n_tokens = ubatch->n_tokens;
-    GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
-
-    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
-    int64_t * data = (int64_t *) dst->data;
-
-    if (!v_trans) {
-        for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
-            const int64_t offs = sinfo.strm[s]*get_size();
-
-            for (uint32_t i = 0; i < sinfo.size(); ++i) {
-                data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
-            }
-        }
-    } else {
-        // note: the V cache is transposed when not using flash attention
-        const int64_t kv_size = get_size();
-
-        const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max();
-
-        for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
-            const int64_t offs = sinfo.strm[s]*kv_size*n_embd_v_gqa;
-
-            for (uint32_t i = 0; i < sinfo.size(); ++i) {
-                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                    data[s*sinfo.size()*n_embd_v_gqa + i*n_embd_v_gqa + j] = offs + j*kv_size + sinfo.idxs[s][i];
-                }
-            }
-        }
-    }
-}
-
-void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
-    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
-
-    int32_t * data = (int32_t *) dst->data;
-
-    for (uint32_t s = 0; s < n_stream; ++s) {
-        const auto & cells = v_cells[s];
-
-        for (uint32_t i = 0; i < cells.size(); ++i) {
-            data[s*cells.size() + i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
-        }
-    }
-}
-
-void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
-    const uint32_t n_tokens = ubatch->n_tokens;
-
-    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
-    float * data = (float *) dst->data;
-
-    const int64_t n_kv     = dst->ne[0];
-    const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch
-
-    GGML_ASSERT(n_tokens%n_stream == 0);
-
-    // n_tps == n_tokens_per_stream
-    const int64_t n_tps     = n_tokens/n_stream;
-    const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD);
-
-    std::fill(data, data + ggml_nelements(dst), -INFINITY);
-
-    // Use only the previous KV cells of the correct sequence for each token of the ubatch.
-    // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
-    // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
-    //   Causal mask:
-    //      xxx-------
-    //      xxxx------
-    //      xxxxx-----
-    //   Non-causal mask:
-    //      xxxxx-----
-    //      xxxxx-----
-    //      xxxxx-----
-    // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
-    // TODO: optimize this section
-    for (uint32_t h = 0; h < 1; ++h) {
-        for (uint32_t s = 0; s < n_stream; ++s) {
-            for (uint32_t ii = 0; ii < n_tps; ++ii) {
-                const uint32_t i = s*n_tps + ii;
-
-                const llama_seq_id seq_id = ubatch->seq_id[i][0];
-
-                const auto & cells = v_cells[seq_to_stream[seq_id]];
-
-                const llama_pos p1 = ubatch->pos[i];
-
-                const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
-
-                for (uint32_t j = 0; j < n_kv; ++j) {
-                    if (cells.is_empty(j)) {
-                        continue;
-                    }
-
-                    // mask the token if not the same sequence
-                    if (!cells.seq_has(j, seq_id)) {
-                        continue;
-                    }
-
-                    const llama_pos p0 = cells.pos_get(j);
-
-                    // mask future tokens
-                    if (causal_attn && p0 > p1) {
-                        continue;
-                    }
-
-                    // apply SWA if any
-                    if (is_masked_swa(p0, p1)) {
-                        continue;
-                    }
-
-                    data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
-                }
-            }
-        }
-    }
-}
-
-void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
-    const int64_t n_tokens = ubatch->n_tokens;
-
-    GGML_ASSERT(n_stream == 1 && "TODO: support multiple streams");
-    const auto & cells = v_cells[0];
-
-    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
-    GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
-
-    int32_t * data = (int32_t *) dst->data;
-
-    const int32_t n_kv = dst->ne[0];
-
-    for (int h = 0; h < 1; ++h) {
-        for (int i = 0; i < n_tokens; ++i) {
-            for (int j = 0; j < n_kv; ++j) {
-                // the position when the cells is empty is irrelevant - it will be masked out later in the attention
-                const llama_pos p0 = cells.is_empty(j) ? -1 : cells.pos_get(j);
-
-                data[h*(n_kv*n_tokens) + i*n_kv + j] = llama_relative_position_bucket(p0, ubatch->pos[i], hparams.n_rel_attn_bkts, false);
-            }
-        }
-    }
-}
-
-size_t llama_kv_cache_unified::total_size() const {
-    size_t size = 0;
-
-    for (const auto & buf : bufs) {
-        size += ggml_backend_buffer_get_size(buf.get());
-    }
-
-    return size;
-}
-
-size_t llama_kv_cache_unified::size_k_bytes() const {
-    size_t size_k_bytes = 0;
-
-    for (const auto & layer : layers) {
-        size_k_bytes += ggml_nbytes(layer.k);
-    }
-
-    return size_k_bytes;
-}
-
-size_t llama_kv_cache_unified::size_v_bytes() const {
-    size_t size_v_bytes = 0;
-
-    for (const auto & layer : layers) {
-        size_v_bytes += ggml_nbytes(layer.v);
-    }
-
-    return size_v_bytes;
-}
-
-ggml_tensor * llama_kv_cache_unified::build_rope_shift(
-        const llama_cparams & cparams,
-               ggml_context * ctx,
-                ggml_tensor * cur,
-                ggml_tensor * shift,
-                ggml_tensor * factors,
-                      float   freq_base,
-                      float   freq_scale) const {
-    const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
-
-    const auto & yarn_ext_factor = cparams.yarn_ext_factor;
-    const auto & yarn_beta_fast  = cparams.yarn_beta_fast;
-    const auto & yarn_beta_slow  = cparams.yarn_beta_slow;
-
-    const auto & n_rot     = hparams.n_rot;
-    const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE
-                                // @ngxson : this is a workaround
-                                // for M-RoPE, we want to rotate the whole vector when doing KV shift
-                                // a normal RoPE should work, we just need to use the correct ordering
-                                // ref: https://github.com/ggml-org/llama.cpp/pull/13870
-                                ? LLAMA_ROPE_TYPE_NEOX
-                                : hparams.rope_type;
-
-    // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
-    // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
-    const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2
-                                    ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
-                                    : cparams.yarn_attn_factor;
-
-    ggml_tensor * tmp;
-
-    if (ggml_is_quantized(cur->type)) {
-        // dequantize to f32 -> RoPE -> quantize back
-        tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
-
-        tmp = ggml_rope_ext(ctx, tmp,
-                shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
-
-        tmp = ggml_cpy(ctx, tmp, cur);
-    } else {
-        // we rotate only the first n_rot dimensions
-        tmp = ggml_rope_ext_inplace(ctx, cur,
-                shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
-    }
-
-    return tmp;
-}
-
-class llm_graph_input_k_shift : public llm_graph_input_i {
-public:
-    llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
-    virtual ~llm_graph_input_k_shift() = default;
-
-    void set_input(const llama_ubatch * ubatch) override;
-
-    ggml_tensor * k_shift; // I32 [kv_size*n_stream]
-
-    const llama_kv_cache_unified * kv_self;
-};
-
-void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
-    GGML_UNUSED(ubatch);
-
-    if (k_shift) {
-        kv_self->set_input_k_shift(k_shift);
-    }
-}
-
-ggml_cgraph * llama_kv_cache_unified::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
-    auto * ctx = res->get_ctx();
-    auto * gf  = res->get_gf();
-
-    const auto & n_embd_head_k = hparams.n_embd_head_k;
-  //const auto & n_embd_head_v = hparams.n_embd_head_v;
-
-    auto inp = std::make_unique<llm_graph_input_k_shift>(this);
-
-    inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
-    ggml_set_input(inp->k_shift);
-
-    const auto & cparams = lctx->get_cparams();
-
-    for (const auto & layer : layers) {
-        const uint32_t il = layer.il;
-
-        const int64_t n_head_kv    = hparams.n_head_kv(il);
-        const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
-
-        const float freq_base_l  = model.get_rope_freq_base (cparams, il);
-        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
-
-        ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
-
-        ggml_tensor * k =
-            ggml_view_3d(ctx, layer.k,
-                n_embd_head_k, n_head_kv, get_size()*n_stream,
-                ggml_row_size(layer.k->type, n_embd_head_k),
-                ggml_row_size(layer.k->type, n_embd_k_gqa),
-                0);
-
-        ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
-
-        ggml_build_forward_expand(gf, cur);
-    }
-
-    res->add_input(std::move(inp));
-
-    return gf;
-}
-
-ggml_cgraph * llama_kv_cache_unified::build_graph_defrag(
-         llm_graph_result * res,
-            llama_context * lctx,
-        const defrag_info & dinfo) const {
-    auto * ctx = res->get_ctx();
-    auto * gf  = res->get_gf();
-
-    GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
-
-    const auto & cells = v_cells[0];
-
-    const auto & ids = dinfo.ids;
-
-    const auto & cparams = lctx->get_cparams();
-
-#if 0
-    // CPU defrag
-    //
-    // TODO: optimizations are possible:
-    //       - multiple threads
-    //       - avoid copying to the host memory when already there
-    //
-    // likely not worth the effort, as we have ggml_graph based defrag
-    //
-
-    const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
-    const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
-
-    const uint32_t kv_size = size;
-
-    std::vector<uint8_t> buf_k;
-    std::vector<uint8_t> buf_v;
-
-    for (uint32_t il = 0; il < n_layer; ++il) {
-        const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
-        const size_t k_size     = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
-
-        const size_t v_size_el = ggml_type_size(v_l[il]->type);
-        const size_t v_size    = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
-
-        buf_k.resize(k_size);
-        buf_v.resize(v_size);
-
-        ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
-        ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
-
-        // batch move [i, i+nm) to [id, id+nm)
-        // note: cells can move only to a lower index
-        for (uint32_t i = 0; i < n_kv; ++i) {
-            const uint32_t id = ids[i];
-
-            if (i == id || id == n_kv) {
-                continue;
-            }
-
-            uint32_t nm = 1;
-
-            while (i + nm < n_kv && ids[i + nm] == id + nm) {
-                nm++;
-            }
-
-            // move keys
-            {
-                const int64_t os =  i*k_size_row;
-                const int64_t od = id*k_size_row;
-
-                memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
-            }
-
-            // move values (note: they are transposed)
-            {
-                const int64_t os =  i;
-                const int64_t od = id;
-
-                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                    memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
-                }
-            }
-
-            i += nm - 1;
-        }
-
-        ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
-        ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
-    }
-#else
-    for (uint32_t i = 0; i < ids.size(); ++i) {
-        const uint32_t id = ids[i];
-
-        if (i == id || id == ids.size()) {
-            continue;
-        }
-
-        uint32_t nm = 1;
-
-        while (i + nm < ids.size() && ids[i + nm] == id + nm) {
-            nm++;
-        }
-
-        for (const auto & layer : layers) {
-            const uint32_t il = layer.il;
-
-            const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
-            const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
-
-            ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k,
-                    n_embd_k_gqa, nm,
-                    ggml_row_size(layer.k->type, n_embd_k_gqa),
-                    ggml_row_size(layer.k->type, n_embd_k_gqa*i));
-
-            ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k,
-                    n_embd_k_gqa, nm,
-                    ggml_row_size(layer.k->type, n_embd_k_gqa),
-                    ggml_row_size(layer.k->type, n_embd_k_gqa*id));
-
-            ggml_tensor * view_v_src;
-            ggml_tensor * view_v_dst;
-
-            if (cparams.flash_attn) {
-                // NOTE: the V cache is not transposed when using flash attention
-                view_v_src = ggml_view_2d(ctx, layer.v,
-                        n_embd_v_gqa, nm,
-                        ggml_row_size(layer.v->type, n_embd_v_gqa),
-                        ggml_row_size(layer.v->type, n_embd_v_gqa*i));
-
-                view_v_dst = ggml_view_2d(ctx, layer.v,
-                        n_embd_v_gqa, nm,
-                        ggml_row_size(layer.v->type, n_embd_v_gqa),
-                        ggml_row_size(layer.v->type, n_embd_v_gqa*id));
-            } else {
-                view_v_src = ggml_view_2d(ctx, layer.v,
-                        nm, n_embd_v_gqa,
-                        ggml_row_size(layer.v->type, cells.size()),
-                        ggml_row_size(layer.v->type, i));
-
-                view_v_dst = ggml_view_2d(ctx, layer.v,
-                        nm, n_embd_v_gqa,
-                        ggml_row_size(layer.v->type, cells.size()),
-                        ggml_row_size(layer.v->type, id));
-            }
-
-            ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst));
-            ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst));
-        }
-
-        i += nm - 1;
-    }
-
-    //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
-#endif
-
-    return gf;
-}
-
-llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
-    GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
-
-    const auto & cells = v_cells[0];
-
-    const uint32_t n_layer = layers.size();
-
-    const uint32_t n_kv   = cells.used_max_p1();
-    const uint32_t n_used = cells.get_used();
-
-    assert(n_used <= n_kv);
-
-    //const int64_t t_start = ggml_time_us();
-
-    // number of cells moved
-    uint32_t n_moves = 0;
-
-    // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
-    //   - source view, destination view, copy operation
-    //   - x2 for keys and values
-    //const uint32_t max_moves = max_nodes()/(6*n_layer);
-    // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
-    const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
-
-    // determine which KV cells to move where
-    defrag_info res;
-    auto & ids = res.ids;
-
-    ids.resize(n_kv, n_kv);
-
-    for (uint32_t i0 = 0; i0 < n_used; ++i0) {
-        if (!cells.is_empty(i0)) {
-            ids[i0] = i0;
-
-            continue;
-        }
-
-        // found a hole - fill it with data from the end of the cache
-
-        uint32_t nh = 1;
-
-        // determine the size of the hole
-        while (i0 + nh < n_used && cells.is_empty(i0 + nh)) {
-            nh++;
-        }
-
-        uint32_t nf = 0;
-        uint32_t is = n_kv - 1;
-
-        // starting from the end, find nh non-empty cells
-        for (; is > i0; --is) {
-            if (cells.is_empty(is) || ids[is] != n_kv) {
-                continue;
-            }
-
-            // non-empty cell which is not yet moved
-            nf++;
-
-            if (nf == nh) {
-                break;
-            }
-        }
-
-        // this can only happen if `n_used` is not accurate, which would be a bug
-        GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
-
-        nf = 0;
-
-        uint32_t i1 = is;
-
-        // are we moving a continuous block of memory?
-        bool cont = false;
-
-        // should we stop searching for the next move?
-        bool stop = false;
-
-        // go back and move the nf cells to the hole
-        for (; i1 < n_kv; ++i1) {
-            if (cells.is_empty(i1) || ids[i1] != n_kv) {
-                if (n_moves == max_moves) {
-                    stop = true;
-                    break;
-                }
-
-                cont = false;
-                continue;
-            }
-
-            // this cell goes to (i0 + nf)
-            ids[i1] = i0 + nf;
-
-            if (!cont) {
-                n_moves++;
-                cont = true;
-            }
-
-            nf++;
-
-            if (nf == nh) {
-                break;
-            }
-        }
-
-        if (stop || n_moves == max_moves) {
-            break;
-        }
-
-        //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
-
-        i0 += nh - 1;
-    }
-
-    if (n_moves == 0) {
-        return {};
-    }
-
-    LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
-
-    LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
-
-    return res;
-}
-
-bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
-    assert(p0 >= 0 && p1 >= 0);
-
-    switch (swa_type) {
-        case LLAMA_SWA_TYPE_NONE:
-            {
-            } break;
-        case LLAMA_SWA_TYPE_STANDARD:
-            {
-                if (p1 - p0 >= (int32_t) n_swa) {
-                    return true;
-                }
-            } break;
-        case LLAMA_SWA_TYPE_CHUNKED:
-            {
-                const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
-
-                if (p0 < pos_chunk_start) {
-                    return true;
-                }
-            } break;
-    }
-
-    return false;
-}
-
-void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
-    GGML_UNUSED(flags);
-
-    io.write(&n_stream, sizeof(n_stream));
-
-    for (uint32_t s = 0; s < n_stream; ++s) {
-        cell_ranges_t cr { s, {} };
-
-        uint32_t cell_count = 0;
-
-        const auto & cells = v_cells[s];
-
-        // Count the number of cells with the specified seq_id
-        // Find all the ranges of cells with this seq id (or all, when -1)
-        uint32_t cell_range_begin = cells.size();
-
-        for (uint32_t i = 0; i < cells.size(); ++i) {
-            if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
-                ++cell_count;
-                if (cell_range_begin == cells.size()) {
-                    cell_range_begin = i;
-                }
-            } else {
-                if (cell_range_begin != cells.size()) {
-                    cr.data.emplace_back(cell_range_begin, i);
-                    cell_range_begin = cells.size();
-                }
-            }
-        }
-
-        if (cell_range_begin != cells.size()) {
-            cr.data.emplace_back(cell_range_begin, cells.size());
-        }
-
-        // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
-        uint32_t cell_count_check = 0;
-        for (const auto & range : cr.data) {
-            cell_count_check += range.second - range.first;
-        }
-        GGML_ASSERT(cell_count == cell_count_check);
-
-        io.write(&cell_count, sizeof(cell_count));
-
-        // skip empty streams
-        if (cell_count == 0) {
-            continue;
-        }
-
-        state_write_meta(io, cr, seq_id);
-        state_write_data(io, cr);
-    }
-}
-
-void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
-    GGML_UNUSED(flags);
-
-    GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
-
-    uint32_t n_stream_cur;
-    io.read_to(&n_stream_cur, sizeof(n_stream_cur));
-    if (n_stream_cur != n_stream) {
-        throw std::runtime_error("n_stream mismatch");
-    }
-
-    for (uint32_t s = 0; s < n_stream; ++s) {
-        uint32_t cell_count;
-        io.read_to(&cell_count, sizeof(cell_count));
-
-        if (cell_count == 0) {
-            continue;
-        }
-
-        const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id];
-
-        bool res = true;
-        res = res && state_read_meta(io, strm, cell_count, seq_id);
-        res = res && state_read_data(io, strm, cell_count);
-
-        if (!res) {
-            if (seq_id == -1) {
-                clear(true);
-            } else {
-                seq_rm(seq_id, -1, -1);
-            }
-            throw std::runtime_error("failed to restore kv cache");
-        }
-    }
-}
-
-void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const {
-    const auto & cells = v_cells[cr.strm];
-
-    for (const auto & range : cr.data) {
-        for (uint32_t i = range.first; i < range.second; ++i) {
-            std::vector<llama_seq_id> seq_ids;
-
-            for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) {
-                if (cur == seq_id || seq_id == -1) {
-                    if (cells.seq_has(i, cur)) {
-                        seq_ids.push_back(cur);
-                    }
-                }
-            }
-
-            const llama_pos pos     = cells.pos_get(i);
-            const uint32_t n_seq_id = seq_ids.size();
-
-            io.write(&pos,      sizeof(pos));
-            io.write(&n_seq_id, sizeof(n_seq_id));
-
-            for (const auto & seq_id : seq_ids) {
-                io.write(&seq_id, sizeof(seq_id));
-            }
-        }
-    }
-}
-
-void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const {
-    const auto & cells = v_cells[cr.strm];
-
-    const uint32_t v_trans = this->v_trans ? 1 : 0;
-    const uint32_t n_layer = layers.size();
-
-    io.write(&v_trans, sizeof(v_trans));
-    io.write(&n_layer, sizeof(n_layer));
-
-    std::vector<uint8_t> tmp_buf;
-
-    // Iterate and write all the keys first, each row is a cell
-    // Get whole range at a time
-    for (const auto & layer : layers) {
-        const uint32_t il = layer.il;
-
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
-
-        auto * k = layer.k_stream[cr.strm];
-
-        // Write key type
-        const int32_t k_type_i = (int32_t) k->type;
-        io.write(&k_type_i, sizeof(k_type_i));
-
-        // Write row size of key
-        const uint64_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
-        io.write(&k_size_row, sizeof(k_size_row));
-
-        // Read each range of cells of k_size length each into tmp_buf and write out
-        for (const auto & range : cr.data) {
-            const size_t range_size = range.second - range.first;
-            const size_t buf_size = range_size * k_size_row;
-            io.write_tensor(k, range.first * k_size_row, buf_size);
-        }
-    }
-
-    if (!v_trans) {
-        for (const auto & layer : layers) {
-            const uint32_t il = layer.il;
-
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
-
-            auto * v = layer.v_stream[cr.strm];
-
-            // Write value type
-            const int32_t v_type_i = (int32_t) v->type;
-            io.write(&v_type_i, sizeof(v_type_i));
-
-            // Write row size of value
-            const uint64_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
-            io.write(&v_size_row, sizeof(v_size_row));
-
-            // Read each range of cells of v_size length each into tmp_buf and write out
-            for (const auto & range : cr.data) {
-                const size_t range_size = range.second - range.first;
-                const size_t buf_size = range_size * v_size_row;
-                io.write_tensor(v, range.first * v_size_row, buf_size);
-            }
-        }
-    } else {
-        // When v is transposed, we also need the element size and get the element ranges from each row
-        const uint32_t kv_size = cells.size();
-
-        for (const auto & layer : layers) {
-            const uint32_t il = layer.il;
-
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
-
-            auto * v = layer.v_stream[cr.strm];
-
-            // Write value type
-            const int32_t v_type_i = (int32_t) v->type;
-            io.write(&v_type_i, sizeof(v_type_i));
-
-            // Write element size
-            const uint32_t v_size_el = ggml_type_size(v->type);
-            io.write(&v_size_el, sizeof(v_size_el));
-
-            // Write GQA embedding size
-            io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
-
-            // For each row, we get the element values of each cell
-            for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                // Read each range of cells of v_size_el length each into tmp_buf and write out
-                for (const auto & range : cr.data) {
-                    const size_t range_size = range.second - range.first;
-                    const size_t src_offset = (range.first + j * kv_size) * v_size_el;
-                    const size_t buf_size = range_size * v_size_el;
-                    io.write_tensor(v, src_offset, buf_size);
-                }
-            }
-        }
-    }
-}
-
-bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
-    auto & cells = v_cells[strm];
-    auto & head  = v_heads[strm];
-
-    if (dest_seq_id != -1) {
-        // single sequence
-        seq_rm(dest_seq_id, -1, -1);
-
-        llama_batch_allocr balloc(hparams.n_pos_per_embd());
-
-        llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
-
-        ubatch.seq_id_unq[0] = dest_seq_id;
-
-        for (uint32_t i = 0; i < cell_count; ++i) {
-            llama_pos pos;
-            uint32_t n_seq_id;
-
-            io.read_to(&pos,      sizeof(pos));
-            io.read_to(&n_seq_id, sizeof(n_seq_id));
-
-            if (n_seq_id != 1) {
-                LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
-                return false;
-            }
-
-            // read the sequence id, but directly discard it - we will use dest_seq_id instead
-            {
-                llama_seq_id seq_id;
-                io.read_to(&seq_id, sizeof(seq_id));
-            }
-
-            ubatch.pos[i]      = pos;
-            ubatch.n_seq_id[i] = n_seq_id;
-            ubatch.seq_id[i]   = &dest_seq_id;
-        }
-
-        const auto sinfo = find_slot(ubatch, true);
-        if (sinfo.empty()) {
-            LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
-            return false;
-        }
-
-        apply_ubatch(sinfo, ubatch);
-
-        const auto head_cur = sinfo.head();
-
-        // keep the head at the old position because we will read the KV data into it in state_read_data()
-        head = head_cur;
-
-        LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id);
-
-        // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
-        // Assume that this is one contiguous block of cells
-        GGML_ASSERT(head_cur + cell_count <= cells.size());
-        GGML_ASSERT(cells.pos_get(head_cur)                  == ubatch.pos[0]);
-        GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
-        GGML_ASSERT(cells.seq_has(head_cur,                  dest_seq_id));
-        GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
-    } else {
-        // whole KV cache restore
-
-        if (cell_count > cells.size()) {
-            LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
-            return false;
-        }
-
-        clear(true);
-
-        for (uint32_t i = 0; i < cell_count; ++i) {
-            llama_pos pos;
-            uint32_t  n_seq_id;
-
-            io.read_to(&pos,      sizeof(pos));
-            io.read_to(&n_seq_id, sizeof(n_seq_id));
-
-            cells.pos_set(i, pos);
-
-            for (uint32_t j = 0; j < n_seq_id; ++j) {
-                llama_seq_id seq_id;
-                io.read_to(&seq_id, sizeof(seq_id));
-
-                if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
-                    LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
-                    return false;
-                }
-
-                cells.seq_add(i, seq_id);
-            }
-        }
-
-        head = 0;
-    }
-
-    return true;
-}
-
-bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
-    auto & cells = v_cells[strm];
-    auto & head  = v_heads[strm];
-
-    uint32_t v_trans;
-    uint32_t n_layer;
-
-    io.read_to(&v_trans, sizeof(v_trans));
-    io.read_to(&n_layer, sizeof(n_layer));
-
-    if (n_layer != layers.size()) {
-        LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
-        return false;
-    }
-
-    if (cell_count > cells.size()) {
-        LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size());
-        return false;
-    }
-
-    if (this->v_trans != (bool) v_trans) {
-        LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
-        return false;
-    }
-
-    // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
-    for (const auto & layer : layers) {
-        const uint32_t il = layer.il;
-
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
-
-        auto * k = layer.k_stream[strm];
-
-        // Read type of key
-        int32_t k_type_i_ref;
-        io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
-        const int32_t k_type_i = (int32_t) k->type;
-        if (k_type_i != k_type_i_ref) {
-            LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
-            return false;
-        }
-
-        // Read row size of key
-        uint64_t k_size_row_ref;
-        io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
-        const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
-        if (k_size_row != k_size_row_ref) {
-            LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
-            return false;
-        }
-
-        if (cell_count) {
-            // Read and set the keys for the whole cell range
-            ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
-        }
-    }
-
-    if (!this->v_trans) {
-        for (const auto & layer : layers) {
-            const uint32_t il = layer.il;
-
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
-
-            auto * v = layer.v_stream[strm];
-
-            // Read type of value
-            int32_t v_type_i_ref;
-            io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
-            const int32_t v_type_i = (int32_t) v->type;
-            if (v_type_i != v_type_i_ref) {
-                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
-                return false;
-            }
-
-            // Read row size of value
-            uint64_t v_size_row_ref;
-            io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
-            const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
-            if (v_size_row != v_size_row_ref) {
-                LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
-                return false;
-            }
-
-            if (cell_count) {
-                // Read and set the values for the whole cell range
-                ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
-            }
-        }
-    } else {
-        // For each layer, read the values for each cell (transposed)
-        for (const auto & layer : layers) {
-            const uint32_t il = layer.il;
-
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
-
-            auto * v = layer.v_stream[strm];
-
-            // Read type of value
-            int32_t v_type_i_ref;
-            io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
-            const int32_t v_type_i = (int32_t) v->type;
-            if (v_type_i != v_type_i_ref) {
-                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
-                return false;
-            }
-
-            // Read element size of value
-            uint32_t v_size_el_ref;
-            io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
-            const size_t v_size_el = ggml_type_size(v->type);
-            if (v_size_el != v_size_el_ref) {
-                LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
-                return false;
-            }
-
-            // Read GQA embedding size
-            uint32_t n_embd_v_gqa_ref;
-            io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
-            if (n_embd_v_gqa != n_embd_v_gqa_ref) {
-                LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
-                return false;
-            }
-
-            if (cell_count) {
-                // For each row in the transposed matrix, read the values for the whole cell range
-                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                    const size_t dst_offset = (head + j * cells.size()) * v_size_el;
-                    ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
-                }
-            }
-        }
-    }
-
-    return true;
-}
-
-//
-// llama_kv_cache_unified_context
-//
-
-llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_status status) : status(status) {}
-
-llama_kv_cache_unified_context::llama_kv_cache_unified_context(
-        llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
-    n_kv = kv->get_size();
-
-    const uint32_t n_stream = kv->get_n_stream();
-
-    // create a dummy slot info - the actual data is irrelevant. we just need to build the graph
-    sinfos.resize(1);
-    sinfos[0].s0 = 0;
-    sinfos[0].s1 = n_stream - 1;
-    sinfos[0].idxs.resize(n_stream);
-    for (uint32_t s = 0; s < n_stream; ++s) {
-        sinfos[0].strm.push_back(s);
-        sinfos[0].idxs[s].resize(1, 0);
-    }
-}
-
-llama_kv_cache_unified_context::llama_kv_cache_unified_context(
-        llama_kv_cache_unified * kv,
-        llama_context * lctx,
-        bool do_shift,
-        defrag_info dinfo,
-        stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)), sc_info(std::move(sc_info)) {
-    if (!do_shift && this->dinfo.empty() && this->sc_info.empty()) {
-        status = LLAMA_MEMORY_STATUS_NO_UPDATE;
-    }
-}
-
-llama_kv_cache_unified_context::llama_kv_cache_unified_context(
-        llama_kv_cache_unified * kv,
-        llama_kv_cache_unified::slot_info_vec_t sinfos,
-        std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
-}
-
-llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
-
-bool llama_kv_cache_unified_context::next() {
-    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
-
-    if (++i_cur >= ubatches.size()) {
-        return false;
-    }
-
-    return true;
-}
-
-bool llama_kv_cache_unified_context::apply() {
-    assert(!llama_memory_status_is_fail(status));
-
-    // no ubatches -> this is a KV cache update
-    if (ubatches.empty()) {
-        kv->update(lctx, do_shift, dinfo, sc_info);
-
-        return true;
-    }
-
-    kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
-
-    n_kv = kv->get_n_kv();
-
-    return true;
-}
-
-llama_memory_status llama_kv_cache_unified_context::get_status() const {
-    return status;
-}
-
-const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
-    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
-
-    return ubatches[i_cur];
-}
-
-uint32_t llama_kv_cache_unified_context::get_n_kv() const {
-    return n_kv;
-}
-
-bool llama_kv_cache_unified_context::get_supports_set_rows() const {
-    return kv->get_supports_set_rows();
-}
-
-ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
-    return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
-}
-
-ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
-    return kv->get_v(ctx, il, n_kv, sinfos[i_cur]);
-}
-
-ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
-    return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
-}
-
-ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
-    return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
-}
-
-ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
-    return kv->build_input_k_idxs(ctx, ubatch);
-}
-
-ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
-    return kv->build_input_v_idxs(ctx, ubatch);
-}
-
-void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
-    kv->set_input_k_shift(dst);
-}
-
-void llama_kv_cache_unified_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
-    kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
-}
-
-void llama_kv_cache_unified_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
-    kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
-}
-
-void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
-    kv->set_input_kq_mask(dst, ubatch, causal_attn);
-}
-
-void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
-    kv->set_input_pos_bucket(dst, ubatch);
-}
-
-uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
-    // the FA kernels require padding to avoid extra runtime boundary checks
-    return cparams.flash_attn ? 256u : 32u;
-}
diff --git a/examples/talk-llama/llama-kv-cache-unified.h b/examples/talk-llama/llama-kv-cache-unified.h
deleted file mode 100644 (file)
index 07a7c9e..0000000
+++ /dev/null
@@ -1,399 +0,0 @@
-#pragma once
-
-#include "llama-batch.h"
-#include "llama-graph.h"
-#include "llama-kv-cells.h"
-#include "llama-memory.h"
-
-#include <unordered_map>
-#include <vector>
-
-struct llama_cparams;
-struct llama_hparams;
-struct llama_model;
-struct llama_context;
-
-//
-// llama_kv_cache_unified
-//
-
-class llama_kv_cache_unified : public llama_memory_i {
-public:
-    static uint32_t get_padding(const llama_cparams & cparams);
-
-    // this callback is used to filter out layers that should not be included in the cache
-    using layer_filter_cb = std::function<bool(int32_t il)>;
-
-    struct defrag_info {
-        bool empty() const {
-            return ids.empty();
-        }
-
-        // contains information about which cell moves where:
-        //  - cell i moves to ids[i]
-        //  - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved
-        std::vector<uint32_t> ids;
-    };
-
-    struct stream_copy_info {
-        bool empty() const {
-            assert(ssrc.size() == sdst.size());
-            return ssrc.empty();
-        }
-
-        std::vector<uint32_t> ssrc;
-        std::vector<uint32_t> sdst;
-    };
-
-    // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
-    //   KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
-    struct slot_info {
-        // data for ggml_set_rows
-        using idx_vec_t = std::vector<uint32_t>;
-
-        // number of streams: ns = s1 - s0 + 1
-        llama_seq_id s0;
-        llama_seq_id s1;
-
-        std::vector<llama_seq_id> strm; // [ns]
-        std::vector<idx_vec_t>    idxs; // [ns]
-
-        uint32_t head() const {
-            GGML_ASSERT(idxs.size() == 1);
-            GGML_ASSERT(!idxs[0].empty());
-
-            return idxs[0][0];
-        }
-
-        void resize(size_t n) {
-            strm.resize(n);
-            idxs.resize(n);
-        }
-
-        size_t size() const {
-            GGML_ASSERT(idxs.size() == strm.size());
-            GGML_ASSERT(!idxs.empty());
-
-            return idxs[0].size();
-        }
-
-        size_t n_stream() const {
-            return strm.size();
-        }
-
-        bool empty() const {
-            return idxs.empty();
-        }
-
-        void clear() {
-            idxs.clear();
-        }
-    };
-
-    using slot_info_vec_t = std::vector<slot_info>;
-
-    llama_kv_cache_unified(
-            const llama_model &  model,
-              layer_filter_cb && filter,
-                    ggml_type    type_k,
-                    ggml_type    type_v,
-                         bool    v_trans,
-                         bool    offload,
-                         bool    unified,
-                     uint32_t    kv_size,
-                     uint32_t    n_seq_max,
-                     uint32_t    n_pad,
-                     uint32_t    n_swa,
-               llama_swa_type    swa_type);
-
-    ~llama_kv_cache_unified() = default;
-
-    //
-    // llama_memory_i
-    //
-
-    llama_memory_context_ptr init_batch(
-            llama_batch_allocr & balloc,
-            uint32_t n_ubatch,
-            bool embd_all) override;
-
-    llama_memory_context_ptr init_full() override;
-
-    llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
-
-    bool get_can_shift() const override;
-
-    void clear(bool data) override;
-
-    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
-    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
-    void seq_keep(llama_seq_id seq_id)                                                          override;
-    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
-    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
-
-    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
-    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
-
-    // state write/load
-
-    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
-    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
-
-    //
-    // llama_kv_cache_unified specific API
-    //
-
-    uint32_t get_size()     const;
-    uint32_t get_n_stream() const;
-
-    bool get_has_shift() const;
-
-    //
-    // graph_build API
-    //
-
-    uint32_t get_n_kv() const;
-
-    // TODO: temporary
-    bool get_supports_set_rows() const;
-
-    // get views of the current state of the cache
-    ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
-    ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
-
-    // store k_cur and v_cur in the cache based on the provided head location
-    ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
-    ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;
-
-    //
-    // preparation API
-    //
-
-    // find places for the provided ubatches in the cache, returns the slot infos
-    // return empty vector on failure
-    slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
-
-    bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info);
-
-    // find a slot of kv cells that can hold the ubatch
-    // if cont == true, then the slot must be continuous
-    // return empty slot_info on failure
-    slot_info find_slot(const llama_ubatch & ubatch, bool cont) const;
-
-    // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]
-    void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
-
-    //
-    // input API
-    //
-
-    ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
-    ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
-
-    void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
-    void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
-
-    void set_input_k_shift(ggml_tensor * dst) const;
-
-    void set_input_kq_mask   (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
-    void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
-
-private:
-    const llama_model & model;
-    const llama_hparams & hparams;
-
-    struct kv_layer {
-        // layer index in the model
-        // note: can be different from the layer index in the KV cache
-        uint32_t il;
-
-        ggml_tensor * k;
-        ggml_tensor * v;
-
-        std::vector<ggml_tensor *> k_stream;
-        std::vector<ggml_tensor *> v_stream;
-    };
-
-    bool v_trans = true;  // the value tensor is transposed
-
-    const uint32_t n_seq_max = 1;
-    const uint32_t n_stream  = 1;
-
-    // required padding
-    const uint32_t n_pad = 1;
-
-    // SWA
-    const uint32_t n_swa = 0;
-
-    // env: LLAMA_KV_CACHE_DEBUG
-    int debug = 0;
-
-    // env: LLAMA_SET_ROWS (temporary)
-    // ref: https://github.com/ggml-org/llama.cpp/pull/14285
-    bool supports_set_rows = true;
-
-    const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
-
-    std::vector<ggml_context_ptr>        ctxs;
-    std::vector<ggml_backend_buffer_ptr> bufs;
-
-    // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
-    // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
-    std::vector<uint32_t> v_heads;
-
-    std::vector<llama_kv_cells_unified> v_cells;
-
-    // maps from a sequence id to a stream id
-    std::vector<uint32_t> seq_to_stream;
-
-    // pending stream copies that will be applied during the next update
-    stream_copy_info sc_info;
-
-    std::vector<kv_layer> layers;
-
-    // model layer id -> KV cache layer id
-    std::unordered_map<int32_t, int32_t> map_layer_ids;
-
-    // return non-empty vector if cells have been moved
-    defrag_info defrag_prepare(int32_t n_max_nodes) const;
-
-    size_t total_size() const;
-
-    size_t size_k_bytes() const;
-    size_t size_v_bytes() const;
-
-    bool is_masked_swa(llama_pos p0, llama_pos p1) const;
-
-    ggml_tensor * build_rope_shift(
-            const llama_cparams & cparams,
-                   ggml_context * ctx,
-                    ggml_tensor * cur,
-                    ggml_tensor * shift,
-                    ggml_tensor * factors,
-                          float   freq_base,
-                          float   freq_scale) const;
-
-    ggml_cgraph * build_graph_shift(
-               llm_graph_result * res,
-                  llama_context * lctx) const;
-
-    ggml_cgraph * build_graph_defrag(
-               llm_graph_result * res,
-                  llama_context * lctx,
-              const defrag_info & dinfo) const;
-
-    struct cell_ranges_t {
-        uint32_t strm;
-
-        std::vector<std::pair<uint32_t, uint32_t>> data; // ranges, from inclusive, to exclusive
-    };
-
-    void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
-    void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
-
-    bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
-    bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
-};
-
-class llama_kv_cache_unified_context : public llama_memory_context_i {
-public:
-    // some shorthands
-    using slot_info_vec_t  = llama_kv_cache_unified::slot_info_vec_t;
-    using defrag_info      = llama_kv_cache_unified::defrag_info;
-    using stream_copy_info = llama_kv_cache_unified::stream_copy_info;
-
-    // used for errors
-    llama_kv_cache_unified_context(llama_memory_status status);
-
-    // used to create a full-cache context
-    llama_kv_cache_unified_context(
-            llama_kv_cache_unified * kv);
-
-    // used to create an update context
-    llama_kv_cache_unified_context(
-            llama_kv_cache_unified * kv,
-            llama_context * lctx,
-            bool do_shift,
-            defrag_info dinfo,
-            stream_copy_info sc_info);
-
-    // used to create a batch procesing context from a batch
-    llama_kv_cache_unified_context(
-            llama_kv_cache_unified * kv,
-            slot_info_vec_t sinfos,
-            std::vector<llama_ubatch> ubatches);
-
-    virtual ~llama_kv_cache_unified_context();
-
-    //
-    // llama_memory_context_i
-    //
-
-    bool next()  override;
-    bool apply() override;
-
-    llama_memory_status  get_status() const override;
-    const llama_ubatch & get_ubatch() const override;
-
-    //
-    // llama_kv_cache_unified_context specific API
-    //
-
-    uint32_t get_n_kv() const;
-
-    // TODO: temporary
-    bool get_supports_set_rows() const;
-
-    // get views of the current state of the cache
-    ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
-    ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
-
-    // store k_cur and v_cur in the cache based on the provided head location
-    ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
-    ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
-
-    ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
-    ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
-
-    void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
-    void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
-
-    void set_input_k_shift   (ggml_tensor * dst) const;
-    void set_input_kq_mask   (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
-    void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
-
-private:
-    llama_memory_status status;
-
-    llama_kv_cache_unified * kv;
-    llama_context * lctx;
-
-    //
-    // update context
-    //
-
-    bool do_shift = false;
-
-    defrag_info dinfo;
-
-    stream_copy_info sc_info;
-
-    //
-    // batch processing context
-    //
-
-    // the index of the cur ubatch to process
-    size_t i_cur = 0;
-
-    slot_info_vec_t sinfos;
-
-    std::vector<llama_ubatch> ubatches;
-
-    //
-    // data needed for building the compute graph for the current ubatch:
-    //
-
-    // a heuristic, to avoid attending the full cache if it is not yet utilized
-    // as the cache gets filled, the benefit from this heuristic disappears
-    int32_t n_kv;
-};
diff --git a/examples/talk-llama/llama-kv-cache.cpp b/examples/talk-llama/llama-kv-cache.cpp
new file mode 100644 (file)
index 0000000..885be07
--- /dev/null
@@ -0,0 +1,2012 @@
+#include "llama-kv-cache.h"
+
+#include "llama-impl.h"
+#include "llama-io.h"
+#include "llama-model.h"
+#include "llama-context.h"
+
+#include <algorithm>
+#include <cassert>
+#include <cmath>
+#include <limits>
+#include <map>
+#include <stdexcept>
+
+//
+// llama_kv_cache
+//
+
+llama_kv_cache::llama_kv_cache(
+        const llama_model & model,
+                ggml_type   type_k,
+                ggml_type   type_v,
+                     bool   v_trans,
+                     bool   offload,
+                     bool   unified,
+                 uint32_t   kv_size,
+                 uint32_t   n_seq_max,
+                 uint32_t   n_pad,
+                 uint32_t   n_swa,
+           llama_swa_type   swa_type,
+    const layer_filter_cb & filter,
+    const  layer_reuse_cb & reuse) :
+    model(model), hparams(model.hparams), v_trans(v_trans),
+    n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
+
+    GGML_ASSERT(kv_size % n_pad == 0);
+
+    const uint32_t n_layer_kv = hparams.n_layer_kv();
+
+    // create a context for each buffer type
+    std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
+    auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
+        auto it = ctx_map.find(buft);
+        if (it == ctx_map.end()) {
+            ggml_init_params params = {
+                /*.mem_size   =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()),
+                /*.mem_buffer =*/ NULL,
+                /*.no_alloc   =*/ true,
+            };
+
+            ggml_context * ctx = ggml_init(params);
+            if (!ctx) {
+                return nullptr;
+            }
+
+            ctx_map[buft] = ctx;
+            ctxs.emplace_back(ctx);
+
+            return ctx;
+        }
+
+        return it->second;
+    };
+
+    GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
+
+    v_heads.resize(n_stream);
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        v_heads[s] = 0;
+    }
+
+    v_cells.resize(n_stream);
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        v_cells[s].resize(kv_size);
+    }
+
+    // by default, all sequence ids are mapped to the 0th stream
+    seq_to_stream.resize(LLAMA_MAX_SEQ, 0);
+
+    if (n_stream > 1) {
+        seq_to_stream.resize(n_stream, 0);
+        for (uint32_t s = 0; s < n_stream; ++s) {
+            seq_to_stream[s] = s;
+        }
+    }
+
+    // [TAG_V_CACHE_VARIABLE]
+    if (v_trans && hparams.is_n_embd_v_gqa_variable()) {
+        LLAMA_LOG_WARN("%s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n",
+                __func__, hparams.n_embd_v_gqa_max());
+    }
+
+    for (uint32_t il = 0; il < hparams.n_layer; il++) {
+        if (!hparams.has_kv(il)) {
+            LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
+            continue;
+        }
+
+        if (filter && !filter(il)) {
+            LLAMA_LOG_DEBUG("%s: layer %3d: filtered\n", __func__, il);
+            continue;
+        }
+
+        // [TAG_V_CACHE_VARIABLE]
+        const uint32_t n_embd_k_gqa =            hparams.n_embd_k_gqa(il);
+        const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
+
+        const char * dev_name = "CPU";
+
+        ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
+
+        if (offload) {
+            auto * dev = model.dev_layer(il);
+            buft = ggml_backend_dev_buffer_type(dev);
+
+            dev_name = ggml_backend_dev_name(dev);
+        }
+
+        LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name);
+
+        ggml_context * ctx = ctx_for_buft(buft);
+        if (!ctx) {
+            throw std::runtime_error("failed to create ggml context for kv cache");
+        }
+
+        ggml_tensor * k;
+        ggml_tensor * v;
+
+        k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
+        v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
+
+        ggml_format_name(k, "cache_k_l%d", il);
+        ggml_format_name(v, "cache_v_l%d", il);
+
+        std::vector<ggml_tensor *> k_stream;
+        std::vector<ggml_tensor *> v_stream;
+
+        for (uint32_t s = 0; s < n_stream; ++s) {
+            k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]));
+            v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]));
+        }
+
+        map_layer_ids[il] = layers.size();
+
+        layers.push_back({ il, k, v, k_stream, v_stream, });
+    }
+
+    if (reuse) {
+        LLAMA_LOG_DEBUG("%s: reusing layers:\n", __func__);
+
+        for (uint32_t il = 0; il < hparams.n_layer; il++) {
+            const int32_t il_reuse = reuse(il);
+
+            if (il_reuse < 0) {
+                LLAMA_LOG_DEBUG("%s: - layer %3d: no reuse\n", __func__, il);
+                continue;
+            }
+
+            if (filter && !filter(il)) {
+                LLAMA_LOG_DEBUG("%s: - layer %3d: filtered\n", __func__, il);
+                continue;
+            }
+
+            GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
+
+            map_layer_ids[il] = map_layer_ids[il_reuse];
+
+            LLAMA_LOG_DEBUG("%s: - layer %3d: reuse layer %d, is_swa = %d\n", __func__, il, il_reuse, hparams.is_swa(il));
+        }
+    }
+
+    // allocate tensors and initialize the buffers to avoid NaNs in the padding
+    for (auto it : ctx_map) {
+        auto * buft = it.first;
+        auto * ctx  = it.second;
+
+        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
+        if (!buf) {
+            throw std::runtime_error("failed to allocate buffer for kv cache");
+        }
+
+        LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
+
+        ggml_backend_buffer_clear(buf, 0);
+        bufs.emplace_back(buf);
+    }
+
+    {
+        const size_t memory_size_k = size_k_bytes();
+        const size_t memory_size_v = size_v_bytes();
+
+        LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
+                (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_stream,
+                ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
+                ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
+    }
+
+    const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
+    debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
+}
+
+void llama_kv_cache::clear(bool data) {
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        v_cells[s].reset();
+        v_heads[s] = 0;
+    }
+
+    if (data) {
+        for (auto & buf : bufs) {
+            ggml_backend_buffer_clear(buf.get(), 0);
+        }
+    }
+}
+
+bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
+    GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
+
+    if (p0 < 0) {
+        p0 = 0;
+    }
+
+    if (p1 < 0) {
+        p1 = std::numeric_limits<llama_pos>::max();
+    }
+
+    if (seq_id >= 0) {
+        auto & cells = v_cells[seq_to_stream[seq_id]];
+        auto & head  = v_heads[seq_to_stream[seq_id]];
+
+        uint32_t new_head = cells.size();
+
+        for (uint32_t i = 0; i < cells.size(); ++i) {
+            if (!cells.pos_in(i, p0, p1)) {
+                continue;
+            }
+
+            if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
+                if (new_head == cells.size()) {
+                    new_head = i;
+                }
+            }
+        }
+
+        // If we freed up a slot, set head to it so searching can start there.
+        if (new_head != cells.size() && new_head < head) {
+            head = new_head;
+        }
+    } else {
+        // match any sequence
+        for (uint32_t s = 0; s < n_stream; ++s) {
+            auto & cells = v_cells[s];
+            auto & head  = v_heads[s];
+
+            uint32_t new_head = cells.size();
+
+            for (uint32_t i = 0; i < cells.size(); ++i) {
+                if (!cells.pos_in(i, p0, p1)) {
+                    continue;
+                }
+
+                cells.rm(i);
+
+                if (new_head == cells.size()) {
+                    new_head = i;
+                }
+            }
+
+            // If we freed up a slot, set head to it so searching can start there.
+            if (new_head != cells.size() && new_head < head) {
+                head = new_head;
+            }
+        }
+    }
+
+    return true;
+}
+
+void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
+    GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size());
+    GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size());
+
+    const auto s0 = seq_to_stream[seq_id_src];
+    const auto s1 = seq_to_stream[seq_id_dst];
+
+    if (s0 == s1) {
+        // since both sequences are in the same stream, no data copy is necessary
+        // we just have to update the cells meta data
+
+        auto & cells = v_cells[s0];
+
+        if (seq_id_src == seq_id_dst) {
+            return;
+        }
+
+        if (p0 < 0) {
+            p0 = 0;
+        }
+
+        if (p1 < 0) {
+            p1 = std::numeric_limits<llama_pos>::max();
+        }
+
+        for (uint32_t i = 0; i < cells.size(); ++i) {
+            if (!cells.pos_in(i, p0, p1)) {
+                continue;
+            }
+
+            if (cells.seq_has(i, seq_id_src)) {
+                cells.seq_add(i, seq_id_dst);
+            }
+        }
+
+        return;
+    }
+
+    // cross-stream sequence copies require to copy the actual buffer data
+
+    bool is_full = true;
+
+    if (p0 > 0 && p0 + 1 < (int) get_size()) {
+        is_full = false;
+    }
+
+    if (p1 > 0 && p1 + 1 < (int) get_size()) {
+        is_full = false;
+    }
+
+    GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers");
+
+    // enqueue the copy operation - the buffer copy will be performed during the next update
+    sc_info.ssrc.push_back(s0);
+    sc_info.sdst.push_back(s1);
+
+    v_cells[s1].reset();
+    for (uint32_t i = 0; i < v_cells[s0].size(); ++i) {
+        if (v_cells[s0].seq_has(i, seq_id_src)) {
+            llama_pos pos   = v_cells[s0].pos_get(i);
+            llama_pos shift = v_cells[s0].get_shift(i);
+
+            if (shift != 0) {
+                pos -= shift;
+                assert(pos >= 0);
+            }
+
+            v_cells[s1].pos_set(i, pos);
+            v_cells[s1].seq_add(i, seq_id_dst);
+
+            if (shift != 0) {
+                v_cells[s1].pos_add(i, shift);
+            }
+        }
+    }
+
+    v_heads[s1] = v_heads[s0];
+
+    //for (uint32_t s = 0; s < n_stream; ++s) {
+    //    LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s));
+    //}
+}
+
+void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
+    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
+
+    auto & cells = v_cells[seq_to_stream[seq_id]];
+    auto & head  = v_heads[seq_to_stream[seq_id]];
+
+    uint32_t new_head = cells.size();
+
+    for (uint32_t i = 0; i < cells.size(); ++i) {
+        if (cells.seq_keep(i, seq_id)) {
+            if (new_head == cells.size()) {
+                new_head = i;
+            }
+        }
+    }
+
+    // If we freed up a slot, set head to it so searching can start there.
+    if (new_head != cells.size() && new_head < head) {
+        head = new_head;
+    }
+}
+
+void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
+    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
+
+    auto & cells = v_cells[seq_to_stream[seq_id]];
+    auto & head  = v_heads[seq_to_stream[seq_id]];
+
+    if (shift == 0) {
+        return;
+    }
+
+    uint32_t new_head = cells.size();
+
+    if (p0 < 0) {
+        p0 = 0;
+    }
+
+    if (p1 < 0) {
+        p1 = std::numeric_limits<llama_pos>::max();
+    }
+
+    // If there is no range then return early to avoid looping over all cells.
+    if (p0 == p1) {
+        return;
+    }
+
+    for (uint32_t i = 0; i < cells.size(); ++i) {
+        if (!cells.pos_in(i, p0, p1)) {
+            continue;
+        }
+
+        if (cells.seq_has(i, seq_id)) {
+            if (cells.pos_add(i, shift)) {
+                if (new_head == cells.size()) {
+                    new_head = i;
+                }
+            }
+        }
+    }
+
+    // If we freed up a slot, set head to it so searching can start there.
+    // Otherwise we just start the next search from the beginning.
+    head = new_head != cells.size() ? new_head : 0;
+}
+
+void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
+    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
+
+    auto & cells = v_cells[seq_to_stream[seq_id]];
+
+    if (d == 1) {
+        return;
+    }
+
+    if (p0 < 0) {
+        p0 = 0;
+    }
+
+    if (p1 < 0) {
+        p1 = std::numeric_limits<llama_pos>::max();
+    }
+
+    // If there is no range then return early to avoid looping over the cache.
+    if (p0 == p1) {
+        return;
+    }
+
+    for (uint32_t i = 0; i < cells.size(); ++i) {
+        if (!cells.pos_in(i, p0, p1)) {
+            continue;
+        }
+
+        if (cells.seq_has(i, seq_id)) {
+            cells.pos_div(i, d);
+        }
+    }
+}
+
+llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const {
+    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
+
+    const auto & cells = v_cells[seq_to_stream[seq_id]];
+
+    return cells.seq_pos_min(seq_id);
+}
+
+llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
+    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
+
+    const auto & cells = v_cells[seq_to_stream[seq_id]];
+
+    return cells.seq_pos_max(seq_id);
+}
+
+llama_memory_context_ptr llama_kv_cache::init_batch(
+            llama_batch_allocr & balloc,
+            uint32_t n_ubatch,
+            bool embd_all) {
+    GGML_UNUSED(embd_all);
+
+    do {
+        balloc.split_reset();
+
+        std::vector<llama_ubatch> ubatches;
+        while (true) {
+            auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true);
+
+            if (ubatch.n_tokens == 0) {
+                break;
+            }
+
+            ubatches.push_back(std::move(ubatch)); // NOLINT
+        }
+
+        if (balloc.get_n_used() < balloc.get_n_tokens()) {
+            // failed to find a suitable split
+            break;
+        }
+
+        auto sinfos = prepare(ubatches);
+        if (sinfos.empty()) {
+            break;
+        }
+
+        return std::make_unique<llama_kv_cache_context>(
+                this, std::move(sinfos), std::move(ubatches));
+    } while (false);
+
+    return std::make_unique<llama_kv_cache_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+}
+
+llama_memory_context_ptr llama_kv_cache::init_full() {
+    return std::make_unique<llama_kv_cache_context>(this);
+}
+
+llama_memory_context_ptr llama_kv_cache::init_update(llama_context * lctx, bool optimize) {
+    GGML_UNUSED(optimize);
+
+    bool do_shift = get_has_shift();
+
+    return std::make_unique<llama_kv_cache_context>(this, lctx, do_shift, std::move(sc_info));
+}
+
+llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_ubatch> & ubatches) {
+    llama_kv_cache::slot_info_vec_t res;
+
+    struct state_t {
+        slot_info sinfo; // slot info for the ubatch
+
+        std::vector<uint32_t> v_heads_old; // old positions of the heads, before placing the ubatch
+
+        std::vector<llama_kv_cells> v_cells; // copy of the old cells, before placing the ubatch
+    };
+
+    // remember the old state of the cells so we can restore it in the end
+    std::vector<state_t> states;
+
+    bool success = true;
+
+    for (const auto & ubatch : ubatches) {
+        // only find a suitable slot for the ubatch. don't modify the cells yet
+        const auto sinfo_new = find_slot(ubatch, false);
+        if (sinfo_new.empty()) {
+            success = false;
+            break;
+        }
+
+        // remeber the position that we found
+        res.push_back(sinfo_new);
+
+        // store the old state of the cells in the recovery stack
+        {
+            state_t state = { sinfo_new, v_heads, {} };
+
+            for (uint32_t s = 0; s < sinfo_new.n_stream(); ++s) {
+                auto & cells = v_cells[sinfo_new.strm[s]];
+
+                state.v_cells.push_back(cells.cp(sinfo_new.idxs[s]));
+            }
+
+            states.push_back(std::move(state));
+        }
+
+        // now emplace the ubatch
+        apply_ubatch(sinfo_new, ubatch);
+    }
+
+    GGML_ASSERT(!states.empty() || !success);
+
+    // iterate backwards and restore the cells to their original state
+    for (auto it = states.rbegin(); it != states.rend(); ++it) {
+        const auto & sinfo = it->sinfo;
+
+        for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
+            auto & cells = v_cells[sinfo.strm[s]];
+            auto & head  = v_heads[sinfo.strm[s]];
+
+            cells.set(sinfo.idxs[s], it->v_cells[s]);
+            head = it->v_heads_old[s];
+        }
+    }
+
+    if (!success) {
+        return {};
+    }
+
+    return res;
+}
+
+bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info) {
+    bool updated = false;
+
+    auto * sched = lctx->get_sched();
+
+    if (!sc_info.empty()) {
+        assert(n_stream > 1 && "stream copy should never happen with a single stream");
+
+        llama_synchronize(lctx);
+
+        const size_t n_copy = sc_info.ssrc.size();
+
+        for (size_t i = 0; i < n_copy; ++i) {
+            const auto ssrc = sc_info.ssrc[i];
+            const auto sdst = sc_info.sdst[i];
+
+            assert(ssrc < n_stream);
+            assert(sdst < n_stream);
+
+            LLAMA_LOG_DEBUG("%s: copying KV buffer: stream %d to stream %d\n", __func__, ssrc, sdst);
+
+            assert(ssrc != sdst);
+
+            for (uint32_t il = 0; il < layers.size(); ++il) {
+                const auto & layer = layers[il];
+
+                ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]);
+                ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
+            }
+        }
+    }
+
+    if (do_shift) {
+        if (!get_can_shift()) {
+            GGML_ABORT("The current KV cache / model configuration does not support K-shift");
+        }
+
+        LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
+
+        // apply K-shift if needed
+        if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
+            ggml_backend_sched_reset(sched);
+
+            auto * res = lctx->get_gf_res_reserve();
+
+            res->reset();
+
+            auto * gf = build_graph_shift(res, lctx);
+            if (!ggml_backend_sched_alloc_graph(sched, gf)) {
+                LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__);
+                return updated;
+            }
+
+            res->set_inputs(nullptr);
+
+            if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
+                LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
+                return updated;
+            }
+
+            updated = true;
+        }
+
+        for (uint32_t s = 0; s < n_stream; ++s) {
+            auto & cells = v_cells[s];
+
+            cells.reset_shift();
+        }
+    }
+
+    return updated;
+}
+
+llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, bool cont) const {
+
+    if (debug > 0) {
+        for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
+            const auto seq_id = ubatch.seq_id_unq[s];
+            const auto stream_id = seq_to_stream[seq_id];
+            const auto & cells = v_cells[stream_id];
+            const uint32_t head_cur = v_heads[stream_id];
+
+            LLAMA_LOG_DEBUG("%s: stream[%d], n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n",
+                    __func__, stream_id, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa);
+
+            if ((debug == 2 && n_swa > 0) || debug > 2) {
+                std::string ss;
+                for (uint32_t i = 0; i < cells.size(); ++i) {
+                    if (cells.is_empty(i)) {
+                        ss += '.';
+                    } else {
+                        assert(cells.seq_count(i) >= 1);
+
+                        if (cells.seq_count(i) == 1) {
+                            ss += std::to_string(cells.seq_get(i));
+                        } else {
+                            ss += 'M';
+                        }
+                    }
+                    if (i%256 == 255) {
+                        ss += " *";
+                        ss += '\n';
+                    }
+                }
+                LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
+            }
+
+            if ((debug == 2 && n_swa > 0) || debug > 2) {
+                std::string ss;
+                for (uint32_t i = 0; i < cells.size(); ++i) {
+                    std::string cur;
+                    if (cells.is_empty(i)) {
+                        cur = '.';
+                    } else {
+                        cur = std::to_string(cells.pos_get(i));
+                    }
+                    const int n = cur.size();
+                    for (int j = 0; j < 5 - n; ++j) {
+                        cur += ' ';
+                    }
+                    ss += cur;
+                    if (i%256 == 255) {
+                        ss += " *";
+                    }
+                    if (i%64 == 63) {
+                        ss += '\n';
+                    }
+                }
+                LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
+            }
+
+            for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
+                if (cells.seq_pos_min(s) < 0) {
+                    continue;
+                }
+
+                LLAMA_LOG_DEBUG("%s: stream[%d] min[%d] = %5d, max[%d] = %5d\n", __func__, stream_id, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
+            }
+        }
+    }
+
+    uint32_t n_tokens = ubatch.n_tokens;
+    uint32_t n_seqs   = 1;
+
+    if (n_stream > 1) {
+        GGML_ASSERT(n_tokens % ubatch.n_seqs_unq == 0);
+
+        n_seqs   = ubatch.n_seqs_unq;
+        n_tokens = n_tokens / n_seqs;
+    }
+
+    slot_info res = {
+        /*.s0   =*/ LLAMA_MAX_SEQ,
+        /*.s1   =*/ 0,
+        /*.strm =*/ { },
+        /*.idxs =*/ { },
+    };
+
+    res.resize(n_seqs);
+
+    for (uint32_t s = 0; s < n_seqs; ++s) {
+        const auto seq_id = ubatch.seq_id_unq[s];
+
+        if (n_stream > 1) {
+            GGML_ASSERT(ubatch.n_seq_id[s*n_tokens]    == 1);
+            GGML_ASSERT(ubatch.seq_id  [s*n_tokens][0] == seq_id);
+        }
+
+        res.s0 = std::min<uint32_t>(res.s0, seq_to_stream[seq_id]);
+        res.s1 = std::max<uint32_t>(res.s1, seq_to_stream[seq_id]);
+
+        res.strm[s] = seq_to_stream[seq_id];
+        res.idxs[s].reserve(n_tokens);
+
+        const auto & cells = v_cells[seq_to_stream[seq_id]];
+
+        uint32_t head_cur = v_heads[seq_to_stream[seq_id]];
+
+        // if we have enough unused cells before the current head ->
+        //   better to start searching from the beginning of the cache, hoping to fill it
+        if (head_cur > cells.get_used() + 2*n_tokens) {
+            head_cur = 0;
+        }
+
+        if (n_tokens > cells.size()) {
+            LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
+            return { };
+        }
+
+        uint32_t n_tested = 0;
+
+        // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
+        // for non-continuous slots, we test the tokens one by one
+        const uint32_t n_test = cont ? n_tokens : 1;
+
+        while (true) {
+            if (head_cur + n_test > cells.size()) {
+                n_tested += cells.size() - head_cur;
+                head_cur = 0;
+                continue;
+            }
+
+            for (uint32_t i = 0; i < n_test; i++) {
+                const auto idx = head_cur;
+
+                head_cur++;
+                n_tested++;
+
+                //const llama_pos    pos    = ubatch.pos[i];
+                //const llama_seq_id seq_id = ubatch.seq_id[i][0];
+
+                // can we use this cell? either:
+                //  - the cell is empty
+                //  - the cell is occupied only by one sequence:
+                //    - (disabled) mask causally, if the sequence is the same as the one we are inserting
+                //    - mask SWA, using current max pos for that sequence in the cache
+                //                always insert in the cell with minimum pos
+                bool can_use = cells.is_empty(idx);
+
+                if (!can_use && cells.seq_count(idx) == 1) {
+                    const llama_pos pos_cell = cells.pos_get(idx);
+
+                    // (disabled) causal mask
+                    // note: it's better to purge any "future" tokens beforehand
+                    //if (cells.seq_has(idx, seq_id)) {
+                    //    can_use = pos_cell >= pos;
+                    //}
+
+                    if (!can_use) {
+                        const llama_seq_id seq_id_cell = cells.seq_get(idx);
+
+                        // SWA mask
+                        if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
+                            can_use = true;
+                        }
+                    }
+                }
+
+                if (can_use) {
+                    res.idxs[s].push_back(idx);
+                } else {
+                    if (cont) {
+                        break;
+                    }
+                }
+            }
+
+            if (res.idxs[s].size() == n_tokens) {
+                break;
+            }
+
+            if (cont) {
+                res.idxs[s].clear();
+            }
+
+            if (n_tested >= cells.size()) {
+                //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
+                return { };
+            }
+        }
+
+        // we didn't find a suitable slot - return empty result
+        if (res.idxs[s].size() < n_tokens) {
+            return { };
+        }
+    }
+
+    assert(res.s1 >= res.s0);
+
+    return res;
+}
+
+void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
+    // keep track of the max sequence position that we would overwrite with this ubatch
+    // for non-SWA cache, this would be always empty
+    llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
+    for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+        seq_pos_max_rm[s] = -1;
+    }
+
+    assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size());
+
+    for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
+        for (uint32_t ii = 0; ii < sinfo.size(); ++ii) {
+            const uint32_t i = s*sinfo.size() + ii;
+
+            auto & cells = v_cells[sinfo.strm[s]];
+
+            const auto idx = sinfo.idxs[s][ii];
+
+            if (!cells.is_empty(idx)) {
+                assert(cells.seq_count(idx) == 1);
+
+                const llama_seq_id seq_id = cells.seq_get(idx);
+                const llama_pos    pos    = cells.pos_get(idx);
+
+                seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
+
+                cells.rm(idx);
+            }
+
+            cells.pos_set(idx, ubatch.pos[i]);
+
+            for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
+                cells.seq_add(idx, ubatch.seq_id[i][s]);
+            }
+        }
+    }
+
+    // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
+    //       will be present in the cache. so we have to purge any position which is less than those we would overwrite
+    //       ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
+    for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+        if (seq_pos_max_rm[s] == -1) {
+            continue;
+        }
+
+        GGML_ASSERT(s < seq_to_stream.size());
+
+        auto & cells = v_cells[seq_to_stream[s]];
+
+        if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
+            LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
+                    __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
+
+            seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
+        }
+    }
+
+    // move the head at the end of the slot
+    for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
+        auto & head = v_heads[sinfo.strm[s]];
+
+        head = sinfo.idxs[s].back() + 1;
+    }
+}
+
+bool llama_kv_cache::get_can_shift() const {
+    return true;
+}
+
+uint32_t llama_kv_cache::get_size() const {
+    const auto & cells = v_cells[seq_to_stream[0]];
+
+    return cells.size();
+}
+
+uint32_t llama_kv_cache::get_n_stream() const {
+    return n_stream;
+}
+
+bool llama_kv_cache::get_has_shift() const {
+    bool result = false;
+
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        result |= v_cells[s].get_has_shift();
+    }
+
+    return result;
+}
+
+uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
+    uint32_t result = 0;
+
+    for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
+        const auto & cells = v_cells[sinfo.strm[s]];
+
+        result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
+    }
+
+    return result;
+}
+
+ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
+    const int32_t ikv = map_layer_ids.at(il);
+
+    auto * k = layers[ikv].k;
+
+    const uint64_t kv_size      = get_size();
+    const uint64_t n_embd_k_gqa = k->ne[0];
+
+    assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il));
+
+    const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
+
+    return ggml_view_4d(ctx, k,
+            hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns,
+            ggml_row_size(k->type, hparams.n_embd_head_k),
+            ggml_row_size(k->type, n_embd_k_gqa),
+            ggml_row_size(k->type, n_embd_k_gqa*kv_size),
+            ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0);
+}
+
+ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
+    const int32_t ikv = map_layer_ids.at(il);
+
+    auto * v = layers[ikv].v;
+
+    const uint64_t kv_size      = get_size();
+    const uint64_t n_embd_v_gqa = v->ne[0];
+
+    // [TAG_V_CACHE_VARIABLE]
+    assert(n_embd_v_gqa >= hparams.n_embd_v_gqa(il));
+
+    const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
+
+    if (!v_trans) {
+        // note: v->nb[1] <= v->nb[2]
+        return ggml_view_4d(ctx, v,
+                hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
+                ggml_row_size(v->type, hparams.n_embd_head_v),          // v->nb[1]
+                ggml_row_size(v->type, n_embd_v_gqa),                   // v->nb[2]
+                ggml_row_size(v->type, n_embd_v_gqa*kv_size),           // v->nb[3]
+                ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0);
+    }
+
+    // note: v->nb[1] > v->nb[2]
+    return ggml_view_4d(ctx, v,
+            n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
+            ggml_row_size(v->type, kv_size*hparams.n_embd_head_v),  // v->nb[1]
+            ggml_row_size(v->type, kv_size),                        // v->nb[2]
+            ggml_row_size(v->type, kv_size*n_embd_v_gqa),           // v->nb[3]
+            ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
+}
+
+ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
+    GGML_UNUSED(sinfo);
+
+    const int32_t ikv = map_layer_ids.at(il);
+
+    ggml_tensor * k = layers[ikv].k;
+
+    const int64_t n_embd_head = k_cur->ne[0];
+    const int64_t n_head      = k_cur->ne[1];
+    const int64_t n_tokens    = k_cur->ne[2];
+
+    const int64_t n_embd_gqa = n_embd_head*n_head;
+
+    // we can merge dims 0 and 1
+    // TODO: add ggml helper function for this?
+    GGML_ASSERT(ggml_row_size(k_cur->type, n_embd_head) == k_cur->nb[1]);
+
+    k_cur = ggml_view_2d(ctx, k_cur, n_embd_gqa, n_tokens, k_cur->nb[2], 0);
+
+    const int64_t n_stream = k->ne[2];
+
+    if (n_stream > 1) {
+        const int64_t kv_size = get_size();
+
+        assert(n_embd_gqa == k->ne[0]);
+        assert(kv_size    == k->ne[1]);
+
+        // merge the buffer across all streams because the idxs are global
+        k = ggml_reshape_2d(ctx, k, n_embd_gqa, kv_size*n_stream);
+    }
+
+    // store the current K values into the cache
+    return ggml_set_rows(ctx, k, k_cur, k_idxs);
+}
+
+ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
+    GGML_UNUSED(sinfo);
+
+    const int32_t ikv = map_layer_ids.at(il);
+
+    auto * v = layers[ikv].v;
+
+    const int64_t n_embd_head = v_cur->ne[0];
+    const int64_t n_head      = v_cur->ne[1];
+    const int64_t n_tokens    = v_cur->ne[2];
+
+    const int64_t n_embd_gqa = n_embd_head*n_head;
+
+    // we can merge dims 0 and 1
+    GGML_ASSERT(ggml_row_size(v_cur->type, n_embd_head) == v_cur->nb[1]);
+
+    const int64_t n_stream = v->ne[2];
+
+    // take this branch when FA is enabled (the V cache is not transposed)
+    if (!v_trans) {
+        v_cur = ggml_view_2d(ctx, v_cur, n_embd_gqa, n_tokens, v_cur->nb[2], 0);
+
+        if (n_stream > 1) {
+            const int64_t kv_size = get_size();
+
+            assert(n_embd_gqa == v->ne[0]);
+            assert(kv_size    == v->ne[1]);
+
+            // merge the buffer across all streams because the idxs are global
+            v = ggml_reshape_2d(ctx, v, n_embd_gqa, kv_size*n_stream);
+        }
+
+        return ggml_set_rows(ctx, v, v_cur, v_idxs);
+    }
+
+    if (ggml_row_size(v_cur->type, n_embd_gqa) == v_cur->nb[2]) {
+        // we can merge dims 0, 1 and 2
+        v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_gqa, n_tokens);
+    } else {
+        // otherwise -> make a copy to get contiguous data
+        v_cur = ggml_cont_2d   (ctx, v_cur, n_embd_gqa, n_tokens);
+    }
+
+    // [TAG_V_CACHE_VARIABLE]
+    if (n_embd_gqa < v->ne[0]) {
+        v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_gqa, 0, 0, 0);
+    }
+
+    // in this branch the v_idxs are constructed in such a way that each row is a single head element
+    ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, ggml_nelements(v));
+
+    v_cur = ggml_reshape_2d(ctx, v_cur, 1, ggml_nelements(v_cur));
+
+    return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
+}
+
+ggml_tensor * llama_kv_cache::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
+    const uint32_t n_tokens = ubatch.n_tokens;
+
+    ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
+
+    ggml_set_input(k_idxs);
+
+    return k_idxs;
+}
+
+ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
+    const uint32_t n_tokens = ubatch.n_tokens;
+
+    ggml_tensor * v_idxs;
+
+    if (!v_trans) {
+        v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
+    } else {
+        v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa_max());
+    }
+
+    ggml_set_input(v_idxs);
+
+    return v_idxs;
+}
+
+void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
+    const uint32_t n_tokens = ubatch->n_tokens;
+    GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
+
+    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
+    int64_t * data = (int64_t *) dst->data;
+
+    for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
+        const int64_t offs = sinfo.strm[s]*get_size();
+
+        for (uint32_t i = 0; i < sinfo.size(); ++i) {
+            data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
+        }
+    }
+}
+
+void llama_kv_cache::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
+    const uint32_t n_tokens = ubatch->n_tokens;
+    GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
+
+    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
+    int64_t * data = (int64_t *) dst->data;
+
+    if (!v_trans) {
+        for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
+            const int64_t offs = sinfo.strm[s]*get_size();
+
+            for (uint32_t i = 0; i < sinfo.size(); ++i) {
+                data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
+            }
+        }
+    } else {
+        // note: the V cache is transposed when not using flash attention
+        const int64_t kv_size = get_size();
+
+        const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max();
+
+        for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
+            const int64_t offs = sinfo.strm[s]*kv_size*n_embd_v_gqa;
+
+            for (uint32_t i = 0; i < sinfo.size(); ++i) {
+                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+                    data[s*sinfo.size()*n_embd_v_gqa + i*n_embd_v_gqa + j] = offs + j*kv_size + sinfo.idxs[s][i];
+                }
+            }
+        }
+    }
+}
+
+void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const {
+    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
+
+    int32_t * data = (int32_t *) dst->data;
+
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        const auto & cells = v_cells[s];
+
+        for (uint32_t i = 0; i < cells.size(); ++i) {
+            data[s*cells.size() + i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
+        }
+    }
+}
+
+void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
+    const uint32_t n_tokens = ubatch->n_tokens;
+
+    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
+    float * data = (float *) dst->data;
+
+    const int64_t n_kv     = dst->ne[0];
+    const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch
+
+    GGML_ASSERT(n_tokens%n_stream == 0);
+
+    // n_tps == n_tokens_per_stream
+    const int64_t n_tps     = n_tokens/n_stream;
+    const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD);
+
+    std::fill(data, data + ggml_nelements(dst), -INFINITY);
+
+    // Use only the previous KV cells of the correct sequence for each token of the ubatch.
+    // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
+    // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
+    //   Causal mask:
+    //      xxx-------
+    //      xxxx------
+    //      xxxxx-----
+    //   Non-causal mask:
+    //      xxxxx-----
+    //      xxxxx-----
+    //      xxxxx-----
+    // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
+    // TODO: optimize this section
+    for (uint32_t h = 0; h < 1; ++h) {
+        for (uint32_t s = 0; s < n_stream; ++s) {
+            for (uint32_t ii = 0; ii < n_tps; ++ii) {
+                const uint32_t i = s*n_tps + ii;
+
+                const llama_seq_id seq_id = ubatch->seq_id[i][0];
+
+                const auto & cells = v_cells[seq_to_stream[seq_id]];
+
+                const llama_pos p1 = ubatch->pos[i];
+
+                const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
+
+                for (uint32_t j = 0; j < n_kv; ++j) {
+                    if (cells.is_empty(j)) {
+                        continue;
+                    }
+
+                    // mask the token if not the same sequence
+                    if (!cells.seq_has(j, seq_id)) {
+                        continue;
+                    }
+
+                    const llama_pos p0 = cells.pos_get(j);
+
+                    // mask future tokens
+                    if (causal_attn && p0 > p1) {
+                        continue;
+                    }
+
+                    // apply SWA if any
+                    if (is_masked_swa(p0, p1)) {
+                        continue;
+                    }
+
+                    data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
+                }
+            }
+        }
+    }
+}
+
+void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
+    const int64_t n_tokens = ubatch->n_tokens;
+
+    GGML_ASSERT(n_stream == 1 && "TODO: support multiple streams");
+    const auto & cells = v_cells[0];
+
+    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
+    GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
+
+    int32_t * data = (int32_t *) dst->data;
+
+    const int32_t n_kv = dst->ne[0];
+
+    for (int h = 0; h < 1; ++h) {
+        for (int i = 0; i < n_tokens; ++i) {
+            for (int j = 0; j < n_kv; ++j) {
+                // the position when the cells is empty is irrelevant - it will be masked out later in the attention
+                const llama_pos p0 = cells.is_empty(j) ? -1 : cells.pos_get(j);
+
+                data[h*(n_kv*n_tokens) + i*n_kv + j] = llama_relative_position_bucket(p0, ubatch->pos[i], hparams.n_rel_attn_bkts, false);
+            }
+        }
+    }
+}
+
+size_t llama_kv_cache::total_size() const {
+    size_t size = 0;
+
+    for (const auto & buf : bufs) {
+        size += ggml_backend_buffer_get_size(buf.get());
+    }
+
+    return size;
+}
+
+size_t llama_kv_cache::size_k_bytes() const {
+    size_t size_k_bytes = 0;
+
+    for (const auto & layer : layers) {
+        size_k_bytes += ggml_nbytes(layer.k);
+    }
+
+    return size_k_bytes;
+}
+
+size_t llama_kv_cache::size_v_bytes() const {
+    size_t size_v_bytes = 0;
+
+    for (const auto & layer : layers) {
+        size_v_bytes += ggml_nbytes(layer.v);
+    }
+
+    return size_v_bytes;
+}
+
+ggml_tensor * llama_kv_cache::build_rope_shift(
+        const llama_cparams & cparams,
+               ggml_context * ctx,
+                ggml_tensor * cur,
+                ggml_tensor * shift,
+                ggml_tensor * factors,
+                      float   freq_base,
+                      float   freq_scale) const {
+    const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
+
+    const auto & yarn_ext_factor = cparams.yarn_ext_factor;
+    const auto & yarn_beta_fast  = cparams.yarn_beta_fast;
+    const auto & yarn_beta_slow  = cparams.yarn_beta_slow;
+
+    const auto & n_rot     = hparams.n_rot;
+    const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE
+                                // @ngxson : this is a workaround
+                                // for M-RoPE, we want to rotate the whole vector when doing KV shift
+                                // a normal RoPE should work, we just need to use the correct ordering
+                                // ref: https://github.com/ggml-org/llama.cpp/pull/13870
+                                ? LLAMA_ROPE_TYPE_NEOX
+                                : hparams.rope_type;
+
+    // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
+    // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
+    const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2
+                                    ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
+                                    : cparams.yarn_attn_factor;
+
+    ggml_tensor * tmp;
+
+    if (ggml_is_quantized(cur->type)) {
+        // dequantize to f32 -> RoPE -> quantize back
+        tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
+
+        tmp = ggml_rope_ext(ctx, tmp,
+                shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
+
+        tmp = ggml_cpy(ctx, tmp, cur);
+    } else {
+        // we rotate only the first n_rot dimensions
+        tmp = ggml_rope_ext_inplace(ctx, cur,
+                shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
+    }
+
+    return tmp;
+}
+
+class llm_graph_input_k_shift : public llm_graph_input_i {
+public:
+    llm_graph_input_k_shift(const llama_kv_cache * kv_self) : kv_self(kv_self) {}
+    virtual ~llm_graph_input_k_shift() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * k_shift; // I32 [kv_size*n_stream]
+
+    const llama_kv_cache * kv_self;
+};
+
+void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
+    GGML_UNUSED(ubatch);
+
+    if (k_shift) {
+        kv_self->set_input_k_shift(k_shift);
+    }
+}
+
+ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
+    auto * ctx = res->get_ctx();
+    auto * gf  = res->get_gf();
+
+    const auto & n_embd_head_k = hparams.n_embd_head_k;
+  //const auto & n_embd_head_v = hparams.n_embd_head_v;
+
+    auto inp = std::make_unique<llm_graph_input_k_shift>(this);
+
+    inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
+    ggml_set_input(inp->k_shift);
+
+    const auto & cparams = lctx->get_cparams();
+
+    for (const auto & layer : layers) {
+        const uint32_t il = layer.il;
+
+        const int64_t n_head_kv    = hparams.n_head_kv(il);
+        const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
+
+        const float freq_base_l  = model.get_rope_freq_base (cparams, il);
+        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
+
+        ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
+
+        ggml_tensor * k =
+            ggml_view_3d(ctx, layer.k,
+                n_embd_head_k, n_head_kv, get_size()*n_stream,
+                ggml_row_size(layer.k->type, n_embd_head_k),
+                ggml_row_size(layer.k->type, n_embd_k_gqa),
+                0);
+
+        ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
+
+        ggml_build_forward_expand(gf, cur);
+    }
+
+    res->add_input(std::move(inp));
+
+    return gf;
+}
+
+bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
+    return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1);
+}
+
+void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
+    GGML_UNUSED(flags);
+
+    io.write(&n_stream, sizeof(n_stream));
+
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        cell_ranges_t cr { s, {} };
+
+        uint32_t cell_count = 0;
+
+        const auto & cells = v_cells[s];
+
+        // Count the number of cells with the specified seq_id
+        // Find all the ranges of cells with this seq id (or all, when -1)
+        uint32_t cell_range_begin = cells.size();
+
+        for (uint32_t i = 0; i < cells.size(); ++i) {
+            if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
+                ++cell_count;
+                if (cell_range_begin == cells.size()) {
+                    cell_range_begin = i;
+                }
+            } else {
+                if (cell_range_begin != cells.size()) {
+                    cr.data.emplace_back(cell_range_begin, i);
+                    cell_range_begin = cells.size();
+                }
+            }
+        }
+
+        if (cell_range_begin != cells.size()) {
+            cr.data.emplace_back(cell_range_begin, cells.size());
+        }
+
+        // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
+        uint32_t cell_count_check = 0;
+        for (const auto & range : cr.data) {
+            cell_count_check += range.second - range.first;
+        }
+        GGML_ASSERT(cell_count == cell_count_check);
+
+        io.write(&cell_count, sizeof(cell_count));
+
+        // skip empty streams
+        if (cell_count == 0) {
+            continue;
+        }
+
+        state_write_meta(io, cr, seq_id);
+        state_write_data(io, cr);
+    }
+}
+
+void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
+    GGML_UNUSED(flags);
+
+    GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
+
+    uint32_t n_stream_cur;
+    io.read_to(&n_stream_cur, sizeof(n_stream_cur));
+    if (n_stream_cur != n_stream) {
+        throw std::runtime_error("n_stream mismatch");
+    }
+
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        uint32_t cell_count;
+        io.read_to(&cell_count, sizeof(cell_count));
+
+        if (cell_count == 0) {
+            continue;
+        }
+
+        const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id];
+
+        bool res = true;
+        res = res && state_read_meta(io, strm, cell_count, seq_id);
+        res = res && state_read_data(io, strm, cell_count);
+
+        if (!res) {
+            if (seq_id == -1) {
+                clear(true);
+            } else {
+                seq_rm(seq_id, -1, -1);
+            }
+            throw std::runtime_error("failed to restore kv cache");
+        }
+    }
+}
+
+void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const {
+    const auto & cells = v_cells[cr.strm];
+
+    for (const auto & range : cr.data) {
+        for (uint32_t i = range.first; i < range.second; ++i) {
+            std::vector<llama_seq_id> seq_ids;
+
+            for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) {
+                if (cur == seq_id || seq_id == -1) {
+                    if (cells.seq_has(i, cur)) {
+                        seq_ids.push_back(cur);
+                    }
+                }
+            }
+
+            const llama_pos pos     = cells.pos_get(i);
+            const uint32_t n_seq_id = seq_ids.size();
+
+            io.write(&pos,      sizeof(pos));
+            io.write(&n_seq_id, sizeof(n_seq_id));
+
+            for (const auto & seq_id : seq_ids) {
+                io.write(&seq_id, sizeof(seq_id));
+            }
+        }
+    }
+}
+
+void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const {
+    const auto & cells = v_cells[cr.strm];
+
+    const uint32_t v_trans = this->v_trans ? 1 : 0;
+    const uint32_t n_layer = layers.size();
+
+    io.write(&v_trans, sizeof(v_trans));
+    io.write(&n_layer, sizeof(n_layer));
+
+    std::vector<uint8_t> tmp_buf;
+
+    // Iterate and write all the keys first, each row is a cell
+    // Get whole range at a time
+    for (const auto & layer : layers) {
+        const uint32_t il = layer.il;
+
+        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
+
+        auto * k = layer.k_stream[cr.strm];
+
+        // Write key type
+        const int32_t k_type_i = (int32_t) k->type;
+        io.write(&k_type_i, sizeof(k_type_i));
+
+        // Write row size of key
+        const uint64_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
+        io.write(&k_size_row, sizeof(k_size_row));
+
+        // Read each range of cells of k_size length each into tmp_buf and write out
+        for (const auto & range : cr.data) {
+            const size_t range_size = range.second - range.first;
+            const size_t buf_size = range_size * k_size_row;
+            io.write_tensor(k, range.first * k_size_row, buf_size);
+        }
+    }
+
+    if (!v_trans) {
+        for (const auto & layer : layers) {
+            const uint32_t il = layer.il;
+
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
+
+            auto * v = layer.v_stream[cr.strm];
+
+            // Write value type
+            const int32_t v_type_i = (int32_t) v->type;
+            io.write(&v_type_i, sizeof(v_type_i));
+
+            // Write row size of value
+            const uint64_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
+            io.write(&v_size_row, sizeof(v_size_row));
+
+            // Read each range of cells of v_size length each into tmp_buf and write out
+            for (const auto & range : cr.data) {
+                const size_t range_size = range.second - range.first;
+                const size_t buf_size = range_size * v_size_row;
+                io.write_tensor(v, range.first * v_size_row, buf_size);
+            }
+        }
+    } else {
+        // When v is transposed, we also need the element size and get the element ranges from each row
+        const uint32_t kv_size = cells.size();
+
+        for (const auto & layer : layers) {
+            const uint32_t il = layer.il;
+
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
+
+            auto * v = layer.v_stream[cr.strm];
+
+            // Write value type
+            const int32_t v_type_i = (int32_t) v->type;
+            io.write(&v_type_i, sizeof(v_type_i));
+
+            // Write element size
+            const uint32_t v_size_el = ggml_type_size(v->type);
+            io.write(&v_size_el, sizeof(v_size_el));
+
+            // Write GQA embedding size
+            io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
+
+            // For each row, we get the element values of each cell
+            for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+                // Read each range of cells of v_size_el length each into tmp_buf and write out
+                for (const auto & range : cr.data) {
+                    const size_t range_size = range.second - range.first;
+                    const size_t src_offset = (range.first + j * kv_size) * v_size_el;
+                    const size_t buf_size = range_size * v_size_el;
+                    io.write_tensor(v, src_offset, buf_size);
+                }
+            }
+        }
+    }
+}
+
+bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
+    auto & cells = v_cells[strm];
+    auto & head  = v_heads[strm];
+
+    if (dest_seq_id != -1) {
+        // single sequence
+        seq_rm(dest_seq_id, -1, -1);
+
+        llama_batch_allocr balloc(hparams.n_pos_per_embd());
+
+        llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
+
+        ubatch.seq_id_unq[0] = dest_seq_id;
+
+        for (uint32_t i = 0; i < cell_count; ++i) {
+            llama_pos pos;
+            uint32_t n_seq_id;
+
+            io.read_to(&pos,      sizeof(pos));
+            io.read_to(&n_seq_id, sizeof(n_seq_id));
+
+            if (n_seq_id != 1) {
+                LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
+                return false;
+            }
+
+            // read the sequence id, but directly discard it - we will use dest_seq_id instead
+            {
+                llama_seq_id seq_id;
+                io.read_to(&seq_id, sizeof(seq_id));
+            }
+
+            ubatch.pos[i]      = pos;
+            ubatch.n_seq_id[i] = n_seq_id;
+            ubatch.seq_id[i]   = &dest_seq_id;
+        }
+
+        const auto sinfo = find_slot(ubatch, true);
+        if (sinfo.empty()) {
+            LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
+            return false;
+        }
+
+        apply_ubatch(sinfo, ubatch);
+
+        const auto head_cur = sinfo.head();
+
+        // keep the head at the old position because we will read the KV data into it in state_read_data()
+        head = head_cur;
+
+        LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id);
+
+        // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
+        // Assume that this is one contiguous block of cells
+        GGML_ASSERT(head_cur + cell_count <= cells.size());
+        GGML_ASSERT(cells.pos_get(head_cur)                  == ubatch.pos[0]);
+        GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
+        GGML_ASSERT(cells.seq_has(head_cur,                  dest_seq_id));
+        GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
+    } else {
+        // whole KV cache restore
+
+        if (cell_count > cells.size()) {
+            LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
+            return false;
+        }
+
+        clear(true);
+
+        for (uint32_t i = 0; i < cell_count; ++i) {
+            llama_pos pos;
+            uint32_t  n_seq_id;
+
+            io.read_to(&pos,      sizeof(pos));
+            io.read_to(&n_seq_id, sizeof(n_seq_id));
+
+            cells.pos_set(i, pos);
+
+            for (uint32_t j = 0; j < n_seq_id; ++j) {
+                llama_seq_id seq_id;
+                io.read_to(&seq_id, sizeof(seq_id));
+
+                if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
+                    LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
+                    return false;
+                }
+
+                cells.seq_add(i, seq_id);
+            }
+        }
+
+        head = 0;
+    }
+
+    return true;
+}
+
+bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
+    auto & cells = v_cells[strm];
+    auto & head  = v_heads[strm];
+
+    uint32_t v_trans;
+    uint32_t n_layer;
+
+    io.read_to(&v_trans, sizeof(v_trans));
+    io.read_to(&n_layer, sizeof(n_layer));
+
+    if (n_layer != layers.size()) {
+        LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
+        return false;
+    }
+
+    if (cell_count > cells.size()) {
+        LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size());
+        return false;
+    }
+
+    if (this->v_trans != (bool) v_trans) {
+        LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
+        return false;
+    }
+
+    // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
+    for (const auto & layer : layers) {
+        const uint32_t il = layer.il;
+
+        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
+
+        auto * k = layer.k_stream[strm];
+
+        // Read type of key
+        int32_t k_type_i_ref;
+        io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
+        const int32_t k_type_i = (int32_t) k->type;
+        if (k_type_i != k_type_i_ref) {
+            LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
+            return false;
+        }
+
+        // Read row size of key
+        uint64_t k_size_row_ref;
+        io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
+        const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
+        if (k_size_row != k_size_row_ref) {
+            LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
+            return false;
+        }
+
+        if (cell_count) {
+            // Read and set the keys for the whole cell range
+            ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
+        }
+    }
+
+    if (!this->v_trans) {
+        for (const auto & layer : layers) {
+            const uint32_t il = layer.il;
+
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
+
+            auto * v = layer.v_stream[strm];
+
+            // Read type of value
+            int32_t v_type_i_ref;
+            io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
+            const int32_t v_type_i = (int32_t) v->type;
+            if (v_type_i != v_type_i_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
+                return false;
+            }
+
+            // Read row size of value
+            uint64_t v_size_row_ref;
+            io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
+            const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
+            if (v_size_row != v_size_row_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
+                return false;
+            }
+
+            if (cell_count) {
+                // Read and set the values for the whole cell range
+                ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
+            }
+        }
+    } else {
+        // For each layer, read the values for each cell (transposed)
+        for (const auto & layer : layers) {
+            const uint32_t il = layer.il;
+
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
+
+            auto * v = layer.v_stream[strm];
+
+            // Read type of value
+            int32_t v_type_i_ref;
+            io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
+            const int32_t v_type_i = (int32_t) v->type;
+            if (v_type_i != v_type_i_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
+                return false;
+            }
+
+            // Read element size of value
+            uint32_t v_size_el_ref;
+            io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
+            const size_t v_size_el = ggml_type_size(v->type);
+            if (v_size_el != v_size_el_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
+                return false;
+            }
+
+            // Read GQA embedding size
+            uint32_t n_embd_v_gqa_ref;
+            io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
+            if (n_embd_v_gqa != n_embd_v_gqa_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
+                return false;
+            }
+
+            if (cell_count) {
+                // For each row in the transposed matrix, read the values for the whole cell range
+                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+                    const size_t dst_offset = (head + j * cells.size()) * v_size_el;
+                    ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
+                }
+            }
+        }
+    }
+
+    return true;
+}
+
+//
+// llama_kv_cache_context
+//
+
+llama_kv_cache_context::llama_kv_cache_context(llama_memory_status status) : status(status) {}
+
+llama_kv_cache_context::llama_kv_cache_context(
+        llama_kv_cache * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
+    n_kv = kv->get_size();
+
+    const uint32_t n_stream = kv->get_n_stream();
+
+    // create a dummy slot info - the actual data is irrelevant. we just need to build the graph
+    sinfos.resize(1);
+    sinfos[0].s0 = 0;
+    sinfos[0].s1 = n_stream - 1;
+    sinfos[0].idxs.resize(n_stream);
+    for (uint32_t s = 0; s < n_stream; ++s) {
+        sinfos[0].strm.push_back(s);
+        sinfos[0].idxs[s].resize(1, 0);
+    }
+}
+
+llama_kv_cache_context::llama_kv_cache_context(
+        llama_kv_cache * kv,
+        llama_context * lctx,
+        bool do_shift,
+        stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), sc_info(std::move(sc_info)) {
+    if (!do_shift && this->sc_info.empty()) {
+        status = LLAMA_MEMORY_STATUS_NO_UPDATE;
+    }
+}
+
+llama_kv_cache_context::llama_kv_cache_context(
+        llama_kv_cache * kv,
+        llama_kv_cache::slot_info_vec_t sinfos,
+        std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
+}
+
+llama_kv_cache_context::~llama_kv_cache_context() = default;
+
+bool llama_kv_cache_context::next() {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    if (++i_cur >= ubatches.size()) {
+        return false;
+    }
+
+    return true;
+}
+
+bool llama_kv_cache_context::apply() {
+    assert(!llama_memory_status_is_fail(status));
+
+    // no ubatches -> this is a KV cache update
+    if (ubatches.empty()) {
+        kv->update(lctx, do_shift, sc_info);
+
+        return true;
+    }
+
+    kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
+    n_kv = kv->get_n_kv(sinfos[i_cur]);
+
+    return true;
+}
+
+llama_memory_status llama_kv_cache_context::get_status() const {
+    return status;
+}
+
+const llama_ubatch & llama_kv_cache_context::get_ubatch() const {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    return ubatches[i_cur];
+}
+
+uint32_t llama_kv_cache_context::get_n_kv() const {
+    return n_kv;
+}
+
+ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const {
+    return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
+}
+
+ggml_tensor * llama_kv_cache_context::get_v(ggml_context * ctx, int32_t il) const {
+    return kv->get_v(ctx, il, n_kv, sinfos[i_cur]);
+}
+
+ggml_tensor * llama_kv_cache_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
+    return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
+}
+
+ggml_tensor * llama_kv_cache_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
+    return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
+}
+
+ggml_tensor * llama_kv_cache_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
+    return kv->build_input_k_idxs(ctx, ubatch);
+}
+
+ggml_tensor * llama_kv_cache_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
+    return kv->build_input_v_idxs(ctx, ubatch);
+}
+
+void llama_kv_cache_context::set_input_k_shift(ggml_tensor * dst) const {
+    kv->set_input_k_shift(dst);
+}
+
+void llama_kv_cache_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
+    kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
+}
+
+void llama_kv_cache_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
+    kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
+}
+
+void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
+    kv->set_input_kq_mask(dst, ubatch, causal_attn);
+}
+
+void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
+    kv->set_input_pos_bucket(dst, ubatch);
+}
+
+uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) {
+    // the FA kernels require padding to avoid extra runtime boundary checks
+    return cparams.flash_attn ? 256u : 32u;
+}
index 2d04705f27857ea078761f3f065bf0bf1f3674f2..30de013f5f7f36b0901716ae63cc10e7bbd907eb 100644 (file)
 #pragma once
 
-#include "llama.h"
-#include "llama-io.h"
+#include "llama-batch.h"
+#include "llama-graph.h"
+#include "llama-kv-cells.h"
 #include "llama-memory.h"
 
-struct llama_kv_cache : public llama_memory_i {
-    virtual ~llama_kv_cache() = default;
+#include <unordered_map>
+#include <vector>
 
-    // split the input batch into a set of ubatches and verify that they can fit into the cache
-    // return a state object containing the ubatches and KV cache state required to process them
-    // check the llama_memory_state_i::get_status() for the result
-    virtual llama_memory_state_ptr init_batch(
-            const llama_batch & batch,
+struct llama_cparams;
+struct llama_hparams;
+struct llama_model;
+struct llama_context;
+
+//
+// llama_kv_cache
+//
+
+class llama_kv_cache : public llama_memory_i {
+public:
+    static uint32_t get_padding(const llama_cparams & cparams);
+
+    struct stream_copy_info {
+        bool empty() const {
+            assert(ssrc.size() == sdst.size());
+            return ssrc.empty();
+        }
+
+        std::vector<uint32_t> ssrc;
+        std::vector<uint32_t> sdst;
+    };
+
+    // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
+    //   KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
+    struct slot_info {
+        // data for ggml_set_rows
+        using idx_vec_t = std::vector<uint32_t>;
+
+        // number of streams: ns = s1 - s0 + 1
+        uint32_t s0;
+        uint32_t s1;
+
+        std::vector<llama_seq_id> strm; // [ns]
+        std::vector<idx_vec_t>    idxs; // [ns]
+
+        uint32_t head() const {
+            GGML_ASSERT(idxs.size() == 1);
+            GGML_ASSERT(!idxs[0].empty());
+
+            return idxs[0][0];
+        }
+
+        void resize(size_t n) {
+            strm.resize(n);
+            idxs.resize(n);
+        }
+
+        size_t size() const {
+            GGML_ASSERT(idxs.size() == strm.size());
+            GGML_ASSERT(!idxs.empty());
+
+            return idxs[0].size();
+        }
+
+        size_t n_stream() const {
+            return strm.size();
+        }
+
+        bool empty() const {
+            return idxs.empty();
+        }
+
+        void clear() {
+            idxs.clear();
+        }
+    };
+
+    using slot_info_vec_t = std::vector<slot_info>;
+
+    llama_kv_cache(
+            const llama_model & model,
+                    ggml_type   type_k,
+                    ggml_type   type_v,
+                         bool   v_trans,
+                         bool   offload,
+                         bool   unified,
+                     uint32_t   kv_size,
+                     uint32_t   n_seq_max,
+                     uint32_t   n_pad,
+                     uint32_t   n_swa,
+               llama_swa_type   swa_type,
+        const layer_filter_cb & filter,
+        const  layer_reuse_cb & reuse);
+
+    ~llama_kv_cache() = default;
+
+    //
+    // llama_memory_i
+    //
+
+    llama_memory_context_ptr init_batch(
+            llama_batch_allocr & balloc,
             uint32_t n_ubatch,
-            bool embd_pooled,
-            bool logits_all) = 0;
+            bool embd_all) override;
+
+    llama_memory_context_ptr init_full() override;
+
+    llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
+
+    bool get_can_shift() const override;
+
+    void clear(bool data) override;
+
+    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
+    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
+    void seq_keep(llama_seq_id seq_id)                                                          override;
+    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
+    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
+
+    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
+    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
 
-    // simulate full cache, used for allocating worst-case compute buffers
-    virtual llama_memory_state_ptr init_full() = 0;
+    // state write/load
 
-    // process any pending defrag/shift/etc. operations
-    // optionally call once before processing a new batch
-    // return true if any operations were performed
-    virtual bool update(llama_context & lctx) = 0;
+    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
+    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
 
-    // schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
-    // TODO: change to
-    //   llama_memory_state_ptr init_defrag(float thold) = 0;
     //
-    virtual void defrag_sched(float thold) = 0;
+    // llama_kv_cache specific API
+    //
+
+    uint32_t get_size()     const;
+    uint32_t get_n_stream() const;
+
+    bool get_has_shift() const;
+
+    //
+    // graph_build API
+    //
+
+    uint32_t get_n_kv(const slot_info & sinfo) const;
+
+    // get views of the current state of the cache
+    ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
+    ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
+
+    // store k_cur and v_cur in the cache based on the provided head location
+    ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
+    ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;
+
+    //
+    // preparation API
+    //
+
+    // find places for the provided ubatches in the cache, returns the slot infos
+    // return empty vector on failure
+    slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
+
+    bool update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info);
+
+    // find a slot of kv cells that can hold the ubatch
+    // if cont == true, then the slot must be continuous
+    // return empty slot_info on failure
+    slot_info find_slot(const llama_ubatch & ubatch, bool cont) const;
+
+    // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]
+    void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
+
+    //
+    // input API
+    //
+
+    ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
+    ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
+
+    void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
+    void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
+
+    void set_input_k_shift(ggml_tensor * dst) const;
+
+    void set_input_kq_mask   (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
+    void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
+
+private:
+    const llama_model & model;
+    const llama_hparams & hparams;
+
+    struct kv_layer {
+        // layer index in the model
+        // note: can be different from the layer index in the KV cache
+        uint32_t il;
+
+        ggml_tensor * k;
+        ggml_tensor * v;
+
+        std::vector<ggml_tensor *> k_stream;
+        std::vector<ggml_tensor *> v_stream;
+    };
+
+    bool v_trans = true;  // the value tensor is transposed
+
+    const uint32_t n_seq_max = 1;
+    const uint32_t n_stream  = 1;
+
+    // required padding
+    const uint32_t n_pad = 1;
+
+    // SWA
+    const uint32_t n_swa = 0;
+
+    // env: LLAMA_KV_CACHE_DEBUG
+    int debug = 0;
+
+    // this is the SWA type of the cache - not to be confused with the model SWA type
+    const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
+
+    std::vector<ggml_context_ptr>        ctxs;
+    std::vector<ggml_backend_buffer_ptr> bufs;
+
+    // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
+    // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
+    std::vector<uint32_t> v_heads;
+
+    std::vector<llama_kv_cells> v_cells;
+
+    // maps from a sequence id to a stream id
+    std::vector<uint32_t> seq_to_stream;
+
+    // pending stream copies that will be applied during the next update
+    stream_copy_info sc_info;
+
+    std::vector<kv_layer> layers;
+
+    // model layer id -> KV cache layer id
+    std::unordered_map<int32_t, int32_t> map_layer_ids;
+
+    size_t total_size() const;
+
+    size_t size_k_bytes() const;
+    size_t size_v_bytes() const;
+
+    bool is_masked_swa(llama_pos p0, llama_pos p1) const;
+
+    ggml_tensor * build_rope_shift(
+            const llama_cparams & cparams,
+                   ggml_context * ctx,
+                    ggml_tensor * cur,
+                    ggml_tensor * shift,
+                    ggml_tensor * factors,
+                          float   freq_base,
+                          float   freq_scale) const;
+
+    ggml_cgraph * build_graph_shift(
+               llm_graph_result * res,
+                  llama_context * lctx) const;
+
+    struct cell_ranges_t {
+        uint32_t strm;
+
+        std::vector<std::pair<uint32_t, uint32_t>> data; // ranges, from inclusive, to exclusive
+    };
+
+    void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
+    void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
+
+    bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
+    bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
+};
+
+class llama_kv_cache_context : public llama_memory_context_i {
+public:
+    // some shorthands
+    using slot_info_vec_t  = llama_kv_cache::slot_info_vec_t;
+    using stream_copy_info = llama_kv_cache::stream_copy_info;
+
+    // used for errors
+    llama_kv_cache_context(llama_memory_status status);
+
+    // used to create a full-cache context
+    llama_kv_cache_context(
+            llama_kv_cache * kv);
+
+    // used to create an update context
+    llama_kv_cache_context(
+            llama_kv_cache * kv,
+            llama_context * lctx,
+            bool do_shift,
+            stream_copy_info sc_info);
+
+    // used to create a batch procesing context from a batch
+    llama_kv_cache_context(
+            llama_kv_cache * kv,
+            slot_info_vec_t sinfos,
+            std::vector<llama_ubatch> ubatches);
+
+    virtual ~llama_kv_cache_context();
+
+    //
+    // llama_memory_context_i
+    //
+
+    bool next()  override;
+    bool apply() override;
+
+    llama_memory_status  get_status() const override;
+    const llama_ubatch & get_ubatch() const override;
+
+    //
+    // llama_kv_cache_context specific API
+    //
+
+    uint32_t get_n_kv() const;
+
+    // get views of the current state of the cache
+    ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
+    ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
+
+    // store k_cur and v_cur in the cache based on the provided head location
+    // note: the heads in k_cur and v_cur should be layed out contiguously in memory
+    //   - k_cur  [n_embd_head_k, n_head_k, n_tokens]
+    //   - k_idxs [n_tokens]
+    //   - v_cur  [n_embd_head_v, n_head_v, n_tokens]
+    //   - v_idxs [n_tokens] or [n_tokens*n_embd_v_gqa] depending if V cache is transposed
+    ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
+    ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
+
+    // create destination indices for each head of the current batch for where it would be written in the KV cache
+    // the indices address the global KV cache (not per stream) - this is not relevant for the user of this API, but
+    //   helps understand the implementation logic of cpy_k and cpy_v
+    ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
+    ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
+
+    void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
+    void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
+
+    void set_input_k_shift   (ggml_tensor * dst) const;
+    void set_input_kq_mask   (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
+    void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
+
+private:
+    llama_memory_status status;
+
+    llama_kv_cache * kv;
+    llama_context * lctx;
+
+    //
+    // update context
+    //
+
+    bool do_shift = false;
+
+    stream_copy_info sc_info;
+
+    //
+    // batch processing context
+    //
+
+    // the index of the cur ubatch to process
+    size_t i_cur = 0;
 
-    // getters
-    virtual bool get_can_shift() const = 0;
+    slot_info_vec_t sinfos;
 
-    bool get_can_edit() const override { return get_can_shift(); }
+    std::vector<llama_ubatch> ubatches;
 
     //
-    // state write/read
+    // data needed for building the compute graph for the current ubatch:
     //
 
-    virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
-    virtual void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1) = 0;
+    // a heuristic, to avoid attending the full cache if it is not yet utilized
+    // as the cache gets filled, the benefit from this heuristic disappears
+    int32_t n_kv;
 };
index 0d0dd316fd0415f0952952750f510a3361f8f9ea..8f6bf01456c8fb734230b00170f6d2b84b869d89 100644 (file)
@@ -11,7 +11,7 @@
 
 // meta information about KV cells that can be part of multiple sequences at the same time
 // TODO: add unit tests
-class llama_kv_cells_unified {
+class llama_kv_cells {
 public:
     void reset() {
         for (uint32_t i = 0; i < pos.size(); ++i) {
@@ -77,30 +77,30 @@ public:
     }
 
     // move cell isrc to idst (used during defrag)
-    void mv(uint32_t isrc, uint32_t idst) {
-        assert(isrc < pos.size());
-        assert(idst < pos.size());
+    //void mv(uint32_t isrc, uint32_t idst) {
+    //    assert(isrc < pos.size());
+    //    assert(idst < pos.size());
 
-        assert(pos[idst] == -1);
-        assert(pos[isrc] != -1);
+    //    assert(pos[idst] == -1);
+    //    assert(pos[isrc] != -1);
 
-        pos  [idst] = pos  [isrc];
-        shift[idst] = shift[isrc];
-        seq  [idst] = seq  [isrc];
+    //    pos  [idst] = pos  [isrc];
+    //    shift[idst] = shift[isrc];
+    //    seq  [idst] = seq  [isrc];
 
-        pos  [isrc] = -1;
-        shift[isrc] =  0;
-        seq  [isrc].reset();
+    //    pos  [isrc] = -1;
+    //    shift[isrc] =  0;
+    //    seq  [isrc].reset();
 
-        used.erase (isrc);
-        used.insert(idst);
-    }
+    //    used.erase (isrc);
+    //    used.insert(idst);
+    //}
 
     // copy the state of cells [i, i + n) (used for save/restore the state of the cells)
-    llama_kv_cells_unified cp(uint32_t i, uint32_t n) const {
+    llama_kv_cells cp(uint32_t i, uint32_t n) const {
         assert(i + n <= pos.size());
 
-        llama_kv_cells_unified res;
+        llama_kv_cells res;
 
         res.resize(n);
 
@@ -117,8 +117,8 @@ public:
     }
 
     // copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
-    llama_kv_cells_unified cp(const std::vector<uint32_t> & idxs) const {
-        llama_kv_cells_unified res;
+    llama_kv_cells cp(const std::vector<uint32_t> & idxs) const {
+        llama_kv_cells res;
 
         res.resize(idxs.size());
 
@@ -135,7 +135,7 @@ public:
     }
 
     // set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
-    void set(uint32_t i, const llama_kv_cells_unified & other) {
+    void set(uint32_t i, const llama_kv_cells & other) {
         assert(i + other.pos.size() <= pos.size());
 
         for (uint32_t j = 0; j < other.pos.size(); ++j) {
@@ -165,7 +165,7 @@ public:
     }
 
     // set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
-    void set(const std::vector<uint32_t> & idxs, const llama_kv_cells_unified & other) {
+    void set(const std::vector<uint32_t> & idxs, const llama_kv_cells & other) {
         assert(idxs.size() == other.pos.size());
 
         for (uint32_t j = 0; j < other.pos.size(); ++j) {
index cbeeb21344ecede21046e05666b2b4863f1a4187..ba61ebaa885feffc62ac012563736f59fa4325eb 100644 (file)
@@ -9,32 +9,29 @@
 //
 
 llama_memory_hybrid::llama_memory_hybrid(
-    const llama_model & model,
-                         /* attn */
-            ggml_type    type_k,
-            ggml_type    type_v,
-                 bool    v_trans,
-             uint32_t    kv_size,
-             uint32_t    n_pad,
-             uint32_t    n_swa,
-       llama_swa_type    swa_type,
-                         /* recurrent */
-            ggml_type    type_r,
-            ggml_type    type_s,
-             uint32_t    rs_size,
-                         /* common */
-             uint32_t    n_seq_max,
-                 bool    offload,
-                 bool    unified,
-                         /* layer filters */
-      layer_filter_cb && filter_attn,
-      layer_filter_cb && filter_recr) :
+        const llama_model & model,
+                            /* attn */
+                ggml_type   type_k,
+                ggml_type   type_v,
+                     bool   v_trans,
+                 uint32_t   kv_size,
+                 uint32_t   n_pad,
+                 uint32_t   n_swa,
+           llama_swa_type   swa_type,
+                            /* recurrent */
+                ggml_type   type_r,
+                ggml_type   type_s,
+                 uint32_t   rs_size,
+                            /* common */
+                 uint32_t   n_seq_max,
+                     bool   offload,
+                     bool   unified,
+                            /* layer filters */
+    const layer_filter_cb & filter_attn,
+    const layer_filter_cb & filter_recr) :
     hparams(model.hparams),
-    mem_attn(new llama_kv_cache_unified(
+    mem_attn(new llama_kv_cache(
         model,
-        filter_attn == nullptr ?
-            [&](int32_t il) { return !hparams.is_recurrent(il); }
-            : filter_attn,
         type_k,
         type_v,
         v_trans,
@@ -44,18 +41,22 @@ llama_memory_hybrid::llama_memory_hybrid(
         n_seq_max,
         n_pad,
         n_swa,
-        swa_type
+        swa_type,
+        filter_attn == nullptr ?
+            [&](int32_t il) { return !hparams.is_recurrent(il); }
+            : filter_attn,
+        nullptr
     )),
     mem_recr(new llama_memory_recurrent(
         model,
-        filter_recr == nullptr ?
-            [&](int32_t il) { return hparams.is_recurrent(il); }
-            : filter_recr,
         type_r,
         type_s,
         offload,
         rs_size,
-        n_seq_max
+        n_seq_max,
+        filter_recr == nullptr ?
+            [&](int32_t il) { return hparams.is_recurrent(il); }
+            : filter_recr
     )) {}
 
 llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
@@ -179,7 +180,7 @@ void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id,
     mem_recr->state_read(io, seq_id);
 }
 
-llama_kv_cache_unified * llama_memory_hybrid::get_mem_attn() const {
+llama_kv_cache * llama_memory_hybrid::get_mem_attn() const {
     return mem_attn.get();
 }
 
@@ -210,7 +211,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
         std::vector<llama_ubatch>   ubatches) :
     ubatches(std::move(ubatches)),
     // note: here we copy the ubatches. not sure if this is ideal
-    ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
+    ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
     ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(),                        this->ubatches)),
     status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
 }
@@ -248,8 +249,8 @@ const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
     return ubatches[i_next];
 }
 
-const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const {
-    return static_cast<const llama_kv_cache_unified_context *>(ctx_attn.get());
+const llama_kv_cache_context * llama_memory_hybrid_context::get_attn() const {
+    return static_cast<const llama_kv_cache_context *>(ctx_attn.get());
 }
 
 const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {
index acdbc26bfb624c1a7ac22ce55d4808c6e3398148..11a35651782974023d48be9ddba531c0c79e0a26 100644 (file)
@@ -2,7 +2,7 @@
 
 #include "llama-batch.h"
 #include "llama-graph.h"
-#include "llama-kv-cache-unified.h"
+#include "llama-kv-cache.h"
 #include "llama-memory.h"
 #include "llama-memory-recurrent.h"
 
 // llama_memory_hybrid
 //
 
-// utilizes instances of llama_memory_recurrent and llama_kv_cache_unified to
+// utilizes instances of llama_memory_recurrent and llama_kv_cache to
 //   support models where each layer may be either attention-based or recurrent
 
 class llama_memory_hybrid : public llama_memory_i {
 public:
-
-    // this callback is used to filter out layers that should not be included in the cache
-    using layer_filter_cb = std::function<bool(int32_t il)>;
-
     llama_memory_hybrid(
         const llama_model & model,
                             /* attn */
-                ggml_type    type_k,
-                ggml_type    type_v,
-                     bool    v_trans,
-                 uint32_t    kv_size,
-                 uint32_t    n_pad,
-                 uint32_t    n_swa,
-           llama_swa_type    swa_type,
-                             /* recurrent */
-                ggml_type    type_r,
-                ggml_type    type_s,
-                 uint32_t    rs_size,
-                             /* common */
-                 uint32_t    n_seq_max,
-                     bool    offload,
-                     bool    unified,
-                             /* layer filters */
-          layer_filter_cb && filter_attn = nullptr,
-          layer_filter_cb && filter_recr = nullptr);
+                ggml_type   type_k,
+                ggml_type   type_v,
+                     bool   v_trans,
+                 uint32_t   kv_size,
+                 uint32_t   n_pad,
+                 uint32_t   n_swa,
+           llama_swa_type   swa_type,
+                            /* recurrent */
+                ggml_type   type_r,
+                ggml_type   type_s,
+                 uint32_t   rs_size,
+                            /* common */
+                 uint32_t   n_seq_max,
+                     bool   offload,
+                     bool   unified,
+                            /* layer filters */
+    const layer_filter_cb & filter_attn = nullptr,
+    const layer_filter_cb & filter_recr = nullptr);
 
     ~llama_memory_hybrid() = default;
 
@@ -81,19 +77,19 @@ public:
     // llama_memory_hybrid specific API
     //
 
-    llama_kv_cache_unified * get_mem_attn() const;
+    llama_kv_cache * get_mem_attn() const;
     llama_memory_recurrent * get_mem_recr() const;
 
 private:
     const llama_hparams & hparams;
 
-    const std::unique_ptr<llama_kv_cache_unified> mem_attn;
+    const std::unique_ptr<llama_kv_cache> mem_attn;
     const std::unique_ptr<llama_memory_recurrent> mem_recr;
 };
 
 class llama_memory_hybrid_context : public llama_memory_context_i {
 public:
-    using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
+    using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
 
     // init failure
     explicit llama_memory_hybrid_context(llama_memory_status status);
@@ -125,7 +121,7 @@ public:
     // llama_memory_hybrid_context
     //
 
-    const llama_kv_cache_unified_context * get_attn() const;
+    const llama_kv_cache_context * get_attn() const;
     const llama_memory_recurrent_context * get_recr() const;
 
 private:
index 849675c418891d9da3436d5c2a6a035cb7b95ce9..08716ed91aed124fcb71c70fec4be6d2ec8e94c7 100644 (file)
 //
 
 llama_memory_recurrent::llama_memory_recurrent(
-        const llama_model &  model,
-          layer_filter_cb && filter,
-                ggml_type    type_r,
-                ggml_type    type_s,
-                     bool    offload,
-                 uint32_t    mem_size,
-                 uint32_t    n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
+        const llama_model & model,
+                ggml_type   type_r,
+                ggml_type   type_s,
+                     bool   offload,
+                 uint32_t   mem_size,
+                 uint32_t   n_seq_max,
+    const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) {
     const int32_t n_layer = hparams.n_layer;
 
     head = 0;
index 95c617b2c94bdb53480d7129a0f1474e7cfbbfea..c4daf00495bc2a43192914e4b19dbdf3079708af 100644 (file)
 //
 
 // TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i
-//       see the implementation of llama_kv_cache_unified_context_i for an example how to do it
+//       see the implementation of llama_kv_cache_context_i for an example how to do it
 class llama_memory_recurrent : public llama_memory_i {
 public:
-
-    // this callback is used to filter out layers that should not be included in the cache
-    using layer_filter_cb = std::function<bool(int32_t il)>;
-
     llama_memory_recurrent(
-            const llama_model &  model,
-              layer_filter_cb && filter,
-                    ggml_type    type_r,
-                    ggml_type    type_s,
-                         bool    offload,
-                     uint32_t    mem_size,
-                     uint32_t    n_seq_max);
+            const llama_model & model,
+                    ggml_type   type_r,
+                    ggml_type   type_s,
+                         bool   offload,
+                     uint32_t   mem_size,
+                     uint32_t   n_seq_max,
+        const layer_filter_cb & filter);
 
     ~llama_memory_recurrent() = default;
 
index 171d312cc99d91c3b0665322de6f8e59f2c29de0..ccd1f073b0848c28920dfcd909c753e5482f4c3a 100644 (file)
@@ -3,6 +3,7 @@
 #include "llama.h"
 
 #include <memory>
+#include <functional>
 
 struct llama_ubatch;
 
@@ -36,8 +37,8 @@ bool llama_memory_status_is_fail(llama_memory_status status);
 
 // the interface for managing the memory context during batch processing
 // this interface is implemented per memory type. see:
-//   - llama_kv_cache_unified_context
-//   - llama_kv_cache_unified_iswa_context
+//   - llama_kv_cache_context
+//   - llama_kv_cache_iswa_context
 //   ...
 //
 // the only method that should mutate the memory and the memory context is llama_memory_i::apply()
@@ -64,6 +65,13 @@ using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>;
 // general concept of LLM memory
 // the KV cache is a type of LLM memory, but there can be other types
 struct llama_memory_i {
+    // this callback is used to filter out layers that should not be included in the cache
+    using layer_filter_cb = std::function<bool(int32_t il)>;
+
+    // this callback is used to specify which layers should reuse memory from other layers
+    // return negative value to indicate that the layer il should not reuse memory
+    using layer_reuse_cb = std::function<int32_t(int32_t il)>;
+
     virtual ~llama_memory_i() = default;
 
     // split the input batch into a set of ubatches and verify that they can fit into the cache
@@ -77,7 +85,7 @@ struct llama_memory_i {
     // simulate full cache, used for allocating worst-case compute buffers
     virtual llama_memory_context_ptr init_full() = 0;
 
-    // prepare for any pending memory updates, such as shifts, defrags, etc.
+    // prepare for any pending memory updates, such as shifts, copies, etc.
     // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
     virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0;
 
@@ -109,8 +117,3 @@ struct llama_memory_i {
 };
 
 using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
-
-// TODO: temporary until the llama_kv_cache is removed from the public API
-struct llama_kv_cache : public llama_memory_i {
-    virtual ~llama_kv_cache() = default;
-};
index f71c40f8e3f330cc18c2dbc752d32540be029efa..8182a9adf53a63e119a55d1cff3c321b0fece578 100644 (file)
@@ -788,6 +788,7 @@ const struct ggml_tensor * llama_model_loader::check_tensor_dims(const std::stri
 }
 
 struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list<int64_t> & ne, int flags) {
+    LLAMA_LOG_DEBUG("%s: loading tensor %s\n", __func__, name.c_str());
     const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED));
 
     if (cur == NULL) {
index 23a26f0c64ea6d463f57582cd13c4d835080198b..981e57083c48d90610fd413496e90e14d858ad7f 100644 (file)
@@ -6,8 +6,8 @@
 #include "llama-cparams.h"
 #include "llama-model-loader.h"
 
-#include "llama-kv-cache-unified.h"
-#include "llama-kv-cache-unified-iswa.h"
+#include "llama-kv-cache.h"
+#include "llama-kv-cache-iswa.h"
 #include "llama-memory-hybrid.h"
 #include "llama-memory-recurrent.h"
 
@@ -36,6 +36,7 @@ const char * llm_type_name(llm_type type) {
         case LLM_TYPE_80M:           return "80M";
         case LLM_TYPE_109M:          return "109M";
         case LLM_TYPE_137M:          return "137M";
+        case LLM_TYPE_140M:          return "140M";
         case LLM_TYPE_160M:          return "160M";
         case LLM_TYPE_190M:          return "190M";
         case LLM_TYPE_220M:          return "220M";
@@ -44,12 +45,15 @@ const char * llm_type_name(llm_type type) {
         case LLM_TYPE_270M:          return "270M";
         case LLM_TYPE_335M:          return "335M";
         case LLM_TYPE_350M:          return "350M";
+        case LLM_TYPE_360M:          return "360M";
         case LLM_TYPE_410M:          return "410M";
         case LLM_TYPE_450M:          return "450M";
         case LLM_TYPE_475M:          return "475M";
+        case LLM_TYPE_558M:          return "558M";
         case LLM_TYPE_700M:          return "700M";
         case LLM_TYPE_770M:          return "770M";
         case LLM_TYPE_780M:          return "780M";
+        case LLM_TYPE_950M:          return "950M";
         case LLM_TYPE_0_3B:          return "0.3B";
         case LLM_TYPE_0_5B:          return "0.5B";
         case LLM_TYPE_0_6B:          return "0.6B";
@@ -83,9 +87,11 @@ const char * llm_type_name(llm_type type) {
         case LLM_TYPE_32B:           return "32B";
         case LLM_TYPE_34B:           return "34B";
         case LLM_TYPE_35B:           return "35B";
+        case LLM_TYPE_36B:           return "36B";
         case LLM_TYPE_40B:           return "40B";
         case LLM_TYPE_65B:           return "65B";
         case LLM_TYPE_70B:           return "70B";
+        case LLM_TYPE_120B:          return "120B";
         case LLM_TYPE_142B:          return "142B";
         case LLM_TYPE_236B:          return "236B";
         case LLM_TYPE_290B:          return "290B";
@@ -619,19 +625,32 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,  hparams.n_ff_exp);
                 ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP,   hparams.n_moe_layer_step);
 
-                hparams.swa_type      = LLAMA_SWA_TYPE_CHUNKED;
-                hparams.n_swa         = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
-                hparams.set_swa_pattern(4);   // pattern: 3 chunked - 1 full
+                const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
+                if (found_swa && hparams.n_swa == 0) {
+                    hparams.swa_type             = LLAMA_SWA_TYPE_NONE;
+                    hparams.n_no_rope_layer_step = hparams.n_layer; // always use rope
+                } else {
+                    hparams.swa_type      = LLAMA_SWA_TYPE_CHUNKED;
+                    hparams.n_swa         = 8192;
+                    hparams.set_swa_pattern(4);   // pattern: 3 chunked - 1 full
+                }
 
                 switch (hparams.n_expert) {
+                    case 0: {
+                        // MobileLLM (no MoE)
+                        switch (hparams.n_embd) {
+                            case 2048: type = LLM_TYPE_140M; break;
+                            case 4096: type = LLM_TYPE_360M; break;
+                            case 6144: type = LLM_TYPE_950M; break;
+                            default:   type = LLM_TYPE_UNKNOWN;
+                        }
+                    } break;
                     case 16:  type = LLM_TYPE_17B_16E; break;
                     case 128: type = LLM_TYPE_17B_128E; break;
                     default:  type = LLM_TYPE_UNKNOWN;
                 }
 
-                if (type == LLM_TYPE_17B_128E) {
-                    hparams.use_kq_norm = false;
-                }
+                hparams.use_kq_norm = type != LLM_TYPE_17B_128E;
             } break;
         case LLM_ARCH_ARCEE:
             {
@@ -682,7 +701,30 @@ void llama_model::load_hparams(llama_model_loader & ml) {
             } break;
         case LLM_ARCH_GROK:
             {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                // defaults for old GGUFs
+                hparams.yarn_beta_fast = 8.0f;
+                hparams.f_logit_scale = 0.5773502691896257f;
+                hparams.f_embedding_scale = 78.38367176906169f;
+                hparams.f_attn_out_scale = 0.08838834764831845f;
+                hparams.f_attn_logit_softcapping = 30.0f;
+                hparams.f_router_logit_softcapping = 30.0f;
+                // no final_logit_softcapping in grok-1
+                hparams.f_final_logit_softcapping = 0.0f;
+
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,  hparams.f_norm_rms_eps);
+                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,   hparams.n_ff_exp, false);
+                ml.get_key(LLM_KV_LOGIT_SCALE,                  hparams.f_logit_scale, false);
+                ml.get_key(LLM_KV_EMBEDDING_SCALE,              hparams.f_embedding_scale, false);
+                ml.get_key(LLM_KV_ATTENTION_OUTPUT_SCALE,       hparams.f_attn_out_scale, false);
+                ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING,       hparams.f_attn_logit_softcapping, false);
+                ml.get_key(LLM_KV_ROUTER_LOGIT_SOFTCAPPING,     hparams.f_router_logit_softcapping, false);
+                ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING,      hparams.f_final_logit_softcapping, false);
+
+                ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH,  hparams.attn_temp_length, false);
+                ml.get_key(LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR,  hparams.yarn_ext_factor, false);
+                ml.get_key(LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, hparams.yarn_attn_factor, false);
+                ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST,   hparams.yarn_beta_fast, false);
+                ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW,   hparams.yarn_beta_slow, false);
 
                 switch (hparams.n_layer) {
                     case 64: type = LLM_TYPE_314B; break;
@@ -770,6 +812,18 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_JINA_BERT_V3:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps);
+                ml.get_key(LLM_KV_ATTENTION_CAUSAL,           hparams.causal_attn);
+                ml.get_key(LLM_KV_POOLING_TYPE,               hparams.pooling_type, false);
+
+                switch (hparams.n_layer) {
+                    case 24:
+                        type = LLM_TYPE_558M; break;
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_NOMIC_BERT:
         case LLM_ARCH_NOMIC_BERT_MOE:
             {
@@ -898,6 +952,18 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 hparams.causal_attn = false;
             }
             break;
+        case LLM_ARCH_LLADA_MOE:
+            {
+                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
+
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                // diffusion language model uses non-causal attention
+                hparams.causal_attn = false;
+                switch (hparams.n_layer) {
+                    case 16: type = LLM_TYPE_A1_7B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_QWEN2MOE:
             {
                 ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,        hparams.n_ff_exp, false);
@@ -1095,7 +1161,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
 
                 switch (hparams.n_layer) {
-                    case 18: type = LLM_TYPE_537M; break;
+                    case 18: type = LLM_TYPE_270M; break;
                     case 26: type = LLM_TYPE_1B; break;
                     case 34: type = LLM_TYPE_4B; break;
                     case 48: type = LLM_TYPE_12B; break;
@@ -1113,6 +1179,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
                 hparams.set_swa_pattern(5);
 
+                hparams.n_layer_kv_from_start     = 20;
                 hparams.rope_freq_base_train_swa  = 10000.0f;
                 hparams.rope_freq_scale_train_swa = 1.0f;
                 hparams.f_attention_scale         = 1.0f;
@@ -1126,6 +1193,26 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_GEMMA_EMBEDDING:
+            {
+                hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC;
+                hparams.set_swa_pattern(6);
+
+                hparams.causal_attn = false; // embeddings do not use causal attention
+                hparams.rope_freq_base_train_swa  = 10000.0f;
+                hparams.rope_freq_scale_train_swa = 1.0f;
+
+                ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW,    hparams.n_swa);
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                ml.get_key(LLM_KV_POOLING_TYPE,                hparams.pooling_type);
+
+                switch (hparams.n_layer) {
+                    case 24: type = LLM_TYPE_0_3B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+                hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k));
+
+            } break;
         case LLM_ARCH_STARCODER2:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -1279,6 +1366,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
 
+                const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
+                if (found_swa && hparams.n_swa > 0) {
+                    hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
+                    hparams.set_swa_pattern(4);
+                } else {
+                    hparams.swa_type = LLAMA_SWA_TYPE_NONE;
+                }
+
                 switch (hparams.n_layer) {
                     case 16: type = LLM_TYPE_1B; break;
                     case 32: type = LLM_TYPE_7B; break;
@@ -1287,6 +1382,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_SEED_OSS:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                switch (hparams.n_layer) {
+                    case 64: type = LLM_TYPE_36B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_OLMOE:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -1464,12 +1567,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 // Expert gating function (GLM-4.5 uses sigmoid)
                 ml.get_key(LLM_KV_EXPERT_GATING_FUNC,          hparams.expert_gating_func, false);
                 if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
-                    hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
+                    hparams.expert_gating_func =  LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
                 }
 
                 // NextN/MTP parameters
                 ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS,        hparams.nextn_predict_layers, false);
 
+                // TODO: when MTP is implemented, this should probably be updated if needed
+                hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
+
                 switch (hparams.n_layer) {
                     case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer)
                     case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer)
@@ -1495,6 +1601,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     hparams.dec_start_token_id = dec_start_token_id;
                 }
 
+                hparams.dec_n_layer = hparams.n_layer;
+                ml.get_key(LLM_KV_DECODER_BLOCK_COUNT, hparams.dec_n_layer, false);
+
                 switch (hparams.n_layer) {
                     case 6:  type = LLM_TYPE_60M;  break; // t5-small
                     case 8:  type = LLM_TYPE_80M;  break; // flan-t5-small
@@ -1543,6 +1652,27 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_NEMOTRON_H:
+            {
+                ml.get_key(LLM_KV_SSM_CONV_KERNEL,    hparams.ssm_d_conv);
+                ml.get_key(LLM_KV_SSM_INNER_SIZE,     hparams.ssm_d_inner);
+                ml.get_key(LLM_KV_SSM_STATE_SIZE,     hparams.ssm_d_state);
+                ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
+                ml.get_key(LLM_KV_SSM_GROUP_COUNT,    hparams.ssm_n_group);
+
+                // A layer is recurrent IFF the n_head_kv value is set to 0 and
+                // the n_ff value is set to 0
+                for (uint32_t i = 0; i < hparams.n_layer; ++i) {
+                    hparams.recurrent_layer_arr[i] = (hparams.n_head_kv(i) == 0 && hparams.n_ff(i) == 0);
+                }
+
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+                switch (hparams.n_layer) {
+                    case 56: type = LLM_TYPE_9B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_EXAONE:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -1834,7 +1964,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
                 hparams.set_swa_pattern(2);
 
-                // TODO: switch (hparams.n_layer)
+                switch (hparams.n_layer) {
+                    case 24: type = LLM_TYPE_20B; break;
+                    case 36: type = LLM_TYPE_120B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
             } break;
         case LLM_ARCH_LFM2:
             {
@@ -2289,6 +2423,40 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                     }
                 }
                 break;
+            case LLM_ARCH_LLADA_MOE:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for llada-moe");
+                    GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for llada-moe");
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
+                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+
+                        const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
+
+                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
+                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+                    }
+                } break;
             case LLM_ARCH_LLAMA4:
                 {
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -2302,9 +2470,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
                     }
 
-                    GGML_ASSERT(hparams.n_moe_layer_step > 0 && "Llama 4 requires n_moe_layer_step > 0");
                     for (int i = 0; i < n_layer; ++i) {
-                        bool is_moe_layer = (i + 1) % hparams.n_moe_layer_step == 0;
+                        bool is_moe_layer = hparams.n_moe_layer_step > 0 && (i + 1) % hparams.n_moe_layer_step == 0;
 
                         auto & layer = layers[i];
 
@@ -2465,6 +2632,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
                     }
 
+                    const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff/* / n_expert_used*/; // grok-1 n_ff_exp == n_ff
                     for (int i = 0; i < n_layer; ++i) {
                         auto & layer = layers[i];
 
@@ -2479,12 +2647,19 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
 
                         layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
 
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff,   n_embd}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
+
                         layer.ffn_gate_inp  = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert}, 0);
-                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED);
-                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert}, 0);
-                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert}, 0);
+                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd,   n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd,   n_expert}, 0);
+                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff_exp, n_expert}, 0);
 
-                        layer.layer_out_norm   = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
+                        if (!layer.ffn_post_norm) {
+                            layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
+                        }
                     }
                 } break;
             case LLM_ARCH_DBRX:
@@ -2613,6 +2788,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
             case LLM_ARCH_BERT:
             case LLM_ARCH_NOMIC_BERT:
             case LLM_ARCH_NOMIC_BERT_MOE:
+            case LLM_ARCH_JINA_BERT_V3:
                 {
                     tok_embd     = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, 0);
                     type_embd    = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED);
@@ -2648,24 +2824,22 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         }
 
                         layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {n_embd, n_embd}, 0);
+                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT,      "bias", i),   {n_embd}, TENSOR_NOT_REQUIRED);
 
                         layer.attn_out_norm   = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0);
                         layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i),   {n_embd}, 0);
 
                         if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) {
-                            layer.bo         = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
                             layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff,   n_expert}, 0);
                             layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff,   n_embd, n_expert}, 0);
                             layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,   "weight", i), {n_embd, n_expert}, 0);
                         } else {
-                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,        "weight", i), {n_embd, n_ff}, 0);
-                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN,      "weight", i), {n_ff, n_embd}, 0);
-
-                            if (arch == LLM_ARCH_BERT || arch == LLM_ARCH_NOMIC_BERT_MOE) {
-                                layer.bo         = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
-                                layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, 0);
-                                layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
-                            } else {
+                            layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
+                            layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, TENSOR_NOT_REQUIRED);
+                            layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                            layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, TENSOR_NOT_REQUIRED);
+
+                            if (arch == LLM_ARCH_NOMIC_BERT) {
                                 layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
                             }
                         }
@@ -3433,6 +3607,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                     }
                 } break;
             case LLM_ARCH_GEMMA3:
+            case LLM_ARCH_GEMMA_EMBEDDING:
                 {
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
@@ -3962,6 +4137,43 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
                     }
                 } break;
+            case LLM_ARCH_SEED_OSS:
+                {
+                    const uint32_t head_dim             = hparams.n_embd_head_k;
+                    const int64_t n_qo_dim              = n_head * head_dim;
+                    const int64_t n_kv_dim              = n_head_kv * head_dim;
+
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_qo_dim}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_kv_dim}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_kv_dim}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, 0);
+
+                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_qo_dim},   TENSOR_NOT_REQUIRED);
+                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_kv_dim},   TENSOR_NOT_REQUIRED);
+                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_kv_dim},   TENSOR_NOT_REQUIRED);
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                    }
+                } break;
+
             case LLM_ARCH_OLMOE:
                 {
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -4305,6 +4517,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
                     }
 
+                    // n_layer:     number of encoder_layers
+                    // dec_n_layer: number of decoder_layers
+                    const int dec_n_layer = hparams.dec_n_layer;
+                    if (dec_n_layer > n_layer) {
+                        layers.resize(dec_n_layer);
+                    }
+
+                    // load encoder layers
                     for (int i = 0; i < n_layer; ++i) {
                         auto & layer = layers[i];
 
@@ -4320,6 +4540,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd,   n_ff}, TENSOR_NOT_REQUIRED);
                         layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
                         layer.ffn_up_enc   = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+
+                    // load decoder layers
+                    for (int i = 0; i < dec_n_layer; ++i) {
+                        auto & layer = layers[i];
 
                         layer.attn_norm  = create_tensor(tn(LLM_TENSOR_DEC_ATTN_NORM,  "weight", i), {n_embd}, 0);
                         layer.attn_rel_b = create_tensor(tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED);
@@ -4621,6 +4846,75 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, TENSOR_NOT_REQUIRED);
                     }
                 } break;
+            case LLM_ARCH_NEMOTRON_H:
+                {
+                    // mamba2 Mixer SSM params
+                    // NOTE: int64_t for tensor dimensions
+                    const int64_t d_conv     = hparams.ssm_d_conv;
+                    const int64_t d_inner    = hparams.ssm_d_inner;
+                    const int64_t d_state    = hparams.ssm_d_state;
+                    const int64_t n_ssm_head = hparams.ssm_dt_rank;
+                    const int64_t n_group    = hparams.ssm_n_group;
+                    const int64_t d_in_proj  = 2*d_inner + 2*n_group*d_state + n_ssm_head;
+
+                    // embeddings
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    {
+                        output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                        output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                        // if output is NULL, init from the input tok embed, duplicated to allow offloading
+                        if (output == NULL) {
+                            output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                        }
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        // all blocks use the attn norm
+                        layer.attn_norm  = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        if (hparams.is_recurrent(i)) {
+                            // ssm layers
+                            layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0);
+
+                            layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0);
+                            layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, TENSOR_NOT_REQUIRED);
+
+                            layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0);
+
+                            // no "weight" suffix for these
+                            layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0);
+                            layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0);
+
+                            layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0);
+
+                            // out_proj
+                            layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
+                        } else if (hparams.n_ff(i) == 0) {
+                            // attention layers (with optional bias)
+                            const int64_t n_head_i = hparams.n_head(i);
+                            const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i);
+                            const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i);
+                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head_i}, 0);
+                            layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa_i}, 0);
+                            layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa_i}, 0);
+                            layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0);
+                            layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias",   i), {n_embd},         TENSOR_NOT_REQUIRED);
+                            layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias",   i), {n_embd_k_gqa_i}, TENSOR_NOT_REQUIRED);
+                            layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias",   i), {n_embd_v_gqa_i}, TENSOR_NOT_REQUIRED);
+                            layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias",   i), {n_embd},         TENSOR_NOT_REQUIRED);
+                        } else {
+                            // mlp layers
+                            layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  hparams.n_ff(i), n_embd}, 0);
+                            layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   hparams.n_ff(i)}, 0);
+                            layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias",   i), {n_embd}, TENSOR_NOT_REQUIRED);
+                            layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias",   i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED);
+                        }
+                    }
+                } break;
             case LLM_ARCH_EXAONE:
                 {
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -5469,8 +5763,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                 } break;
             case LLM_ARCH_LFM2:
                 {
-                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab}, 0);
                     tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
+                    output   = create_tensor(tn(LLM_TENSOR_OUTPUT,          "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
 
                     for (int i = 0; i < n_layer; ++i) {
                         auto & layer = layers[i];
@@ -5790,7 +6089,8 @@ void llama_model::print_info() const {
         arch == LLM_ARCH_JAMBA ||
         arch == LLM_ARCH_FALCON_H1 ||
         arch == LLM_ARCH_PLAMO2 ||
-        arch == LLM_ARCH_GRANITE_HYBRID) {
+        arch == LLM_ARCH_GRANITE_HYBRID ||
+        arch == LLM_ARCH_NEMOTRON_H) {
         LLAMA_LOG_INFO("%s: ssm_d_conv       = %u\n",     __func__, hparams.ssm_d_conv);
         LLAMA_LOG_INFO("%s: ssm_d_inner      = %u\n",     __func__, hparams.ssm_d_inner);
         LLAMA_LOG_INFO("%s: ssm_d_state      = %u\n",     __func__, hparams.ssm_d_state);
@@ -5981,7 +6281,7 @@ struct llm_build_llama : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
 
@@ -6043,9 +6343,17 @@ struct llm_build_llama : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
+                if (hparams.use_kq_norm) {
+                    // Llama4TextL2Norm
+                    Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps);
+                    Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps);
+                    cb(Qcur, "Qcur_normed", il);
+                    cb(Kcur, "Kcur_normed", il);
+                }
+
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
                 cb(cur, "attn_out", il);
             }
 
@@ -6141,7 +6449,7 @@ struct llm_build_llama_iswa : public llm_graph_context {
         ggml_tensor * inp_attn_scale = nullptr;
         inp_attn_scale = build_inp_attn_scale();
 
-        auto * inp_attn = build_attn_inp_kv_unified_iswa();
+        auto * inp_attn = build_attn_inp_kv_iswa();
 
         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
 
@@ -6150,7 +6458,8 @@ struct llm_build_llama_iswa : public llm_graph_context {
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
 
-            const bool use_rope = (il + 1) % hparams.n_no_rope_layer_step != 0;
+            const bool use_rope = hparams.n_no_rope_layer_step > 0 &&
+                                  (il + 1) % hparams.n_no_rope_layer_step != 0;
 
             // norm
             cur = build_norm(inpL,
@@ -6219,7 +6528,7 @@ struct llm_build_llama_iswa : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
                 cb(cur, "attn_out", il);
             }
 
@@ -6320,7 +6629,7 @@ struct llm_build_deci : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
 
@@ -6396,7 +6705,7 @@ struct llm_build_deci : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -6476,7 +6785,7 @@ struct llm_build_baichuan : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = model.type == LLM_TYPE_7B ? build_inp_pos() : nullptr;
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -6528,7 +6837,7 @@ struct llm_build_baichuan : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -6598,7 +6907,7 @@ struct llm_build_xverse : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -6643,7 +6952,7 @@ struct llm_build_xverse : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -6712,7 +7021,7 @@ struct llm_build_falcon : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -6743,9 +7052,7 @@ struct llm_build_falcon : public llm_graph_context {
 
                 ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
                 ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
-                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
-                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+                ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
 
                 // using mode = 2 for neox mode
                 Qcur = ggml_rope_ext(
@@ -6766,7 +7073,7 @@ struct llm_build_falcon : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -6830,13 +7137,10 @@ struct llm_build_grok : public llm_graph_context {
 
         inpL = build_inp_embd(model.tok_embd);
 
-        // multiply by embedding_multiplier_scale of 78.38367176906169
-        inpL = ggml_scale(ctx0, inpL, 78.38367176906169f);
-
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -6896,7 +7200,7 @@ struct llm_build_grok : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -6904,26 +7208,22 @@ struct llm_build_grok : public llm_graph_context {
                 inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
             }
 
-            // Grok
-            // if attn_out_norm is present then apply it before adding the input
-            if (model.layers[il].attn_out_norm) {
-                cur = build_norm(cur,
-                        model.layers[il].attn_out_norm, NULL,
-                        LLM_NORM_RMS, il);
-                cb(cur, "attn_out_norm", il);
-            }
+            cur = build_norm(cur,
+                    model.layers[il].attn_out_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_out_norm", il);
 
             ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
             cb(ffn_inp, "ffn_inp", il);
 
             // feed-forward network
-            // MoE branch
             cur = build_norm(ffn_inp,
                     model.layers[il].ffn_norm, NULL,
                     LLM_NORM_RMS, il);
             cb(cur, "ffn_norm", il);
 
-            cur = build_moe_ffn(cur,
+            // MoE branch
+            ggml_tensor * moe_out = build_moe_ffn(cur,
                     model.layers[il].ffn_gate_inp,
                     model.layers[il].ffn_up_exps,
                     model.layers[il].ffn_gate_exps,
@@ -6934,18 +7234,28 @@ struct llm_build_grok : public llm_graph_context {
                     false, 0.0,
                     LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                     il);
-            cb(cur, "ffn_moe_out", il);
+            cb(moe_out, "ffn_moe_out", il);
 
-            // Grok
-            // if layer_out_norm is present then apply it before adding the input
-            // Idea: maybe ffn_out_norm is a better name
-            if (model.layers[il].layer_out_norm) {
-                cur = build_norm(cur,
-                        model.layers[il].layer_out_norm, NULL,
-                        LLM_NORM_RMS, il);
-                cb(cur, "layer_out_norm", il);
+            if (model.layers[il].ffn_up) {
+                ggml_tensor * ffn_out = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_PAR, il);
+                cb(ffn_out, "ffn_out", il);
+
+                cur = ggml_scale(ctx0, ggml_add(ctx0, ffn_out, moe_out), std::sqrt(2) / 2);
+                cb(cur, "ffn_out", il);
+            } else {
+                cur = moe_out;
             }
 
+            cur = build_norm(cur,
+                    model.layers[il].ffn_post_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_post_norm", il);
+
             cur = ggml_add(ctx0, cur, ffn_inp);
             cb(cur, "ffn_out", il);
 
@@ -6968,10 +7278,14 @@ struct llm_build_grok : public llm_graph_context {
         // lm_head
         cur = build_lora_mm(model.output, cur);
 
-        // Grok
-        // multiply logits by output_multiplier_scale of 0.5773502691896257
+        cur = ggml_scale(ctx0, cur, hparams.f_logit_scale);
 
-        cur = ggml_scale(ctx0, cur, 0.5773502691896257f);
+        // final logit soft-capping
+        if (hparams.f_final_logit_softcapping) {
+            cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
+            cur = ggml_tanh(ctx0, cur);
+            cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
+        }
 
         cb(cur, "result_output", -1);
         res->t_logits = cur;
@@ -6996,7 +7310,7 @@ struct llm_build_dbrx : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -7023,9 +7337,7 @@ struct llm_build_dbrx : public llm_graph_context {
 
                 Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
                 Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
-                Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
-                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
 
                 Qcur = ggml_rope_ext(
                         ctx0, Qcur, inp_pos, nullptr,
@@ -7045,7 +7357,7 @@ struct llm_build_dbrx : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -7120,7 +7432,7 @@ struct llm_build_starcoder : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
         cb(pos, "pos_embd", -1);
@@ -7145,13 +7457,9 @@ struct llm_build_starcoder : public llm_graph_context {
                 cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
                 cb(cur, "bqkv", il);
 
-                ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+                ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
+                ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
+                ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
 
                 cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
@@ -7159,7 +7467,7 @@ struct llm_build_starcoder : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -7225,7 +7533,7 @@ struct llm_build_refact : public llm_graph_context {
 
         inpL = build_inp_embd(model.tok_embd);
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -7258,7 +7566,7 @@ struct llm_build_refact : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -7367,13 +7675,17 @@ struct llm_build_bert : public llm_graph_context {
                         cb(cur, "bqkv", il);
                     }
 
-                    Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                    Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                    Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+                    Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
+                    Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
+                    Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
                 } else {
                     Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq);
                     Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk);
                     Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv);
+
+                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                    Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
                 }
 
                 if (model.layers[il].attn_q_norm) {
@@ -7381,6 +7693,8 @@ struct llm_build_bert : public llm_graph_context {
                             model.layers[il].attn_q_norm,
                             model.layers[il].attn_q_norm_b,
                             LLM_NORM, il);
+
+                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                 }
 
                 if (model.layers[il].attn_k_norm) {
@@ -7388,14 +7702,12 @@ struct llm_build_bert : public llm_graph_context {
                             model.layers[il].attn_k_norm,
                             model.layers[il].attn_k_norm_b,
                             LLM_NORM, il);
-                }
 
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                }
 
                 // RoPE
-                if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
+                if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE || model.arch == LLM_ARCH_JINA_BERT_V3) {
                     Qcur = ggml_rope_ext(
                             ctx0, Qcur, inp_pos, nullptr,
                             n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -7415,7 +7727,7 @@ struct llm_build_bert : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -7454,7 +7766,7 @@ struct llm_build_bert : public llm_graph_context {
                         0.0f,
                         LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
                 cb(cur, "ffn_moe_out", il);
-            } else if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
+            } else if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE || model.arch == LLM_ARCH_JINA_BERT_V3) {
                 cur = build_ffn(cur,
                         model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
                         NULL,                      NULL,                        NULL,
@@ -7537,9 +7849,7 @@ struct llm_build_neo_bert : public llm_graph_context {
 
                 Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
                 Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
-                Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
-                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
 
                 // RoPE
                 Qcur = ggml_rope_ext(
@@ -7560,7 +7870,7 @@ struct llm_build_neo_bert : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, nullptr,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -7621,7 +7931,7 @@ struct llm_build_bloom : public llm_graph_context {
 
         inpL = build_inp_embd(model.tok_embd);
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         inpL = build_norm(inpL,
                 model.tok_norm,
@@ -7646,13 +7956,9 @@ struct llm_build_bloom : public llm_graph_context {
                 cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
                 cb(cur, "bqkv", il);
 
-                ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+                ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
+                ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
+                ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
 
                 cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
@@ -7660,7 +7966,7 @@ struct llm_build_bloom : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -7728,7 +8034,7 @@ struct llm_build_mpt : public llm_graph_context {
 
         inpL = build_inp_embd(model.tok_embd);
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         if (model.pos_embd) {
             // inp_pos - contains the positions
@@ -7768,13 +8074,9 @@ struct llm_build_mpt : public llm_graph_context {
                     cb(cur, "wqkv_clamped", il);
                 }
 
-                ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd));
-                ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd));
-                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
+                ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
+                ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
+                ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
 
                 // Q/K Layernorm
                 if (model.layers[il].attn_q_norm) {
@@ -7782,32 +8084,23 @@ struct llm_build_mpt : public llm_graph_context {
                             model.layers[il].attn_q_norm,
                             model.layers[il].attn_q_norm_b,
                             LLM_NORM, il);
-                    cb(Qcur, "Qcur", il);
 
                     Kcur = build_norm(Kcur,
                             model.layers[il].attn_k_norm,
                             model.layers[il].attn_k_norm_b,
                             LLM_NORM, il);
-                    cb(Kcur, "Kcur", il);
-                } else {
-                    Qcur = ggml_cont(ctx0, Qcur);
-                    cb(Qcur, "Qcur", il);
 
-                    Kcur = ggml_cont(ctx0, Kcur);
-                    cb(Kcur, "Kcur", il);
+                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
                 }
 
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
-
                 cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -7877,7 +8170,7 @@ struct llm_build_stablelm : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -7953,7 +8246,7 @@ struct llm_build_stablelm : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -8029,7 +8322,7 @@ struct llm_build_qwen : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -8049,11 +8342,9 @@ struct llm_build_qwen : public llm_graph_context {
                 cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
                 cb(cur, "bqkv", il);
 
-                ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,   n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
+                ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
                 ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
-                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 2*sizeof(float)*(n_embd)));
-
-                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+                ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 2*sizeof(float)*(n_embd));
 
                 // using mode = 2 for neox mode
                 Qcur = ggml_rope_ext(
@@ -8074,7 +8365,7 @@ struct llm_build_qwen : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -8144,7 +8435,7 @@ struct llm_build_qwen2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -8194,7 +8485,7 @@ struct llm_build_qwen2 : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -8308,8 +8599,9 @@ struct llm_build_dream : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr,
-                                 nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
+                cur = build_attn(inp_attn,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -8408,8 +8700,9 @@ struct llm_build_llada : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr,
-                                 1.0f / sqrtf(float(n_embd_head)), il);
+                cur = build_attn(inp_attn,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -8469,7 +8762,7 @@ struct llm_build_qwen2vl : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         int sections[4];
         std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
@@ -8522,7 +8815,7 @@ struct llm_build_qwen2vl : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -8590,7 +8883,7 @@ struct llm_build_qwen2moe : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -8649,7 +8942,7 @@ struct llm_build_qwen2moe : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -8749,7 +9042,7 @@ struct llm_build_qwen3 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -8802,7 +9095,7 @@ struct llm_build_qwen3 : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -8870,7 +9163,7 @@ struct llm_build_qwen3moe : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -8923,7 +9216,7 @@ struct llm_build_qwen3moe : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -9000,7 +9293,7 @@ struct llm_build_phi2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -9026,21 +9319,17 @@ struct llm_build_phi2 : public llm_graph_context {
 
                     Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
                     Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
-                    Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+                    Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
                 } else {
                     Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq);
                     Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk);
                     Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv);
+
                     Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                     Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                    Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
                 }
 
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
-                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
-
                 Qcur = ggml_rope_ext(
                         ctx0, Qcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -9063,7 +9352,7 @@ struct llm_build_phi2 : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -9129,13 +9418,13 @@ struct llm_build_phi3 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_unified_iswa, llm_graph_input_attn_kv_unified>;
+        using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>;
         inp_attn_type * inp_attn = nullptr;
 
         if constexpr (iswa) {
-            inp_attn = build_attn_inp_kv_unified_iswa();
+            inp_attn = build_attn_inp_kv_iswa();
         } else {
-            inp_attn = build_attn_inp_kv_unified();
+            inp_attn = build_attn_inp_kv();
         }
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -9164,21 +9453,17 @@ struct llm_build_phi3 : public llm_graph_context {
 
                     Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head * sizeof(float), cur->nb[1], 0 * sizeof(float) * (n_embd));
                     Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd));
-                    Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)));
+                    Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa));
                 } else {
                     Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq);
                     Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk);
                     Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv);
+
                     Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                     Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                    Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
                 }
 
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
-                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
-
                 Qcur = ggml_rope_ext(
                         ctx0, Qcur, inp_pos, rope_factors,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -9200,7 +9485,7 @@ struct llm_build_phi3 : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -9287,7 +9572,7 @@ struct llm_build_plamo : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -9334,7 +9619,7 @@ struct llm_build_plamo : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -9403,7 +9688,7 @@ struct llm_build_gpt2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
         cb(pos, "pos_embd", -1);
@@ -9428,21 +9713,17 @@ struct llm_build_gpt2 : public llm_graph_context {
                 cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
                 cb(cur, "bqkv", il);
 
-                ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+                ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
+                ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
+                ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
 
                 cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
-
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -9513,7 +9794,7 @@ struct llm_build_codeshell : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -9534,9 +9815,7 @@ struct llm_build_codeshell : public llm_graph_context {
 
                 ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
                 ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
-                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
-                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+                ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
 
                 Qcur = ggml_rope_ext(
                         ctx0, Qcur, inp_pos, nullptr,
@@ -9556,7 +9835,7 @@ struct llm_build_codeshell : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -9626,7 +9905,7 @@ struct llm_build_orion : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -9685,7 +9964,7 @@ struct llm_build_orion : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -9753,7 +10032,7 @@ struct llm_build_internlm2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -9812,7 +10091,7 @@ struct llm_build_internlm2 : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -9889,7 +10168,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -10000,7 +10279,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
-                        q_states, k_states, v_states, nullptr, nullptr, kq_scale, il);
+                        q_states, k_states, v_states, nullptr, nullptr, nullptr, kq_scale, il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -10084,7 +10363,7 @@ struct llm_build_gemma : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -10130,7 +10409,7 @@ struct llm_build_gemma : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -10200,7 +10479,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified_iswa();
+        auto * inp_attn = build_attn_inp_kv_iswa();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -10245,7 +10524,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -10334,7 +10613,7 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
         ggml_tensor * inp_pos = build_inp_pos();
 
         // TODO: is causal == true correct? might need some changes
-        auto * inp_attn = build_attn_inp_kv_unified_iswa();
+        auto * inp_attn = build_attn_inp_kv_iswa();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -10387,7 +10666,7 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -10459,7 +10738,6 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
     const int64_t n_embd_altup;
     const int64_t n_altup;
     const int     i_altup_act;
-    const int     n_layer_kv = 20; // number of layers having KV [KV_REUSE]
     const int     n_layer_sparsity = 10; // number of layers using activation sparsity
     const float   f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95)
 
@@ -10485,7 +10763,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
         ggml_tensor * inp_pos = build_inp_pos();
 
         // TODO: is causal == true correct? might need some changes
-        auto * inp_attn = build_attn_inp_kv_unified_iswa();
+        auto * inp_attn = build_attn_inp_kv_iswa();
 
         // inp_per_layer shape: [n_embd_altup, n_tokens, n_layer]
         ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs());
@@ -10509,8 +10787,6 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
 
         for (int il = 0; il < n_layer; ++il) {
             // this block is made to be closely resemble Gemma3p5DecoderLayer on python code
-            const bool has_kv = (il < n_layer_kv);
-
             const float freq_base_l  = model.get_rope_freq_base (cparams, il);
             const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
 
@@ -10530,7 +10806,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
             ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens]
 
             // self-attention
-            if (has_kv) {
+            if (hparams.has_kv(il)) {
                 // compute Q and K and RoPE them
                 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
@@ -10568,9 +10844,9 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
             } else {
-                // no KV layers
+                // reuse KV cache of earlier layers
                 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
@@ -10586,7 +10862,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                     model.layers[il].wo, NULL,
-                    Qcur, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
+                    Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
             }
 
             cur = build_norm(cur,
@@ -10864,8 +11140,8 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
         ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities); // [n_altup, n_tokens]
         all_coefs = ggml_scale_bias(ctx0, all_coefs, 1.0f, 1.0f); // + 1.0
         cb(all_coefs, "all_coefs", il);
-        all_coefs = ggml_cont(ctx0, ggml_transpose(ctx0, all_coefs)); // [n_tokens, n_altup]
-        all_coefs = ggml_reshape_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup]
+        all_coefs = ggml_transpose(ctx0, all_coefs); // [n_tokens, n_altup]
+        all_coefs = ggml_cont_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup]
 
         innovation = ggml_repeat_4d(ctx0, innovation, n_embd, n_tokens, n_altup, 1);
         ggml_tensor * corrected = ggml_mul(ctx0, innovation, all_coefs); // [n_embd, n_tokens, n_altup]
@@ -10876,6 +11152,137 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
     }
 };
 
+struct llm_build_gemma_embedding_iswa : public llm_graph_context {
+    llm_build_gemma_embedding_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_k;
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
+        if (ubatch.token) {
+            inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
+            cb(inpL, "inp_scaled", -1);
+        }
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        // TODO: support cacheless iSWA embeddings [TAG_NO_CACHE_ISWA]
+        auto * inp_attn = build_attn_inp_kv_iswa();
+
+        ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+        for (int il = 0; il < n_layer; ++il) {
+            const float freq_base_l  = model.get_rope_freq_base (cparams, il);
+            const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
+
+            // norm
+            cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+                cb(Qcur, "Qcur_normed", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+
+                Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+                cb(Kcur, "Kcur_normed", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315
+                Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
+
+                cur = build_attn(inp_attn,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il);
+            }
+
+            if (il == n_layer - 1 && inp_out_ids) {
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            cur = build_norm(cur,
+                    model.layers[il].attn_post_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_post_norm", il);
+
+            ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
+            cb(sa_out, "sa_out", il);
+
+            cur = build_norm(sa_out,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            // feed-forward network
+            {
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = build_norm(cur,
+                    model.layers[il].ffn_post_norm, NULL,
+                    LLM_NORM_RMS, -1);
+            cb(cur, "ffn_post_norm", -1);
+
+            cur = ggml_add(ctx0, cur, sa_out);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
 // TODO: move up next to build_starcoder
 struct llm_build_starcoder2 : public llm_graph_context {
     llm_build_starcoder2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
@@ -10892,7 +11299,7 @@ struct llm_build_starcoder2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -10951,7 +11358,7 @@ struct llm_build_starcoder2 : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -11378,7 +11785,9 @@ struct llm_build_jamba : public llm_graph_context_mamba {
                 cb(Vcur, "Vcur", il);
 
                 // No RoPE :)
-                cur = build_attn(inp_hybrid->get_attn(), model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il);
+                cur = build_attn(inp_hybrid->get_attn(),
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, NULL, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -11461,7 +11870,7 @@ struct llm_build_command_r : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -11536,7 +11945,7 @@ struct llm_build_command_r : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -11608,7 +12017,7 @@ struct llm_build_cohere2_iswa : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified_iswa();
+        auto * inp_attn = build_attn_inp_kv_iswa();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -11671,7 +12080,7 @@ struct llm_build_cohere2_iswa : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -11743,7 +12152,7 @@ struct llm_build_olmo : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -11802,7 +12211,7 @@ struct llm_build_olmo : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, nullptr,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -11856,6 +12265,7 @@ struct llm_build_olmo : public llm_graph_context {
     }
 };
 
+template <bool iswa>
 struct llm_build_olmo2 : public llm_graph_context {
     llm_build_olmo2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -11871,7 +12281,14 @@ struct llm_build_olmo2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>;
+        inp_attn_type * inp_attn = nullptr;
+
+        if constexpr (iswa) {
+            inp_attn = build_attn_inp_kv_iswa();
+        } else {
+            inp_attn = build_attn_inp_kv();
+        }
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -11904,17 +12321,36 @@ struct llm_build_olmo2 : public llm_graph_context {
                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
                 Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
-                Qcur = ggml_rope_ext(
+                const bool is_swa = hparams.is_swa(il);
+
+                if (is_swa) {
+                    // For sliding window layers, Olmo3 use regular rope with no yarn rope scaling.
+                    // This is achieved here by setting freq_scale and attn_factor to 1.
+                    // We also set ext_factor to 0 to avoid a few unnecessary computations.
+                    Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, 1.0,
+                        0.0, 1.0, beta_fast, beta_slow
+                        );
+
+                    Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, 1.0,
+                        0.0, 1.0, beta_fast, beta_slow
+                        );
+                } else {
+                    Qcur = ggml_rope_ext(
                         ctx0, Qcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
 
-                Kcur = ggml_rope_ext(
+                    Kcur = ggml_rope_ext(
                         ctx0, Kcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
+                }
 
                 cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
@@ -11922,7 +12358,7 @@ struct llm_build_olmo2 : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -12000,7 +12436,7 @@ struct llm_build_olmoe : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -12055,7 +12491,7 @@ struct llm_build_olmoe : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -12113,30 +12549,27 @@ struct llm_build_olmoe : public llm_graph_context {
     }
 };
 
-struct llm_build_openelm : public llm_graph_context {
-    llm_build_openelm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+struct llm_build_llada_moe : public llm_graph_context {
+    llm_build_llada_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
 
         ggml_tensor * cur;
         ggml_tensor * inpL;
+
         inpL = build_inp_embd(model.tok_embd);
 
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_no_cache();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
         for (int il = 0; il < n_layer; ++il) {
-            const int64_t n_head    = hparams.n_head(il);
-            const int64_t n_head_kv = hparams.n_head_kv(il);
-            const int64_t n_head_qkv = 2*n_head_kv + n_head;
-
-            cur = inpL;
-            ggml_tensor * residual = cur;
+            ggml_tensor * inpSA = inpL;
 
             // norm
             cur = build_norm(inpL,
@@ -12144,26 +12577,155 @@ struct llm_build_openelm : public llm_graph_context {
                     LLM_NORM_RMS, il);
             cb(cur, "attn_norm", il);
 
-            // self-attention
+            // self_attention
             {
-                cur = build_lora_mm(model.layers[il].wqkv, cur);
-                cb(cur, "wqkv", il);
-
-                cur = ggml_reshape_3d(ctx0, cur, n_embd_head_k, n_head_qkv, n_tokens);
-
-                ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, cur->nb[1], cur->nb[2], 0);
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
 
-                ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head);
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
 
-                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv)));
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
-                Qcur = build_norm(Qcur,
-                        model.layers[il].attn_q_norm, NULL,
-                        LLM_NORM_RMS, il);
-                cb(Qcur, "Qcur", il);
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+                cb(Qcur, "Qcur_normed", il);
+
+                Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+                cb(Kcur, "Kcur_normed", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                cur = build_attn(inp_attn,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1 && inp_out_ids) {
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // MoE branch
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_moe_ffn(cur,
+                    model.layers[il].ffn_gate_inp,
+                    model.layers[il].ffn_up_exps,
+                    model.layers[il].ffn_gate_exps,
+                    model.layers[il].ffn_down_exps,
+                    nullptr,
+                    n_expert, n_expert_used,
+                    LLM_FFN_SILU, false,
+                    false, 0.0,
+                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                    il);
+            cb(cur, "ffn_moe_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_openelm : public llm_graph_context {
+    llm_build_openelm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv();
+
+        ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+        for (int il = 0; il < n_layer; ++il) {
+            const int64_t n_head    = hparams.n_head(il);
+            const int64_t n_head_kv = hparams.n_head_kv(il);
+            const int64_t n_head_qkv = 2*n_head_kv + n_head;
+
+            cur = inpL;
+            ggml_tensor * residual = cur;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                cur = build_lora_mm(model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_reshape_3d(ctx0, cur, n_embd_head_k, n_head_qkv, n_tokens);
+
+                ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, cur->nb[1], cur->nb[2], 0);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv)));
+                cb(Vcur, "Vcur", il);
+
+                Qcur = build_norm(Qcur,
+                        model.layers[il].attn_q_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(Qcur, "Qcur", il);
 
                 Kcur = build_norm(Kcur,
                         model.layers[il].attn_k_norm, NULL,
@@ -12188,7 +12750,7 @@ struct llm_build_openelm : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -12257,7 +12819,7 @@ struct llm_build_gptneox : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -12278,9 +12840,7 @@ struct llm_build_gptneox : public llm_graph_context {
 
                 ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
                 ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
-                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
-                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+                ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
 
                 Qcur = ggml_rope_ext(
                         ctx0, Qcur, inp_pos, nullptr,
@@ -12300,7 +12860,7 @@ struct llm_build_gptneox : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -12403,7 +12963,7 @@ struct llm_build_arctic : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -12450,7 +13010,7 @@ struct llm_build_arctic : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -12541,7 +13101,7 @@ struct llm_build_deepseek : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
 
@@ -12605,7 +13165,7 @@ struct llm_build_deepseek : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -12718,7 +13278,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -12833,7 +13393,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
                     // note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group)
                     cur = build_attn(inp_attn,
                             model.layers[il].wo, NULL,
-                            Qcur, Kcur, Vcur, nullptr, model.layers[il].wv_b, kq_scale, il);
+                            Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, kq_scale, il);
                 } else {
                     ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cmpr);
                     cb(kv, "kv", il);
@@ -12867,7 +13427,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
                     // note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups)
                     cur = build_attn(inp_attn,
                             model.layers[il].wo, NULL,
-                            Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
+                            Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
                 }
             }
 
@@ -12965,7 +13525,7 @@ struct llm_build_bitnet : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -13034,7 +13594,7 @@ struct llm_build_bitnet : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         NULL, NULL,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
 
                 cur = build_norm(cur,
                         model.layers[il].attn_sub_norm, NULL,
@@ -13157,7 +13717,7 @@ struct llm_build_t5_enc : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo_enc, nullptr,
-                        Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il);
+                        Qcur, Kcur, Vcur, kq_b, nullptr, nullptr, 1.0f, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -13229,12 +13789,14 @@ struct llm_build_t5_dec : public llm_graph_context {
 
         const int64_t n_outputs_enc = embd_enc->ne[1];
 
-        auto * inp_attn_self  = build_attn_inp_kv_unified();
+        auto * inp_attn_self  = build_attn_inp_kv();
         auto * inp_attn_cross = build_attn_inp_cross();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
-        for (int il = 0; il < n_layer; ++il) {
+        const int64_t dec_n_layer = hparams.dec_n_layer;
+
+        for (int il = 0; il < dec_n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
 
             // norm
@@ -13263,7 +13825,7 @@ struct llm_build_t5_dec : public llm_graph_context {
 
                 cur = build_attn(inp_attn_self,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il);
+                        Qcur, Kcur, Vcur, kq_b, nullptr, nullptr, 1.0f, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -13295,7 +13857,7 @@ struct llm_build_t5_dec : public llm_graph_context {
 
                 cur = build_attn(inp_attn_cross,
                         model.layers[il].wo_cross, nullptr,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il);
                 cb(cur, "kqv_out", il);
 
                 //ggml_tensor * q =                 ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
@@ -13325,7 +13887,7 @@ struct llm_build_t5_dec : public llm_graph_context {
                 //cb(cur, "kqv_out", il);
             }
 
-            if (il == n_layer - 1 && inp_out_ids) {
+            if (il == dec_n_layer - 1 && inp_out_ids) {
                 cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
                 inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids);
             }
@@ -13346,8 +13908,8 @@ struct llm_build_t5_dec : public llm_graph_context {
                         model.layers[il].ffn_gate, NULL, NULL,
                         model.layers[il].ffn_down, NULL, NULL,
                         NULL,
-                        model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
-                        model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ,
+                        model.layers[il].ffn_gate ? LLM_FFN_GELU : LLM_FFN_RELU,
+                        model.layers[il].ffn_gate ? LLM_FFN_PAR : LLM_FFN_SEQ,
                         il);
                 cb(cur, "ffn_out", il);
             }
@@ -13394,7 +13956,7 @@ struct llm_build_jais : public llm_graph_context {
 
         inpL = build_inp_embd(model.tok_embd);
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -13413,21 +13975,17 @@ struct llm_build_jais : public llm_graph_context {
                 cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
                 cb(cur, "bqkv", il);
 
-                ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*cur->nb[0]*(n_embd)));
-                ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd)));
-                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa)));
+                ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*cur->nb[0]*(n_embd));
+                ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*cur->nb[0]*(n_embd));
+                ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa));
 
                 cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
-
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/float(n_embd_head), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/float(n_embd_head), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -13492,7 +14050,7 @@ struct llm_build_chatglm : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -13526,6 +14084,7 @@ struct llm_build_chatglm : public llm_graph_context {
                     }
                     Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                     Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                    Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
                 } else {
                     cur = build_lora_mm(model.layers[il].wqkv, cur);
                     cb(cur, "wqkv", il);
@@ -13535,11 +14094,9 @@ struct llm_build_chatglm : public llm_graph_context {
                     }
                     Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
                     Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
-                    Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+                    Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
                 }
 
-                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
-
                 //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor);
                 Qcur = ggml_rope_ext(
                         ctx0, Qcur, inp_pos, nullptr,
@@ -13559,7 +14116,7 @@ struct llm_build_chatglm : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -13625,7 +14182,7 @@ struct llm_build_glm4 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -13660,6 +14217,7 @@ struct llm_build_glm4 : public llm_graph_context {
                     }
                     Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                     Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                    Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
                 } else {
                     cur = build_lora_mm(model.layers[il].wqkv, cur);
                     cb(cur, "wqkv", il);
@@ -13669,11 +14227,9 @@ struct llm_build_glm4 : public llm_graph_context {
                     }
                     Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
                     Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
-                    Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+                    Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
                 }
 
-                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
-
                 Qcur = ggml_rope_ext(
                         ctx0, Qcur, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -13692,7 +14248,7 @@ struct llm_build_glm4 : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -13775,7 +14331,7 @@ struct llm_build_glm4_moe : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -13841,7 +14397,7 @@ struct llm_build_glm4_moe : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_transformer_layers - 1 && inp_out_ids) {
@@ -13935,7 +14491,7 @@ struct llm_build_nemotron : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -13995,7 +14551,7 @@ struct llm_build_nemotron : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -14049,6 +14605,138 @@ struct llm_build_nemotron : public llm_graph_context {
     }
 };
 
+struct llm_build_nemotron_h : public llm_graph_context_mamba {
+    llm_build_nemotron_h(
+            const llama_model      & model,
+            const llm_graph_params & params) :
+        llm_graph_context_mamba(params) {
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        auto * inp = build_inp_mem_hybrid();
+
+        ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            if (hparams.is_recurrent(il)) {
+                // ssm layer //
+                cur = build_mamba2_layer(inp->get_recr(), cur, model, ubatch, il);
+            } else if (hparams.n_ff(il) == 0) {
+                // attention layer //
+                cur = build_attention_layer(cur, inp->get_attn(), model, n_embd_head, il);
+            } else {
+                cur = build_ffn_layer(cur, model, il);
+            }
+
+            if (il == n_layer - 1 && inp_out_ids) {
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            // add residual
+            cur = ggml_add(ctx0, cur, inpSA);
+            cb(cur, "block_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+
+    ggml_tensor * build_attention_layer(
+              ggml_tensor             * cur,
+              llm_graph_input_attn_kv * inp_attn,
+        const llama_model             & model,
+        const int64_t                   n_embd_head,
+        const int                       il) {
+
+        // compute Q and K and (optionally) RoPE them
+        ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+        cb(Qcur, "Qcur", il);
+        if (model.layers[il].bq) {
+            Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+            cb(Qcur, "Qcur", il);
+        }
+
+        ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+        cb(Kcur, "Kcur", il);
+        if (model.layers[il].bk) {
+            Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+            cb(Kcur, "Kcur", il);
+        }
+
+        ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+        cb(Vcur, "Vcur", il);
+        if (model.layers[il].bv) {
+            Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+            cb(Vcur, "Vcur", il);
+        }
+
+        Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il),    n_tokens);
+        Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens);
+        Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens);
+
+        cb(Qcur, "Qcur", il);
+        cb(Kcur, "Kcur", il);
+        cb(Vcur, "Vcur", il);
+
+        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+        cur = build_attn(inp_attn,
+                model.layers[il].wo, model.layers[il].bo,
+                Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
+                cb(cur, "attn_out", il);
+        return cur;
+    }
+
+    ggml_tensor * build_ffn_layer(
+              ggml_tensor * cur,
+        const llama_model & model,
+        const int           il) {
+
+        cur = build_ffn(cur,
+                model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                NULL,                      NULL,                        NULL,
+                model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                NULL,
+                LLM_FFN_RELU_SQR, LLM_FFN_PAR, il);
+        cb(cur, "ffn_out", il);
+
+        cur = build_cvec(cur, il);
+        cb(cur, "l_out", il);
+
+        return cur;
+    }
+};
+
 struct llm_build_exaone : public llm_graph_context {
     llm_build_exaone(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -14064,7 +14752,7 @@ struct llm_build_exaone : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -14126,7 +14814,7 @@ struct llm_build_exaone : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -14196,13 +14884,13 @@ struct llm_build_exaone4 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_unified_iswa, llm_graph_input_attn_kv_unified>;
+        using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>;
         inp_attn_type * inp_attn = nullptr;
 
         if constexpr (iswa) {
-            inp_attn = build_attn_inp_kv_unified_iswa();
+            inp_attn = build_attn_inp_kv_iswa();
         } else {
-            inp_attn = build_attn_inp_kv_unified();
+            inp_attn = build_attn_inp_kv();
         }
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -14257,7 +14945,7 @@ struct llm_build_exaone4 : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
                 cb(cur, "attn_out", il);
             }
 
@@ -15085,7 +15773,7 @@ struct llm_build_granite : public llm_graph_context {
             inp_pos = build_inp_pos();
         }
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -15136,12 +15824,12 @@ struct llm_build_granite : public llm_graph_context {
     }
 
     ggml_tensor * build_attention_layer(
-              ggml_tensor                     * cur,
-              ggml_tensor                     * inp_pos,
-              llm_graph_input_attn_kv_unified * inp_attn,
-        const llama_model                     & model,
-        const int64_t                           n_embd_head,
-        const int                               il) {
+              ggml_tensor             * cur,
+              ggml_tensor             * inp_pos,
+              llm_graph_input_attn_kv * inp_attn,
+        const llama_model             & model,
+        const int64_t                 n_embd_head,
+        const int                     il) {
 
         // compute Q and K and (optionally) RoPE them
         ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -15192,7 +15880,7 @@ struct llm_build_granite : public llm_graph_context {
         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
         cur = build_attn(inp_attn,
                 model.layers[il].wo, model.layers[il].bo,
-                Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
+                Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
                 cb(cur, "attn_out", il);
         return cur;
     }
@@ -15355,12 +16043,12 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba {
     }
 
     ggml_tensor * build_attention_layer(
-              ggml_tensor                     * cur,
-              ggml_tensor                     * inp_pos,
-              llm_graph_input_attn_kv_unified * inp_attn,
-        const llama_model                     & model,
-        const int64_t                           n_embd_head,
-        const int                               il) {
+              ggml_tensor             * cur,
+              ggml_tensor             * inp_pos,
+              llm_graph_input_attn_kv * inp_attn,
+        const llama_model             & model,
+        const int64_t                 n_embd_head,
+        const int                     il) {
 
         // compute Q and K and (optionally) RoPE them
         ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -15411,7 +16099,7 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba {
         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
         cur = build_attn(inp_attn,
                 model.layers[il].wo, model.layers[il].bo,
-                Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
+                Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
                 cb(cur, "attn_out", il);
         return cur;
     }
@@ -15517,7 +16205,7 @@ struct llm_build_chameleon : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -15596,7 +16284,7 @@ struct llm_build_chameleon : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, nullptr,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -15848,7 +16536,7 @@ struct llm_build_plm : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -15952,7 +16640,7 @@ struct llm_build_plm : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
-                        q_states, k_states, v_states, nullptr, nullptr, kq_scale, il);
+                        q_states, k_states, v_states, nullptr, nullptr, nullptr, kq_scale, il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -16013,7 +16701,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -16075,7 +16763,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -16162,7 +16850,7 @@ struct llm_build_dots1 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -16215,7 +16903,7 @@ struct llm_build_dots1 : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -16312,7 +17000,7 @@ struct llm_build_ernie4_5 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
@@ -16370,7 +17058,7 @@ struct llm_build_ernie4_5 : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1) {
@@ -16442,7 +17130,7 @@ struct llm_build_ernie4_5_moe : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -16503,7 +17191,7 @@ struct llm_build_ernie4_5_moe : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
                 cb(cur, "attn_out", il);
             }
 
@@ -16656,7 +17344,7 @@ struct llm_build_falcon_h1 : public llm_graph_context_mamba {
 
             ggml_tensor * attn_out = build_attn(inp->get_attn(),
                     model.layers[il].wo, NULL,
-                    Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
+                    Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
             cb(attn_out, "attn_out", il);
 
             cur = build_norm(inpL,
@@ -16816,7 +17504,7 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
 
 private:
     ggml_tensor * build_plamo2_attn_layer(
-            llm_graph_input_attn_kv_unified * inp,
+            llm_graph_input_attn_kv * inp,
             ggml_tensor * inp_pos,
             ggml_tensor * cur,
             const llama_model & model,
@@ -16838,16 +17526,14 @@ private:
             const int64_t k_offset = n_embd_head_q * n_head;
             const int64_t v_offset = k_offset + n_embd_head_k * n_head_kv;
 
-            ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_q, n_head, n_tokens, n_embd_head_q * sizeof(float), qkv->nb[1], q_offset * ggml_element_size(qkv));
+            ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_q, n_head,    n_tokens, n_embd_head_q * sizeof(float), qkv->nb[1], q_offset * ggml_element_size(qkv));
             ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_kv, n_tokens, n_embd_head_k * sizeof(float), qkv->nb[1], k_offset * ggml_element_size(qkv));
-            ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_head_v * n_head_kv, n_tokens, qkv->nb[1], v_offset * ggml_element_size(qkv)));
+            ggml_tensor * Vcur = ggml_view_3d(ctx0, qkv, n_embd_head_v, n_head_kv, n_tokens, n_embd_head_v * sizeof(float), qkv->nb[1], v_offset * ggml_element_size(qkv));
 
             cb(Qcur, "Qcur", il);
             cb(Kcur, "Kcur", il);
             cb(Vcur, "Vcur", il);
 
-            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv, n_tokens);
-
             Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
             cb(Qcur, "Qcur_normed", il);
 
@@ -16866,7 +17552,9 @@ private:
                     ext_factor, attn_factor, beta_fast, beta_slow
                     );
 
-            cur = build_attn(inp, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head_v)), il);
+            cur = build_attn(inp,
+                    model.layers[il].wo, NULL,
+                    Qcur, Kcur, Vcur, NULL, NULL, NULL, 1.0f/sqrtf(float(n_embd_head_v)), il);
         }
 
         cb(cur, "attn_out", il);
@@ -16913,15 +17601,13 @@ private:
         cb(zx, "mamba_in_proj", il);
         // {8192, 5, 1, 1} -> {8192, 1, 5, 1}
         zx = ggml_permute(ctx0, zx, 0, 2, 1, 3);
-        zx = ggml_cont(ctx0, zx);
-        zx = ggml_reshape_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs);
+        zx = ggml_cont_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs);
         cb(zx, "mamba_in_proj_out", il);
 
         // split into z and x
         // => {head_dim * n_heads, n_seq_tokens, n_seqs}
         ggml_tensor * x = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], head_dim*ggml_element_size(zx));
-        x = ggml_cont(ctx0, x);
-        x = ggml_reshape_3d(ctx0, x, head_dim * n_heads, n_seq_tokens, n_seqs);
+        x = ggml_cont_3d(ctx0, x, head_dim * n_heads, n_seq_tokens, n_seqs);
         // x = ggml_permute(ctx0, x, 0, 2, 1, 3);
         cb(x, "mamba_x_split", il);
 
@@ -17051,7 +17737,7 @@ struct llm_build_arcee : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
 
@@ -17115,7 +17801,7 @@ struct llm_build_arcee : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
                 cb(cur, "attn_out", il);
             }
 
@@ -17186,7 +17872,7 @@ struct llm_build_hunyuan_moe : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
 
@@ -17260,7 +17946,7 @@ struct llm_build_hunyuan_moe : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
                 cb(cur, "attn_out", il);
             }
 
@@ -17347,7 +18033,7 @@ struct llm_build_hunyuan_dense : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
 
@@ -17420,7 +18106,7 @@ struct llm_build_hunyuan_dense : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
                 cb(cur, "attn_out", il);
             }
 
@@ -17485,7 +18171,7 @@ struct llm_build_smollm3 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
 
@@ -17550,7 +18236,7 @@ struct llm_build_smollm3 : public llm_graph_context {
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
                 cb(cur, "attn_out", il);
             }
 
@@ -17617,7 +18303,7 @@ struct llm_build_openai_moe_iswa : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified_iswa();
+        auto * inp_attn = build_attn_inp_kv_iswa();
 
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
@@ -17672,9 +18358,9 @@ struct llm_build_openai_moe_iswa : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                cur = build_attn_with_sinks(inp_attn,
+                cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].attn_sinks, 1.0f/sqrtf(float(n_rot)), il);
+                        Qcur, Kcur, Vcur, nullptr, model.layers[il].attn_sinks, nullptr, 1.0f/sqrtf(float(n_rot)), il);
 
                 cb(cur, "attn_out", il);
             }
@@ -17771,8 +18457,7 @@ struct llm_build_lfm2 : public llm_graph_context {
         cb(cur, "model.embedding_norm", -1);
         res->t_embd = cur;
 
-        // lm_head is tied with embeddings
-        cur = build_lora_mm(model.tok_embd, cur);
+        cur = build_lora_mm(model.output, cur);
         cb(cur, "lm_head", -1);
 
         res->t_logits = cur;
@@ -17799,10 +18484,10 @@ struct llm_build_lfm2 : public llm_graph_context {
         return cur;
     }
 
-    ggml_tensor * build_attn_block(ggml_tensor                     * cur,
-                                   ggml_tensor                     * inp_pos,
-                                   llm_graph_input_attn_kv_unified * inp_attn,
-                                   int                               il) const {
+    ggml_tensor * build_attn_block(ggml_tensor             * cur,
+                                   ggml_tensor             * inp_pos,
+                                   llm_graph_input_attn_kv * inp_attn,
+                                   int                     il) const {
         GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il));
         auto const n_embd_head = hparams.n_embd_head_v;
         auto const n_head_kv = hparams.n_head_kv(il);
@@ -17837,7 +18522,7 @@ struct llm_build_lfm2 : public llm_graph_context {
                 );
 
         cur = build_attn(inp_attn, model.layers[il].wo, NULL,
-                q, k, v, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                q, k, v, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
 
         cb(cur, "model.layers.{}.self_attn.out_proj", il);
 
@@ -17914,6 +18599,137 @@ struct llm_build_lfm2 : public llm_graph_context {
     }
 };
 
+struct llm_build_seed_oss : public llm_graph_context {
+    llm_build_seed_oss(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv();
+
+        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+
+        ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                cur = build_attn(inp_attn,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
+                cb(cur, "attn_out", il);
+            }
+
+            if (il == n_layer - 1 && inp_out_ids) {
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = build_norm(ffn_inp,
+                    model.layers[il].attn_post_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_post_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
 template <bool iswa>
 struct llm_build_smallthinker : public llm_graph_context{
     llm_build_smallthinker(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params){
@@ -17930,13 +18746,13 @@ struct llm_build_smallthinker : public llm_graph_context{
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_unified_iswa, llm_graph_input_attn_kv_unified>;
+        using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>;
         inp_attn_type * inp_attn = nullptr;
 
         if constexpr (iswa) {
-            inp_attn = build_attn_inp_kv_unified_iswa();
+            inp_attn = build_attn_inp_kv_iswa();
         } else {
-            inp_attn = build_attn_inp_kv_unified();
+            inp_attn = build_attn_inp_kv();
         }
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -17981,7 +18797,7 @@ struct llm_build_smallthinker : public llm_graph_context{
 
                 cur = build_attn(inp_attn,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
             }
 
             if (il == n_layer - 1 && inp_out_ids) {
@@ -18043,12 +18859,15 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
         // switch statement
         case LLM_ARCH_BERT:
         case LLM_ARCH_JINA_BERT_V2:
+        case LLM_ARCH_JINA_BERT_V3:
         case LLM_ARCH_NOMIC_BERT:
         case LLM_ARCH_NOMIC_BERT_MOE:
         case LLM_ARCH_NEO_BERT:
         case LLM_ARCH_WAVTOKENIZER_DEC:
+        //case LLM_ARCH_GEMMA_EMBEDDING: // TODO: disabled until the cacheless SWA logic is fixed [TAG_NO_CACHE_ISWA]
         case LLM_ARCH_DREAM:
         case LLM_ARCH_LLADA:
+        case LLM_ARCH_LLADA_MOE:
             {
                 res = nullptr;
             } break;
@@ -18059,14 +18878,31 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                 if (llm_arch_is_recurrent(arch)) {
                     res = new llama_memory_recurrent(
                             *this,
-                            nullptr,
                             GGML_TYPE_F32,
                             GGML_TYPE_F32,
                             cparams.offload_kqv,
                             std::max((uint32_t) 1, cparams.n_seq_max),
-                            cparams.n_seq_max);
+                            cparams.n_seq_max,
+                            nullptr);
                 } else if (llm_arch_is_hybrid(arch)) {
-                    const auto padding = llama_kv_cache_unified::get_padding(cparams);
+
+                    // The main difference between hybrid architectures is the
+                    // layer filters, so pick the right one here
+                    llama_memory_hybrid::layer_filter_cb filter_attn = nullptr;
+                    llama_memory_hybrid::layer_filter_cb filter_recr = nullptr;
+                    if (arch == LLM_ARCH_FALCON_H1) {
+                        filter_attn = [&](int32_t) { return true; };
+                        filter_recr = [&](int32_t) { return true; };
+                    } else if (arch == LLM_ARCH_NEMOTRON_H) {
+                        filter_attn = [&](int32_t il) {
+                            return !hparams.is_recurrent(il) && hparams.n_ff(il) == 0;
+                        };
+                        filter_recr = [&](int32_t il) {
+                            return hparams.is_recurrent(il) && hparams.n_ff(il) == 0;
+                        };
+                    }
+
+                    const auto padding = llama_kv_cache::get_padding(cparams);
 
                     cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
 
@@ -18085,10 +18921,10 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                         /* n_seq_max         */ cparams.n_seq_max,
                         /* offload           */ cparams.offload_kqv,
                         /* unified           */ cparams.kv_unified,
-                        /* filter_attn       */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr,
-                        /* filter_recr       */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr);
+                        /* filter_attn       */ std::move(filter_attn),
+                        /* filter_recr       */ std::move(filter_recr));
                 } else {
-                    const auto padding = llama_kv_cache_unified::get_padding(cparams);
+                    const auto padding = llama_kv_cache::get_padding(cparams);
 
                     uint32_t n_ctx_per_stream = cparams.n_ctx;
 
@@ -18105,10 +18941,22 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
 
                     LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
 
+                    llama_memory_i::layer_reuse_cb reuse = nullptr;
+
+                    if (arch == LLM_ARCH_GEMMA3N) {
+                        reuse = [&](int32_t il) {
+                            if (il >= (int32_t) hparams.n_layer_kv_from_start) {
+                                return (int32_t) hparams.n_layer_kv_from_start - (hparams.is_swa(il) ? 2 : 1);
+                            }
+
+                            return -1;
+                        };
+                    }
+
                     if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
                         GGML_ASSERT(hparams.is_swa_any());
 
-                        res = new llama_kv_cache_unified_iswa(
+                        res = new llama_kv_cache_iswa(
                                 *this,
                                 params.type_k,
                                 params.type_v,
@@ -18119,13 +18967,14 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                                 n_ctx_per_stream,
                                 cparams.n_seq_max,
                                 cparams.n_ubatch,
-                                padding);
+                                padding,
+                                nullptr,
+                                reuse);
                     } else {
                         GGML_ASSERT(!hparams.is_swa_any());
 
-                        res = new llama_kv_cache_unified(
+                        res = new llama_kv_cache(
                                 *this,
-                                nullptr,
                                 params.type_k,
                                 params.type_v,
                                 !cparams.flash_attn,
@@ -18135,7 +18984,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                                 cparams.n_seq_max,
                                 padding,
                                 hparams.n_swa,
-                                hparams.swa_type);
+                                hparams.swa_type,
+                                nullptr,
+                                nullptr);
                     }
                 }
             }
@@ -18154,7 +19005,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
             } break;
         case LLM_ARCH_LLAMA4:
             {
-                llm = std::make_unique<llm_build_llama_iswa>(*this, params);
+                if (hparams.swa_type == LLAMA_SWA_TYPE_NONE) {
+                    llm = std::make_unique<llm_build_llama>(*this, params);
+                } else {
+                    llm = std::make_unique<llm_build_llama_iswa>(*this, params);
+                }
             } break;
         case LLM_ARCH_DECI:
             {
@@ -18182,6 +19037,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
             } break;
         case LLM_ARCH_BERT:
         case LLM_ARCH_JINA_BERT_V2:
+        case LLM_ARCH_JINA_BERT_V3:
         case LLM_ARCH_NOMIC_BERT:
         case LLM_ARCH_NOMIC_BERT_MOE:
             {
@@ -18221,6 +19077,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
                 llm = std::make_unique<llm_build_llada>(*this, params);
             }
             break;
+        case LLM_ARCH_LLADA_MOE:
+            {
+                llm = std::make_unique<llm_build_llada_moe>(*this, params);
+            }
+            break;
         case LLM_ARCH_QWEN2VL:
             {
                 llm = std::make_unique<llm_build_qwen2vl>(*this, params);
@@ -18294,6 +19155,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
             {
                 llm = std::make_unique<llm_build_gemma3n_iswa>(*this, params);
             } break;
+        case LLM_ARCH_GEMMA_EMBEDDING:
+            {
+                llm = std::make_unique<llm_build_gemma_embedding_iswa>(*this, params);
+            } break;
         case LLM_ARCH_STARCODER2:
             {
                 llm = std::make_unique<llm_build_starcoder2>(*this, params);
@@ -18329,7 +19194,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
             } break;
         case LLM_ARCH_OLMO2:
             {
-                llm = std::make_unique<llm_build_olmo2>(*this, params);
+                if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) {
+                    llm = std::make_unique<llm_build_olmo2<true>>(*this, params);
+                } else {
+                    llm = std::make_unique<llm_build_olmo2<false>>(*this, params);
+                }
             } break;
         case LLM_ARCH_OLMOE:
             {
@@ -18398,6 +19267,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
             {
                 llm = std::make_unique<llm_build_nemotron>(*this, params);
             } break;
+        case LLM_ARCH_NEMOTRON_H:
+            {
+                llm = std::make_unique<llm_build_nemotron_h>(*this, params);
+            } break;
         case LLM_ARCH_EXAONE:
             {
                 llm = std::make_unique<llm_build_exaone>(*this, params);
@@ -18452,6 +19325,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
             {
                 llm = std::make_unique<llm_build_bailingmoe>(*this, params);
             } break;
+        case LLM_ARCH_SEED_OSS:
+            {
+                llm = std::make_unique<llm_build_seed_oss>(*this, params);
+            } break;
         case LLM_ARCH_DOTS1:
             {
                 llm = std::make_unique<llm_build_dots1>(*this, params);
@@ -18510,6 +19387,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
     return llm->res->get_gf();
 }
 
+
 //
 // interface implementation
 //
@@ -18518,7 +19396,7 @@ llama_model_params llama_model_default_params() {
     llama_model_params result = {
         /*.devices                     =*/ nullptr,
         /*.tensor_buft_overrides       =*/ nullptr,
-        /*.n_gpu_layers                =*/ 0,
+        /*.n_gpu_layers                =*/ 999,
         /*.split_mode                  =*/ LLAMA_SPLIT_MODE_LAYER,
         /*.main_gpu                    =*/ 0,
         /*.tensor_split                =*/ nullptr,
@@ -18532,11 +19410,6 @@ llama_model_params llama_model_default_params() {
         /*.use_extra_bufts             =*/ true,
     };
 
-#ifdef GGML_USE_METAL
-    // note: we usually have plenty of VRAM, so by default offload all layers to the GPU
-    result.n_gpu_layers = 999;
-#endif
-
     return result;
 }
 
@@ -18628,6 +19501,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_RWKV7:
         case LLM_ARCH_ARWKV7:
         case LLM_ARCH_WAVTOKENIZER_DEC:
+        case LLM_ARCH_NEMOTRON_H:
             return LLAMA_ROPE_TYPE_NONE;
 
         // use what we call a normal RoPE, operating on pairs of consecutive head values
@@ -18667,6 +19541,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_GROK:
         case LLM_ARCH_DBRX:
         case LLM_ARCH_BERT:
+        case LLM_ARCH_JINA_BERT_V3:
         case LLM_ARCH_NOMIC_BERT:
         case LLM_ARCH_NOMIC_BERT_MOE:
         case LLM_ARCH_STABLELM:
@@ -18677,6 +19552,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_QWEN2MOE:
         case LLM_ARCH_QWEN3:
         case LLM_ARCH_QWEN3MOE:
+        case LLM_ARCH_LLADA_MOE:
         case LLM_ARCH_OLMO2:
         case LLM_ARCH_OLMOE:
         case LLM_ARCH_PHI2:
@@ -18688,6 +19564,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_GEMMA2:
         case LLM_ARCH_GEMMA3:
         case LLM_ARCH_GEMMA3N:
+        case LLM_ARCH_GEMMA_EMBEDDING:
         case LLM_ARCH_STARCODER2:
         case LLM_ARCH_OPENELM:
         case LLM_ARCH_GPTNEOX:
@@ -18704,6 +19581,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_LFM2:
         case LLM_ARCH_SMALLTHINKER:
         case LLM_ARCH_GLM4_MOE:
+        case LLM_ARCH_SEED_OSS:
             return LLAMA_ROPE_TYPE_NEOX;
 
         case LLM_ARCH_QWEN2VL:
index 46f7d0480fabe580df5e9bc3df569942b48c4109..b1981978e3acf8753a96dceff561b1f6e5060cdf 100644 (file)
@@ -28,6 +28,7 @@ enum llm_type {
     LLM_TYPE_80M,
     LLM_TYPE_109M,
     LLM_TYPE_137M,
+    LLM_TYPE_140M,
     LLM_TYPE_160M,
     LLM_TYPE_190M,
     LLM_TYPE_220M,
@@ -36,13 +37,15 @@ enum llm_type {
     LLM_TYPE_270M,
     LLM_TYPE_335M,
     LLM_TYPE_350M,
+    LLM_TYPE_360M,
     LLM_TYPE_410M,
     LLM_TYPE_450M,
     LLM_TYPE_475M,
-    LLM_TYPE_537M,
+    LLM_TYPE_558M,
     LLM_TYPE_700M,
     LLM_TYPE_770M,
     LLM_TYPE_780M,
+    LLM_TYPE_950M,
     LLM_TYPE_0_3B,
     LLM_TYPE_0_5B,
     LLM_TYPE_0_6B,
@@ -76,9 +79,11 @@ enum llm_type {
     LLM_TYPE_32B,
     LLM_TYPE_34B,
     LLM_TYPE_35B,
+    LLM_TYPE_36B,
     LLM_TYPE_40B,
     LLM_TYPE_65B,
     LLM_TYPE_70B,
+    LLM_TYPE_120B,
     LLM_TYPE_142B,
     LLM_TYPE_236B,
     LLM_TYPE_290B,
index 1d0361cc16659d5d93a5b42c49bb211a52ee4f3d..97228b2a693241045d3888736ddc06776c8c2506 100644 (file)
@@ -725,7 +725,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
         // attention layers have a non-zero number of kv heads
         int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
         if (llama_model_has_encoder(&model)) {
-            n_attn_layer *= 3;
+            // now n_attn_layer is the number of attention layers in the encoder
+            // for each decoder block, there are 2 attention layers
+            n_attn_layer += 2 * model.hparams.dec_n_layer;
         }
         GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected");
     }
@@ -920,7 +922,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
             new_type = tensor->type;
             new_data = tensor->data;
             new_size = ggml_nbytes(tensor);
-            LLAMA_LOG_INFO("size = %8.3f MB\n", ggml_nbytes(tensor)/1024.0/1024.0);
+            LLAMA_LOG_INFO("size = %8.3f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0);
         } else {
             const int64_t nelements = ggml_nelements(tensor);
 
@@ -1037,8 +1039,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
     }
     close_ofstream();
 
-    LLAMA_LOG_INFO("%s: model size  = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);
-    LLAMA_LOG_INFO("%s: quant size  = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0);
+    LLAMA_LOG_INFO("%s: model size  = %8.2f MiB\n", __func__, total_size_org/1024.0/1024.0);
+    LLAMA_LOG_INFO("%s: quant size  = %8.2f MiB\n", __func__, total_size_new/1024.0/1024.0);
 
     if (qs.n_fallback > 0) {
         LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) required fallback quantization\n",
index bfbf5fa23011240c0dec57b390670ef1ff47079b..2186f827bf54307731d1fbb57e9b38b380415f94 100644 (file)
@@ -128,6 +128,89 @@ struct ring_buffer {
     std::vector<T> data;
 };
 
+// writes result in res, does not mutate cur
+static void llama_token_data_array_partial_sort(const llama_token_data_array & cur, int npartial, std::vector<llama_token_data> & res) {
+    static const auto comp = [](const llama_token_data & a, const llama_token_data & b) {
+        return a.logit > b.logit;
+    };
+
+    constexpr int   nbuckets     = 128;
+    constexpr float bucket_low   = -10.0f;
+    constexpr float bucket_high  =  10.0f;
+    constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
+    constexpr float bucket_inter = -bucket_low * bucket_scale;
+
+    std::vector<int> bucket_idx;
+    std::vector<int> histo(nbuckets, 0);
+
+    std::vector<llama_token_data*> bucket_ptrs;
+
+    bucket_idx.reserve(cur.size);
+
+    for (int i = 0; i < (int)cur.size; ++i) {
+        const float val = cur.data[i].logit;
+        int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
+        ib = std::max(0, std::min(nbuckets - 1, ib));
+        bucket_idx.push_back(ib);
+        ++histo[ib];
+    }
+    int nhave = 0;
+    int ib = nbuckets - 1;
+    for ( ; ib >= 0; --ib) {
+        nhave += histo[ib];
+        if (nhave >= npartial) {
+            break;
+        }
+    }
+    res.resize(nhave);
+    auto * ptr = res.data();
+    bucket_ptrs.reserve(nbuckets - ib);
+    for (int j = nbuckets - 1; j >= ib; --j) {
+        bucket_ptrs.push_back(ptr);
+        ptr += histo[j];
+    }
+    for (int i = 0; i < (int)cur.size; ++i) {
+        int j = bucket_idx[i];
+        if (j >= ib) {
+            *bucket_ptrs[nbuckets - 1 - j]++ = cur.data[i];
+        }
+    }
+
+    ptr = res.data();
+    int ndone = 0;
+    for (int j = nbuckets - 1; j > ib; --j) {
+        std::sort(ptr, ptr + histo[j], comp);
+        ptr += histo[j];
+        ndone += histo[j];
+    }
+    std::partial_sort(ptr, ptr + npartial - ndone, ptr + histo[ib], comp);
+}
+
+// reduces the size of cur_p to npartial, keeping only the top npartial elements
+static void llama_token_data_array_partial_sort_inplace(llama_token_data_array * cur_p, int npartial) {
+    static const auto comp = [](const llama_token_data & a, const llama_token_data & b) {
+        return a.logit > b.logit;
+    };
+
+    if (npartial <= 128) {
+        std::partial_sort(cur_p->data, cur_p->data + npartial, cur_p->data + cur_p->size, comp);
+
+        cur_p->size = npartial;
+        cur_p->sorted = true;
+
+        return;
+    }
+
+    std::vector<llama_token_data> tmp;
+
+    llama_token_data_array_partial_sort(*cur_p, npartial, tmp);
+
+    std::copy(tmp.data(), tmp.data() + npartial, cur_p->data);
+
+    cur_p->size = npartial;
+    cur_p->sorted = true;
+}
+
 static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
     // iterator for the probabilities
 #ifdef __GNUC__
@@ -200,18 +283,21 @@ static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp)
     }
 }
 
-static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
+static void llama_sampler_softmax_impl(llama_token_data_array * cur_p, bool do_sort) {
     GGML_ASSERT(cur_p->size > 0);
 
-    // Sort the logits in descending order
-    if (!cur_p->sorted) {
-        std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
-            return a.logit > b.logit;
-        });
-        cur_p->sorted = true;
+    // Sort the logits in descending order if requested
+    if (do_sort && !cur_p->sorted) {
+        llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size);
     }
 
     float max_l = cur_p->data[0].logit;
+    if (!cur_p->sorted) {
+        for (size_t i = 1; i < cur_p->size; ++i) {
+            max_l = std::max(max_l, cur_p->data[i].logit);
+        }
+    }
+
     float cum_sum = 0.0f;
 
     for (size_t i = 0; i < cur_p->size; ++i) {
@@ -226,7 +312,6 @@ static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
 }
 
 static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
-    // TODO: move bucket sort to separate function so that top_p/typical/softmax first is equally fast
     // if (k >= (int32_t)cur_p->size) {
     //     return;
     // }
@@ -239,64 +324,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
 
     // Sort scores in descending order
     if (!cur_p->sorted) {
-        auto comp = [](const llama_token_data & a, const llama_token_data & b) {
-            return a.logit > b.logit;
-        };
-        if (k <= 128) {
-            std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp);
-        } else {
-            constexpr int   nbuckets     = 128;
-            constexpr float bucket_low   = -10.0f;
-            constexpr float bucket_high  =  10.0f;
-            constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
-            constexpr float bucket_inter = -bucket_low * bucket_scale;
-
-            std::vector<int> bucket_idx(cur_p->size);
-            std::vector<int> histo(nbuckets, 0);
-
-            for (int i = 0; i < (int)cur_p->size; ++i) {
-                const float val = cur_p->data[i].logit;
-                int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
-                ib = std::max(0, std::min(nbuckets - 1, ib));
-                bucket_idx[i] = ib;
-                ++histo[ib];
-            }
-            int nhave = 0;
-            int ib = nbuckets - 1;
-            for ( ; ib >= 0; --ib) {
-                nhave += histo[ib];
-                if (nhave >= k) {
-                    break;
-                }
-            }
-            std::vector<llama_token_data> tmp_tokens(nhave);
-            auto * ptr = tmp_tokens.data();
-            std::vector<llama_token_data*> bucket_ptrs;
-            bucket_ptrs.reserve(nbuckets - ib);
-            for (int j = nbuckets - 1; j >= ib; --j) {
-                bucket_ptrs.push_back(ptr);
-                ptr += histo[j];
-            }
-            for (int i = 0; i < (int)cur_p->size; ++i) {
-                int j = bucket_idx[i];
-                if (j >= ib) {
-                    *bucket_ptrs[nbuckets - 1 - j]++ = cur_p->data[i];
-                }
-            }
-
-            ptr = tmp_tokens.data();
-            int ndone = 0;
-            for (int j = nbuckets - 1; j > ib; --j) {
-                std::sort(ptr, ptr + histo[j], comp);
-                ptr += histo[j];
-                ndone += histo[j];
-            }
-            std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
-
-            std::memcpy(cur_p->data, tmp_tokens.data(), k*sizeof(llama_token_data));
-
-        }
-        cur_p->sorted = true;
+        llama_token_data_array_partial_sort_inplace(cur_p, k);
     }
 
     cur_p->size = k;
@@ -576,9 +604,73 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*
 static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
     auto * ctx = (llama_sampler_dist *) smpl->ctx;
 
-    llama_sampler_softmax_impl(cur_p);
+    // edge cases
+    if (cur_p->size == 0) {
+        cur_p->selected = -1;
+        return;
+    }
+
+    cur_p->selected = 0;
+
+    if (cur_p->size == 1) {
+        cur_p->data[0].p = 1.0f;
+        return;
+    }
+
+    // max logit for numerical stability
+    float max_l = cur_p->data[0].logit;
+    if (!cur_p->sorted) {
+        for (size_t i = 1; i < cur_p->size; ++i) {
+            max_l = std::max(max_l, cur_p->data[i].logit);
+        }
+    }
+
+    // apply softmax to obtain the probabilities
+    double sum_cum = 0.0f;
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        float p = expf(cur_p->data[i].logit - max_l);
+        cur_p->data[i].p = p;
+        sum_cum += p;
+    }
+
+#if 1
+    // sample from the obtained probabilities and normalize the probs in a single pass
+    // this is ~3x faster on Mac with full gpt-oss vocab than the version below
+    //
+    std::uniform_real_distribution<double> dist(0.0f, 1.0f);
+    const double rnd = dist(ctx->rng);
+
+          double sum_run = 0.0f;
+    const double sum_tgt = sum_cum*rnd;
+
+    bool found = false;
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        if (!found) {
+            // accumulate probs until we reach the target sum
+            sum_run += cur_p->data[i].p;
+            if (sum_run >= sum_tgt) {
+                cur_p->selected = i;
+                found = true;
+            }
+        }
+
+        // normalize probs
+        cur_p->data[i].p /= sum_cum;
+    }
+
+    // fallback to the last token (don't think this can happen)
+    assert(found);
+    if (!found) {
+        cur_p->selected = cur_p->size - 1;
+    }
+#else
+    // for clarity, this is the same as above but does one pass for normalization and one extra pass for sampling
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        cur_p->data[i].p /= sum_cum;
+    }
 
     cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
+#endif
 }
 
 static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
@@ -626,32 +718,6 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
     );
 }
 
-// softmax
-
-static const char * llama_sampler_softmax_name(const struct llama_sampler * /*smpl*/) {
-    return "softmax";
-}
-
-static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
-    llama_sampler_softmax_impl(cur_p);
-}
-
-static struct llama_sampler_i llama_sampler_softmax_i = {
-    /* .name   = */ llama_sampler_softmax_name,
-    /* .accept = */ nullptr,
-    /* .apply  = */ llama_sampler_softmax_apply,
-    /* .reset  = */ nullptr,
-    /* .clone  = */ nullptr,
-    /* .free   = */ nullptr,
-};
-
-struct llama_sampler * llama_sampler_init_softmax() {
-    return llama_sampler_init(
-        /* .iface = */ &llama_sampler_softmax_i,
-        /* .ctx   = */ nullptr
-    );
-}
-
 // top-k
 
 struct llama_sampler_top_k {
@@ -663,7 +729,7 @@ static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl
 }
 
 static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
-    const auto * ctx = (llama_sampler_top_k *) smpl->ctx;
+    auto * ctx = (llama_sampler_top_k *) smpl->ctx;
     llama_sampler_top_k_impl(cur_p, ctx->k);
 }
 
@@ -699,6 +765,8 @@ struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
 struct llama_sampler_top_p {
     const float  p;
     const size_t min_keep;
+
+    std::vector<llama_token_data> buf_sort;
 };
 
 static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) {
@@ -706,20 +774,35 @@ static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl
 }
 
 static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
-    const auto * ctx = (llama_sampler_top_p *) smpl->ctx;
+    auto * ctx = (llama_sampler_top_p *) smpl->ctx;
 
     if (ctx->p >= 1.0f) {
         return;
     }
 
-    llama_sampler_softmax_impl(cur_p);
+    llama_sampler_softmax_impl(cur_p, false);
+
+    size_t k = cur_p->size;
+    auto * pdata = cur_p->data;
+
+    auto & buf_sort = ctx->buf_sort;
+
+    // if not sorted, try adaptive top-k sorting
+    if (!cur_p->sorted && cur_p->size > 1024) {
+        k = std::min<size_t>(256, cur_p->size);
+        llama_token_data_array_partial_sort(*cur_p, k, buf_sort);
+        pdata = buf_sort.data();
+    } else if (!cur_p->sorted) {
+        // small candidates -> sort inplace
+        llama_token_data_array_partial_sort_inplace(cur_p, k);
+    }
 
     // Compute the cumulative probabilities
     float cum_sum = 0.0f;
     size_t last_idx = cur_p->size;
 
     for (size_t i = 0; i < cur_p->size; ++i) {
-        cum_sum += cur_p->data[i].p;
+        cum_sum += pdata[i].p;
 
         // Check if the running sum is at least p or if we have kept at least min_keep tokens
         // we set the last index to i+1 to indicate that the current iterate should be included in the set
@@ -727,9 +810,21 @@ static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_d
             last_idx = i + 1;
             break;
         }
+
+        // we exceeded the current top-k heuristic -> increase k and continue
+        if (!cur_p->sorted && i == k - 1) {
+            k = cur_p->size;
+            llama_token_data_array_partial_sort(*cur_p, k, buf_sort);
+            pdata = buf_sort.data();
+        }
     }
 
     // Resize the output vector to keep only the top-p tokens
+    if (!cur_p->sorted) {
+        std::copy(buf_sort.data(), buf_sort.data() + last_idx, cur_p->data);
+        cur_p->sorted = true;
+    }
+
     cur_p->size = last_idx;
 }
 
@@ -757,6 +852,7 @@ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
         /* .ctx   = */ new llama_sampler_top_p {
             /* .p        = */ p,
             /* .min_keep = */ min_keep,
+            /* .buf_sort = */ {},
         }
     );
 }
@@ -773,7 +869,7 @@ static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl
 }
 
 static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
-    const auto * ctx = (llama_sampler_min_p *) smpl->ctx;
+    auto * ctx = (llama_sampler_min_p *) smpl->ctx;
 
     if (ctx->p <= 0.0f || !cur_p->size) {
         return;
@@ -799,7 +895,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d
 
         // if we have enough values the operation was a success
         if (!filtered_tokens.empty() && filtered_tokens.size() >= ctx->min_keep) {
-            memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
+            std::copy(filtered_tokens.begin(), filtered_tokens.end(), cur_p->data);
             cur_p->size = filtered_tokens.size();
             min_p_applied = true;
         }
@@ -809,10 +905,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d
     if (!min_p_applied) {
         // Sort the logits in descending order
         if (!cur_p->sorted) {
-            std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
-                return a.logit > b.logit;
-            });
-            cur_p->sorted = true;
+            llama_token_data_array_partial_sort_inplace(cur_p, cur_p->size);
         }
 
         const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max
@@ -869,7 +962,7 @@ static const char * llama_sampler_typical_name(const struct llama_sampler * /*sm
 }
 
 static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
-    const auto * ctx = (llama_sampler_typical *) smpl->ctx;
+    auto * ctx = (llama_sampler_typical *) smpl->ctx;
 
     // Reference implementation:
     // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
@@ -878,7 +971,7 @@ static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token
     }
 
     // Compute the softmax of logits and calculate entropy
-    llama_sampler_softmax_impl(cur_p);
+    llama_sampler_softmax_impl(cur_p, true);
 
     float entropy = 0.0f;
     for (size_t i = 0; i < cur_p->size; ++i) {
@@ -1012,7 +1105,7 @@ static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*s
 }
 
 static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
-    const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
+    auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
     if (ctx->delta > 0) {
         const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
         const float max_temp = ctx->temp + ctx->delta;
@@ -1027,7 +1120,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
         // Calculate maximum possible entropy
         float max_entropy = -logf(1.0f / cur_p->size);
 
-        llama_sampler_softmax_impl(cur_p);
+        llama_sampler_softmax_impl(cur_p, true);
 
         // Calculate entropy of the softmax probabilities
         float entropy = 0.0f;
@@ -1121,7 +1214,7 @@ struct llama_sampler_xtc {
     const uint32_t seed;
     uint32_t       seed_cur;
 
-    std::mt19937   rng;
+    std::mt19937    rng;
 };
 
 static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
@@ -1139,17 +1232,20 @@ static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data
 
     std::uniform_real_distribution<float> distribution(0.0f, 1.0f);
     float chance = distribution(ctx->rng);
-    if (chance > ctx->probability) return;
+    if (chance > ctx->probability) {
+        return;
+    }
 
-    // in case it's not sorted/recalculated yet
-    llama_sampler_softmax_impl(cur_p);
+    llama_sampler_softmax_impl(cur_p, true);
 
     int pos_last = 0;
 
     for (size_t i = 0; i < cur_p->size; ++i) {
         if (cur_p->data[i].p >= ctx->threshold) {
             pos_last = i;
-        } else break;
+        } else {
+            break;
+        }
     }
 
     if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) {
@@ -1221,7 +1317,7 @@ struct llama_sampler_mirostat {
 
     float mu;
 
-    std::mt19937 rng;
+    std::mt19937    rng;
 };
 
 static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) {
@@ -1231,7 +1327,7 @@ static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*s
 static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
     auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
 
-    llama_sampler_softmax_impl(cur_p);
+    llama_sampler_softmax_impl(cur_p, true);
 
     // Estimate s_hat using the most probable m tokens
     float s_hat = 0.0;
@@ -1250,7 +1346,8 @@ static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_toke
     float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat);
 
     llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
-    llama_sampler_softmax_impl(cur_p);
+
+    llama_sampler_softmax_impl(cur_p, true);
 
     const int idx = llama_sample_dist(cur_p, ctx->rng);
 
@@ -1336,7 +1433,7 @@ static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler *
 static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
     auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
 
-    llama_sampler_softmax_impl(cur_p);
+    llama_sampler_softmax_impl(cur_p, true);
 
     // Truncate the words with surprise values greater than mu
     cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
@@ -1348,7 +1445,7 @@ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_t
     }
 
     // Normalize the probabilities of the remaining words
-    llama_sampler_softmax_impl(cur_p);
+    llama_sampler_softmax_impl(cur_p, true);
 
     const int idx = llama_sample_dist(cur_p, ctx->rng);
 
@@ -1540,7 +1637,7 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
                 trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0");
             }
             trigger_pattern += ")[\\s\\S]*";
-            auto trigger_pattern_c = trigger_pattern.c_str();
+            const auto * trigger_pattern_c = trigger_pattern.c_str();
             trigger_patterns = &trigger_pattern_c;
             num_trigger_patterns = 1;
         }
@@ -1748,7 +1845,7 @@ static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler *
 }
 
 static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
-    const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
+    auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
 
     if (ctx->n <= 0.0f || cur_p->size <= 1) {
         return;
@@ -1780,13 +1877,14 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
     }
     float std = valid_count > 0 ? sqrt(acc/valid_count) : 0;
 
-    //apply mask
+    // apply mask
     for (size_t i = 0; i < cur_p->size; ++i) {
         if (cur_p->data[i].logit < max - (ctx->n * std)) {
             cur_p->data[i].logit = -INFINITY;
         }
     }
-    llama_sampler_softmax_impl(cur_p);
+
+    llama_sampler_softmax_impl(cur_p, true);
 }
 
 static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) {
@@ -1991,7 +2089,9 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat
 
     {
         const int last = last_n_repeat - 1;
-        int rt = 0, lt = 0;
+
+        int rt = 0;
+        int lt = 0;
 
         for (int k = 1; k < last_n_repeat; ++k) {
             if (k > rt) {
@@ -2135,8 +2235,8 @@ static struct llama_sampler_i llama_sampler_dry_i = {
     /* .free   = */ llama_sampler_dry_free,
 };
 
-struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
-    int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
+struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
+    int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? n_ctx_train : std::max(dry_penalty_last_n, 0);
     std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
     const int MAX_CHAR_LEN = 40;
     const int MAX_SEQ_LEN = 20;
@@ -2169,7 +2269,7 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
     return llama_sampler_init(
         /* .iface = */ &llama_sampler_dry_i,
         /* .ctx   = */ new llama_sampler_dry {
-            /* .total_context_size     = */ context_size,
+            /* .total_context_size     = */ n_ctx_train,
             /* .dry_multiplier         = */ dry_multiplier,
             /* .dry_base               = */ dry_base,
             /* .dry_allowed_length     = */ dry_allowed_length,
@@ -2308,7 +2408,7 @@ static const char * llama_sampler_infill_name(const struct llama_sampler * /*smp
 static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
     auto * ctx = (llama_sampler_infill *) smpl->ctx;
 
-    llama_sampler_softmax_impl(cur_p);
+    llama_sampler_softmax_impl(cur_p, true);
 
 #if defined(GGML_DEBUG_SAMPLER_INFILL)
 #define LOG_DBG_CUR LLAMA_LOG_DEBUG
index de5d1681dff8544b50893f88812c89ecc8fa0e1b..8cb36661a0c968d15a1d6df1a262087b0803d370 100644 (file)
@@ -434,6 +434,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
                     "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\\r\\n]+|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
                 };
                 break;
+            case LLAMA_VOCAB_PRE_TYPE_GROK_2:
+                regex_exprs = {
+                    // original regex from tokenizer.json
+                    // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
+                    "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+                };
+                break;
             default:
                 // default regex for BPE tokenization pre-processing
                 regex_exprs = {
@@ -1955,7 +1962,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                 pre_type = LLAMA_VOCAB_PRE_TYPE_TRILLION;
                 clean_spaces = false;
             } else if (
-                tokenizer_pre == "bailingmoe") {
+                tokenizer_pre == "bailingmoe" ||
+                tokenizer_pre == "llada-moe") {
                 pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE;
                 clean_spaces = false;
             } else if (
@@ -1974,6 +1982,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                 tokenizer_pre == "kimi-k2") {
                 pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2;
                 clean_spaces = false;
+            } else if (
+                tokenizer_pre == "grok-2") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_GROK_2;
+                clean_spaces = false;
             } else {
                 throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
             }
@@ -2470,7 +2482,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
         // set attributes by model/tokenizer/architecture name
         if (false
                 || _contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})
-                || _contains_any(general_arch, {"nomic-bert-moe"})
+                || _contains_any(general_arch, {"nomic-bert-moe", "jina-bert-v3"})
            ) {
             if (token_to_id.count("<mask>") == 0) {
                 LLAMA_LOG_WARN("%s: Mask token is missing in vocab, please reconvert model!\n", __func__);
index 61b8124216847b2eb9d84586c8aae8c382a1c589..0d2f28c36c80dd714a6a2d01981872583dbbc7cd 100644 (file)
@@ -47,6 +47,7 @@ enum llama_vocab_pre_type {
     LLAMA_VOCAB_PRE_TYPE_HUNYUAN        = 36,
     LLAMA_VOCAB_PRE_TYPE_KIMI_K2        = 37,
     LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE  = 38,
+    LLAMA_VOCAB_PRE_TYPE_GROK_2         = 39,
 };
 
 struct LLM_KV;
index 34906cdb62844875bf572a2a1df6118a2a8aa885..fe5a7a835488c5a16c4219521fd20a472e7b202d 100644 (file)
 // interface implementation
 //
 
+const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type) {
+    switch (flash_attn_type) {
+        case LLAMA_FLASH_ATTN_TYPE_AUTO:
+            return "auto";
+        case LLAMA_FLASH_ATTN_TYPE_DISABLED:
+            return "disabled";
+        case LLAMA_FLASH_ATTN_TYPE_ENABLED:
+            return "enabled";
+    }
+    GGML_ABORT("fatal error");
+}
+
 struct llama_sampler_chain_params llama_sampler_chain_default_params() {
     struct llama_sampler_chain_params result = {
         /*.no_perf                     =*/ true,
@@ -47,6 +59,7 @@ bool llama_supports_mlock(void) {
 
 bool llama_supports_gpu_offload(void) {
     return ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU) != nullptr ||
+           ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_IGPU) != nullptr ||
            llama_supports_rpc();
 }
 
@@ -71,7 +84,9 @@ void llama_numa_init(enum ggml_numa_strategy numa) {
         GGML_ASSERT(dev && "CPU backend is not loaded");
         auto * reg = ggml_backend_dev_backend_reg(dev);
         auto * numa_init_fn = (decltype(ggml_numa_init) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_numa_init");
-        numa_init_fn(numa);
+        if (numa_init_fn) {
+            numa_init_fn(numa);
+        }
     }
 }
 
@@ -170,8 +185,13 @@ static struct llama_model * llama_model_load_from_file_impl(
             model->devices.push_back(*dev);
         }
     } else {
+        // default device selection
+
+        // build list of available devices
+        std::vector<ggml_backend_dev_t> gpus;
+        std::vector<ggml_backend_dev_t> igpus;
         std::vector<ggml_backend_dev_t> rpc_servers;
-        // use all available devices
+
         for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
             ggml_backend_dev_t dev = ggml_backend_dev_get(i);
             switch (ggml_backend_dev_type(dev)) {
@@ -180,19 +200,51 @@ static struct llama_model * llama_model_load_from_file_impl(
                     // skip CPU backends since they are handled separately
                     break;
 
-                case GGML_BACKEND_DEVICE_TYPE_GPU:
+                case GGML_BACKEND_DEVICE_TYPE_GPU: {
                     ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
                     if (ggml_backend_reg_name(reg) == std::string("RPC")) {
                         rpc_servers.push_back(dev);
                     } else {
-                        model->devices.push_back(dev);
+                        // check if there is already a GPU with the same device id
+                        ggml_backend_dev_props props;
+                        ggml_backend_dev_get_props(dev, &props);
+                        auto it = std::find_if(gpus.begin(), gpus.end(), [&props](ggml_backend_dev_t d) {
+                            ggml_backend_dev_props d_props;
+                            ggml_backend_dev_get_props(d, &d_props);
+                            if (props.device_id && d_props.device_id) {
+                                return strcmp(props.device_id, d_props.device_id) == 0;
+                            }
+                            return false;
+                        });
+
+                        if (it != gpus.end()) {
+                            LLAMA_LOG_INFO("%s: skipping device %s (%s) with id %s - already using device %s (%s) with the same id\n",
+                                    __func__,
+                                    ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
+                                    props.device_id ? props.device_id : "unknown id",
+                                    ggml_backend_dev_name(*it), ggml_backend_dev_description(*it));
+                        } else {
+                            gpus.push_back(dev);
+                        }
                     }
                     break;
+                }
+
+                case GGML_BACKEND_DEVICE_TYPE_IGPU:
+                    igpus.push_back(dev);
+                    break;
             }
         }
-        // add RPC servers at the front of the list
-        if (!rpc_servers.empty()) {
-            model->devices.insert(model->devices.begin(), rpc_servers.begin(), rpc_servers.end());
+
+        // add RPC servers at the front of the list to minimize network transfers
+        model->devices.insert(model->devices.begin(), rpc_servers.begin(), rpc_servers.end());
+
+        // add GPUs
+        model->devices.insert(model->devices.end(), gpus.begin(), gpus.end());
+
+        // add integrated GPUs only if no other devices were found
+        if (model->devices.empty()) {
+            model->devices.insert(model->devices.end(), igpus.begin(), igpus.end());
         }
     }
 
@@ -213,9 +265,12 @@ static struct llama_model * llama_model_load_from_file_impl(
     }
 
     for (auto * dev : model->devices) {
-        size_t free, total; // NOLINT
-        ggml_backend_dev_memory(dev, &free, &total);
-        LLAMA_LOG_INFO("%s: using device %s (%s) - %zu MiB free\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), free/1024/1024);
+        ggml_backend_dev_props props;
+        ggml_backend_dev_get_props(dev, &props);
+        LLAMA_LOG_INFO("%s: using device %s (%s) (%s) - %zu MiB free\n", __func__,
+                ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
+                props.device_id ? props.device_id : "unknown id",
+                props.memory_free/1024/1024);
     }
 
     const int status = llama_model_load(path_model, splits, *model, params);
index 135eaf1b655695e8e407dedd65ff2c6fe4db53dd..453190e852b51e87502e813c6d4f91d85cd1f888 100644 (file)
@@ -64,8 +64,6 @@ extern "C" {
 
     typedef struct llama_memory_i * llama_memory_t;
 
-    struct llama_kv_cache; // DEPRECATED (use llama_memory instead)
-
     typedef int32_t llama_pos;
     typedef int32_t llama_token;
     typedef int32_t llama_seq_id;
@@ -181,6 +179,14 @@ extern "C" {
         LLAMA_ATTENTION_TYPE_NON_CAUSAL  = 1,
     };
 
+    enum llama_flash_attn_type {
+        LLAMA_FLASH_ATTN_TYPE_AUTO     = -1,
+        LLAMA_FLASH_ATTN_TYPE_DISABLED = 0,
+        LLAMA_FLASH_ATTN_TYPE_ENABLED  = 1,
+    };
+
+    LLAMA_API const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type);
+
     enum llama_split_mode {
         LLAMA_SPLIT_MODE_NONE  = 0, // single GPU
         LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
@@ -200,7 +206,7 @@ extern "C" {
         llama_token_data * data;
         size_t size;
         int64_t selected; // this is the index in the data array (i.e. not the token id)
-        bool sorted;
+        bool sorted;      // note: do not assume the data is sorted - always check this flag
     } llama_token_data_array;
 
     typedef bool (*llama_progress_callback)(float progress, void * user_data);
@@ -305,6 +311,7 @@ extern "C" {
         enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
         enum llama_pooling_type      pooling_type;      // whether to pool (sum) embedding results by sequence id
         enum llama_attention_type    attention_type;    // attention type to use for embeddings
+        enum llama_flash_attn_type   flash_attn_type;   // when to enable Flash Attention
 
         // ref: https://github.com/ggml-org/llama.cpp/pull/2054
         float    rope_freq_base;   // RoPE base frequency, 0 = from model
@@ -314,7 +321,7 @@ extern "C" {
         float    yarn_beta_fast;   // YaRN low correction dim
         float    yarn_beta_slow;   // YaRN high correction dim
         uint32_t yarn_orig_ctx;    // YaRN original context size
-        float    defrag_thold;     // defragment the KV cache if holes/size > thold, <= 0 disabled (default)
+        float    defrag_thold;     // [DEPRECATED] defragment the KV cache if holes/size > thold, <= 0 disabled (default)
 
         ggml_backend_sched_eval_callback cb_eval;
         void * cb_eval_user_data;
@@ -331,7 +338,6 @@ extern "C" {
         // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
         bool embeddings;  // if true, extract embeddings (together with logits)
         bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU
-        bool flash_attn;  // use flash attention [EXPERIMENTAL]
         bool no_perf;     // measure performance timings
         bool op_offload;  // offload host tensor operations to device
         bool swa_full;    // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
@@ -469,8 +475,6 @@ extern "C" {
     LLAMA_API           llama_memory_t   llama_get_memory  (const struct llama_context * ctx);
     LLAMA_API  enum llama_pooling_type   llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
 
-    DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead");
-
     LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
     LLAMA_API enum llama_rope_type       llama_model_rope_type(const struct llama_model * model);
 
@@ -557,10 +561,32 @@ extern "C" {
             struct llama_model * model,
             const char * path_lora);
 
+    // Functions to access the adapter's GGUF metadata scalar values
+    // - The functions return the length of the string on success, or -1 on failure
+    // - The output string is always null-terminated and cleared on failure
+    // - When retrieving a string, an extra byte must be allocated to account for the null terminator
+    // - GGUF array values are not supported by these functions
+
+    // Get metadata value as a string by key name
+    LLAMA_API int32_t llama_adapter_meta_val_str(const struct llama_adapter_lora * adapter, const char * key, char * buf, size_t buf_size);
+
+    // Get the number of metadata key/value pairs
+    LLAMA_API int32_t llama_adapter_meta_count(const struct llama_adapter_lora * adapter);
+
+    // Get metadata key name by index
+    LLAMA_API int32_t llama_adapter_meta_key_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size);
+
+    // Get metadata value as a string by index
+    LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size);
+
     // Manually free a LoRA adapter
     // Note: loaded adapters will be free when the associated model is deleted
     LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
 
+    // Get the invocation tokens if the current lora is an alora
+    LLAMA_API uint64_t            llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter);
+    LLAMA_API const llama_token * llama_adapter_get_alora_invocation_tokens  (const struct llama_adapter_lora * adapter);
+
     // The following functions operate on a llama_context, hence the naming: llama_verb_...
 
     // Add a loaded LoRA adapter to given context
@@ -667,111 +693,6 @@ extern "C" {
     // Check if the memory supports shifting
     LLAMA_API bool llama_memory_can_shift(llama_memory_t mem);
 
-    //
-    // KV cache for self-attention (TODO: deprecate in favor of llama_memory)
-    //
-
-    // Returns the number of tokens in the KV cache (slow, use only for debug)
-    // If a KV cell has multiple sequences assigned to it, it will be counted multiple times
-    DEPRECATED(LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx),
-               "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
-
-    // Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
-    DEPRECATED(LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx),
-               "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
-
-    // Clear the KV cache - both cell info is erased and KV data is zeroed
-    DEPRECATED(LLAMA_API void llama_kv_self_clear(
-                struct llama_context * ctx),
-            "Use llama_memory_clear() instead");
-
-    // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
-    // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
-    // seq_id < 0 : match any sequence
-    // p0 < 0     : [0,  p1]
-    // p1 < 0     : [p0, inf)
-    DEPRECATED(LLAMA_API bool llama_kv_self_seq_rm(
-            struct llama_context * ctx,
-                    llama_seq_id   seq_id,
-                       llama_pos   p0,
-                       llama_pos   p1),
-            "Use llama_memory_seq_rm() instead");
-
-    // Copy all tokens that belong to the specified sequence to another sequence
-    // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
-    // p0 < 0 : [0,  p1]
-    // p1 < 0 : [p0, inf)
-    DEPRECATED(LLAMA_API void llama_kv_self_seq_cp(
-            struct llama_context * ctx,
-                    llama_seq_id   seq_id_src,
-                    llama_seq_id   seq_id_dst,
-                       llama_pos   p0,
-                       llama_pos   p1),
-            "Use llama_memory_seq_cp() instead");
-
-    // Removes all tokens that do not belong to the specified sequence
-    DEPRECATED(LLAMA_API void llama_kv_self_seq_keep(
-            struct llama_context * ctx,
-                    llama_seq_id   seq_id),
-            "Use llama_memory_seq_keep() instead");
-
-    // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
-    // If the KV cache is RoPEd, the KV data is updated accordingly:
-    //   - lazily on next llama_decode()
-    // p0 < 0 : [0,  p1]
-    // p1 < 0 : [p0, inf)
-    DEPRECATED(LLAMA_API void llama_kv_self_seq_add(
-            struct llama_context * ctx,
-                    llama_seq_id   seq_id,
-                       llama_pos   p0,
-                       llama_pos   p1,
-                       llama_pos   delta),
-            "Use llama_memory_seq_add() instead");
-
-    // Integer division of the positions by factor of `d > 1`
-    // If the KV cache is RoPEd, the KV data is updated accordingly:
-    //   - lazily on next llama_decode()
-    // p0 < 0 : [0,  p1]
-    // p1 < 0 : [p0, inf)
-    DEPRECATED(LLAMA_API void llama_kv_self_seq_div(
-            struct llama_context * ctx,
-                    llama_seq_id   seq_id,
-                       llama_pos   p0,
-                       llama_pos   p1,
-                             int   d),
-            "Use llama_memory_seq_div() instead");
-
-    // Returns the smallest position present in the KV cache for the specified sequence
-    // This is typically non-zero only for SWA caches
-    // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
-    // Return -1 if the sequence is empty
-    DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_min(
-            struct llama_context * ctx,
-                    llama_seq_id   seq_id),
-            "Use llama_memory_seq_pos_min() instead");
-
-    // Returns the largest position present in the KV cache for the specified sequence
-    // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
-    // Return -1 if the sequence is empty
-    DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_max(
-            struct llama_context * ctx,
-                    llama_seq_id   seq_id),
-            "Use llama_memory_seq_pos_max() instead");
-
-    // Defragment the KV cache
-    // This will be applied:
-    //   - lazily on next llama_decode()
-    DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx),
-            "simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");
-
-    // Check if the context supports KV cache shifting
-    DEPRECATED(LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx),
-            "use llama_memory_can_shift() instead");
-
-    // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
-    DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx),
-            "simply remove this call, updates are applied lazily on the next llama_decode()");
-
     //
     // State / sessions
     //
@@ -1239,11 +1160,6 @@ extern "C" {
     LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void);
     LLAMA_API struct llama_sampler * llama_sampler_init_dist  (uint32_t seed);
 
-    /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
-    /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
-    DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax    (void),
-        "will be removed in the future (see https://github.com/ggml-org/llama.cpp/pull/9896#discussion_r1800920915)");
-
     /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
     /// Setting k <= 0 makes this a noop
     LLAMA_API struct llama_sampler * llama_sampler_init_top_k      (int32_t k);
index b4219c29446355fa40fc287c05a8db601046851d..239c56902d442682440ab87b3a7435439f011120 100644 (file)
@@ -340,9 +340,10 @@ int main(int argc, char ** argv) {
     llama_context_params lcparams = llama_context_default_params();
 
     // tune these to your liking
-    lcparams.n_ctx      = 2048;
-    lcparams.n_threads  = params.n_threads;
-    lcparams.flash_attn = params.flash_attn;
+    lcparams.n_ctx     = 2048;
+    lcparams.n_threads = params.n_threads;
+
+    lcparams.flash_attn_type = params.flash_attn ? LLAMA_FLASH_ATTN_TYPE_AUTO : LLAMA_FLASH_ATTN_TYPE_DISABLED;
 
     struct llama_context * ctx_llama = llama_init_from_model(model_llama, lcparams);