From: Georgi Gerganov Date: Tue, 3 Feb 2026 20:16:16 +0000 (+0200) Subject: sampling : delegate input allocation to the scheduler (#19266) X-Git-Tag: upstream/0.0.8067~140 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=faa1bc26eed0e56539ac5e7a751607d76f01f10f;p=pkg%2Fggml%2Fsources%2Fllama.cpp sampling : delegate input allocation to the scheduler (#19266) * sampling : delegate input allocation to the scheduler * graph : compute backend samplers only if needed --- diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 203852d0f..95b207e9e 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1027,11 +1027,7 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { llama_sampler_chain_n(sampler) > 0; if (sampler && can_offload) { - ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output()); - auto * host_buft = ggml_backend_dev_host_buffer_type(model.dev_output()); - if (host_buft) { - buft = host_buft; - } + auto * buft = ggml_backend_dev_buffer_type(model.dev_output()); sampler->iface->backend_init(sampler, buft); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 16d42c4ae..54f4ed248 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -2419,6 +2419,9 @@ void llm_graph_context::build_sampling() const { return; } + std::array outs; + outs[0] = res->t_logits; + auto inp_sampling = std::make_unique(samplers); res->add_input(std::move(inp_sampling)); @@ -2439,14 +2442,14 @@ void llm_graph_context::build_sampling() const { // add a dummy row of logits // this trick makes the graph static, regardless of which samplers are activated // this is important in order to minimize graph reallocations - // TODO: use `ggml_build_forward_select()` when available (https://github.com/ggml-org/llama.cpp/pull/18550) ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0); for (const auto & [seq_id, sampler] : samplers) { const auto it = seq_to_logit_row.find(seq_id); // inactive samplers always work on the first row - const auto row_idx = seq_to_logit_row.find(seq_id) != seq_to_logit_row.end() ? it->second : 0; + const auto row_idx = it != seq_to_logit_row.end() ? it->second : 0; + const int i_out = it != seq_to_logit_row.end() ? 1 : 0; ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]); ggml_format_name(logits_seq, "logits_seq_%d", seq_id); @@ -2463,22 +2466,26 @@ void llm_graph_context::build_sampling() const { if (data.sampled != nullptr) { res->t_sampled[seq_id] = data.sampled; - ggml_build_forward_expand(gf, data.sampled); + outs[1] = data.sampled; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); } if (data.probs != nullptr) { res->t_sampled_probs[seq_id] = data.probs; - ggml_build_forward_expand(gf, data.probs); + outs[1] = data.probs; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); } if (data.logits != nullptr) { res->t_sampled_logits[seq_id] = data.logits; - ggml_build_forward_expand(gf, data.logits); + outs[1] = data.logits; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); } if (data.candidates != nullptr) { res->t_candidates[seq_id] = data.candidates; - ggml_build_forward_expand(gf, data.candidates); + outs[1] = data.candidates; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); } } diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 5dde51306..515d6c163 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1025,11 +1025,7 @@ struct llama_sampler_dist : public llama_sampler_backend { std::mt19937 rng; - // backend input - struct ggml_tensor * inp_uniform; - - ggml_context_ptr inp_ctx; - ggml_backend_buffer_ptr inp_buf; + ggml_tensor * inp_uniform; }; static const char * llama_sampler_dist_name(const struct llama_sampler * smpl) { @@ -1138,37 +1134,10 @@ static bool llama_sampler_dist_backend_init( ggml_backend_buffer_type_t buft) { auto * sctx = (llama_sampler_dist *) smpl->ctx; - // allocate inputs - { - ggml_init_params params = { - /*.mem_size =*/ ggml_tensor_overhead(), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; - - sctx->inp_ctx.reset(ggml_init(params)); - - // Create the uniform random scalar input tensor. This will be set by - // llama_sampler_dist_backend_set_input after this graph is built. - sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1); - ggml_set_name (sctx->inp_uniform, "uniform"); - ggml_set_input(sctx->inp_uniform); - - // Allocate all tensors from our context to the backend - sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft)); - - ggml_backend_buffer_clear(sctx->inp_buf.get(), 0); - } - const bool res = llama_sampler_backend_support(smpl, buft); sctx->init(res); - if (!res) { - sctx->inp_ctx.reset(nullptr); - sctx->inp_buf.reset(nullptr); - } - return res; } @@ -1178,8 +1147,13 @@ static void llama_sampler_dist_backend_apply( struct ggml_cgraph * gf, struct llama_sampler_data * data) { GGML_UNUSED(gf); + auto * sctx = (llama_sampler_dist *) smpl->ctx; + sctx->inp_uniform = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); + ggml_set_name (sctx->inp_uniform, "uniform"); + ggml_set_input(sctx->inp_uniform); + struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits); ggml_set_name(probs, "dist_probs"); @@ -1226,6 +1200,7 @@ static void llama_sampler_dist_backend_apply( static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) { auto * sctx = (llama_sampler_dist *) smpl->ctx; + GGML_ASSERT(sctx->inp_uniform != nullptr); // We sample in double precision and cast to float to match rnd numbers of @@ -1262,8 +1237,6 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { /* .seed_cur = */ seed_cur, /* .rng = */ std::mt19937(seed_cur), /* .inp_uniform = */ nullptr, - /* .inp_ctx = */ nullptr, - /* .inp_buf = */ nullptr, } ); } @@ -3461,9 +3434,6 @@ struct llama_sampler_logit_bias : public llama_sampler_backend { struct ggml_tensor * inp_logit_bias; struct ggml_tensor * inp_logit_idxs; - - ggml_context_ptr inp_ctx; - ggml_backend_buffer_ptr inp_buf; }; static const char * llama_sampler_logit_bias_name(const struct llama_sampler * smpl) { @@ -3526,6 +3496,16 @@ static void llama_sampler_logit_bias_backend_apply( return; } + const size_t n = sctx->logit_bias.size(); + + sctx->inp_logit_bias = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n); + ggml_set_name(sctx->inp_logit_bias, "logit_bias"); + ggml_set_input(sctx->inp_logit_bias); + + sctx->inp_logit_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n); + ggml_set_name(sctx->inp_logit_idxs, "logit_idxs"); + ggml_set_input(sctx->inp_logit_idxs); + ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f); cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur)); @@ -3562,6 +3542,8 @@ static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * sm static bool llama_sampler_logit_bias_backend_init( struct llama_sampler * smpl, ggml_backend_buffer_type_t buft) { + GGML_UNUSED(buft); + auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; sctx->init(true); @@ -3570,29 +3552,6 @@ static bool llama_sampler_logit_bias_backend_init( return true; } - ggml_init_params params = { - /*.mem_size =*/ 2*ggml_tensor_overhead(), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; - - sctx->inp_ctx.reset(ggml_init(params)); - - const size_t n = sctx->logit_bias.size(); - - sctx->inp_logit_bias = ggml_new_tensor_2d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1, n); - ggml_set_name(sctx->inp_logit_bias, "logit_bias"); - ggml_set_input(sctx->inp_logit_bias); - - sctx->inp_logit_idxs = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_I32, n); - ggml_set_name(sctx->inp_logit_idxs, "logit_idxs"); - ggml_set_input(sctx->inp_logit_idxs); - - // Allocate all tensors from our context to the backend - sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft)); - - ggml_backend_buffer_clear(sctx->inp_buf.get(), 0); - return true; } @@ -3628,8 +3587,6 @@ struct llama_sampler * llama_sampler_init_logit_bias( /* .to_search = */ {}, /* .inp_logit_bias = */ nullptr, /* .inp_logit_idxs = */ nullptr, - /* .inp_ctx = */ nullptr, - /* .inp_buf = */ nullptr, } ); }