]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
talk-llama : sync llama.cpp
authorGeorgi Gerganov <redacted>
Mon, 12 Jan 2026 12:48:26 +0000 (14:48 +0200)
committerGeorgi Gerganov <redacted>
Wed, 14 Jan 2026 07:11:59 +0000 (09:11 +0200)
43 files changed:
examples/talk-llama/llama-arch.cpp
examples/talk-llama/llama-arch.h
examples/talk-llama/llama-chat.cpp
examples/talk-llama/llama-chat.h
examples/talk-llama/llama-context.cpp
examples/talk-llama/llama-context.h
examples/talk-llama/llama-grammar.cpp
examples/talk-llama/llama-grammar.h
examples/talk-llama/llama-graph.cpp
examples/talk-llama/llama-graph.h
examples/talk-llama/llama-hparams.cpp
examples/talk-llama/llama-hparams.h
examples/talk-llama/llama-mmap.cpp
examples/talk-llama/llama-mmap.h
examples/talk-llama/llama-model-loader.cpp
examples/talk-llama/llama-model-loader.h
examples/talk-llama/llama-model-saver.cpp
examples/talk-llama/llama-model.cpp
examples/talk-llama/llama-model.h
examples/talk-llama/llama-quant.cpp
examples/talk-llama/llama-sampling.cpp
examples/talk-llama/llama-sampling.h
examples/talk-llama/llama-vocab.cpp
examples/talk-llama/llama-vocab.h
examples/talk-llama/llama.cpp
examples/talk-llama/llama.h
examples/talk-llama/models/afmoe.cpp
examples/talk-llama/models/bert.cpp
examples/talk-llama/models/cogvlm.cpp
examples/talk-llama/models/cohere2-iswa.cpp
examples/talk-llama/models/deepseek2.cpp
examples/talk-llama/models/gemma-embedding.cpp
examples/talk-llama/models/gemma2-iswa.cpp
examples/talk-llama/models/gemma3.cpp
examples/talk-llama/models/gemma3n-iswa.cpp
examples/talk-llama/models/llama-iswa.cpp
examples/talk-llama/models/maincoder.cpp [new file with mode: 0644]
examples/talk-llama/models/models.h
examples/talk-llama/models/modern-bert.cpp
examples/talk-llama/models/openai-moe-iswa.cpp
examples/talk-llama/models/qwen3next.cpp
examples/talk-llama/models/smallthinker.cpp
examples/talk-llama/unicode.cpp

index 94a6807eac81920f80c7060beb0b49967650bbc2..f736ee670506b6335b8a1107854a789962b73fa4 100644 (file)
@@ -118,6 +118,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_MISTRAL3,         "mistral3"         },
     { LLM_ARCH_MIMO2,            "mimo2"           },
     { LLM_ARCH_LLAMA_EMBED,      "llama-embed"      },
+    { LLM_ARCH_MAINCODER,        "maincoder"        },
     { LLM_ARCH_UNKNOWN,          "(unknown)"        },
 };
 
@@ -151,6 +152,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_VOCAB_SIZE,                        "%s.vocab_size"                        },
     { LLM_KV_CONTEXT_LENGTH,                    "%s.context_length"                    },
     { LLM_KV_EMBEDDING_LENGTH,                  "%s.embedding_length"                  },
+    { LLM_KV_EMBEDDING_LENGTH_OUT,              "%s.embedding_length_out"              },
     { LLM_KV_FEATURES_LENGTH,                   "%s.features_length"                   },
     { LLM_KV_BLOCK_COUNT,                       "%s.block_count"                       },
     { LLM_KV_LEADING_DENSE_BLOCK_COUNT,         "%s.leading_dense_block_count"         },
@@ -948,6 +950,8 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
                 LLM_TENSOR_ATTN_K_NORM,
                 LLM_TENSOR_ATTN_V,
                 LLM_TENSOR_ATTN_OUT,
+                LLM_TENSOR_ATTN_QKV,
+                LLM_TENSOR_ATTN_GATE,
                 LLM_TENSOR_FFN_NORM,
                 LLM_TENSOR_FFN_GATE_INP,
                 LLM_TENSOR_FFN_GATE_EXPS,
@@ -2074,6 +2078,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
                 LLM_TENSOR_TOKEN_EMBD,
                 LLM_TENSOR_OUTPUT_NORM_LFM2,
                 LLM_TENSOR_OUTPUT,
+                LLM_TENSOR_DENSE_2_OUT,
             };
         case LLM_ARCH_LFM2MOE:
             return {
@@ -2234,6 +2239,23 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
             return {
                 LLM_TENSOR_TOKEN_EMBD,
             };
+        case LLM_ARCH_MAINCODER:
+            return {
+                LLM_TENSOR_TOKEN_EMBD,
+                LLM_TENSOR_OUTPUT_NORM,
+                LLM_TENSOR_OUTPUT,
+                LLM_TENSOR_ATTN_NORM,
+                LLM_TENSOR_ATTN_Q,
+                LLM_TENSOR_ATTN_Q_NORM,
+                LLM_TENSOR_ATTN_K,
+                LLM_TENSOR_ATTN_K_NORM,
+                LLM_TENSOR_ATTN_V,
+                LLM_TENSOR_ATTN_OUT,
+                LLM_TENSOR_FFN_NORM,
+                LLM_TENSOR_FFN_GATE,
+                LLM_TENSOR_FFN_DOWN,
+                LLM_TENSOR_FFN_UP,
+            };
         default:
             GGML_ABORT("unknown architecture for tensor mapping");
     }
index 714ead402571cc1a42359894dc14207b9183136a..68ec6a18b185f0a4c944aa0ba32175bf31f6a1f8 100644 (file)
@@ -122,6 +122,7 @@ enum llm_arch {
     LLM_ARCH_MISTRAL3,
     LLM_ARCH_MIMO2,
     LLM_ARCH_LLAMA_EMBED,
+    LLM_ARCH_MAINCODER,
     LLM_ARCH_UNKNOWN,
 };
 
@@ -155,6 +156,7 @@ enum llm_kv {
     LLM_KV_VOCAB_SIZE,
     LLM_KV_CONTEXT_LENGTH,
     LLM_KV_EMBEDDING_LENGTH,
+    LLM_KV_EMBEDDING_LENGTH_OUT,
     LLM_KV_FEATURES_LENGTH,
     LLM_KV_BLOCK_COUNT,
     LLM_KV_LEADING_DENSE_BLOCK_COUNT,
index fc6a6223cfe2f86b8d4bf7ff2b596e2fb6c945fb..b54ebbd155dbab1c510ce45300ab3d534fa48a86 100644 (file)
@@ -74,6 +74,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
     { "seed_oss",          LLM_CHAT_TEMPLATE_SEED_OSS          },
     { "grok-2",            LLM_CHAT_TEMPLATE_GROK_2            },
     { "pangu-embedded",    LLM_CHAT_TEMPLATE_PANGU_EMBED       },
+    { "solar-open",        LLM_CHAT_TEMPLATE_SOLAR_OPEN        },
 };
 
 llm_chat_template llm_chat_template_from_str(const std::string & name) {
@@ -216,6 +217,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
         return LLM_CHAT_TEMPLATE_GROK_2;
     } else if (tmpl_contains(LU8("[unused9]系统:[unused10]"))) {
         return LLM_CHAT_TEMPLATE_PANGU_EMBED;
+    } else if (tmpl_contains("<|begin|>") && tmpl_contains("<|end|>") && tmpl_contains("<|content|>")) {
+        return LLM_CHAT_TEMPLATE_SOLAR_OPEN;
     }
     return LLM_CHAT_TEMPLATE_UNKNOWN;
 }
@@ -845,6 +848,14 @@ int32_t llm_chat_apply_template(
         if (add_ass) {
             ss << "[unused9]助手:";
         }
+    } else if (tmpl == LLM_CHAT_TEMPLATE_SOLAR_OPEN) {
+        for (auto message : chat) {
+            std::string role(message->role);
+            ss << "<|begin|>" << role << "<|content|>" << message->content << "<|end|>";
+        }
+        if (add_ass) {
+            ss << "<|begin|>assistant";
+        }
     } else {
         // template not supported
         return -1;
index 684efb4d67f45b84f3cf6d96ac5909e22af29395..e1f795249c88625dbee3aed5954a3cb98370ae3d 100644 (file)
@@ -54,6 +54,7 @@ enum llm_chat_template {
     LLM_CHAT_TEMPLATE_SEED_OSS,
     LLM_CHAT_TEMPLATE_GROK_2,
     LLM_CHAT_TEMPLATE_PANGU_EMBED,
+    LLM_CHAT_TEMPLATE_SOLAR_OPEN,
     LLM_CHAT_TEMPLATE_UNKNOWN,
 };
 
index 34dfcd4724bdd177d6a8acc960b13ac282390a7d..f220010a1b4cce8b422942d3492c8aa20bb451ad 100644 (file)
@@ -60,6 +60,25 @@ llama_context::llama_context(
     cparams.cb_eval           = params.cb_eval;
     cparams.cb_eval_user_data = params.cb_eval_user_data;
 
+    // Initialize backend samplers here so they are part of the sampling graph
+    // before the reserve passes run later in this function. This avoids a later
+    // re-reserve when graph nodes change.
+    if (params.samplers != nullptr && params.n_samplers > 0) {
+        for (size_t i = 0; i < params.n_samplers; ++i) {
+            const auto & config = params.samplers[i];
+
+            if (llama_sampler_chain_get(config.sampler, -1) == nullptr) {
+                throw std::runtime_error("the backend samplers must be of type llama_sampler_chain");
+            }
+
+            if (set_sampler(config.seq_id, config.sampler)) {
+                const int n_samplers = llama_sampler_chain_n(config.sampler);
+
+                LLAMA_LOG_INFO("%s: setting backend sampler for seq_id %d (n = %d)\n", __func__, config.seq_id, n_samplers);
+            }
+        }
+    }
+
     auto rope_scaling_type = params.rope_scaling_type;
     if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
         rope_scaling_type = hparams.rope_scaling_type_train;
@@ -231,7 +250,10 @@ llama_context::llama_context(
         // graph outputs buffer
         {
             // resized during inference when a batch uses more outputs
-            if (output_reserve(params.n_seq_max) < params.n_seq_max) {
+            // Create a dummy batch for initialization.
+            llama_batch dummy_batch = {};
+            dummy_batch.n_tokens = 0;
+            if (output_reserve(params.n_seq_max, dummy_batch) < params.n_seq_max) {
                 throw std::runtime_error("failed to reserve initial output buffer");
             }
 
@@ -456,6 +478,16 @@ llama_context::llama_context(
             LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
         }
     }
+
+    // Initialize the full vocabulary token ids for backend samplers.
+    {
+        const int n_vocab = model.vocab.n_tokens();
+
+        sampling.token_ids_full_vocab.resize(n_vocab);
+        for (int i = 0; i < n_vocab; ++i) {
+            sampling.token_ids_full_vocab[i] = i;
+        }
+    }
 }
 
 llama_context::~llama_context() {
@@ -616,6 +648,35 @@ float * llama_context::get_logits() {
     return logits;
 }
 
+int64_t llama_context::output_resolve_row(int32_t i) const {
+    int64_t j = -1;
+
+    // support negative indices (last output row)
+    if (i < 0) {
+        j = n_outputs + i;
+        if (j < 0) {
+            throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
+        }
+    } else if ((size_t) i >= output_ids.size()) {
+        throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
+    } else {
+        // use output_ids to translate the batch token index into a row number
+        // that holds this token's data.
+        j = output_ids[i];
+    }
+
+    if (j < 0) {
+        // the batch token was not configured to output anything
+        throw std::runtime_error(format("batch.logits[%d] != true", i));
+    }
+
+    if (j >= n_outputs) {
+        throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
+    }
+
+    return j;
+}
+
 float * llama_context::get_logits_ith(int32_t i) {
     int64_t j = -1;
 
@@ -626,6 +687,7 @@ float * llama_context::get_logits_ith(int32_t i) {
             throw std::runtime_error("no logits");
         }
 
+        // TODO: use output_resolve_row()
         if (i < 0) {
             j = n_outputs + i;
             if (j < 0) {
@@ -662,6 +724,10 @@ float * llama_context::get_embeddings() {
     return embd;
 }
 
+llama_token * llama_context::get_sampled_tokens()  const{
+    return sampling.sampled;
+}
+
 float * llama_context::get_embeddings_ith(int32_t i) {
     int64_t j = -1;
 
@@ -672,6 +738,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
             throw std::runtime_error("no embeddings");
         }
 
+        // TODO: use output_resolve_row()
         if (i < 0) {
             j = n_outputs + i;
             if (j < 0) {
@@ -691,7 +758,8 @@ float * llama_context::get_embeddings_ith(int32_t i) {
             throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
         }
 
-        return embd + j*model.hparams.n_embd;
+        const uint32_t n_embd_out = model.hparams.get_n_embd_out();
+        return embd + j*n_embd_out;
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
 #ifndef NDEBUG
@@ -711,6 +779,136 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
     return it->second.data();
 }
 
+llama_token llama_context::get_sampled_token_ith(int32_t idx) {
+    output_reorder();
+
+    if (sampling.sampled == nullptr) {
+        return LLAMA_TOKEN_NULL;
+    }
+
+    try {
+        const int64_t row = output_resolve_row(idx);
+        GGML_ASSERT(row < (int64_t) sampling.sampled_size);
+        return sampling.sampled[row];
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what());
+        return LLAMA_TOKEN_NULL;
+    }
+}
+
+float * llama_context::get_sampled_probs_ith(int32_t idx) {
+    output_reorder();
+
+    if (sampling.probs == nullptr) {
+        return nullptr;
+    }
+
+    try {
+        const int64_t row = output_resolve_row(idx);
+        if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) {
+            return nullptr;
+        }
+        return sampling.probs + row*model.vocab.n_tokens();
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what());
+        return nullptr;
+    }
+}
+
+float * llama_context::get_sampled_logits_ith(int32_t idx) {
+    output_reorder();
+
+    if (sampling.logits == nullptr) {
+        return nullptr;
+    }
+
+    try {
+        const int64_t row = output_resolve_row(idx);
+        if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) {
+            return nullptr;
+        }
+        return sampling.logits + row*model.vocab.n_tokens();
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what());
+        return nullptr;
+    }
+}
+
+const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
+    output_reorder();
+
+    try {
+        const int64_t row = output_resolve_row(idx);
+        if (sampling.candidates != nullptr &&
+            (size_t) row < sampling.candidates_count.size() &&
+            sampling.candidates_count[row] > 0) {
+            return sampling.candidates + row*model.vocab.n_tokens();
+        }
+    } catch (const std::exception & err) {
+        // fallback to full vocab list
+    }
+
+    return sampling.token_ids_full_vocab.data();
+}
+
+size_t llama_context::get_sampled_candidates_count(int32_t idx) {
+    output_reorder();
+
+    if (sampling.candidates == nullptr) {
+        return 0;
+    }
+
+    try {
+        const int64_t row = output_resolve_row(idx);
+        if ((size_t) row >= sampling.candidates_count.size()) {
+            return 0;
+        }
+        return sampling.candidates_count[row];
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: invalid backend sampled candidates count id %d, reason: %s\n", __func__, idx, err.what());
+        return 0;
+    }
+}
+
+size_t llama_context::get_sampled_logits_count(int32_t idx) {
+    output_reorder();
+
+    if (sampling.logits == nullptr) {
+        return model.vocab.n_tokens();
+    }
+
+    try {
+        const int64_t row = output_resolve_row(idx);
+        if ((size_t) row >= sampling.logits_count.size()) {
+            return 0;
+        }
+        return sampling.logits_count[row];
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: invalid backend sampled logits count id %d, reason: %s\n", __func__, idx, err.what());
+        return 0;
+    }
+}
+
+size_t llama_context::get_sampled_probs_count(int32_t idx) {
+    output_reorder();
+
+    if (sampling.probs == nullptr) {
+        return 0;
+    }
+
+    try {
+        const int64_t row = output_resolve_row(idx);
+        if ((size_t) row >= sampling.probs_count.size()) {
+            return 0;
+        }
+        return sampling.probs_count[row];
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: invalid backend sampled probs count id %d, reason: %s\n", __func__, idx, err.what());
+        return 0;
+    }
+}
+
+
 void llama_context::attach_threadpool(
            ggml_threadpool_t threadpool,
            ggml_threadpool_t threadpool_batch) {
@@ -767,6 +965,42 @@ void llama_context::set_warmup(bool value) {
     cparams.warmup = value;
 }
 
+bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
+    LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
+
+    const bool can_offload =
+        sampler &&
+        sampler->iface->backend_init &&
+        sampler->iface->backend_apply &&
+        llama_sampler_chain_n(sampler) > 0;
+
+    if (sampler && can_offload) {
+        ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output());
+        auto * host_buft = ggml_backend_dev_host_buffer_type(model.dev_output());
+        if (host_buft) {
+            buft = host_buft;
+        }
+
+        sampler->iface->backend_init(sampler, buft);
+
+        sampling.samplers[seq_id] = sampler;
+
+        return true;
+    }
+
+    if (sampler && !can_offload) {
+        LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id);
+
+        sampling.samplers.erase(seq_id);
+
+        return false;
+    }
+
+    sampling.samplers.erase(seq_id);
+
+    return true;
+}
+
 void llama_context::set_adapter_lora(
             llama_adapter_lora * adapter,
             float scale) {
@@ -907,7 +1141,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
     n_queued_tokens += n_tokens;
 
     // reserve output buffer
-    if (output_reserve(n_tokens) < n_tokens) {
+    if (output_reserve(n_tokens, batch_inp) < n_tokens) {
         LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
         return -2;
     };
@@ -961,9 +1195,10 @@ int llama_context::encode(const llama_batch & batch_inp) {
                 {
                     // extract token embeddings
                     GGML_ASSERT(embd != nullptr);
+                    const uint32_t n_embd_out = hparams.get_n_embd_out();
 
-                    GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
-                    ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
+                    GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd_size);
+                    ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float));
                 } break;
             case LLAMA_POOLING_TYPE_MEAN:
             case LLAMA_POOLING_TYPE_CLS:
@@ -1031,6 +1266,112 @@ int llama_context::encode(const llama_batch & batch_inp) {
     return 0;
 }
 
+static std::map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) {
+    std::map<llama_seq_id, uint32_t> seq_to_row;
+    // how many output tokens we have seen so far for this ubatch.
+    uint32_t local = 0;
+    for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
+        // skip tokens that are not output.
+        if (!ubatch.output[i]) {
+            continue;
+        }
+
+        const llama_seq_id seq_id = ubatch.seq_id[i][0];
+        // row_offset is the number of output tokens before this ubatch.
+        seq_to_row[seq_id] = row_offset + local;
+        ++local;
+    }
+    return seq_to_row;
+}
+
+static void copy_tensor_async_ints(
+    const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
+    llama_token * sampled,
+    size_t sampled_size,
+    const std::map<llama_seq_id, uint32_t> & seq_to_row,
+    ggml_backend_sched_t sched) {
+    if (sampled == nullptr) {
+        return;
+    }
+
+    for (const auto & [seq_id, tensor] : tensor_map) {
+        auto it = seq_to_row.find(seq_id);
+        if (it == seq_to_row.end()) {
+            continue;
+        }
+
+        const uint32_t row = it->second;
+        GGML_ASSERT(row < sampled_size);
+
+        GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy");
+
+        ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
+        ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row]));
+    }
+}
+
+static void copy_tensor_async_floats(
+    const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
+    float * dst,
+    size_t stride,
+    std::vector<uint32_t> & counts,
+    const std::map<llama_seq_id, uint32_t> & seq_to_row,
+    ggml_backend_sched_t sched) {
+    if (dst == nullptr) {
+        return;
+    }
+
+    for (const auto & [seq_id, tensor] : tensor_map) {
+        auto it = seq_to_row.find(seq_id);
+        if (it == seq_to_row.end()) {
+            continue;
+        }
+
+        const uint32_t row = it->second;
+        GGML_ASSERT(row < counts.size());
+
+        GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy");
+
+        ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
+        float * row_ptr = dst + (size_t) row * stride;
+        ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
+
+        // Update the actual number of logits/probabilities that were written for this row.
+        counts[row] = ggml_nelements(tensor);
+    }
+}
+
+static void copy_tensor_async_candidates(
+    const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
+    llama_token * dst,
+    size_t stride,
+    std::vector<uint32_t> & counts,
+    const std::map<llama_seq_id, uint32_t> & seq_to_row,
+    ggml_backend_sched_t sched) {
+    if (dst == nullptr) {
+        return;
+    }
+
+    for (const auto & [seq_id, tensor] : tensor_map) {
+        auto it = seq_to_row.find(seq_id);
+        if (it == seq_to_row.end()) {
+            continue;
+        }
+
+        const uint32_t row = it->second;
+        GGML_ASSERT(row < counts.size());
+
+        GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy");
+
+        ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
+        llama_token * row_ptr = dst + (size_t) row * stride;
+        ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
+
+        // Update the actual number of candidates that were written.
+        counts[row] = ggml_nelements(tensor);
+    }
+}
+
 int llama_context::decode(const llama_batch & batch_inp) {
     GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
 
@@ -1051,9 +1392,36 @@ int llama_context::decode(const llama_batch & batch_inp) {
     const int64_t n_embd  = hparams.n_embd_inp();
 
     // when computing embeddings, all tokens are output
-    const bool output_all = cparams.embeddings;
+    const bool output_all   = cparams.embeddings;
+    const bool has_samplers = !sampling.samplers.empty();
+
+    const uint32_t n_seq_max = cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max;
 
-    if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
+    // TODO: avoid this workaround in the future
+    if (has_samplers && batch_inp.logits) {
+        std::vector<int32_t> seq_output_count(n_seq_max, 0);
+
+        for (int32_t i = 0; i < batch_inp.n_tokens; ++i) {
+            if (batch_inp.logits[i] == 0) {
+                continue;
+            }
+
+            const int ns = batch_inp.n_seq_id ? batch_inp.n_seq_id[i] : 1;
+
+            for (int32_t s = 0; s < ns; ++s) {
+                const llama_seq_id seq_id = batch_inp.seq_id ? batch_inp.seq_id[i][s] : 0;
+
+                seq_output_count[seq_id]++;
+                if (seq_output_count[seq_id] > 1) {
+                    LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (seq_id %d had %d)\n",
+                            __func__, seq_id, seq_output_count[seq_id]);
+                    return -1;
+                }
+            }
+        }
+    }
+
+    if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, n_seq_max, output_all)) {
         LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
         return -1;
     }
@@ -1134,7 +1502,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
     }
 
     // reserve output buffer
-    if (output_reserve(n_outputs_all) < n_outputs_all) {
+    if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) {
         LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
         return -2;
     };
@@ -1207,7 +1575,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
         }
 
         // extract logits
-        if (t_logits && n_outputs > 0) {
+        // For multi-sequence batches that mix backend samplers and CPU sampler
+        // this is currently inefficient as we copy all logits even for the
+        // backend sampled tokens.
+        if (logits && t_logits && n_outputs > 0) {
             ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
             GGML_ASSERT(backend_res != nullptr);
             GGML_ASSERT(logits != nullptr);
@@ -1222,7 +1593,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
         }
 
         // extract embeddings
-        if (t_embd && n_outputs > 0) {
+        if (embd && t_embd && n_outputs > 0) {
             ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
             GGML_ASSERT(backend_embd != nullptr);
 
@@ -1231,12 +1602,13 @@ int llama_context::decode(const llama_batch & batch_inp) {
                     {
                         // extract token embeddings
                         GGML_ASSERT(embd != nullptr);
-                        float * embd_out = embd + n_outputs_prev*n_embd;
+                        const uint32_t n_embd_out = hparams.get_n_embd_out();
+                        float * embd_out = embd + n_outputs_prev*n_embd_out;
 
                         if (n_outputs) {
                             GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
-                            GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
-                            ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
+                            GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd_size);
+                            ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd_out*sizeof(float));
                         }
                     } break;
                 case LLAMA_POOLING_TYPE_MEAN:
@@ -1276,6 +1648,22 @@ int llama_context::decode(const llama_batch & batch_inp) {
             }
         }
 
+        // This flag indicates whether a backend sampler has actually sampled a specific
+        // token, or if it has produced probabilites. If true, we can skip the normal copying of logits and embeddings.
+        const bool has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty();
+
+        if (has_samplers && has_sampled) {
+            const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
+            const auto stride = n_vocab;
+
+            // async copy the sampling data from the backend to the host
+            copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get());
+
+            copy_tensor_async_floats    (res->t_sampled_logits, sampling.logits,     stride, sampling.logits_count,     seq_to_output_row, sched.get());
+            copy_tensor_async_floats    (res->t_sampled_probs,  sampling.probs,      stride, sampling.probs_count,      seq_to_output_row, sched.get());
+            copy_tensor_async_candidates(res->t_candidates,     sampling.candidates, stride, sampling.candidates_count, seq_to_output_row, sched.get());
+        }
+
         n_outputs_prev += n_outputs;
     } while (mctx->next());
 
@@ -1339,15 +1727,15 @@ int llama_context::decode(const llama_batch & batch_inp) {
 // output
 //
 
-uint32_t llama_context::output_reserve(int32_t n_outputs) {
+uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & batch) {
     const auto & hparams = model.hparams;
     const auto & vocab   = model.vocab;
 
     const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max());
 
-    const auto n_batch = cparams.n_batch;
-    const auto n_vocab = vocab.n_tokens();
-    const auto n_embd  = hparams.n_embd;
+    const auto n_batch    = cparams.n_batch;
+    const auto n_vocab    = vocab.n_tokens();
+    const auto n_embd_out = hparams.get_n_embd_out();
 
     bool has_logits = true;
     bool has_embd   = cparams.embeddings;
@@ -1358,8 +1746,53 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
         has_embd   = true;
     }
 
-    logits_size = has_logits ? n_vocab*n_outputs_max : 0;
-    embd_size   = has_embd   ?  n_embd*n_outputs_max : 0;
+    // Check which sampling modes are needed for the current batch.
+    // TODO: avoid this branching by working with the worst-case
+    bool has_sampling = false;
+    bool cpu_logits   = false;
+
+    if (batch.logits) {
+        for (int32_t i = 0; i < batch.n_tokens; i++) {
+            if (!batch.logits[i]) {
+                continue;
+            }
+            for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
+                llama_seq_id seq_id = batch.seq_id[i][j];
+                if (sampling.samplers.find(seq_id) != sampling.samplers.end()) {
+                    has_sampling = true;
+                } else {
+                    cpu_logits = true;
+                }
+            }
+        }
+    } else {
+        // When batch.logits is nullptr (when loading state with a dummy batch),
+        // allocate CPU logits.
+        cpu_logits = true;
+    }
+
+    size_t backend_float_count = 0;
+    size_t backend_token_count = 0;
+
+    // Allocate CPU logits buffer only if needed by sequences in this batch
+    logits_size = (has_logits && cpu_logits) ? n_vocab*n_outputs_max : 0;
+    embd_size   = has_embd ? n_embd_out*n_outputs_max : 0;
+
+    // TODO: avoid this branching by working with the worst-case
+    if (!has_sampling) {
+        sampling.logits_size     = 0;
+        sampling.probs_size      = 0;
+        sampling.sampled_size    = 0;
+        sampling.candidates_size = 0;
+    } else {
+        sampling.logits_size     = n_vocab*n_outputs_max;
+        sampling.probs_size      = n_vocab*n_outputs_max;
+        sampling.sampled_size    =         n_outputs_max;
+        sampling.candidates_size = n_vocab*n_outputs_max;
+
+        backend_float_count = sampling.logits_size  + sampling.probs_size;
+        backend_token_count = sampling.sampled_size + sampling.candidates_size;
+    }
 
     if (output_ids.empty()) {
         // init, never resized afterwards
@@ -1367,7 +1800,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
     }
 
     const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
-    const size_t new_size  = (logits_size + embd_size) * sizeof(float);
+    const size_t new_size  =
+        (logits_size + embd_size + backend_float_count) * sizeof(float) +
+        (                          backend_token_count) * sizeof(llama_token);
 
     // alloc only when more than the current capacity is required
     // TODO: also consider shrinking the buffer
@@ -1375,9 +1810,11 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
         if (buf_output) {
 #ifndef NDEBUG
             // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
-            LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
+            LLAMA_LOG_DEBUG("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
 #endif
             synchronize();
+
+            // TODO: not needed?
             buf_output = nullptr;
             logits = nullptr;
             embd = nullptr;
@@ -1399,8 +1836,49 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
 
     float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get());
 
-    logits = has_logits ? output_base               : nullptr;
-    embd   = has_embd   ? output_base + logits_size : nullptr;
+    logits = nullptr;
+    embd   = nullptr;
+
+    size_t offset = 0;
+    uint8_t * base = (uint8_t *) output_base;
+
+    logits = (has_logits && cpu_logits) ? output_base : nullptr;
+    offset += logits_size * sizeof(float);
+
+    embd = has_embd ? (float *) (base + offset) : nullptr;
+    offset += embd_size * sizeof(float);
+
+    sampling.logits     = nullptr;
+    sampling.probs      = nullptr;
+    sampling.sampled    = nullptr;
+    sampling.candidates = nullptr;
+
+    if (has_sampling) {
+        sampling.logits = (float *) (base + offset);
+        offset += sampling.logits_size * sizeof(float);
+
+        sampling.probs = (float *) (base + offset);
+        offset += sampling.probs_size * sizeof(float);
+
+        sampling.sampled = (llama_token *) (base + offset);
+        offset += sampling.sampled_size * sizeof(llama_token);
+
+        sampling.candidates = (llama_token *) (base + offset);
+        offset += sampling.candidates_size * sizeof(llama_token);
+
+        // The count vectors keep track of the actual number of logits/probs/candidates
+        // copied from the backend for each output row.
+
+        sampling.logits_count.resize(n_outputs_max);
+        sampling.probs_count.resize(n_outputs_max);
+        sampling.candidates_count.resize(n_outputs_max);
+
+        std::fill(sampling.logits_count.begin(),     sampling.logits_count.end(),     0);
+        std::fill(sampling.probs_count.begin(),      sampling.probs_count.end(),      0);
+        std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
+
+        std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL);
+    }
 
     // set all ids as invalid (negative)
     std::fill(output_ids.begin(), output_ids.end(), -1);
