server_response --> server_routes
```
-TODO: mention about how batching is handled by `server_slot`
+### Batching
+
+The server context maintains a single batch shared across all slots. When `update_slots()` is invoked, the system iterates through all active slots to populate this batch. For each slot, either a generated token from the previous decoding step or available prompt tokens are added to the batch.
+
+Batching constraints apply: slots can only be batched together if they share compatible configurations. For instance, slots using a specific LoRA adapter can be batched with each other, but not with slots using a different LoRA adapter or no adapter at all.
+
+Once the batch reaches capacity or all slots have been processed, `llama_decode` is called to execute the inference. This operation represents the primary computational bottleneck in `update_slots()`.
+
+Following decoding, the system either retrieves embeddings or samples the next token using `common_sampler_sample`. If a slot has remaining prompt tokens to process, it yields until the next `update_slots()` iteration.
### Thread Management
- All JSON formatting and chat template logic must stay in the HTTP layer.
- Avoid passing raw JSON between the HTTP layer and `server_slot`. Instead, parse everything into native C++ types as early as possible.
+### Example trace of a request
+
+Here is an example trace of an API request for text completion:
+
+- A request arrives at the HTTP layer.
+- The request is routed to the corresponding handler inside `server_routes`. In this case, `handle_completions_impl` is invoked.
+- The handler parses the input request, constructs a new `server_task`, and passes it to `server_res_generator`.
+- `server_res_generator` creates a new `task_result_state` for each task:
+ - `task_result_state` stays in the HTTP layer, responsible for keeping track of the current state of the response (e.g., parsing tool calls or thinking messages).
+ - `server_task` is moved into `server_queue` inside `server_context`.
+- `server_context` launches the task by moving it into an available slot (see `launch_slot_with_task()`).
+- `update_slot()` processes the task as described in the "Batching" section above.
+- Results may be sent using `send_partial_response` or `send_final_response`, which creates a new `server_task_result` and pushes it to the response queue.
+- At the same time, `server_res_generator` listens to the response queue and retrieves this response.
+- As the response is stateless, `server_res_generator` calls `response->update()` to update the response with the current state.
+- `server_res_generator` then calls `response->to_json()` and passes the response to the HTTP layer.
+
### Testing
`llama-server` includes an automated test suite based on `pytest`.
int get_slot_n_ctx() {
return slots.back().n_ctx;
}
+
+ server_response_reader get_response_reader() {
+ return server_response_reader(queue_tasks, queue_results, HTTP_POLLING_SECONDS);
+ }
};
//
return impl->ctx;
}
-std::pair<server_queue &, server_response &> server_context::get_queues() {
- return { impl->queue_tasks, impl->queue_results };
+server_response_reader server_context::get_response_reader() {
+ return impl->get_response_reader();
}
struct server_res_generator : server_http_res {
server_response_reader rd;
server_res_generator(server_context_impl & ctx_server)
- : rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS) {}
+ : rd(ctx_server.queue_tasks, ctx_server.queue_results, HTTP_POLLING_SECONDS) {}
void ok(const json & response_data) {
status = 200;
data = safe_json_to_str(response_data);
try {
std::vector<server_task> tasks;
- // tracking generation state and partial tool calls
- std::vector<task_result_state> states;
-
const auto & prompt = data.at("prompt");
// TODO: this log can become very long, put it behind a flag or think about a more compact format
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
}
tasks.reserve(inputs.size());
- states.reserve(inputs.size());
int idx = 0;
for (size_t i = 0; i < inputs.size(); i++) {
server_task task = server_task(type);
task.params.res_type = res_type;
task.params.oaicompat_cmpl_id = completion_id;
task.params.oaicompat_model = ctx_server.model_name;
- states.push_back(task.params.oaicompat_chat_syntax);
if (task.params.n_cmpl > 1) {
task.n_children = task.params.n_cmpl - 1;
task.id,
ctx_server.queue_tasks.get_new_id(),
idx++);
- states.push_back(child.params.oaicompat_chat_syntax);
tasks.push_back(std::move(child));
}
}
tasks.push_back(std::move(task));
}
- rd.set_states(std::move(states));
rd.post_tasks(std::move(tasks));
} catch (const std::exception & e) {
res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
// create and queue the task
json responses = json::array();
- server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS);
+ server_response_reader rd = ctx_server.get_response_reader();
{
std::vector<server_task> tasks;
tasks.reserve(documents.size());
// create and queue the task
json responses = json::array();
- server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS);
+ server_response_reader rd = ctx_server.get_response_reader();
{
std::vector<server_task> tasks;
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
// get the underlaying llama_context
llama_context * get_llama_context() const;
- // get the underlaying queue_tasks and queue_results
- // used by CLI application
- std::pair<server_queue &, server_response &> get_queues();
+ // get a new response reader, used by CLI application
+ server_response_reader get_response_reader();
};
// server_response_reader
//
-void server_response_reader::set_states(std::vector<task_result_state> && states) {
- this->states = std::move(states);
+void server_response_reader::post_task(server_task && task) {
+ GGML_ASSERT(id_tasks.empty() && "post_task() can only be called once per reader");
+ id_tasks.insert(task.id);
+ states.push_back(task.create_state());
+ queue_results.add_waiting_task_id(task.id);
+ queue_tasks.post(std::move(task));
}
void server_response_reader::post_tasks(std::vector<server_task> && tasks) {
+ GGML_ASSERT(id_tasks.empty() && "post_tasks() can only be called once per reader");
id_tasks = server_task::get_list_id(tasks);
+ states.reserve(tasks.size());
+ for (size_t i = 0; i < tasks.size(); i++) {
+ states.push_back(tasks[i].create_state());
+ }
queue_results.add_waiting_tasks(tasks);
queue_tasks.post(std::move(tasks));
}
std::vector<task_result_state> states;
// should_stop function will be called each polling_interval_seconds
- server_response_reader(std::pair<server_queue &, server_response &> server_queues, int polling_interval_seconds)
- : queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {}
+ server_response_reader(server_queue & queue_tasks, server_response & queue_results, int polling_interval_seconds)
+ : queue_tasks(queue_tasks), queue_results(queue_results), polling_interval_seconds(polling_interval_seconds) {}
~server_response_reader() {
stop();
}
- void set_states(std::vector<task_result_state> && states);
+ void post_task(server_task && tasks);
void post_tasks(std::vector<server_task> && tasks);
bool has_next() const;
json to_json(bool only_metrics = false) const;
};
+// struct for tracking the state of a task (e.g., for streaming)
+struct task_result_state {
+ // tracking diffs for partial tool calls
+ std::vector<common_chat_msg_diff> diffs;
+ common_chat_syntax oaicompat_chat_syntax;
+ common_chat_msg chat_msg;
+ std::string generated_text; // append new chunks of generated text here
+ std::vector<std::string> generated_tool_call_ids;
+
+ task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
+ : oaicompat_chat_syntax(oaicompat_chat_syntax) {}
+
+ // parse partial tool calls and update the internal state
+ common_chat_msg update_chat_msg(
+ const std::string & text_added,
+ bool is_partial,
+ std::vector<common_chat_msg_diff> & diffs);
+};
+
struct server_task {
int id = -1; // to be filled by server_queue
int index = -1; // used when there are multiple prompts (batch request)
copy.tokens = tokens.clone();
return copy;
}
+
+ // the task will be moved into queue, then onto slots
+ // however, the state must be kept by caller (e.g., HTTP thread)
+ task_result_state create_state() const {
+ return task_result_state(params.oaicompat_chat_syntax);
+ }
};
struct result_timings {
json to_json() const;
};
-// struct for tracking the state of a task (e.g., for streaming)
-struct task_result_state {
- // tracking diffs for partial tool calls
- std::vector<common_chat_msg_diff> diffs;
- common_chat_syntax oaicompat_chat_syntax;
- common_chat_msg chat_msg;
- std::string generated_text; // append new chunks of generated text here
- std::vector<std::string> generated_tool_call_ids;
-
- task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
- : oaicompat_chat_syntax(oaicompat_chat_syntax) {}
-
- // parse partial tool calls and update the internal state
- common_chat_msg update_chat_msg(
- const std::string & text_added,
- bool is_partial,
- std::vector<common_chat_msg_diff> & diffs);
-};
-
struct server_task_result {
int id = -1;
int id_slot = -1;