]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
talk-llama : sync llama.cpp
authorGeorgi Gerganov <redacted>
Tue, 1 Jul 2025 09:21:09 +0000 (12:21 +0300)
committerGeorgi Gerganov <redacted>
Tue, 1 Jul 2025 14:54:53 +0000 (17:54 +0300)
24 files changed:
examples/talk-llama/llama-arch.cpp
examples/talk-llama/llama-arch.h
examples/talk-llama/llama-batch.cpp
examples/talk-llama/llama-chat.cpp
examples/talk-llama/llama-context.cpp
examples/talk-llama/llama-context.h
examples/talk-llama/llama-graph.cpp
examples/talk-llama/llama-graph.h
examples/talk-llama/llama-hparams.h
examples/talk-llama/llama-kv-cache-unified-iswa.cpp
examples/talk-llama/llama-kv-cache-unified-iswa.h
examples/talk-llama/llama-kv-cache-unified.cpp
examples/talk-llama/llama-kv-cache-unified.h
examples/talk-llama/llama-kv-cells.h
examples/talk-llama/llama-memory-hybrid.cpp
examples/talk-llama/llama-memory-hybrid.h
examples/talk-llama/llama-memory-recurrent.cpp
examples/talk-llama/llama-memory-recurrent.h
examples/talk-llama/llama-memory.cpp
examples/talk-llama/llama-memory.h
examples/talk-llama/llama-model.cpp
examples/talk-llama/llama-model.h
examples/talk-llama/llama-quant.cpp
examples/talk-llama/llama.h

index 8dadef204f9d71f45039c4402aedcb5e923683ad..aa21108a4bd79009723842c43d8324b4b0ff07f7 100644 (file)
@@ -42,6 +42,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_GEMMA,            "gemma"            },
     { LLM_ARCH_GEMMA2,           "gemma2"           },
     { LLM_ARCH_GEMMA3,           "gemma3"           },
+    { LLM_ARCH_GEMMA3N,          "gemma3n"          },
     { LLM_ARCH_STARCODER2,       "starcoder2"       },
     { LLM_ARCH_MAMBA,            "mamba"            },
     { LLM_ARCH_XVERSE,           "xverse"           },
@@ -75,6 +76,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_BAILINGMOE,       "bailingmoe"       },
     { LLM_ARCH_DOTS1,            "dots1"            },
     { LLM_ARCH_ARCEE,            "arcee"            },
+    { LLM_ARCH_ERNIE4_5,         "ernie4_5"         },
     { LLM_ARCH_UNKNOWN,          "(unknown)"        },
 };
 
@@ -932,6 +934,42 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_POST_NORM,   "blk.%d.post_ffw_norm" },
         },
     },
+    {
+        LLM_ARCH_GEMMA3N,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,           "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,          "output_norm" },
+            { LLM_TENSOR_ATTN_NORM,            "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,               "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_Q_NORM,          "blk.%d.attn_q_norm" },
+            { LLM_TENSOR_ATTN_K,               "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_K_NORM,          "blk.%d.attn_k_norm" },
+            { LLM_TENSOR_ATTN_V,               "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,             "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_POST_NORM,       "blk.%d.post_attention_norm" },
+            { LLM_TENSOR_FFN_NORM,             "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,             "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,             "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,               "blk.%d.ffn_up" },
+            { LLM_TENSOR_FFN_POST_NORM,        "blk.%d.post_ffw_norm" },
+            { LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "per_layer_token_embd" },
+            { LLM_TENSOR_PER_LAYER_MODEL_PROJ, "per_layer_model_proj" },
+            { LLM_TENSOR_PER_LAYER_PROJ_NORM,  "per_layer_proj_norm" },
+            { LLM_TENSOR_ALTUP_UNEMBD_PROJ,    "altup_unembd_proj" },
+            { LLM_TENSOR_ALTUP_PROJ,           "altup_proj" },
+            { LLM_TENSOR_PER_LAYER_INP_GATE,   "blk.%d.inp_gate" },
+            { LLM_TENSOR_PER_LAYER_PROJ,       "blk.%d.proj" },
+            { LLM_TENSOR_PER_LAYER_POST_NORM,  "blk.%d.post_norm" },
+            { LLM_TENSOR_ALTUP_CORRECT_COEF,   "blk.%d.altup_correct_coef" },
+            { LLM_TENSOR_ALTUP_CORRECT_SCALE,  "blk.%d.altup_correct_scale" },
+            { LLM_TENSOR_ALTUP_PREDICT_COEF,   "blk.%d.altup_predict_coef" },
+            { LLM_TENSOR_ALTUP_ROUTER,         "blk.%d.altup_router" },
+            { LLM_TENSOR_ALTUP_ROUTER_NORM,    "blk.%d.altup_router_norm" },
+            { LLM_TENSOR_LAUREL_L,             "blk.%d.laurel_l" },
+            { LLM_TENSOR_LAUREL_R,             "blk.%d.laurel_r" },
+            { LLM_TENSOR_LAUREL_POST_NORM,     "blk.%d.laurel_post_norm" },
+        },
+    },
     {
         LLM_ARCH_STARCODER2,
         {
@@ -1621,6 +1659,23 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_EXP_PROBS_B,    "blk.%d.exp_probs_b" },
         }
     },
