return it->second;
}
+bool common_speculative_is_compat(llama_context * ctx_tgt) {
+ auto * mem = llama_get_memory(ctx_tgt);
+ if (mem == nullptr) {
+ return false;
+ }
+
+ bool res = true;
+
+ llama_memory_clear(mem, true);
+
+ // eval 2 tokens to check if the context is compatible
+ std::vector<llama_token> tmp;
+ tmp.push_back(0);
+ tmp.push_back(0);
+
+ int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size()));
+ if (ret != 0) {
+ LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
+ res = false;
+ goto done;
+ }
+
+ // try to remove the last tokens
+ if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
+ LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
+ res = false;
+ goto done;
+ }
+
+done:
+ llama_memory_clear(mem, true);
+ llama_synchronize(ctx_tgt);
+
+ return res;
+}
+
// initialization of the speculative decoding system
//
common_speculative * common_speculative_init(
// convert type to string
std::string common_speculative_type_to_str(enum common_speculative_type type);
+// check if the llama_context is compatible for speculative decoding
+// note: clears the memory of the context
+bool common_speculative_is_compat(llama_context * ctx_tgt);
+
common_speculative * common_speculative_init(
common_params_speculative & params,
llama_context * ctx_tgt);
slots.clear();
+ const bool can_spec = common_speculative_is_compat(ctx);
+ if (!can_spec) {
+ SRV_WRN("%s", "speculative decoding not supported by this context\n");
+ }
+
// initialize slots
for (int i = 0; i < params_base.n_parallel; i++) {
server_slot slot;
slot.prompt.tokens.has_mtmd = mctx != nullptr;
// try speculative decoding
- {
+ if (can_spec) {
slot.spec = common_speculative_init(params_base.speculative, slot.ctx);
if (slot.spec) {
if (mctx) {