const auto & logits = lctx.logits;
const auto * plogits = logits.data() + logits.size() - n_logits;
+ if (temp <= 0) {
+ // select the token with the highest logit directly
+ float max_logit = plogits[0];
+ llama_vocab::id max_id = 0;
+
+ for (int i = 1; i < n_logits; ++i) {
+ if (plogits[i] > max_logit) {
+ max_logit = plogits[i];
+ max_id = i;
+ }
+ }
+ return max_id;
+ }
+
std::vector<std::pair<float, llama_vocab::id>> logits_id;
logits_id.reserve(n_logits);