+    {
+        LLM_ARCH_ERNIE4_5,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
+            { LLM_TENSOR_OUTPUT,             "output" },
+            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,             "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,             "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,           "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,           "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,             "blk.%d.ffn_up" },
+        },
+    },
     {
         LLM_ARCH_UNKNOWN,
         {
@@ -1749,6 +1804,23 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
     {LLM_TENSOR_FFN_GATE_EXPS,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
     {LLM_TENSOR_FFN_UP_EXPS,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
     {LLM_TENSOR_FFN_EXP_PROBS_B,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
+    // altup / laurel (gemma 3n)
+    {LLM_TENSOR_PER_LAYER_TOKEN_EMBD,       {LLM_TENSOR_LAYER_OUTPUT,    GGML_OP_GET_ROWS}},
+    {LLM_TENSOR_PER_LAYER_MODEL_PROJ,       {LLM_TENSOR_LAYER_OUTPUT,    GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_PER_LAYER_PROJ_NORM,        {LLM_TENSOR_LAYER_OUTPUT,    GGML_OP_MUL}},
+    {LLM_TENSOR_ALTUP_PROJ,                 {LLM_TENSOR_LAYER_OUTPUT,    GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ALTUP_UNEMBD_PROJ,          {LLM_TENSOR_LAYER_OUTPUT,    GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_PER_LAYER_INP_GATE,         {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_PER_LAYER_PROJ,             {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_PER_LAYER_POST_NORM,        {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+    {LLM_TENSOR_ALTUP_CORRECT_COEF,         {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ALTUP_CORRECT_SCALE,        {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+    {LLM_TENSOR_ALTUP_PREDICT_COEF,         {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ALTUP_ROUTER,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_ALTUP_ROUTER_NORM,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+    {LLM_TENSOR_LAUREL_L,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_LAUREL_R,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_LAUREL_POST_NORM,           {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
     // this tensor is loaded for T5, but never used
     {LLM_TENSOR_DEC_CROSS_ATTN_REL_B,       {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
     {LLM_TENSOR_CONV1D,                     {LLM_TENSOR_LAYER_INPUT,     GGML_OP_IM2COL}},
index 5b0230c15067817a4500c903259f7e97a1213db3..0771ec3ebadcdc970f2ff3e1c08afbe83d9f1319 100644 (file)
@@ -46,6 +46,7 @@ enum llm_arch {
     LLM_ARCH_GEMMA,
     LLM_ARCH_GEMMA2,
     LLM_ARCH_GEMMA3,
+    LLM_ARCH_GEMMA3N,
     LLM_ARCH_STARCODER2,
     LLM_ARCH_MAMBA,
     LLM_ARCH_XVERSE,
@@ -79,6 +80,7 @@ enum llm_arch {
     LLM_ARCH_BAILINGMOE,
     LLM_ARCH_DOTS1,
     LLM_ARCH_ARCEE,
+    LLM_ARCH_ERNIE4_5,
     LLM_ARCH_UNKNOWN,
 };
 
@@ -269,6 +271,22 @@ enum llm_tensor {
     LLM_TENSOR_LAYER_OUT_NORM,
     LLM_TENSOR_POST_ATTN_NORM,
     LLM_TENSOR_POST_MLP_NORM,
+    LLM_TENSOR_PER_LAYER_TOKEN_EMBD, // gemma3n
+    LLM_TENSOR_PER_LAYER_MODEL_PROJ, // gemma3n
+    LLM_TENSOR_PER_LAYER_INP_GATE,   // gemma3n
+    LLM_TENSOR_PER_LAYER_PROJ,       // gemma3n
+    LLM_TENSOR_PER_LAYER_PROJ_NORM,  // gemma3n
+    LLM_TENSOR_PER_LAYER_POST_NORM,  // gemma3n
+    LLM_TENSOR_ALTUP_PROJ,           // gemma3n
+    LLM_TENSOR_ALTUP_UNEMBD_PROJ,    // gemma3n
+    LLM_TENSOR_ALTUP_CORRECT_COEF,   // gemma3n
+    LLM_TENSOR_ALTUP_CORRECT_SCALE,  // gemma3n
+    LLM_TENSOR_ALTUP_PREDICT_COEF,   // gemma3n
+    LLM_TENSOR_ALTUP_ROUTER,         // gemma3n
+    LLM_TENSOR_ALTUP_ROUTER_NORM,    // gemma3n
+    LLM_TENSOR_LAUREL_L,             // gemma3n
+    LLM_TENSOR_LAUREL_R,             // gemma3n
+    LLM_TENSOR_LAUREL_POST_NORM,     // gemma3n
     LLM_TENSOR_SSM_IN,
     LLM_TENSOR_SSM_CONV1D,
     LLM_TENSOR_SSM_X,
index b3c996e18ab41183ad1ebcb8a7c2773b32e52ee5..91b1d6078a2529e4c31c43a5295f26c35b70e090 100644 (file)
@@ -244,22 +244,35 @@ bool llama_batch_allocr::init(
             continue;
         }
 
-        if (memory) {
+        const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
+
+        if (p0 >= 0) {
+            bool ok = true;
+
             if (batch.token) {
-                if (seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
-                    LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
-                    return false;
+                if (seq_pos_min(s) != p0 + 1) {
+                    ok = false;
                 }
             } else {
                 assert(batch.embd);
 
                 // for embeddings (typically used as vision input), we allow them to have repeating positions
                 // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
-                if (seq_pos_min(s) != memory->seq_pos_max(s) && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
-                    LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
-                    return false;
+                if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
+                    ok = false;
                 }
             }
+
+            if (!ok) {
+                LLAMA_LOG_ERROR(
+                        "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
+                        " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
+                        " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
+                        " it is required that the sequence positions remain consecutive: Y = X + 1\n",
+                        __func__, s, s, p0, s, seq_pos_min(s));
+
+                return false;
+            }
         }
 
         if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
index 0839cad3ee6db5cea3b883542cee7bb0c9a09a1e..5d317f4ee62ebc0f60d83a805e137de1a24a0ff4 100644 (file)
@@ -528,12 +528,17 @@ int32_t llm_chat_apply_template(
         }
     } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
         // this template requires the model to have "\n\n" as EOT token
-        for (auto message : chat) {
-            std::string role(message->role);
-            if (role == "user") {
-                ss << "User: " << message->content << "\n\nAssistant:";
-            } else {
-                ss << message->content << "\n\n";
+        for (size_t i = 0; i < chat.size(); i++) {
+            std::string role(chat[i]->role);
+            if (role == "system") {
+                ss << "System: " << trim(chat[i]->content) << "\n\n";
+            } else if (role == "user") {
+                ss << "User: " << trim(chat[i]->content) << "\n\n";
+                if (i == chat.size() - 1) {
+                    ss << "Assistant:";
+                }
+            } else if (role == "assistant") {
+                ss << "Assistant: " << trim(chat[i]->content) << "\n\n";
             }
         }
     } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) {
index 5a18a4fb3939a10082f2ecb2a4f17ec9815a1320..06e93b19cbf4087b24284ea3a473f12441c31b6a 100644 (file)
@@ -280,8 +280,8 @@ llama_context::llama_context(
 
         // simulate full KV cache
 
-        const auto mstate = memory->init_full();
-        if (!mstate) {
+        const auto mctx = memory->init_full();
+        if (!mctx) {
             throw std::runtime_error("failed to initialize KV cache");
         }
 
@@ -289,7 +289,7 @@ llama_context::llama_context(
 
         // reserve pp graph first so that buffers are only allocated once
         {
-            auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
+            auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
             if (!gf) {
                 throw std::runtime_error("failed to allocate compute pp buffers");
             }
@@ -300,7 +300,7 @@ llama_context::llama_context(
 
         // reserve with tg graph to get the number of splits and nodes
         {
-            auto * gf = graph_reserve(1, 1, 1, mstate.get());
+            auto * gf = graph_reserve(1, 1, 1, mctx.get());
             if (!gf) {
                 throw std::runtime_error("failed to allocate compute tg buffers");
             }
@@ -311,7 +311,7 @@ llama_context::llama_context(
 
         // reserve again with pp graph to avoid ggml-alloc reallocations during inference
         {
-            auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
+            auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
             if (!gf) {
                 throw std::runtime_error("failed to allocate compute pp buffers");
             }
@@ -444,8 +444,8 @@ bool llama_context::kv_self_update(bool optimize) {
         optimize |= memory_force_optimize;
         memory_force_optimize = false;
 
-        const auto mstate = memory->init_update(this, optimize);
-        switch (mstate->get_status()) {
+        const auto mctx = memory->init_update(this, optimize);
+        switch (mctx->get_status()) {
             case LLAMA_MEMORY_STATUS_SUCCESS:
                 {
                     // noop
@@ -463,22 +463,22 @@ bool llama_context::kv_self_update(bool optimize) {
                 }
         }
 
-        if (!mstate->apply()) {
+        if (!mctx->apply()) {
             LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
         }
     }
 
     // if the memory module did any computation, we have to reserve a new worst-case graph
     {
-        const auto mstate = memory->init_full();
-        if (!mstate) {
-            throw std::runtime_error("failed to initialize memory state");
+        const auto mctx = memory->init_full();
+        if (!mctx) {
+            throw std::runtime_error("failed to initialize memory context");
         }
 
         const uint32_t n_seqs   = cparams.n_seq_max;
         const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
 
-        auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
+        auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
         if (!gf) {
             LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
         }
@@ -678,9 +678,9 @@ bool llama_context::apply_adapter_cvec(
     return cvec.apply(model, data, len, n_embd, il_start, il_end);
 }
 
-llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) {
-    if (mstate && !mstate->apply()) {
-        LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__);
+llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
+    if (mctx && !mctx->apply()) {
+        LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
         ret = GGML_STATUS_FAILED;
         return nullptr;
     }
@@ -692,7 +692,7 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
         return nullptr;
     }
 
-    auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate);
+    auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
     if (!res) {
         LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
         ret = GGML_STATUS_FAILED;
@@ -933,21 +933,21 @@ int llama_context::decode(const llama_batch & batch_inp) {
     // handle any pending defrags/shifts
     kv_self_update(false);
 
-    llama_memory_state_ptr mstate;
+    llama_memory_context_ptr mctx;
 
     while (true) {
-        mstate = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
-        if (!mstate) {
+        mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
+        if (!mctx) {
             return -2;
         }
 
-        switch (mstate->get_status()) {
+        switch (mctx->get_status()) {
             case LLAMA_MEMORY_STATUS_SUCCESS:
                 {
                 } break;
             case LLAMA_MEMORY_STATUS_NO_UPDATE:
                 {
-                    LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status());
+                    LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status());
 
                     return -2;
                 }
@@ -987,7 +987,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
     int64_t n_outputs_prev = 0;
 
     do {
-        const auto & ubatch = mstate->get_ubatch();
+        const auto & ubatch = mctx->get_ubatch();
 
         // count the outputs in this ubatch
         {
@@ -1009,7 +1009,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
         ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
 
         ggml_status status;
-        const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status);
+        const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
 
         if (!res) {
             // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@@ -1018,7 +1018,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
                 pos_min[s] = std::numeric_limits<llama_pos>::max();
             }
 
-            // TODO: fix sequence indexing
             for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
                 const auto & seq_id = ubatch.seq_id[i][0];
 
@@ -1126,7 +1125,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
         }
 
         n_outputs_prev += n_outputs;
-    } while (mstate->next());
+    } while (mctx->next());
 
     // set to total number of outputs in the batch, for use in llama_get_logits_ith
     n_outputs = n_outputs_all;
@@ -1292,7 +1291,7 @@ ggml_cgraph * llama_context::graph_init() {
     return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
 }
 
-ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) {
+ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
     LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
 
     if (n_tokens % n_seqs != 0) {
@@ -1312,7 +1311,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
     llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
 
     auto * gf = graph_init();
-    auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
+    auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
 
     this->n_outputs = save_n_outputs;
 
@@ -1333,11 +1332,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
 }
 
 llm_graph_result_ptr llama_context::graph_build(
-                    ggml_context * ctx,
-                     ggml_cgraph * gf,
-              const llama_ubatch & ubatch,
-                  llm_graph_type   gtype,
-      const llama_memory_state_i * mstate) {
+                      ggml_context * ctx,
+                       ggml_cgraph * gf,
+                const llama_ubatch & ubatch,
+                    llm_graph_type   gtype,
+      const llama_memory_context_i * mctx) {
     return model.build_graph(
             {
                 /*.ctx         =*/ ctx,
@@ -1349,7 +1348,7 @@ llm_graph_result_ptr llama_context::graph_build(
                 /*.backend_cpu =*/ backend_cpu,
                 /*.cvec        =*/ &cvec,
                 /*.loras       =*/ &loras,
-                /*.mstate      =*/ mstate,
+                /*.mctx        =*/ mctx,
                 /*.cross       =*/ &cross,
                 /*.n_outputs   =*/ n_outputs,
                 /*.cb          =*/ graph_get_cb(),
@@ -2042,8 +2041,8 @@ void llama_context::opt_epoch_iter(
 
         uint32_t n_outputs_all = n_tokens_all;
 
-        auto mstate = memory->init_batch(*balloc, cparams.n_ubatch, true);
-        if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
+        auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true);
+        if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
             LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
             break;
         }
@@ -2056,17 +2055,17 @@ void llama_context::opt_epoch_iter(
 
         uint32_t pos_batch = 0;
         do {
-            const auto & ubatch = mstate->get_ubatch();
+            const auto & ubatch = mctx->get_ubatch();
 
             n_outputs = ubatch.n_tokens;
 
-            if (!mstate->apply()) {
-                LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
+            if (!mctx->apply()) {
+                LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__);
                 break;
             }
 
             auto * gf = graph_init();
-            auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get());
+            auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
 
             struct ggml_context * ctx_compute_opt;
             {
@@ -2101,7 +2100,7 @@ void llama_context::opt_epoch_iter(
             ggml_free(ctx_compute_opt);
 
             pos_batch += ubatch.n_tokens;
-        } while (mstate->next());
+        } while (mctx->next());
     }
 }
 
index 7d300c14572e9ed2b2009f4a11f69c36eafc3303..9ce05715a8c0306312bd03f516ae38f1f79cb2f6 100644 (file)
@@ -18,7 +18,7 @@ class llama_io_read_i;
 class llama_io_write_i;
 
 struct llama_memory_i;
-struct llama_memory_state_i;
+struct llama_memory_context_i;
 
 struct llama_context {
     // init scheduler and compute buffers, reserve worst-case graphs
@@ -93,14 +93,14 @@ struct llama_context {
                 int32_t   il_end);
 
     // process a single ubatch with a specific graph type
-    // if memory_state is provided, it will be applied first to the context's memory
+    // if memory_context is provided, it will be applied first to the context's memory
     // ret contains the status of the graph computation
     // returns nullptr only if ret != GGML_STATUS_SUCCESS
     llm_graph_result_ptr process_ubatch(
-              const llama_ubatch & ubatch,
-                  llm_graph_type   gtype,
-            llama_memory_state_i * mstate,
-                     ggml_status & ret);
+                const llama_ubatch & ubatch,
+                    llm_graph_type   gtype,
+            llama_memory_context_i * mctx,
+                       ggml_status & ret);
 
     int encode(const llama_batch & batch_inp);
     int decode(const llama_batch & batch_inp);
@@ -197,15 +197,15 @@ public:
     ggml_status graph_compute(ggml_cgraph * gf, bool batched);
 
     // reserve a graph with a dummy ubatch of the specified size
-    ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate);
+    ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
 
 private:
     llm_graph_result_ptr graph_build(
-                    ggml_context * ctx,
-                     ggml_cgraph * gf,
-              const llama_ubatch & ubatch,
-                  llm_graph_type   gtype,
-      const llama_memory_state_i * mstate);
+                      ggml_context * ctx,
+                       ggml_cgraph * gf,
+                const llama_ubatch & ubatch,
+                    llm_graph_type   gtype,
+      const llama_memory_context_i * mctx);
 
     llm_graph_cb graph_get_cb() const;
 
index 7e162c555220439e9d6cdf204deedca4ba104471..010300df6098eb443ec0497ff903c9d02d44ccad 100644 (file)
@@ -87,7 +87,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
 
 void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
     if (pos_bucket) {
-        kv_state->set_input_pos_bucket(pos_bucket, ubatch);
+        mctx->set_input_pos_bucket(pos_bucket, ubatch);
     }
 }
 
@@ -221,7 +221,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
 void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
     GGML_UNUSED(ubatch);
 
-    const int64_t n_rs = mem_state->get_n_rs();
+    const int64_t n_rs = mctx->get_n_rs();
 
     if (s_copy) {
         GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
@@ -229,7 +229,7 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
 
         // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
         for (uint32_t i = 0; i < n_rs; ++i) {
-            data[i] = mem_state->s_copy(i);
+            data[i] = mctx->s_copy(i);
         }
     }
 }
@@ -282,17 +282,17 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
 
 void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
     if (self_kq_mask) {
-        kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
+        mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
     }
 }
 
 void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
     if (self_kq_mask) {
-        kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
+        mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
     }
 
     if (self_kq_mask_swa) {
-        kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
+        mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
     }
 }
 
@@ -334,10 +334,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
 
 void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
     if (self_kq_mask) {
-        mem_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
+        mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
     }
 
-    const int64_t n_rs = mem_state->get_state_recr()->get_n_rs();
+    const int64_t n_rs = mctx->get_recr()->get_n_rs();
 
     if (s_copy) {
         GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
@@ -345,11 +345,17 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
 
         // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
         for (uint32_t i = 0; i < n_rs; ++i) {
-            data[i] = mem_state->get_state_recr()->s_copy(i);
+            data[i] = mctx->get_recr()->s_copy(i);
         }
     }
 }
 
+void llm_graph_input_one::set_input(const llama_ubatch *) {
+    GGML_ASSERT(one && ggml_nelements(one) == 1);
+    float f_one = 1.0f;
+    ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
+}
+
 //
 // llm_graph_context
 //
@@ -389,7 +395,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
     backend_cpu      (params.backend_cpu),
     cvec             (params.cvec),
     loras            (params.loras),
-    mstate           (params.mstate),
+    mctx             (params.mctx),
     cross            (params.cross),
     cb_func          (params.cb),
     res              (std::make_unique<llm_graph_result>()) {
@@ -554,12 +560,20 @@ ggml_tensor * llm_graph_context::build_ffn(
 
     switch (type_op) {
         case LLM_FFN_SILU:
-            {
+            if (gate && type_gate == LLM_FFN_PAR) {
+                cur = ggml_swiglu_split(ctx0, cur, tmp);
+                cb(cur, "ffn_swiglu", il);
+                type_gate = LLM_FFN_SEQ;
+            } else {
                 cur = ggml_silu(ctx0, cur);
                 cb(cur, "ffn_silu", il);
             } break;
         case LLM_FFN_GELU:
-            {
+            if (gate && type_gate == LLM_FFN_PAR) {
+                cur = ggml_geglu_split(ctx0, cur, tmp);
+                cb(cur, "ffn_geglu", il);
+                type_gate = LLM_FFN_SEQ;
+            } else {
                 cur = ggml_gelu(ctx0, cur);
                 cb(cur, "ffn_gelu", il);
                 if (act_scales != NULL) {
@@ -568,7 +582,11 @@ ggml_tensor * llm_graph_context::build_ffn(
                 }
             } break;
         case LLM_FFN_RELU:
-            {
+            if (gate && type_gate == LLM_FFN_PAR) {
+                cur = ggml_reglu_split(ctx0, cur, tmp);
+                cb(cur, "ffn_reglu", il);
+                type_gate = LLM_FFN_SEQ;
+            } else {
                 cur = ggml_relu(ctx0, cur);
                 cb(cur, "ffn_relu", il);
             } break;
@@ -582,32 +600,19 @@ ggml_tensor * llm_graph_context::build_ffn(
             } break;
         case LLM_FFN_SWIGLU:
             {
-                // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
-                int64_t split_point = cur->ne[0] / 2;
-                // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
-                ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
-                ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
-
-                x0 = ggml_silu(ctx0, x0);
-                cb(cur, "ffn_silu", il);
-
-                cur = ggml_mul(ctx0, x0, x1);
-                cb(cur, "ffn_mul", il);
+                cur = ggml_swiglu(ctx0, cur);
+                cb(cur, "ffn_swiglu", il);
             } break;
         case LLM_FFN_GEGLU:
             {
-                // Split into two equal parts
-                int64_t split_point = cur->ne[0] / 2;
-                // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
-                ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
-                ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
-
-                x0 = ggml_gelu(ctx0, x0);
-                cb(x0, "ffn_gelu", il);
-
-                cur = ggml_mul(ctx0, x0, x1);
+                cur = ggml_geglu(ctx0, cur);
                 cb(cur, "ffn_geglu", il);
             } break;
+        case LLM_FFN_REGLU:
+            {
+                cur = ggml_reglu(ctx0, cur);
+                cb(cur, "ffn_reglu", il);
+            } break;
     }
 
     if (gate && type_gate == LLM_FFN_PAR) {
@@ -737,12 +742,18 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
 
     switch (type_op) {
         case LLM_FFN_SILU:
-            {
+            if (gate_exps) {
+                cur = ggml_swiglu_split(ctx0, cur, up);
+                cb(cur, "ffn_moe_swiglu", il);
+            } else {
                 cur = ggml_silu(ctx0, cur);
                 cb(cur, "ffn_moe_silu", il);
             } break;
         case LLM_FFN_GELU:
-            {
+            if (gate_exps) {
+                cur = ggml_geglu_split(ctx0, cur, up);
+                cb(cur, "ffn_moe_geglu", il);
+            } else {
                 cur = ggml_gelu(ctx0, cur);
                 cb(cur, "ffn_moe_gelu", il);
             } break;
@@ -750,11 +761,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
             GGML_ABORT("fatal error");
     }
 
-    if (gate_exps) {
-        cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens]
-        cb(cur, "ffn_moe_gate_par", il);
-    }
-
     experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
     cb(experts, "ffn_moe_down", il);
 
@@ -950,11 +956,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
 }
 
 ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
-    const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
+    const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
 
-    auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_state);
+    auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
 
-    const auto n_kv = kv_state->get_n_kv();
+    const auto n_kv = mctx_cur->get_n_kv();
 
     auto & cur = inp->pos_bucket;
 
@@ -982,14 +988,14 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
 }
 
 llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
-    const auto * mem_state = static_cast<const llama_memory_hybrid_state *>(mstate);
+    const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
 
-    auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mem_state);
+    auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
 
     {
         GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
 
-        const auto n_kv = inp->mem_state->get_state_attn()->get_n_kv();
+        const auto n_kv = inp->mctx->get_attn()->get_n_kv();
 
         inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
         //cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -999,7 +1005,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
     }
 
     {
-        const auto n_rs = mem_state->get_state_recr()->get_n_rs();
+        const auto n_rs = mctx_cur->get_recr()->get_n_rs();
 
         inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
         ggml_set_input(inp->s_copy);
@@ -1183,14 +1189,14 @@ ggml_tensor * llm_graph_context::build_attn(
 }
 
 llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
-    const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
+    const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
 
-    auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state);
+    auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
 
     {
         GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
 
-        const auto n_kv = kv_state->get_n_kv();
+        const auto n_kv = mctx_cur->get_n_kv();
 
         inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
         //cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1220,19 +1226,19 @@ ggml_tensor * llm_graph_context::build_attn(
     ggml_build_forward_expand(gf, k_cur);
     ggml_build_forward_expand(gf, v_cur);
 
-    const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
+    const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
 
     // store to KV cache
     {
-        ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
-        ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
+        ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
+        ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
     }
 
     const auto & kq_mask = inp->get_kq_mask();
 
     ggml_tensor * q = q_cur;
-    ggml_tensor * k = kv_state->get_k(ctx0, il);
-    ggml_tensor * v = kv_state->get_v(ctx0, il);
+    ggml_tensor * k = mctx_cur->get_k(ctx0, il);
+    ggml_tensor * v = mctx_cur->get_v(ctx0, il);
 
     ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
     cb(cur, "kqv_out", il);
@@ -1267,26 +1273,35 @@ ggml_tensor * llm_graph_context::build_attn(
     // these nodes are added to the graph together so that they are not reordered
     // by doing so, the number of splits in the graph is reduced
     ggml_build_forward_expand(gf, q_cur);
-    ggml_build_forward_expand(gf, k_cur);
-    ggml_build_forward_expand(gf, v_cur);
 
-    const auto * kv_state_iswa = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
+    if (k_cur) {
+        ggml_build_forward_expand(gf, k_cur);
+    }
+
+    if (v_cur) {
+        ggml_build_forward_expand(gf, v_cur);
+    }
+
+    const auto * mctx_iswa = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
 
     const bool is_swa = hparams.is_swa(il);
 
-    const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base();
+    const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
 
-    // store to KV cache
-    {
-        ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
-        ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
+    // optionally store to KV cache
+    if (k_cur) {
+        ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
+    }
+
+    if (v_cur) {
+        ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
     }
 
     const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
 
     ggml_tensor * q = q_cur;
-    ggml_tensor * k = kv_state->get_k(ctx0, il);
-    ggml_tensor * v = kv_state->get_v(ctx0, il);
+    ggml_tensor * k = mctx_cur->get_k(ctx0, il);
+    ggml_tensor * v = mctx_cur->get_v(ctx0, il);
 
     ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
     cb(cur, "kqv_out", il);
@@ -1379,19 +1394,19 @@ ggml_tensor * llm_graph_context::build_attn(
     ggml_build_forward_expand(gf, k_cur);
     ggml_build_forward_expand(gf, v_cur);
 
-    const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_attn();
+    const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_attn();
 
     // store to KV cache
     {
-        ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
-        ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
+        ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
+        ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
     }
 
     const auto & kq_mask = inp->get_kq_mask();
 
     ggml_tensor * q = q_cur;
-    ggml_tensor * k = kv_state->get_k(ctx0, il);
-    ggml_tensor * v = kv_state->get_v(ctx0, il);
+    ggml_tensor * k = mctx_cur->get_k(ctx0, il);
+    ggml_tensor * v = mctx_cur->get_v(ctx0, il);
 
     ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
     cb(cur, "kqv_out", il);
@@ -1412,12 +1427,12 @@ ggml_tensor * llm_graph_context::build_attn(
 }
 
 llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
-    const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
+    const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
 
-    auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
+    auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
 
     {
-        const auto n_kv = kv_state->get_base()->get_n_kv();
+        const auto n_kv = mctx_cur->get_base()->get_n_kv();
 
         inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
         //cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1429,7 +1444,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
     {
         GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
 
-        const auto n_kv = kv_state->get_swa()->get_n_kv();
+        const auto n_kv = mctx_cur->get_swa()->get_n_kv();
 
         inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
         //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
@@ -1485,11 +1500,11 @@ ggml_tensor * llm_graph_context::build_rs(
 }
 
 llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
-    const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
+    const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
 
-    auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
+    auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
 
-    const auto n_rs = kv_state->get_n_rs();
+    const auto n_rs = mctx_cur->get_n_rs();
 
     inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
     ggml_set_input(inp->s_copy);
@@ -1504,9 +1519,9 @@ ggml_tensor * llm_graph_context::build_rs(
             int32_t   state_size,
             int32_t   n_seqs,
                bool   avoid_copies) const {
-    const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
+    const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
 
-    return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
+    return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
 }
 
 ggml_tensor * llm_graph_context::build_rs(
@@ -1516,9 +1531,9 @@ ggml_tensor * llm_graph_context::build_rs(
             int32_t   state_size,
             int32_t   n_seqs,
                bool   avoid_copies) const {
-    const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_recr();
+    const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
 
-    return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
+    return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
 }
 
 ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
@@ -1526,13 +1541,13 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
            ggml_cgraph * gf,
     const llama_ubatch & ubatch,
                  int   il) const {
-    const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
+    const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
 
     const auto token_shift_count = hparams.token_shift_count;
 
     const int64_t n_seqs  = ubatch.n_seqs;
 
-    ggml_tensor * token_shift_all = kv_state->get_r_l(il);
+    ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
 
     ggml_tensor * token_shift = build_rs(
             inp, gf, token_shift_all,
@@ -1547,19 +1562,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
          ggml_tensor * token_shift,
   const llama_ubatch & ubatch,
                  int   il) const {
-    const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
+    const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
 
     const auto token_shift_count = hparams.token_shift_count;
     const auto n_embd = hparams.n_embd;
 
     const int64_t n_seqs = ubatch.n_seqs;
 
-    const auto kv_head = kv_state->get_head();
+    const auto kv_head = mctx_cur->get_head();
 
     return ggml_cpy(
         ctx0,
         ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
-        ggml_view_1d(ctx0, kv_state->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(kv_state->get_r_l(il)))
+        ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
     );
 }
 
index 9e62fa60720d78b9f3aded7ff29bf1d4212197bb..ceddb6021f11479deca09e0f96f148adb611129a 100644 (file)
@@ -17,12 +17,12 @@ struct ggml_tensor;
 struct llama_ubatch;
 struct llama_cparams;
 
-struct llama_memory_state_i;
+struct llama_memory_context_i;
 
-class llama_kv_cache_unified_state;
-class llama_kv_cache_unified_iswa_state;
-class llama_memory_recurrent_state;
-class llama_memory_hybrid_state;
+class llama_kv_cache_unified_context;
+class llama_kv_cache_unified_iswa_context;
+class llama_memory_recurrent_context;
+class llama_memory_hybrid_context;
 
 // certain models (typically multi-modal) can produce different types of graphs
 enum llm_graph_type {
@@ -38,6 +38,7 @@ enum llm_ffn_op_type {
     LLM_FFN_RELU_SQR,
     LLM_FFN_SWIGLU,
     LLM_FFN_GEGLU,
+    LLM_FFN_REGLU,
 };
 
 enum llm_ffn_gate_type {
@@ -136,7 +137,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
 public:
     llm_graph_input_pos_bucket_kv(
             const llama_hparams & hparams,
-            const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {}
+            const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
     virtual ~llm_graph_input_pos_bucket_kv() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
@@ -144,7 +145,8 @@ public:
     ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
 
     const llama_hparams & hparams;
-    const llama_kv_cache_unified_state * kv_state;
+
+    const llama_kv_cache_unified_context * mctx;
 };
 
 class llm_graph_input_out_ids : public llm_graph_input_i {
@@ -191,14 +193,14 @@ public:
 
 class llm_graph_input_rs : public llm_graph_input_i {
 public:
-    llm_graph_input_rs(const llama_memory_recurrent_state * mem_state) : mem_state(mem_state) {}
+    llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
     virtual ~llm_graph_input_rs() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
 
     ggml_tensor * s_copy; // I32 [kv_size]
 
-    const llama_memory_recurrent_state * mem_state;
+    const llama_memory_recurrent_context * mctx;
 };
 
 class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -238,10 +240,10 @@ public:
     llm_graph_input_attn_kv_unified(
             const llama_hparams & hparams,
             const llama_cparams & cparams,
-            const llama_kv_cache_unified_state * kv_state) :
+            const llama_kv_cache_unified_context * mctx) :
         hparams(hparams),
         cparams(cparams),
-        kv_state(kv_state) {
+        mctx(mctx) {
     }
     ~llm_graph_input_attn_kv_unified() = default;
 
@@ -255,7 +257,7 @@ public:
     const llama_hparams & hparams;
     const llama_cparams & cparams;
 
-    const llama_kv_cache_unified_state * kv_state;
+    const llama_kv_cache_unified_context * mctx;
 };
 
 class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
@@ -263,10 +265,10 @@ public:
     llm_graph_input_attn_kv_unified_iswa(
             const llama_hparams & hparams,
             const llama_cparams & cparams,
-            const llama_kv_cache_unified_iswa_state * kv_state) :
+            const llama_kv_cache_unified_iswa_context * mctx) :
         hparams(hparams),
         cparams(cparams),
-        kv_state(kv_state) {
+        mctx(mctx) {
     }
     ~llm_graph_input_attn_kv_unified_iswa() = default;
 
@@ -283,7 +285,7 @@ public:
     const llama_hparams & hparams;
     const llama_cparams & cparams;
 
-    const llama_kv_cache_unified_iswa_state * kv_state;
+    const llama_kv_cache_unified_iswa_context * mctx;
 };
 
 class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -306,10 +308,10 @@ public:
     llm_graph_input_mem_hybrid(
             const llama_hparams & hparams,
             const llama_cparams & cparams,
-            const llama_memory_hybrid_state * mem_state) :
+            const llama_memory_hybrid_context * mctx) :
         hparams(hparams),
         cparams(cparams),
-        mem_state(mem_state) {
+        mctx(mctx) {
     }
     virtual ~llm_graph_input_mem_hybrid() = default;
 
@@ -325,7 +327,18 @@ public:
     const llama_hparams & hparams;
     const llama_cparams & cparams;
 
-    const llama_memory_hybrid_state * mem_state;
+    const llama_memory_hybrid_context * mctx;
+};
+
+// TODO: remove this when ggml_scale_add is implemented
+class llm_graph_input_one : public llm_graph_input_i {
+public:
+    llm_graph_input_one() {}
+    virtual ~llm_graph_input_one() = default;
+
+    void set_input(const llama_ubatch *) override;
+
+    ggml_tensor * one = nullptr; // F32
 };
 
 //
@@ -401,10 +414,10 @@ struct llm_graph_params {
     ggml_backend_sched_t sched;
     ggml_backend_t backend_cpu;
 
-    const llama_adapter_cvec   * cvec;
-    const llama_adapter_loras  * loras;
-    const llama_memory_state_i * mstate;
-    const llama_cross          * cross;
+    const llama_adapter_cvec     * cvec;
+    const llama_adapter_loras    * loras;
+    const llama_memory_context_i * mctx;
+    const llama_cross            * cross;
 
     uint32_t n_outputs;
 
@@ -453,16 +466,17 @@ struct llm_graph_context {
 
     ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
 
-    const llama_adapter_cvec   * cvec;
-    const llama_adapter_loras  * loras;
-    const llama_memory_state_i * mstate;
-    const llama_cross          * cross;
+    const llama_adapter_cvec     * cvec;
+    const llama_adapter_loras    * loras;
+    const llama_memory_context_i * mctx;
+    const llama_cross            * cross;
 
     const llm_graph_cb & cb_func;
 
     std::unique_ptr<llm_graph_result> res;
 
     llm_graph_context(const llm_graph_params & params);
+    virtual ~llm_graph_context() = default;
 
     void cb(ggml_tensor * cur, const char * name, int il) const;
 
@@ -588,14 +602,15 @@ struct llm_graph_context {
 
     llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
 
+    // note: if k_cur or v_cur are not provided, they will not be stored in the memory
     ggml_tensor * build_attn(
             llm_graph_input_attn_kv_unified_iswa * inp,
             ggml_cgraph * gf,
             ggml_tensor * wo,
             ggml_tensor * wo_b,
             ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
-            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
-            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
+            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
+            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
             ggml_tensor * kq_b,
             ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
                   float   kq_scale,
index 7b315a9a74b1da6669ba64eb2c6efe3ab88c8d81..e85afe145a922a57d74b887d343857c35cff5ced 100644 (file)
@@ -143,6 +143,12 @@ struct llama_hparams {
     uint32_t n_attn_temp_floor_scale = 8192;
     float    f_attn_temp_scale       = 0.1;
 
+    // gemma3n altup
+    uint32_t n_altup      = 4; // altup_num_inputs
+    uint32_t i_altup_act  = 0; // altup_active_idx
+    uint32_t laurel_rank  = 64;
+    uint32_t n_embd_altup = 256;
+
     // needed by encoder-decoder models (e.g. T5, FLAN-T5)
     // ref: https://github.com/ggerganov/llama.cpp/pull/8141
     llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
index 0ced340dec6c5c787f767163d4d3849785f31d61..d1f839b63aaf55fd61bd6a422f722ceca4adaac4 100644 (file)
@@ -95,7 +95,7 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
     return kv_swa->seq_pos_max(seq_id);
 }
 
-llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
+llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
     GGML_UNUSED(embd_all);
 
     // first try simple split
@@ -125,7 +125,7 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_alloc
 
         assert(heads_base.size() == heads_swa.size());
 
-        return std::make_unique<llama_kv_cache_unified_iswa_state>(
+        return std::make_unique<llama_kv_cache_unified_iswa_context>(
                 this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
     } while (false);
 
@@ -156,22 +156,22 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_alloc
 
         assert(heads_base.size() == heads_swa.size());
 
-        return std::make_unique<llama_kv_cache_unified_iswa_state>(
+        return std::make_unique<llama_kv_cache_unified_iswa_context>(
                 this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
     } while (false);
 
     // TODO: if we fail again, we should attempt different splitting strategies
     //       but to do that properly, we first have to refactor the batches to be more flexible
 
-    return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+    return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
 }
 
-llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
-    return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
+llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() {
+    return std::make_unique<llama_kv_cache_unified_iswa_context>(this);
 }
 
-llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
-    return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
+llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
+    return std::make_unique<llama_kv_cache_unified_iswa_context>(this, lctx, optimize);
 }
 
 bool llama_kv_cache_unified_iswa::get_can_shift() const {
@@ -197,46 +197,46 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
 }
 
 //
-// llama_kv_cache_unified_iswa_state
+// llama_kv_cache_unified_iswa_context
 //
 
-llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
+llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {}
 
-llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
+llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
         llama_kv_cache_unified_iswa * kv) :
-    state_base(kv->get_base()->init_full()),
-    state_swa (kv->get_swa ()->init_full()),
-    status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
+    ctx_base(kv->get_base()->init_full()),
+    ctx_swa (kv->get_swa ()->init_full()),
+    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
 }
 
-llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
+llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
         llama_kv_cache_unified_iswa * kv,
         llama_context * lctx,
         bool optimize) :
-    state_base(kv->get_base()->init_update(lctx, optimize)),
-    state_swa (kv->get_swa ()->init_update(lctx, optimize)),
-    status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
+    ctx_base(kv->get_base()->init_update(lctx, optimize)),
+    ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
+    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
 }
 
-llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
+llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
         llama_kv_cache_unified_iswa * kv,
         std::vector<uint32_t> heads_base,
         std::vector<uint32_t> heads_swa,
         std::vector<llama_ubatch> ubatches) :
     ubatches(std::move(ubatches)),
     // note: here we copy the ubatches. not sure if this is ideal
-    state_base(new llama_kv_cache_unified_state(kv->get_base(), std::move(heads_base), this->ubatches)),
-    state_swa (new llama_kv_cache_unified_state(kv->get_swa (), std::move(heads_swa),  this->ubatches)),
-    status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
+    ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
+    ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa),  this->ubatches)),
+    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
 }
 
-llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
+llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default;
 
-bool llama_kv_cache_unified_iswa_state::next() {
+bool llama_kv_cache_unified_iswa_context::next() {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
-    state_base->next();
-    state_swa ->next();
+    ctx_base->next();
+    ctx_swa ->next();
 
     if (++i_next >= ubatches.size()) {
         return false;
@@ -245,35 +245,35 @@ bool llama_kv_cache_unified_iswa_state::next() {
     return true;
 }
 
-bool llama_kv_cache_unified_iswa_state::apply() {
-    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+bool llama_kv_cache_unified_iswa_context::apply() {
+    assert(!llama_memory_status_is_fail(status));
 
     bool res = true;
 
-    res = res & state_base->apply();
-    res = res & state_swa ->apply();
+    res = res & ctx_base->apply();
+    res = res & ctx_swa ->apply();
 
     return res;
 }
 
-llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
+llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const {
     return status;
 }
 
-const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
+const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
     return ubatches[i_next];
 }
 
-const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
+const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
-    return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
+    return static_cast<const llama_kv_cache_unified_context *>(ctx_base.get());
 }
 
-const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa()  const {
+const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa()  const {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
-    return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
+    return static_cast<const llama_kv_cache_unified_context *>(ctx_swa.get());
 }
index 071041585db38588217547ef556470ddbae2fdc1..46c1ed614f2f0166960d86035396d26a76520420 100644 (file)
@@ -31,14 +31,14 @@ public:
     // llama_memory_i
     //
 
-    llama_memory_state_ptr init_batch(
+    llama_memory_context_ptr init_batch(
             llama_batch_allocr & balloc,
             uint32_t n_ubatch,
             bool embd_all) override;
 
-    llama_memory_state_ptr init_full() override;
+    llama_memory_context_ptr init_full() override;
 
-    llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
+    llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
 
     bool get_can_shift() const override;
 
@@ -72,32 +72,32 @@ private:
     std::unique_ptr<llama_kv_cache_unified> kv_swa;
 };
 
-class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
+class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
 public:
     // used for errors
-    llama_kv_cache_unified_iswa_state(llama_memory_status status);
+    llama_kv_cache_unified_iswa_context(llama_memory_status status);
 
-    // used to create a full-cache state
-    llama_kv_cache_unified_iswa_state(
+    // used to create a full-cache context
+    llama_kv_cache_unified_iswa_context(
             llama_kv_cache_unified_iswa * kv);
 
-    // used to create an update state
-    llama_kv_cache_unified_iswa_state(
+    // used to create an update context
+    llama_kv_cache_unified_iswa_context(
             llama_kv_cache_unified_iswa * kv,
             llama_context * lctx,
             bool optimize);
 
-    // used to create a state from a batch
-    llama_kv_cache_unified_iswa_state(
+    // used to create a batch processing context from a batch
+    llama_kv_cache_unified_iswa_context(
             llama_kv_cache_unified_iswa * kv,
             std::vector<uint32_t> heads_base,
             std::vector<uint32_t> heads_swa,
             std::vector<llama_ubatch> ubatches);
 
-    virtual ~llama_kv_cache_unified_iswa_state();
+    virtual ~llama_kv_cache_unified_iswa_context();
 
     //
-    // llama_memory_state_i
+    // llama_memory_context_i
     //
 
     bool next()  override;
@@ -107,11 +107,11 @@ public:
     const llama_ubatch & get_ubatch() const override;
 
     //
-    // llama_kv_cache_unified_iswa_state specific API
+    // llama_kv_cache_unified_iswa_context specific API
     //
 
-    const llama_kv_cache_unified_state * get_base() const;
-    const llama_kv_cache_unified_state * get_swa()  const;
+    const llama_kv_cache_unified_context * get_base() const;
+    const llama_kv_cache_unified_context * get_swa()  const;
 
 private:
     //llama_kv_cache_unified_iswa * kv;
@@ -121,8 +121,8 @@ private:
 
     std::vector<llama_ubatch> ubatches;
 
-    const llama_memory_state_ptr state_base;
-    const llama_memory_state_ptr state_swa;
+    const llama_memory_context_ptr ctx_base;
+    const llama_memory_context_ptr ctx_swa;
 
     const llama_memory_status status;
 };
index 6897b797153dbe894b8d66cf9f626463efe47a9c..7f7b162ffd7cefd524ab9676110d0c61051574ef 100644 (file)
@@ -33,13 +33,19 @@ llama_kv_cache_unified::llama_kv_cache_unified(
 
     GGML_ASSERT(kv_size % n_pad == 0);
 
+    // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
+    auto n_layer_cache = hparams.n_layer;
+    if (model.arch == LLM_ARCH_GEMMA3N) {
+        n_layer_cache = 20;
+    }
+
     // create a context for each buffer type
     std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
     auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
         auto it = ctx_map.find(buft);
         if (it == ctx_map.end()) {
             ggml_init_params params = {
-                /*.mem_size   =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()),
+                /*.mem_size   =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()),
                 /*.mem_buffer =*/ NULL,
                 /*.no_alloc   =*/ true,
             };
@@ -62,7 +68,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
 
     cells.resize(kv_size);
 
-    for (uint32_t il = 0; il < hparams.n_layer; il++) {
+    for (uint32_t il = 0; il < n_layer_cache; il++) {
         if (filter && !filter(il)) {
             LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
             continue;
@@ -102,6 +108,26 @@ llama_kv_cache_unified::llama_kv_cache_unified(
         layers.push_back({ il, k, v });
     }
 
+    // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
+    if (model.arch == LLM_ARCH_GEMMA3N) {
+        LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1);
+
+        for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) {
+            if (filter && !filter(il)) {
+                LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
+                continue;
+            }
+
+            const bool     is_swa   = hparams.is_swa(il);
+            const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1);
+
+            GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
+            map_layer_ids[il] = map_layer_ids[il_reuse];
+
+            LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa);
+        }
+    }
+
     // allocate tensors and initialize the buffers to avoid NaNs in the padding
     for (auto it : ctx_map) {
         auto * buft = it.first;
@@ -307,7 +333,7 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
     return cells.seq_pos_max(seq_id);
 }
 
-llama_memory_state_ptr llama_kv_cache_unified::init_batch(
+llama_memory_context_ptr llama_kv_cache_unified::init_batch(
             llama_batch_allocr & balloc,
             uint32_t n_ubatch,
             bool embd_all) {
@@ -332,18 +358,18 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
             break;
         }
 
-        return std::make_unique<llama_kv_cache_unified_state>(
+        return std::make_unique<llama_kv_cache_unified_context>(
                 this, std::move(heads), std::move(ubatches));
     } while (false);
 
-    return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+    return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
 }
 
-llama_memory_state_ptr llama_kv_cache_unified::init_full() {
-    return std::make_unique<llama_kv_cache_unified_state>(this);
+llama_memory_context_ptr llama_kv_cache_unified::init_full() {
+    return std::make_unique<llama_kv_cache_unified_context>(this);
 }
 
-llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
+llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
     bool do_shift = get_has_shift();
 
     defrag_info dinfo;
@@ -373,7 +399,7 @@ llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx,
         }
     }
 
-    return std::make_unique<llama_kv_cache_unified_state>(this, lctx, do_shift, std::move(dinfo));
+    return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
 }
 
 llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
@@ -1710,18 +1736,18 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
 }
 
 //
-// llama_kv_cache_unified_state
+// llama_kv_cache_unified_context
 //
 
-llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
+llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_status status) : status(status) {}
 
-llama_kv_cache_unified_state::llama_kv_cache_unified_state(
+llama_kv_cache_unified_context::llama_kv_cache_unified_context(
         llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
     n_kv = kv->get_size();
     head = 0;
 }
 
-llama_kv_cache_unified_state::llama_kv_cache_unified_state(
+llama_kv_cache_unified_context::llama_kv_cache_unified_context(
         llama_kv_cache_unified * kv,
         llama_context * lctx,
         bool do_shift,
@@ -1731,15 +1757,15 @@ llama_kv_cache_unified_state::llama_kv_cache_unified_state(
     }
 }
 
-llama_kv_cache_unified_state::llama_kv_cache_unified_state(
+llama_kv_cache_unified_context::llama_kv_cache_unified_context(
         llama_kv_cache_unified * kv,
         llama_kv_cache_unified::ubatch_heads heads,
         std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
 }
 
-llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
+llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
 
-bool llama_kv_cache_unified_state::next() {
+bool llama_kv_cache_unified_context::next() {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
     if (++i_next >= ubatches.size()) {
@@ -1749,8 +1775,8 @@ bool llama_kv_cache_unified_state::next() {
     return true;
 }
 
-bool llama_kv_cache_unified_state::apply() {
-    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+bool llama_kv_cache_unified_context::apply() {
+    assert(!llama_memory_status_is_fail(status));
 
     // no ubatches -> this is a KV cache update
     if (ubatches.empty()) {
@@ -1767,45 +1793,45 @@ bool llama_kv_cache_unified_state::apply() {
     return true;
 }
 
-llama_memory_status llama_kv_cache_unified_state::get_status() const {
+llama_memory_status llama_kv_cache_unified_context::get_status() const {
     return status;
 }
 
-const llama_ubatch & llama_kv_cache_unified_state::get_ubatch() const {
+const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
     return ubatches[i_next];
 }
 
-uint32_t llama_kv_cache_unified_state::get_n_kv() const {
+uint32_t llama_kv_cache_unified_context::get_n_kv() const {
     return n_kv;
 }
 
-ggml_tensor * llama_kv_cache_unified_state::get_k(ggml_context * ctx, int32_t il) const {
+ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
     return kv->get_k(ctx, il, n_kv);
 }
 
-ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il) const {
+ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
     return kv->get_v(ctx, il, n_kv);
 }
 
-ggml_tensor * llama_kv_cache_unified_state::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
+ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
     return kv->cpy_k(ctx, k_cur, il, head);
 }
 
-ggml_tensor * llama_kv_cache_unified_state::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
+ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
     return kv->cpy_v(ctx, v_cur, il, head);
 }
 
-void llama_kv_cache_unified_state::set_input_k_shift(ggml_tensor * dst) const {
+void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
     kv->set_input_k_shift(dst);
 }
 
-void llama_kv_cache_unified_state::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
+void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
     kv->set_input_kq_mask(dst, ubatch, causal_attn);
 }
 
-void llama_kv_cache_unified_state::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
+void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
     kv->set_input_pos_bucket(dst, ubatch);
 }
 
index 1560640045c82d9267bc59b472eaf7b25fd98c3f..4c53f1273ab88326cb21f6bc25fd999935a8b761 100644 (file)
@@ -56,14 +56,14 @@ public:
     // llama_memory_i
     //
 
-    llama_memory_state_ptr init_batch(
+    llama_memory_context_ptr init_batch(
             llama_batch_allocr & balloc,
             uint32_t n_ubatch,
             bool embd_all) override;
 
-    llama_memory_state_ptr init_full() override;
+    llama_memory_context_ptr init_full() override;
 
-    llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
+    llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
 
     bool get_can_shift() const override;
 
@@ -208,36 +208,36 @@ private:
     bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
 };
 
-class llama_kv_cache_unified_state : public llama_memory_state_i {
+class llama_kv_cache_unified_context : public llama_memory_context_i {
 public:
     // some shorthands
     using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
     using defrag_info  = llama_kv_cache_unified::defrag_info;
 
     // used for errors
-    llama_kv_cache_unified_state(llama_memory_status status);
+    llama_kv_cache_unified_context(llama_memory_status status);
 
-    // used to create a full-cache state
-    llama_kv_cache_unified_state(
+    // used to create a full-cache context
+    llama_kv_cache_unified_context(
             llama_kv_cache_unified * kv);
 
-    // used to create an update state
-    llama_kv_cache_unified_state(
+    // used to create an update context
+    llama_kv_cache_unified_context(
             llama_kv_cache_unified * kv,
             llama_context * lctx,
             bool do_shift,
             defrag_info dinfo);
 
-    // used to create a decode state from a batch
-    llama_kv_cache_unified_state(
+    // used to create a batch procesing context from a batch
+    llama_kv_cache_unified_context(
             llama_kv_cache_unified * kv,
             ubatch_heads heads,
             std::vector<llama_ubatch> ubatches);
 
-    virtual ~llama_kv_cache_unified_state();
+    virtual ~llama_kv_cache_unified_context();
 
     //
-    // llama_memory_state_i
+    // llama_memory_context_i
     //
 
     bool next()  override;
@@ -247,7 +247,7 @@ public:
     const llama_ubatch & get_ubatch() const override;
 
     //
-    // llama_kv_cache_unified_state specific API
+    // llama_kv_cache_unified_context specific API
     //
 
     uint32_t get_n_kv() const;
@@ -272,7 +272,7 @@ private:
     llama_context * lctx;
 
     //
-    // update state
+    // update context
     //
 
     bool do_shift = false;
@@ -280,7 +280,7 @@ private:
     defrag_info dinfo;
 
     //
-    // batch processing state
+    // batch processing context
     //
 
     // the index of the next ubatch to process
index 349e9032e2484b71b55c5ae14664d0d498d3d651..c95d635948b5d61190cbaa5a128d339cac64910c 100644 (file)
@@ -7,6 +7,7 @@
 #include <cassert>
 #include <vector>
 #include <set>
+#include <map>
 
 // meta information about KV cells that can be part of multiple sequences at the same time
 // TODO: add unit tests
@@ -164,7 +165,7 @@ public:
         assert(seq_id >= 0);
 
         seq[i].reset(seq_id);
-        seq_pos[seq_id].erase(pos[i]);
+        seq_pos_dec(seq_id, pos[i]);
 
         if (seq[i].none()) {
             pos[i] = -1;
@@ -187,7 +188,7 @@ public:
             seq[i].reset();
 
             seq[i].set(seq_id);
-            seq_pos[seq_id].insert(pos[i]);
+            seq_pos_inc(seq_id, pos[i]);
 
             return false;
         }
@@ -232,7 +233,7 @@ public:
         assert(!seq[i].test(seq_id));
 
         seq[i].set(seq_id);
-        seq_pos[seq_id].insert(pos[i]);
+        seq_pos_inc(seq_id, pos[i]);
     }
 
     // return the sequence id of this cell
@@ -259,7 +260,9 @@ public:
             return -1;
         }
 
-        return *seq_pos[seq_id].begin();
+        assert(seq_pos[seq_id].begin()->second > 0);
+
+        return seq_pos[seq_id].begin()->first;
     }
 
     // the maximum position of sequence seq_id currently present in any of the cells
@@ -272,7 +275,9 @@ public:
             return -1;
         }
 
-        return *seq_pos[seq_id].rbegin();
+        assert(seq_pos[seq_id].rbegin()->second > 0);
+
+        return seq_pos[seq_id].rbegin()->first;
     }
 
     // note: call only if the cell is not empty
@@ -389,17 +394,36 @@ private:
     // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
     std::vector<seq_set_t> seq;
 
-    // the set seq_pos[s] tells us which positions are currently present for sequence s
+    // the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
+    // if the position p is not present, seq_pos[s][p] is not set
     // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
-    std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
+    //
+    // note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
+    //  - during performing a cache reuse via (rm + add)
+    //  - some vision models have input embeddings with repeating positions
+    //
+    std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
 
     // helper functions for updating `seq_pos`, once cell at a time:
 
+    void seq_pos_dec(llama_seq_id s, llama_pos p) {
+        auto it = seq_pos[s].find(p);
+        assert(it != seq_pos[s].end());
+
+        if (--it->second == 0) {
+            seq_pos[s].erase(it);
+        }
+    }
+
+    void seq_pos_inc(llama_seq_id s, llama_pos p) {
+        seq_pos[s][p]++;
+    }
+
     // remove cell i
     void seq_pos_rm(uint32_t i) {
         for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
             if (seq[i].test(s)) {
-                seq_pos[s].erase(pos[i]);
+                seq_pos_dec(s, pos[i]);
             }
         }
     }
@@ -408,7 +432,7 @@ private:
     void seq_pos_add(uint32_t i) {
         for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
             if (seq[i].test(s)) {
-                seq_pos[s].insert(pos[i]);
+                seq_pos_inc(s, pos[i]);
             }
         }
     }
index 1b16686819eff8582dc7ab6eea4334cb304f69f8..67cbf955482354a1097c5454107acc4afe4c8c36 100644 (file)
@@ -56,7 +56,7 @@ llama_memory_hybrid::llama_memory_hybrid(
         n_seq_max
     )) {}
 
-llama_memory_state_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
+llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
     do {
         balloc.split_reset();
 
@@ -82,31 +82,31 @@ llama_memory_state_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ball
 
         // prepare the recurrent batches first
         if (!mem_recr->prepare(ubatches)) {
-            // TODO: will the recurrent cache be in an undefined state at this point?
+            // TODO: will the recurrent cache be in an undefined context at this point?
             LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
-            return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+            return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
         }
 
         // prepare the attention cache
         auto heads_attn = mem_attn->prepare(ubatches);
         if (heads_attn.empty()) {
             LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
-            return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+            return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
         }
 
-        return std::make_unique<llama_memory_hybrid_state>(
+        return std::make_unique<llama_memory_hybrid_context>(
                 this, std::move(heads_attn), std::move(ubatches));
     } while(false);
 
-    return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+    return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
 }
 
-llama_memory_state_ptr llama_memory_hybrid::init_full() {
-    return std::make_unique<llama_memory_hybrid_state>(this);
+llama_memory_context_ptr llama_memory_hybrid::init_full() {
+    return std::make_unique<llama_memory_hybrid_context>(this);
 }
 
-llama_memory_state_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
-    return std::make_unique<llama_memory_hybrid_state>(this, lctx, optimize);
+llama_memory_context_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
+    return std::make_unique<llama_memory_hybrid_context>(this, lctx, optimize);
 }
 
 bool llama_memory_hybrid::get_can_shift() const {
@@ -176,39 +176,39 @@ llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
     return mem_recr.get();
 }
 
-llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_status status) : status(status) {}
+llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_status status) : status(status) {}
 
-llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_hybrid * mem) :
-    state_attn(mem->get_mem_attn()->init_full()),
-    state_recr(mem->get_mem_recr()->init_full()),
-    status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
+llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_hybrid * mem) :
+    ctx_attn(mem->get_mem_attn()->init_full()),
+    ctx_recr(mem->get_mem_recr()->init_full()),
+    status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
 }
 
-llama_memory_hybrid_state::llama_memory_hybrid_state(
+llama_memory_hybrid_context::llama_memory_hybrid_context(
         llama_memory_hybrid * mem,
               llama_context * lctx,
                        bool   optimize) :
-    state_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
-    state_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
-    status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
+    ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
+    ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
+    status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
 }
 
-llama_memory_hybrid_state::llama_memory_hybrid_state(
+llama_memory_hybrid_context::llama_memory_hybrid_context(
               llama_memory_hybrid * mem,
             std::vector<uint32_t>   heads_attn,
         std::vector<llama_ubatch>   ubatches) :
     ubatches(std::move(ubatches)),
     // note: here we copy the ubatches. not sure if this is ideal
-    state_attn(new llama_kv_cache_unified_state(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
-    state_recr(new llama_memory_recurrent_state(mem->get_mem_recr(),                        this->ubatches)),
-    status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
+    ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
+    ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(),                        this->ubatches)),
+    status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
 }
 
-bool llama_memory_hybrid_state::next() {
+bool llama_memory_hybrid_context::next() {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
-    state_attn->next();
-    state_recr->next();
+    ctx_attn->next();
+    ctx_recr->next();
 
     if (++i_next >= ubatches.size()) {
         return false;
@@ -217,30 +217,30 @@ bool llama_memory_hybrid_state::next() {
     return true;
 }
 
-bool llama_memory_hybrid_state::apply() {
-    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+bool llama_memory_hybrid_context::apply() {
+    assert(!llama_memory_status_is_fail(status));
 
     bool res = true;
 
-    res = res & state_attn->apply();
-    res = res & state_recr->apply();
+    res = res & ctx_attn->apply();
+    res = res & ctx_recr->apply();
 
     return res;
 }
 
-llama_memory_status llama_memory_hybrid_state::get_status() const {
+llama_memory_status llama_memory_hybrid_context::get_status() const {
     return status;
 }
 
-const llama_ubatch & llama_memory_hybrid_state::get_ubatch() const {
+const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
     return ubatches[i_next];
 }
 
-const llama_kv_cache_unified_state * llama_memory_hybrid_state::get_state_attn() const {
-    return static_cast<const llama_kv_cache_unified_state *>(state_attn.get());
+const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const {
+    return static_cast<const llama_kv_cache_unified_context *>(ctx_attn.get());
 }
 
-const llama_memory_recurrent_state * llama_memory_hybrid_state::get_state_recr() const {
-    return static_cast<const llama_memory_recurrent_state *>(state_recr.get());
+const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {
+    return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get());
 }
index 4d27ab896aa05177b8711b7fad6a1e39b13653e6..f0c2420e9a2df5bc04093c07e972a837b8ccd1aa 100644 (file)
@@ -49,14 +49,14 @@ public:
     // llama_memory_i
     //
 
-    llama_memory_state_ptr init_batch(
+    llama_memory_context_ptr init_batch(
             llama_batch_allocr & balloc,
             uint32_t n_ubatch,
             bool embd_all) override;
 
-    llama_memory_state_ptr init_full() override;
+    llama_memory_context_ptr init_full() override;
 
-    llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
+    llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
 
     bool get_can_shift() const override;
 
@@ -90,27 +90,27 @@ private:
     const std::unique_ptr<llama_memory_recurrent> mem_recr;
 };
 
-class llama_memory_hybrid_state : public llama_memory_state_i {
+class llama_memory_hybrid_context : public llama_memory_context_i {
 public:
     // init failure
-    explicit llama_memory_hybrid_state(llama_memory_status status);
+    explicit llama_memory_hybrid_context(llama_memory_status status);
 
     // init full
-    explicit llama_memory_hybrid_state(llama_memory_hybrid * mem);
+    explicit llama_memory_hybrid_context(llama_memory_hybrid * mem);
 
     // init update
-    explicit llama_memory_hybrid_state(
+    explicit llama_memory_hybrid_context(
         llama_memory_hybrid * mem,
               llama_context * lctx,
                        bool   optimize);
 
     // init success
-    llama_memory_hybrid_state(
+    llama_memory_hybrid_context(
               llama_memory_hybrid * mem,
             std::vector<uint32_t>   heads_attn,
         std::vector<llama_ubatch>   ubatches);
 
-    ~llama_memory_hybrid_state() = default;
+    ~llama_memory_hybrid_context() = default;
 
     bool next()  override;
     bool apply() override;
@@ -119,11 +119,11 @@ public:
     const llama_ubatch & get_ubatch() const override;
 
     //
-    // llama_memory_hybrid_state
+    // llama_memory_hybrid_context
     //
 
-    const llama_kv_cache_unified_state * get_state_attn() const;
-    const llama_memory_recurrent_state * get_state_recr() const;
+    const llama_kv_cache_unified_context * get_attn() const;
+    const llama_memory_recurrent_context * get_recr() const;
 
 private:
     // the index of the next ubatch to process
@@ -131,8 +131,8 @@ private:
 
     std::vector<llama_ubatch> ubatches;
 
-    const llama_memory_state_ptr state_attn;
-    const llama_memory_state_ptr state_recr;
+    const llama_memory_context_ptr ctx_attn;
+    const llama_memory_context_ptr ctx_recr;
 
     const llama_memory_status status;
 };
index b064da0084c5295f3a28b93c9dbfb1a15452f6f1..6ed84057ccfe25b36006faee916fbcd7218e95c9 100644 (file)
@@ -362,42 +362,47 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
     return result;
 }
 
-llama_memory_state_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
-    std::vector<llama_ubatch> ubatches;
+llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
+    do {
+        balloc.split_reset();
 
-    while (true) {
-        llama_ubatch ubatch;
+        std::vector<llama_ubatch> ubatches;
+        while (true) {
+            llama_ubatch ubatch;
 
-        if (embd_all) {
-            // if all tokens are output, split by sequence
-            ubatch = balloc.split_seq(n_ubatch);
-        } else {
-            ubatch = balloc.split_equal(n_ubatch);
+            if (embd_all) {
+                // if all tokens are output, split by sequence
+                ubatch = balloc.split_seq(n_ubatch);
+            } else {
+                ubatch = balloc.split_equal(n_ubatch);
+            }
+
+            if (ubatch.n_tokens == 0) {
+                break;
+            }
+
+            ubatches.push_back(std::move(ubatch)); // NOLINT
         }
 
-        if (ubatch.n_tokens == 0) {
+        if (!prepare(ubatches)) {
             break;
         }
 
-        ubatches.push_back(std::move(ubatch)); // NOLINT
-    }
-
-    if (!prepare(ubatches)) {
-        return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
-    }
+        return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
+    } while (false);
 
-    return std::make_unique<llama_memory_recurrent_state>(this, std::move(ubatches));
+    return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
 }
 
-llama_memory_state_ptr llama_memory_recurrent::init_full() {
-    return std::make_unique<llama_memory_recurrent_state>(this);
+llama_memory_context_ptr llama_memory_recurrent::init_full() {
+    return std::make_unique<llama_memory_recurrent_context>(this);
 }
 
-llama_memory_state_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
+llama_memory_context_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
     GGML_UNUSED(lctx);
     GGML_UNUSED(optimize);
 
-    return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
+    return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_NO_UPDATE);
 }
 
 bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
@@ -1040,22 +1045,22 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
 }
 
 //
-// llama_memory_recurrent_state
+// llama_memory_recurrent_context
 //
 
-llama_memory_recurrent_state::llama_memory_recurrent_state(llama_memory_status status) : status(status) {}
+llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {}
 
-llama_memory_recurrent_state::llama_memory_recurrent_state(
+llama_memory_recurrent_context::llama_memory_recurrent_context(
         llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
 }
 
-llama_memory_recurrent_state::llama_memory_recurrent_state(
+llama_memory_recurrent_context::llama_memory_recurrent_context(
         llama_memory_recurrent * mem,
         std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
 
-llama_memory_recurrent_state::~llama_memory_recurrent_state() = default;
+llama_memory_recurrent_context::~llama_memory_recurrent_context() = default;
 
-bool llama_memory_recurrent_state::next() {
+bool llama_memory_recurrent_context::next() {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
     if (++i_next >= ubatches.size()) {
@@ -1065,48 +1070,56 @@ bool llama_memory_recurrent_state::next() {
     return true;
 }
 
-bool llama_memory_recurrent_state::apply() {
-    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+bool llama_memory_recurrent_context::apply() {
+    assert(!llama_memory_status_is_fail(status));
+
+    // no ubatches -> this is an update
+    if (ubatches.empty()) {
+        // recurrent cache never performs updates
+        assert(status == LLAMA_MEMORY_STATUS_NO_UPDATE);
+
+        return true;
+    }
 
     mem->find_slot(ubatches[i_next]);
 
     return true;
 }
 
-llama_memory_status llama_memory_recurrent_state::get_status() const {
+llama_memory_status llama_memory_recurrent_context::get_status() const {
     return status;
 }
 
-const llama_ubatch & llama_memory_recurrent_state::get_ubatch() const {
+const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
     return ubatches[i_next];
 }
 
-uint32_t llama_memory_recurrent_state::get_n_rs() const {
+uint32_t llama_memory_recurrent_context::get_n_rs() const {
     return is_full ? mem->size : mem->n;
 }
 
-uint32_t llama_memory_recurrent_state::get_head() const {
+uint32_t llama_memory_recurrent_context::get_head() const {
     return is_full ? 0 : mem->head;
 }
 
-int32_t llama_memory_recurrent_state::get_rs_z() const {
+int32_t llama_memory_recurrent_context::get_rs_z() const {
     return is_full ? 0 : mem->rs_z;
 }
 
-uint32_t llama_memory_recurrent_state::get_size() const {
+uint32_t llama_memory_recurrent_context::get_size() const {
     return mem->size;
 }
 
-ggml_tensor * llama_memory_recurrent_state::get_r_l(int32_t il) const {
+ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const {
     return mem->r_l[il];
 }
 
-ggml_tensor * llama_memory_recurrent_state::get_s_l(int32_t il) const {
+ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
     return mem->s_l[il];
 }
 
-int32_t llama_memory_recurrent_state::s_copy(int i) const {
+int32_t llama_memory_recurrent_context::s_copy(int i) const {
     return  mem->cells[i + mem->head].src0;
 }
index be58dae7cfe33184cfe1e08bc9c95c1b9e1ffa49..4d094f9a05788cb3fa18b29bc188de1926578225 100644 (file)
@@ -11,8 +11,8 @@
 // llama_memory_recurrent
 //
 
-// TODO: extract the cache state used for graph computation into llama_memory_recurrent_state_i
-//       see the implementation of llama_kv_cache_unified_state_i for an example how to do it
+// TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i
+//       see the implementation of llama_kv_cache_unified_context_i for an example how to do it
 class llama_memory_recurrent : public llama_memory_i {
 public:
 
@@ -34,14 +34,14 @@ public:
     // llama_memory_i
     //
 
-    llama_memory_state_ptr init_batch(
+    llama_memory_context_ptr init_batch(
             llama_batch_allocr & balloc,
             uint32_t n_ubatch,
             bool embd_all) override;
 
-    llama_memory_state_ptr init_full() override;
+    llama_memory_context_ptr init_full() override;
 
-    llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
+    llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
 
     void clear(bool data) override;
 
@@ -125,24 +125,24 @@ private:
     bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
 };
 
-class llama_memory_recurrent_state : public llama_memory_state_i {
+class llama_memory_recurrent_context : public llama_memory_context_i {
 public:
     // used for errors
-    llama_memory_recurrent_state(llama_memory_status status);
+    llama_memory_recurrent_context(llama_memory_status status);
 
-    // used to create a full-cache state
-    llama_memory_recurrent_state(
+    // used to create a full-cache or update context
+    llama_memory_recurrent_context(
             llama_memory_recurrent * mem);
 
-    // used to create a state from a batch
-    llama_memory_recurrent_state(
+    // used to create a batch processing context from a batch
+    llama_memory_recurrent_context(
             llama_memory_recurrent * mem,
             std::vector<llama_ubatch> ubatches);
 
-    virtual ~llama_memory_recurrent_state();
+    virtual ~llama_memory_recurrent_context();
 
     //
-    // llama_memory_state_i
+    // llama_memory_context_i
     //
 
     bool next()  override;
@@ -152,7 +152,7 @@ public:
     const llama_ubatch & get_ubatch() const override;
 
     //
-    // llama_memory_recurrent_state specific API
+    // llama_memory_recurrent_context specific API
     //
 
     uint32_t get_n_rs() const;
index f1107672c6476411b04521db02379255328e7728..ca6844c32a76748cb278d90e1898dfa788dd1e3f 100644 (file)
@@ -40,3 +40,20 @@ llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_me
     // if either status has an update, then the combined status has an update
     return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE;
 }
+
+bool llama_memory_status_is_fail(llama_memory_status status) {
+    switch (status) {
+        case LLAMA_MEMORY_STATUS_SUCCESS:
+        case LLAMA_MEMORY_STATUS_NO_UPDATE:
+            {
+                return false;
+            }
+        case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
+        case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
+            {
+                return true;
+            }
+    }
+
+    return false;
+}
index d2ef0c2a3b4aafaae17a2bede37b97340f21b64a..e8ba336e8525d16b2cd277eb53a60c4c36ecbc39 100644 (file)
@@ -3,7 +3,6 @@
 #include "llama.h"
 
 #include <memory>
-#include <vector>
 
 struct llama_ubatch;
 
@@ -28,23 +27,24 @@ enum llama_memory_status {
     LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
 };
 
-// helper function for combining the status of two memory states
+// helper function for combining the status of two memory contexts
 // useful for implementing hybrid memory types (e.g. iSWA)
 llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
 
-// the interface for managing the memory state during batch processing
+// helper function for checking if a memory status indicates a failure
+bool llama_memory_status_is_fail(llama_memory_status status);
+
+// the interface for managing the memory context during batch processing
 // this interface is implemented per memory type. see:
-//   - llama_kv_cache_unified_state
-//   - llama_kv_cache_unified_iswa_state
+//   - llama_kv_cache_unified_context
+//   - llama_kv_cache_unified_iswa_context
 //   ...
 //
-// the only method that can mutate the memory and the memory state is llama_memory_i::apply()
-//
-// TODO: rename to llama_memory_context_i ?
-struct llama_memory_state_i {
-    virtual ~llama_memory_state_i() = default;
+// the only method that should mutate the memory and the memory context is llama_memory_i::apply()
+struct llama_memory_context_i {
+    virtual ~llama_memory_context_i() = default;
 
-    // consume the current ubatch from the state and proceed to the next one
+    // consume the current ubatch from the context and proceed to the next one
     // return false if we are done
     virtual bool next() = 0;
 
@@ -55,11 +55,11 @@ struct llama_memory_state_i {
     // get the current ubatch
     virtual const llama_ubatch & get_ubatch() const = 0;
 
-    // get the status of the memory state - used for error handling and checking if any updates would be applied
+    // get the status of the memory context - used for error handling and checking if any updates would be applied
     virtual llama_memory_status get_status() const = 0;
 };
 
-using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>;
+using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>;
 
 // general concept of LLM memory
 // the KV cache is a type of LLM memory, but there can be other types
@@ -67,19 +67,19 @@ struct llama_memory_i {
     virtual ~llama_memory_i() = default;
 
     // split the input batch into a set of ubatches and verify that they can fit into the cache
-    // return a state object containing the ubatches and KV cache state required to process them
-    // check the llama_memory_state_i::get_status() for the result
-    virtual llama_memory_state_ptr init_batch(
+    // return a context object containing the ubatches and memory state required to process them
+    // check the llama_memory_context_i::get_status() for the result
+    virtual llama_memory_context_ptr init_batch(
             llama_batch_allocr & balloc,
             uint32_t n_ubatch,
             bool embd_all) = 0;
 
     // simulate full cache, used for allocating worst-case compute buffers
-    virtual llama_memory_state_ptr init_full() = 0;
+    virtual llama_memory_context_ptr init_full() = 0;
 
     // prepare for any pending memory updates, such as shifts, defrags, etc.
     // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
-    virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
+    virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0;
 
     // getters
     virtual bool get_can_shift() const = 0;
index e2c82017f689018106e09e9ef499f9b18f8c5352..b15bf73c2a29ad4160a777f818c2f044f698770e 100644 (file)
@@ -47,6 +47,7 @@ const char * llm_type_name(llm_type type) {
         case LLM_TYPE_475M:          return "475M";
         case LLM_TYPE_770M:          return "770M";
         case LLM_TYPE_780M:          return "780M";
+        case LLM_TYPE_0_3B:          return "0.3B";
         case LLM_TYPE_0_5B:          return "0.5B";
         case LLM_TYPE_0_6B:          return "0.6B";
         case LLM_TYPE_1B:            return "1B";
@@ -103,6 +104,8 @@ const char * llm_type_name(llm_type type) {
         case LLM_TYPE_17B_128E:      return "17Bx128E (Maverick)";
         case LLM_TYPE_30B_A3B:       return "30B.A3B";
         case LLM_TYPE_235B_A22B:     return "235B.A22B";
+        case LLM_TYPE_E2B:           return "E2B";
+        case LLM_TYPE_E4B:           return "E4B";
         default:                     return "?B";
     }
 }
@@ -1017,6 +1020,24 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
                     : 1.0f / std::sqrt(float(hparams.n_embd_head_k));
             } break;
+        case LLM_ARCH_GEMMA3N:
+            {
+                hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
+                hparams.set_swa_pattern(5);
+
+                hparams.rope_freq_base_train_swa  = 10000.0f;
+                hparams.rope_freq_scale_train_swa = 1.0f;
+                hparams.f_attention_scale         = 1.0f;
+
+                ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW,    hparams.n_swa);
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+                switch (hparams.n_layer) {
+                    case 30: type = LLM_TYPE_E2B; break;
+                    case 35: type = LLM_TYPE_E4B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_STARCODER2:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -1484,6 +1505,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_ERNIE4_5:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                switch (hparams.n_layer) {
+                    case 18: type = LLM_TYPE_0_3B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         default: throw std::runtime_error("unsupported model architecture");
     }
 
@@ -2950,6 +2979,62 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
                     }
                 } break;
+            case LLM_ARCH_GEMMA3N:
+                {
+                    const int64_t n_altup      = hparams.n_altup;
+                    const int64_t laurel_rank  = hparams.laurel_rank;
+                    const int64_t n_embd_altup = hparams.n_embd_altup;
+
+                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    tok_embd           = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,           "weight"), {n_embd, n_vocab}, 0);
+                    tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0);
+
+                    altup_proj           = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ,           "weight"), {n_embd, n_embd, n_altup - 1}, 0);
+                    altup_unembd_proj    = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ,    "weight"), {n_embd, n_embd, n_altup - 1}, 0);
+                    per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_altup * n_layer}, 0);
+                    per_layer_proj_norm  = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM,  "weight"), {n_embd_altup}, 0);
+
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+                        layer.attn_q_norm    = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM,    "weight", i), {n_embd_head_k}, 0);
+                        layer.attn_k_norm    = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM,    "weight", i), {n_embd_head_k}, 0);
+                        layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
+
+                        // altup & laurel
+                        layer.per_layer_inp_gate   = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE,  "weight", i), {n_embd, n_embd_altup}, 0);
+                        layer.per_layer_proj       = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ,      "weight", i), {n_embd_altup, n_embd}, 0);
+                        layer.per_layer_post_norm  = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0);
+                        layer.altup_correct_coef   = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_COEF,  "weight", i), {n_altup, n_altup}, 0);
+                        layer.altup_correct_scale  = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_SCALE, "weight", i), {n_embd}, 0);
+                        layer.altup_predict_coef   = create_tensor(tn(LLM_TENSOR_ALTUP_PREDICT_COEF,  "weight", i), {n_altup, n_altup * n_altup}, 0);
+                        layer.altup_router         = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER,        "weight", i), {n_embd, n_altup}, 0);
+                        layer.altup_router_norm    = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER_NORM,   "weight", i), {n_embd}, 0);
+                        layer.laurel_l             = create_tensor(tn(LLM_TENSOR_LAUREL_L,            "weight", i), {n_embd, laurel_rank}, 0);
+                        layer.laurel_r             = create_tensor(tn(LLM_TENSOR_LAUREL_R,            "weight", i), {laurel_rank, n_embd}, 0);
+                        layer.laurel_post_norm     = create_tensor(tn(LLM_TENSOR_LAUREL_POST_NORM,    "weight", i), {n_embd}, 0);
+                    }
+                } break;
             case LLM_ARCH_STARCODER2:
                 {
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -4268,6 +4353,40 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
 
                         layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
 
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_ERNIE4_5:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+                        // optional bias tensors
+                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     TENSOR_NOT_REQUIRED);
+                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
+                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
+                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd},     TENSOR_NOT_REQUIRED);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
                         layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
                         layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                     }
@@ -8980,6 +9099,442 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
     }
 };
 
+struct llm_build_gemma3n_iswa : public llm_graph_context {
+    const llama_model & model;
+    ggml_cgraph * gf;
+
+    const int64_t n_embd_head;
+    const int64_t n_embd_altup;
+    const int64_t n_altup;
+    const int     i_altup_act;
+    const int     n_layer_kv = 20; // number of layers having KV [KV_REUSE]
+    const int     n_layer_sparsity = 10; // number of layers using activation sparsity
+    const float   f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95)
+
+    ggml_tensor * one; // containing single element 1.0f
+
+    llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf)
+            : llm_graph_context(params),
+              model(model),
+              gf(gf),
+              n_embd_head(model.hparams.n_embd_head_k),
+              n_embd_altup(model.hparams.n_embd_altup),
+              n_altup(model.hparams.n_altup),
+              i_altup_act(model.hparams.i_altup_act) {
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        // TODO: remove this when ggml_scale_add is implemented
+        one = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+        {
+            auto inp = std::make_unique<llm_graph_input_one>();
+            inp->one = one;
+            res->add_input(std::move(inp));
+        }
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
+        if (ubatch.token) {
+            inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
+            cb(inpL, "inp_scaled", -1);
+        }
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        // TODO: is causal == true correct? might need some changes
+        auto * inp_attn = build_attn_inp_kv_unified_iswa();
+
+        // inp_per_layer shape: [n_embd_altup, n_tokens, n_layer]
+        ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs());
+
+        // inpL now has only 1 altup, project it to the rest of the altups
+        // these "added" altups will be concat to the last dim of inpL
+        {
+            ggml_tensor * target_magnitude = calc_magnitude(inpL);
+            ggml_tensor * inp_repeated = ggml_repeat_4d(ctx0, inpL, n_embd, n_tokens, n_altup - 1, 1);
+            ggml_tensor * altup_added = ggml_mul_mat(ctx0, model.altup_proj, inp_repeated); // shape: [n_embd, n_tokens, n_altup - 1]
+            ggml_tensor * new_magnitude = calc_magnitude(altup_added);
+            altup_added = ggml_div(ctx0,
+                                ggml_mul(ctx0, altup_added, target_magnitude),
+                                new_magnitude);
+            inpL = ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup]
+            cb(inpL, "inp_stacked", -1);
+        }
+
+        // inpL now has shape:          [n_embd,       n_tokens, n_altup]
+        // inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer]
+
+        for (int il = 0; il < n_layer; ++il) {
+            // this block is made to be closely resemble Gemma3p5DecoderLayer on python code
+            const bool has_kv = (il < n_layer_kv);
+
+            const float freq_base_l  = model.get_rope_freq_base (cparams, il);
+            const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
+
+            ggml_tensor * cur = inpL; // [n_embd, n_tokens, n_altup]
+            ggml_tensor * predictions = altup_predict(cur, il); // [n_embd, n_tokens, n_altup]
+
+            // predicted value will go through self-attention and laurel
+            ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); // [n_embd, n_tokens]
+            cur = active_prediction;
+            cb(cur, "active_prediction", il);
+
+            // norm
+            cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // laurel
+            ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens]
+
+            // self-attention
+            if (has_kv) {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+                Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+                Vcur = ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps);
+
+                cb(Qcur, "Qcur_normed", il);
+                cb(Kcur, "Kcur_normed", il);
+                cb(Vcur, "Vcur_normed", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+
+                cb(Qcur, "Qcur_pos", il);
+                cb(Kcur, "Kcur_pos", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il);
+            } else {
+                // no KV layers
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+
+                Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+                cb(Qcur, "Qcur_normed", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(Qcur, "Qcur_pos", il);
+
+                cur = build_attn(inp_attn, gf,
+                    model.layers[il].wo, NULL,
+                    Qcur, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
+            }
+
+            cur = build_norm(cur,
+                    model.layers[il].attn_post_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_post_norm", il);
+
+            cur = ggml_add(ctx0, cur, active_prediction); // [n_embd, n_tokens]
+            cb(cur, "attn_gated", il);
+
+            ggml_tensor * attn_laurel = ggml_scale(ctx0,
+                                            ggml_add(ctx0, cur, laurel_out),
+                                            1.0f / sqrtf(2.0f)); // [n_embd, n_tokens]
+            cb(attn_laurel, "attn_laurel", il);
+
+            cur = build_norm(attn_laurel,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            // feed-forward network
+            {
+                ggml_tensor * up_proj   = build_lora_mm(model.layers[il].ffn_up,   cur);
+                ggml_tensor * gate_proj = build_lora_mm(model.layers[il].ffn_gate, cur);
+
+                if (il < n_layer_sparsity) {
+                    // apply activation sparsity
+                    gate_proj = gaussian_topk(gate_proj);
+                }
+                gate_proj = ggml_gelu(ctx0, gate_proj);
+
+                cur = ggml_mul(ctx0, up_proj, gate_proj);
+                cur = build_lora_mm(model.layers[il].ffn_down, cur);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = build_norm(cur,
+                    model.layers[il].ffn_post_norm, NULL,
+                    LLM_NORM_RMS, -1);
+            cb(cur, "ffn_post_norm", il);
+
+            ggml_tensor * attn_ffw_laurel_gated = ggml_add(ctx0, cur, attn_laurel); // [n_embd, n_tokens]
+            cb(attn_ffw_laurel_gated, "attn_ffw_laurel_gated", il);
+
+            ggml_tensor * corrected = altup_correct(predictions, attn_ffw_laurel_gated, il); // [n_embd, n_tokens, n_altup]
+
+            ggml_tensor * first_prediction; // [n_embd, n_tokens]
+            {
+                first_prediction = view_2d_slice(corrected, i_altup_act); // [n_embd, n_tokens]
+                first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale);
+                first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction);
+                first_prediction = ggml_gelu(ctx0, first_prediction); // [n_embd_altup, n_tokens]
+                cb(first_prediction, "first_prediction_gated", il);
+                ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_altup, n_tokens]
+                first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer); // [n_embd_altup, n_tokens]
+                cb(first_prediction, "first_prediction_scaled", il);
+
+                first_prediction = build_lora_mm(model.layers[il].per_layer_proj, first_prediction); // [n_embd, n_tokens]
+                first_prediction = build_norm(first_prediction,
+                        model.layers[il].per_layer_post_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(first_prediction, "first_prediction_out", il);
+            }
+
+            // equivalent to python code: corrected_predictions[1:] += first_prediction
+            {
+                ggml_tensor * slice_first = view_2d_slice(corrected, 0);
+                ggml_tensor * slice_rest  = ggml_view_3d(ctx0, corrected, n_embd, n_tokens, n_altup - 1,
+                                                    ggml_row_size(corrected->type, n_embd),
+                                                    ggml_row_size(corrected->type, n_embd*n_tokens),
+                                                    n_embd*n_tokens*ggml_element_size(corrected));
+                ggml_tensor * tmp = ggml_add(ctx0, slice_rest, first_prediction); // [n_embd, n_tokens, n_altup - 1]
+                corrected = ggml_concat(ctx0, slice_first, tmp, 2); // [n_embd, n_tokens, n_altup]
+            }
+
+            cur = corrected; // [n_embd, n_tokens, n_altup]
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL; // [n_embd, n_tokens, n_altup]
+
+        // cur now has multiple altup(s), we want to merge them back to 1 altup
+        {
+            ggml_tensor * target_magnitude = calc_magnitude(view_2d_slice(cur, i_altup_act)); // [n_embd, n_tokens]
+            // do a view to skip the first slice (active altup)
+            ggml_tensor * alt_slice = ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1,
+                                                    ggml_row_size(cur->type, n_embd),
+                                                    ggml_row_size(cur->type, n_embd*n_tokens),
+                                                    n_embd*n_tokens*ggml_element_size(cur));
+            ggml_tensor * altup_unembd = ggml_mul_mat(ctx0, model.altup_unembd_proj, alt_slice); // shape: [n_embd, n_tokens, n_altup - 1]
+            ggml_tensor * new_magnitude = calc_magnitude(altup_unembd);
+            altup_unembd = ggml_div(ctx0,
+                                ggml_mul(ctx0, altup_unembd, target_magnitude),
+                                new_magnitude);
+            cb(altup_unembd, "altup_unembd", -1);
+
+            // equivalent to torch.mean(hidden_states, dim=0)
+            cur = view_2d_slice(cur, 0); // [n_embd, n_tokens]
+            for (int i = 0; i < n_altup - 1; ++i) {
+                cur = ggml_add(ctx0, cur, view_2d_slice(altup_unembd, i));
+            }
+            cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup)); // [n_embd, n_tokens]
+            cb(cur, "unembd_merged", -1);
+        }
+
+        // cur now has shape: [n_embd, n_tokens]
+
+        // TODO: move this to right after the last KV layer
+        {
+            // skip computing output for unused tokens
+            ggml_tensor * inp_out_ids = build_inp_out_ids();
+            cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+        }
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        {
+            // final logit soft-capping
+            cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
+            cur = ggml_tanh(ctx0, cur);
+            cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
+        }
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+
+    ggml_tensor * calc_magnitude(ggml_tensor * x) {
+        return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x)));
+    }
+
+    // get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
+    ggml_tensor * view_2d_slice(ggml_tensor * x, int idx) {
+        GGML_ASSERT(idx < (int)x->ne[2]);
+        return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1],
+                            ggml_row_size(x->type, x->ne[0]),
+                            idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
+    }
+
+    // equivalent to get_per_layer_inputs() in python code
+    // output shape: [n_embd_altup, n_layer, n_tokens]
+    ggml_tensor * get_per_layer_inputs() {
+        auto inp = std::make_unique<llm_graph_input_embd>();
+        ggml_tensor * inp_per_layer;
+        if (ubatch.token) {
+            inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
+            ggml_set_input(inp->tokens);
+            res->t_tokens = inp->tokens;
+            inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
+            inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens);
+            inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float)n_embd_altup));
+            cb(inp_per_layer, "inp_per_layer_selected", -1);
+        } else {
+            GGML_ABORT("TODO: support embd input");
+        }
+        res->add_input(std::move(inp));
+        return inp_per_layer;
+    }
+
+    // equivalent to project_per_layer_inputs() in python code
+    // this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim
+    // output shape: [n_embd_altup, n_tokens, n_layer]
+    ggml_tensor * project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) {
+        const float per_layer_projection_scale = 1.0f / sqrtf((float)n_embd);
+        const float per_layer_input_scale      = 1.0f / sqrtf(2.0f);
+
+        ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds);
+        per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale);
+        per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens);
+        per_layer_proj = build_norm(per_layer_proj,
+                                    model.per_layer_proj_norm, NULL,
+                                    LLM_NORM_RMS, -1); // [n_embd_altup, n_layer, n_tokens]
+        cb(per_layer_proj, "per_layer_proj", -1);
+
+        inp_per_layer = ggml_add(ctx0, inp_per_layer, per_layer_proj);
+        inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale);
+        cb(inp_per_layer, "inp_per_layer", -1);
+
+        // permute to shape: [n_embd_altup, n_tokens, n_layer]
+        inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3));
+        return inp_per_layer;
+    }
+
+    // input cur shape: [n_altup, n_tokens]
+    // output    shape: [n_altup, n_tokens]
+    ggml_tensor * laurel(ggml_tensor * cur, int il) {
+        ggml_tensor * tmp = cur;
+        tmp = build_lora_mm(model.layers[il].laurel_l, tmp);
+        tmp = build_lora_mm(model.layers[il].laurel_r, tmp);
+        tmp = build_norm(tmp, model.layers[il].laurel_post_norm, NULL, LLM_NORM_RMS, il);
+        tmp = ggml_add(ctx0, tmp, cur);
+        cb(tmp, "laurel_out", il);
+        return tmp;
+    }
+
+    // input x shape: [n_embd, n_tokens]
+    // output  shape: [n_embd, n_tokens]
+    ggml_tensor * gaussian_topk(ggml_tensor * x) {
+        ggml_tensor * mean = ggml_mean(ctx0, x);
+        ggml_tensor * std  = ggml_sqrt(ctx0, ggml_scale(ctx0,
+            ggml_sum_rows(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x, mean))),
+            1.0f / (float)(x->ne[0] - 1)
+        ));
+        ggml_tensor * cutoff_x = ggml_add(ctx0, mean, ggml_scale(ctx0, std, f_sparsity_std_mul));
+        return ggml_relu(ctx0, ggml_sub(ctx0, x, cutoff_x));
+    }
+
+    //
+    // altup functions
+    //
+
+    // equivalent to compute_router_modalities() in python code
+    // input x shape: [n_embd,  n_tokens]
+    // output  shape: [n_altup, n_tokens]
+    ggml_tensor * altup_compute_router_modalities(ggml_tensor * x, int il) {
+        ggml_tensor * router_inputs = build_norm(x,
+            model.layers[il].altup_router_norm, NULL,
+            LLM_NORM_RMS, il);
+
+        // router_input_scale
+        router_inputs = ggml_scale(ctx0, router_inputs, 1.0f / (float)n_embd);
+
+        ggml_tensor * output = ggml_mul_mat(ctx0, model.layers[il].altup_router, router_inputs);
+        return ggml_tanh(ctx0, output); // [n_altup, n_tokens]
+    }
+
+    // input cur shape: [n_embd, n_tokens, n_altup]
+    // output    shape: [n_embd, n_tokens, n_altup]
+    ggml_tensor * altup_predict(ggml_tensor * cur, int il) {
+        ggml_tensor * activated = view_2d_slice(cur, i_altup_act); // [n_embd, n_tokens]
+        ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
+        cb(modalities, "modalities", il);
+
+        ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_predict_coef, modalities);
+        cb(all_coefs, "all_coefs", il);
+        // first dim now having n_altup^2 elements, we reshape it to 2D (so we end up with 3D tensor)
+        all_coefs = ggml_reshape_3d(ctx0, all_coefs, n_altup, n_altup, n_tokens);
+
+        // permute to [n_altup, n_embd, n_tokens]
+        ggml_tensor * cur_permuted = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
+        ggml_tensor * predictions = ggml_mul_mat(ctx0, cur_permuted, all_coefs); // [n_altup, n_embd, n_tokens]
+
+        // final shape must be the same as cur: [n_embd, n_tokens, n_altup]
+        predictions = ggml_cont(ctx0, ggml_permute(ctx0, predictions, 0, 2, 1, 3));
+        predictions = ggml_add(ctx0, predictions, cur);
+        cb(predictions, "predictions", il);
+
+        return predictions;
+    }
+
+    // input predictions       shape: [n_embd, n_tokens, n_altup]
+    // input activated         shape: [n_embd, n_tokens]
+    // output                  shape: [n_embd, n_tokens, n_altup]
+    ggml_tensor * altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il) {
+        ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
+        cb(modalities, "modalities", il);
+
+        ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act);
+        ggml_tensor * innovation = ggml_sub(ctx0, activated, active_prediction); // [n_embd, n_tokens]
+        cb(innovation, "innovation", il);
+
+        ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities); // [n_altup, n_tokens]
+        all_coefs = ggml_add(ctx0, all_coefs, one);
+        cb(all_coefs, "all_coefs", il);
+        all_coefs = ggml_cont(ctx0, ggml_transpose(ctx0, all_coefs)); // [n_tokens, n_altup]
+        all_coefs = ggml_reshape_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup]
+
+        innovation = ggml_repeat_4d(ctx0, innovation, n_embd, n_tokens, n_altup, 1);
+        ggml_tensor * corrected = ggml_mul(ctx0, innovation, all_coefs); // [n_embd, n_tokens, n_altup]
+        corrected = ggml_add(ctx0, corrected, predictions); // [n_embd, n_tokens, n_altup]
+        cb(corrected, "corrected", il);
+
+        return corrected;
+    }
+};
+
 // TODO: move up next to build_starcoder
 struct llm_build_starcoder2 : public llm_graph_context {
     llm_build_starcoder2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
@@ -9171,9 +9726,9 @@ struct llm_build_mamba : public llm_graph_context {
                ggml_tensor * cur,
         const llama_ubatch & ubatch,
                        int   il) const {
-        const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
+        const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
 
-        const auto kv_head = kv_state->get_head();
+        const auto kv_head = mctx_cur->get_head();
 
         const int64_t d_conv  = hparams.ssm_d_conv;
         const int64_t d_inner = hparams.ssm_d_inner;
@@ -9191,8 +9746,8 @@ struct llm_build_mamba : public llm_graph_context {
         GGML_ASSERT(ubatch.equal_seqs);
         GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
 
-        ggml_tensor * conv_states_all = kv_state->get_r_l(il);
-        ggml_tensor * ssm_states_all  = kv_state->get_s_l(il);
+        ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
+        ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
 
         // (ab)using the KV cache to store the states
         ggml_tensor * conv = build_rs(
@@ -11916,7 +12471,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
             ggml_tensor * x_prev,
             const llama_ubatch & ubatch,
             int   il) const {
-        const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
+        const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
 
         const auto n_tokens = ubatch.n_tokens;
         const auto n_seqs = ubatch.n_seqs;
@@ -11926,7 +12481,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
         const auto n_head = n_embd / head_size;
         const auto n_head_kv = hparams.n_head_kv(il);
 
-        const auto kv_head = kv_state->get_head();
+        const auto kv_head = mctx_cur->get_head();
 
         const auto & layer = model.layers[il];
 
@@ -12038,7 +12593,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
         }
 
         ggml_tensor * wkv_state = build_rs(
-                inp, gf, kv_state->get_s_l(il),
+                inp, gf, mctx_cur->get_s_l(il),
                 hparams.n_embd_s(), n_seqs);
 
         ggml_tensor * wkv_output;
@@ -12057,9 +12612,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
                     wkv_state,
                     ggml_view_1d(
                         ctx0,
-                        kv_state->get_s_l(il),
+                        mctx_cur->get_s_l(il),
                         hparams.n_embd_s() * n_seqs,
-                        hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il))
+                        hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il))
                         )
                     )
                 );
@@ -12313,7 +12868,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
             ggml_tensor *& first_layer_value,
             const llama_ubatch & ubatch,
             int   il) const {
-        const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
+        const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
 
         const auto n_tokens = ubatch.n_tokens;
         const auto n_seqs = ubatch.n_seqs;
@@ -12322,7 +12877,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
         const auto head_count = n_embd / head_size;
         const auto n_seq_tokens = ubatch.n_seq_tokens;
 
-        const auto kv_head = kv_state->get_head();
+        const auto kv_head = mctx_cur->get_head();
 
         const auto & layer = model.layers[il];
 
@@ -12393,7 +12948,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
         a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
 
         ggml_tensor * wkv_state = build_rs(
-                inp, gf, kv_state->get_s_l(il),
+                inp, gf, mctx_cur->get_s_l(il),
                 hparams.n_embd_s(), n_seqs);
 
         ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
@@ -12407,9 +12962,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
                     wkv_state,
                     ggml_view_1d(
                         ctx0,
-                        kv_state->get_s_l(il),
+                        mctx_cur->get_s_l(il),
                         hparams.n_embd_s() * n_seqs,
-                        hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il))
+                        hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il))
                         )
                     )
                 );