@@ -1429,6 +1907,40 @@ void llama_context::output_reorder() {
                 std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
             }
         }
+
+        if (sampling.logits && sampling.logits_size > 0) {
+            for (uint64_t k = 0; k < n_vocab; ++k) {
+                std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]);
+            }
+        }
+
+        if (sampling.probs && sampling.probs_size > 0) {
+            for (uint64_t k = 0; k < n_vocab; ++k) {
+                std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]);
+            }
+        }
+
+        if (sampling.candidates && sampling.candidates_size > 0) {
+            for (uint64_t k = 0; k < n_vocab; ++k) {
+                std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]);
+            }
+        }
+
+        if (sampling.sampled && sampling.sampled_size > 0) {
+            std::swap(sampling.sampled[i0], sampling.sampled[i1]);
+        }
+
+        if (!sampling.logits_count.empty()) {
+            std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
+        }
+
+        if (!sampling.probs_count.empty()) {
+            std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
+        }
+
+        if (!sampling.candidates_count.empty()) {
+            std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]);
+        }
     }
 
     output_swaps.clear();
@@ -1458,7 +1970,7 @@ ggml_cgraph * llama_context::graph_reserve(
 
     if (n_tokens % n_seqs != 0) {
         n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs
-        n_outputs = std::min(n_outputs, n_tokens);
+        n_outputs = std::max(n_outputs, n_tokens);
 
         LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
     }
@@ -1477,6 +1989,15 @@ ggml_cgraph * llama_context::graph_reserve(
     llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
     llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
 
+    // set one output token per sequence in order to activate all backend samplers
+    std::vector<llama_seq_id> seq_ids(n_seqs);
+    for (uint32_t i = 0; i < n_seqs; ++i) {
+        seq_ids[i] = i;
+        ubatch.n_seq_id[i] = 1;
+        ubatch.seq_id[i] = &seq_ids[i];
+        ubatch.output[i] = true;
+    }
+
     auto * res = gf_res_reserve.get();
 
     const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
@@ -1507,7 +2028,7 @@ llm_graph_params llama_context::graph_params(
                         llm_graph_result * res,
                       const llama_ubatch & ubatch,
             const llama_memory_context_i * mctx,
-            llm_graph_type   gtype) const {
+                          llm_graph_type   gtype) const {
     return {
         /*.arch        =*/ model.arch,
         /*.hparams     =*/ model.hparams,
@@ -1520,6 +2041,7 @@ llm_graph_params llama_context::graph_params(
         /*.loras       =*/ &loras,
         /*.mctx        =*/ mctx,
         /*.cross       =*/ &cross,
+        /*.samplers    =*/ sampling.samplers,
         /*.n_outputs   =*/ n_outputs,
         /*.cb          =*/ graph_get_cb(),
         /*.res         =*/ res,
@@ -1975,6 +2497,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
         }
     }
 
+    // TODO: handle sampling buffers and samplers state ?
+    //       https://github.com/ggml-org/llama.cpp/pull/17004
+
     if (memory != nullptr) {
         LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
         memory->state_write(io);
@@ -2007,7 +2532,10 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
         auto n_outputs = this->n_outputs;
         io.read_to(&n_outputs, sizeof(n_outputs));
 
-        if (n_outputs > output_reserve(n_outputs)) {
+        // Create a dummy batch for state loading.
+        llama_batch dummy_batch = {};
+        dummy_batch.n_tokens = 0;
+        if (n_outputs > output_reserve(n_outputs, dummy_batch)) {
             throw std::runtime_error("could not reserve outputs");
         }
 
@@ -2061,6 +2589,9 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
         }
     }
 
+    // TODO: handle sampling buffers and samplers state ?
+    //       https://github.com/ggml-org/llama.cpp/pull/17004
+
     if (memory) {
         LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);
 
@@ -2249,7 +2780,7 @@ void llama_context::opt_epoch_iter(
         }
 
         // reserve output buffer
-        if (output_reserve(n_outputs_all) < n_outputs_all) {
+        if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) {
             LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
             GGML_ABORT("TODO: handle this error");
         };
@@ -2394,6 +2925,8 @@ llama_context_params llama_context_default_params() {
         /*.op_offload                  =*/ true,
         /*.swa_full                    =*/ true,
         /*.kv_unified                  =*/ false,
+        /*.sampler                     =*/ nullptr,
+        /*.n_sampler                   =*/ 0,
     };
 
     return result;
@@ -2553,7 +3086,15 @@ float * llama_get_logits(llama_context * ctx) {
 float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
     ctx->synchronize();
 
-    return ctx->get_logits_ith(i);
+    float * res = nullptr;
+
+    res = ctx->get_sampled_logits_ith(i);
+
+    if (!res) {
+        res = ctx->get_logits_ith(i);
+    }
+
+    return res;
 }
 
 float * llama_get_embeddings(llama_context * ctx) {
@@ -2574,6 +3115,52 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
     return ctx->get_embeddings_seq(seq_id);
 }
 
+bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {
+    return ctx->set_sampler(seq_id, smpl);
+}
+
+llama_token llama_get_sampled_token_ith(llama_context * ctx, int32_t i) {
+    ctx->synchronize();
+
+    return ctx->get_sampled_token_ith(i);
+}
+
+float * llama_get_sampled_probs_ith(llama_context * ctx, int32_t i) {
+    ctx->synchronize();
+
+    return ctx->get_sampled_probs_ith(i);
+}
+
+float * llama_get_sampled_logits_ith(llama_context * ctx, int32_t i) {
+    ctx->synchronize();
+
+    return ctx->get_sampled_logits_ith(i);
+}
+
+llama_token * llama_get_sampled_candidates_ith(llama_context * ctx, int32_t i) {
+    ctx->synchronize();
+
+    return const_cast<llama_token *>(ctx->get_sampled_candidates_ith(i));
+}
+
+uint32_t llama_get_sampled_candidates_count_ith(llama_context * ctx, int32_t i) {
+    ctx->synchronize();
+
+    return static_cast<uint32_t>(ctx->get_sampled_candidates_count(i));
+}
+
+uint32_t llama_get_sampled_logits_count_ith(llama_context * ctx, int32_t i) {
+    ctx->synchronize();
+
+    return static_cast<uint32_t>(ctx->get_sampled_logits_count(i));
+}
+
+uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) {
+    ctx->synchronize();
+
+    return static_cast<uint32_t>(ctx->get_sampled_probs_count(i));
+}
+
 // llama adapter API
 
 int32_t llama_set_adapter_lora(
index c31101330e2db826f240c17679ee57c55222c920..b29edf4db21efc05d1027d3a65146bf612b11a24 100644 (file)
@@ -70,6 +70,18 @@ struct llama_context {
     float * get_embeddings_ith(int32_t i);
     float * get_embeddings_seq(llama_seq_id seq_id);
 
+    llama_token * get_sampled_tokens() const;
+    llama_token   get_sampled_token_ith(int32_t idx);
+
+    float * get_sampled_logits_ith(int32_t idx);
+    size_t  get_sampled_logits_count(int32_t idx);
+
+    float * get_sampled_probs_ith(int32_t idx);
+    size_t  get_sampled_probs_count(int32_t idx);
+
+    const llama_token * get_sampled_candidates_ith(int32_t idx);
+    size_t get_sampled_candidates_count(int32_t idx);
+
     void attach_threadpool(
             ggml_threadpool_t threadpool,
             ggml_threadpool_t threadpool_batch);
@@ -192,10 +204,13 @@ private:
 
     // Make sure enough space is available for outputs.
     // Returns max number of outputs for which space was reserved.
-    uint32_t output_reserve(int32_t n_outputs);
+    uint32_t output_reserve(int32_t n_outputs, const llama_batch & batch);
 
     void output_reorder();
 
+    // map the output row index `i` to batch index
+    int64_t output_resolve_row(int32_t i) const;
+
     //
     // graph
     //
@@ -213,6 +228,8 @@ public:
     ggml_cgraph * graph_reserve(
         uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false, size_t * sizes = nullptr);
 
+    bool set_sampler(llama_seq_id seq_id, llama_sampler * sampler);
+
 private:
     llm_graph_params graph_params(
                         llm_graph_result * res,
@@ -252,6 +269,31 @@ private:
     size_t  embd_size = 0; // capacity (of floats) for embeddings
     float * embd      = nullptr;
 
+    // TODO: simplify
+    struct sampling_info {
+        std::map<llama_seq_id, llama_sampler *> samplers;
+
+        float       * logits      = nullptr;
+        size_t        logits_size = 0;
+
+        llama_token * sampled      = nullptr;
+        size_t        sampled_size = 0;
+
+        float       * probs        = nullptr;
+        size_t        probs_size   = 0;
+
+        llama_token * candidates   = nullptr;
+        size_t        candidates_size = 0;
+
+        std::vector<uint32_t> logits_count;
+        std::vector<uint32_t> probs_count;
+        std::vector<uint32_t> candidates_count;
+
+        std::vector<llama_token> token_ids_full_vocab;
+    };
+
+    sampling_info sampling;
+
     // sequence embeddings output (map of [n_embd] vectors)
     // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
     std::map<llama_seq_id, std::vector<float>> embd_seq;
index 75d5d750c3998a74a86153955961440bf81f190a..64ea2fd00a9ac90cf39722d6b4a247290ce5bb8c 100644 (file)
@@ -369,6 +369,44 @@ static void print_rule(
     fprintf(file, "\n");
 }
 
+//
+// Regex utilities
+//
+
+size_t llama_grammar_trigger_pattern::find(const std::string & input) const {
+    auto find_start_pos = [](const std::smatch & match) {
+        // get from the first matched capturing group to the end of the string
+        size_t start = std::string::npos;
+        for (auto i = 1u; i < match.size(); i++) {
+            if (match.length(i) > 0) {
+                start = match.position(i);
+                break;
+            }
+        }
+        if (start == std::string::npos) {
+            start = match.position(0);
+        }
+        return start;
+    };
+
+    if (!pattern.empty() && pattern.front() == '^' && pattern.back() == '$') {
+        // match against the entire input
+        std::smatch match;
+        if (std::regex_match(input, match, regex)) {
+            return find_start_pos(match);
+        }
+    }
+
+    // search anywhere
+    std::smatch match;
+    if (std::regex_search(input, match, regex)) {
+        return find_start_pos(match);
+    }
+
+    return std::string::npos;
+}
+
+
 //
 // implementation
 //
@@ -1312,21 +1350,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
             grammar.trigger_buffer_positions.push_back(std::make_pair(token, position));
             grammar.trigger_buffer += piece;
 
-            std::smatch match;
             for (const auto & trigger_pattern : grammar.trigger_patterns) {
-                if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
+                auto start = trigger_pattern.find(grammar.trigger_buffer);
+                if (start != std::string::npos) {
                     grammar.awaiting_trigger = false;
-                    // get from the first matched capturing group to the end of the string
-                    size_t start = std::string::npos;
-                    for (auto i = 1u; i < match.size(); i++) {
-                        if (match.length(i) > 0) {
-                            start = match.position(i);
-                            break;
-                        }
-                    }
-                    if (start == std::string::npos) {
-                        start = match.position(0);
-                    }
 
                     // replay tokens that overlap with [start, end)
                     for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) {
index a4c978ac1154549ba4df8a4590d2ec214c94a044..b5a0e588e903be1f3277026ebfbc81cb966216f2 100644 (file)
@@ -119,6 +119,8 @@ struct llama_grammar_parser {
 struct llama_grammar_trigger_pattern {
     std::string pattern;
     std::regex  regex;
+
+    size_t find(const std::string & input) const;
 };
 
 struct llama_grammar {
index 1d0d7197e1f7fde65fd7fd24a3974545deae1cf0..374ff1ebf3a2a41d498b6a0476593272b0707518 100644 (file)
@@ -12,6 +12,7 @@
 #include <cassert>
 #include <cmath>
 #include <cstring>
+#include <unordered_set>
 
 void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
     if (ubatch->token) {
@@ -32,7 +33,7 @@ bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
     bool res = true;
 
     res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
-    res &= (!embd   && !params.ubatch.embd)  || (embd   &&   embd->ne[0] == params.ubatch.n_tokens);
+    res &= (!embd   && !params.ubatch.embd)  || (embd   &&   embd->ne[1] == params.ubatch.n_tokens);
 
     return res;
 }
@@ -62,7 +63,7 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
 bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
     bool res = true;
 
-    res &= pos->ne[0] == params.ubatch.n_tokens;
+    res &= pos->ne[0] == params.ubatch.n_tokens*n_pos_per_embd;
 
     return res;
 }
@@ -521,6 +522,43 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
     return res;
 }
 
+void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
+    // set the inputs only for the active samplers in the current ubatch
+    std::unordered_set<llama_seq_id> active_samplers;
+    for (uint32_t i = 0; i < ubatch->n_tokens; i++) {
+        if (ubatch->output[i]) {
+            llama_seq_id seq_id = ubatch->seq_id[i][0];
+            active_samplers.insert(seq_id);
+        }
+    }
+
+    for (auto seq_id : active_samplers) {
+        if (samplers.find(seq_id) == samplers.end()) {
+            continue;
+        }
+
+        auto & sampler = samplers[seq_id];
+
+        if (sampler->iface->backend_set_input) {
+            sampler->iface->backend_set_input(sampler);
+        }
+    }
+}
+
+bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) {
+    if (samplers.size() != params.samplers.size()) {
+        return false;
+    }
+
+    for (const auto & [seq_id, sampler] : params.samplers) {
+        if (samplers[seq_id] != sampler) {
+            return false;
+        }
+    }
+
+    return true;
+}
+
 //
 // llm_graph_result
 //
@@ -541,6 +579,10 @@ void llm_graph_result::reset() {
     t_logits      = nullptr;
     t_embd        = nullptr;
     t_embd_pooled = nullptr;
+    t_sampled.clear();
+    t_sampled_probs.clear();
+    t_sampled_logits.clear();
+    t_candidates.clear();
 
     params = {};
 
@@ -565,6 +607,38 @@ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
     }
 }
 
+void llm_graph_result::set_outputs() {
+    if (t_logits != nullptr) {
+        ggml_set_output(t_logits);
+    }
+    if (t_embd != nullptr) {
+        ggml_set_output(t_embd);
+    }
+    if (t_embd_pooled != nullptr) {
+        ggml_set_output(t_embd_pooled);
+    }
+    for (auto & [seq_id, t] : t_sampled) {
+        if (t != nullptr) {
+            ggml_set_output(t);
+        }
+    }
+    for (auto & [seq_id, t] : t_sampled_probs) {
+        if (t != nullptr) {
+            ggml_set_output(t);
+        }
+    }
+    for (auto & [seq_id, t] : t_sampled_logits) {
+        if (t != nullptr) {
+            ggml_set_output(t);
+        }
+    }
+    for (auto & [seq_id, t] : t_candidates) {
+        if (t != nullptr) {
+            ggml_set_output(t);
+        }
+    }
+}
+
 bool llm_graph_result::can_reuse(const llm_graph_params & params) {
     if (!this->params.allow_reuse(params)) {
         if (debug > 1) {
@@ -646,6 +720,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
     loras            (params.loras),
     mctx             (params.mctx),
     cross            (params.cross),
+    samplers         (params.samplers),
     cb_func          (params.cb),
     res              (params.res),
     ctx0             (res->get_ctx()),
@@ -1251,6 +1326,10 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
 
     res->add_input(std::move(inp));
 
+    // make sure the produced embeddings are immediately materialized in the ggml graph
+    // ref: https://github.com/ggml-org/llama.cpp/pull/18599
+    ggml_build_forward_expand(gf, cur);
+
     return cur;
 }
 
@@ -1834,8 +1913,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
 
         inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
         ggml_set_input(inp->self_kq_mask);
+        ggml_set_name(inp->self_kq_mask, "self_kq_mask");
 
         inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
+        ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv");
     }
 
     {
@@ -1848,8 +1929,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
 
         inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
         ggml_set_input(inp->self_kq_mask_swa);
+        ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
 
         inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
+        ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv");
     }
 
     return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
@@ -1988,14 +2071,18 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
 void llm_graph_context::build_dense_out(
     ggml_tensor * dense_2,
     ggml_tensor * dense_3) const {
-    if (!cparams.embeddings || dense_2 == nullptr || dense_3 == nullptr) {
+    if (!cparams.embeddings || !(dense_2 || dense_3)) {
         return;
     }
     ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd;
     GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd");
 
-    cur = ggml_mul_mat(ctx0, dense_2, cur);
-    cur = ggml_mul_mat(ctx0, dense_3, cur);
+    if (dense_2) {
+        cur = ggml_mul_mat(ctx0, dense_2, cur);
+    }
+    if (dense_3) {
+        cur = ggml_mul_mat(ctx0, dense_3, cur);
+    }
     cb(cur, "result_embd_pooled", -1);
     res->t_embd_pooled = cur;
     ggml_build_forward_expand(gf, cur);
@@ -2086,6 +2173,87 @@ void llm_graph_context::build_pooling(
     ggml_build_forward_expand(gf, cur);
 }
 
+void llm_graph_context::build_sampling() const {
+    if (samplers.empty() || !res->t_logits) {
+        return;
+    }
+
+    auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
+    res->add_input(std::move(inp_sampling));
+
+    std::map<llama_seq_id, int32_t> seq_to_logit_row;
+    int32_t logit_row_idx = 0;
+
+    for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
+        if (ubatch.output[i]) {
+            llama_seq_id seq_id = ubatch.seq_id[i][0];
+            seq_to_logit_row[seq_id] = logit_row_idx;
+            logit_row_idx++;
+        }
+    }
+
+    // res->t_logits will contain logits for all tokens that want the logits calculated (logits=1 or output=1)
+    GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor");
+
+    // add a dummy row of logits
+    // this trick makes the graph static, regardless of which samplers are activated
+    // this is important in order to minimize graph reallocations
+    // TODO: use `ggml_build_forward_select()` when available (https://github.com/ggml-org/llama.cpp/pull/18550)
+    ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0);
+
+    for (const auto & [seq_id, sampler] : samplers) {
+        const auto it = seq_to_logit_row.find(seq_id);
+
+        // inactive samplers always work on the first row
+        const auto row_idx = seq_to_logit_row.find(seq_id) != seq_to_logit_row.end() ? it->second : 0;
+
+        ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
+        ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
+
+        struct llama_sampler_data data = {
+            /*.logits      =*/ logits_seq,
+            /*.probs       =*/ nullptr,
+            /*.sampled     =*/ nullptr,
+            /*.candidates  =*/ nullptr,
+        };
+
+        assert(sampler->iface->backend_apply);
+        sampler->iface->backend_apply(sampler, ctx0, gf, &data);
+
+        if (data.sampled != nullptr) {
+            res->t_sampled[seq_id] = data.sampled;
+            ggml_build_forward_expand(gf, data.sampled);
+        }
+
+        if (data.probs != nullptr) {
+            res->t_sampled_probs[seq_id] = data.probs;
+            ggml_build_forward_expand(gf, data.probs);
+        }
+
+        if (data.logits != nullptr) {
+            res->t_sampled_logits[seq_id] = data.logits;
+            ggml_build_forward_expand(gf, data.logits);
+        }
+
+        if (data.candidates != nullptr) {
+            res->t_candidates[seq_id] = data.candidates;
+            ggml_build_forward_expand(gf, data.candidates);
+        }
+    }
+
+    // TODO: Call llama_sampler_accept_ggml after all samplers have been applied.
+    /*
+    for (const auto & [seq_id, sampler] : samplers) {
+        if (auto it = res->t_sampled.find(seq_id); it != res->t_sampled.end()) {
+            ggml_tensor * selected_token = it->second;
+            if (selected_token != nullptr) {
+                llama_sampler_accept_ggml(sampler, ctx0, gf, selected_token);
+            }
+        }
+    }
+    */
+}
+
 int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
     // TODO move to hparams if a T5 variant appears that uses a different value
     const int64_t max_distance = 128;
index 81ac329cc311b10ea9c5c722f7679744306c88ba..503ffd695aa76b3388d8b08a6da78227ccb5f5d3 100644 (file)
@@ -10,6 +10,7 @@
 #include <memory>
 #include <set>
 #include <functional>
+#include <map>
 
 struct ggml_cgraph;
 struct ggml_context;
@@ -396,6 +397,18 @@ public:
     const llama_memory_hybrid_context * mctx;
 };
 
+class llm_graph_input_sampling : public llm_graph_input_i {
+public:
+    llm_graph_input_sampling(std::map<llama_seq_id, llama_sampler *> samplers) :
+        samplers(std::move(samplers)) { }
+    virtual ~llm_graph_input_sampling() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+    bool can_reuse(const llm_graph_params & params) override;
+
+    std::map<llama_seq_id, llama_sampler *> samplers;
+};
+
 //
 // llm_graph_result
 //
@@ -429,6 +442,23 @@ struct llm_graph_params {
     const llama_memory_context_i * mctx;
     const llama_cross            * cross;
 
+    std::map<llama_seq_id, llama_sampler *> samplers;
+
+    static bool samplers_equal(
+          const std::map<llama_seq_id, llama_sampler *> & lhs,
+          const std::map<llama_seq_id, llama_sampler *> & rhs) {
+        if (lhs.size() != rhs.size()) {
+            return false;
+        }
+        for (const auto & [seq_id, sampler] : lhs) {
+            auto it = rhs.find(seq_id);
+            if (it == rhs.end() || it->second != sampler) {
+                return false;
+            }
+        }
+        return true;
+    }
+
     uint32_t n_outputs;
 
     llm_graph_cb cb;
@@ -468,15 +498,36 @@ struct llm_graph_params {
             return false;
         }
 
+        if (n_outputs != other.n_outputs) {
+            return false;
+        }
+
+        if (!samplers_equal(samplers, other.samplers)) {
+            return false;
+        }
+
+        if (samplers.size() > 0) {
+            if (!ubatch.data || !other.ubatch.data) {
+                return false;
+            }
+
+            // check that the outputs are the same for all samplers
+            for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
+                if (ubatch.output[i]    != other.ubatch.output[i] ||
+                    ubatch.seq_id[i][0] != other.ubatch.seq_id[i][0]) {
+                    return false;
+                }
+            }
+        }
+
         return
             cparams.embeddings  == other.cparams.embeddings  &&
             cparams.causal_attn == other.cparams.causal_attn &&
-            arch      == other.arch  &&
-            gtype     == other.gtype &&
-            cvec      == other.cvec  &&
-            loras     == other.loras &&
-            cross     == other.cross &&
-            n_outputs == other.n_outputs;
+            arch  == other.arch  &&
+            gtype == other.gtype &&
+            cvec  == other.cvec  &&
+            loras == other.loras &&
+            cross == other.cross;
     }
 };
 
@@ -499,6 +550,7 @@ public:
     void reset();
 
     void set_inputs(const llama_ubatch * ubatch);
+    void set_outputs();
 
     // try to update the existing graph result using the new graph parameters in order to reuse it
     // this can only be done if we determine that the resulting graph using the new graph parameters
@@ -517,6 +569,11 @@ public:
     ggml_tensor * t_embd        = nullptr;
     ggml_tensor * t_embd_pooled = nullptr;
 
+    std::map<llama_seq_id, ggml_tensor*> t_sampled_logits;
+    std::map<llama_seq_id, ggml_tensor*> t_candidates;
+    std::map<llama_seq_id, ggml_tensor*> t_sampled;
+    std::map<llama_seq_id, ggml_tensor*> t_sampled_probs;
+
     std::vector<llm_graph_input_ptr> inputs;
 
     ggml_context_ptr ctx_compute;
@@ -592,6 +649,8 @@ struct llm_graph_context {
     const llama_memory_context_i * mctx;
     const llama_cross            * cross;
 
+    std::map<llama_seq_id, llama_sampler *> samplers;
+
     const llm_graph_cb & cb_func;
 
     llm_graph_result * res;
@@ -832,6 +891,12 @@ struct llm_graph_context {
             ggml_tensor * cls_out,
             ggml_tensor * cls_out_b) const;
 
+    //
+    // sampling (backend sampling)
+    //
+
+    void build_sampling() const;
+
     //
     // dense (out)
     //
index fe1fa4341d4288c209bc2ccebb6141fd8c9a1e7c..c847ef91b7aa26b8a4e76b94675550c6293ece12 100644 (file)
@@ -72,6 +72,10 @@ uint32_t llama_hparams::n_embd_inp() const {
     return n_embd_inp;
 }
 
+uint32_t llama_hparams::get_n_embd_out() const {
+    return n_embd_out > 0 ? n_embd_out : n_embd;
+}
+
 uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const {
     const uint32_t n_head_kv = this->n_head_kv(il);
 
index 42def73f06f772ea546dc27f59f2b226b19f495e..7ae3ec292efed1c8cd62f0dda95a868e44861c94 100644 (file)
@@ -105,9 +105,9 @@ struct llama_hparams {
 
     float    rope_attn_factor = 1.0f;
     float    rope_freq_base_train;
-    float    rope_freq_base_train_swa;
+    float    rope_freq_base_train_swa  = 10000.0f;
     float    rope_freq_scale_train;
-    float    rope_freq_scale_train_swa;
+    float    rope_freq_scale_train_swa = 1.0f;
 
     uint32_t n_ctx_orig_yarn;
     float    rope_yarn_log_mul = 0.0f;
@@ -162,6 +162,9 @@ struct llama_hparams {
     // for Classifiers
     uint32_t n_cls_out = 1;
 
+    // output embedding dimension (0 = use n_embd)
+    uint32_t n_embd_out = 0;
+
     // llama4 smallthinker
     uint32_t n_moe_layer_step        = 0;
     uint32_t n_no_rope_layer_step    = 4;
@@ -234,6 +237,9 @@ struct llama_hparams {
     // dimension of main + auxiliary input embeddings
     uint32_t n_embd_inp() const;
 
+    // dimension of output embeddings
+    uint32_t get_n_embd_out() const;
+
     // dimension of key embeddings across all k-v heads
     uint32_t n_embd_k_gqa(uint32_t il = 0) const;
 
index 23b648a2e3b63b6e6dd526814f305dac4099d000..2da857b3aaec22b90782fd1d61afd922af1a45b8 100644 (file)
@@ -110,7 +110,7 @@ struct llama_file::impl {
         }
     }
 
-    void read_raw(void * ptr, size_t len) const {
+    void read_raw(void * ptr, size_t len) {
         size_t bytes_read = 0;
         while (bytes_read < len) {
             size_t chunk_size = std::min<size_t>(len - bytes_read, 64*1024*1024);
@@ -127,7 +127,7 @@ struct llama_file::impl {
         }
     }
 
-    uint32_t read_u32() const {
+    uint32_t read_u32() {
         uint32_t val;
         read_raw(&val, sizeof(val));
         return val;
@@ -154,8 +154,8 @@ struct llama_file::impl {
         write_raw(&val, sizeof(val));
     }
 
-    void read_aligned_chunk(size_t offset, void * dest, size_t size) const {
-        throw std::runtime_error("DirectIO is not implemented on Windows.");
+    bool has_direct_io() const {
+        return true;
     }
 
     ~impl() {
@@ -164,33 +164,45 @@ struct llama_file::impl {
         }
     }
 #else
-    impl(const char * fname, const char * mode, [[maybe_unused]] const bool use_direct_io = false) {
+    impl(const char * fname, const char * mode, [[maybe_unused]] const bool use_direct_io = false) : fname(fname) {
 #ifdef __linux__
         // Try unbuffered I/O for read only
         if (use_direct_io && std::strcmp(mode, "rb") == 0) {
-            fd = open(fname, O_RDONLY | O_DIRECT);
+            if (init_fd()) {
+                return;
+            }
+            LLAMA_LOG_WARN("Failed to open file '%s' with error: %s. Falling back to buffered I/O",
+                           fname, strerror(errno));
+        }
+#endif
+        init_fp(mode);
+    }
 
-            if (fd != -1) {
-                struct stat file_stats{};
-                fstat(fd, &file_stats);
+#ifdef __linux__
+    bool init_fd() {
+        fd = open(fname.c_str(), O_RDONLY | O_DIRECT);
 
-                size = file_stats.st_size;
-                alignment = file_stats.st_blksize;
+        if (fd != -1) {
+            struct stat file_stats{};
+            fstat(fd, &file_stats);
 
-                off_t ret = lseek(fd, 0, SEEK_SET);
-                if (ret == -1) {
-                    throw std::runtime_error(format("seek error: %s", strerror(errno)));
-                }
-                return;
-            }
+            size = file_stats.st_size;
+            alignment = file_stats.st_blksize;
 
-            LLAMA_LOG_WARN("Failed to open model %s with error: %s. Falling back to buffered I/O",
-                fname, strerror(errno));
+            off_t ret = lseek(fd, 0, SEEK_SET);
+            if (ret == -1) {
+                throw std::runtime_error(format("seek error: %s", strerror(errno)));
+            }
+            return true;
         }
+        return false;
+    }
 #endif
-        fp = ggml_fopen(fname, mode);
+
+    void init_fp(const char * mode) {
+        fp = ggml_fopen(fname.c_str(), mode);
         if (fp == NULL) {
-            throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
+            throw std::runtime_error(format("failed to open %s: %s", fname.c_str(), strerror(errno)));
         }
         seek(0, SEEK_END);
         size = tell();
@@ -226,7 +238,7 @@ struct llama_file::impl {
         }
     }
 
-    void read_raw(void * ptr, size_t len) const {
+    void read_raw_unsafe(void * ptr, size_t len) {
         if (len == 0) {
             return;
         }
@@ -240,26 +252,45 @@ struct llama_file::impl {
                 throw std::runtime_error("unexpectedly reached end of file");
             }
         } else {
-            bool successful = false;
-            while (!successful) {
-                off_t ret = read(fd, ptr, len);
+            size_t bytes_read = 0;
+            while (bytes_read < len) {
+                const size_t to_read = len - bytes_read;
+                ssize_t ret = ::read(fd, reinterpret_cast<char *>(ptr) + bytes_read, to_read);
 
                 if (ret == -1) {
                     if (errno == EINTR) {
                         continue;  // Interrupted by signal, retry
                     }
+                    // Fallback to std::fread in case the DMA controller cannot access the buffer
+                    if (errno == EFAULT) {
+                        auto curr_off = tell();
+                        close(fd);
+                        fd = -1;
+                        alignment = 1;
+                        init_fp("rb");
+                        seek(curr_off, SEEK_SET);
+                        read_raw_unsafe(ptr, len);
+                        return;
+                    }
                     throw std::runtime_error(format("read error: %s", strerror(errno)));
                 }
                 if (ret == 0) {
+                    // EOF: allow if this read was only pulling alignment padding past file end
+                    off_t pos = lseek(fd, 0, SEEK_CUR);
+                    if (pos != -1 && (size_t) pos == size) {
+                        std::memset(reinterpret_cast<char *>(ptr) + bytes_read, 0, len - bytes_read);
+                        return;
+                    }
                     throw std::runtime_error("unexpectedly reached end of file");
                 }
 
-                successful = true;
+                bytes_read += (size_t) ret;
             }
         }
     }
 
-    void read_aligned_chunk(size_t offset, void * dest, size_t size) const {
+    void read_aligned_chunk(void * dest, size_t size) {
+        size_t offset = tell();
         off_t aligned_offset = offset & ~(alignment - 1);
         off_t offset_from_alignment = offset - aligned_offset;
         size_t bytes_to_read = (offset_from_alignment + size + alignment - 1) & ~(alignment - 1);
@@ -276,13 +307,21 @@ struct llama_file::impl {
         std::unique_ptr<void, aligned_buffer_deleter> buffer(raw_buffer);
 
         seek(aligned_offset, SEEK_SET);
-        read_raw(buffer.get(), bytes_to_read);
+        read_raw_unsafe(buffer.get(), bytes_to_read);
 
         uintptr_t actual_data = reinterpret_cast<uintptr_t>(buffer.get()) + offset_from_alignment;
         memcpy(dest, reinterpret_cast<void *>(actual_data), size);
     }
 
-    uint32_t read_u32() const {
+    void read_raw(void * ptr, size_t len) {
+        if (has_direct_io()) {
+            read_aligned_chunk(ptr, len);
+        } else {
+            read_raw_unsafe(ptr, len);
+        }
+    }
+
+    uint32_t read_u32() {
         uint32_t ret;
         read_raw(&ret, sizeof(ret));
         return ret;
@@ -303,6 +342,10 @@ struct llama_file::impl {
         write_raw(&val, sizeof(val));
     }
 
+    bool has_direct_io() const {
+        return fd != -1 && alignment > 1;
+    }
+
     ~impl() {
         if (fd != -1) {
             close(fd);
@@ -311,17 +354,9 @@ struct llama_file::impl {
         }
     }
     int fd = -1;
+    std::string fname;
 #endif
 
-    void read_raw_at(void * ptr, size_t len, size_t offset) const {
-        if (alignment != 1) {
-            read_aligned_chunk(offset, ptr, len);
-        } else {
-            seek(offset, SEEK_SET);
-            read_raw(ptr, len);
-        }
-    }
-
     size_t read_alignment() const {
         return alignment;
     }
@@ -340,6 +375,7 @@ size_t llama_file::tell() const { return pimpl->tell(); }
 size_t llama_file::size() const { return pimpl->size; }
 
 size_t llama_file::read_alignment() const { return pimpl->read_alignment(); }
+bool llama_file::has_direct_io() const { return pimpl->has_direct_io(); }
 
 int llama_file::file_id() const {
 #ifdef _WIN32
@@ -354,10 +390,14 @@ int llama_file::file_id() const {
 }
 
 void llama_file::seek(size_t offset, int whence) const { pimpl->seek(offset, whence); }
-void llama_file::read_raw(void * ptr, size_t len) const { pimpl->read_raw(ptr, len); }
-void llama_file::read_raw_at(void * ptr, size_t len, size_t offset) const { pimpl->read_raw_at(ptr, len, offset); }
+void llama_file::read_raw(void * ptr, size_t len) { pimpl->read_raw(ptr, len); }
+#ifdef _WIN32
+void llama_file::read_raw_unsafe(void * ptr, size_t len) { pimpl->read_raw(ptr, len); }
+#else
+void llama_file::read_raw_unsafe(void * ptr, size_t len) { pimpl->read_raw_unsafe(ptr, len); }
+#endif
 
-uint32_t llama_file::read_u32() const { return pimpl->read_u32(); }
+uint32_t llama_file::read_u32() { return pimpl->read_u32(); }
 
 void llama_file::write_raw(const void * ptr, size_t len) const { pimpl->write_raw(ptr, len); }
 void llama_file::write_u32(uint32_t val) const { pimpl->write_u32(val); }
index 729aac164b838b9ce88e1dc9faa61c9b8522483c..29ce4d246857874b6d05d98ffa5ae0f6614a8ee1 100644 (file)
@@ -24,15 +24,16 @@ struct llama_file {
 
     void seek(size_t offset, int whence) const;
 
-    void read_raw(void * ptr, size_t len) const;
-    void read_raw_at(void * ptr, size_t len, size_t offset) const;
-    void read_aligned_chunk(size_t offset, void * dest, size_t size) const;
-    uint32_t read_u32() const;
+    void read_raw(void * ptr, size_t len);
+    void read_raw_unsafe(void * ptr, size_t len);
+    void read_aligned_chunk(void * dest, size_t size);
+    uint32_t read_u32();
 
     void write_raw(const void * ptr, size_t len) const;
     void write_u32(uint32_t val) const;
 
     size_t read_alignment() const;
+    bool has_direct_io() const;
 private:
     struct impl;
     std::unique_ptr<impl> pimpl;
index 5003b4fbf5301b46d4d3a4fcc63b96bbb6874474..e66febaa021b6b20c8121b38fb8e6f1c6d4c20c1 100644 (file)
@@ -495,6 +495,7 @@ llama_model_loader::llama_model_loader(
         const std::string & fname,
         std::vector<std::string> & splits,
         bool use_mmap,
+        bool use_direct_io,
         bool check_tensors,
         bool no_alloc,
         const llama_model_kv_override * param_overrides_p,
@@ -527,9 +528,17 @@ llama_model_loader::llama_model_loader(
     get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
     llm_kv = LLM_KV(llm_arch_from_string(arch_name));
 
-    files.emplace_back(new llama_file(fname.c_str(), "rb", !use_mmap));
+    files.emplace_back(new llama_file(fname.c_str(), "rb", use_direct_io));
     contexts.emplace_back(ctx);
 
+    use_direct_io = use_direct_io && files.back()->has_direct_io();
+
+    // Disable mmap in case Direct I/O is enabled and available
+    if (use_direct_io && use_mmap) {
+        use_mmap = false;
+        LLAMA_LOG_WARN("%s: direct I/O is enabled, disabling mmap\n", __func__);
+    }
+
     // Save tensors data offset of the main file.
     // For subsidiary files, `meta` tensor data offset must not be used,
     // so we build a unified tensors index for weights.
@@ -595,7 +604,7 @@ llama_model_loader::llama_model_loader(
                 }
             }
 
-            files.emplace_back(new llama_file(fname_split, "rb", !use_mmap));
+            files.emplace_back(new llama_file(fname_split, "rb", use_direct_io));
             contexts.emplace_back(ctx);
 
             // Save tensors data offset info of the shard.
@@ -739,6 +748,7 @@ llama_model_loader::llama_model_loader(
     }
 
     this->use_mmap = use_mmap;
+    this->use_direct_io = use_direct_io;
     this->check_tensors = check_tensors;
     this->no_alloc = no_alloc;
 }
@@ -1100,7 +1110,8 @@ bool llama_model_loader::load_all_data(
             const auto & file = files.at(weight->idx);
 
             if (ggml_backend_buffer_is_host(cur->buffer)) {
-                file->read_raw_at(cur->data, n_size, weight->offs);
+                file->seek(weight->offs, SEEK_SET);
+                file->read_raw(cur->data, n_size);
                 if (check_tensors) {
                     validation_result.emplace_back(std::async(std::launch::async, [cur, n_size] {
                         return std::make_pair(cur, ggml_validate_row_data(cur->type, cur->data, n_size));
@@ -1132,7 +1143,7 @@ bool llama_model_loader::load_all_data(
                         ggml_backend_event_synchronize(events[buffer_idx]);
 
                         // Read aligned chunk from file
-                        file->read_raw(reinterpret_cast<void *>(ptr_dest_aligned), read_size);
+                        file->read_raw_unsafe(reinterpret_cast<void *>(ptr_dest_aligned), read_size);
 
                         // Calculate actual data portion (excluding alignment padding)
                         uintptr_t ptr_data = ptr_dest_aligned;
@@ -1162,7 +1173,8 @@ bool llama_model_loader::load_all_data(
                     }
                 } else {
                     read_buf.resize(n_size);
-                    file->read_raw_at(read_buf.data(), n_size, weight->offs);
+                    file->seek(weight->offs, SEEK_SET);
+                    file->read_raw(read_buf.data(), n_size);
                     ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size);
                     if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) {
                         throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
index d13299ad3f12819174826a8e732eee4f08ac5f80..65953dd3d5a6950c38b713d8649dabd4081f1ee7 100644 (file)
@@ -70,6 +70,7 @@ struct llama_model_loader {
     size_t   n_bytes    = 0;
 
     bool use_mmap = false;
+    bool use_direct_io = false;
     bool check_tensors;
     bool no_alloc;
 
@@ -97,6 +98,7 @@ struct llama_model_loader {
         const std::string & fname,
         std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme
         bool use_mmap,
+        bool use_direct_io,
         bool check_tensors,
         bool no_alloc,
         const llama_model_kv_override * param_overrides_p,
index 563823dc35d8eef29b5e7b62589c6500face3a97..ae27c71ce2300ac5cc477aca2f47a9d86625ef90 100644 (file)
@@ -146,6 +146,9 @@ void llama_model_saver::add_kv_from_model() {
     add_kv(LLM_KV_VOCAB_SIZE,                        vocab.n_tokens());
     add_kv(LLM_KV_CONTEXT_LENGTH,                    hparams.n_ctx_train);
     add_kv(LLM_KV_EMBEDDING_LENGTH,                  hparams.n_embd);
+    if (hparams.n_embd_out > 0) {
+        add_kv(LLM_KV_EMBEDDING_LENGTH_OUT,          hparams.n_embd_out);
+    }
     add_kv(LLM_KV_BLOCK_COUNT,                       hparams.n_layer);
     add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT,         hparams.n_layer_dense_lead);
     add_kv(LLM_KV_FEED_FORWARD_LENGTH,               hparams.n_ff_arr, true);
index 5e664c8c574a045c5177ae558c9a9b3629d8a21c..f6cea8f8db4ed926aab90f1e04e08d6c59f6caf6 100644 (file)
@@ -126,6 +126,7 @@ const char * llm_type_name(llm_type type) {
         case LLM_TYPE_31B_A3_5B:     return "31B.A3.5B";
         case LLM_TYPE_80B_A3B:       return "80B.A3B";
         case LLM_TYPE_100B_A6B:      return "100B.A6B";
+        case LLM_TYPE_102B_A12B:     return "102B.A12B";
         case LLM_TYPE_106B_A12B:     return "106B.A12B";
         case LLM_TYPE_230B_A10B:     return "230B.A10B";
         case LLM_TYPE_235B_A22B:     return "235B.A22B";
@@ -506,6 +507,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
 
     ml.get_key(LLM_KV_CONTEXT_LENGTH,          hparams.n_ctx_train);
     ml.get_key(LLM_KV_EMBEDDING_LENGTH,        hparams.n_embd);
+    ml.get_key(LLM_KV_EMBEDDING_LENGTH_OUT,    hparams.n_embd_out, false);
     ml.get_key(LLM_KV_BLOCK_COUNT,             hparams.n_layer);
     ml.get_key(LLM_KV_EXPERT_COUNT,            hparams.n_expert,        false);
     ml.get_key(LLM_KV_EXPERT_USED_COUNT,       hparams.n_expert_used,   false);
@@ -577,6 +579,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
     hparams.rope_scaling_type_train = llama_rope_scaling_type_from_string(rope_scaling);
     GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED);
 
+    // TODO: Handle SWA metadata similarly when models start implementing it
     // rope_freq_scale (inverse of the kv) is optional
     float ropescale = 0.0f;
     if (!ml.get_key(LLM_KV_ROPE_SCALING_FACTOR, ropescale, false)) {
@@ -585,10 +588,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
     }
     hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;
 
-    // by default assume that the sliding-window layers use the same scaling type as the non-sliding-window layers
-    hparams.rope_freq_base_train_swa  = hparams.rope_freq_base_train;
-    hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
-
     ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false);
 
     // non-transformer models do not have attention heads
@@ -676,6 +675,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     hparams.f_attn_temp_scale       = 0.1f;
                     hparams.f_attn_temp_offset      = 1.0f;
                     hparams.set_swa_pattern(4);   // pattern: 3 chunked - 1 full
+
+                    hparams.rope_freq_base_train_swa  = hparams.rope_freq_base_train;
+                    hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
+                    ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
                 }
 
                 switch (hparams.n_expert) {
@@ -721,6 +724,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 if (hparams.n_swa > 0) {
                     hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
                     hparams.set_swa_pattern(4);
+
+                    hparams.rope_freq_base_train_swa  = hparams.rope_freq_base_train;
+                    hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
+                    ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
                 } else {
                     hparams.swa_type = LLAMA_SWA_TYPE_NONE;
                 }
@@ -1109,6 +1116,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_MAINCODER:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                switch (hparams.n_layer) {
+                    case 32: type = LLM_TYPE_1B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_QWEN3VL:
             {
                 ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false);
@@ -1234,7 +1249,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 if (found_swa && hparams.n_swa > 0) {
                     uint32_t swa_period = 8;
                     hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
-                    hparams.rope_freq_scale_train_swa = 1.0f;
                     ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa);
                     ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false);
                     hparams.set_swa_pattern(swa_period);
@@ -1300,7 +1314,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 hparams.n_swa = 4096; // default value of gemma 2
                 hparams.set_swa_pattern(2);
                 hparams.attn_soft_cap = true;
+                hparams.rope_freq_base_train_swa  = hparams.rope_freq_base_train;
+                hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
 
+                ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA,          hparams.rope_freq_base_train_swa, false);
                 ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW,    hparams.n_swa, false);
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
                 ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING,      hparams.f_attn_logit_softcapping, false);
@@ -1325,8 +1342,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
                     hparams.set_swa_pattern(6);
 
-                    hparams.rope_freq_base_train_swa  = 10000.0f;
-                    hparams.rope_freq_scale_train_swa = 1.0f;
+                    ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
                 } else {
                     hparams.swa_type = LLAMA_SWA_TYPE_NONE;
                 }
@@ -1356,10 +1372,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 hparams.set_swa_pattern(5);
 
                 hparams.n_layer_kv_from_start     = 20;
-                hparams.rope_freq_base_train_swa  = 10000.0f;
-                hparams.rope_freq_scale_train_swa = 1.0f;
                 hparams.f_attention_scale         = 1.0f;
 
+                ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA,          hparams.rope_freq_base_train_swa, false);
                 ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW,    hparams.n_swa);
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
 
@@ -1375,9 +1390,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 hparams.set_swa_pattern(6);
 
                 hparams.causal_attn = false; // embeddings do not use causal attention
-                hparams.rope_freq_base_train_swa = 10000.0f;
-                hparams.rope_freq_scale_train_swa = 1.0f;
 
+                ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
                 ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
                 ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
@@ -1516,7 +1530,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
             {
                 hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
                 hparams.set_swa_pattern(4);
+                hparams.rope_freq_base_train_swa  = hparams.rope_freq_base_train;
+                hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
 
+                ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA,       hparams.rope_freq_base_train_swa, false);
                 ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
                 ml.get_key(LLM_KV_LOGIT_SCALE,              hparams.f_logit_scale);
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,  hparams.f_norm_eps);
@@ -1555,6 +1572,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 if (found_swa && hparams.n_swa > 0) {
                     hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
                     hparams.set_swa_pattern(4);
+
+                    hparams.rope_freq_base_train_swa  = hparams.rope_freq_base_train;
+                    hparams.rope_freq_scale_train_swa = 1.0; // See olmo2.cpp
+                    ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
                 } else {
                     hparams.swa_type = LLAMA_SWA_TYPE_NONE;
                 }
@@ -1682,7 +1703,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false);
                 ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
                 ml.get_key(LLM_KV_EXPERT_SHARED_COUNT,        hparams.n_expert_shared);
-                ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE,       hparams.expert_weights_scale);
+                ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE,       hparams.expert_weights_scale, false);
                 ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM,        hparams.expert_weights_norm, false);
                 ml.get_key(LLM_KV_EXPERT_GATING_FUNC,         hparams.expert_gating_func, false);
                 if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
@@ -1778,6 +1799,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
 
                 switch (hparams.n_layer) {
                     case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer)
+                    case 48: type = LLM_TYPE_102B_A12B; break; // Solar Open
                     case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer)
                     default: type = LLM_TYPE_UNKNOWN;
                 }
@@ -1896,6 +1918,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
                     hparams.n_swa = 4096;
                     hparams.set_swa_pattern(4);
+
+                    hparams.rope_freq_base_train_swa  = hparams.rope_freq_base_train;
+                    hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
+                    ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
                 }
 
                 ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW,    hparams.n_swa, false);
@@ -2198,6 +2224,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
                 hparams.set_swa_pattern(2);
 
+                hparams.rope_freq_base_train_swa  = hparams.rope_freq_base_train;
+                hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
+                ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
+
                 switch (hparams.n_layer) {
                     case 24: type = LLM_TYPE_20B; break;
                     case 36: type = LLM_TYPE_120B; break;
@@ -2242,6 +2272,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     hparams.swa_type      = LLAMA_SWA_TYPE_STANDARD;
                     hparams.n_swa         = 4096;
                     hparams.set_swa_pattern(4, true);
+
+                    hparams.rope_freq_base_train_swa  = hparams.rope_freq_base_train;
+                    hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
+                    ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
                 } else {
                     hparams.swa_type             = LLAMA_SWA_TYPE_NONE;
                     hparams.n_no_rope_layer_step = hparams.n_layer;
@@ -2406,7 +2440,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
 
     const bool use_mmap_buffer = true;
 
-    LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s)\n", __func__, ml.use_mmap ? "true" : "false");
+    LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s, direct_io = %s)\n",
+        __func__, ml.use_mmap ? "true" : "false", ml.use_direct_io ? "true" : "false");
 
     // build a list of buffer types for the CPU and GPU devices
     pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts, params.no_host);
@@ -2417,6 +2452,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
         pimpl->gpu_buft_list.emplace(dev, std::move(buft_list));
     }
 
+    ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
+    if (cpu_dev == nullptr) {
+        throw std::runtime_error(format("%s: no CPU backend found", __func__));
+    }
+
     // calculate the split points
     bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + n_devices(), [](float x) { return x == 0.0f; });
     std::vector<float> splits(n_devices());
@@ -2427,6 +2467,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
             size_t total;
             size_t free;
             ggml_backend_dev_memory(dev, &free, &total);
+
+            // devices can return 0 bytes for free and total memory if they do not
+            // have any to report. in this case, we will use the host memory as a fallback
+            // fixes: https://github.com/ggml-org/llama.cpp/issues/18577
+            if (free == 0 && total == 0) {
+                ggml_backend_dev_memory(cpu_dev, &free, &total);
+            }
             splits[i] = free;
         }
     } else {
@@ -2443,10 +2490,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
         splits[i] /= split_sum;
     }
 
-    ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
-    if (cpu_dev == nullptr) {
-        throw std::runtime_error(format("%s: no CPU backend found", __func__));
-    }
     const int i_gpu_start = std::max(int(hparams.n_layer) + 1 - n_gpu_layers, 0);
     const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, int(n_layer) + 1);
     auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev {
@@ -3320,7 +3363,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias",   i), {n_embd}, TENSOR_NOT_REQUIRED);
 
                         layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, layer.ffn_gate ? n_ff : n_ff * 2}, 0);
+
+                        const auto tn_ffn_up_weight = tn(LLM_TENSOR_FFN_UP, "weight", i);
+                        ggml_tensor * t_ffn_up = ml.get_tensor_meta(tn_ffn_up_weight.str().c_str());
+                        const int64_t n_ffn_up = t_ffn_up ? t_ffn_up->ne[1] : n_ff;
+
+                        GGML_ASSERT(n_ffn_up == n_ff || n_ffn_up == n_ff * 2);
+                        layer.ffn_up   = create_tensor(tn_ffn_up_weight, {n_embd, n_ffn_up}, 0);
+                        layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ffn_up}, TENSOR_NOT_REQUIRED);
 
                         layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
                         layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias",   i), {n_embd}, 0);
@@ -4776,7 +4826,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
 
                     // output
                     output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+                    // try to load output.weight, if not found, use token_embd (tied embeddings)
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    if (!output) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
 
                     for (int i = 0; i < n_layer; ++i) {
                         auto & layer = layers[i];
@@ -4839,7 +4893,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
 
                     // output
                     output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+                    // try to load output.weight, if not found, use token_embd (tied embeddings)
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    if (!output) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
 
                     for (int i = 0; i < n_layer; ++i) {
                         auto & layer = layers[i];
@@ -5206,9 +5264,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, flags);
                         layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, flags);
                         layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, flags);
