COMMON_SAMPLER_TYPE_TYPICAL_P = 5,
COMMON_SAMPLER_TYPE_TEMPERATURE = 6,
COMMON_SAMPLER_TYPE_XTC = 7,
-
+ COMMON_SAMPLER_TYPE_INFILL = 8,
};
// dimensionality reduction methods, used by cvector-generator
COMMON_SAMPLER_TYPE_TOP_P,
COMMON_SAMPLER_TYPE_MIN_P,
COMMON_SAMPLER_TYPE_XTC,
- COMMON_SAMPLER_TYPE_TEMPERATURE
+ COMMON_SAMPLER_TYPE_TEMPERATURE,
};
std::string grammar; // optional BNF-like grammar to constrain sampling
case COMMON_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break;
+ case COMMON_SAMPLER_TYPE_INFILL:
+ llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
+ break;
default:
GGML_ASSERT(false && "unknown sampler type");
}
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
case COMMON_SAMPLER_TYPE_XTC: return 'x';
+ case COMMON_SAMPLER_TYPE_INFILL: return 'i';
default : return '?';
}
}
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
+ case COMMON_SAMPLER_TYPE_INFILL: return "infill";
default : return "";
}
}
{ "tfs_z", COMMON_SAMPLER_TYPE_TFS_Z },
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
+ { "infill", COMMON_SAMPLER_TYPE_INFILL },
};
// since samplers names are written multiple ways
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
- { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }
+ { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
+ { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
};
std::vector<common_sampler_type> samplers;
if (!params.ctx_shift){
LOG_DBG("\n\n%s: context full and context shift is disabled => stopping\n", __func__);
break;
- } else {
- if (params.n_predict == -2) {
- LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
- break;
- }
+ }
+
+ if (params.n_predict == -2) {
+ LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
+ break;
+ }
- const int n_left = n_past - params.n_keep;
- const int n_discard = n_left/2;
+ const int n_left = n_past - params.n_keep;
+ const int n_discard = n_left/2;
- LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
- n_past, n_left, n_ctx, params.n_keep, n_discard);
+ LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
+ n_past, n_left, n_ctx, params.n_keep, n_discard);
- llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
- llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
+ llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
+ llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
- n_past -= n_discard;
+ n_past -= n_discard;
- LOG_DBG("after swap: n_past = %d\n", n_past);
+ LOG_DBG("after swap: n_past = %d\n", n_past);
- LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
+ LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
- LOG_DBG("clear session path\n");
- path_session.clear();
- }
+ LOG_DBG("clear session path\n");
+ path_session.clear();
}
} else {
// context extension via Self-Extend
int32_t lstrip,
bool special);
+ // check if token0 is contained as a prefix in token1
+ LLAMA_API bool llama_token_is_prefix(
+ const struct llama_model * model,
+ llama_token token0,
+ llama_token token1);
+
/// @details Convert the provided tokens into text (inverse of llama_tokenize()).
/// @param text The char pointer must be large enough to hold the resulting text.
/// @return Returns the number of chars/bytes on success, no more than text_len_max.
int32_t n_logit_bias,
const llama_logit_bias * logit_bias);
+ // this sampler is meant to be used for fill-in-the-middle infilling
+ // it's supposed to be used after top_k + top_p sampling
+ //
+ // 1. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG
+ // 2. combine probs of tokens that have the same prefix
+ //
+ // example:
+ //
+ // - before:
+ // "hel": 0.5
+ // "hell": 0.2
+ // "hello": 0.1
+ // "dummy": 0.1
+ //
+ // - after:
+ // "hel": 0.8
+ // "dummy": 0.1
+ //
+ // 3. discard non-EOG tokens with low prob
+ // 4. if no tokens are left -> pick EOT
+ //
+ LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model);
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
};
}
+// infill
+
+//#define GGML_DEBUG_SAMPLER_INFILL
+
+struct llama_sampler_infill {
+ const struct llama_vocab * vocab;
+};
+
+static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
+ return "infill";
+}
+
+static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+ auto * ctx = (llama_sampler_infill *) smpl->ctx;
+
+ llama_sampler_softmax_impl(cur_p);
+
+#if defined(GGML_DEBUG_SAMPLER_INFILL)
+#define LOG_DBG_CUR LLAMA_LOG_DEBUG
+#else
+#define LOG_DBG_CUR(...)
+#endif
+
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
+ }
+
+ float p_txt_sum = 0.0f;
+ float p_eog_sum = 0.0f;
+
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
+ p_eog_sum += cur_p->data[i].p;
+ } else {
+ p_txt_sum += cur_p->data[i].p;
+ }
+ }
+
+ const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat);
+
+ LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);
+
+ if (3*p_eog_sum*cur_p->size > p_txt_sum) {
+ LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);
+
+ // keep just the EOG tokens
+ const auto size_org = cur_p->size;
+
+ cur_p->size = 0;
+
+ float p_sum = 0.0f;
+
+ for (size_t i = 0; i < size_org; ++i) {
+ if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
+ p_sum += cur_p->data[i].p;
+
+ cur_p->data[cur_p->size++] = cur_p->data[i];
+ }
+ }
+
+ // normalize probs
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ cur_p->data[i].p /= p_sum;
+ }
+
+ return;
+ }
+
+ size_t n_combined = 0; GGML_UNUSED(n_combined);
+
+ // combine tokens with common prefix
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ for (size_t j = 0; j < cur_p->size; ++j) {
+ if (cur_p->data[i].logit == -INFINITY) {
+ break;
+ }
+
+ if (i == j || cur_p->data[j].logit == -INFINITY) {
+ continue;
+ }
+
+ if (llama_token_is_prefix_impl(*ctx->vocab, cur_p->data[i].id, cur_p->data[j].id)) {
+ if (cur_p->data[i].p > cur_p->data[j].p) {
+ cur_p->data[i].p += cur_p->data[j].p;
+ cur_p->data[j].logit = -INFINITY;
+ cur_p->data[j].p = 0.0f;
+ } else {
+ cur_p->data[j].p += cur_p->data[i].p;
+ cur_p->data[i].logit = -INFINITY;
+ cur_p->data[i].p = 0.0f;
+ }
+
+ n_combined++;
+ }
+ }
+ }
+
+ size_t n_non_eog = 0;
+
+ size_t size_org = cur_p->size;
+
+ float p_sum = 0.0f;
+ float thold = 0.2f;
+
+ cur_p->size = 0;
+
+ LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
+
+ for (size_t i = 0; i < size_org; ++i) {
+ const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
+
+ if (cur_p->data[i].p < thold && !is_eog) {
+ continue;
+ }
+
+ if (!is_eog) {
+ ++n_non_eog;
+ }
+
+ p_sum += cur_p->data[i].p;
+
+ // keep this token
+ cur_p->data[cur_p->size++] = cur_p->data[i];
+ }
+
+ LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog);
+
+ // if no non-EOG tokens are left -> reduce cur_p to single EOT token
+ if (n_non_eog == 0) {
+ cur_p->size = 1;
+ cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
+ cur_p->data[0].logit = 1.0f;
+
+ return;
+ }
+
+ // normalize probs
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ cur_p->data[i].p /= p_sum;
+
+ LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
+ }
+
+ size_org = cur_p->size;
+ p_sum = 0.0f;
+ thold = 1.0/(n_non_eog + 1);
+
+ cur_p->size = 0;
+
+ LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
+
+ for (size_t i = 0; i < size_org; ++i) {
+ const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
+
+ if (cur_p->data[i].p < thold && !is_eog) {
+ continue;
+ }
+
+ p_sum += cur_p->data[i].p;
+
+ cur_p->data[cur_p->size++] = cur_p->data[i];
+ }
+
+ // normalize probs
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ cur_p->data[i].p /= p_sum;
+
+ LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
+ }
+
+#undef LOG_DBG_CUR
+}
+
+static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
+ const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
+ return llama_sampler_init_infill_impl(*ctx->vocab);
+}
+
+static void llama_sampler_infill_free(struct llama_sampler * smpl) {
+ delete (llama_sampler_infill *) smpl->ctx;
+}
+
+static struct llama_sampler_i llama_sampler_infill_i = {
+ /* .name = */ llama_sampler_infill_name,
+ /* .accept = */ nullptr,
+ /* .apply = */ llama_sampler_infill_apply,
+ /* .reset = */ nullptr,
+ /* .clone = */ llama_sampler_infill_clone,
+ /* .free = */ llama_sampler_infill_free,
+};
+
+struct llama_sampler * llama_sampler_init_infill_impl(
+ const struct llama_vocab & vocab) {
+ return new llama_sampler {
+ /* .iface = */ &llama_sampler_infill_i,
+ /* .ctx = */ new llama_sampler_infill {
+ /* .vocab = */ &vocab,
+ },
+ };
+}
+
// utils
uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
#include "llama-grammar.h"
-#include <unordered_map>
-
struct llama_vocab;
struct llama_grammar;
const struct llama_vocab & vocab,
const char * grammar_str,
const char * grammar_root);
+
+struct llama_sampler * llama_sampler_init_infill_impl(
+ const struct llama_vocab & vocab);
return 0;
}
+bool llama_token_is_prefix_impl(
+ const struct llama_vocab & vocab,
+ llama_token token0,
+ llama_token token1) {
+ char text_buf_0[128];
+ char text_buf_1[128];
+
+ const int32_t len0 = llama_token_to_piece_impl(vocab, token0, text_buf_0, sizeof(text_buf_0) - 1, 0, false);
+ const int32_t len1 = llama_token_to_piece_impl(vocab, token1, text_buf_1, sizeof(text_buf_1) - 1, 0, false);
+
+ if (len0 <= 0 || len1 <= 0) {
+ return false;
+ }
+
+ return len0 <= len1 && memcmp(text_buf_0, text_buf_1, len0) == 0;
+}
+
int32_t llama_detokenize_impl(
const struct llama_vocab & vocab,
const llama_token * tokens,
id special_cls_id = LLAMA_TOKEN_NULL;
id special_mask_id = LLAMA_TOKEN_NULL;
- id linefeed_id = 13;
+ id linefeed_id = 13;
// fim tokens
id special_fim_pre_id = LLAMA_TOKEN_NULL;
int32_t lstrip,
bool special);
+// check if token0 is contained as a prefix in token1
+bool llama_token_is_prefix_impl(
+ const struct llama_vocab & vocab,
+ llama_token token0,
+ llama_token token1);
+
int32_t llama_detokenize_impl(
const struct llama_vocab & vocab,
const llama_token * tokens,
return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special);
}
+bool llama_token_is_prefix(
+ const struct llama_model * model,
+ llama_token token0,
+ llama_token token1) {
+ return llama_token_is_prefix_impl(model->vocab, token0, token1);
+}
+
int32_t llama_detokenize(
const struct llama_model * model,
const llama_token * tokens,
return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root);
}
+struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model) {
+ return llama_sampler_init_infill_impl(model->vocab);
+}
+
//
// model split
//