@@ -13613,6 +14168,136 @@ struct llm_build_dots1 : public llm_graph_context {
     }
 };
 
+struct llm_build_ernie4_5 : public llm_graph_context {
+    llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified();
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            {
+                cur = build_norm(inpL,
+                        model.layers[il].attn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "attn_norm", il);
+            }
+
+            // self-attention
+            {
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
 struct llm_build_arcee : public llm_graph_context {
     llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -13974,6 +14659,10 @@ llm_graph_result_ptr llama_model::build_graph(
             {
                 llm = std::make_unique<llm_build_gemma3_iswa>(*this, params, gf);
             } break;
+        case LLM_ARCH_GEMMA3N:
+            {
+                llm = std::make_unique<llm_build_gemma3n_iswa>(*this, params, gf);
+            } break;
         case LLM_ARCH_STARCODER2:
             {
                 llm = std::make_unique<llm_build_starcoder2>(*this, params, gf);
@@ -14119,6 +14808,10 @@ llm_graph_result_ptr llama_model::build_graph(
             {
                 llm = std::make_unique<llm_build_arcee>(*this, params, gf);
             } break;
+        case LLM_ARCH_ERNIE4_5:
+            {
+                llm = std::make_unique<llm_build_ernie4_5>(*this, params, gf);
+            } break;
         default:
             GGML_ABORT("fatal error");
     }
@@ -14270,6 +14963,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_BAILINGMOE:
         case LLM_ARCH_NEO_BERT:
         case LLM_ARCH_ARCEE:
+        case LLM_ARCH_ERNIE4_5:
             return LLAMA_ROPE_TYPE_NORM;
 
         // the pairs of head values are offset by n_rot/2
@@ -14295,6 +14989,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_GEMMA:
         case LLM_ARCH_GEMMA2:
         case LLM_ARCH_GEMMA3:
+        case LLM_ARCH_GEMMA3N:
         case LLM_ARCH_STARCODER2:
         case LLM_ARCH_OPENELM:
         case LLM_ARCH_GPTNEOX:
@@ -14377,7 +15072,7 @@ const char * llama_model_chat_template(const llama_model * model, const char * n
         // do not extend this list unless absolutely necessary
         // Mistral-Small-2503 does not have built-in chat template
         llama_vocab_pre_type pre_type = model->vocab.get_pre_type();
-        if (pre_type == LLAMA_VOCAB_PRE_TYPE_TEKKEN && model->layers.size() == 40) {
+        if (!name && pre_type == LLAMA_VOCAB_PRE_TYPE_TEKKEN && model->layers.size() == 40) {
             return "mistral-v7-tekken";
         }
 
index 06e6c687943cc23e615bd1f49f773347d2b4247b..a958c5997a11b8e3447635ab635c31af50149fd7 100644 (file)
@@ -39,6 +39,7 @@ enum llm_type {
     LLM_TYPE_475M,
     LLM_TYPE_770M,
     LLM_TYPE_780M,
+    LLM_TYPE_0_3B,
     LLM_TYPE_0_5B,
     LLM_TYPE_0_6B,
     LLM_TYPE_1B,
@@ -95,6 +96,8 @@ enum llm_type {
     LLM_TYPE_17B_128E, // llama4 Maverick
     LLM_TYPE_30B_A3B,
     LLM_TYPE_235B_A22B,
+    LLM_TYPE_E2B,
+    LLM_TYPE_E4B,
 };
 
 std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type);
@@ -316,6 +319,19 @@ struct llama_layer {
     struct ggml_tensor * ffn_up_scale   = nullptr;
     struct ggml_tensor * ffn_down_scale = nullptr;
 
+    // altup & laurel
+    struct ggml_tensor * per_layer_inp_gate   = nullptr;
+    struct ggml_tensor * per_layer_proj       = nullptr;
+    struct ggml_tensor * per_layer_post_norm  = nullptr;
+    struct ggml_tensor * altup_correct_coef   = nullptr;
+    struct ggml_tensor * altup_correct_scale  = nullptr;
+    struct ggml_tensor * altup_predict_coef   = nullptr;
+    struct ggml_tensor * altup_router         = nullptr;
+    struct ggml_tensor * altup_router_norm    = nullptr;
+    struct ggml_tensor * laurel_l             = nullptr;
+    struct ggml_tensor * laurel_r             = nullptr;
+    struct ggml_tensor * laurel_post_norm     = nullptr;
+
     struct llama_layer_posnet posnet;
 
     struct llama_layer_convnext convnext;
@@ -354,6 +370,13 @@ struct llama_model {
     struct ggml_tensor * conv1d   = nullptr;
     struct ggml_tensor * conv1d_b = nullptr;
 
+    // gemma3n altup
+    struct ggml_tensor * tok_embd_per_layer   = nullptr;
+    struct ggml_tensor * altup_proj           = nullptr;
+    struct ggml_tensor * altup_unembd_proj    = nullptr;
+    struct ggml_tensor * per_layer_model_proj = nullptr;
+    struct ggml_tensor * per_layer_proj_norm  = nullptr;
+
     std::vector<llama_layer> layers;
 
     llama_model_params params;
index 8cf45732fd6d4817cdafd6fd3de6c2bea421fec4..f4b5713d7dd9aefe1f26759bd15d33a6e9487fe8 100644 (file)
@@ -1,5 +1,4 @@
 #include "llama-quant.h"
-
 #include "llama-impl.h"
 #include "llama-model.h"
 #include "llama-model-loader.h"
@@ -27,6 +26,56 @@ static void zeros(std::ofstream & file, size_t n) {
     }
 }
 
+static std::string remap_layer(const std::string & orig_name, const std::vector<int> & prune, std::map<int, std::string> & mapped, int & next_id) {
+    if (prune.empty()) {
+        return orig_name;
+    }
+
+    static const std::regex pattern(R"(blk\.(\d+)\.)");
+    if (std::smatch match; std::regex_search(orig_name, match, pattern)) {
+        const int blk = std::stoi(match[1]);
+        std::string new_name = orig_name;
+
+        if (mapped.count(blk)) {
+            // Already mapped, do nothing
+        } else if (std::find(prune.begin(), prune.end(), blk) != prune.end()) {
+            mapped[blk] = "";
+        } else if (blk < prune.front()) {
+            mapped[blk] = std::to_string(blk);
+            next_id = blk + 1;
+        } else {
+            mapped[blk] = std::to_string(next_id);
+            ++next_id;
+        }
+
+        return mapped[blk].empty() ? mapped[blk] : new_name.replace(match.position(1), match.length(1), mapped[blk]);
+    }
+
+    return orig_name;
+}
+
+static std::string remap_imatrix (const std::string & orig_name, const std::map<int, std::string> & mapped) {
+    if (mapped.empty()) {
+        return orig_name;
+    }
+
+    static const std::regex pattern(R"(blk\.(\d+)\.)");
+    if (std::smatch match; std::regex_search(orig_name, match, pattern)) {
+        const std::string blk(match[1]);
+        std::string new_name = orig_name;
+
+        for (const auto & p : mapped) {
+            if (p.second == blk) {
+                LLAMA_LOG_DEBUG("(blk.%d imatrix) ", p.first);
+                return new_name.replace(match.position(1), match.length(1), std::to_string(p.first));
+            }
+        }
+        GGML_ABORT("\n%s: imatrix mapping error for %s\n", __func__, orig_name.c_str());
+    }
+
+    return orig_name;
+}
+
 struct quantize_state_impl {
     const llama_model                 & model;
     const llama_model_quantize_params * params;
@@ -174,7 +223,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
                 new_type = GGML_TYPE_Q6_K;
             }
         }
-    } else if (name == "token_embd.weight") {
+    } else if (name == "token_embd.weight" || name == "per_layer_token_embd.weight") {
         if (qs.params->token_embedding_type < GGML_TYPE_COUNT) {
             new_type = qs.params->token_embedding_type;
         } else {
@@ -568,6 +617,11 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
     const size_t align = GGUF_DEFAULT_ALIGNMENT;
     gguf_context_ptr ctx_out { gguf_init_empty() };
 
+    std::vector<int> prune_list = {};
+    if (params->prune_layers) {
+        prune_list = *static_cast<const std::vector<int> *>(params->prune_layers);
+    }
+
     // copy the KV pairs from the input file
     gguf_set_kv     (ctx_out.get(), ml.meta.get());
     gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
@@ -597,12 +651,32 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
         }
     }
 
+    std::map<int, std::string> mapped;
+    int blk_id = 0;
+    int pruned_attention_w = 0;
+
     // make a list of weights
     std::vector<const llama_model_loader::llama_tensor_weight *> tensors;
     tensors.reserve(ml.weights_map.size());
     for (const auto & it : ml.weights_map) {
+        const std::string remapped_name(remap_layer(it.first, prune_list, mapped, blk_id));
+        if (remapped_name.empty()) {
+            if (it.first.find("attn_v.weight") != std::string::npos ||
+                it.first.find("attn_qkv.weight") != std::string::npos ||
+                it.first.find("attn_kv_b.weight") != std::string::npos) {
+                    pruned_attention_w++;
+            }
+            LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str());
+            continue;
+        } else if (remapped_name != it.first) {
+            ggml_set_name(it.second.tensor, remapped_name.c_str());
+            LLAMA_LOG_DEBUG("%s: tensor %s remapped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor));
+        }
         tensors.push_back(&it.second);
     }
+    if (!prune_list.empty()) {
+        gguf_set_val_u32(ctx_out.get(), ml.llm_kv(LLM_KV_BLOCK_COUNT).c_str(), blk_id);
+    }
 
     // keep_split requires that the weights are sorted by split index
     if (params->keep_split) {
@@ -640,7 +714,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
         if (llama_model_has_encoder(&model)) {
             n_attn_layer *= 3;
         }
-        GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
+        GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected");
     }
 
     size_t total_size_org = 0;
@@ -681,7 +755,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
         for (size_t i = 0; i < ctx_outs.size(); ++i) {
             gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i);
             gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split);
-            gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), ml.n_tensors);
+            gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), (int32_t)tensors.size());
         }
     }
 