-                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, flags);
-                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, flags);
-                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, flags);
+                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, TENSOR_NOT_REQUIRED | flags);
+                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, TENSOR_NOT_REQUIRED | flags);
+                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, TENSOR_NOT_REQUIRED | flags);
 
                         layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags);
 
@@ -6421,6 +6479,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                             layer.shortconv.out_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_OUTPROJ, "weight", i), {n_embd, n_embd}, 0);
                         }
                     }
+
+                    // for LFM2-ColBert-350M
+                    dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.get_n_embd_out()}, TENSOR_NOT_REQUIRED);
                 } break;
             case LLM_ARCH_SMALLTHINKER:
                 {
@@ -6702,7 +6763,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         } else {
                             // Linear attention (gated delta net) specific tensors
                             // Create tensors with calculated dimensions
-                            layer.ssm_in         = create_tensor(tn(LLM_TENSOR_SSM_IN,         "weight", i), { n_embd, qkvz_dim }, 0);
+                            // note: ssm_in is used by legacy GGUF
+                            layer.ssm_in         = create_tensor(tn(LLM_TENSOR_SSM_IN,         "weight", i), { n_embd, qkvz_dim }, TENSOR_NOT_REQUIRED);
+                            layer.wqkv           = create_tensor(tn(LLM_TENSOR_ATTN_QKV,       "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED);
+                            layer.wqkv_gate      = create_tensor(tn(LLM_TENSOR_ATTN_GATE,      "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED);
                             layer.ssm_conv1d     = create_tensor(tn(LLM_TENSOR_SSM_CONV1D,     "weight", i), { hparams.ssm_d_conv, conv_dim }, 0);
                             layer.ssm_dt         = create_tensor(tn(LLM_TENSOR_SSM_DT,         "bias",   i), { hparams.ssm_dt_rank }, 0);
                             layer.ssm_a          = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN,             i), { hparams.ssm_dt_rank }, 0);
@@ -6761,6 +6825,37 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED);
                     }
                 } break;
+            case LLM_ARCH_MAINCODER:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
+                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -7042,6 +7137,10 @@ void llama_model::print_info() const {
         LLAMA_LOG_INFO("%s: rope scaling     = %s\n",     __func__, rope_scaling_type.c_str());
         LLAMA_LOG_INFO("%s: freq_base_train  = %.1f\n",   __func__, hparams.rope_freq_base_train);
         LLAMA_LOG_INFO("%s: freq_scale_train = %g\n",     __func__, hparams.rope_freq_scale_train);
+        if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
+            LLAMA_LOG_INFO("%s: freq_base_swa    = %.1f\n",   __func__, hparams.rope_freq_base_train_swa);
+            LLAMA_LOG_INFO("%s: freq_scale_swa   = %g\n",     __func__, hparams.rope_freq_scale_train_swa);
+        }
         LLAMA_LOG_INFO("%s: n_ctx_orig_yarn  = %u\n",     __func__, hparams.n_ctx_orig_yarn);
         LLAMA_LOG_INFO("%s: rope_yarn_log_mul= %.4f\n",   __func__, hparams.rope_yarn_log_mul);
         LLAMA_LOG_INFO("%s: rope_finetuned   = %s\n",     __func__, hparams.rope_finetuned ? "yes" : "unknown");
@@ -7406,6 +7505,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
             {
                 llm = std::make_unique<llm_build_llama<true>>(*this, params);
             } break;
+        case LLM_ARCH_MAINCODER:
+            {
+                llm = std::make_unique<llm_build_maincoder>(*this, params);
+            } break;
         case LLM_ARCH_DECI:
             {
                 llm = std::make_unique<llm_build_deci>(*this, params);
@@ -7440,7 +7543,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
             } break;
         case LLM_ARCH_MODERN_BERT:
             {
-                llm = std::make_unique<llm_build_modern_bert<true>>(*this, params);
+                llm = std::make_unique<llm_build_modern_bert>(*this, params);
             } break;
         case LLM_ARCH_NEO_BERT:
             {
@@ -7850,12 +7953,17 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
     // add on pooling layer
     llm->build_pooling(cls, cls_b, cls_out, cls_out_b);
 
+    // add backend sampling layers (if any)
+    llm->build_sampling();
+
     // if the gguf model was converted with --sentence-transformers-dense-modules
     // there will be two additional dense projection layers
     // dense linear projections are applied after pooling
     // TODO: move reranking logic here and generalize
     llm->build_dense_out(dense_2_out_layers, dense_3_out_layers);
 
+    llm->res->set_outputs();
+
     return llm->res->get_gf();
 }
 
@@ -7877,6 +7985,7 @@ llama_model_params llama_model_default_params() {
         /*.kv_overrides                =*/ nullptr,
         /*.vocab_only                  =*/ false,
         /*.use_mmap                    =*/ true,
+        /*.use_direct_io               =*/ true,
         /*.use_mlock                   =*/ false,
         /*.check_tensors               =*/ false,
         /*.use_extra_bufts             =*/ true,
@@ -7911,6 +8020,10 @@ int32_t llama_model_n_embd_inp(const llama_model * model) {
     return model->hparams.n_embd_inp();
 }
 
+int32_t llama_model_n_embd_out(const llama_model * model) {
+    return model->hparams.get_n_embd_out();
+}
+
 int32_t llama_model_n_layer(const llama_model * model) {
     return model->hparams.n_layer;
 }
@@ -8014,6 +8127,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_ERNIE4_5_MOE:
         case LLM_ARCH_MISTRAL3:
         case LLM_ARCH_LLAMA_EMBED:
+        case LLM_ARCH_MAINCODER:
             return LLAMA_ROPE_TYPE_NORM;
 
         // the pairs of head values are offset by n_rot/2
index f4f44a92b63ad9787e25e431b12b2c3034a1f1ff..79200a0d97a84cadf82657335561ce9861028e73 100644 (file)
@@ -119,6 +119,7 @@ enum llm_type {
     LLM_TYPE_31B_A3_5B,
     LLM_TYPE_80B_A3B, // Qwen3 Next
     LLM_TYPE_100B_A6B,
+    LLM_TYPE_102B_A12B, // Solar-Open
     LLM_TYPE_106B_A12B, // GLM-4.5-Air
     LLM_TYPE_230B_A10B, // Minimax M2
     LLM_TYPE_235B_A22B,
index bc4b05c3b50b05023d8ce2af2bca55dc1c24e79a..048d65a75c21a2a955ff068835ad8e74329c2756 100644 (file)
@@ -596,7 +596,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
     }
 
     std::vector<std::string> splits = {};
-    llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr);
+    llama_model_loader ml(fname_inp, splits, use_mmap, /*use_direct_io*/ true, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr);
     ml.init_mappings(false); // no prefetching
 
     llama_model model(llama_model_default_params());
