static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
- const struct llama_model * model = llama_get_model(ctx);
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_self_clear(ctx);
// run model
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
- if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) {
- // encoder-only model
- if (llama_encode(ctx, batch) < 0) {
- LOG_ERR("%s : failed to encode\n", __func__);
- }
- } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
- // decoder-only model
- if (llama_decode(ctx, batch) < 0) {
- LOG_ERR("%s : failed to decode\n", __func__);
- }
+ if (llama_encode(ctx, batch) < 0) {
+ LOG_ERR("%s : failed to encode\n", __func__);
}
for (int i = 0; i < batch.n_tokens; i++) {
// Frees a batch of tokens allocated with llama_batch_init()
LLAMA_API void llama_batch_free(struct llama_batch batch);
- // Processes a batch of tokens with the ecoder part of the encoder-decoder model.
- // Stores the encoder output internally for later use by the decoder cross-attention layers.
+ // Process a batch of tokens.
+ // In contrast to llama_decode() - this call does not use KV cache.
+ // For encode-decoder contexts, processes the batch using the encoder.
+ // Can store the encoder output internally for later use by the decoder's cross-attention layers.
// 0 - success
// < 0 - error. the KV cache state is restored to the state before this call
LLAMA_API int32_t llama_encode(
struct llama_context * ctx,
struct llama_batch batch);
+ // Process a batch of tokens.
+ // Requires KV cache.
+ // For encode-decoder contexts, processes the batch using the decoder.
// Positive return values does not mean a fatal error, but rather a warning.
// 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
}
// reserve worst-case graph
- if (!hparams.vocab_only) {
+ if (!hparams.vocab_only && memory) {
const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
t_compute_start_us = ggml_time_us();
}
+ embd_seq.clear();
+
n_queued_tokens += n_tokens;
const int64_t n_embd = hparams.n_embd;
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
GGML_ASSERT(backend_embd != nullptr);
- GGML_ASSERT(embd != nullptr);
-
switch (cparams.pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
{
// extract token embeddings
+ GGML_ASSERT(embd != nullptr);
+
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
} break;
} break;
case LLAMA_POOLING_TYPE_RANK:
{
- // TODO: this likely should be the same logic as in llama_decoder_internal, but better to
- // wait for an encoder model that requires this pooling type in order to test it
- // https://github.com/ggerganov/llama.cpp/pull/9510
- GGML_ABORT("RANK pooling not implemented yet");
- }
+ // extract the rerank score - a single float per sequence
+ auto & embd_seq_out = embd_seq;
+
+ for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
+ const llama_seq_id seq_id = ubatch.seq_id[s][0];
+ if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
+ continue;
+ }
+ embd_seq_out[seq_id].resize(1);
+ ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
+ }
+ } break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
{
GGML_ABORT("unknown pooling type");
}
int llama_context::decode(llama_batch & inp_batch) {
+ if (!memory) {
+ LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__);
+ return encode(inp_batch);
+ }
+
if (inp_batch.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
llama_memory_i * res;
switch (arch) {
+ case LLM_ARCH_BERT:
+ case LLM_ARCH_JINA_BERT_V2:
+ case LLM_ARCH_NOMIC_BERT:
+ case LLM_ARCH_NOMIC_BERT_MOE:
+ {
+ res = nullptr;
+ } break;
case LLM_ARCH_MAMBA:
case LLM_ARCH_RWKV6:
case LLM_ARCH_RWKV6QWEN2:
batch.logits + i,
};
- const int ret = llama_decode(ctx, batch_view);
+ int ret = 0;
+
+ if (params_base.embedding || params_base.reranking) {
+ ret = llama_encode(ctx, batch_view);
+ } else {
+ ret = llama_decode(ctx, batch_view);
+ }
+
metrics.on_decoded(slots);
if (ret != 0) {
const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok](
server_task_type type,
json & data,
- std::function<bool()> is_connection_closed,
+ const std::function<bool()> & is_connection_closed,
httplib::Response & res,
oaicompat_type oaicompat) {
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);