@@ -756,6 +830,13 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
         // NOTE: can't use LLM_TN here because the layer number is not known
         quantize &= name.find("ffn_gate_inp.weight") == std::string::npos;
 
+        // these are very small (e.g. 4x4)
+        quantize &= name.find("altup")  == std::string::npos;
+        quantize &= name.find("laurel") == std::string::npos;
+
+        // these are not too big so keep them as it is
+        quantize &= name.find("per_layer_model_proj") == std::string::npos;
+
         // do not quantize positional embeddings and token types (BERT)
         quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD,    "weight");
         quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight");
@@ -832,7 +913,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
 
             const float * imatrix = nullptr;
             if (imatrix_data) {
-                auto it = imatrix_data->find(tensor->name);
+                auto it = imatrix_data->find(remap_imatrix(tensor->name, mapped));
                 if (it == imatrix_data->end()) {
                     LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name);
                 } else {
@@ -947,6 +1028,7 @@ llama_model_quantize_params llama_model_quantize_default_params() {
         /*.imatrix                     =*/ nullptr,
         /*.kv_overrides                =*/ nullptr,
         /*.tensor_type                 =*/ nullptr,
+        /*.prune_layers                =*/ nullptr
     };
 
     return result;
index b04720bee59ef644922638d210632c7621d7edae..3eda9bc68608c0f0b55bca1a74512d4e28876d80 100644 (file)
@@ -390,6 +390,7 @@ extern "C" {
         void * imatrix;                       // pointer to importance matrix data
         void * kv_overrides;                  // pointer to vector containing overrides
         void * tensor_types;                  // pointer to vector containing tensor types
+        void * prune_layers;                  // pointer to vector containing layer indices to prune
     } llama_model_quantize_params;
 
     typedef struct llama_logit_bias {
@@ -943,12 +944,14 @@ extern "C" {
     // Requires the context to have a memory.
     // For encode-decoder contexts, processes the batch using the decoder.
     // Positive return values does not mean a fatal error, but rather a warning.
-    // Upon non-zero return values, the memory state is restored to the state before this call
+    // Upon fatal-error or abort, the ubatches that managed to be been processed will remain in the memory state of the context
+    //   To handle this correctly, query the memory state using llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
+    // Upon other return values, the memory state is restored to the state before this call
     //    0 - success
     //    1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
-    //    2 - aborted
+    //    2 - aborted     (processed ubatches will remain in the context's memory)
     //   -1 - invalid input batch
-    // < -1 - error
+    // < -1 - fatal error (processed ubatches will remain in the context's memory)
     LLAMA_API int32_t llama_decode(
             struct llama_context * ctx,
               struct llama_batch   batch);