]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
talk-llama : sync llama.cpp
authorGeorgi Gerganov <redacted>
Wed, 18 Jun 2025 07:22:47 +0000 (10:22 +0300)
committerGeorgi Gerganov <redacted>
Wed, 18 Jun 2025 09:40:34 +0000 (12:40 +0300)
ggml-ci

26 files changed:
examples/talk-llama/llama-arch.cpp
examples/talk-llama/llama-arch.h
examples/talk-llama/llama-batch.cpp
examples/talk-llama/llama-batch.h
examples/talk-llama/llama-chat.cpp
examples/talk-llama/llama-chat.h
examples/talk-llama/llama-context.cpp
examples/talk-llama/llama-context.h
examples/talk-llama/llama-cparams.cpp
examples/talk-llama/llama-cparams.h
examples/talk-llama/llama-graph.cpp
examples/talk-llama/llama-graph.h
examples/talk-llama/llama-kv-cache-recurrent.cpp
examples/talk-llama/llama-kv-cache-recurrent.h
examples/talk-llama/llama-kv-cache-unified-iswa.cpp
examples/talk-llama/llama-kv-cache-unified-iswa.h
examples/talk-llama/llama-kv-cache-unified.cpp
examples/talk-llama/llama-kv-cache-unified.h
examples/talk-llama/llama-kv-cells.h
examples/talk-llama/llama-memory.h
examples/talk-llama/llama-model.cpp
examples/talk-llama/llama-model.h
examples/talk-llama/llama-quant.cpp
examples/talk-llama/llama-vocab.cpp
examples/talk-llama/llama.cpp
examples/talk-llama/llama.h

index 43fa60a8070b765f09dff5ef73124bb342cb0d52..de8d289cf967e989cbd3fcca85639b1b11c4ba32 100644 (file)
@@ -20,6 +20,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_BERT,             "bert"             },
     { LLM_ARCH_NOMIC_BERT,       "nomic-bert"       },
     { LLM_ARCH_NOMIC_BERT_MOE,   "nomic-bert-moe"   },
+    { LLM_ARCH_NEO_BERT,         "neo-bert"         },
     { LLM_ARCH_JINA_BERT_V2,     "jina-bert-v2"     },
     { LLM_ARCH_BLOOM,            "bloom"            },
     { LLM_ARCH_STABLELM,         "stablelm"         },
@@ -72,6 +73,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
     { LLM_ARCH_PLM,              "plm"              },
     { LLM_ARCH_BAILINGMOE,       "bailingmoe"       },
+    { LLM_ARCH_DOTS1,            "dots1"            },
+    { LLM_ARCH_ARCEE,            "arcee"            },
     { LLM_ARCH_UNKNOWN,          "(unknown)"        },
 };
 
@@ -243,6 +246,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
         },
     },
+    {
+        LLM_ARCH_ARCEE,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
     {
         LLM_ARCH_LLAMA4,
         {
@@ -494,6 +515,21 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
         },
     },
+    {
+        LLM_ARCH_NEO_BERT,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+            { LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" },
+            { LLM_TENSOR_CLS,             "cls" },
+            { LLM_TENSOR_CLS_OUT,         "cls.output" },
+        },
+    },
     {
         LLM_ARCH_JINA_BERT_V2,
         {
@@ -1555,6 +1591,34 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_UP_SHEXP,       "blk.%d.ffn_up_shexp" },
         },
     },
+    {
+        LLM_ARCH_DOTS1,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
+            { LLM_TENSOR_OUTPUT,             "output" },
+            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_Q_NORM,        "blk.%d.attn_q_norm" },
+            { LLM_TENSOR_ATTN_K,             "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_K_NORM,        "blk.%d.attn_k_norm" },
+            { LLM_TENSOR_ATTN_V,             "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,           "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_UP,             "blk.%d.ffn_up" },
+            { LLM_TENSOR_FFN_DOWN,           "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_GATE_INP,       "blk.%d.ffn_gate_inp" },
+            { LLM_TENSOR_FFN_GATE_EXPS,      "blk.%d.ffn_gate_exps" },
+            { LLM_TENSOR_FFN_DOWN_EXPS,      "blk.%d.ffn_down_exps" },
+            { LLM_TENSOR_FFN_UP_EXPS,        "blk.%d.ffn_up_exps" },
+            { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
+            { LLM_TENSOR_FFN_GATE_SHEXP,     "blk.%d.ffn_gate_shexp" },
+            { LLM_TENSOR_FFN_DOWN_SHEXP,     "blk.%d.ffn_down_shexp" },
+            { LLM_TENSOR_FFN_UP_SHEXP,       "blk.%d.ffn_up_shexp" },
+            { LLM_TENSOR_FFN_EXP_PROBS_B,    "blk.%d.exp_probs_b" },
+        }
+    },
     {
         LLM_ARCH_UNKNOWN,
         {
index f3825528aefdb93cf9f571fb479de36fe5aadee4..3e8a61da3c13e38fc1711e003d2f2b6ba3f59393 100644 (file)
@@ -24,6 +24,7 @@ enum llm_arch {
     LLM_ARCH_BERT,
     LLM_ARCH_NOMIC_BERT,
     LLM_ARCH_NOMIC_BERT_MOE,
+    LLM_ARCH_NEO_BERT,
     LLM_ARCH_JINA_BERT_V2,
     LLM_ARCH_BLOOM,
     LLM_ARCH_STABLELM,
@@ -76,6 +77,8 @@ enum llm_arch {
     LLM_ARCH_WAVTOKENIZER_DEC,
     LLM_ARCH_PLM,
     LLM_ARCH_BAILINGMOE,
+    LLM_ARCH_DOTS1,
+    LLM_ARCH_ARCEE,
     LLM_ARCH_UNKNOWN,
 };
 
index 6a19a243118d344bfd9f33a881a356dc74929138..8b6d14fe8813c3d0874a4ad2aaca71c9e6f0aa4b 100644 (file)
@@ -1,8 +1,14 @@
 #include "llama-batch.h"
 
+#include "llama-impl.h"
+#include "llama-cparams.h"
+#include "llama-vocab.h"
+#include "llama-memory.h"
+
 #include <cassert>
 #include <cstring>
 #include <algorithm>
+#include <sstream>
 
 llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
     // clear empty sequences
@@ -105,12 +111,7 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s
             ubatch.seq_id = batch->seq_id + seq.offset;
         }
     }
-    if (logits_all) {
-        for (size_t i = 0; i < length; ++i) {
-            ubatch.output[ubatch.n_tokens + i] = 1;
-            out_ids.push_back(ids[seq.offset + i]);
-        }
-    } else if (batch->logits) {
+    if (batch->logits) {
         if (ubatch.equal_seqs) {
             for (size_t i = 0; i < length; ++i) {
                 size_t id = ids[seq.offset + i];
@@ -197,11 +198,10 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
     return ubatch;
 }
 
-llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
+llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split) {
     GGML_ASSERT(batch.n_tokens >= 0);
     this->batch = &batch;
     this->n_embd = n_embd;
-    this->logits_all = logits_all;
 
     n_tokens = batch.n_tokens;
     ids.resize(n_tokens);
@@ -285,17 +285,56 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
             );
 }
 
-llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
-    batch = in_batch;
+llama_batch_allocr::llama_batch_allocr() {
+    const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
+    debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
+
+    seq_pos.resize(LLAMA_MAX_SEQ);
+    seq_cpl.resize(LLAMA_MAX_SEQ);
+    for (auto & cur : seq_cpl) {
+        cur.resize(LLAMA_MAX_SEQ);
+    }
+}
+
+bool llama_batch_allocr::init(
+        const llama_batch & batch_inp,
+        const llama_vocab & vocab,
+        const llama_memory_i * memory,
+        bool embd_all) {
+    clear();
+
+    batch = batch_inp;
+
     GGML_ASSERT(batch.n_tokens > 0);
-    if (!batch.pos) {
-        assert(p0 >= 0);
-        pos.resize(batch.n_tokens);
-        for (int32_t i = 0; i < batch.n_tokens; i++) {
-            pos[i] = p0 + i;
+
+    //
+    // validate input batch
+    //
+
+    if (batch.token) {
+        for (int32_t i = 0; i < batch.n_tokens; ++i) {
+            if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
+                LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
+                return false;
+            }
+        }
+    }
+
+    if (batch.seq_id) {
+        for (int32_t i = 0; i < batch.n_tokens; ++i) {
+            for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
+                if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
+                    LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
+                    return false;
+                }
+            }
         }
-        batch.pos = pos.data();
     }
+
+    //
+    // auto-generate missing fields
+    //
+
     if (!batch.n_seq_id) {
         n_seq_id.resize(batch.n_tokens);
         for (int32_t i = 0; i < batch.n_tokens; i++) {
@@ -303,6 +342,7 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
         }
         batch.n_seq_id = n_seq_id.data();
     }
+
     if (!batch.seq_id) {
         seq_id.resize(batch.n_tokens + 1);
         seq_id[batch.n_tokens] = NULL;
@@ -311,10 +351,221 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
         }
         batch.seq_id = seq_id.data();
     }
+
+    if (!batch.pos) {
+        pos.resize(batch.n_tokens);
+
+        // initialize the starting position for each sequence based on the positions in the memory
+        llama_pos p0[LLAMA_MAX_SEQ];
+        for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+            if (!memory) {
+                p0[s] = 0;
+            } else {
+                p0[s] = memory->seq_pos_max(s) + 1;
+            }
+        }
+
+        for (int32_t i = 0; i < batch.n_tokens; i++) {
+            const llama_seq_id seq_id = batch.seq_id[i][0];
+
+            pos[i] = p0[seq_id];
+
+            for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
+                p0[batch.seq_id[i][s]] = pos[i] + 1;
+            }
+        }
+
+        batch.pos = pos.data();
+    }
+
     if (!batch.logits) {
-        logits.resize(batch.n_tokens);
-        logits[logits.size() - 1] = true;
-        batch.logits = logits.data();
+        if (embd_all) {
+            // return the output for all tokens
+            output.resize(batch.n_tokens, true);
+        } else {
+            // return the output only for the last token
+            output.resize(batch.n_tokens, false);
+            output[output.size() - 1] = true;
+        }
+
+        batch.logits = output.data();
+    } else if (embd_all) {
+        bool warn = false;
+
+        for (int32_t i = 0; i < batch.n_tokens; ++i) {
+            if (batch.logits[i] == 0) {
+                warn = true;
+            }
+        }
+
+        if (warn) {
+            LLAMA_LOG_WARN("%s: embeddings required but some input tokens were not marked as outputs -> overriding\n", __func__);
+
+            output.resize(batch.n_tokens, true);
+            batch.logits = output.data();
+        }
+    }
+
+    //
+    // compute stats
+    //
+
+    for (int32_t i = 0; i < batch.n_tokens; ++i) {
+        n_outputs += batch.logits[i] != 0;
+    }
+
+    // determine coupled sequences
+    // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
+    for (int32_t i = 0; i < batch.n_tokens; ++i) {
+        for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
+            seq_pos[batch.seq_id[i][s]].insert(batch.pos[i]);
+
+            if (s > 0) {
+                const llama_seq_id s0 = batch.seq_id[i][0];
+                const llama_seq_id s1 = batch.seq_id[i][s];
+
+                // mark that sequence s1 is coupled to s0
+                seq_cpl[s1][s0] = true;
+
+                // note: the other way around is not necessary for now
+                //seq_cpl[s0][s1] = true;
+            }
+        }
+    }
+
+    if (debug > 0) {
+        LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
+        LLAMA_LOG_DEBUG("%s:   n_tokens  = %d\n", __func__,          batch.n_tokens);
+        LLAMA_LOG_DEBUG("%s:   token     = %p\n", __func__, (void *) batch.token);
+        LLAMA_LOG_DEBUG("%s:   embd      = %p\n", __func__, (void *) batch.embd);
+        LLAMA_LOG_DEBUG("%s:   pos       = %p\n", __func__, (void *) batch.pos);
+        LLAMA_LOG_DEBUG("%s:   n_seq_id  = %p\n", __func__, (void *) batch.n_seq_id);
+        LLAMA_LOG_DEBUG("%s:   seq_id    = %p\n", __func__, (void *) batch.seq_id);
+        LLAMA_LOG_DEBUG("%s:   logits    = %p\n", __func__, (void *) batch.logits);
+        LLAMA_LOG_DEBUG("%s:   n_outputs = %d\n", __func__, n_outputs);
+
+        if (debug > 1) {
+            int seq_id_max = 0;
+            for (int32_t i = 0; i < batch.n_tokens; ++i) {
+                for (int s = 0; s < batch.n_seq_id[i]; ++s) {
+                    for (int s = 0; s < batch.n_seq_id[i]; ++s) {
+                        seq_id_max = std::max(seq_id_max, batch.seq_id[i][s]);
+                    }
+                }
+            }
+            ++seq_id_max;
+
+            LLAMA_LOG_DEBUG("%s:   token     = [\n", __func__);
+            for (int32_t i = 0; i < batch.n_tokens; ++i) {
+                std::vector<int8_t> seq_id(seq_id_max);
+
+                for (int s = 0; s < batch.n_seq_id[i]; ++s) {
+                    seq_id[batch.seq_id[i][s]] = 1;
+                }
+
+                std::stringstream ss;
+                for (int s = 0; s < seq_id_max; ++s) {
+                    if (seq_id[s]) {
+                        ss << s%10;
+                    } else {
+                        ss << ".";
+                    }
+                }
+
+                LLAMA_LOG_DEBUG("%s:  %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
+                        __func__, i, batch.token[i], vocab.token_to_piece(batch.token[i]).c_str(),
+                        batch.pos[i], batch.n_seq_id[i], ss.str().c_str(), batch.logits[i]);
+            }
+            LLAMA_LOG_DEBUG("%s:   ]\n", __func__);
+
+            LLAMA_LOG_DEBUG("%s:   seq       = [\n", __func__);
+            for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
+                if (seq_pos[s0].empty()) {
+                    continue;
+                }
+
+                std::stringstream ss;
+                for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
+                    if (seq_cpl[s0][s1]) {
+                        ss << s1 << " ";
+                    }
+                }
+
+                LLAMA_LOG_DEBUG("%s:  %4d: pos = [%4d, %4d], cpl = %s\n",
+                        __func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
+            }
+            LLAMA_LOG_DEBUG("%s:   ]\n", __func__);
+        }
+    }
+
+    //
+    // consistency checks
+    //
+
+    for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
+        if (seq_pos[s].empty()) {
+            continue;
+        }
+
+        if (memory && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
+            LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
+            return false;
+        }
+
+        if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
+            LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
+            return false;
+        }
+    }
+
+    if (memory) {
+        for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
+            for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
+                if (seq_cpl[s0][s1]) {
+                    if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
+                        memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
+                        LLAMA_LOG_ERROR("%s: sequence %d is coupled to %d in the input batch, but have divereged\n", __func__, s0, s1);
+                        return false;
+                    }
+                }
+            }
+        }
+    }
+
+    return true;
+}
+
+const llama_batch & llama_batch_allocr::get_batch() const {
+    return batch;
+}
+
+uint32_t llama_batch_allocr::get_n_outputs() const {
+    return n_outputs;
+}
+
+llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
+    return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
+}
+
+llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
+    return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
+}
+
+void llama_batch_allocr::clear() {
+    n_outputs = 0;
+
+    batch = {};
+    pos.clear();
+    n_seq_id.clear();
+    seq_id.clear();
+    output.clear();
+
+    for (auto & cur : seq_pos) {
+        cur.clear();
+    }
+
+    for (auto & cur : seq_cpl) {
+        std::fill(cur.begin(), cur.end(), false);
     }
 }
 