index f3891453e4b8a2b3f2a34cf7c227c36d6acf9549..11f0394c4ceb15daf5ab21c6c92af33a94b5575e 100644 (file)
@@ -4,6 +4,8 @@
 #include "llama-vocab.h"
 #include "llama-grammar.h"
 
+#include "ggml-cpp.h"
+
 #include <array>
 #include <algorithm>
 #include <cassert>
@@ -346,7 +348,9 @@ static uint32_t get_rng_seed(uint32_t seed) {
 
 // llama_sampler API
 
-struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) {
+struct llama_sampler * llama_sampler_init(
+        struct llama_sampler_i * iface,
+        llama_sampler_context_t ctx) {
     return new llama_sampler {
         /* .iface = */ iface,
         /* .ctx   = */ ctx,
@@ -421,6 +425,202 @@ void llama_sampler_free(struct llama_sampler * smpl) {
     delete smpl;
 }
 
+// empty sampler
+
+struct llama_sampler_empty {
+    const char * name;
+};
+
+static struct llama_sampler * llama_sampler_init_empty(const char * name);
+
+static const char * llama_sampler_empty_name(const struct llama_sampler * smpl) {
+    auto * ctx = (llama_sampler_empty *) smpl->ctx;
+    return ctx->name;
+}
+
+static void llama_sampler_empty_accept(struct llama_sampler * smpl, llama_token token) {
+    GGML_UNUSED(smpl);
+    GGML_UNUSED(token);
+}
+
+static void llama_sampler_empty_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    GGML_UNUSED(smpl);
+    GGML_UNUSED(cur_p);
+}
+
+static void llama_sampler_empty_reset(struct llama_sampler * smpl) {
+    GGML_UNUSED(smpl);
+}
+
+static struct llama_sampler * llama_sampler_empty_clone(const struct llama_sampler * smpl) {
+    auto * ctx = (llama_sampler_empty *) smpl->ctx;
+    return llama_sampler_init_empty(ctx->name);
+}
+
+static void llama_sampler_empty_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_empty *) smpl->ctx;
+}
+
+static bool llama_sampler_empty_backend_init(
+        struct llama_sampler       * smpl,
+        ggml_backend_buffer_type_t   buft) {
+    GGML_UNUSED(smpl);
+    GGML_UNUSED(buft);
+
+    return true;
+}
+
+static void llama_sampler_empty_backend_accept(
+        struct llama_sampler * smpl,
+        ggml_context * ctx,
+        ggml_cgraph * gf,
+        struct ggml_tensor * selected_token) {
+    GGML_UNUSED(smpl);
+    GGML_UNUSED(ctx);
+    GGML_UNUSED(gf);
+    GGML_UNUSED(selected_token);
+}
+
+static void llama_sampler_empty_backend_apply(
+          struct llama_sampler      * smpl,
+          struct ggml_context       * ctx,
+          struct ggml_cgraph        * gf,
+          struct llama_sampler_data * data) {
+    GGML_UNUSED(smpl);
+    GGML_UNUSED(ctx);
+    GGML_UNUSED(gf);
+    GGML_UNUSED(data);
+}
+
+static void llama_sampler_empty_backend_set_input(struct llama_sampler * smpl) {
+    GGML_UNUSED(smpl);
+}
+
+static struct llama_sampler_i llama_sampler_empty_i = {
+    /* .name              = */ llama_sampler_empty_name,
+    /* .accept            = */ llama_sampler_empty_accept,
+    /* .apply             = */ llama_sampler_empty_apply,
+    /* .reset             = */ llama_sampler_empty_reset,
+    /* .clone             = */ llama_sampler_empty_clone,
+    /* .free              = */ llama_sampler_empty_free,
+    /* .backend_init      = */ llama_sampler_empty_backend_init,
+    /* .backend_accept    = */ llama_sampler_empty_backend_accept,
+    /* .backend_apply     = */ llama_sampler_empty_backend_apply,
+    /* .backend_set_input = */ llama_sampler_empty_backend_set_input,
+};
+
+struct llama_sampler * llama_sampler_init_empty(const char * name) {
+    return llama_sampler_init(
+        /* .iface = */ &llama_sampler_empty_i,
+        /* .ctx   = */ new llama_sampler_empty {
+            /* .name = */ name,
+        }
+    );
+}
+
+// common backend sampler functionality
+//
+// +name : means that the sampler is support and will run on the backend
+// -name : means that a ggml operator is not supported by the backend
+//
+struct llama_sampler_backend {
+    llama_sampler_backend(const char * name) : name(name), name_ext(name), is_init(false), support(false) {}
+
+    const char * get_name() {
+        if (!is_init) {
+            return name.c_str();
+        }
+
+        if (support) {
+            name_ext = "+" + name;
+        } else {
+            name_ext = "-" + name;
+        }
+
+        return name_ext.c_str();
+    }
+
+    void init(bool support) {
+        GGML_ASSERT(this->is_init == false);
+
+        this->is_init = true;
+        this->support = support;
+    }
+
+private:
+    std::string name;
+    std::string name_ext;
+
+    bool is_init;
+    bool support;
+};
+
+// check if all ggml ops used by the sampler are supported by the backend
+static bool llama_sampler_backend_support(
+        llama_sampler              * smpl,
+        ggml_backend_buffer_type_t   buft) {
+    auto * device = ggml_backend_buft_get_device(buft);
+    if (!device) {
+        // CPU backend always supported
+        return true;
+    }
+
+    ggml_init_params params = {
+        /*.mem_size   =*/ 128*ggml_tensor_overhead() + ggml_graph_overhead(),
+        /*.mem_buffer =*/ NULL,
+        /*.no_alloc   =*/ true,
+    };
+
+    ggml_context_ptr ctx_ptr { ggml_init(params) };
+    if (!ctx_ptr) {
+        throw std::runtime_error(format("failed to create ggml context"));
+    }
+
+    ggml_context * ctx = ctx_ptr.get();
+
+    const int64_t n = 1024*1024;
+
+    llama_sampler_data data = {
+        /*.logits     = */ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n),
+        /*.probs      = */ nullptr,
+        /*.sampled    = */ nullptr,
+        /*.candidates = */ ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n),
+    };
+
+    ggml_cgraph * gf = ggml_new_graph(ctx);
+
+    smpl->iface->backend_apply(smpl, ctx, gf, &data);
+
+    if (data.logits) {
+        ggml_build_forward_expand(gf, data.logits);
+    }
+
+    if (data.probs) {
+        ggml_build_forward_expand(gf, data.probs);
+    }
+
+    if (data.sampled) {
+        ggml_build_forward_expand(gf, data.sampled);
+    }
+
+    if (data.candidates) {
+        ggml_build_forward_expand(gf, data.candidates);
+    }
+
+    for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
+        struct ggml_tensor * op = ggml_graph_node(gf, i);
+
+        if (!ggml_backend_dev_supports_op(device, op)) {
+            LLAMA_LOG_WARN("%s: device '%s' does not have support for op %s needed for sampler '%s'\n",
+                    __func__, ggml_backend_dev_name(device), ggml_op_name(op->op), smpl->iface->name(smpl));
+
+            return false;
+        }
+    }
+
+    return true;
+}
+
 // sampler chain
 
 static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
@@ -432,8 +632,8 @@ static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token
 
     time_meas tm(chain->t_sample_us, chain->params.no_perf);
 
-    for (auto * smpl : chain->samplers) {
-        llama_sampler_accept(smpl, token);
+    for (auto & smpl : chain->samplers) {
+        llama_sampler_accept(smpl.ptr, token);
     }
 
     chain->n_sample++;
@@ -444,16 +644,28 @@ static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_d
 
     time_meas tm(chain->t_sample_us, chain->params.no_perf);
 
-    for (auto * smpl : chain->samplers) {
-        llama_sampler_apply(smpl, cur_p);
+    bool is_backend = chain->is_init;
+
+    for (auto & smpl : chain->samplers) {
+        if (is_backend && smpl.is_backend) {
+            continue;
+        }
+
+        is_backend = false;
+
+        if (smpl.ptr->iface->apply == nullptr) {
+            continue;
+        }
+
+        llama_sampler_apply(smpl.ptr, cur_p);
     }
 }
 
 static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
     auto * chain = (llama_sampler_chain *) smpl->ctx;
 
-    for (auto * smpl : chain->samplers) {
-        llama_sampler_reset(smpl);
+    for (auto & smpl : chain->samplers) {
+        llama_sampler_reset(smpl.ptr);
     }
 }
 
@@ -462,8 +674,8 @@ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampl
 
     auto * result = llama_sampler_chain_init(chain_src->params);
 
-    for (auto * smpl : chain_src->samplers) {
-        llama_sampler_chain_add(result, llama_sampler_clone(smpl));
+    for (const auto & smpl : chain_src->samplers) {
+        llama_sampler_chain_add(result, llama_sampler_clone(smpl.ptr));
     }
 
     return result;
@@ -472,20 +684,109 @@ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampl
 static void llama_sampler_chain_free(struct llama_sampler * smpl) {
     auto * chain = (llama_sampler_chain *) smpl->ctx;
 
-    for (auto * smpl : chain->samplers) {
-        llama_sampler_free(smpl);
+    for (auto & smpl : chain->samplers) {
+        llama_sampler_free(smpl.ptr);
     }
 
     delete chain;
 }
 
+static bool llama_sampler_chain_backend_init(
+        struct llama_sampler       * smpl,
+        ggml_backend_buffer_type_t   buft) {
+    auto * chain = (llama_sampler_chain *) smpl->ctx;
+
+    GGML_ASSERT(chain->is_init == false && "llama_sampler_chain_backend_init() called twice");
+
+    chain->is_init = true;
+
+    bool res = true;
+
+    for (auto & smpl : chain->samplers) {
+        bool res_cur = true;
+
+        // to be able to run a sampler on the backend, it has to:
+        // - have the .backend_init() API implemented
+        // - return true during .backend_init()
+        if (smpl.ptr->iface->backend_init) {
+            if (!smpl.ptr->iface->backend_init(smpl.ptr, buft)) {
+                res_cur = false;
+            }
+        } else {
+            res_cur = false;
+        }
+
+        smpl.is_backend = res_cur;
+
+        res = res && res_cur;
+    }
+
+    return res;
+}
+
+static void llama_sampler_chain_backend_accept(
+        struct llama_sampler * smpl,
+        ggml_context * ctx,
+        ggml_cgraph * gf,
+        struct ggml_tensor * selected_token) {
+    auto * chain = (llama_sampler_chain *) smpl->ctx;
+
+    for (auto & smpl : chain->samplers) {
+        if (!smpl.is_backend) {
+            break;
+        }
+
+        if (smpl.ptr->iface->backend_accept) {
+            smpl.ptr->iface->backend_accept(smpl.ptr, ctx, gf, selected_token);
+        }
+    }
+}
+
+static void llama_sampler_chain_backend_apply(
+          struct llama_sampler      * smpl,
+          struct ggml_context       * ctx,
+          struct ggml_cgraph        * gf,
+          struct llama_sampler_data * data) {
+    auto * chain = (llama_sampler_chain *) smpl->ctx;
+
+    GGML_ASSERT(chain->is_init && "llama_sampler_chain_backend_init() not called");
+
+    for (auto & smpl : chain->samplers) {
+        if (!smpl.is_backend) {
+            break;
+        }
+
+        if (smpl.ptr->iface->backend_apply) {
+            smpl.ptr->iface->backend_apply(smpl.ptr, ctx, gf, data);
+        }
+    }
+}
+
+static void llama_sampler_chain_backend_set_input(struct llama_sampler * smpl) {
+    auto * chain = (llama_sampler_chain *) smpl->ctx;
+
+    for (auto & smpl : chain->samplers) {
+        if (!smpl.is_backend) {
+            break;
+        }
+
+        if (smpl.ptr->iface->backend_set_input) {
+            smpl.ptr->iface->backend_set_input(smpl.ptr);
+        }
+    }
+}
+
 static struct llama_sampler_i llama_sampler_chain_i = {
-    /* .name   = */ llama_sampler_chain_name,
-    /* .accept = */ llama_sampler_chain_accept,
-    /* .apply  = */ llama_sampler_chain_apply,
-    /* .reset  = */ llama_sampler_chain_reset,
-    /* .clone  = */ llama_sampler_chain_clone,
-    /* .free   = */ llama_sampler_chain_free,
+    /* .name              = */ llama_sampler_chain_name,
+    /* .accept            = */ llama_sampler_chain_accept,
+    /* .apply             = */ llama_sampler_chain_apply,
+    /* .reset             = */ llama_sampler_chain_reset,
+    /* .clone             = */ llama_sampler_chain_clone,
+    /* .free              = */ llama_sampler_chain_free,
+    /* .backend_init      = */ llama_sampler_chain_backend_init,
+    /* .backend_accept    = */ llama_sampler_chain_backend_accept,
+    /* .backend_apply     = */ llama_sampler_chain_backend_apply,
+    /* .backend_set_input = */ llama_sampler_chain_backend_set_input,
 };
 
 struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
@@ -493,6 +794,7 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param
         /* .iface = */ &llama_sampler_chain_i,
         /* .ctx   = */ new llama_sampler_chain {
             /* .params      = */ params,
+            /* .is_init     = */ false,
             /* .samplers    = */ {},
             /* .cur         = */ {},
             /* .t_sample_us = */ 0,
@@ -502,7 +804,16 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param
 }
 
 llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
-    const auto * logits = llama_get_logits_ith(ctx, idx);
+    const llama_token   sampled_token  = llama_get_sampled_token_ith     (ctx, idx);
+    const float *       sampled_probs  = llama_get_sampled_probs_ith     (ctx, idx);
+    const float *       sampled_logits = llama_get_sampled_logits_ith    (ctx, idx);
+    const llama_token * sampled_ids    = llama_get_sampled_candidates_ith(ctx, idx);
+
+    // If a backend sampler has already sampled a token, return it.
+    if (sampled_token != LLAMA_TOKEN_NULL) {
+        LLAMA_LOG_DEBUG("%s: Backend sampler selected token for idx %d. Skipping CPU samplers\n", __func__, idx);
+        return sampled_token;
+    }
 
     const llama_model * model = llama_get_model(ctx);
     const llama_vocab * vocab = llama_model_get_vocab(model);
@@ -521,9 +832,26 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte
     }
 
     auto & cur = *cur_ptr;
-    cur.resize(n_vocab);
-    for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
-        cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
+
+    if (sampled_probs) {
+        const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx);
+        cur.resize(sampled_probs_count);
+        for (uint32_t i = 0; i < sampled_probs_count; ++i) {
+            cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]};
+        }
+    } else if (sampled_logits) {
+        const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx);
+        cur.resize(sampled_logits_count);
+        for (llama_token i = 0; i < (int)sampled_logits_count; i++) {
+            cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f};
+        }
+    } else {
+        const auto * logits = llama_get_logits_ith(ctx, idx);
+        GGML_ASSERT(logits != nullptr);
+        cur.resize(n_vocab);
+        for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+            cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
+        }
     }
 
     llama_token_data_array cur_p = {
@@ -544,19 +872,35 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte
     return token;
 }
 
+
 void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
     auto * p = (llama_sampler_chain *) chain->ctx;
-    p->samplers.push_back(smpl);
+    p->samplers.push_back({
+        /* .is_backend = */ false,
+        /* .ptr        = */ smpl,
+    });
 }
 
-struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
+struct llama_sampler * llama_sampler_chain_get(struct llama_sampler * chain, int32_t i) {
+    if (chain == nullptr) {
+        return nullptr;
+    }
+
+    if (chain->iface != &llama_sampler_chain_i) {
+        return nullptr;
+    }
+
+    if (i == -1) {
+        return chain;
+    }
+
     const auto * p = (const llama_sampler_chain *) chain->ctx;
 
     if (i < 0 || (size_t) i >= p->samplers.size()) {
         return nullptr;
     }
 
-    return p->samplers[i];
+    return p->samplers[i].ptr;
 }
 
 struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
@@ -566,7 +910,7 @@ struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain,
         return nullptr;
     }
 
-    auto * result = p->samplers[i];
+    auto * result = p->samplers[i].ptr;
     p->samplers.erase(p->samplers.begin() + i);
 
     return result;
@@ -584,8 +928,36 @@ int llama_sampler_chain_n(const struct llama_sampler * chain) {
 
 // greedy
 
-static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) {
-    return "greedy";
+struct llama_sampler_greedy : public llama_sampler_backend {
+};
+
+static const char * llama_sampler_greedy_name(const struct llama_sampler * smpl) {
+    auto * sctx = (llama_sampler_greedy *) smpl->ctx;
+    return sctx->get_name();
+}
+
+static void llama_sampler_greedy_reset(struct llama_sampler * smpl) {
+    auto * ctx = (llama_sampler_greedy *) smpl->ctx;
+    GGML_UNUSED(ctx);
+}
+
+static struct llama_sampler * llama_sampler_greedy_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_greedy *) smpl->ctx;
+    auto * result = llama_sampler_init_greedy();
+
+    // copy the state
+    {
+        auto * result_ctx = (llama_sampler_greedy *) result->ctx;
+
+        GGML_UNUSED(ctx);
+        GGML_UNUSED(result_ctx);
+    }
+
+    return result;
+}
+
+static void llama_sampler_greedy_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_greedy *) smpl->ctx;
 }
 
 static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
@@ -597,33 +969,72 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to
     }
 }
 
+static bool llama_sampler_greedy_backend_init(
+        struct llama_sampler       * smpl,
+        ggml_backend_buffer_type_t   buft) {
+    auto * sctx = (llama_sampler_greedy *) smpl->ctx;
+
+    const bool res = llama_sampler_backend_support(smpl, buft);
+
+    sctx->init(res);
+
+    return res;
+}
+
+static void llama_sampler_greedy_backend_apply(
+        struct llama_sampler      * smpl,
+        struct ggml_context       * ctx,
+        struct ggml_cgraph        * gf,
+        struct llama_sampler_data * data) {
+    GGML_UNUSED(gf);
+    GGML_UNUSED(smpl);
+
+    struct ggml_tensor * curl = ggml_argmax(ctx, data->logits);
+    ggml_set_name(curl, "greedy_argmax");
+
+    data->sampled = curl;
+}
+
 static struct llama_sampler_i llama_sampler_greedy_i = {
-    /* .name   = */ llama_sampler_greedy_name,
-    /* .accept = */ nullptr,
-    /* .apply  = */ llama_sampler_greedy_apply,
-    /* .reset  = */ nullptr,
-    /* .clone  = */ nullptr,
-    /* .free   = */ nullptr,
+    /* .name              = */ llama_sampler_greedy_name,
+    /* .accept            = */ nullptr,
+    /* .apply             = */ llama_sampler_greedy_apply,
+    /* .reset             = */ llama_sampler_greedy_reset,
+    /* .clone             = */ llama_sampler_greedy_clone,
+    /* .free              = */ llama_sampler_greedy_free,
+    /* .backend_init      = */ llama_sampler_greedy_backend_init,
+    /* .backend_accept    = */ nullptr,
+    /* .backend_apply     = */ llama_sampler_greedy_backend_apply,
+    /* .backend_set_input = */ nullptr,
 };
 
 struct llama_sampler * llama_sampler_init_greedy() {
     return llama_sampler_init(
         /* .iface = */ &llama_sampler_greedy_i,
-        /* .ctx   = */ nullptr
+        /* .ctx   = */ new llama_sampler_greedy {
+            ("greedy"),
+        }
     );
 }
 
 // dist
 
-struct llama_sampler_dist {
+struct llama_sampler_dist : public llama_sampler_backend {
     const uint32_t seed;
           uint32_t seed_cur;
 
     std::mt19937 rng;
+
+    // backend input
+    struct ggml_tensor * inp_uniform;
+
+    ggml_context_ptr        inp_ctx;
+    ggml_backend_buffer_ptr inp_buf;
 };
 
-static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) {
-    return "dist";
+static const char * llama_sampler_dist_name(const struct llama_sampler * smpl) {
+    auto * sctx = (llama_sampler_dist *) smpl->ctx;
+    return sctx->get_name();
 }
 
 static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@@ -698,6 +1109,12 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da
 #endif
 }
 
+static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
+    auto * ctx = (llama_sampler_dist *) smpl->ctx;
+    ctx->seed_cur = get_rng_seed(ctx->seed);
+    ctx->rng.seed(ctx->seed_cur);
+}
+
 static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
     const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
     auto * result = llama_sampler_init_dist(ctx->seed);
@@ -712,23 +1129,127 @@ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sample
     return result;
 }
 
-static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
-    auto * ctx = (llama_sampler_dist *) smpl->ctx;
-    ctx->seed_cur = get_rng_seed(ctx->seed);
-    ctx->rng.seed(ctx->seed_cur);
-}
-
 static void llama_sampler_dist_free(struct llama_sampler * smpl) {
     delete (llama_sampler_dist *) smpl->ctx;
 }
 
+static bool llama_sampler_dist_backend_init(
+        struct llama_sampler       * smpl,
+        ggml_backend_buffer_type_t   buft) {
+    auto * sctx = (llama_sampler_dist *) smpl->ctx;
+
+    // allocate inputs
+    {
+        ggml_init_params params = {
+            /*.mem_size   =*/ ggml_tensor_overhead(),
+            /*.mem_buffer =*/ nullptr,
+            /*.no_alloc   =*/ true,
+        };
+
+        sctx->inp_ctx.reset(ggml_init(params));
+
+        // Create the uniform random scalar input tensor. This will be set by
+        // llama_sampler_dist_backend_set_input after this graph is built.
+        sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1);
+        ggml_set_name (sctx->inp_uniform, "uniform");
+        ggml_set_input(sctx->inp_uniform);
+
+        // Allocate all tensors from our context to the backend
+        sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
+
+        ggml_backend_buffer_clear(sctx->inp_buf.get(), 0);
+    }
+
+    const bool res = llama_sampler_backend_support(smpl, buft);
+
+    sctx->init(res);
+
+    if (!res) {
+        sctx->inp_ctx.reset(nullptr);
+        sctx->inp_buf.reset(nullptr);
+    }
+
+    return res;
+}
+
+static void llama_sampler_dist_backend_apply(
+        struct llama_sampler      * smpl,
+        struct ggml_context       * ctx,
+        struct ggml_cgraph        * gf,
+        struct llama_sampler_data * data) {
+    GGML_UNUSED(gf);
+    auto * sctx = (llama_sampler_dist *) smpl->ctx;
+
+    struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
+    ggml_set_name(probs, "dist_probs");
+
+    struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs);
+    ggml_set_name(cumsum, "dist_cumsum");
+
+    // The uniform tensor has a random value and we subtract this tensor with
+    // the cumsum tensor (the uniform tensor will be broadcasted by ggml_sub).
+    // Recall that each entry in cumsum is the cumulative probability up to that
+    // index so values stay negative while the cumulative total is below the
+    // random value, and become zero/positive once the threshold is crossed.
+    struct ggml_tensor * diff = ggml_sub(ctx, cumsum, sctx->inp_uniform);
+    ggml_set_name(diff, "dist_cumsum");
+
+    // The ggml_step function produces a tensor where entries are 1 if the
+    // corresponding entry in diff is > 0, and 0 otherwise. So all values up to
+    // the index where the cumulative probability exceeds the random value are 0,
+    // and all entries after that are 1.
+    struct ggml_tensor * mask = ggml_step(ctx, diff);
+    ggml_set_name(mask, "dist_mask");
+
+    // Taking the sum of the mask gives us the sum of elements after the threshold
+    // we are interested in.
+    struct ggml_tensor * idxf = ggml_sum(ctx, mask);
+    ggml_set_name(idxf, "dist_index_f32");
+
+    // Use ggml_scale_bias to scale the index value by -1 and then add the size
+    // of the mask to that value so we get the correct index ((-1 * idxf) + n).
+    struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32);
+    ggml_set_name(idx, "dist_index_i32");
+
+    // Map back to original vocab ids if a candidates tensor is available.
+    struct ggml_tensor * sampled_token = idx;
+    if (data->candidates != nullptr) {
+        struct ggml_tensor * candidates = ggml_reshape_2d(ctx, data->candidates, 1, ggml_nelements(data->candidates));
+
+        sampled_token = ggml_get_rows(ctx, candidates, idx);
+        ggml_set_name(sampled_token, "dist_sampled_token");
+    }
+
+    data->sampled = sampled_token;
+    data->probs = probs;
+}
+
+static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) {
+    auto * sctx = (llama_sampler_dist *) smpl->ctx;
+    GGML_ASSERT(sctx->inp_uniform != nullptr);
+
+    // We sample in double precision and cast to float to match rnd numbers of
+    // llama_dampler_dist which uses double precision (sampling from
+    // std::uniform_real_distribution<double> and
+    // std::uniform_real_distribution<float> with same rng will produce
+    // different sequences).
+    std::uniform_real_distribution<double> dist(0.0f, 1.0f);
+    const float rnd = dist(sctx->rng);
+
+    ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float));
+}
+
 static struct llama_sampler_i llama_sampler_dist_i = {
-    /* .name   = */ llama_sampler_dist_name,
-    /* .accept = */ nullptr,
-    /* .apply  = */ llama_sampler_dist_apply,
-    /* .reset  = */ llama_sampler_dist_reset,
-    /* .clone  = */ llama_sampler_dist_clone,
-    /* .free   = */ llama_sampler_dist_free,
+    /* .name              = */ llama_sampler_dist_name,
+    /* .accept            = */ nullptr,
+    /* .apply             = */ llama_sampler_dist_apply,
+    /* .reset             = */ llama_sampler_dist_reset,
+    /* .clone             = */ llama_sampler_dist_clone,
+    /* .free              = */ llama_sampler_dist_free,
+    /* .backend_init      = */ llama_sampler_dist_backend_init,
+    /* .backend_accept    = */ nullptr,
+    /* .backend_apply     = */ llama_sampler_dist_backend_apply,
+    /* .backend_set_input = */ llama_sampler_dist_backend_set_input,
 };
 
 struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
@@ -736,21 +1257,26 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
     return llama_sampler_init(
         /* .iface = */ &llama_sampler_dist_i,
         /* .ctx   = */ new llama_sampler_dist {
-            /* .seed     = */ seed,
-            /* .seed_cur = */ seed_cur,
-            /* .rng      = */ std::mt19937(seed_cur),
+            ("dist"),
+            /* .seed        = */ seed,
+            /* .seed_cur    = */ seed_cur,
+            /* .rng         = */ std::mt19937(seed_cur),
+            /* .inp_uniform = */ nullptr,
+            /* .inp_ctx     = */ nullptr,
+            /* .inp_buf     = */ nullptr,
         }
     );
 }
 
 // top-k
 
-struct llama_sampler_top_k {
+struct llama_sampler_top_k : public llama_sampler_backend {
     const int32_t k;
 };
 
-static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) {
-    return "top-k";
+static const char * llama_sampler_top_k_name(const struct llama_sampler * smpl) {
+    auto * sctx = (llama_sampler_top_k *) smpl->ctx;
+    return sctx->get_name();
 }
 
 static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@@ -767,19 +1293,69 @@ static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
     delete (llama_sampler_top_k *) smpl->ctx;
 }
 
+static bool llama_sampler_top_k_backend_init(
+        struct llama_sampler       * smpl,
+        ggml_backend_buffer_type_t   buft) {
+    auto * sctx = (llama_sampler_top_k *) smpl->ctx;
+
+    const bool res = llama_sampler_backend_support(smpl, buft);
+
+    sctx->init(res);
+
+    return res;
+}
+
+static void llama_sampler_top_k_backend_apply(
+        struct llama_sampler      * smpl,
+        struct ggml_context       * ctx,
+        struct ggml_cgraph        * gf,
+        struct llama_sampler_data * data) {
+    auto * sctx = (llama_sampler_top_k *) smpl->ctx;
+
+    struct ggml_tensor * top_k = ggml_top_k(ctx, data->logits, sctx->k);
+    ggml_set_name(top_k, "top_k");
+
+    if (data->candidates) {
+        struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]);
+        data->candidates = ggml_get_rows(ctx, candidates_rows, top_k);
+        data->candidates = ggml_reshape_1d(ctx, data->candidates, sctx->k);
+        ggml_set_name(data->candidates, "top_k_candidates");
+    } else {
+        data->candidates = top_k;
+    }
+
+    struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
+    struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k);
+    data->logits = ggml_reshape_1d(ctx, top_k_rows, sctx->k);
+    ggml_set_name(top_k_rows, "top_k_rows");
+
+    GGML_UNUSED(gf);
+}
+
 static struct llama_sampler_i llama_sampler_top_k_i = {
-    /* .name   = */ llama_sampler_top_k_name,
-    /* .accept = */ nullptr,
-    /* .apply  = */ llama_sampler_top_k_apply,
-    /* .reset  = */ nullptr,
-    /* .clone  = */ llama_sampler_top_k_clone,
-    /* .free   = */ llama_sampler_top_k_free,
+    /* .name              = */ llama_sampler_top_k_name,
+    /* .accept            = */ nullptr,
+    /* .apply             = */ llama_sampler_top_k_apply,
+    /* .reset             = */ nullptr,
+    /* .clone             = */ llama_sampler_top_k_clone,
+    /* .free              = */ llama_sampler_top_k_free,
+    /* .backend_init      = */ llama_sampler_top_k_backend_init,
+    /* .backend_accept    = */ nullptr,
+    /* .backend_apply     = */ llama_sampler_top_k_backend_apply,
+    /* .backend_set_input = */ nullptr,
 };
 
 struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
