}
} catch (const std::exception & err) {
// fallback to full vocab list
+ GGML_UNUSED(err);
}
return sampling.token_ids_full_vocab.data();
//
uint32_t llama_context::output_reserve(int32_t n_outputs) {
-
const auto & hparams = model.hparams;
const auto & vocab = model.vocab;
embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0};
offset += embd.size * sizeof(float);
- sampling.logits = {nullptr, 0};
- sampling.probs = {nullptr, 0};
- sampling.sampled = {nullptr, 0};
- sampling.candidates = {nullptr, 0};
-
if (has_sampling) {
sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
offset += sampling.logits.size * sizeof(float);
std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL);
+ } else {
+ sampling.logits = {nullptr, 0};
+ sampling.probs = {nullptr, 0};
+ sampling.sampled = {nullptr, 0};
+ sampling.candidates = {nullptr, 0};
+
+ sampling.logits_count.clear();
+ sampling.probs_count.clear();
+ sampling.candidates_count.clear();
}
// set all ids as invalid (negative)
}
}
- if (sampling.logits.has_data()) {
+ if (!sampling.samplers.empty()) {
+ assert(sampling.logits.size > 0);
+ assert(sampling.probs.size > 0);
+ assert(sampling.candidates.size > 0);
+ assert(sampling.sampled.size > 0);
+ assert(sampling.logits_count.size() > 0);
+ assert(sampling.probs_count.size() > 0);
+ assert(sampling.candidates_count.size() > 0);
+
for (uint64_t k = 0; k < n_vocab; ++k) {
std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]);
}
- }
- if (sampling.probs.has_data()) {
for (uint64_t k = 0; k < n_vocab; ++k) {
std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]);
}
- }
- if (sampling.candidates.has_data()) {
for (uint64_t k = 0; k < n_vocab; ++k) {
std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]);
}
- }
-
- if (sampling.sampled.has_data()) {
- std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]);
- }
-
- if (!sampling.logits_count.empty()) {
- std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
- }
-
- if (!sampling.probs_count.empty()) {
- std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
- }
- if (!sampling.candidates_count.empty()) {
+ std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]);
+ std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
+ std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]);
}
}
std::unique_ptr<llama_memory_i> memory;
// decode output (2-dimensional array: [n_outputs][n_vocab])
- struct buffer_view<float> logits = {nullptr, 0};
+ buffer_view<float> logits = {nullptr, 0};
// embeddings output (2-dimensional array: [n_outputs][n_embd])
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
- struct buffer_view<float> embd = {nullptr, 0};
+ buffer_view<float> embd = {nullptr, 0};
struct sampling_info {
+ // !samplers.empty() to check if any samplers are active
std::map<llama_seq_id, llama_sampler *> samplers;
- struct buffer_view<float> logits = {nullptr, 0};
- struct buffer_view<llama_token> sampled = {nullptr, 0};
- struct buffer_view<float> probs = {nullptr, 0};
- struct buffer_view<llama_token> candidates = {nullptr, 0};
+ buffer_view<float> logits = {nullptr, 0};
+ buffer_view<llama_token> sampled = {nullptr, 0};
+ buffer_view<float> probs = {nullptr, 0};
+ buffer_view<llama_token> candidates = {nullptr, 0};
std::vector<uint32_t> logits_count;
std::vector<uint32_t> probs_count;
std::vector<uint32_t> candidates_count;
+ // optimization
std::vector<llama_token> token_ids_full_vocab;
};