index b8260b94fd2d0aaf301347fdf70af299251556ef..a555c157234be82933b0bf7b35b43c4fed2593e4 100644 (file)
@@ -4,6 +4,7 @@
 
 #include <array>
 #include <vector>
+#include <set>
 
 // very similar to llama_batch,
 // but has more metadata about sequences
@@ -18,8 +19,8 @@ struct llama_ubatch {
     llama_token  *  token;    // [n_tokens]
     float        *  embd;     // [n_embd, n_tokens]
     llama_pos    *  pos;      // [n_tokens]
-    int32_t      *  n_seq_id; // [n_seqs] // TODO: remove, should belong to only 1 sequence
-    llama_seq_id ** seq_id;   // [n_seqs] // TODO: become llama_seq_id * seq_id;
+    int32_t      *  n_seq_id; // [n_seqs]
+    llama_seq_id ** seq_id;   // [n_seqs]
     int8_t       *  output;   // [n_tokens]
 };
 
@@ -39,8 +40,6 @@ struct llama_sbatch {
 
     size_t n_embd;
 
-    bool logits_all; // TODO: remove once lctx.logits_all is removed too
-
     // sorted indices into the batch
     std::vector<int64_t> ids;
     // batch indices of the output
@@ -76,19 +75,45 @@ struct llama_sbatch {
     llama_ubatch split_seq(size_t n_ubatch);
 
     llama_sbatch() = default;
-    llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
+    llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
 };
 
-// temporary allocate memory for the input batch if needed
-struct llama_batch_allocr {
-    struct llama_batch batch;
+// a helper for sanitizing and fulfilling a batch
+class llama_batch_allocr {
+public:
+    llama_batch_allocr();
+
+    // sanitize and auto-gen missing data in the input batch
+    // memory is optional. if provided will be used to check for sequence continuity and to determine the positions
+    bool init(
+            const llama_batch & batch_inp,
+            const llama_vocab & vocab,
+            const llama_memory_i * memory,
+            bool embd_all);
+
+    const llama_batch & get_batch() const;
+
+    uint32_t get_n_outputs() const;
+
+    llama_pos seq_pos_min(llama_seq_id seq_id) const;
+    llama_pos seq_pos_max(llama_seq_id seq_id) const;
+
+private:
+    void clear();
+
+    llama_batch batch;
+
+    uint32_t n_outputs;
 
     std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
+
     std::vector<llama_pos>      pos;
     std::vector<int32_t>        n_seq_id;
     std::vector<llama_seq_id *> seq_id;
-    std::vector<int8_t>         logits;
+    std::vector<int8_t>         output;
+
+    std::vector<std::set<llama_pos>> seq_pos; // seq_pos[s]: the set of positions in sequence s
+    std::vector<std::vector<bool>>   seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
 
-    // optionally fulfill the batch returned by llama_batch_get_one
-    llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
+    int debug;
 };
index d12743e6b9a0cf1bc12ebd698b71cef0a92db62b..bc4fa05a74ef470796099125835dcc32b0ad52d9 100644 (file)
@@ -183,6 +183,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
         return LLM_CHAT_TEMPLATE_BAILING;
     } else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
         return LLM_CHAT_TEMPLATE_LLAMA4;
+    } else if (tmpl_contains("<|endofuserprompt|>")) {
+        return LLM_CHAT_TEMPLATE_DOTS1;
     }
     return LLM_CHAT_TEMPLATE_UNKNOWN;
 }
