struct llama_sampler_infill {
const struct llama_vocab * vocab;
+
+ std::vector<char> buf0;
+ std::vector<char> buf1;
};
static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
size_t n_combined = 0; GGML_UNUSED(n_combined);
// combine tokens with common prefix
- for (size_t i = 0; i < cur_p->size; ++i) {
- for (size_t j = 0; j < cur_p->size; ++j) {
- if (cur_p->data[i].logit == -INFINITY) {
+ for (size_t i0 = 0; i0 < cur_p->size; ++i0) {
+ for (size_t i1 = 0; i1 < cur_p->size; ++i1) {
+ if (cur_p->data[i0].logit == -INFINITY) {
break;
}
- if (i == j || cur_p->data[j].logit == -INFINITY) {
+ if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) {
continue;
}
- if (llama_token_is_prefix_impl(*ctx->vocab, cur_p->data[i].id, cur_p->data[j].id)) {
- if (cur_p->data[i].p > cur_p->data[j].p) {
- cur_p->data[i].p += cur_p->data[j].p;
- cur_p->data[j].logit = -INFINITY;
- cur_p->data[j].p = 0.0f;
- } else {
- cur_p->data[j].p += cur_p->data[i].p;
- cur_p->data[i].logit = -INFINITY;
- cur_p->data[i].p = 0.0f;
+ int len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
+ if (len0 < 0) {
+ ctx->buf0.resize(len0);
+ len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
+ assert(len0 > 0);
+ }
+
+ int len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
+ if (len1 < 0) {
+ ctx->buf1.resize(len1);
+ len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
+ assert(len1 > 0);
+ }
+
+ // token i0 is a prefix of token i1
+ if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) {
+ int dst = i0;
+ int src = i1;
+
+ // merge into the token with higher probability
+ if (cur_p->data[i1].p > cur_p->data[i0].p) {
+ std::swap(dst, src);
}
+ cur_p->data[dst].p += cur_p->data[src].p;
+ cur_p->data[src].logit = -INFINITY;
+ cur_p->data[src].p = 0.0f;
+
n_combined++;
}
}
/* .iface = */ &llama_sampler_infill_i,
/* .ctx = */ new llama_sampler_infill {
/* .vocab = */ &vocab,
+ /* .buf0 = */ std::vector<char>(512),
+ /* .buf1 = */ std::vector<char>(512),
},
};
}
return 0;
}
-bool llama_token_is_prefix_impl(
- const struct llama_vocab & vocab,
- llama_token token0,
- llama_token token1) {
- char text_buf_0[128];
- char text_buf_1[128];
-
- const int32_t len0 = llama_token_to_piece_impl(vocab, token0, text_buf_0, sizeof(text_buf_0) - 1, 0, false);
- const int32_t len1 = llama_token_to_piece_impl(vocab, token1, text_buf_1, sizeof(text_buf_1) - 1, 0, false);
-
- if (len0 <= 0 || len1 <= 0) {
- return false;
- }
-
- return len0 <= len1 && memcmp(text_buf_0, text_buf_1, len0) == 0;
-}
-
int32_t llama_detokenize_impl(
const struct llama_vocab & vocab,
const llama_token * tokens,
return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special);
}
-bool llama_token_is_prefix(
- const struct llama_model * model,
- llama_token token0,
- llama_token token1) {
- return llama_token_is_prefix_impl(model->vocab, token0, token1);
-}
-
int32_t llama_detokenize(
const struct llama_model * model,
const llama_token * tokens,