+    const bool is_empty = (k <= 0);
+
+    if (is_empty) {
+        return llama_sampler_init_empty("?top-k");
+    }
+
     return llama_sampler_init(
         /* .iface = */ &llama_sampler_top_k_i,
         /* .ctx   = */ new llama_sampler_top_k {
+            ("top-k"),
             /* .k = */ k,
         }
     );
@@ -787,15 +1363,16 @@ struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
 
 // top-p
 
-struct llama_sampler_top_p {
+struct llama_sampler_top_p : public llama_sampler_backend {
     const float  p;
     const size_t min_keep;
 
     std::vector<llama_token_data> buf_sort;
 };
 
-static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) {
-    return "top-p";
+static const char * llama_sampler_top_p_name(const struct llama_sampler * smpl) {
+    auto * sctx = (llama_sampler_top_p *) smpl->ctx;
+    return sctx->get_name();
 }
 
 static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@@ -862,19 +1439,118 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
     delete (llama_sampler_top_p *) smpl->ctx;
 }
 
+static bool llama_sampler_top_p_backend_init(
+        struct llama_sampler       * smpl,
+        ggml_backend_buffer_type_t   buft) {
+    auto * sctx = (llama_sampler_top_p *) smpl->ctx;
+
+    const bool res = llama_sampler_backend_support(smpl, buft);
+
+    sctx->init(res);
+
+    return res;
+}
+
+static void llama_sampler_top_p_backend_apply(
+        struct llama_sampler      * smpl,
+        struct ggml_context       * ctx,
+        struct ggml_cgraph        * gf,
+        struct llama_sampler_data * data) {
+    auto * sctx = (llama_sampler_top_p *) smpl->ctx;
+
+    auto ggml_sort = [ctx](struct ggml_tensor * a, struct ggml_tensor * b) {
+        GGML_ASSERT(ggml_nrows(a) == 1);
+        struct ggml_tensor * a_reshaped = ggml_reshape_2d(ctx, a, 1, a->ne[0]);
+        struct ggml_tensor * a_sorted   = ggml_get_rows(ctx, a_reshaped, b);
+        return ggml_reshape_1d(ctx, a_sorted, a->ne[0]);
+    };
+
+    // Get the sorted logits in descending order.
+    struct ggml_tensor * sorted_idx = ggml_argsort(ctx, data->logits, GGML_SORT_ORDER_DESC);
+    ggml_set_name(sorted_idx, "top_p_sorted_idx");
+
+    // Do the sorting via reshape + get_rows
+    struct ggml_tensor * sorted_logits = ggml_sort(data->logits, sorted_idx);
+    ggml_set_name(sorted_logits, "top_p_sorted_logits");
+
+    struct ggml_tensor * softmax = ggml_soft_max(ctx, sorted_logits);
+    ggml_set_name(softmax, "top_p_softmax");
+
+    // If candidates are provided, sort them as well. Otherwise, set sorted indices as candidates.
+    if (data->candidates) {
+        data->candidates = ggml_sort(data->candidates, sorted_idx);
+    } else {
+        data->candidates = sorted_idx;
+    }
+    ggml_set_name(data->candidates, "top_p_candidates");
+
+    // Compute Cumulative Distribution Function (CDF) by means of GGML_OP_CUMSUM.
+    struct ggml_tensor * cdf = ggml_cumsum(ctx, softmax);
+    ggml_set_name(cdf, "top_p_cdf");
+
+    // Invert CDF and add top-p value so that ggml_step yields 1 for values we want to keep
+    struct ggml_tensor * cdf_scaled = ggml_scale_bias(ctx, cdf, -1.0f, sctx->p);
+    ggml_set_name(cdf_scaled, "top_p_cdf_scaled");
+
+    struct ggml_tensor * mask = ggml_step(ctx, cdf_scaled);
+    ggml_set_name(mask, "top_p_mask");
+
+    // Taking the sum of the mask gives us the sum of elements after the threshold
+    // we are interested in.
+    struct ggml_tensor * idxf = ggml_sum(ctx, mask);
+    ggml_set_name(idxf, "top_p_index_f32");
+
+    // prevent out-of-bounds access
+    idxf = ggml_clamp(ctx, idxf, 0.0f, mask->ne[0] - 1);
+
+    // construct ones tensor to set the value in the mask
+    struct ggml_tensor * ones = ggml_scale_bias(ctx, idxf, 0.0f, 1.0f);
+    ggml_set_name(ones, "top_p_ones");
+
+    // Make top-p inclusive (i.e. return all values such that cum_sum/cdf >= p)
+    struct ggml_tensor * mask_reshaped = ggml_reshape_2d(ctx, mask, 1, mask->ne[0]);
+
+    mask_reshaped = ggml_set_rows(ctx, mask_reshaped, ones, ggml_cast(ctx, idxf, GGML_TYPE_I32));
+    mask = ggml_reshape_1d(ctx, mask_reshaped, mask->ne[0]);
+
+    // Use ggml_scale_bias (output = (a * s) + b) which in this case becomes:
+    // top_p_bias = (mask * 1e9f) - 1e9f.
+    // So entries in the mask that we want to discard will become -1e9f, and
+    // others will be 0 (meaning that will not effect the logits).
+    const float large_val = 1e9f;
+    struct ggml_tensor * top_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val);
+    ggml_set_name(top_p_bias, "top_p_bias");
+
+    data->logits = ggml_add(ctx, sorted_logits, top_p_bias);
+    ggml_set_name(data->logits, "top_p_logits");
+
+    GGML_UNUSED(gf);
+}
+
 static struct llama_sampler_i llama_sampler_top_p_i = {
-    /* .name   = */ llama_sampler_top_p_name,
-    /* .accept = */ nullptr,
-    /* .apply  = */ llama_sampler_top_p_apply,
-    /* .reset  = */ nullptr,
-    /* .clone  = */ llama_sampler_top_p_clone,
-    /* .free   = */ llama_sampler_top_p_free,
+    /* .name              = */ llama_sampler_top_p_name,
+    /* .accept            = */ nullptr,
+    /* .apply             = */ llama_sampler_top_p_apply,
+    /* .reset             = */ nullptr,
+    /* .clone             = */ llama_sampler_top_p_clone,
+    /* .free              = */ llama_sampler_top_p_free,
+    /* .backend_init      = */ llama_sampler_top_p_backend_init,
+    /* .backend_accept    = */ nullptr,
+    /* .backend_apply     = */ llama_sampler_top_p_backend_apply,
+    /* .backend_set_input = */ nullptr,
 };
 
 struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
+    const bool is_empty = p >= 1.0f;
+
+    if (is_empty) {
+        return llama_sampler_init_empty("?top-p");
+    }
+
     return llama_sampler_init(
         /* .iface = */ &llama_sampler_top_p_i,
         /* .ctx   = */ new llama_sampler_top_p {
+            ("top-p"),
             /* .p        = */ p,
             /* .min_keep = */ min_keep,
             /* .buf_sort = */ {},
@@ -884,13 +1560,14 @@ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
 
 // min-p
 
-struct llama_sampler_min_p {
+struct llama_sampler_min_p : public llama_sampler_backend {
     const float  p;
     const size_t min_keep;
 };
 
-static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) {
-    return "min-p";
+static const char * llama_sampler_min_p_name(const struct llama_sampler * smpl) {
+    auto * sctx = (llama_sampler_min_p *) smpl->ctx;
+    return sctx->get_name();
 }
 
 static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@@ -956,19 +1633,85 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
     delete (llama_sampler_min_p *) smpl->ctx;
 }
 
+static bool llama_sampler_min_p_backend_init(
+        struct llama_sampler       * smpl,
+        ggml_backend_buffer_type_t   buft) {
+    auto * sctx = (llama_sampler_min_p *) smpl->ctx;
+
+    const bool res = llama_sampler_backend_support(smpl, buft);
+
+    sctx->init(res);
+
+    return res;
+}
+
+static void llama_sampler_min_p_backend_apply(
+        struct llama_sampler      * smpl,
+        struct ggml_context       * ctx,
+        struct ggml_cgraph        * gf,
+        struct llama_sampler_data * data) {
+    auto * sctx = (llama_sampler_min_p *) smpl->ctx;
+
+    struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
+    ggml_set_name(max_idx, "max_idx");
+
+    struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
+    ggml_set_name(logits_rows, "logits_rows");
+
+    struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_rows, max_idx);
+    ggml_set_name(max_logit, "max_logit");
+
+    // Calculate the threshold value.
+    struct ggml_tensor * threshold = ggml_scale_bias(ctx, max_logit, 1.0f, logf(sctx->p));
+    ggml_set_name(threshold, "min_p_threshold");
+
+    // Subtract the threshold from logits.
+    struct ggml_tensor * sub = ggml_sub(ctx, data->logits, threshold);
+
+    // Create a mask where logits below the threshold are 0 (discard),
+    // and others are 1 (keep).
+    struct ggml_tensor * mask = ggml_step(ctx, sub);
+    ggml_set_name(mask, "min_p_mask");
+
+    // Use ggml_scale_bias (output = (a * s) + b) which in this case becomes:
+    // min_p_bias = (mask * 1e9f) - 1e9f.
+    // So entries in the mask that we want to discard will become -1e9f, and
+    // others will be 0 (meaning that will not effect the logits).
+    const float large_val = 1e9f;
+    struct ggml_tensor * min_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val);
+    ggml_set_name(min_p_bias, "min_p_bias");
+
+    // Add the min_p bias to the logits.
+    data->logits = ggml_add(ctx, data->logits, min_p_bias);
+    ggml_set_name(data->logits, "min_p_logits");
+
+    GGML_UNUSED(gf);
+}
+
 static struct llama_sampler_i llama_sampler_min_p_i = {
-    /* .name   = */ llama_sampler_min_p_name,
-    /* .accept = */ nullptr,
-    /* .apply  = */ llama_sampler_min_p_apply,
-    /* .reset  = */ nullptr,
-    /* .clone  = */ llama_sampler_min_p_clone,
-    /* .free   = */ llama_sampler_min_p_free,
+    /* .name              = */ llama_sampler_min_p_name,
+    /* .accept            = */ nullptr,
+    /* .apply             = */ llama_sampler_min_p_apply,
+    /* .reset             = */ nullptr,
+    /* .clone             = */ llama_sampler_min_p_clone,
+    /* .free              = */ llama_sampler_min_p_free,
+    /* .backend_init      = */ llama_sampler_min_p_backend_init,
+    /* .backend_accept    = */ nullptr,
+    /* .backend_apply     = */ llama_sampler_min_p_backend_apply,
+    /* .backend_set_input = */ nullptr,
 };
 
 struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
+    const bool is_empty = (p <= 0.0f);
+
+    if (is_empty) {
+        return llama_sampler_init_empty("?min-p");
+    }
+
     return llama_sampler_init(
         /* .iface = */ &llama_sampler_min_p_i,
         /* .ctx   = */ new llama_sampler_min_p {
+            ("min-p"),
             /* .p        = */ p,
             /* .min_keep = */ min_keep,
         }
@@ -1056,15 +1799,25 @@ static void llama_sampler_typical_free(struct llama_sampler * smpl) {
 }
 
 static struct llama_sampler_i llama_sampler_typical_i = {
-    /* .name   = */ llama_sampler_typical_name,
-    /* .accept = */ nullptr,
-    /* .apply  = */ llama_sampler_typical_apply,
-    /* .reset  = */ nullptr,
-    /* .clone  = */ llama_sampler_typical_clone,
-    /* .free   = */ llama_sampler_typical_free,
+    /* .name              = */ llama_sampler_typical_name,
+    /* .accept            = */ nullptr,
+    /* .apply             = */ llama_sampler_typical_apply,
+    /* .reset             = */ nullptr,
+    /* .clone             = */ llama_sampler_typical_clone,
+    /* .free              = */ llama_sampler_typical_free,
+    /* .backend_init      = */ nullptr,
+    /* .backend_accept    = */ nullptr,
+    /* .backend_apply     = */ nullptr,
+    /* .backend_set_input = */ nullptr,
 };
 
 struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
+    const bool is_empty = (p >= 1.0f);
+
+    if (is_empty) {
+        return llama_sampler_init_empty("?typical");
+    }
+
     return llama_sampler_init(
         /* .iface = */ &llama_sampler_typical_i,
         /* .ctx   = */ new llama_sampler_typical {
@@ -1076,12 +1829,13 @@ struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
 
 // temp
 
-struct llama_sampler_temp {
+struct llama_sampler_temp : public llama_sampler_backend {
     const float temp;
 };
 
-static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) {
-    return "temp";
+static const char * llama_sampler_temp_name(const struct llama_sampler * smpl) {
+    auto * sctx = (llama_sampler_temp *) smpl->ctx;
+    return sctx->get_name();
 }
 
 static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@@ -1099,19 +1853,79 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) {
     delete (llama_sampler_temp *) smpl->ctx;
 }
 
+static void llama_sampler_backend_temp_sampling(
+        struct ggml_context       * ctx,
+        struct ggml_cgraph        * gf,
+        struct llama_sampler_data * data,
+        float                       temp) {
+    if (temp <= 0.0f) {
+        // Find the most probable token index.
+        struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
+        ggml_set_name(max_idx, "temp_max_idx");
+
+        if (data->candidates) {
+            struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]);
+            data->candidates = ggml_get_rows(ctx, candidates_rows, max_idx);
+        } else {
+            data->candidates = max_idx;
+        }
+
+        struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
+        data->logits = ggml_get_rows(ctx, logits_rows, max_idx);
+
+        return;
+    }
+
+    data->logits = ggml_scale(ctx, data->logits, 1.0f / temp);
+
+    GGML_UNUSED(gf);
+}
+
+static bool llama_sampler_temp_backend_init(
+        struct llama_sampler       * smpl,
+        ggml_backend_buffer_type_t   buft) {
+    auto * sctx = (llama_sampler_temp *) smpl->ctx;
+
+    const bool res = llama_sampler_backend_support(smpl, buft);
+
+    sctx->init(res);
+
+    return res;
+}
+
+static void llama_sampler_temp_backend_apply(
+        struct llama_sampler      * smpl,
+        struct ggml_context       * ctx,
+        struct ggml_cgraph        * gf,
+        struct llama_sampler_data * data) {
+    auto * sctx = (llama_sampler_temp *) smpl->ctx;
+    llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp);
+}
+
 static struct llama_sampler_i llama_sampler_temp_i = {
-    /* .name   = */ llama_sampler_temp_name,
-    /* .accept = */ nullptr,
-    /* .apply  = */ llama_sampler_temp_apply,
-    /* .reset  = */ nullptr,
-    /* .clone  = */ llama_sampler_temp_clone,
-    /* .free   = */ llama_sampler_temp_free,
+    /* .name              = */ llama_sampler_temp_name,
+    /* .accept            = */ nullptr,
+    /* .apply             = */ llama_sampler_temp_apply,
+    /* .reset             = */ nullptr,
+    /* .clone             = */ llama_sampler_temp_clone,
+    /* .free              = */ llama_sampler_temp_free,
+    /* .backend_init      = */ llama_sampler_temp_backend_init,
+    /* .backend_accept    = */ nullptr,
+    /* .backend_apply     = */ llama_sampler_temp_backend_apply,
+    /* .backend_set_input = */ nullptr,
 };
 
 struct llama_sampler * llama_sampler_init_temp(float temp) {
+    const bool is_empty = temp == 1.0f;
+
+    if (is_empty) {
+        return llama_sampler_init_empty("?temp");
+    }
+
     return llama_sampler_init(
         /* .iface = */ &llama_sampler_temp_i,
         /* .ctx   = */ new llama_sampler_temp {
+            ("temp"),
             /*.temp = */ temp,
         }
     );
@@ -1119,14 +1933,15 @@ struct llama_sampler * llama_sampler_init_temp(float temp) {
 
 // temp-ext
 
-struct llama_sampler_temp_ext {
+struct llama_sampler_temp_ext : public llama_sampler_backend {
     const float temp;
     const float delta;
     const float exponent;
 };
 
-static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) {
-    return "temp-ext";
+static const char * llama_sampler_temp_ext_name(const struct llama_sampler * smpl) {
+    auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
+    return sctx->get_name();
 }
 
 static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@@ -1209,24 +2024,112 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
     delete (llama_sampler_temp_ext *) smpl->ctx;
 }
 
+static bool llama_sampler_temp_ext_backend_init(
+        struct llama_sampler       * smpl,
+        ggml_backend_buffer_type_t   buft) {
+    auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
+
+    const bool res = llama_sampler_backend_support(smpl, buft);
+
+    sctx->init(res);
+
+    return res;
+}
+
+static void llama_sampler_temp_ext_backend_apply(
+        struct llama_sampler      * smpl,
+        struct ggml_context       * ctx,
+        struct ggml_cgraph        * gf,
+        struct llama_sampler_data * data) {
+    auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
+
+    // Revert to standard temperature scaling if delta or temp are non-positive.
+    if (sctx->delta <= 0.0f || sctx->temp <= 0.0f) {
+        llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp);
+        return;
+    }
+
+    // Calculate min_temp, max_temp, and max_entropy.
+    const float min_temp    = std::max(0.0f, sctx->temp - sctx->delta);
+    const float max_temp    = sctx->temp + sctx->delta;
+    const float max_entropy = logf(data->logits->ne[0]);
+
+    // Calculate the probabilities.
+    struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
+    ggml_set_name(probs, "temp_ext_softmax_probs");
+
+    // Clamp probabilities to avoid log(0) which would give -inf
+    struct ggml_tensor * probs_clamped = ggml_clamp(ctx, probs, 1e-10f, 1.0f);
+    ggml_set_name(probs_clamped, "temp_ext_probs_clamped");
+
+    // Calculate the entropy, entropy = -Σ(p * log(p)).
+    struct ggml_tensor * log_probs   = ggml_log(ctx, probs_clamped);
+    struct ggml_tensor * p_log_p     = ggml_mul(ctx, probs_clamped, log_probs);
+    struct ggml_tensor * sum_p_log_p = ggml_sum(ctx, p_log_p);
+    struct ggml_tensor * entropy     = ggml_scale(ctx, sum_p_log_p, -1.0f);
+    ggml_set_name(log_probs,   "temp_ext_log_probs");
+    ggml_set_name(p_log_p,     "temp_ext_p_log_p");
+    ggml_set_name(sum_p_log_p, "temp_ext_sum_p_log_p");
+    ggml_set_name(entropy,     "temp_ext_entropy");
+
+    // Normalize the entropy, norm_entropy = entropy / max_entropy
+    struct ggml_tensor * norm_entropy = ggml_scale(ctx, entropy, 1.0f / max_entropy);
+    ggml_set_name(norm_entropy, "temp_ext_norm_entropy");
+
+    // Calculate the dynamic temperature:
+    // dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent);
+    //
+    // Calculate powf(normalized_entropy, exponent) as
+    // norm_entropy^exponent = exp(exponent * log(norm_entropy))
+    struct ggml_tensor * log_norm_entropy = ggml_log(ctx, norm_entropy);
+    struct ggml_tensor * scaled_log       = ggml_scale(ctx, log_norm_entropy, sctx->exponent);
+    struct ggml_tensor * pow_entropy      = ggml_exp(ctx, scaled_log);
+    // With pow_entropy computed we can now compute dyn_temp, scaling by
+    // (max_temp - min_temp) and then adding min_temp.
+    struct ggml_tensor * dyn_temp         = ggml_scale_bias(ctx, pow_entropy, max_temp - min_temp, min_temp);
+    ggml_set_name(log_norm_entropy, "temp_ext_log_norm_entropy");
+    ggml_set_name(scaled_log,       "temp_ext_scaled_log");
+    ggml_set_name(pow_entropy,      "temp_ext_pow_entropy");
+    ggml_set_name(dyn_temp,         "temp_ext_dyn_temp");
+
+    // Scale the logits by the dynamic temperature
+    struct ggml_tensor * scaled_logits = ggml_div(ctx, data->logits, dyn_temp);
+    ggml_set_name(scaled_logits, "temp_ext_scaled_logits");
+
+    data->logits = scaled_logits;
+}
+
 static struct llama_sampler_i llama_sampler_temp_ext_i = {
-    /* .name   = */ llama_sampler_temp_ext_name,
-    /* .accept = */ nullptr,
-    /* .apply  = */ llama_sampler_temp_ext_apply,
-    /* .reset  = */ nullptr,
-    /* .clone  = */ llama_sampler_temp_ext_clone,
-    /* .free   = */ llama_sampler_temp_ext_free,
+    /* .name              = */ llama_sampler_temp_ext_name,
+    /* .accept            = */ nullptr,
+    /* .apply             = */ llama_sampler_temp_ext_apply,
+    /* .reset             = */ nullptr,
+    /* .clone             = */ llama_sampler_temp_ext_clone,
+    /* .free              = */ llama_sampler_temp_ext_free,
+    /* .backend_init      = */ llama_sampler_temp_ext_backend_init,
+    /* .backend_accept    = */ nullptr,
+    /* .backend_apply     = */ llama_sampler_temp_ext_backend_apply,
+    /* .backend_set_input = */ nullptr,
 };
 
 struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
-    return llama_sampler_init(
+    const bool is_empty = temp == 1.0f && delta <= 0.0f;
+
+    if (is_empty) {
+        return llama_sampler_init_empty("?temp-ext");
+    }
+
+    auto * res = llama_sampler_init(
         /* .iface = */ &llama_sampler_temp_ext_i,
         /* .ctx   = */ new llama_sampler_temp_ext {
+            ("temp-ext"),
             /* .temp     = */ temp,
             /* .delta    = */ delta,
             /* .exponent = */ exponent,
         }
     );
+
+    return res;
 }
 
 // xtc
@@ -1239,7 +2142,7 @@ struct llama_sampler_xtc {
     const uint32_t seed;
     uint32_t       seed_cur;
 
-    std::mt19937    rng;
+    std::mt19937   rng;
 };
 
 static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) {
@@ -1304,16 +2207,27 @@ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
 }
 
 static struct llama_sampler_i llama_sampler_xtc_i = {
-    /* .name   = */ llama_sampler_xtc_name,
-    /* .accept = */ nullptr,
-    /* .apply  = */ llama_sample_xtc_apply,
-    /* .reset  = */ llama_sampler_xtc_reset,
-    /* .clone  = */ llama_sampler_xtc_clone,
-    /* .free   = */ llama_sampler_xtc_free,
+    /* .name              = */ llama_sampler_xtc_name,
+    /* .accept            = */ nullptr,
+    /* .apply             = */ llama_sample_xtc_apply,
+    /* .reset             = */ llama_sampler_xtc_reset,
+    /* .clone             = */ llama_sampler_xtc_clone,
+    /* .free              = */ llama_sampler_xtc_free,
+    /* .backend_init      = */ nullptr,
+    /* .backend_accept    = */ nullptr,
+    /* .backend_apply     = */ nullptr,
+    /* .backend_set_input = */ nullptr,
 };
 
 struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