@@ -643,6 +645,21 @@ int32_t llm_chat_apply_template(
         if (add_ass) {
             ss << "Assistant:";
         }
+    } else if (tmpl == LLM_CHAT_TEMPLATE_DOTS1) {
+        // dots.llm1.inst (DOTS1)
+        for (auto message : chat) {
+            std::string role(message->role);
+            if (role == "system") {
+                ss << "<|system|>" << message->content << "<|endofsystem|>";
+            } else if (role == "user") {
+                ss << "<|userprompt|>" << message->content << "<|endofuserprompt|>";
+            } else {
+                ss << "<|response|>" << message->content << "<|endofresponse|>";
+            }
+        }
+        if (add_ass) {
+            ss << "<|response|>";
+        }
     } else {
         // template not supported
         return -1;
index db24ade21e2ad76671732d3c6b625ea76357120e..38800010ae48b5da3474fb47ee4c490d693db68f 100644 (file)
@@ -43,6 +43,7 @@ enum llm_chat_template {
     LLM_CHAT_TEMPLATE_BAILING,
     LLM_CHAT_TEMPLATE_LLAMA4,
     LLM_CHAT_TEMPLATE_SMOLVLM,
+    LLM_CHAT_TEMPLATE_DOTS1,
     LLM_CHAT_TEMPLATE_UNKNOWN,
 };
 
index b130b484bcf6fd4c2c38a1b64de13962efe55be6..f56a58e9b6ec6208c92e0e38c4a879c5af44937c 100644 (file)
@@ -1,6 +1,7 @@
 #include "llama-context.h"
 
 #include "llama-impl.h"
+#include "llama-batch.h"
 #include "llama-io.h"
 #include "llama-memory.h"
 #include "llama-mmap.h"
@@ -18,7 +19,8 @@
 llama_context::llama_context(
         const llama_model & model,
               llama_context_params params) :
-    model(model) {
+    model(model),
+    batch_allocr(std::make_unique<llama_batch_allocr>()) {
     LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
 
     t_start_us = model.t_start_us;
@@ -27,8 +29,8 @@ llama_context::llama_context(
     const auto & hparams = model.hparams;
 
     cparams.n_seq_max = std::max(1u, params.n_seq_max);
-    if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) {
-        throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES));
+    if (cparams.n_seq_max > LLAMA_MAX_SEQ) {
+        throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
     }
 
     cparams.n_threads        = params.n_threads;
@@ -494,7 +496,7 @@ float * llama_context::get_logits() {
 }
 
 float * llama_context::get_logits_ith(int32_t i) {
-    int32_t j = -1;
+    int64_t j = -1;
 
     try {
         if (logits == nullptr) {
@@ -517,7 +519,7 @@ float * llama_context::get_logits_ith(int32_t i) {
         }
         if (j >= n_outputs) {
             // This should not happen
-            throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
+            throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
         }
 
         return logits + j*model.vocab.n_tokens();
@@ -536,7 +538,7 @@ float * llama_context::get_embeddings() {
 }
 
 float * llama_context::get_embeddings_ith(int32_t i) {
-    int32_t j = -1;
+    int64_t j = -1;
 
     try {
         if (embd == nullptr) {
@@ -559,7 +561,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
         }
         if (j >= n_outputs) {
             // This should not happen
-            throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
+            throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
         }
 
         return embd + j*model.hparams.n_embd;
@@ -719,52 +721,41 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
     return res;
 }
 
-int llama_context::encode(llama_batch & inp_batch) {
-    if (inp_batch.n_tokens == 0) {
+int llama_context::encode(const llama_batch & batch_inp) {
+    if (batch_inp.n_tokens == 0) {
         LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
         return -1;
     }
 
-    // temporary allocate memory for the input batch if needed
     // note: during encode, we always pass the full sequence starting from pos = 0
-    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0);
+    if (!batch_allocr->init(batch_inp, model.vocab, nullptr, true)) {
+        LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
+        return -1;
+    }
 
-    const llama_batch & batch = batch_allocr.batch;
-    const int32_t n_tokens = batch.n_tokens;
+    const llama_batch & batch = batch_allocr->get_batch();
 
-    const auto & hparams = model.hparams;
+    const uint32_t n_tokens = batch.n_tokens;
 
     GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
 
-    // TODO: move the validation to the llama_batch_allocr
-    if (batch.token) {
-        for (int32_t i = 0; i < n_tokens; ++i) {
-            if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
-                LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
-                return -1;
-            }
-
-            if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
-                LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
-                throw -1;
-            }
-        }
-    }
-
     // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
-    GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
+    GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
 
     if (t_compute_start_us == 0) {
         t_compute_start_us = ggml_time_us();
     }
 
+    // TODO: this clear of the buffer can easily be forgotten - need something better
     embd_seq.clear();
 
     n_queued_tokens += n_tokens;
 
+    const auto & hparams = model.hparams;
+
     const int64_t n_embd = hparams.n_embd;
 
-    llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
+    llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true);
 
     const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
 
@@ -774,7 +765,7 @@ int llama_context::encode(llama_batch & inp_batch) {
         return -2;
     };
 
-    for (int32_t i = 0; i < n_tokens; ++i) {
+    for (uint32_t i = 0; i < n_tokens; ++i) {
         output_ids[i] = i;
     }
 
@@ -830,7 +821,8 @@ int llama_context::encode(llama_batch & inp_batch) {
 
                     GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
 
-                    for (int32_t i = 0; i < n_tokens; i++) {
+                    // TODO: fix indexing [UBATCH_IDX]
+                    for (uint32_t i = 0; i < n_tokens; i++) {
                         const llama_seq_id seq_id = ubatch.seq_id[i][0];
                         if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
                             continue;
@@ -845,6 +837,7 @@ int llama_context::encode(llama_batch & inp_batch) {
                     auto & embd_seq_out = embd_seq;
                     const uint32_t n_cls_out = hparams.n_cls_out;
 
+                    // TODO: fix indexing [UBATCH_IDX]
                     for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
                         const llama_seq_id seq_id = ubatch.seq_id[s][0];
                         if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
@@ -878,10 +871,10 @@ int llama_context::encode(llama_batch & inp_batch) {
 
         // remember the sequence ids used during the encoding - needed for cross attention later
         cross.seq_ids_enc.resize(n_tokens);
-        for (int32_t i = 0; i < n_tokens; i++) {
+        for (uint32_t i = 0; i < n_tokens; i++) {
             cross.seq_ids_enc[i].clear();
-            for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
-                llama_seq_id seq_id = ubatch.seq_id[i][s];
+            for (int s = 0; s < batch.n_seq_id[i]; s++) {
+                llama_seq_id seq_id = batch.seq_id[i][s];
                 cross.seq_ids_enc[i].insert(seq_id);
             }
         }
@@ -890,51 +883,45 @@ int llama_context::encode(llama_batch & inp_batch) {
     return 0;
 }
 
-int llama_context::decode(llama_batch & inp_batch) {
+int llama_context::decode(const llama_batch & batch_inp) {
     if (!memory) {
         LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
-        return encode(inp_batch);
+        return encode(batch_inp);
     }
 
-    if (inp_batch.n_tokens == 0) {
+    if (batch_inp.n_tokens == 0) {
         LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
         return -1;
     }
 
-    if (!inp_batch.pos) {
-        if (inp_batch.seq_id) {
-            LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
-            return -1;
-        }
-    }
+    // when computing embeddings, all tokens are output
+    const bool embd_all = cparams.embeddings;
 
-    // temporary allocate memory for the input batch if needed
-    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max(0) + 1);
+    if (!batch_allocr->init(batch_inp, model.vocab, memory.get(), embd_all)) {
+        LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
+        return -1;
+    }
 
-    const llama_batch & batch = batch_allocr.batch;
+    const llama_batch & batch = batch_allocr->get_batch();
 
     const auto & vocab   = model.vocab;
     const auto & hparams = model.hparams;
 
     const int32_t n_vocab = vocab.n_tokens();
+    const int64_t n_embd  = hparams.n_embd;
 
-    const int64_t n_tokens_all = batch.n_tokens;
-    const int64_t n_embd       = hparams.n_embd;
+    const uint32_t n_tokens_all = batch.n_tokens;
 
     GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
 
-    // TODO: move the validation to the llama_batch_allocr
-    if (batch.token) {
-        for (int64_t i = 0; i < n_tokens_all; ++i) {
-            if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
-                LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
-                return -1;
-            }
+    const uint32_t n_outputs_all = batch_allocr->get_n_outputs();
 
-            if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
-                LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
-                return -1;
-            }
+    if (embd_all) {
+        // require that all tokens are output
+        if (n_outputs_all != n_tokens_all) {
+            LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
+                    __func__, n_outputs_all, n_tokens_all);
+            return -1;
         }
     }
 
@@ -947,25 +934,9 @@ int llama_context::decode(llama_batch & inp_batch) {
     }
     n_queued_tokens += n_tokens_all;
 
-    // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
-    const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
-
+    // TODO: this clear of the buffer can easily be forgotten - need something better
     embd_seq.clear();
 
-    int64_t n_outputs_all = 0;
-
-    // count outputs
-    if (batch.logits && !embd_pooled) {
-        for (uint32_t i = 0; i < n_tokens_all; ++i) {
-            n_outputs_all += batch.logits[i] != 0;
-        }
-    } else if (embd_pooled) {
-        n_outputs_all = n_tokens_all;
-    } else {
-        // keep last output only
-        n_outputs_all = 1;
-    }
-
     bool did_optimize = false;
 
     // handle any pending defrags/shifts
@@ -974,7 +945,7 @@ int llama_context::decode(llama_batch & inp_batch) {
     llama_memory_state_ptr mstate;
 
     while (true) {
-        mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
+        mstate = memory->init_batch(batch, cparams.n_ubatch, embd_all);
         if (!mstate) {
             return -2;
         }
@@ -1018,7 +989,7 @@ int llama_context::decode(llama_batch & inp_batch) {
 
     // reserve output buffer
     if (output_reserve(n_outputs_all) < n_outputs_all) {
-        LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
+        LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
         return -2;
     };
 
@@ -1027,7 +998,7 @@ int llama_context::decode(llama_batch & inp_batch) {
     do {
         const auto & ubatch = mstate->get_ubatch();
 
-        // count the outputs in this u_batch
+        // count the outputs in this ubatch
         {
             int32_t n_outputs_new = 0;
 
@@ -1052,18 +1023,19 @@ int llama_context::decode(llama_batch & inp_batch) {
 
         if (!res) {
             // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
-            llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
-            for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+            llama_pos pos_min[LLAMA_MAX_SEQ];
+            for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
                 pos_min[s] = std::numeric_limits<llama_pos>::max();
             }
 
+            // TODO: fix sequence indexing
             for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
                 const auto & seq_id = ubatch.seq_id[i][0];
 
                 pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]);
             }
 
-            for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+            for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
                 if (pos_min[s] == std::numeric_limits<llama_pos>::max()) {
                     continue;
                 }
@@ -1086,7 +1058,7 @@ int llama_context::decode(llama_batch & inp_batch) {
         //    ggml_graph_dump_dot(gf, NULL, "llama.dot");
         //}
 
-        auto * t_logits = cparams.embeddings ? nullptr         : res->get_logits();
+        auto * t_logits = res->get_logits();
         auto * t_embd   = cparams.embeddings ? res->get_embd() : nullptr;
 
         if (t_embd && res->get_embd_pooled()) {
@@ -1170,14 +1142,14 @@ int llama_context::decode(llama_batch & inp_batch) {
     n_outputs = n_outputs_all;
 
     // set output mappings
-    {
+    if (n_outputs > 0) {
         bool sorted_output = true;
 
         auto & out_ids = mstate->out_ids();
 
-        GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
+        GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
 
-        for (int64_t i = 0; i < n_outputs_all; ++i) {
+        for (int64_t i = 0; i < n_outputs; ++i) {
             int64_t out_id = out_ids[i];
             output_ids[out_id] = i;
             if (out_id != i) {
@@ -1189,20 +1161,22 @@ int llama_context::decode(llama_batch & inp_batch) {
         // note: this is mostly relevant for recurrent models atm
         if (!sorted_output) {
             const uint32_t n_vocab = model.vocab.n_tokens();
-            const uint32_t n_embd  = model.hparams.n_embd;
+            const uint64_t n_embd  = model.hparams.n_embd;
 
             GGML_ASSERT((size_t) n_outputs == out_ids.size());
 
             // TODO: is there something more efficient which also minimizes swaps?
             // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
-            for (int32_t i = 0; i < n_outputs - 1; ++i) {
-                int32_t j_min = i;
-                for (int32_t j = i + 1; j < n_outputs; ++j) {
+            for (uint32_t i = 0; i < n_outputs - 1; ++i) {
+                uint32_t j_min = i;
+                for (uint32_t j = i + 1; j < n_outputs; ++j) {
                     if (out_ids[j] < out_ids[j_min]) {
                         j_min = j;
                     }
                 }
-                if (j_min == i) { continue; }
+                if (j_min == i) {
+                    continue;
+                }
                 std::swap(out_ids[i], out_ids[j_min]);
                 if (logits_size > 0) {
                     for (uint32_t k = 0; k < n_vocab; k++) {
@@ -1215,8 +1189,10 @@ int llama_context::decode(llama_batch & inp_batch) {
                     }
                 }
             }
+
             std::fill(output_ids.begin(), output_ids.end(), -1);
-            for (int32_t i = 0; i < n_outputs; ++i) {
+
+            for (uint32_t i = 0; i < n_outputs; ++i) {
                 output_ids[out_ids[i]] = i;
             }
         }
@@ -1236,7 +1212,7 @@ int llama_context::decode(llama_batch & inp_batch) {
 // output
 //
 
-int32_t llama_context::output_reserve(int32_t n_outputs) {
+uint32_t llama_context::output_reserve(int32_t n_outputs) {
     const auto & hparams = model.hparams;
     const auto & vocab   = model.vocab;
 
@@ -1246,9 +1222,8 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
     const auto n_vocab = vocab.n_tokens();
     const auto n_embd  = hparams.n_embd;
 
-    // TODO: use a per-batch flag for logits presence instead
-    bool has_logits = !cparams.embeddings;
-    bool has_embd   =  cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
+    bool has_logits = true;
+    bool has_embd   = cparams.embeddings;
 
     // TODO: hacky enc-dec support
     if (model.arch == LLM_ARCH_T5) {
@@ -1302,8 +1277,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
     // set all ids as invalid (negative)
     std::fill(output_ids.begin(), output_ids.end(), -1);
 
-    this->n_outputs     = 0;
-    this->n_outputs_max = n_outputs_max;
+    this->n_outputs = 0;
 
     return n_outputs_max;
 }
@@ -1332,7 +1306,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
     LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
 
     if (n_tokens % n_seqs != 0) {
-        n_tokens = (n_tokens / n_seqs) * n_seqs;
+        n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
         n_outputs = std::min(n_outputs, n_tokens);
 
         LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
@@ -1794,14 +1768,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
 
         std::vector<int32_t> w_output_pos;
 
-        GGML_ASSERT(n_outputs <= n_outputs_max);
-
         w_output_pos.resize(n_outputs);
 
         // build a more compact representation of the output ids
         for (size_t i = 0; i < n_batch(); ++i) {
             // map an output id to a position in the batch
-            int32_t pos = output_ids[i];
+            int64_t pos = output_ids[i];
             if (pos >= 0) {
                 GGML_ASSERT(pos < n_outputs);
                 w_output_pos[pos] = i;
@@ -2071,14 +2043,11 @@ void llama_context::opt_epoch_iter(
 
         n_queued_tokens += n_tokens_all;
 
-        // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
-        const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
-
         embd_seq.clear();
 
-        int64_t n_outputs_all = n_tokens_all;
+        uint32_t n_outputs_all = n_tokens_all;
 
-        auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
+        auto mstate = memory->init_batch(batch, cparams.n_ubatch, true);
         if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
             LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
             break;
@@ -2086,7 +2055,7 @@ void llama_context::opt_epoch_iter(
 
         // reserve output buffer
         if (output_reserve(n_outputs_all) < n_outputs_all) {
-            LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
+            LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
             GGML_ABORT("TODO: handle this error");
         };
 
index 2e0da8c83bd59b5eca04b3c6d6fdfdf7f1a8fc30..040f03ae42e65b242f05f7e7f49a6cea577eb6ba 100644 (file)
@@ -1,7 +1,6 @@
 #pragma once
 
 #include "llama.h"
-#include "llama-batch.h"
 #include "llama-cparams.h"
 #include "llama-graph.h"
 #include "llama-adapter.h"
@@ -13,6 +12,7 @@
 #include <vector>
 
 struct llama_model;
+class llama_batch_allocr;
 
 class llama_io_read_i;
 class llama_io_write_i;
@@ -102,8 +102,8 @@ struct llama_context {
             llama_memory_state_i * mstate,
                      ggml_status & ret);
 
-    int encode(llama_batch & inp_batch);
-    int decode(llama_batch & inp_batch);
+    int encode(const llama_batch & batch_inp);
+    int decode(const llama_batch & batch_inp);
 
     //
     // state save/load
@@ -181,7 +181,7 @@ private:
 
     // Make sure enough space is available for outputs.
     // Returns max number of outputs for which space was reserved.
-    int32_t output_reserve(int32_t n_outputs);
+    uint32_t output_reserve(int32_t n_outputs);
 
     //
     // graph
@@ -246,8 +246,10 @@ private:
     // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
     std::map<llama_seq_id, std::vector<float>> embd_seq;
 
-    int32_t n_outputs     = 0; // number of actually-used outputs in the current ubatch or last logical batch
-    int32_t n_outputs_max = 0; // capacity (of tokens positions) for the output buffers
+    // reuse the batch_allocr to avoid unnecessary memory allocations
+    std::unique_ptr<llama_batch_allocr> batch_allocr;
+
+    uint32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
 
     std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
 
index f7b36590fe3e3f0a9dc9dfee9e7f749d2610aee7..a3e7a37ee36d78fd6af7d5e3aba8762ecbb636b6 100644 (file)
@@ -1,5 +1,5 @@
 #include "llama-cparams.h"
 
 size_t llama_max_parallel_sequences(void) {
-    return LLAMA_MAX_PARALLEL_SEQUENCES;
+    return LLAMA_MAX_SEQ;
 }
index 2871031ef09619bbce252126c1ddd13fb4681dcd..118615d5bd2d59f4e640c661592cf91ee0d3fda3 100644 (file)
@@ -4,7 +4,7 @@
 
 #include <cstdint>
 
-#define LLAMA_MAX_PARALLEL_SEQUENCES 64
+#define LLAMA_MAX_SEQ 64
 
 struct llama_cparams {
     uint32_t n_ctx;           // context size used during inference
index 27c9ab74be1125e6b7811c75a0a5ae92bf6be3a0..337fb5cb0df3634d00284c64429dce9141c641a9 100644 (file)
@@ -139,6 +139,7 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
 
         std::vector<uint64_t> sum(n_tokens, 0);
 
+        // TODO: fix indexing [UBATCH_IDX]
         for (int s = 0; s < n_seqs; ++s) {
             const llama_seq_id seq_id = ubatch->seq_id[s][0];
 
@@ -156,6 +157,7 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
             }
         }
 
+        // TODO: fix indexing [UBATCH_IDX]
         for (int s = 0; s < n_seqs; ++s) {
             const llama_seq_id seq_id = ubatch->seq_id[s][0];
 
@@ -180,6 +182,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
         uint32_t * data = (uint32_t *) cls->data;
         memset(cls->data, 0, n_tokens * ggml_element_size(cls));
 
+        // TODO: fix indexing [UBATCH_IDX]
         for (int s = 0; s < n_seqs; ++s) {
             const llama_seq_id seq_id = ubatch->seq_id[s][0];
 
@@ -210,6 +213,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
         std::vector<int> last_pos(n_tokens, -1);
         std::vector<int> last_row(n_tokens, -1);
 
+        // TODO: fix indexing [UBATCH_IDX]
         for (int s = 0; s < n_seqs; ++s) {
             const llama_seq_id seq_id = ubatch->seq_id[s][0];
 
@@ -250,22 +254,6 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
     }
 }
 
-void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
-    GGML_UNUSED(ubatch);
-
-    const int64_t n_kv = kv_state->get_n_kv();
-
-    if (s_mask) {
-        GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
-        float * data = (float *) s_mask->data;
-
-        // clear unused states
-        for (int i = 0; i < n_kv; ++i) {
-            data[i] = kv_state->s_mask(i);
-        }
-    }
-}
-
 void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
     GGML_UNUSED(ubatch);
 
@@ -299,6 +287,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
                                 const int32_t ti = s0*n_seq_tokens + i;
                                 float f = -INFINITY;
 
+                                // TODO: fix indexing [UBATCH_IDX]
                                 for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
                                     if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
                                         if (hparams.use_alibi) {
@@ -338,6 +327,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
                                 const int32_t ti = s0*n_seq_tokens + i;
                                 float f = -INFINITY;
 
+                                // TODO: fix indexing [UBATCH_IDX]
                                 for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
                                     if (ubatch->seq_id[s0][s] == seq_id) {
                                         if (hparams.use_alibi) {
@@ -393,6 +383,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
             for (int j = 0; j < n_tokens; ++j) {
                 for (int i = 0; i < n_enc; ++i) {
                     float f = -INFINITY;
+                    // TODO: fix indexing [UBATCH_IDX]
                     for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
                         const llama_seq_id seq_id = ubatch->seq_id[j][s];
                         if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
@@ -650,6 +641,7 @@ ggml_tensor * llm_graph_context::build_ffn(
             {
                 // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
                 int64_t split_point = cur->ne[0] / 2;
+                // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
                 ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
                 ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
 
@@ -663,7 +655,7 @@ ggml_tensor * llm_graph_context::build_ffn(
             {
                 // Split into two equal parts
                 int64_t split_point = cur->ne[0] / 2;
-                // TODO: these conts should not be needed
+                // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
                 ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
                 ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
 
@@ -986,23 +978,6 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
     return cur;
 }
 
-ggml_tensor * llm_graph_context::build_inp_s_mask() const {
-    const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
-
-    auto inp = std::make_unique<llm_graph_input_s_mask>(kv_state);
-
-    const auto n_kv = kv_state->get_n_kv();
-
-    auto & cur = inp->s_mask;
-
-    cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
-    ggml_set_input(cur);
-
-    res->add_input(std::move(inp));
-
-    return cur;
-}
-
 ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
     auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
 
@@ -1455,43 +1430,53 @@ ggml_tensor * llm_graph_context::build_attn(
     return cur;
 }
 
-ggml_tensor * llm_graph_context::build_copy_mask_state(
+ggml_tensor * llm_graph_context::build_recurrent_state(
          ggml_cgraph * gf,
          ggml_tensor * s,
          ggml_tensor * state_copy,
-         ggml_tensor * state_mask,
-             int32_t   n_state,
-             int32_t   n_seqs) const {
+             int32_t   state_size,
+             int32_t   n_seqs,
+                bool   avoid_copies) const {
     const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
 
     const auto n_kv    = kv_state->get_n_kv();
     const auto kv_head = kv_state->get_head();
+    const auto rs_zero = kv_state->get_rs_z();
+
+    ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
 
-    ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size());
+    // Clear a single state which will then be copied to the other cleared states.
+    // Note that this is a no-op when the view is zero-sized.
+    ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
+    ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
 
-    // copy states
-    // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
-    // this shrinks the tensors's ne[1] to n_kv
-    states = ggml_get_rows(ctx0, states, state_copy);
+    ggml_tensor * output_states;
 
-    // clear states of sequences which are starting at the beginning of this batch
-    // FIXME: zero-out NANs?
-    states = ggml_mul(ctx0, states, state_mask);
+    if (!avoid_copies) {
+        // copy states
+        // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
+        // {state_size, kv_size} -> {state_size, n_seqs}
+        output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
+        ggml_build_forward_expand(gf, output_states);
+    } else {
+        // FIXME: make the gathering operation happen before the copy below
+        //        (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
+        output_states = states;
+    }
 
-    // copy states which won't be changed further (between n_seqs and n_kv)
+    // copy extra states which won't be changed further (between n_seqs and n_kv)
+    ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
     ggml_build_forward_expand(gf,
         ggml_cpy(ctx0,
-            ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs          )*n_state*ggml_element_size(states)),
-            ggml_view_1d(ctx0, s,      n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
+            states_extra,
+            ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
 
-    // the part of the states that will be used and modified
-    return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
+    return output_states;
 }
 
 ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
          ggml_cgraph * gf,
          ggml_tensor * state_copy,
-         ggml_tensor * state_mask,
   const llama_ubatch & ubatch,
                  int   il) const {
     const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
@@ -1502,8 +1487,8 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
 
     ggml_tensor * token_shift_all = kv_state->get_k_l(il);
 
-    ggml_tensor * token_shift = build_copy_mask_state(
-            gf, token_shift_all, state_copy, state_mask,
+    ggml_tensor * token_shift = build_recurrent_state(
+            gf, token_shift_all, state_copy,
             hparams.n_embd_k_s(), n_seqs);
 
     token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
@@ -1578,23 +1563,30 @@ void llm_graph_context::build_pooling(
                 ggml_tensor * inp_cls = build_inp_cls();
                 inp = ggml_get_rows(ctx0, inp, inp_cls);
 
-                if (cls != nullptr && cls_b != nullptr) {
+                if (cls) {
                     // classification head
                     // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
-                    cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
+                    cur = ggml_mul_mat(ctx0, cls, inp);
+                    if (cls_b) {
+                        cur = ggml_add(ctx0, cur, cls_b);
+                    }
                     cur = ggml_tanh(ctx0, cur);
 
                     // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
                     // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
                     if (cls_out) {
-                        GGML_ASSERT(cls_out_b != nullptr);
-                        cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
+                        cur = ggml_mul_mat(ctx0, cls_out, cur);
+                        if (cls_out_b) {
+                            cur = ggml_add(ctx0, cur, cls_out_b);
+                        }
                     }
                 } else if (cls_out) {
                     // Single layer classification head (direct projection)
                     // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
-                    GGML_ASSERT(cls_out_b != nullptr);
-                    cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, inp), cls_out_b);
+                    cur = ggml_mul_mat(ctx0, cls_out, inp);
+                    if (cls_out_b) {
+                        cur = ggml_add(ctx0, cur, cls_out_b);
+                    }
                 } else {
                     GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
                 }
index 28da6a5228bdcbb928d2d16c8da1e2a6bce7b027..87813119b1a3cf5656bec8b837df18628e7902a5 100644 (file)
@@ -200,18 +200,6 @@ public:
     const llama_kv_cache_recurrent_state * kv_state;
 };
 
-class llm_graph_input_s_mask : public llm_graph_input_i {
-public:
-    llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
-    virtual ~llm_graph_input_s_mask() = default;
-
-    void set_input(const llama_ubatch * ubatch) override;
-
-    ggml_tensor * s_mask; // F32 [1, n_kv]
-
-    const llama_kv_cache_recurrent_state * kv_state;
-};
-
 class llm_graph_input_cross_embd : public llm_graph_input_i {
 public:
     llm_graph_input_cross_embd(
@@ -390,7 +378,7 @@ struct llm_graph_params {
     const llama_memory_state_i * mstate;
     const llama_cross          * cross;
 
-    int32_t n_outputs;
+    uint32_t n_outputs;
 
     const llm_graph_cb & cb;
 };
@@ -424,8 +412,8 @@ struct llm_graph_context {
     const float norm_eps;
     const float norm_rms_eps;
 
-    const int32_t n_tokens;
-    const int32_t n_outputs;
+    const int64_t n_tokens;
+    const int64_t n_outputs;
     const int32_t n_ctx_orig; // yarn
 
     const enum llama_pooling_type pooling_type;
@@ -521,7 +509,6 @@ struct llm_graph_context {
     ggml_tensor * build_inp_mean() const;
     ggml_tensor * build_inp_cls() const;
     ggml_tensor * build_inp_s_copy() const;
-    ggml_tensor * build_inp_s_mask() const;
 
     ggml_tensor * build_inp_cross_embd() const;
     ggml_tensor * build_inp_pos_bucket_enc() const;
@@ -606,18 +593,17 @@ struct llm_graph_context {
     // recurrent
     //
 
-    ggml_tensor * build_copy_mask_state(
+    ggml_tensor * build_recurrent_state(
              ggml_cgraph * gf,
              ggml_tensor * s,
              ggml_tensor * state_copy,
-             ggml_tensor * state_mask,
-                 int32_t   n_state,
-                 int32_t   n_seqs) const;
+                 int32_t   state_size,
+                 int32_t   n_seqs,
+                    bool   avoid_copies = false) const;
 
     ggml_tensor * build_rwkv_token_shift_load(
              ggml_cgraph * gf,
              ggml_tensor * state_copy,
-             ggml_tensor * state_mask,
       const llama_ubatch & ubatch,
                      int   il) const;
 
index f5c6dcd66ce9e8a519d7158a28b9cd5fcd29568b..8f6f120f682b769730113d3c4bf835284209a239 100644 (file)
@@ -359,18 +359,16 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
     return result;
 }
 
-llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
-    GGML_UNUSED(embd_pooled);
-
-    auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
+llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
+    auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
 
     std::vector<llama_ubatch> ubatches;
 
     while (sbatch.n_tokens > 0) {
         llama_ubatch ubatch;
 
-        if (embd_pooled) {
-            // Pooled embeddings cannot be split across ubatches (yet)
+        if (embd_all) {
+            // if all tokens are output, split by sequence
             ubatch = sbatch.split_seq(n_ubatch);
         } else {
             ubatch = sbatch.split_equal(n_ubatch);
@@ -406,21 +404,12 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
 
     bool success = true;
 
-    // TODO: here we have to verify that all ubatches can fit in the cells
-    //       however, the current implementation is broken because it relies on s_copy() and s_mask() to update the cells
-    //         during the compute of each ubatch. to reproduce, uncomment the following loop and run:
-    //
-    //           $ llama-parallel -m ./mamba-130m/ggml-model-f16.gguf -np 5 -ns 8
-    //
-    //       recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed
-    //
-    GGML_UNUSED(ubatches);
-    //for (const auto & ubatch : ubatches) {
-    //    if (!find_slot(ubatch)) {
-    //        success = false;
-    //        break;
-    //    }
-    //}
+    for (const auto & ubatch : ubatches) {
+        if (!find_slot(ubatch)) {
+            success = false;
+            break;
+        }
+    }
 
     // restore the original state
     cells = std::move(org_cells);
@@ -431,14 +420,13 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
 }
 
 bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
-    const uint32_t n_tokens = ubatch.n_tokens;
-    const uint32_t n_seqs   = ubatch.n_seqs;
+    const uint32_t n_seqs = ubatch.n_seqs;
 
     const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
 
     // if we have enough unused cells before the current head ->
     //   better to start searching from the beginning of the cache, hoping to fill it
-    if (head > used + 2*n_tokens) {
+    if (head > used + 2*n_seqs) {
         head = 0;
     }
 
@@ -534,16 +522,16 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
                 empty_cell.src = orig_cell.src;
                 orig_cell.seq_id.erase(seq_id);
                 empty_cell.seq_id.insert(seq_id); // will be overwritten
+                GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id
             }
             seq_meta.tail = next_empty_cell;
             // find next empty cell
             if (s + 1 < n_seqs) {
-                next_empty_cell += 1;
                 for (uint32_t i = 0; i < size; ++i) {
+                    next_empty_cell += 1;
                     if (next_empty_cell >= size) { next_empty_cell -= size; }
                     kv_cell & cell = cells[next_empty_cell];
                     if (cell.is_empty()) { break; }
-                    next_empty_cell += 1;
                 }
             }
         }
@@ -553,8 +541,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
 
     // gather and re-order
     for (uint32_t s = 0; s < n_seqs; ++s) {
-        int32_t dst_id = s + min;
-        int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
+        const int32_t dst_id = s + min;
+        const int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
         if (dst_id != src_id) {
             kv_cell & dst_cell = cells[dst_id];
             kv_cell & src_cell = cells[src_id];
@@ -563,12 +551,14 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
             std::swap(dst_cell.src, src_cell.src);
             std::swap(dst_cell.seq_id, src_cell.seq_id);
 
-            // swap tails (assuming they NEVER overlap)
-            for (const llama_seq_id seq_id : src_cell.seq_id) {
-                cells[seq_id].tail = src_id;
-            }
-            for (const llama_seq_id seq_id : dst_cell.seq_id) {
-                cells[seq_id].tail = dst_id;
+            // swap tails
+            for (uint32_t i = 0; i < size; ++i) {
+                int32_t & tail = cells[i].tail;
+                if (tail == src_id) {
+                    tail = dst_id;
+                } else if (tail == dst_id) {
+                    tail = src_id;
+                }
             }
         }
     }
@@ -576,7 +566,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
     // update the pos of the used seqs
     for (uint32_t s = 0; s < n_seqs; ++s) {
         const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
-        int32_t cell_id = s + min;
+        const int32_t cell_id = s + min;
         kv_cell & cell = cells[cell_id];
 
         if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
@@ -594,6 +584,38 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
         }
     }
 
+    // Find first cell without src refs, to use as the zero-ed state
+    {
+        // TODO: bake-in src refcounts in the cell metadata
+        std::vector<int32_t> refcounts(size, 0);
+        for (size_t i = 0; i < size; ++i) {
+            const int32_t src = cells[i].src;
+            if (src >= 0) {
+                refcounts[src] += 1;
+            }
+        }
+
+        rs_z = -1;
+        for (int i = min; i <= max; ++i) {
+            if (refcounts[i] == 0) {
+                rs_z = i;
+                break;
+            }
+        }
+
+        for (int i = min; i <= max; ++i) {
+            if (cells[i].src < 0) {
+                GGML_ASSERT(rs_z >= 0);
+                cells[i].src0 = rs_z;
+            } else {
+                // Stage the source ids for all used cells to allow correct seq_* behavior
+                // and still make these values available when setting the inputs
+                cells[i].src0 = cells[i].src;
+            }
+            cells[i].src = i; // avoid moving or clearing twice
+        }
+    }
+
     // allow getting the range of used cells, from head to head + n
     head = min;
     n    = max - min + 1;
@@ -605,47 +627,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
 }
 
 bool llama_kv_cache_recurrent::get_can_shift() const {
-    return false;
-}
-
-int32_t llama_kv_cache_recurrent::s_copy(int i) const {
-    const uint32_t cell_id = i + head;
-
-    //////////////////////////////////////////////
-    // TODO: this should not mutate the KV cache !
-    kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
-
-    // prevent out-of-bound sources
-    if (cell.src < 0 || (uint32_t) cell.src >= size) {
-        cell.src = cell_id;
-    }
-
-    int32_t res = cell.src;
-
-    // TODO: do not mutate the KV cache
-    // ensure copy only happens once
-    if (cell.src != (int32_t) cell_id) {
-        cell.src = cell_id;
-    }
-
-    return res;
-}
-
-float llama_kv_cache_recurrent::s_mask(int i) const {
-    const uint32_t cell_id = i + head;
-
-    //////////////////////////////////////////////
-    // TODO: this should not mutate the KV cache !
-    kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
-
-    float res = (float) (cell.src >= 0);
-
-    // only clear once
-    if (cell.src < 0) {
-        cell.src = cell_id;
-    }
-
-    return res;
+    // shifting the pos is trivial for recurrent models
+    return true;
 }
 
 size_t llama_kv_cache_recurrent::total_size() const {
@@ -1111,6 +1094,10 @@ uint32_t llama_kv_cache_recurrent_state::get_head() const {
     return is_full ? 0 : kv->head;
 }
 
+int32_t llama_kv_cache_recurrent_state::get_rs_z() const {
+    return is_full ? 0 : kv->rs_z;
+}
+
 uint32_t llama_kv_cache_recurrent_state::get_size() const {
     return kv->size;
 }
@@ -1124,9 +1111,5 @@ ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const {
 }
 
 int32_t llama_kv_cache_recurrent_state::s_copy(int i) const {
-    return kv->s_copy(i);
-}
-
-float llama_kv_cache_recurrent_state::s_mask(int i) const {
-    return kv->s_mask(i);
+    return  kv->cells[i + kv->head].src0;
 }
index d1da1225655fa4b18d0ae055df131ebff2c9d4cc..f9b01a6513393fa4fea94eeeda9f51cbc89e5377 100644 (file)
@@ -32,8 +32,7 @@ public:
     llama_memory_state_ptr init_batch(
             const llama_batch & batch,
             uint32_t n_ubatch,
-            bool embd_pooled,
-            bool logits_all) override;
+            bool embd_all) override;
 
     llama_memory_state_ptr init_full() override;
 
@@ -57,10 +56,6 @@ public:
 
     bool get_can_shift() const override;
 
-    // TODO: temporary methods - they are not really const as they do const_cast<>, fix this
-    int32_t s_copy(int i) const;
-    float   s_mask(int i) const;
-
     // state write/load
 
     void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
@@ -73,10 +68,14 @@ public:
     // computed before each graph build
     uint32_t n = 0;
 
+    // first zero-ed state
+    int32_t rs_z = -1;
+
     // TODO: optimize for recurrent state needs
     struct kv_cell {
         llama_pos pos  = -1;
-        int32_t   src  = -1; // used to copy states
+        int32_t   src  = -1; // used to know where states should be copied from
+        int32_t   src0 = -1; // like src, but only used when setting the inputs (allowing to copy once)
         int32_t   tail = -1;
 
         std::set<llama_seq_id> seq_id;
@@ -157,13 +156,13 @@ public:
 
     uint32_t get_n_kv() const;
     uint32_t get_head() const;
+    int32_t  get_rs_z() const;
     uint32_t get_size() const;
 
     ggml_tensor * get_k_l(int32_t il) const;
     ggml_tensor * get_v_l(int32_t il) const;
 
     int32_t s_copy(int i) const;
-    float   s_mask(int i) const;
 
 private:
     const llama_memory_status status;
index 28d18265476497e42b93087b8c336a360d6e0d13..a4a4c2b1b859de2c2b4be3904f8ecdf0c2e27d9c 100644 (file)
@@ -95,36 +95,69 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
     return kv_swa->seq_pos_max(seq_id);
 }
 
-llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
-    GGML_UNUSED(embd_pooled);
+llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
+    GGML_UNUSED(embd_all);
 
-    // TODO: if we fail with split_simple, we should attempt different splitting strategies
-    //       but to do that properly, we first have to refactor the batches to be more flexible
+    // first try simple split
+    do {
+        auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
 
-    auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
+        std::vector<llama_ubatch> ubatches;
 
-    std::vector<llama_ubatch> ubatches;
+        while (sbatch.n_tokens > 0) {
+            auto ubatch = sbatch.split_simple(n_ubatch);
 
-    while (sbatch.n_tokens > 0) {
-        auto ubatch = sbatch.split_simple(n_ubatch);
+            ubatches.push_back(ubatch);
+        }
 
-        ubatches.push_back(ubatch);
-    }
+        auto heads_base = kv_base->prepare(ubatches);
+        if (heads_base.empty()) {
+            break;
+        }
 
-    auto heads_base = kv_base->prepare(ubatches);
-    if (heads_base.empty()) {
-        return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
-    }
+        auto heads_swa = kv_swa->prepare(ubatches);
+        if (heads_swa.empty()) {
+            break;
+        }
 
-    auto heads_swa = kv_swa->prepare(ubatches);
-    if (heads_swa.empty()) {
-        return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
-    }
+        assert(heads_base.size() == heads_swa.size());
+
+        return std::make_unique<llama_kv_cache_unified_iswa_state>(
+                this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
+    } while (false);
+
+    // if it fails, try equal split
+    do {
+        auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
+
+        std::vector<llama_ubatch> ubatches;
 
-    assert(heads_base.size() == heads_swa.size());
+        while (sbatch.n_tokens > 0) {
+            auto ubatch = sbatch.split_equal(n_ubatch);
+
+            ubatches.push_back(ubatch);
+        }
+
+        auto heads_base = kv_base->prepare(ubatches);
+        if (heads_base.empty()) {
+            break;
+        }
+
+        auto heads_swa = kv_swa->prepare(ubatches);
+        if (heads_swa.empty()) {
+            break;
+        }
+
+        assert(heads_base.size() == heads_swa.size());
+
+        return std::make_unique<llama_kv_cache_unified_iswa_state>(
+                this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
+    } while (false);
+
+    // TODO: if we fail again, we should attempt different splitting strategies
+    //       but to do that properly, we first have to refactor the batches to be more flexible
 
-    return std::make_unique<llama_kv_cache_unified_iswa_state>(
-            this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
+    return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
 }
 
 llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
index 3dbf33ed7b960d3985804e84a63eccddbb308198..6e941e1a41b88b31f0838c4834ae6eda830a3455 100644 (file)
@@ -34,8 +34,7 @@ public:
     llama_memory_state_ptr init_batch(
             const llama_batch & batch,
             uint32_t n_ubatch,
-            bool embd_pooled,
-            bool logits_all) override;
+            bool embd_all) override;
 
     llama_memory_state_ptr init_full() override;
 
index 3566d5fd4d72bdf5a7e8a968f18197e7743815ae..3b37679859d392481c4e56bb1004acbb4004ee76 100644 (file)
@@ -127,6 +127,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
                 ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
                 ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
     }
+
+    const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
+    debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
 }
 
 void llama_kv_cache_unified::clear(bool data) {
@@ -307,24 +310,27 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
 llama_memory_state_ptr llama_kv_cache_unified::init_batch(
             const llama_batch & batch,
             uint32_t n_ubatch,
-            bool embd_pooled,
-            bool logits_all) {
-    GGML_UNUSED(embd_pooled);
+            bool embd_all) {
+    GGML_UNUSED(embd_all);
 
-    auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
+    do {
+        auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
 
-    std::vector<llama_ubatch> ubatches;
-    while (sbatch.n_tokens > 0) {
-        ubatches.push_back(sbatch.split_simple(n_ubatch));
-    }
+        std::vector<llama_ubatch> ubatches;
+        while (sbatch.n_tokens > 0) {
+            ubatches.push_back(sbatch.split_simple(n_ubatch));
+        }
 
-    auto heads = prepare(ubatches);
-    if (heads.empty()) {
-        return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
-    }
+        auto heads = prepare(ubatches);
+        if (heads.empty()) {
+            break;
+        }
+
+        return std::make_unique<llama_kv_cache_unified_state>(
+                this, std::move(sbatch), std::move(heads), std::move(ubatches));
+    } while (false);
 
-    return std::make_unique<llama_kv_cache_unified_state>(
-            this, std::move(sbatch), std::move(heads), std::move(ubatches));
+    return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
 }
 
 llama_memory_state_ptr llama_kv_cache_unified::init_full() {
@@ -512,43 +518,68 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
         head_cur = 0;
     }
 
-    // otherwise, one cell per token.
-
     if (n_tokens > cells.size()) {
         LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
         return -1;
     }
 
-//#define FIND_SLOT_DEBUG 1
-#if FIND_SLOT_DEBUG
-    LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", cells.used_max_p1(), cells.get_used(), head, n_swa);
+    if (debug > 0) {
+        LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa);
 
-    // for debugging
-    {
-        std::string ss;
-        if (n_swa > 0) {
+        if ((debug == 2 && n_swa > 0) || debug > 2) {
+            std::string ss;
             for (uint32_t i = 0; i < cells.size(); ++i) {
                 if (cells.is_empty(i)) {
                     ss += '.';
                 } else {
-                    ss += std::to_string(cells.seq_get(i));
+                    assert(cells.seq_count(i) >= 1);
+
+                    if (cells.seq_count(i) == 1) {
+                        ss += std::to_string(cells.seq_get(i));
+                    } else {
+                        ss += 'M';
+                    }
                 }
                 if (i%256 == 255) {
+                    ss += " *";
                     ss += '\n';
                 }
             }
+            LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
         }
-        LLAMA_LOG_WARN("\n%s\n", ss.c_str());
-    }
 
-    for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
-        if (cells.seq_pos_min(s) < 0) {
-            continue;
+        if ((debug == 2 && n_swa > 0) || debug > 2) {
+            std::string ss;
+            for (uint32_t i = 0; i < cells.size(); ++i) {
+                std::string cur;
+                if (cells.is_empty(i)) {
+                    cur = '.';
+                } else {
+                    cur = std::to_string(cells.pos_get(i));
+                }
+                const int n = cur.size();
+                for (int j = 0; j < 5 - n; ++j) {
+                    cur += ' ';
+                }
+                ss += cur;
+                if (i%256 == 255) {
+                    ss += " *";
+                }
+                if (i%64 == 63) {
+                    ss += '\n';
+                }
+            }
+            LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
         }
 
-        LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[%d] = %5d, max[%d] = %5d\n", n_swa, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
+        for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
+            if (cells.seq_pos_min(s) < 0) {
+                continue;
+            }
+
+            LLAMA_LOG_DEBUG("%s: min[%d] = %5d, max[%d] = %5d\n", __func__, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
+        }
     }
-#endif
 
     uint32_t n_tested = 0;
 
@@ -559,21 +590,15 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
             continue;
         }
 
-        // keep track of what the minimum sequence positions would be if we accept the ubatch
-        llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
-        for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
-            seq_pos_min[s] = cells.seq_pos_min(s);
-        }
-
         bool found = true;
         for (uint32_t i = 0; i < n_tokens; i++) {
-            const llama_pos    pos    = ubatch.pos[i];
-            const llama_seq_id seq_id = ubatch.seq_id[i][0];
+            //const llama_pos    pos    = ubatch.pos[i];
+            //const llama_seq_id seq_id = ubatch.seq_id[i][0];
 
             // can we use this cell? either:
             //  - the cell is empty
             //  - the cell is occupied only by one sequence:
-            //    - mask causally, if the sequence is the same as the one we are inserting
+            //    - (disabled) mask causally, if the sequence is the same as the one we are inserting
             //    - mask SWA, using current max pos for that sequence in the cache
             //                always insert in the cell with minimum pos
             bool can_use = cells.is_empty(head_cur + i);
@@ -581,21 +606,17 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
             if (!can_use && cells.seq_count(head_cur + i) == 1) {
                 const llama_pos pos_cell = cells.pos_get(head_cur + i);
 
-                // causal mask
-                if (cells.seq_has(head_cur + i, seq_id)) {
-                    can_use = pos_cell >= pos;
-                }
+                // (disabled) causal mask
+                // note: it's better to purge any "future" tokens beforehand
+                //if (cells.seq_has(head_cur + i, seq_id)) {
+                //    can_use = pos_cell >= pos;
+                //}
 
                 if (!can_use) {
                     const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
 
                     // SWA mask
-                    // note: we insert only in the cell with minimum pos in order to preserve the invariant that
-                    //       all positions between [pos_min, pos_max] for each sequence will be present in the cache
-                    //       ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
-                    if (pos_cell == seq_pos_min[seq_id_cell] &&
-                        is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
-                        seq_pos_min[seq_id_cell]++;
+                    if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
                         can_use = true;
                     }
                 }
@@ -623,18 +644,58 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
 }
 
 void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
-    for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
-        if (!cells.is_empty(head_cur + i)) {
-            cells.rm(head_cur + i);
-        }
+    if (debug > 0) {
+        LLAMA_LOG_DEBUG("%s: ubatch info:\n", __func__);
+        LLAMA_LOG_DEBUG("%s:   n_tokens = %d, equal_seqs = %d\n", __func__, ubatch.n_tokens, ubatch.equal_seqs);
+        LLAMA_LOG_DEBUG("%s:   n_seq_tokens = %d, n_seqs = %d\n", __func__, ubatch.n_seq_tokens, ubatch.n_seqs);
+    }
+
+    // keep track of the max sequence position that we would overwrite with this ubatch
+    // for non-SWA cache, this would be always empty
+    llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
+    for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
+        seq_pos_max_rm[s] = -1;
+    }
+
+    for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
+        for (uint32_t j = 0; j < ubatch.n_seq_tokens; ++j) {
+            const uint32_t idx = s*ubatch.n_seq_tokens + j;
+
+            if (!cells.is_empty(head_cur + idx)) {
+                assert(cells.seq_count(head_cur + idx) == 1);
+
+                const llama_seq_id seq_id = cells.seq_get(head_cur + idx);
+                const llama_pos    pos    = cells.pos_get(head_cur + idx);
 
-        cells.pos_set(head_cur + i, ubatch.pos[i]);
+                seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
 
-        for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
-            cells.seq_add(head_cur + i, ubatch.seq_id[i][j]);
+                cells.rm(head_cur + idx);
+            }
+
+            cells.pos_set(head_cur + idx, ubatch.pos[idx]);
+
+            // TODO: fix indexing [UBATCH_IDX]
+            for (int32_t i = 0; i < ubatch.n_seq_id[s]; i++) {
+                cells.seq_add(head_cur + idx, ubatch.seq_id[s][i]);
+            }
         }
     }
 
+    // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
+    //       will be present in the cache. so we have to purge any position which is less than those we would overwrite
+    //       ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
+    for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
+        if (seq_pos_max_rm[s] == -1) {
+            continue;
+        }
+
+        if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
+            LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
+                    __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
+
+            seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
+        }
+    }
     // move the head at the end of the slot
     head = head_cur + ubatch.n_tokens;
 }
@@ -731,14 +792,14 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
 }
 
 void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
-    const int64_t n_tokens     = ubatch->n_tokens;
-    const int64_t n_seq_tokens = ubatch->n_seq_tokens;
-    const int64_t n_seqs       = ubatch->n_seqs;
+    const uint32_t n_tokens     = ubatch->n_tokens;
+    const uint32_t n_seq_tokens = ubatch->n_seq_tokens;
+    const uint32_t n_seqs       = ubatch->n_seqs;
 
     GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
     float * data = (float *) dst->data;
 
-    const auto n_kv = dst->ne[0];
+    const int64_t n_kv = dst->ne[0];
 
     // Use only the previous KV cells of the correct sequence for each token of the ubatch.
     // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
@@ -752,12 +813,14 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
     //      xxxxx-----
     //      xxxxx-----
     // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
-    for (int h = 0; h < 1; ++h) {
-        for (int s = 0; s < n_seqs; ++s) {
+    for (uint32_t h = 0; h < 1; ++h) {
+        for (uint32_t s = 0; s < n_seqs; ++s) {
             const llama_seq_id seq_id = ubatch->seq_id[s][0];
 
-            for (int j = 0; j < n_seq_tokens; ++j) {
-                const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
+            for (uint32_t j = 0; j < n_seq_tokens; ++j) {
+                const uint32_t idx = s*n_seq_tokens + j;
+
+                const llama_pos p1 = ubatch->pos[idx];
 
                 for (uint32_t i = 0; i < n_kv; ++i) {
                     float f = 0.0f;
@@ -787,16 +850,16 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
                         f = -INFINITY;
                     }
 
-                    data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
+                    data[h*(n_kv*n_tokens) + idx*n_kv + i] = f;
                 }
             }
         }
 
         // mask padded tokens
         if (data) {
-            for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
-                for (uint32_t j = 0; j < n_kv; ++j) {
-                    data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
+            for (uint32_t j = n_tokens; j < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++j) {
+                for (uint32_t i = 0; i < n_kv; ++i) {
+                    data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
                 }
             }
         }
@@ -1447,9 +1510,11 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
         seq_rm(dest_seq_id, -1, -1);
 
         llama_sbatch sbatch;
-        llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
+        llama_ubatch ubatch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
 
-        batch.n_tokens = cell_count;
+        ubatch.n_tokens = cell_count;
+        ubatch.n_seq_tokens = cell_count;
+        ubatch.n_seqs = 1;
 
         for (uint32_t i = 0; i < cell_count; ++i) {
             llama_pos pos;
@@ -1469,18 +1534,18 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
                 io.read_to(&seq_id, sizeof(seq_id));
             }
 
-            batch.pos[i]      = pos;
-            batch.n_seq_id[i] = n_seq_id;
-            batch.seq_id[i]   = &dest_seq_id;
+            ubatch.pos[i]      = pos;
+            ubatch.n_seq_id[i] = n_seq_id;
+            ubatch.seq_id[i]   = &dest_seq_id;
         }
 
-        const auto head_cur = find_slot(batch);
+        const auto head_cur = find_slot(ubatch);
         if (head_cur < 0) {
             LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
             return false;
         }
 
-        apply_ubatch(head_cur, batch);
+        apply_ubatch(head_cur, ubatch);
 
         // keep the head at the old position because we will read the KV data into it in state_read_data()
         head = head_cur;
@@ -1488,8 +1553,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
         // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
         // Assume that this is one contiguous block of cells
         GGML_ASSERT(head_cur + cell_count <= cells.size());
-        GGML_ASSERT(cells.pos_get(head_cur)                  == batch.pos[0]);
-        GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == batch.pos[cell_count - 1]);
+        GGML_ASSERT(cells.pos_get(head_cur)                  == ubatch.pos[0]);
+        GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
         GGML_ASSERT(cells.seq_has(head_cur,                  dest_seq_id));
         GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
     } else {
@@ -1674,7 +1739,7 @@ llama_kv_cache_unified_state::llama_kv_cache_unified_state(
         llama_context * lctx,
         bool do_shift,
         defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) {
-    if (!do_shift && dinfo.empty()) {
+    if (!do_shift && this->dinfo.empty()) {
         status = LLAMA_MEMORY_STATUS_NO_UPDATE;
     }
 }
index 49f410ef6ecabf11aa267b0d0568f66934a41a05..d96571d952b81db3e47b8653018b79b5e3235325 100644 (file)
@@ -59,8 +59,7 @@ public:
     llama_memory_state_ptr init_batch(
             const llama_batch & batch,
             uint32_t n_ubatch,
-            bool embd_pooled,
-            bool logits_all) override;
+            bool embd_all) override;
 
     llama_memory_state_ptr init_full() override;
 
@@ -158,6 +157,8 @@ private:
     // SWA
     const uint32_t n_swa = 0;
 
+    int debug = 0;
+
     const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
 
     std::vector<ggml_context_ptr>        ctxs;
index acf30aebec69b071bc0b36223771e71ce0a5f9fd..1d4e70f4d321249882287e0bf6b1f56f1c8110dc 100644 (file)
@@ -23,7 +23,7 @@ public:
 
         used.clear();
 
-        for (uint32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+        for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
             seq_pos[s].clear();
         }
     }
@@ -240,7 +240,7 @@ public:
     llama_seq_id seq_get(uint32_t i) const {
         assert(seq[i].count() == 1);
 
-        for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+        for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
             if (seq[i].test(s)) {
                 return s;
             }
@@ -253,7 +253,7 @@ public:
     // return -1 if the sequence is not present
     llama_pos seq_pos_min(llama_seq_id seq_id) const {
         assert(seq_id >= 0);
-        assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
+        assert(seq_id < LLAMA_MAX_SEQ);
 
         if (seq_pos[seq_id].empty()) {
             return -1;
@@ -266,7 +266,7 @@ public:
     // return -1 if the sequence is not present
     llama_pos seq_pos_max(llama_seq_id seq_id) const {
         assert(seq_id >= 0);
-        assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
+        assert(seq_id < LLAMA_MAX_SEQ);
 
         if (seq_pos[seq_id].empty()) {
             return -1;
@@ -384,20 +384,20 @@ private:
     //
     std::vector<llama_pos> shift;
 
-    using bits_t = std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>;
+    using bits_t = std::bitset<LLAMA_MAX_SEQ>;
 
     // the bitset seq[i] tells us which sequences are currently occupying the i-th cell
     std::vector<bits_t> seq;
 
     // the set seq_pos[s] tells us which positions are currently present for sequence s
     // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
-    std::set<llama_pos> seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES];
+    std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
 
     // helper functions for updating `seq_pos`, once cell at a time:
 
     // remove cell i
     void seq_pos_rm(uint32_t i) {
-        for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+        for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
             if (seq[i].test(s)) {
                 seq_pos[s].erase(pos[i]);
             }
@@ -406,7 +406,7 @@ private:
 
     // add cell i
     void seq_pos_add(uint32_t i) {
-        for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+        for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
             if (seq[i].test(s)) {
                 seq_pos[s].insert(pos[i]);
             }
index 991aae781ba57003d2d975994f848ff69735a931..24668f861b976243bb8bb59fd149e615c71cb87a 100644 (file)
@@ -73,8 +73,7 @@ struct llama_memory_i {
     virtual llama_memory_state_ptr init_batch(
             const llama_batch & batch,
             uint32_t n_ubatch,
-            bool embd_pooled,
-            bool logits_all) = 0;
+            bool embd_all) = 0;
 
     // simulate full cache, used for allocating worst-case compute buffers
     virtual llama_memory_state_ptr init_full() = 0;
index c41ee24507fca47f7cab1ca353c4edfd4ec7bd9a..a5eb122f998d85cbcd309953dd71518271578ceb 100644 (file)
@@ -80,6 +80,7 @@ const char * llm_type_name(llm_type type) {
         case LLM_TYPE_40B:           return "40B";
         case LLM_TYPE_65B:           return "65B";
         case LLM_TYPE_70B:           return "70B";
+        case LLM_TYPE_142B:          return "142B";
         case LLM_TYPE_236B:          return "236B";
         case LLM_TYPE_290B:          return "290B";
         case LLM_TYPE_314B:          return "314B";
@@ -598,6 +599,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     hparams.use_kq_norm = false;
                 }
             } break;
+        case LLM_ARCH_ARCEE:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+                // Arcee uses the same structure as Llama
+                switch (hparams.n_layer) {
+                    case 36: type = LLM_TYPE_4B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_DECI:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -738,6 +749,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     }
                 }
             } break;
+        case LLM_ARCH_NEO_BERT:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                ml.get_key(LLM_KV_ATTENTION_CAUSAL,            hparams.causal_attn);
+                ml.get_key(LLM_KV_POOLING_TYPE,                hparams.pooling_type);
+
+                if (hparams.n_layer == 28) {
+                    type = LLM_TYPE_250M;
+                }
+            } break;
         case LLM_ARCH_BLOOM:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -1444,6 +1465,20 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_DOTS1:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT,   hparams.n_layer_dense_lead);
+                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,  hparams.n_ff_exp);
+                ml.get_key(LLM_KV_EXPERT_SHARED_COUNT,         hparams.n_expert_shared);
+                ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE,        hparams.expert_weights_scale);
+                ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM,         hparams.expert_weights_norm, false);
+                ml.get_key(LLM_KV_EXPERT_GATING_FUNC,          hparams.expert_gating_func, false);
+                switch (hparams.n_layer) {
+                    case 62: type = LLM_TYPE_142B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         default: throw std::runtime_error("unsupported model architecture");
     }
 
@@ -2187,6 +2222,32 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i),   {n_embd}, 0);
                     }
                 } break;
+            case LLM_ARCH_NEO_BERT:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, 0);
+
+                    cls   = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED);
+                    cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"),   {n_embd},         TENSOR_NOT_REQUIRED);
+
+                    cls_out   = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
+                    cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"),   {hparams.n_cls_out},         TENSOR_NOT_REQUIRED);
+
+                    output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff*2}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                    }
+                } break;
             case LLM_ARCH_JINA_BERT_V2:
                 {
                     tok_embd  = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, 0); // word_embeddings
@@ -2224,8 +2285,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.attn_norm_2   = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
                         layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias",   i), {n_embd}, TENSOR_NOT_REQUIRED);
 
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, layer.ffn_gate ? n_ff : n_ff * 2}, 0);
 
                         layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
                         layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias",   i), {n_embd}, 0);
@@ -4123,6 +4184,89 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.ffn_up_shexp   = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
                     }
                 } break;
+            case LLM_ARCH_DOTS1:
+                {
+                    const int64_t n_ff_exp        = hparams.n_ff_exp;
+                    const int64_t n_expert_shared = hparams.n_expert_shared;
+
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
+                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        if (i < (int) hparams.n_layer_dense_lead) {
+                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                        } else {
+                            layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+                            layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED);
+
+                            if (n_expert == 0) {
+                                throw std::runtime_error("n_expert must be > 0");
+                            }
+                            if (n_expert_used == 0) {
+                                throw std::runtime_error("n_expert_used must be > 0");
+                            }
+
+                            // MoE branch
+                            layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+                            layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
+                            layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+
+                            // Shared expert branch
+                            layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
+                            layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {        n_ff_exp * n_expert_shared, n_embd}, 0);
+                            layer.ffn_up_shexp   = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
+                        }
+                    }
+                } break;
+            case LLM_ARCH_ARCEE:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -6043,7 +6187,7 @@ struct llm_build_bert : public llm_graph_context {
                         model.layers[il].ffn_gate, NULL,                        NULL,
                         model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         NULL,
-                        LLM_FFN_GELU, LLM_FFN_PAR, il);
+                        model.layers[il].ffn_gate ? LLM_FFN_GELU : LLM_FFN_GEGLU, LLM_FFN_PAR, il);
                 cb(cur, "ffn_out", il);
             } else {
                 cur = build_ffn(cur,
@@ -6074,6 +6218,117 @@ struct llm_build_bert : public llm_graph_context {
     }
 };
 
+struct llm_build_neo_bert : public llm_graph_context {
+    llm_build_neo_bert(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        // construct input embeddings (token, type, position)
+        inpL = build_inp_embd(model.tok_embd);
+        cb(inpL, "inp_embd", -1);
+
+        auto * inp_attn = build_attn_inp_no_cache();
+
+        // iterate layers
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * cur = inpL;
+
+            ggml_tensor * Qcur;
+            ggml_tensor * Kcur;
+            ggml_tensor * Vcur;
+
+            // pre-norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+
+            // self-attention
+            cur = build_lora_mm(model.layers[il].wqkv, cur);
+            cb(cur, "wqkv", il);
+
+            Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+            Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+            Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+            // RoPE
+            Qcur = ggml_rope_ext(
+                    ctx0, Qcur, inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                    );
+
+            Kcur = ggml_rope_ext(
+                    ctx0, Kcur, inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                    );
+
+            cb(Qcur, "Qcur", il);
+            cb(Kcur, "Kcur", il);
+            cb(Vcur, "Vcur", il);
+
+            cur = build_attn(inp_attn, gf,
+                    model.layers[il].wo, nullptr,
+                    Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            cb(cur, "kqv_out", il);
+
+            if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            // re-add the layer input
+            cur = ggml_add(ctx0, cur, inpL);
+
+            ggml_tensor * ffn_inp = cur;
+            cb(ffn_inp, "ffn_inp", il);
+
+            // pre-norm
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            // feed-forward network
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,
+                    NULL, NULL, NULL, NULL, NULL,
+                    model.layers[il].ffn_down,
+                    NULL, NULL, NULL,
+                    LLM_FFN_SWIGLU, LLM_FFN_SEQ, il);
+
+            // attentions bypass the intermediate layer
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm_enc, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_embd", -1);
+        res->t_embd = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
 struct llm_build_bloom : public llm_graph_context {
     llm_build_bloom(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -8857,7 +9112,6 @@ struct llm_build_mamba : public llm_graph_context {
         inpL = build_inp_embd(model.tok_embd);
 
         ggml_tensor * state_copy = build_inp_s_copy();
-        ggml_tensor * state_mask = build_inp_s_mask();
 
         for (int il = 0; il < n_layer; ++il) {
             // norm
@@ -8866,8 +9120,7 @@ struct llm_build_mamba : public llm_graph_context {
                     LLM_NORM_RMS, il);
             cb(cur, "attn_norm", il);
 
-            //cur = build_mamba_layer(gf, cur, state_copy, state_mask, il);
-            cur = build_mamba_layer(gf, cur, state_copy, state_mask, ubatch, il);
+            cur = build_mamba_layer(gf, cur, state_copy, ubatch, il);
 
             if (il == n_layer - 1) {
                 // skip computing output for unused tokens
@@ -8908,7 +9161,6 @@ struct llm_build_mamba : public llm_graph_context {
              ggml_cgraph * gf,
              ggml_tensor * cur,
              ggml_tensor * state_copy,
-             ggml_tensor * state_mask,
       const llama_ubatch & ubatch,
                      int   il) const {
         const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
@@ -8935,12 +9187,12 @@ struct llm_build_mamba : public llm_graph_context {
         ggml_tensor * ssm_states_all  = kv_state->get_v_l(il);
 
         // (ab)using the KV cache to store the states
-        ggml_tensor * conv = build_copy_mask_state(
-                gf, conv_states_all, state_copy, state_mask,
+        ggml_tensor * conv = build_recurrent_state(
+                gf, conv_states_all, state_copy,
                 hparams.n_embd_k_s(), n_seqs);
         conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
-        ggml_tensor * ssm = build_copy_mask_state(
-                gf, ssm_states_all, state_copy, state_mask,
+        ggml_tensor * ssm = build_recurrent_state(
+                gf, ssm_states_all, state_copy,
                 hparams.n_embd_v_s(), n_seqs);
         ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
 
@@ -11656,7 +11908,6 @@ struct llm_build_rwkv6_base : public llm_graph_context {
             ggml_tensor * cur,
             ggml_tensor * x_prev,
             ggml_tensor * state_copy,
-            ggml_tensor * state_mask,
             const llama_ubatch & ubatch,
             int   il) const {
         const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
@@ -11780,8 +12031,8 @@ struct llm_build_rwkv6_base : public llm_graph_context {
             k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w));
         }
 
-        ggml_tensor * wkv_state = build_copy_mask_state(
-                gf, kv_state->get_v_l(il), state_copy, state_mask,
+        ggml_tensor * wkv_state = build_recurrent_state(
+                gf, kv_state->get_v_l(il), state_copy,
                 hparams.n_embd_v_s(), n_seqs);
 
         ggml_tensor * wkv_output;
@@ -11837,7 +12088,6 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
         inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
 
         ggml_tensor * state_copy = build_inp_s_copy();
-        ggml_tensor * state_mask = build_inp_s_mask();
 
         const auto n_embd = hparams.n_embd;
         const auto n_seq_tokens = ubatch.n_seq_tokens;
@@ -11848,7 +12098,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
             inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
 
             ggml_tensor * token_shift = build_rwkv_token_shift_load(
-                    gf, state_copy, state_mask, ubatch, il
+                    gf, state_copy, ubatch, il
                     );
 
             ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
@@ -11864,7 +12114,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
                     1
                     );
 
-            cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, state_mask, ubatch, il);
+            cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il);
 
             ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
             cb(ffn_inp, "ffn_inp", il);
@@ -11935,7 +12185,6 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
         inpL = build_inp_embd(model.tok_embd);
 
         ggml_tensor * state_copy = build_inp_s_copy();
-        ggml_tensor * state_mask = build_inp_s_mask();
 
         const auto n_embd = hparams.n_embd;
         const auto n_seq_tokens = ubatch.n_seq_tokens;
@@ -11946,7 +12195,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
             inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
 
             ggml_tensor * token_shift = build_rwkv_token_shift_load(
-                    gf, state_copy, state_mask, ubatch, il
+                    gf, state_copy, ubatch, il
                     );
 
             ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
@@ -11959,7 +12208,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
                     1
                     );
 
-            cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, state_mask, ubatch, il);
+            cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il);
 
             token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
             ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
@@ -12051,7 +12300,6 @@ struct llm_build_rwkv7_base : public llm_graph_context {
             ggml_tensor * cur,
             ggml_tensor * x_prev,
             ggml_tensor * state_copy,
-            ggml_tensor * state_mask,
             ggml_tensor *& first_layer_value,
             const llama_ubatch & ubatch,
             int   il) const {
@@ -12134,8 +12382,8 @@ struct llm_build_rwkv7_base : public llm_graph_context {
         v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
         a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
 
-        ggml_tensor * wkv_state = build_copy_mask_state(
-                gf, kv_state->get_v_l(il), state_copy, state_mask,
+        ggml_tensor * wkv_state = build_recurrent_state(
+                gf, kv_state->get_v_l(il), state_copy,
                 hparams.n_embd_v_s(), n_seqs);
 
         ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
@@ -12193,7 +12441,6 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
         inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
 
         ggml_tensor * state_copy = build_inp_s_copy();
-        ggml_tensor * state_mask = build_inp_s_mask();
 
         const auto n_embd = hparams.n_embd;
         const auto n_seq_tokens = ubatch.n_seq_tokens;
@@ -12204,7 +12451,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
             inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
 
             ggml_tensor * token_shift = build_rwkv_token_shift_load(
-                    gf, state_copy, state_mask, ubatch, il
+                    gf, state_copy, ubatch, il
                     );
 
             ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
@@ -12220,7 +12467,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
                     1
                     );
 
-            cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il);
+            cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il);
 
             ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
             cb(ffn_inp, "ffn_inp", il);
@@ -12287,7 +12534,6 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
         inpL = build_inp_embd(model.tok_embd);
 
         ggml_tensor * state_copy = build_inp_s_copy();
-        ggml_tensor * state_mask = build_inp_s_mask();
 
         const auto n_embd = hparams.n_embd;
         const auto n_seq_tokens = ubatch.n_seq_tokens;
@@ -12298,7 +12544,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
             inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
 
             ggml_tensor * token_shift = build_rwkv_token_shift_load(
-                    gf, state_copy, state_mask, ubatch, il
+                    gf, state_copy, ubatch, il
                     );
 
             ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
@@ -12311,7 +12557,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
                     1
                     );
 
-            cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il);
+            cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il);
 
             token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
             ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
@@ -13203,6 +13449,291 @@ struct llm_build_bailingmoe : public llm_graph_context {
     }
 };
 
+struct llm_build_dots1 : public llm_graph_context {
+    llm_build_dots1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified();
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self_attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+                cb(Qcur, "Qcur_normed", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+                cb(Kcur, "Kcur_normed", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // MoE branch
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            if ((uint32_t) il < hparams.n_layer_dense_lead) {
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            } else {
+                ggml_tensor * moe_out =
+                    build_moe_ffn(cur,
+                            model.layers[il].ffn_gate_inp,
+                            model.layers[il].ffn_up_exps,
+                            model.layers[il].ffn_gate_exps,
+                            model.layers[il].ffn_down_exps,
+                            model.layers[il].ffn_exp_probs_b,
+                            n_expert, n_expert_used,
+                            LLM_FFN_SILU, hparams.expert_weights_norm,
+                            true, hparams.expert_weights_scale,
+                            (llama_expert_gating_func_type) hparams.expert_gating_func,
+                            il);
+                cb(moe_out, "ffn_moe_out", il);
+
+                {
+                    ggml_tensor * ffn_shexp = build_ffn(cur,
+                            model.layers[il].ffn_up_shexp,   NULL, NULL,
+                            model.layers[il].ffn_gate_shexp, NULL, NULL,
+                            model.layers[il].ffn_down_shexp, NULL, NULL,
+                            NULL,
+                            LLM_FFN_SILU, LLM_FFN_PAR, il);
+                    cb(ffn_shexp, "ffn_shexp", il);
+
+                    cur = ggml_add(ctx0, moe_out, ffn_shexp);
+                    cb(cur, "ffn_out", il);
+                }
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_arcee : public llm_graph_context {
+    llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified();
+
+        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // rope freq factors for llama3; may return nullptr for llama2 and other models
+                ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
+
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
+                cb(cur, "attn_out", il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            // ARCEE uses relu^2 instead of silu
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    NULL,                      NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
 llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
     llama_memory_i * res;
 
@@ -13211,6 +13742,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
         case LLM_ARCH_JINA_BERT_V2:
         case LLM_ARCH_NOMIC_BERT:
         case LLM_ARCH_NOMIC_BERT_MOE:
+        case LLM_ARCH_NEO_BERT:
         case LLM_ARCH_WAVTOKENIZER_DEC:
             {
                 res = nullptr;
@@ -13319,6 +13851,10 @@ llm_graph_result_ptr llama_model::build_graph(
             {
                 llm = std::make_unique<llm_build_bert>(*this, params, gf);
             } break;
+        case LLM_ARCH_NEO_BERT:
+            {
+                llm = std::make_unique<llm_build_neo_bert>(*this, params, gf);
+            } break;
         case LLM_ARCH_BLOOM:
             {
                 llm = std::make_unique<llm_build_bloom>(*this, params, gf);
@@ -13541,6 +14077,14 @@ llm_graph_result_ptr llama_model::build_graph(
             {
                 llm = std::make_unique<llm_build_bailingmoe>(*this, params, gf);
             } break;
+        case LLM_ARCH_DOTS1:
+            {
+                llm = std::make_unique<llm_build_dots1>(*this, params, gf);
+            } break;
+        case LLM_ARCH_ARCEE:
+            {
+                llm = std::make_unique<llm_build_arcee>(*this, params, gf);
+            } break;
         default:
             GGML_ABORT("fatal error");
     }
@@ -13690,6 +14234,8 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_GRANITE_MOE:
         case LLM_ARCH_CHAMELEON:
         case LLM_ARCH_BAILINGMOE:
+        case LLM_ARCH_NEO_BERT:
+        case LLM_ARCH_ARCEE:
             return LLAMA_ROPE_TYPE_NORM;
 
         // the pairs of head values are offset by n_rot/2
@@ -13723,6 +14269,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_NEMOTRON:
         case LLM_ARCH_EXAONE:
         case LLM_ARCH_MINICPM3:
+        case LLM_ARCH_DOTS1:
             return LLAMA_ROPE_TYPE_NEOX;
 
         case LLM_ARCH_QWEN2VL:
index 18b714620bbcf899e23828eb9653fcb80b728e49..06e6c687943cc23e615bd1f49f773347d2b4247b 100644 (file)
@@ -73,6 +73,7 @@ enum llm_type {
     LLM_TYPE_40B,
     LLM_TYPE_65B,
     LLM_TYPE_70B,
+    LLM_TYPE_142B,
     LLM_TYPE_236B,
     LLM_TYPE_290B,
     LLM_TYPE_314B,
index 159b1307a4c5d70ba49d476743350feaf0f7a231..8cf45732fd6d4817cdafd6fd3de6c2bea421fec4 100644 (file)
@@ -585,7 +585,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
             if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) {
                 gguf_set_val_f32(ctx_out.get(), o.key, o.val_f64);
             } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) {
-                gguf_set_val_i32(ctx_out.get(), o.key, o.val_i64);
+                // Setting type to UINT32. See https://github.com/ggml-org/llama.cpp/pull/14182 for context
+                gguf_set_val_u32(ctx_out.get(), o.key, (uint32_t)abs(o.val_i64));
             } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) {
                 gguf_set_val_bool(ctx_out.get(), o.key, o.val_bool);
             } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) {
index ba2e1864ec0050ba58777b417ada98113809e3cd..dd2251ef3cbefa7bc6d88fecab69cec4bc70b7ad 100644 (file)
@@ -9,16 +9,16 @@
 
 #include <algorithm>
 #include <cassert>
+#include <cctype>
 #include <cfloat>
-#include <climits>
 #include <cstdarg>
 #include <cstring>
 #include <forward_list>
+#include <limits>
 #include <map>
 #include <queue>
 #include <set>
 #include <unordered_map>
-#include <cctype>
 
 //
 // helpers
@@ -1987,6 +1987,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                     || t.first == "<|eom_id|>"
                     || t.first == "<EOT>"
                     || t.first == "_<EOT>"
+                    || t.first == "<|end_of_text|>"
                ) {
                 special_eog_ids.insert(t.second);
                 if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2572,6 +2573,10 @@ int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t
     // copy piece chars to output text buffer
     // skip up to 'lstrip' leading spaces before copying
     auto _try_copy = [=] (const char * token, size_t size) -> int32_t {
+        if (size >= static_cast<size_t>(std::numeric_limits<int32_t>::max())) {
+            GGML_ABORT("invalid token size: %zu exceeds int32_t limit", size);
+        }
+
         for (int32_t i = 0; i < lstrip && size && *token == ' '; ++i) {
             token++;
             size--;
@@ -2768,26 +2773,26 @@ void llama_vocab::impl::print_info() const {
     LLAMA_LOG_INFO("%s: n_merges         = %u\n",     __func__, (uint32_t) bpe_ranks.size());
 
     // special tokens
-    if (special_bos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, special_bos_id,     id_to_token[special_bos_id].text.c_str() );  }
-    if (special_eos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, special_eos_id,     id_to_token[special_eos_id].text.c_str() );  }
-    if (special_eot_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, special_eot_id,     id_to_token[special_eot_id].text.c_str() );  }
-    if (special_eom_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOM token        = %d '%s'\n", __func__, special_eom_id,     id_to_token[special_eom_id].text.c_str() );  }
-    if (special_unk_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, special_unk_id,     id_to_token[special_unk_id].text.c_str() );  }
-    if (special_sep_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, special_sep_id,     id_to_token[special_sep_id].text.c_str() );  }
-    if (special_pad_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, special_pad_id,     id_to_token[special_pad_id].text.c_str() );  }
-    if (special_mask_id != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, special_mask_id,    id_to_token[special_mask_id].text.c_str() ); }
-
-    if (linefeed_id != LLAMA_TOKEN_NULL)        { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, linefeed_id,        id_to_token[linefeed_id].text.c_str() ); }
-
-    if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token    = %d '%s'\n", __func__, special_fim_pre_id, id_to_token[special_fim_pre_id].text.c_str() ); }
-    if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token    = %d '%s'\n", __func__, special_fim_suf_id, id_to_token[special_fim_suf_id].text.c_str() ); }
-    if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token    = %d '%s'\n", __func__, special_fim_mid_id, id_to_token[special_fim_mid_id].text.c_str() ); }
-    if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token    = %d '%s'\n", __func__, special_fim_pad_id, id_to_token[special_fim_pad_id].text.c_str() ); }
-    if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token    = %d '%s'\n", __func__, special_fim_rep_id, id_to_token[special_fim_rep_id].text.c_str() ); }
-    if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token    = %d '%s'\n", __func__, special_fim_sep_id, id_to_token[special_fim_sep_id].text.c_str() ); }
+    if (special_bos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, special_bos_id,     id_to_token.at(special_bos_id).text.c_str() );  }
+    if (special_eos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, special_eos_id,     id_to_token.at(special_eos_id).text.c_str() );  }
+    if (special_eot_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, special_eot_id,     id_to_token.at(special_eot_id).text.c_str() );  }
+    if (special_eom_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOM token        = %d '%s'\n", __func__, special_eom_id,     id_to_token.at(special_eom_id).text.c_str() );  }
+    if (special_unk_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, special_unk_id,     id_to_token.at(special_unk_id).text.c_str() );  }
+    if (special_sep_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, special_sep_id,     id_to_token.at(special_sep_id).text.c_str() );  }
+    if (special_pad_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, special_pad_id,     id_to_token.at(special_pad_id).text.c_str() );  }
+    if (special_mask_id != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, special_mask_id,    id_to_token.at(special_mask_id).text.c_str() ); }
+
+    if (linefeed_id != LLAMA_TOKEN_NULL)        { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, linefeed_id,        id_to_token.at(linefeed_id).text.c_str() ); }
+
+    if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token    = %d '%s'\n", __func__, special_fim_pre_id, id_to_token.at(special_fim_pre_id).text.c_str() ); }
+    if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token    = %d '%s'\n", __func__, special_fim_suf_id, id_to_token.at(special_fim_suf_id).text.c_str() ); }
+    if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token    = %d '%s'\n", __func__, special_fim_mid_id, id_to_token.at(special_fim_mid_id).text.c_str() ); }
+    if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token    = %d '%s'\n", __func__, special_fim_pad_id, id_to_token.at(special_fim_pad_id).text.c_str() ); }
+    if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token    = %d '%s'\n", __func__, special_fim_rep_id, id_to_token.at(special_fim_rep_id).text.c_str() ); }
+    if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token    = %d '%s'\n", __func__, special_fim_sep_id, id_to_token.at(special_fim_sep_id).text.c_str() ); }
 
     for (const auto & id : special_eog_ids) {
-        LLAMA_LOG_INFO( "%s: EOG token        = %d '%s'\n", __func__, id, id_to_token[id].text.c_str() );
+        LLAMA_LOG_INFO( "%s: EOG token        = %d '%s'\n", __func__, id, id_to_token.at(id).text.c_str() );
     }
 
     LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, max_token_len);
index 2f06e0f8ce12d2d309f5c61cdc8219cec27d06b5..34906cdb62844875bf572a2a1df6118a2a8aa885 100644 (file)
@@ -198,14 +198,18 @@ static struct llama_model * llama_model_load_from_file_impl(
 
     // if using single GPU mode, remove all except the main GPU
     if (params.split_mode == LLAMA_SPLIT_MODE_NONE) {
-        if (params.main_gpu < 0 || params.main_gpu >= (int)model->devices.size()) {
-            LLAMA_LOG_ERROR("%s: invalid value for main_gpu: %d (available devices: %d)\n", __func__, params.main_gpu, (int)model->devices.size());
-            llama_model_free(model);
-            return nullptr;
+        if (params.main_gpu < 0) {
+            model->devices.clear();
+        } else {
+            if (params.main_gpu >= (int)model->devices.size()) {
+                LLAMA_LOG_ERROR("%s: invalid value for main_gpu: %d (available devices: %zu)\n", __func__, params.main_gpu, model->devices.size());
+                llama_model_free(model);
+                return nullptr;
+            }
+            ggml_backend_dev_t main_gpu = model->devices[params.main_gpu];
+            model->devices.clear();
+            model->devices.push_back(main_gpu);
         }
-        ggml_backend_dev_t main_gpu = model->devices[params.main_gpu];
-        model->devices.clear();
-        model->devices.push_back(main_gpu);
     }
 
     for (auto * dev : model->devices) {
index 015a57898e22d2411296108fbc7446713f1dc2c9..635508b10f2ff1a2820ca98b15d26992b539f495 100644 (file)
@@ -243,18 +243,21 @@ extern "C" {
 
     typedef bool (*llama_progress_callback)(float progress, void * user_data);
 
-    // Input data for llama_decode
+    // Input data for llama_encode/llama_decode
     // A llama_batch object can contain input about one or many sequences
     // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
     //
     // - token  : the token ids of the input (used when embd is NULL)
     // - embd   : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
     // - pos    : the positions of the respective token in the sequence
-    //            (if set to NULL, the token position will be tracked automatically by llama_decode)
+    //            (if set to NULL, the token position will be tracked automatically by llama_encode/llama_decode)
     // - seq_id : the sequence to which the respective token belongs
     //            (if set to NULL, the sequence ID will be assumed to be 0)
     // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
-    //            (if set to NULL, only the logits for last token will be returned)
+    //            (if set to NULL:
+    //               - if embeddings: all tokens are output
+    //               - if not:        only the last token is output
+    //            )
     //
     typedef struct llama_batch {
         int32_t n_tokens;
@@ -262,8 +265,8 @@ extern "C" {
         llama_token  *  token;
         float        *  embd;
         llama_pos    *  pos;
-        int32_t      *  n_seq_id; // TODO: remove, should belong to only 1 sequence
-        llama_seq_id ** seq_id;   // TODO: become llama_seq_id * seq_id;
+        int32_t      *  n_seq_id;
+        llama_seq_id ** seq_id;
         int8_t       *  logits;   // TODO: rename this to "output"
     } llama_batch;
 
@@ -961,8 +964,8 @@ extern "C" {
     // Get the number of threads used for prompt and batch processing (multiple token).
     LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx);
 
-    // Set whether the model is in embeddings mode or not
-    // If true, embeddings will be returned but logits will not
+    // Set whether the context outputs embeddings or not
+    // TODO: rename to avoid confusion with llama_get_embeddings()
     LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
 
     // Set whether to use causal attention or not