-    auto seed_cur = get_rng_seed(seed);
+    const bool is_empty = (p <= 0.0f || t > 0.5f);
+
+    if (is_empty) {
+        return llama_sampler_init_empty("?xtc");
+    }
+
+    const auto seed_cur = get_rng_seed(seed);
+
     return llama_sampler_init(
         /* .iface = */ &llama_sampler_xtc_i,
         /* .ctx   = */ new llama_sampler_xtc {
@@ -1412,16 +2326,21 @@ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
 }
 
 static struct llama_sampler_i llama_sampler_mirostat_i = {
-    /* .name   = */ llama_sampler_mirostat_name,
-    /* .accept = */ nullptr,
-    /* .apply  = */ llama_sampler_mirostat_apply,
-    /* .reset  = */ llama_sampler_mirostat_reset,
-    /* .clone  = */ llama_sampler_mirostat_clone,
-    /* .free   = */ llama_sampler_mirostat_free,
+    /* .name              = */ llama_sampler_mirostat_name,
+    /* .accept            = */ nullptr,
+    /* .apply             = */ llama_sampler_mirostat_apply,
+    /* .reset             = */ llama_sampler_mirostat_reset,
+    /* .clone             = */ llama_sampler_mirostat_clone,
+    /* .free              = */ llama_sampler_mirostat_free,
+    /* .backend_init      = */ nullptr,
+    /* .backend_accept    = */ nullptr,
+    /* .backend_apply     = */ nullptr,
+    /* .backend_set_input = */ nullptr,
 };
 
 struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
-    auto seed_cur = get_rng_seed(seed);
+    const auto seed_cur = get_rng_seed(seed);
+
     return llama_sampler_init(
         /* .iface = */ &llama_sampler_mirostat_i,
         /* .ctx   = */ new llama_sampler_mirostat {
@@ -1511,12 +2430,16 @@ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
 }
 
 static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
-    /* .name   = */ llama_sampler_mirostat_v2_name,
-    /* .accept = */ nullptr,
-    /* .apply  = */ llama_sampler_mirostat_v2_apply,
-    /* .reset  = */ llama_sampler_mirostat_v2_reset,
-    /* .clone  = */ llama_sampler_mirostat_v2_clone,
-    /* .free   = */ llama_sampler_mirostat_v2_free,
+    /* .name              = */ llama_sampler_mirostat_v2_name,
+    /* .accept            = */ nullptr,
+    /* .apply             = */ llama_sampler_mirostat_v2_apply,
+    /* .reset             = */ llama_sampler_mirostat_v2_reset,
+    /* .clone             = */ llama_sampler_mirostat_v2_clone,
+    /* .free              = */ llama_sampler_mirostat_v2_free,
+    /* .backend_init      = */ nullptr,
+    /* .backend_accept    = */ nullptr,
+    /* .backend_apply     = */ nullptr,
+    /* .backend_set_input = */ nullptr,
 };
 
 struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
@@ -1628,12 +2551,16 @@ static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
 }
 
 static struct llama_sampler_i llama_sampler_grammar_i = {
-    /* .name   = */ llama_sampler_grammar_name,
-    /* .accept = */ llama_sampler_grammar_accept_impl,
-    /* .apply  = */ llama_sampler_grammar_apply,
-    /* .reset  = */ llama_sampler_grammar_reset,
-    /* .clone  = */ llama_sampler_grammar_clone,
-    /* .free   = */ llama_sampler_grammar_free,
+    /* .name              = */ llama_sampler_grammar_name,
+    /* .accept            = */ llama_sampler_grammar_accept_impl,
+    /* .apply             = */ llama_sampler_grammar_apply,
+    /* .reset             = */ llama_sampler_grammar_reset,
+    /* .clone             = */ llama_sampler_grammar_clone,
+    /* .free              = */ llama_sampler_grammar_free,
+    /* .backend_init      = */ nullptr,
+    /* .backend_accept    = */ nullptr,
+    /* .backend_apply     = */ nullptr,
+    /* .backend_set_input = */ nullptr,
 };
 
 static struct llama_sampler * llama_sampler_init_grammar_impl(
@@ -1835,12 +2762,16 @@ static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
 }
 
 static struct llama_sampler_i llama_sampler_penalties_i = {
-    /* .name   = */ llama_sampler_penalties_name,
-    /* .accept = */ llama_sampler_penalties_accept,
-    /* .apply  = */ llama_sampler_penalties_apply,
-    /* .reset  = */ llama_sampler_penalties_reset,
-    /* .clone  = */ llama_sampler_penalties_clone,
-    /* .free   = */ llama_sampler_penalties_free,
+    /* .name              = */ llama_sampler_penalties_name,
+    /* .accept            = */ llama_sampler_penalties_accept,
+    /* .apply             = */ llama_sampler_penalties_apply,
+    /* .reset             = */ llama_sampler_penalties_reset,
+    /* .clone             = */ llama_sampler_penalties_clone,
+    /* .free              = */ llama_sampler_penalties_free,
+    /* .backend_init      = */ nullptr,
+    /* .backend_accept    = */ nullptr,
+    /* .backend_apply     = */ nullptr,
+    /* .backend_set_input = */ nullptr,
 };
 
 struct llama_sampler * llama_sampler_init_penalties(
@@ -1850,6 +2781,12 @@ struct llama_sampler * llama_sampler_init_penalties(
         float penalty_present) {
     penalty_last_n = std::max(penalty_last_n, 0);
 
+    const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f));
+
+    if (is_empty) {
+        return llama_sampler_init_empty("?penalties");
+    }
+
     return llama_sampler_init(
         /* .iface = */ &llama_sampler_penalties_i,
         /* .ctx   = */ new llama_sampler_penalties {
@@ -1887,9 +2824,7 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
     for (size_t i = 0; i < cur_p->size; ++i) {
         // Only count non-negative infinity values
         if (cur_p->data[i].logit != -INFINITY) {
-            if (cur_p->data[i].logit > max) {
-                max = cur_p->data[i].logit;
-            }
+            max = std::max(max, cur_p->data[i].logit);
             logits_sum += cur_p->data[i].logit;
             valid_count++;
         }
@@ -1926,15 +2861,25 @@ static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
 }
 
 static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
-    /* .name   = */ llama_sampler_top_n_sigma_name,
-    /* .accept = */ nullptr,
-    /* .apply  = */ llama_sampler_top_n_sigma_apply,
-    /* .reset  = */ nullptr,
-    /* .clone  = */ llama_sampler_top_n_sigma_clone,
-    /* .free   = */ llama_sampler_top_n_sigma_free,
+    /* .name              = */ llama_sampler_top_n_sigma_name,
+    /* .accept            = */ nullptr,
+    /* .apply             = */ llama_sampler_top_n_sigma_apply,
+    /* .reset             = */ nullptr,
+    /* .clone             = */ llama_sampler_top_n_sigma_clone,
+    /* .free              = */ llama_sampler_top_n_sigma_free,
+    /* .backend_init      = */ nullptr,
+    /* .backend_accept    = */ nullptr,
+    /* .backend_apply     = */ nullptr,
+    /* .backend_set_input = */ nullptr,
 };
 
 struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
+    const bool is_empty = (n <= 0.0f);
+
+    if (is_empty) {
+        return llama_sampler_init_empty("?top-n-sigma");
+    }
+
     return llama_sampler_init(
         /* .iface = */ &llama_sampler_top_n_sigma_i,
         /* .ctx   = */ new llama_sampler_top_n_sigma {
@@ -2256,12 +3201,16 @@ static void llama_sampler_dry_free(struct llama_sampler * smpl) {
 }
 
 static struct llama_sampler_i llama_sampler_dry_i = {
-    /* .name   = */ llama_sampler_dry_name,
-    /* .accept = */ llama_sampler_dry_accept,
-    /* .apply  = */ llama_sampler_dry_apply,
-    /* .reset  = */ llama_sampler_dry_reset,
-    /* .clone  = */ llama_sampler_dry_clone,
-    /* .free   = */ llama_sampler_dry_free,
+    /* .name              = */ llama_sampler_dry_name,
+    /* .accept            = */ llama_sampler_dry_accept,
+    /* .apply             = */ llama_sampler_dry_apply,
+    /* .reset             = */ llama_sampler_dry_reset,
+    /* .clone             = */ llama_sampler_dry_clone,
+    /* .free              = */ llama_sampler_dry_free,
+    /* .backend_init      = */ nullptr,
+    /* .backend_accept    = */ nullptr,
+    /* .backend_apply     = */ nullptr,
+    /* .backend_set_input = */ nullptr,
 };
 
 struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
@@ -2272,6 +3221,10 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
 
     const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
 
+    if (!dry_enabled) {
+        return llama_sampler_init_empty("?dry");
+    }
+
     if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
         // Process sequence breakers
         for (size_t i = 0; i < num_breakers; ++i) {
@@ -2342,16 +3295,23 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa
 
 // logit-bias
 
-struct llama_sampler_logit_bias {
+struct llama_sampler_logit_bias : public llama_sampler_backend {
     const int32_t n_vocab;
 
     const std::vector<llama_logit_bias> logit_bias;
 
     std::vector<llama_logit_bias> to_search;
+
+    struct ggml_tensor * inp_logit_bias;
+    struct ggml_tensor * inp_logit_idxs;
+
+    ggml_context_ptr        inp_ctx;
+    ggml_backend_buffer_ptr inp_buf;
 };
 
-static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) {
-    return "logit-bias";
+static const char * llama_sampler_logit_bias_name(const struct llama_sampler * smpl) {
+    auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
+    return ctx->get_name();
 }
 
 static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
@@ -2396,25 +3356,123 @@ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
     delete (llama_sampler_logit_bias *) smpl->ctx;
 }
 
+static void llama_sampler_logit_bias_backend_apply(
+        struct llama_sampler      * smpl,
+        struct ggml_context       * ctx,
+        struct ggml_cgraph        * gf,
+        struct llama_sampler_data * data) {
+    GGML_UNUSED(gf);
+    GGML_UNUSED(ctx);
+
+    auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
+    if (sctx->logit_bias.empty()) {
+        return;
+    }
+
+    ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f);
+
+    cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur));
+    cur = ggml_set_rows(ctx, cur, sctx->inp_logit_bias, sctx->inp_logit_idxs);
+    cur = ggml_reshape_1d(ctx, cur, ggml_nelements(cur));
+
+    data->logits = ggml_add(ctx, data->logits, cur);
+}
+
+static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * smpl) {
+    auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
+    if (sctx->logit_bias.empty()) {
+        return;
+    }
+
+    GGML_ASSERT(sctx->inp_logit_bias != nullptr);
+    GGML_ASSERT(sctx->inp_logit_idxs != nullptr);
+
+    const size_t n = sctx->logit_bias.size();
+
+    std::vector<float>   data_logit_bias(n, 0.0f);
+    std::vector<int32_t> data_logit_idxs(n, 0);
+    for (size_t i = 0; i < n; ++i) {
+        const auto & lb = sctx->logit_bias[i];
+        GGML_ASSERT(lb.token >= 0 && lb.token < (int32_t) sctx->n_vocab);
+        data_logit_bias[i] = lb.bias;
+        data_logit_idxs[i] = lb.token;
+    }
+
+    ggml_backend_tensor_set(sctx->inp_logit_bias, data_logit_bias.data(), 0, ggml_nbytes(sctx->inp_logit_bias));
+    ggml_backend_tensor_set(sctx->inp_logit_idxs, data_logit_idxs.data(), 0, ggml_nbytes(sctx->inp_logit_idxs));
+}
+
+static bool llama_sampler_logit_bias_backend_init(
+        struct llama_sampler       * smpl,
+        ggml_backend_buffer_type_t   buft) {
+    auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
+
+    sctx->init(true);
+
+    if (sctx->logit_bias.empty()) {
+        return true;
+    }
+
+    ggml_init_params params = {
+        /*.mem_size   =*/ 2*ggml_tensor_overhead(),
+        /*.mem_buffer =*/ nullptr,
+        /*.no_alloc   =*/ true,
+    };
+
+    sctx->inp_ctx.reset(ggml_init(params));
+
+    const size_t n = sctx->logit_bias.size();
+
+    sctx->inp_logit_bias = ggml_new_tensor_2d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1, n);
+    ggml_set_name(sctx->inp_logit_bias, "logit_bias");
+    ggml_set_input(sctx->inp_logit_bias);
+
+    sctx->inp_logit_idxs = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_I32, n);
+    ggml_set_name(sctx->inp_logit_idxs, "logit_idxs");
+    ggml_set_input(sctx->inp_logit_idxs);
+
+    // Allocate all tensors from our context to the backend
+    sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
+
+    ggml_backend_buffer_clear(sctx->inp_buf.get(), 0);
+
+    return true;
+}
+
 static struct llama_sampler_i llama_sampler_logit_bias_i = {
-    /* .name   = */ llama_sampler_logit_bias_name,
-    /* .accept = */ nullptr,
-    /* .apply  = */ llama_sampler_logit_bias_apply,
-    /* .reset  = */ nullptr,
-    /* .clone  = */ llama_sampler_logit_bias_clone,
-    /* .free   = */ llama_sampler_logit_bias_free,
+    /* .name              = */ llama_sampler_logit_bias_name,
+    /* .accept            = */ nullptr,
+    /* .apply             = */ llama_sampler_logit_bias_apply,
+    /* .reset             = */ nullptr,
+    /* .clone             = */ llama_sampler_logit_bias_clone,
+    /* .free              = */ llama_sampler_logit_bias_free,
+    /* .backend_init      = */ llama_sampler_logit_bias_backend_init,
+    /* .backend_accept    = */ nullptr,
+    /* .backend_apply     = */ llama_sampler_logit_bias_backend_apply,
+    /* .backend_set_input = */ llama_sampler_logit_bias_backend_set_input,
 };
 
 struct llama_sampler * llama_sampler_init_logit_bias(
                          int32_t   n_vocab,
                          int32_t   n_logit_bias,
           const llama_logit_bias * logit_bias) {
+    const bool is_empty = n_logit_bias <= 0;
+
+    if (is_empty) {
+        return llama_sampler_init_empty("?logit-bias");
+    }
+
     return llama_sampler_init(
         /* .iface = */ &llama_sampler_logit_bias_i,
         /* .ctx   = */ new llama_sampler_logit_bias {
-            /* .n_vocab    = */ n_vocab,
-            /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
-            /* .to_search  = */ {},
+            ("logit-bias"),
+            /* .n_vocab        = */ n_vocab,
+            /* .logit_bias     = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
+            /* .to_search      = */ {},
+            /* .inp_logit_bias = */ nullptr,
+            /* .inp_logit_idxs = */ nullptr,
+            /* .inp_ctx        = */ nullptr,
+            /* .inp_buf        = */ nullptr,
         }
     );
 }
@@ -2627,12 +3685,16 @@ static void llama_sampler_infill_free(struct llama_sampler * smpl) {
 }
 
 static struct llama_sampler_i llama_sampler_infill_i = {
-    /* .name   = */ llama_sampler_infill_name,
-    /* .accept = */ nullptr,
-    /* .apply  = */ llama_sampler_infill_apply,
-    /* .reset  = */ nullptr,
-    /* .clone  = */ llama_sampler_infill_clone,
-    /* .free   = */ llama_sampler_infill_free,
+    /* .name              = */ llama_sampler_infill_name,
+    /* .accept            = */ nullptr,
+    /* .apply             = */ llama_sampler_infill_apply,
+    /* .reset             = */ nullptr,
+    /* .clone             = */ llama_sampler_infill_clone,
+    /* .free              = */ llama_sampler_infill_free,
+    /* .backend_apply     = */ nullptr,
+    /* .backend_accept    = */ nullptr,
+    /* .backend_set_input = */ nullptr,
+    /* .backend_init      = */ nullptr,
 };
 
 struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
@@ -2664,7 +3726,7 @@ uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
     if (smpl->iface == &llama_sampler_chain_i) {
         const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
         for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
-            const uint32_t seed = llama_sampler_get_seed(*it);
+            const uint32_t seed = llama_sampler_get_seed(it->ptr);
             if (seed != LLAMA_DEFAULT_SEED) {
                 return seed;
             }
index 1e3de4e2ec4988ae4d5b41cc420099ca928c7810..6a963c0bb73c1e0fbe8f4872d36760f165a1244d 100644 (file)
@@ -14,7 +14,16 @@ struct llama_grammar;
 struct llama_sampler_chain {
     llama_sampler_chain_params params;
 
-    std::vector<struct llama_sampler *> samplers;
+    // has .backend_init() been called?
+    bool is_init = false;
+
+    struct info {
+        bool is_backend;
+
+        llama_sampler * ptr;
+    };
+
+    std::vector<info> samplers;
 
     // pre-allocated buffer for llama_sampler_sample to avoid repeated allocations
     std::vector<llama_token_data> cur;
@@ -27,9 +36,9 @@ struct llama_sampler_chain {
 };
 
 struct llama_sampler * llama_sampler_init_dry_testing(
-                         int32_t   context_size,
-                           float   dry_multiplier,
-                           float   dry_base,
-                         int32_t   dry_allowed_length,
-                         int32_t   dry_penalty_last_n,
-  const std::vector<std::vector<llama_token>>& seq_breakers);
+        int32_t context_size,
+        float   dry_multiplier,
+        float   dry_base,
+        int32_t dry_allowed_length,
+        int32_t dry_penalty_last_n,
+        const std::vector<std::vector<llama_token>> & seq_breakers);
index cd4092ca0772ad0af6ca01ece0a6f31eddf44a82..a20c6525e46ac35e95c786d08c8e9fe07295fd32 100644 (file)
@@ -314,6 +314,12 @@ struct llm_tokenizer_bpe : llm_tokenizer {
                     "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
                 };
                 break;
+            case LLAMA_VOCAB_PRE_TYPE_YOUTU:
+                regex_exprs = {
+                    "[가-힣ㄱ-ㆎ]+|[!…“”‘’—:;,、-〿︰-﹏]+|[ㄅ-ㄯ]+|[一-龥぀-ゟ゠-ヿ]+",
+                    "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+                };
+                break;
             case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
                 regex_exprs = {
                     "[\r\n]",
@@ -355,6 +361,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
             case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
             case LLAMA_VOCAB_PRE_TYPE_QWEN2:
             case LLAMA_VOCAB_PRE_TYPE_HUNYUAN:
+            case LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN:
                 regex_exprs = {
                     // original regex from tokenizer.json
                     // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
@@ -1860,6 +1867,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                     tokenizer_pre == "deepseek-v3") {
                 pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM;
                 clean_spaces = false;
+            } else if (
+                    tokenizer_pre == "youtu") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_YOUTU;
+                clean_spaces = false;
+                ignore_merges = true;
             } else if (
                     tokenizer_pre == "falcon") {
                 pre_type = LLAMA_VOCAB_PRE_TYPE_FALCON;
@@ -2015,6 +2027,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                 tokenizer_pre == "minimax-m2") {
                 pre_type = LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2;
                 clean_spaces = false;
+            } else if (
+                tokenizer_pre == "solar-open") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN;
+                clean_spaces = false;
             } else {
                 throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
             }
@@ -2187,6 +2203,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
         //       for now, we apply this workaround to find the tokens based on their text
 
         for (const auto & t : token_to_id) {
+            auto & attr = id_to_token[t.second].attr;
+
             // find EOT token: "<|eot_id|>", "<|im_end|>", "<end_of_turn>", etc.
             if (special_eot_id == LLAMA_TOKEN_NULL) {
                 if (false
@@ -2202,10 +2220,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == "<end_of_utterance>" // smoldocling
                    ) {
                     special_eot_id = t.second;
-                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                    if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
                         LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
                                 __func__, t.second, t.first.c_str());
-                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                        attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
                     }
                 }
             }
@@ -2216,10 +2234,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == "<|eom_id|>"
                         ) {
                     special_eom_id = t.second;
-                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                    if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
                         LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
                                 __func__, t.second, t.first.c_str());
-                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                        attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
                     }
                 }
             }
@@ -2236,10 +2254,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == "<|code_prefix|>" // GLM-4.5
                         ) {
                     special_fim_pre_id = t.second;
-                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                    if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
                         LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
                                 __func__, t.second, t.first.c_str());
-                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                        attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
                     }
                 }
             }
@@ -2256,10 +2274,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == "<|code_suffix|>" // GLM-4.5
                         ) {
                     special_fim_suf_id = t.second;
-                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                    if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
                         LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
                                 __func__, t.second, t.first.c_str());
-                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                        attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
                     }
                 }
             }
@@ -2276,10 +2294,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == "<|code_middle|>" // GLM-4.5
                         ) {
                     special_fim_mid_id = t.second;
-                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                    if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
                         LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
                                 __func__, t.second, t.first.c_str());
-                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                        attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
                     }
                 }
             }
@@ -2293,10 +2311,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == "<PAD>"
                         ) {
                     special_fim_pad_id = t.second;
-                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                    if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
                         LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
                                 __func__, t.second, t.first.c_str());
-                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                        attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
                     }
                 }
             }
@@ -2311,10 +2329,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == "<reponame>"    // Granite
                         ) {
                     special_fim_rep_id = t.second;
-                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                    if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
                         LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
                                 __func__, t.second, t.first.c_str());
-                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                        attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
                     }
                 }
             }
@@ -2325,15 +2343,41 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == "<|file_sep|>" // Qwen
                         ) {
                     special_fim_sep_id = t.second;
-                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                    if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
                         LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
                                 __func__, t.second, t.first.c_str());
-                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                        attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
                     }
                 }
             }
         }
 
+        // auto-detect unused tokens: e.g. control tokens with the word "unused"
+        // ideally, these tokens should be marked as unused during conversion
+        {
+            uint32_t n_unused = 0;
+
+            for (const auto & t : token_to_id) {
+                auto & attr = id_to_token[t.second].attr;
+
+                if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                    continue;
+                }
+
+                if ((attr & LLAMA_TOKEN_ATTR_UNUSED) == 0) {
+                    if (strstr(t.first.c_str(), "unused") != NULL) {
+                        attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_UNUSED);
+                    }
+                }
+
+                if (attr & LLAMA_TOKEN_ATTR_UNUSED) {
+                    n_unused++;
+                }
+            }
+
+            LLAMA_LOG_INFO("%s: %u unused tokens\n", __func__, n_unused);
+        }
+
         // maintain a list of tokens that cause end-of-generation
         // this is currently determined based on the token text, which is obviously not ideal
         // ref: https://github.com/ggerganov/llama.cpp/issues/9606
@@ -2352,12 +2396,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
         }
 
         for (const auto & t : token_to_id) {
+            auto & attr = id_to_token[t.second].attr;
+
             if (false
                     || t.first == "<|eot_id|>"
                     || t.first == "<|im_end|>"
                     || t.first == "<|end|>"
                     || t.first == "<|return|>" // o200k_harmony
                     || t.first == "<|call|>"   // o200k_harmony
+                    || t.first == "<|flush|>"  // solar-open
+                    || t.first == "<|calls|>"  // solar-open
                     || t.first == "<end_of_turn>"
                     || t.first == "<|endoftext|>"
                     || t.first == "<|eom_id|>"
@@ -2367,24 +2415,28 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                     || t.first == "<end_of_utterance>" // smoldocling
                ) {
                 special_eog_ids.insert(t.second);
-                if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
                     LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
                             __func__, t.second, t.first.c_str());
-                    id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL);
                 }
             } else {
-                // token is control, but not marked as EOG -> print a debug log
-                if (id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL && special_eog_ids.count(t.second) == 0) {
-                    LLAMA_LOG_DEBUG("%s: control token: %6d '%s' is not marked as EOG\n",
-                            __func__, t.second, t.first.c_str());
+                if (attr & LLAMA_TOKEN_ATTR_CONTROL && !(attr & LLAMA_TOKEN_ATTR_UNUSED)) {
+                    // token is control, but not marked as EOG -> print a debug log
+                    if (special_eog_ids.count(t.second) == 0) {
+                        LLAMA_LOG_DEBUG("%s: control token: %6d '%s' is not marked as EOG\n",
+                                __func__, t.second, t.first.c_str());
+                    }
                 }
             }
         }
 
         // @ngxson : quick hack for gpt-oss, always render these tokens
         for (const auto & t : token_to_id) {
+            auto & attr = id_to_token[t.second].attr;
+
             if (t.first == "<|channel|>" || t.first == "<|message|>" || t.first == "<|start|>" || t.first == "<|constrain|>") {
-                id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_USER_DEFINED;
+                attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_USER_DEFINED);
             }
         }
 
@@ -2404,34 +2456,42 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
             LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
         }
 
-        // TODO: workaround for o200k_harmony tokenizer: the "<|end|>" token should not be EOG
-        //       we don't have a good way to detect this, so for now, if we have "<|return|>" and "<|call|>" tokens,
+        // TODO: workaround for o200k_harmony and solar-open tokenizer: the "<|end|>" token should not be EOG
+        //       we don't have a good way to detect this, so for now, if we have "<|return|>" and "<|call|>" tokens ("<|calls|>" and "<|flush|>" for solar-open),
         //       we remove the "<|end|>" token from the EOG list
         {
             bool has_return = false;
             bool has_call   = false;
             bool has_end    = false;
+            bool has_flush  = false;
 
             llama_token end_id = LLAMA_TOKEN_NULL;
 
             LLAMA_LOG_INFO("%s: printing all EOG tokens:\n", __func__);
             for (auto tid : special_eog_ids) {
-                LLAMA_LOG_INFO("%s:   - %d ('%s')\n", __func__, tid, id_to_token[tid].text.c_str());
+                auto & text = id_to_token[tid].text;
 
-                if (id_to_token[tid].text == "<|return|>") {
+                LLAMA_LOG_INFO("%s:   - %d ('%s')\n", __func__, tid, text.c_str());
+
+                if (text == "<|return|>") {
                     has_return = true;
-                } else if (id_to_token[tid].text == "<|call|>") {
+                } else if (text == "<|call|>" || text == "<|calls|>") {
                     has_call = true;
-                } else if (id_to_token[tid].text == "<|end|>") {
+                } else if (text == "<|flush|>") {
+                    has_flush = true;
+                } else if (text == "<|end|>") {
                     has_end = true;
                     end_id = tid;
                 }
             }
 
-            if (has_return && has_call && has_end) {
+            if ((has_return && has_call && has_end) || (has_call && has_flush && has_end)) {
                 special_eog_ids.erase(end_id);
-                id_to_token[end_id].attr = LLAMA_TOKEN_ATTR_USER_DEFINED;
-                LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>' tokens, removing '<|end|>' token from EOG list\n", __func__);
+
+                auto & attr = id_to_token[end_id].attr;
+                attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_USER_DEFINED);
+
+                LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>', or '<|calls|>' and '<|flush|>' tokens, removing '<|end|>' token from EOG list\n", __func__);
             }
         }
     }
index 55f8f3923c95bcecd3033c94dcee1584c3e5060e..2b240a5491bed19bebf2aa83fc5477eafe97eab5 100644 (file)
@@ -51,6 +51,8 @@ enum llama_vocab_pre_type {
     LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40,
     LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2      = 41,
     LLAMA_VOCAB_PRE_TYPE_AFMOE           = 42,
+    LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN      = 43,
+    LLAMA_VOCAB_PRE_TYPE_YOUTU           = 44,
 };
 
 struct LLM_KV;
index 76b3acbadb62906dc42474f380ee2b70b2bd99ac..f1096d960e130ce83140157d020b7af191fd8a5c 100644 (file)
@@ -111,8 +111,20 @@ static std::vector<llama_device_memory_data> llama_get_device_memory_data(
         }
     }
     for (size_t i = 0; i < ret.size(); i++) {
-        size_t free, total;
+        size_t free;
+        size_t total;
         ggml_backend_dev_memory(model->devices[i], &free, &total);
+
+        // devices can return 0 bytes for free and total memory if they do not
+        // have any to report. in this case, we will use the host memory as a fallback
+        // fixes: https://github.com/ggml-org/llama.cpp/issues/18577
+        if (free == 0 && total == 0) {
+            ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
+            if (cpu_dev == nullptr) {
+                throw std::runtime_error(format("%s: no CPU backend found", __func__));
+            }
+            ggml_backend_dev_memory(cpu_dev, &free, &total);
+        }
         ret[i].free  = free;
         ret[i].total = total;
     }
@@ -147,9 +159,8 @@ class llama_params_fit_exception : public std::runtime_error {
 static void llama_params_fit_impl(
         const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams,
         float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides,
-        size_t margin_s, uint32_t n_ctx_min, enum ggml_log_level log_level) {
+        size_t * margins_s, uint32_t n_ctx_min, enum ggml_log_level log_level) {
     constexpr int64_t MiB = 1024*1024;
-    const int64_t margin = margin_s; // this function uses int64_t rather than size_t for memory sizes to more conveniently handle deficits
     typedef std::vector<llama_device_memory_data> dmds_t;
     const llama_model_params default_mparams = llama_model_default_params();
 
@@ -168,6 +179,12 @@ static void llama_params_fit_impl(
         return;
     }
 
+    std::vector<int64_t> margins; // this function uses int64_t rather than size_t for memory sizes to more conveniently handle deficits
+    margins.reserve(nd);
+    for (size_t id = 0; id < nd; id++) {
+        margins.push_back(margins_s[id]);
+    }
+
     std::vector<std::string> dev_names;
     {
         dev_names.reserve(nd);
@@ -187,9 +204,10 @@ static void llama_params_fit_impl(
 
     int64_t sum_free            = 0;
     int64_t sum_projected_free  = 0;
-    int64_t min_projected_free  = INT64_MAX;
     int64_t sum_projected_used  = 0;
     int64_t sum_projected_model = 0;
+    std::vector<int64_t> projected_free_per_device;
+    projected_free_per_device.reserve(nd);
 
     if (nd > 1) {
         LLAMA_LOG_INFO("%s: projected memory use with initial parameters [MiB]:\n", __func__);
@@ -199,45 +217,63 @@ static void llama_params_fit_impl(
 
         const int64_t projected_used = dmd.mb.total();
         const int64_t projected_free = dmd.free - projected_used;
+        projected_free_per_device.push_back(projected_free);
 
         sum_free            += dmd.free;
         sum_projected_used  += projected_used;
         sum_projected_free  += projected_free;
-        min_projected_free   = std::min(min_projected_free, projected_free);
         sum_projected_model += dmd.mb.model;
 
         if (nd > 1) {
-            LLAMA_LOG_INFO("%s:   - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " %s\n",
-                __func__, dev_names[id].c_str(), dmd.total/MiB, projected_used/MiB, std::abs(projected_free)/MiB,
-                projected_free >= 0 ? "surplus" : "deficit");
+            LLAMA_LOG_INFO("%s:   - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " free vs. target of %6" PRId64 "\n",
+                __func__, dev_names[id].c_str(), dmd.total/MiB, projected_used/MiB, projected_free/MiB, margins[id]/MiB);
         }
     }
     assert(sum_free >= 0 && sum_projected_used >= 0);
     LLAMA_LOG_INFO("%s: projected to use %" PRId64 " MiB of device memory vs. %" PRId64 " MiB of free device memory\n",
         __func__, sum_projected_used/MiB, sum_free/MiB);
-    if (min_projected_free >= margin) {
-        if (nd == 1) {
+    if (nd == 1) {
+        if (projected_free_per_device[0] >= margins[0]) {
             LLAMA_LOG_INFO("%s: will leave %" PRId64 " >= %" PRId64 " MiB of free device memory, no changes needed\n",
-                __func__, min_projected_free/MiB, margin/MiB);
+                __func__, projected_free_per_device[0]/MiB, margins[0]/MiB);
+            return;
+        }
+    } else {
+        bool changes_needed = false;
+        for (size_t id = 0; id < nd; id++) {
+            if (projected_free_per_device[id] < margins[id]) {
+                changes_needed = true;
+                break;
+            }
+        }
+        if (!changes_needed) {
+            LLAMA_LOG_INFO("%s: targets for free memory can be met on all devices, no changes needed\n", __func__);
             return;
         }
-        LLAMA_LOG_INFO("%s: will leave at least %" PRId64 " >= %" PRId64 " MiB of free memory on all devices, no changes needed\n",
-            __func__, min_projected_free/MiB, margin/MiB);
-        return;
     }
 
     // step 2: try reducing memory use by reducing the context size
 
     {
-        int64_t global_surplus = sum_projected_free - int64_t(nd)*margin;
+        int64_t global_surplus = sum_projected_free;
+        for (size_t id = 0; id < nd; id++) {
+            global_surplus -= margins[id];
+        }
         if (global_surplus < 0) {
-            LLAMA_LOG_INFO(nd == 1 ?
-                "%s: cannot fulfill margin of %" PRId64 " MiB, need to reduce device memory by %" PRId64 " MiB\n" :
-                "%s: cannot fulfill margin of %" PRId64 " MiB on all devices, need to use %" PRId64 " MiB less in total\n",
-                __func__, margin/MiB, -global_surplus/MiB);
+            if (nd == 1) {
+                LLAMA_LOG_INFO("%s: cannot meet free memory target of %" PRId64 " MiB, need to reduce device memory by %" PRId64 " MiB\n",
+                    __func__, margins[0]/MiB, -global_surplus/MiB);
+            } else {
+                LLAMA_LOG_INFO(
+                    "%s: cannot meet free memory targets on all devices, need to use %" PRId64 " MiB less in total\n",
+                    __func__, -global_surplus/MiB);
+            }
             if (cparams->n_ctx == 0) {
                 if (hp_nct > n_ctx_min) {
-                    int64_t sum_used_target = sum_free - nd*margin_s;
+                    int64_t sum_used_target = sum_free;
+                    for (size_t id = 0; id < nd; id++) {
+                        sum_used_target -= margins[id];
+                    }
                     if (nd > 1) {
                         // for multiple devices we need to be more conservative in terms of how much context we think can fit:
                         //   - for dense models only whole layers can be assigned to devices
@@ -359,6 +395,11 @@ static void llama_params_fit_impl(
 
         // for the first partial layer varying parts can overflow, all further layers use LAYER_FRACTION_MOE:
         layer_fraction_t overflow_type = LAYER_FRACTION_MOE;
+
+        uint32_t n_full() const {
+            assert(n_layer >= n_part);
+            return n_layer - n_part;
+        }
     };
 
     const size_t ntbo = llama_max_tensor_buft_overrides();
@@ -382,7 +423,7 @@ static void llama_params_fit_impl(
 
         size_t itbo = 0;
         for (size_t id = 0; id < nd; id++) {
-            il0 += ngl_per_device[id].n_layer - ngl_per_device[id].n_part;
+            il0 += ngl_per_device[id].n_full();
             for (uint32_t il = il0; il < il0 + ngl_per_device[id].n_part; il++) {
                 if (itbo + 1 >= ntbo) {
                     tensor_buft_overrides[itbo].pattern = nullptr;
@@ -393,7 +434,7 @@ static void llama_params_fit_impl(
                         + std::to_string(ntbo) + " is insufficient for model");
                 }
                 tensor_buft_overrides[itbo].pattern = get_overflow_pattern(il, il == il0 ? ngl_per_device[id].overflow_type : LAYER_FRACTION_MOE);
-                tensor_buft_overrides[itbo].buft = overflow_bufts[id];
+                tensor_buft_overrides[itbo].buft = il == il0 ? overflow_bufts[id] : ggml_backend_cpu_buffer_type();
                 itbo++;
             }
             il0 += ngl_per_device[id].n_part;
@@ -443,9 +484,9 @@ static void llama_params_fit_impl(
         const dmds_t dmds_cpu_moe = llama_get_device_memory_data(
             path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level);
 
-        for (const llama_device_memory_data & dmd : dmds_cpu_moe) {
-            global_surplus_cpu_moe += dmd.free;
-            global_surplus_cpu_moe -= int64_t(dmd.mb.total()) + margin;
+        for (size_t id = 0; id < nd; id++) {
+            global_surplus_cpu_moe += dmds_cpu_moe[id].free;
+            global_surplus_cpu_moe -= int64_t(dmds_cpu_moe[id].mb.total()) + margins[id];
         }
 
         if (global_surplus_cpu_moe > 0) {
@@ -464,24 +505,18 @@ static void llama_params_fit_impl(
     std::vector<int64_t> targets; // maximum acceptable memory use per device
     targets.reserve(nd);
     for (size_t id = 0; id < nd; id++) {
-        targets.push_back(dmds_full[id].free - margin);
+        targets.push_back(dmds_full[id].free - margins[id]);
         LLAMA_LOG_DEBUG("%s: id=%zu, target=%" PRId64 " MiB\n", __func__, id, targets[id]/MiB);
     }
 
-    std::vector<ggml_backend_buffer_type_t> overflow_bufts; // which bufts the partial layers of a device overflow to:
+    std::vector<ggml_backend_buffer_type_t> overflow_bufts; // which bufts the first partial layer of a device overflows to:
     overflow_bufts.reserve(nd);
-    for (size_t id = 0; id < nd - 1; ++id) {
-        overflow_bufts.push_back(ggml_backend_dev_buffer_type(devs[id + 1]));
+    for (size_t id = 0; id < nd; id++) {
+        overflow_bufts.push_back(ggml_backend_cpu_buffer_type());
     }
-    overflow_bufts.push_back(ggml_backend_cpu_buffer_type());
 
     std::vector<ngl_t> ngl_per_device(nd);
     std::vector<int64_t> mem = get_memory_for_layers(__func__, ngl_per_device, overflow_bufts);
-    if (hp_nex > 0) {
-        for (size_t id = 0; id < nd; id++) {
-            ngl_per_device[id].overflow_type = LAYER_FRACTION_MOE;
-        }
-    }
 
     // optimize the number of layers per device using the method of false position:
     //   - ngl_per_device has 0 layers for each device, lower bound
@@ -512,9 +547,6 @@ static void llama_params_fit_impl(
             if (mem_high[id] > targets[id]) {
                 assert(ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer);
                 uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer;
-                if (hp_nex > 0 && size_t(id) == nd - 1) {
-                    delta--;
-                }
                 LLAMA_LOG_DEBUG("%s: start filling device %" PRIu32 ", delta=%" PRIu32 "\n", __func__, id, delta);
                 while (delta > 1) {
                     uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]);
@@ -524,7 +556,8 @@ static void llama_params_fit_impl(
                     std::vector<ngl_t> ngl_per_device_test = ngl_per_device;
                     ngl_per_device_test[id].n_layer += step_size;
                     if (hp_nex) {
-                        ngl_per_device_test[id].n_part += step_size;
+                        ngl_per_device_test[id].n_part += size_t(id) == nd - 1 && ngl_per_device_test[id].n_part == 0 ?
+                            step_size - 1 : step_size; // the first layer is the output layer which must always be full
                     }
                     const std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
 
@@ -573,7 +606,7 @@ static void llama_params_fit_impl(
     assert(id_dense_start < nd);
 
     LLAMA_LOG_INFO("%s: converting dense-only layers to full layers and filling them front-to-back with overflow to next device/system memory:\n", __func__);
-    for (size_t id = 0; id <= id_dense_start; id++) {
+    for (size_t id = 0; id <= id_dense_start && id_dense_start < nd; id++) {
         std::vector<ngl_t> ngl_per_device_high = ngl_per_device;
         for (size_t jd = id_dense_start; jd < nd; jd++) {
             const uint32_t n_layer_move = jd < nd - 1 ? ngl_per_device_high[jd].n_layer : ngl_per_device_high[jd].n_layer - 1;
@@ -585,12 +618,8 @@ static void llama_params_fit_impl(
         std::vector<int64_t> mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts);
 
         if (mem_high[id] > targets[id]) {
-            assert(ngl_per_device_high[id].n_layer >= ngl_per_device_high[id].n_part);
-            assert(ngl_per_device[id].n_layer >= ngl_per_device[id].n_part);
-            assert((ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part)
-                   >= ngl_per_device[id].n_layer - ngl_per_device[id].n_part);
-            uint32_t delta = (ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part)
-                - (ngl_per_device[id].n_layer - ngl_per_device[id].n_part);
+            assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full());
+            uint32_t delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full();
             while (delta > 1) {
                 uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]);
                 step_size = std::max(step_size, uint32_t(1));
@@ -606,7 +635,7 @@ static void llama_params_fit_impl(
                     ngl_per_device_test[id].n_layer += n_convert_jd;
                     n_converted_test += n_convert_jd;
 
-                    if (ngl_per_device_test[id_dense_start_test].n_layer > 0) {
+                    if (ngl_per_device_test[id_dense_start_test].n_part > 0) {
                         break;
                     }
                 }
@@ -625,8 +654,8 @@ static void llama_params_fit_impl(
                     LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start_high=%zu\n",
                         __func__, id, ngl_per_device_high[id].n_layer, ngl_per_device_high[id].n_part, id_dense_start_high);
                 }
-                delta = (ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part)
-                    - (ngl_per_device[id].n_layer - ngl_per_device[id].n_part);
+                assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full());
+                delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full();
             }
         } else {
             ngl_per_device = ngl_per_device_high;
@@ -644,14 +673,19 @@ static void llama_params_fit_impl(
             ngl_per_device_test[id_dense_start_test].n_part--;
             ngl_per_device_test[id].n_layer++;
             ngl_per_device_test[id].n_part++;
-            if (ngl_per_device_test[id_dense_start_test].n_layer == 0) {
+            if (ngl_per_device_test[id_dense_start_test].n_part == 0) {
                 id_dense_start_test++;
             }
             ngl_per_device_test[id].overflow_type = LAYER_FRACTION_UP;
+            std::vector<ggml_backend_buffer_type_t> overflow_bufts_test = overflow_bufts;
+            if (id < nd - 1) {
+                overflow_bufts_test[id] = ggml_backend_dev_buffer_type(devs[id + 1]);
+            }
             LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP\n", __func__);
-            std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
+            std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test);
             if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
                 ngl_per_device = ngl_per_device_test;
+                overflow_bufts = overflow_bufts_test;
                 mem            = mem_test;
                 id_dense_start = id_dense_start_test;
                 LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", UP), id_dense_start=%zu\n",
@@ -659,9 +693,10 @@ static void llama_params_fit_impl(
 
                 ngl_per_device_test[id].overflow_type = LAYER_FRACTION_GATE;
                 LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE\n", __func__);
-                mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
+                mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test);
                 if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
                     ngl_per_device = ngl_per_device_test;
+                    overflow_bufts = overflow_bufts_test;
                     mem            = mem_test;
                     id_dense_start = id_dense_start_test;
                     LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", GATE), id_dense_start=%zu\n",
@@ -670,9 +705,10 @@ static void llama_params_fit_impl(
             } else {
                 ngl_per_device_test[id].overflow_type = LAYER_FRACTION_ATTN;
                 LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_ATTN\n", __func__);
-                mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
+                mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test);
                 if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
                     ngl_per_device = ngl_per_device_test;
+                    overflow_bufts = overflow_bufts_test;
                     mem            = mem_test;
                     id_dense_start = id_dense_start_test;
                     LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", ATTN), id_dense_start=%zu\n",
@@ -687,17 +723,25 @@ static void llama_params_fit_impl(
             __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB);
     }
 
+    // print info for devices that were not changed during the conversion from dense only to full layers:
+    for (size_t id = id_dense_start + 1; id < nd; id++) {
+        const int64_t projected_margin = dmds_full[id].free - mem[id];
+        LLAMA_LOG_INFO(
+            "%s:   - %s: %2" PRIu32 " layers (%2" PRIu32 " overflowing), %6" PRId64 " MiB used, %6" PRId64 " MiB free\n",
+            __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB);
+    }
+
     set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams);
 }
 
 enum llama_params_fit_status llama_params_fit(
         const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams,
         float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides,
-        size_t margin_s, uint32_t n_ctx_min, enum ggml_log_level log_level) {
+        size_t * margins, uint32_t n_ctx_min, enum ggml_log_level log_level) {
     const int64_t t0_us = llama_time_us();
     llama_params_fit_status status = LLAMA_PARAMS_FIT_STATUS_SUCCESS;
     try {
-        llama_params_fit_impl(path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margin_s, n_ctx_min, log_level);
+        llama_params_fit_impl(path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margins, n_ctx_min, log_level);
         LLAMA_LOG_INFO("%s: successfully fit params to free device memory\n", __func__);
     } catch (const llama_params_fit_exception & e) {
         LLAMA_LOG_WARN("%s: failed to fit params to free device memory: %s\n", __func__, e.what());
@@ -713,7 +757,7 @@ enum llama_params_fit_status llama_params_fit(
 
 struct llama_sampler_chain_params llama_sampler_chain_default_params() {
     struct llama_sampler_chain_params result = {
-        /*.no_perf                     =*/ true,
+        /*.no_perf =*/ true,
     };
 
     return result;
@@ -786,7 +830,7 @@ static int llama_model_load(const std::string & fname, std::vector<std::string>
     model.t_start_us = tm.t_start_us;
 
     try {
-        llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides);
+        llama_model_loader ml(fname, splits, params.use_mmap, params.use_direct_io, params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides);
 
         ml.print_info();
 
index 8b3c8a7b10a5c3e3c5963176bbbc9797385447f4..1c17efb9fa1c9f2e2b45cd073704cd35de80e5c8 100644 (file)
@@ -309,6 +309,7 @@ extern "C" {
         // Keep the booleans together to avoid misalignment during copy-by-value.
         bool vocab_only;      // only load the vocabulary, no weights
         bool use_mmap;        // use mmap if possible
+        bool use_direct_io;   // use direct io, takes precedence over use_mmap
         bool use_mlock;       // force system to keep model in RAM
         bool check_tensors;   // validate model tensor data
         bool use_extra_bufts; // use extra buffer types (used for weight repacking)
@@ -316,6 +317,11 @@ extern "C" {
         bool no_alloc;        // only load metadata and simulate memory allocations
     };
 
+    struct llama_sampler_seq_config {
+        llama_seq_id           seq_id;
+        struct llama_sampler * sampler;
+    };
+
     // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
     //       https://github.com/ggml-org/llama.cpp/pull/7544
     struct llama_context_params {
@@ -364,6 +370,12 @@ extern "C" {
         bool kv_unified;  // use a unified buffer across the input sequences when computing the attention
                           // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
                           // ref: https://github.com/ggml-org/llama.cpp/pull/14363
+
+        // [EXPERIMENTAL]
+        // backend sampler chain configuration (make sure the caller keeps the sampler chains alive)
+        // note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init)
+        struct llama_sampler_seq_config * samplers;
+        size_t                            n_samplers;
     };
 
     // model quantization parameters
@@ -483,7 +495,7 @@ extern "C" {
                     struct llama_context_params * cparams,
                                           float * tensor_split,          // writable buffer for tensor split, needs at least llama_max_devices elements
         struct llama_model_tensor_buft_override * tensor_buft_overrides, // writable buffer for overrides, needs at least llama_max_tensor_buft_overrides elements
-                                         size_t   margin,                // margin of memory to leave per device in bytes
+                                         size_t * margins,               // margins of memory to leave per device in bytes
                                        uint32_t   n_ctx_min,             // minimum context size to set when trying to reduce memory use
                             enum ggml_log_level   log_level);            // minimum log level to print during fitting, lower levels go to debug log
 
@@ -524,6 +536,7 @@ extern "C" {
     LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model);
     LLAMA_API int32_t llama_model_n_embd     (const struct llama_model * model);
     LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model);
+    LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model);
     LLAMA_API int32_t llama_model_n_layer    (const struct llama_model * model);
     LLAMA_API int32_t llama_model_n_head     (const struct llama_model * model);
     LLAMA_API int32_t llama_model_n_head_kv  (const struct llama_model * model);
@@ -992,6 +1005,32 @@ extern "C" {
     // otherwise: float[n_embd] (1-dimensional)
     LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
 
+    //
+    // backend sampling API [EXPERIMENTAL]
+    // note: use only if the llama_context was created with at least one llama_sampler_seq_config
+    //
+
+    // Get the backend sampled token for the ith token.
+    // Returns LLAMA_TOKEN_NULL if no token was sampled.
+    LLAMA_API llama_token llama_get_sampled_token_ith(struct llama_context * ctx, int32_t i);
+
+    // Get the backend sampled probabilites for the ith token
+    // The index matches llama_get_sampled_token_ith().
+    // Returns NULL if no probabilites were generated.
+    LLAMA_API float *  llama_get_sampled_probs_ith      (struct llama_context * ctx, int32_t i);
+    LLAMA_API uint32_t llama_get_sampled_probs_count_ith(struct llama_context * ctx, int32_t i);
+
+    // Get the backend sampled logits for the ith token
+    // Returns NULL if no logits were sampled.
+    LLAMA_API float *  llama_get_sampled_logits_ith      (struct llama_context * ctx, int32_t i);
+    LLAMA_API uint32_t llama_get_sampled_logits_count_ith(struct llama_context * ctx, int32_t i);
+
+    // Get the backend sampled candidates (token ids) for the ith token
+    // These are needed to map probability/logit indices to vocab token ids.
+    // Returns NULL if no candidates were sampled.
+    LLAMA_API llama_token * llama_get_sampled_candidates_ith      (struct llama_context * ctx, int32_t i);
+    LLAMA_API uint32_t      llama_get_sampled_candidates_count_ith(struct llama_context * ctx, int32_t i);
+
     //
     // Vocab
     //
@@ -1163,11 +1202,16 @@ extern "C" {
     //
     //    llama_sampler_free(smpl);
     //
-    // TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
-    //
 
     typedef void * llama_sampler_context_t;
 
+    struct llama_sampler_data {
+        struct ggml_tensor * logits;
+        struct ggml_tensor * probs;
+        struct ggml_tensor * sampled;
+        struct ggml_tensor * candidates;
+    };
+
     // user code can implement the interface below in order to create custom llama_sampler
     struct llama_sampler_i {
         const char *           (*name)  (const struct llama_sampler * smpl);                                 // can be NULL
@@ -1177,17 +1221,45 @@ extern "C" {
         struct llama_sampler * (*clone) (const struct llama_sampler * smpl);                                 // can be NULL if ctx is NULL
         void                   (*free)  (      struct llama_sampler * smpl);                                 // can be NULL if ctx is NULL
 
-        // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph
-        //void (*apply_ggml) (struct llama_sampler * smpl, ...);
+        // [EXPERIMENTAL]
+        // backend sampling interface:
+
+        // return true if the backend supports all ops needed by the sampler
+        // note: call once per sampler
+        bool (*backend_init)(struct llama_sampler * smpl, ggml_backend_buffer_type_t buft);
+
+        // call after .backend_apply()
+        void (*backend_accept)(
+                struct llama_sampler * smpl,
+                struct ggml_context  * ctx,
+                struct ggml_cgraph   * gf,
+                struct ggml_tensor   * selected_token);
+
+        // call after .backend_init()
+        void (*backend_apply)(
+                struct llama_sampler      * smpl,
+                struct ggml_context       * ctx,
+                struct ggml_cgraph        * gf,
+                struct llama_sampler_data * data);
+
+        // called before graph execution to set inputs for the current ubatch
+        void (*backend_set_input)(struct llama_sampler * smpl);
     };
 
     struct llama_sampler {
-        const struct llama_sampler_i * iface;
-        llama_sampler_context_t        ctx;
+        struct llama_sampler_i * iface;
+
+        llama_sampler_context_t ctx;
     };
 
+    // [EXPERIMENTAL]
+    // attach a sampler to the context
+    // note: prefer initializing the context with llama_context_params.samplers when possible
+    // note: changing the samplers of a context can cause graph reallocations and degraded performance
+    LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl);
+
     // mirror of llama_sampler_i:
-    LLAMA_API struct llama_sampler * llama_sampler_init  (const struct llama_sampler_i * iface, llama_sampler_context_t ctx);
+    LLAMA_API struct llama_sampler * llama_sampler_init  (      struct llama_sampler_i * iface, llama_sampler_context_t ctx);
     LLAMA_API const char *           llama_sampler_name  (const struct llama_sampler * smpl);
     LLAMA_API void                   llama_sampler_accept(      struct llama_sampler * smpl, llama_token token);
     LLAMA_API void                   llama_sampler_apply (      struct llama_sampler * smpl, llama_token_data_array * cur_p);
@@ -1203,7 +1275,15 @@ extern "C" {
 
     // important: takes ownership of the sampler object and will free it when llama_sampler_free is called
     LLAMA_API void                   llama_sampler_chain_add(      struct llama_sampler * chain, struct llama_sampler * smpl);
-    LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i);
+
+    // return NULL if:
+    //   - the sampler is NULL
+    //   - the sampler is not a llama_sampler_chain
+    //   - the index is out of bounds, unless i == -1
+    //   - if i == -1, returns the chain itself (can be used to check if the sampler is a chain)
+    LLAMA_API struct llama_sampler * llama_sampler_chain_get(      struct llama_sampler * chain, int32_t i);
+
+    // the total number of samplers in the chain
     LLAMA_API int                    llama_sampler_chain_n  (const struct llama_sampler * chain);
 
     // after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed
@@ -1212,7 +1292,9 @@ extern "C" {
     // available samplers:
 
     LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void);
-    LLAMA_API struct llama_sampler * llama_sampler_init_dist  (uint32_t seed);
+
+    /// seed == LLAMA_DEFAULT_SEED to use a random seed.
+    LLAMA_API struct llama_sampler * llama_sampler_init_dist(uint32_t seed);
 
     /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
     /// Setting k <= 0 makes this a noop
index 0192e344ca03ada299b946765e2bc17863470331..6a752a403f6f12beab2446c8df741f7cfcfefe7d 100644 (file)
@@ -22,8 +22,15 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para
     const float kq_scale = 1.0f/sqrtf(float(n_embd_head));
 
     for (int il = 0; il < n_layer; ++il) {
+        const float freq_base_l  = model.get_rope_freq_base (cparams, il);
+        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
+
         ggml_tensor * inpSA = inpL;
 
+        // This overlaps with SWA layers in current models, so get_rope_freq_base/scale may be superfluous
+        const bool use_rope = hparams.n_no_rope_layer_step > 0 &&
+                              (il + 1) % hparams.n_no_rope_layer_step != 0;
+
         // dual attention normalization (pre)
         cur = build_norm(inpL,
                 model.layers[il].attn_norm, NULL,
@@ -56,19 +63,16 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para
             cb(Qcur, "Qcur_normed", il);
             cb(Kcur, "Kcur_normed", il);
 
-            // RoPE only for sliding_attention layers
-            const bool use_rope = hparams.n_no_rope_layer_step > 0 &&
-                                ((il + 1) % hparams.n_no_rope_layer_step) != 0;
             if (use_rope) {
                 Qcur = ggml_rope_ext(
                         ctx0, Qcur, inp_pos, nullptr,
-                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Qcur, "Qcur_rope", il);
 
                 Kcur = ggml_rope_ext(
                         ctx0, Kcur, inp_pos, nullptr,
-                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Kcur, "Kcur_rope", il);
             }
index 3274fa3b99dd1a2a042b9a102f5597fe77b61d4b..bca0e254fc51bc875abe8a65c224b37b6d89476e 100644 (file)
@@ -142,11 +142,13 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params
                     LLM_FFN_GELU, LLM_FFN_SEQ, il);
             cb(cur, "ffn_out", il);
         } else if (model.arch == LLM_ARCH_JINA_BERT_V2) {
+            const bool up_contains_gate = !model.layers[il].ffn_gate && model.layers[il].ffn_up->ne[1] != hparams.n_ff();
+            auto type_op = up_contains_gate ? LLM_FFN_GEGLU : LLM_FFN_GELU;
             cur = build_ffn(cur,
-                    model.layers[il].ffn_up, NULL, NULL,
+                    model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
                     model.layers[il].ffn_gate, NULL, NULL,
                     model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL,
-                    model.layers[il].ffn_gate ? LLM_FFN_GELU : LLM_FFN_GEGLU, LLM_FFN_PAR, il);
+                    type_op, LLM_FFN_PAR, il);
             cb(cur, "ffn_out", il);
         } else {
             cur = build_ffn(cur,
index edf0d1424ceaea43606f59c18b3fe3aa28a33959..0ceae3aaeb550cc812fca7735166b6bbcb1ad7be 100644 (file)
@@ -3,12 +3,14 @@
 llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params) {
     const int64_t n_embd_head = hparams.n_embd_head_v;
-    float         kq_scale    = 1.0f / sqrtf(float(n_embd_head));
+    const float   kq_scale    = 1.0f / sqrtf(float(n_embd_head));
 
     GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
     GGML_ASSERT(n_embd_head == hparams.n_rot);
 
-    ggml_tensor *inpL, *cur;
+    ggml_tensor * inpL;
+    ggml_tensor * cur;
+
     inpL = build_inp_embd(model.tok_embd);
 
     ggml_tensor * inp_pos = build_inp_pos();
@@ -44,7 +46,7 @@ llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_pa
         }
 
         ggml_tensor * inpSA = inpL;
-        cur                 = build_norm(inpSA, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+        cur = build_norm(inpSA, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
 
         // build self attention
         {
index b18aa8c4e6c69e44e53390c73737792eb4bb7b72..9334b5e42634f4905781bd2650d91d92a529dfae 100644 (file)
@@ -21,6 +21,9 @@ llm_build_cohere2_iswa::llm_build_cohere2_iswa(const llama_model & model, const
 
     for (int il = 0; il < n_layer; ++il) {
         const bool is_swa = hparams.is_swa(il);
+        // UNUSED:
+        // const float freq_base_l  = model.get_rope_freq_base (cparams, il);
+        // const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
 
         // norm
         cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM, il);
index 49382874baae10ec831ed07c24fd6f7933e1a57f..ca63a62ad1b1e3e6f69e0ba21a05cf8b71288112 100644 (file)
@@ -215,7 +215,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
                 model.layers[il].ffn_exp_probs_b,
                 n_expert, n_expert_used,
                 LLM_FFN_SILU, hparams.expert_weights_norm,
-                true, hparams.expert_weights_scale,
+                hparams.expert_weights_scale, hparams.expert_weights_scale,
                 (llama_expert_gating_func_type) hparams.expert_gating_func,
                 il);
             cb(moe_out, "ffn_moe_out", il);
index 90a98f7abf0fd3aaa8578884d86ea910e675bc71..944c198bf9502459101d966c58c0ef670a470474 100644 (file)
@@ -1,7 +1,5 @@
 #include "models.h"
 
-
-
 llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params) {
     const int64_t n_embd_head = hparams.n_embd_head_k;
@@ -12,10 +10,8 @@ llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model,
     inpL = build_inp_embd(model.tok_embd);
 
     // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
-    if (ubatch.token) {
-        inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
-        cb(inpL, "inp_scaled", -1);
-    }
+    inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f);
+    cb(inpL, "inp_scaled", -1);
 
     // inp_pos - contains the positions
     ggml_tensor * inp_pos = build_inp_pos();
index 9cc59a53ee5c1f24aec16fd10bafb1ab536586b1..7a9198193acc50b4f10f5fedae0caa0a2d5bae66 100644 (file)
@@ -19,6 +19,9 @@ llm_build_gemma2_iswa::llm_build_gemma2_iswa(const llama_model & model, const ll
     ggml_tensor * inp_out_ids = build_inp_out_ids();
 
     for (int il = 0; il < n_layer; ++il) {
+        const float freq_base_l  = model.get_rope_freq_base (cparams, il);
+        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
+
         // norm
         cur = build_norm(inpL,
                 model.layers[il].attn_norm, NULL,
@@ -43,12 +46,12 @@ llm_build_gemma2_iswa::llm_build_gemma2_iswa(const llama_model & model, const ll
 
             Qcur = ggml_rope_ext(
                     ctx0, Qcur, inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
                     ext_factor, attn_factor, beta_fast, beta_slow);
 
             Kcur = ggml_rope_ext(
                     ctx0, Kcur, inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
                     ext_factor, attn_factor, beta_fast, beta_slow);
 
             cb(Qcur, "Qcur", il);
index ae60ef4790c9796d05d80f7dbd31683997aee200..dec3fc4b8bc3962217d4f5ecb99c6cb3260034bb 100644 (file)
@@ -10,10 +10,9 @@ llm_build_gemma3<iswa>::llm_build_gemma3(const llama_model & model, const llm_gr
     inpL = build_inp_embd(model.tok_embd);
 
     // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
-    if (ubatch.token) {
-        inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
-        cb(inpL, "inp_scaled", -1);
-    }
+    inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f);
+    cb(inpL, "inp_scaled", -1);
+
     // inp_pos - contains the positions
     ggml_tensor * inp_pos = build_inp_pos();
 
index a0bdd6a15a123efa44d6f9fbe0813fffd6ff1742..93defbeef9c1d8cd0f477d8530a3342db92d0588 100644 (file)
@@ -1,7 +1,5 @@
 #include "models.h"
 
-
-
 llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params),
     model(model),
@@ -15,10 +13,9 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
     inpL = build_inp_embd(model.tok_embd);
 
     // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
-    if (ubatch.token) {
-        inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
-        cb(inpL, "inp_scaled", -1);
-    }
+    inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f);
+    cb(inpL, "inp_scaled", -1);
+
     // inp_pos - contains the positions
     ggml_tensor * inp_pos = build_inp_pos();
 
@@ -248,7 +245,7 @@ ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) {
 // equivalent to get_per_layer_inputs() in python code
 // output shape: [n_embd_altup, n_layer, n_tokens]
 ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
-    auto          inp = std::make_unique<llm_graph_input_embd>();
+    auto inp = std::make_unique<llm_graph_input_embd>();
     ggml_tensor * inp_per_layer;
     if (ubatch.token) {
         inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
@@ -258,10 +255,20 @@ ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
         inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens);
         inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup));
         cb(inp_per_layer, "inp_per_layer_selected", -1);
+        res->add_input(std::move(inp));
     } else {
-        GGML_ABORT("TODO: support embd input");
+        // Vision embedding path: use padding token (ID=0) embedding
+        const int64_t embd_size = model.tok_embd_per_layer->ne[0];  // n_embd_altup * n_layer
+
+        // Extract and dequantize padding token embedding (column 0)
+        ggml_tensor * padding_q = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0);
+        ggml_tensor * padding_f32 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, embd_size);
+        inp_per_layer = ggml_cpy(ctx0, padding_q, padding_f32);
+
+        // Reshape to [n_embd_altup, n_layer, 1]
+        inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, 1);
+        cb(inp_per_layer, "inp_per_layer_vision", -1);
     }
-    res->add_input(std::move(inp));
     return inp_per_layer;
 }
 
@@ -279,7 +286,7 @@ ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inp
                                               -1);  // [n_embd_altup, n_layer, n_tokens]
     cb(per_layer_proj, "per_layer_proj", -1);
 
-    inp_per_layer = ggml_add(ctx0, inp_per_layer, per_layer_proj);
+    inp_per_layer = ggml_add(ctx0, per_layer_proj, inp_per_layer);
     inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale);
     cb(inp_per_layer, "inp_per_layer", -1);
 
index 03f80616821149b797c7c692fc08175d6dabc338..61dd2c179f1dd18d967f91e39670d1030add45c6 100644 (file)
@@ -25,8 +25,12 @@ llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_
     ggml_tensor * inp_out_ids = build_inp_out_ids();
 
     for (int il = 0; il < n_layer; ++il) {
+        const float freq_base_l  = model.get_rope_freq_base (cparams, il);
+        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
+
         ggml_tensor * inpSA = inpL;
 
+        // This overlaps with SWA layers in current models, so get_rope_freq_base/scale may be superfluous
         const bool use_rope = hparams.n_no_rope_layer_step > 0 &&
                               (il + 1) % hparams.n_no_rope_layer_step != 0;
 
@@ -67,13 +71,13 @@ llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_
             if (use_rope) {
                 Qcur = ggml_rope_ext(
                         ctx0, Qcur, inp_pos, rope_factors,
-                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
 
                 Kcur = ggml_rope_ext(
                         ctx0, Kcur, inp_pos, rope_factors,
-                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
                         ext_factor, attn_factor, beta_fast, beta_slow
                         );
             } else if (inp_attn_scale) {
diff --git a/examples/talk-llama/models/maincoder.cpp b/examples/talk-llama/models/maincoder.cpp
new file mode 100644 (file)
index 0000000..da57308
--- /dev/null
@@ -0,0 +1,117 @@
+#include "models.h"
+
+llm_build_maincoder::llm_build_maincoder(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+    const int64_t n_embd_head = hparams.n_embd_head_v;
+
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+
+    // inp_pos - contains the positions
+    ggml_tensor * inp_pos = build_inp_pos();
+
+    auto * inp_attn = build_attn_inp_kv();
+
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    for (int il = 0; il < n_layer; ++il) {
+        ggml_tensor * inpSA = inpL;
+
+        // norm
+        cur = build_norm(inpL,
+                model.layers[il].attn_norm, NULL,
+                LLM_NORM_RMS, il);
+        cb(cur, "attn_norm", il);
+
+        // self-attention
+        {
+            // compute Q and K and RoPE them
+            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+            cb(Qcur, "Qcur", il);
+
+            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+            cb(Kcur, "Kcur", il);
+
+            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+            cb(Vcur, "Vcur", il);
+
+            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+            Qcur = ggml_rope_ext(
+                    ctx0, Qcur, inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                    );
+
+            Kcur = ggml_rope_ext(
+                    ctx0, Kcur, inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                    );
+
+            Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+            cb(Qcur, "Qcur_normed", il);
+
+            Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+            cb(Kcur, "Kcur_normed", il);
+
+            cb(Qcur, "Qcur", il);
+            cb(Kcur, "Kcur", il);
+            cb(Vcur, "Vcur", il);
+
+            cur = build_attn(inp_attn,
+                    model.layers[il].wo, model.layers[il].bo,
+                    Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+        }
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+        cb(ffn_inp, "ffn_inp", il);
+
+        // feed-forward network
+        cur = build_norm(ffn_inp,
+                model.layers[il].ffn_norm, NULL,
+                LLM_NORM_RMS, il);
+        cb(cur, "ffn_norm", il);
+
+        cur = build_ffn(cur,
+                model.layers[il].ffn_up,   NULL, NULL,
+                model.layers[il].ffn_gate, NULL, NULL,
+                model.layers[il].ffn_down, NULL, NULL,
+                NULL,
+                LLM_FFN_SILU, LLM_FFN_PAR, il);
+        cb(cur, "ffn_out", il);
+
+        cur = ggml_add(ctx0, cur, ffn_inp);
+
+        cur = build_cvec(cur, il);
+        cb(cur, "l_out", il);
+
+        // input for next layer
+        inpL = cur;
+    }
+    cur = inpL;
+
+    cur = build_norm(cur,
+            model.output_norm, NULL,
+            LLM_NORM_RMS, -1);
+
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    // lm_head
+    cur = build_lora_mm(model.output, cur);
+
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
index e2cd4e484f797b45b559a0ca435d737455820c21..6c40f48042b4bffb1d1c47352db138c011f4212f 100644 (file)
@@ -312,6 +312,10 @@ struct llm_build_llama_iswa : public llm_graph_context {
     llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params);
 };
 
+struct llm_build_maincoder : public llm_graph_context {
+    llm_build_maincoder(const llama_model & model, const llm_graph_params & params);
+};
+
 struct llm_build_mamba : public llm_graph_context_mamba {
     llm_build_mamba(const llama_model & model, const llm_graph_params & params);
 };
@@ -332,7 +336,6 @@ struct llm_build_mistral3 : public llm_graph_context {
     llm_build_mistral3(const llama_model & model, const llm_graph_params & params);
 };
 
-template <bool iswa>
 struct llm_build_modern_bert : public llm_graph_context {
     llm_build_modern_bert(const llama_model & model, const llm_graph_params & params);
 };
@@ -463,7 +466,8 @@ private:
                 ggml_tensor * cur,
                         int   il);
 
-    ggml_tensor * build_delta_net_chunking(
+    // returns pair of output and new state
+    std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_chunking(
                 ggml_tensor * q,
                 ggml_tensor * k,
                 ggml_tensor * v,
@@ -475,7 +479,8 @@ private:
                 ggml_tensor * diag_mask,
                         int   il);
 
-    ggml_tensor * build_delta_net_autoregressive(
+    // returns pair of output and new state
+    std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_autoregressive(
                 ggml_tensor * q,
                 ggml_tensor * k,
                 ggml_tensor * v,
@@ -490,6 +495,11 @@ private:
                 ggml_tensor * gate,
                         int   layer);
 
+    // returns pair of qkv, z
+    std::pair<ggml_tensor *, ggml_tensor *> build_qkvz(
+                ggml_tensor * input,
+                        int   il);
+
     const llama_model & model;
 };
 
index c7809bdedfa8ee4dd434f645b377b5f260577231..bb12ed819f735122a5ce3068bf26e8c05809e298 100644 (file)
@@ -1,7 +1,6 @@
 #include "models.h"
 
-template <bool iswa>
-llm_build_modern_bert<iswa>::llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
     const int64_t n_embd_head = hparams.n_embd_head_v;
     const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
@@ -24,13 +23,8 @@ llm_build_modern_bert<iswa>::llm_build_modern_bert(const llama_model & model, co
     auto * inp_attn = build_attn_inp_no_cache();
 
     for (int il = 0; il < n_layer; ++il) {
-        float freq_base_l  = 0.0f;
-
-        if constexpr (iswa) {
-            freq_base_l = model.get_rope_freq_base(cparams, il);
-        } else {
-            freq_base_l = freq_base;
-        }
+        const float freq_base_l  = model.get_rope_freq_base(cparams, il);
+        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
 
         cur = inpL;
 
@@ -55,13 +49,13 @@ llm_build_modern_bert<iswa>::llm_build_modern_bert(const llama_model & model, co
         // RoPE
         Qcur = ggml_rope_ext(
                 ctx0, Qcur, inp_pos, nullptr,
-                n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale,
+                n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
                 ext_factor, attn_factor, beta_fast, beta_slow
                 );
 
         Kcur = ggml_rope_ext(
                 ctx0, Kcur, inp_pos, nullptr,
-                n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale,
+                n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
                 ext_factor, attn_factor, beta_fast, beta_slow
                 );
 
@@ -120,7 +114,3 @@ llm_build_modern_bert<iswa>::llm_build_modern_bert(const llama_model & model, co
     res->t_embd = cur;
     ggml_build_forward_expand(gf, cur);
 }
-
-// Explicit template instantiations
-template struct llm_build_modern_bert<false>;
-template struct llm_build_modern_bert<true>;
index 96596709eec5611e14c95ef74564ac2381f8cd88..dbe3ca1851feae5131c45665ff8a0bdd6fdba10e 100644 (file)
@@ -14,6 +14,9 @@ llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model,
     ggml_tensor * inp_out_ids = build_inp_out_ids();
 
     for (int il = 0; il < n_layer; ++il) {
+        const float freq_base_l  = model.get_rope_freq_base (cparams, il);
+        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
+
         ggml_tensor * inpSA = inpL;
 
         // norm
@@ -49,13 +52,13 @@ llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model,
 
             Qcur = ggml_rope_ext(
                     ctx0, Qcur, inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
                     ext_factor, attn_factor, beta_fast, beta_slow
                     );
 
             Kcur = ggml_rope_ext(
                     ctx0, Kcur, inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
                     ext_factor, attn_factor, beta_fast, beta_slow
                     );
 
index 775b3135d3507b5449edaf7897f9d7721249c891..57b6659baf0869d8ef28a5a5f1e273b3cc3e1a05 100644 (file)
@@ -86,7 +86,15 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
     ggml_build_forward_expand(gf, cur);
 }
 
-ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
+// utility to get one slice from the third dimension
+// input dim:  [x, y, c, b]
+// output dim: [x, y, 1, b]
+static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) {
+    return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3],
+        t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
+}
+
+std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chunking(
         ggml_tensor * q,
         ggml_tensor * k,
         ggml_tensor * v,
@@ -187,18 +195,16 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
     beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);
 
     ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
+    cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
 
-    cb(g_cumsum, "g_cumsum", il);
-
-    ggml_tensor * gcs_i = ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs);
+    ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs);
     ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
 
     ggml_tensor * gcs_j_broadcast =
         ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
 
     ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
-
-    cb(decay_mask, "decay_mask", il);
+    cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
 
     decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
     decay_mask = ggml_exp(ctx0, decay_mask);
@@ -208,8 +214,7 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
 
     ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
     ggml_tensor * attn    = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
-
-    cb(attn, "attn_pre_solve", il);
+    cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
 
     ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
     ggml_tensor * lhs        = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
@@ -217,8 +222,7 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
     ggml_tensor * lin_solve  = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
     attn                     = ggml_mul(ctx0, lin_solve, causal_mask);
     attn                     = ggml_add(ctx0, attn, identity);
-
-    cb(attn, "attn_solved", il);
+    cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
 
     v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
 
@@ -226,116 +230,126 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(
     ggml_tensor * gexp       = ggml_exp(ctx0, g_cumsum_t);
 
     ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
-
-    cb(kbeta_gexp, "kbeta_gexp", il);
+    cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
 
     ggml_tensor * k_cumdecay =
         ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)))));
+    cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
 
-    cb(k_cumdecay, "k_cumdecay", il);
+    ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q);
+    attn_kq = ggml_mul(ctx0, attn_kq, decay_mask);
+    attn_kq = ggml_mul(ctx0, attn_kq, diag_mask);
+    cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
 
-    ggml_tensor * core_attn_out = nullptr;
-    ggml_tensor * new_state = ggml_dup(ctx0, state);
 
-    cb(new_state, "new_state", il);
+    // vectorized calculation of key_gdiff
+    // improved from the chunked version:
+    //   g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
+    //   g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
+    //   key_gdiff = key * g_diff.unsqueeze(-1)
+    //   kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
+    //   last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
 
-    for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
-        auto chunkify = [=](ggml_tensor * t) {
-            return ggml_cont(ctx0, ggml_view_4d(ctx0, t, t->ne[0], chunk_size, 1, t->ne[3],
-                t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
-        };
+    // get last element in g_cumsum along chunk_size dimension (ne0)
+    // example: [[x, y, z, ..., last], ...] -> [[last], ...]
+    ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3],
+                                        g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3],
+                                        (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum));
+    g_last = ggml_cont(ctx0, g_last);
+    cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs)
 
-        auto chunkify_g = [=](ggml_tensor * t) {
-            return ggml_cont(ctx0, ggml_view_4d(ctx0, t, chunk_size, t->ne[1], 1, t->ne[3],
-                t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
-        };
+    ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last);
+    cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs)
+
+    ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last));
+    cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
+
+    ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
+    ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp);
+    cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
+
+
+    // state to be updated per chunk
+    ggml_tensor * new_state = state; // ggml_dup(ctx0, state);
+    cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs)
 
-        ggml_tensor * k_chunk = chunkify(k);
-        ggml_tensor * q_chunk = chunkify(q);
-        ggml_tensor * v_chunk = chunkify(v);
+    // shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs)
+    ggml_tensor * core_attn_out = nullptr;
+
+    for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
+        // shape: (S_k, chunk_size, 1, H_k * n_seqs)
+        ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul
 
-        ggml_tensor * g_cs_chunk = chunkify_g(g_cumsum);
-        ggml_tensor * g_cs_chunk_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cs_chunk));
+        // shape: (S_v, chunk_size, 1, H_v * n_seqs)
+        ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat
 
-        ggml_tensor * decay_mask_chunk = chunkify(decay_mask);
-        ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay);
+        // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
+        ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul
 
-        ggml_tensor * gexp_chunk = ggml_exp(ctx0, g_cs_chunk_t);
+        // shape: (chunk_size, 1, H_v * n_seqs)
+        ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat
 
         // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
-        attn = ggml_mul_mat(ctx0, k_chunk, q_chunk);
-        attn = ggml_mul(ctx0, attn, decay_mask_chunk);
-        attn = ggml_mul(ctx0, attn, diag_mask);
+        // replaced by precomputed attn_kq
+        ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk);
+        cb(attn_chunk, "attn_chunk", il);
 
         ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
 
         // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
         ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
+        cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs)
 
         // v_new = v_i - v_prime
         ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime);
         ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
+        cb(v_new, "v_new_chunk", il);
 
         // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
         ggml_tensor * q_g_exp    = ggml_mul(ctx0, q_chunk, gexp_chunk);
         ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
+        cb(attn_inter, "attn_inter_chunk", il);
 
         // core_attn_out[:, :, i] = attn_inter + attn @ v_new
-        ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn);
+        ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk);
+        cb(v_attn, "v_attn_chunk", il);
 
         ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
+        cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs)
 
-        core_attn_out = core_attn_out == nullptr ? core_attn_out_chunk : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 1);
+        core_attn_out = core_attn_out == nullptr
+            ? core_attn_out_chunk
+            : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2);
 
-        // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
-        // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
-        // key_gdiff = key * g_diff.unsqueeze(-1)
         // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
-        // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
-
-        ggml_tensor * g_cum_last =
-            ggml_cont(ctx0, ggml_view_4d(ctx0, g_cs_chunk_t, g_cs_chunk_t->ne[0], 1, g_cs_chunk_t->ne[2], g_cs_chunk_t->ne[3],
-                                        g_cs_chunk_t->nb[1], g_cs_chunk_t->nb[2], g_cs_chunk_t->nb[3],
-                                        g_cs_chunk_t->nb[0] * (g_cs_chunk_t->ne[1] - 1)));
-
-        ggml_tensor * gexp_last =
-            ggml_reshape_4d(ctx0, ggml_exp(ctx0, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]);
-
-        ggml_tensor * g_cum_last_3d =
-            ggml_reshape_3d(ctx0, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]);
-
-        ggml_tensor * g_cumsum_3d = ggml_reshape_3d(ctx0, g_cs_chunk, g_cs_chunk->ne[0], g_cs_chunk->ne[2], g_cs_chunk->ne[3]);
-
-        ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum_3d, g_cum_last_3d));
-
-        ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
-
-        ggml_tensor * key_gdiff = ggml_mul(ctx0, k_chunk,
-                                        ggml_reshape_4d(ctx0, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1],
-                                                        g_diff_exp->ne[2] * g_diff_exp->ne[3]));
-
-        ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)));
+        ggml_tensor * k_gdiff = ggml_cont(ctx0, get_slice_2d(ctx0, key_gdiff, chunk));
+        //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why?
+        ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, k_gdiff)));
 
+        // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
+        ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk));
         new_state = ggml_add(ctx0,
-            ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last, gexp_last->ne[0], gexp_last->ne[1], H_v, n_seqs)),
+            ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)),
             ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
     }
 
-    core_attn_out = ggml_cont_4d(ctx0, core_attn_out, S_v, chunk_size * n_chunks, H_v, n_seqs);
-
-    ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, S_v, n_tokens, H_v, n_seqs, core_attn_out->nb[1], core_attn_out->nb[2], core_attn_out->nb[3], 0);
+    // truncate padded tokens
+    ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
+            S_v, n_tokens, H_v, n_seqs,
+            ggml_row_size(core_attn_out->type, S_v),
+            ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks),
+            ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0);
+    output_tokens = ggml_cont(ctx0, output_tokens);
     cb(output_tokens, "output_tokens", il);
 
-    // flatten output
-    ggml_tensor * flat_output =
-        ggml_cont_1d(ctx0, ggml_permute(ctx0, output_tokens, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs);
-
-    ggml_tensor * flat_state = ggml_cont_1d(ctx0, new_state, S_v * S_v * H_v * n_seqs);
+    // permute back to (S_v, H_v, n_tokens, n_seqs)
+    output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
+    output_tokens = ggml_cont(ctx0, output_tokens);
 
-    return ggml_concat(ctx0, flat_output, flat_state, 0);
+    return {output_tokens, new_state};
 }
 
-ggml_tensor * llm_build_qwen3next::build_delta_net_autoregressive(
+std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_autoregressive(
         ggml_tensor * q,
         ggml_tensor * k,
         ggml_tensor * v,
@@ -419,11 +433,7 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_autoregressive(
     cb(core_attn_out, "output_tokens", il);
     cb(state, "new_state", il);
 
-    // flatten output, no need to permute since n_tokens is 1 so [S_v, 1, H_v, n_seqs] and [S_v, H_v, 1, n_seqs] are equivalent memory-layout wise
-    ggml_tensor * flat_output = ggml_reshape_1d(ctx0, core_attn_out, S_v * H_v * n_tokens * n_seqs);
-    ggml_tensor * flat_state  = ggml_reshape_1d(ctx0, state, S_v * S_v * H_v * n_seqs);
-
-    return ggml_concat(ctx0, flat_output, flat_state, 0);
+    return {core_attn_out, state};
 }
 
 ggml_tensor * llm_build_qwen3next::build_norm_gated(
@@ -523,6 +533,88 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
     return cur;
 }
 
+std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_qkvz(
+                ggml_tensor * input,
+                        int   il) {
+    const int64_t d_inner      = hparams.ssm_d_inner;
+    const int64_t n_seqs       = ubatch.n_seqs;
+    const int64_t head_k_dim   = hparams.ssm_d_state;
+    const int64_t num_k_heads  = hparams.ssm_n_group;
+    const int64_t num_v_heads  = hparams.ssm_dt_rank;
+    const int64_t head_v_dim   = d_inner / num_v_heads;
+    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+    if (model.layers[il].wqkv) {
+        // optimized path
+        ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input);
+        qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs);
+        cb(qkv_mixed, "linear_attn_qkv_mixed", il);
+
+        ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input);
+        cb(z, "z", il);
+
+        return { qkv_mixed, z };
+
+    } else {
+        // legacy (slower) path
+        ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, input);
+        cb(mixed_qkvz, "linear_attn_mixed_qkvz", il);
+
+        int64_t       qkvz_new_dim        = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads);
+        ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs);
+
+        // Split mixed_qkvz into query, key, value, z
+        int64_t split_sizes_qkvz[4] = {
+            head_k_dim,                              // query size
+            head_k_dim,                              // key size
+            head_v_dim * num_v_heads / num_k_heads,  // value size
+            head_v_dim * num_v_heads / num_k_heads   // z size
+        };
+
+        ggml_tensor * query =
+            ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs,
+                        mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0);
+        cb(query, "q", il);
+
+        ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs,
+                                        mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
+                                        split_sizes_qkvz[0] * ggml_element_size(mixed_qkvz_reshaped));
+        cb(key, "k", il);
+
+        ggml_tensor * value =
+            ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs,
+                        mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
+                        (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * ggml_element_size(mixed_qkvz_reshaped));
+        cb(value, "v", il);
+
+        ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs,
+                                    mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
+                                    (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * ggml_element_size(mixed_qkvz_reshaped));
+        z = ggml_cont(ctx0, z);
+        cb(z, "z", il);
+
+        // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions
+        // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
+        ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
+        cb(query_flat, "query_flat", il);
+
+        // key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
+        ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
+        cb(key_flat, "key_flat", il);
+
+        // value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs]
+        ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
+        cb(value_flat, "value_flat", il);
+
+        // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs]
+        ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
+        qkv_mixed               = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
+        cb(qkv_mixed, "qkv_mixed", il);
+
+        return { qkv_mixed, z };
+    }
+}
+
 ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
         llm_graph_input_rs * inp,
         ggml_tensor *        cur,
@@ -547,15 +639,13 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
 
     // Input projections
-    ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, cur);
-    cb(mixed_qkvz, "linear_attn_mixed_qkvz", il);
+    auto qkvz = build_qkvz(cur, il);
+    ggml_tensor * qkv_mixed = qkvz.first;
+    ggml_tensor * z         = qkvz.second;
 
     ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur);
     cb(mixed_ba, "linear_attn_mixed_ba", il);
 
-    int64_t       qkvz_new_dim        = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads);
-    ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs);
-
     // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads]
     int64_t       ba_new_dim        = 2 * num_v_heads / num_k_heads;
     ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_seq_tokens, n_seqs);
@@ -575,8 +665,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
                                    split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped));
     cb(a, "a", il);
 
-    // Reshape b and a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
-    ggml_tensor * beta  = ggml_cont_3d(ctx0, b, num_v_heads, n_seq_tokens, n_seqs);
+    ggml_tensor * beta  = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs);
+
+    // Reshape a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
     ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs);
 
     ggml_tensor * alpha_biased   = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
@@ -585,48 +676,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a);  // -A_log.exp() * softplus
     cb(gate, "gate", il);
 
-    // Split mixed_qkvz into query, key, value, z
-    int64_t split_sizes_qkvz[4] = {
-        head_k_dim,                              // query size
-        head_k_dim,                              // key size
-        head_v_dim * num_v_heads / num_k_heads,  // value size
-        head_v_dim * num_v_heads / num_k_heads   // z size
-    };
-
-    ggml_tensor * query =
-        ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs,
-                     mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0);
-    cb(query, "q", il);
-
-    ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs,
-                                     mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
-                                     split_sizes_qkvz[0] * sizeof(float));
-    cb(key, "k", il);
-
-    ggml_tensor * value =
-        ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs,
-                     mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
-                     (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float));
-    cb(value, "v", il);
-
-    ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs,
-                                   mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3],
-                                   (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float));
-    cb(z, "z", il);
-
-    // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions
-    // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
-    ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
-    cb(query_flat, "query_flat", il);
-
-    // key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs]
-    ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs);
-    cb(key_flat, "key_flat", il);
-
-    // value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs]
-    ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
-    cb(value_flat, "value_flat", il);
-
     // Get convolution states from cache
     ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
     ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
@@ -637,17 +686,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
     cb(conv_states, "conv_states", il);
 
-    // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs]
-    ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0);
-    qkv_mixed               = ggml_concat(ctx0, qkv_mixed, value_flat, 0);
-    cb(qkv_mixed, "qkv_mixed", il);
-
-    qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3);
-    cb(qkv_mixed, "qkv_mixed_permuted", il);
-
-    // Calculate the total conv dimension
-    int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
-
     // Calculate convolution kernel size
     ggml_tensor * conv_kernel      = model.layers[il].ssm_conv1d;
     const int64_t conv_kernel_size = conv_kernel->ne[0];
@@ -655,6 +693,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     conv_states                    = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
     cb(conv_states, "conv_states_reshaped", il);
 
+    qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3);
+    cb(qkv_mixed, "qkv_mixed_permuted", il);
+
     ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
     cb(conv_input, "conv_input", il);
 
@@ -677,26 +718,25 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel);
     cb(conv_output_proper, "conv_output_raw", il);
 
-    conv_output_proper = ggml_cont(ctx0, ggml_transpose(ctx0, conv_output_proper));
-    cb(conv_output_proper, "conv_output_pre_silu", il);
-
     ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper);
     cb(conv_output_silu, "conv_output_silu", il);
 
-    ggml_tensor * conv_qkv_mix =
-        ggml_cont_2d(ctx0, ggml_transpose(ctx0, conv_output_silu), qkv_dim, n_seq_tokens * n_seqs);
-    cb(conv_qkv_mix, "conv_qkv_mix", il);
+    ggml_tensor * conv_qkv_mix = conv_output_silu;
+
+    // Calculate the total conv dimension
+    int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
+    int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim);
 
     // Extract the convolved Q, K, V from conv_output
     ggml_tensor * q_conv =
-        ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1], 0);
+        ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0);
     cb(q_conv, "q_conv", il);
     ggml_tensor * k_conv =
-        ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1],
+        ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv,
                      head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
     cb(k_conv, "k_conv", il);
     ggml_tensor * v_conv =
-        ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1],
+        ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv,
                      2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
     cb(v_conv, "v_conv", il);
 
@@ -705,8 +745,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
     v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
 
-    beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs);
-
     ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
     state               = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs);
     cb(state, "state_predelta", il);
@@ -738,45 +776,29 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     cb(v_conv, "v_conv_predelta", il);
 
     // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens
-    ggml_tensor * attn_out;
+    std::pair<ggml_tensor *, ggml_tensor *> attn_out; // pair of (output, new_state)
     if (n_seq_tokens == 1) {
         attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
     } else {
         attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
     }
-    cb(attn_out, "attn_out", il);
-
-    // The tensors were concatenated 1d, so we need to extract them 1d as well
-    const int64_t output_flat_size = head_v_dim * num_v_heads * n_seq_tokens * n_seqs;
-    ggml_tensor * attn_out_1d      = ggml_view_1d(ctx0, attn_out, output_flat_size, 0);
-    cb(attn_out_1d, "attn_out_1d", il);
-
-    ggml_tensor * attn_out_final = ggml_cont_4d(ctx0, attn_out_1d, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
-    cb(attn_out_final, "attn_out_reshaped", il);
-
-    // Extract the state part (second part of the concatenated tensor)
-    // State starts after n_tokens elements along dimension 1
-    const int64_t state_flat_size = head_v_dim * head_v_dim * num_v_heads * n_seqs;
-
-    ggml_tensor * state_1d =
-        ggml_view_1d(ctx0, attn_out, state_flat_size, output_flat_size * ggml_element_size(attn_out));
-    cb(state_1d, "state_1d", il);
+    ggml_tensor * output    = attn_out.first;
+    ggml_tensor * new_state = attn_out.second;
+    cb(output, "attn_output", il);
+    cb(new_state, "new_state", il);
 
     // Update the recurrent states
     ggml_build_forward_expand(gf,
-                              ggml_cpy(ctx0, state_1d,
+                              ggml_cpy(ctx0, new_state,
                                        ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
                                                     kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
 
-    GGML_ASSERT(ggml_nelements(attn_out_1d) + ggml_nelements(state_1d) == ggml_nelements(attn_out));
-
     // Reshape both attn_out_final and z to 2D tensors for normalization
     // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
-    ggml_tensor * attn_out_2d_final =
-        ggml_cont_2d(ctx0, attn_out_final, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
+    ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
 
     // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
-    ggml_tensor * z_2d = ggml_cont_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
+    ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
 
     // Apply gated normalization: self.norm(core_attn_out, z)
     ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
@@ -828,12 +850,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int
             shared_gate = ggml_sigmoid(ctx0, shared_gate);
             cb(shared_gate, "shared_expert_gate_sigmoid", il);
 
-            // The gate needs to be broadcast to match the dimensions of ffn_shexp
-            // ffn_shexp is [n_embd, n_tokens, 1, 1] and shared_gate is [1, n_tokens, 1, 1]
-            // We need to repeat the gate along the feature dimension
-            shared_gate = ggml_repeat(ctx0, shared_gate, ffn_shexp);
-            cb(shared_gate, "shared_expert_gate_broadcast", il);
-
             // Apply the gate to the shared expert output
             ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
             cb(ffn_shexp, "ffn_shexp_gated", il);
index 277eec29554940a36d2416f5deeec058283b7de7..4c497ca76f4295b0b85c7f3a0b4806c4bd900a5d 100644 (file)
@@ -26,10 +26,16 @@ llm_build_smallthinker<iswa>::llm_build_smallthinker(const llama_model & model,
     ggml_tensor * inp_out_ids = build_inp_out_ids();
 
     for (int il = 0; il < n_layer; ++il) {
+        const float freq_base_l  = model.get_rope_freq_base (cparams, il);
+        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
+
         ggml_tensor * inpSA  = inpL;
-        ggml_tensor * probs  = nullptr;
 
-        probs = build_lora_mm(model.layers[il].ffn_gate_inp, inpL);  // [n_expert, n_tokens]
+        // This overlaps with SWA layers in current models, so get_rope_freq_base/scale may be superfluous
+        const bool use_rope = hparams.n_no_rope_layer_step == n_layer ||
+                              il % hparams.n_no_rope_layer_step != 0;
+
+        ggml_tensor * probs = build_lora_mm(model.layers[il].ffn_gate_inp, inpL);  // [n_expert, n_tokens]
         cb(probs, "ffn_moe_logits", il);
 
         // norm
@@ -52,11 +58,11 @@ llm_build_smallthinker<iswa>::llm_build_smallthinker(const llama_model & model,
             Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
             Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
-            if (hparams.n_no_rope_layer_step == n_layer || il % hparams.n_no_rope_layer_step != 0) {
-                Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+            if (use_rope) {
+                Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
                                     ext_factor, attn_factor, beta_fast, beta_slow);
 
-                Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
                                     ext_factor, attn_factor, beta_fast, beta_slow);
             }
             cb(Qcur, "Qcur", il);
index bb44edfaddffdbf159c26629a815fa8c7a4d8cd8..b47dcbe6198a82cd42ab171925aa1200eb1603e8 100644 (file)
@@ -964,6 +964,11 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
         { "\\p{P}", unicode_cpt_flags::PUNCTUATION },
         { "\\p{M}", unicode_cpt_flags::ACCENT_MARK },
         { "\\p{S}", unicode_cpt_flags::SYMBOL },
+        { "\\p{Lu}", unicode_cpt_flags::LETTER }, // Uppercase letter
+        { "\\p{Ll}", unicode_cpt_flags::LETTER }, // Lowercase letter
+        { "\\p{Lt}", unicode_cpt_flags::LETTER }, // Titlecase letter
+        { "\\p{Lm}", unicode_cpt_flags::LETTER }, // Modifier letter
+        { "\\p{Lo}", unicode_cpt_flags::LETTER }, // Other letter
     };
 
     static const std::map<int, int> k_ucat_cpt = {
@@ -1074,22 +1079,26 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
                         continue;
                     }
 
-                    if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() &&
+                    // Match \p{...} Unicode properties of varying lengths
+                    if (regex_expr[i + 0] == '\\' && i + 3 < regex_expr.size() &&
                         regex_expr[i + 1] == 'p' &&
-                        regex_expr[i + 2] == '{' &&
-                        regex_expr[i + 4] == '}') {
-                        const std::string pat = regex_expr.substr(i, 5);
-                        if (k_ucat_enum.find(pat) != k_ucat_enum.end()) {
-                            if (!inside) {
-                                regex_expr_collapsed += '[';
+                        regex_expr[i + 2] == '{') {
+                        // Find the closing brace
+                        size_t closing_brace = regex_expr.find('}', i + 3);
+                        if (closing_brace != std::string::npos && closing_brace <= i + 10) { // reasonable limit
+                            const std::string pat = regex_expr.substr(i, closing_brace - i + 1);
+                            if (k_ucat_enum.find(pat) != k_ucat_enum.end()) {
+                                if (!inside) {
+                                    regex_expr_collapsed += '[';
+                                }
+                                regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat));
+                                regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat));
+                                if (!inside) {
+                                    regex_expr_collapsed += ']';
+                                }
+                                i = closing_brace;
+                                continue;
                             }
-                            regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat));
-                            regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat));
-                            if (!inside) {
-                                regex_expr_collapsed += ']';
-                            }
-                            i += 4;
-                            continue;
                         }
                     }