// llama_sampler API
+struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) {
+ return new llama_sampler {
+ /* .iface = */ iface,
+ /* .ctx = */ ctx,
+ };
+}
+
const char * llama_sampler_name(const struct llama_sampler * smpl) {
if (!smpl->iface) {
return "(null)";
}
if (smpl->ctx == nullptr) {
- return new llama_sampler {
+ return llama_sampler_init(
/* .iface = */ smpl->iface,
- /* .ctx = */ nullptr,
- };
+ /* .ctx = */ nullptr
+ );
}
GGML_ABORT("the sampler does not support cloning");
};
struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
- return new llama_sampler {
+ return llama_sampler_init(
/* .iface = */ &llama_sampler_chain_i,
/* .ctx = */ new llama_sampler_chain {
/* .params = */ params,
/* .samplers = */ {},
/* .t_sample_us = */ 0,
/* .n_sample = */ 0,
- },
- };
+ }
+ );
}
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
};
struct llama_sampler * llama_sampler_init_greedy() {
- return new llama_sampler {
+ return llama_sampler_init(
/* .iface = */ &llama_sampler_greedy_i,
- /* .ctx = */ nullptr,
- };
+ /* .ctx = */ nullptr
+ );
}
// dist
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
auto seed_cur = get_rng_seed(seed);
- return new llama_sampler {
+ return llama_sampler_init(
/* .iface = */ &llama_sampler_dist_i,
/* .ctx = */ new llama_sampler_dist {
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur),
- },
- };
+ }
+ );
}
// softmax
};
struct llama_sampler * llama_sampler_init_softmax() {
- return new llama_sampler {
+ return llama_sampler_init(
/* .iface = */ &llama_sampler_softmax_i,
- /* .ctx = */ nullptr,
- };
+ /* .ctx = */ nullptr
+ );
}
// top-k
};
struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
- return new llama_sampler {
+ return llama_sampler_init(
/* .iface = */ &llama_sampler_top_k_i,
/* .ctx = */ new llama_sampler_top_k {
/* .k = */ k,
- },
- };
+ }
+ );
}
// top-p
};
struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
- return new llama_sampler {
+ return llama_sampler_init(
/* .iface = */ &llama_sampler_top_p_i,
/* .ctx = */ new llama_sampler_top_p {
/* .p = */ p,
/* .min_keep = */ min_keep,
- },
- };
+ }
+ );
}
// min-p
};
struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
- return new llama_sampler {
+ return llama_sampler_init(
/* .iface = */ &llama_sampler_min_p_i,
/* .ctx = */ new llama_sampler_min_p {
/* .p = */ p,
/* .min_keep = */ min_keep,
- },
- };
+ }
+ );
}
// typical
};
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
- return new llama_sampler {
+ return llama_sampler_init(
/* .iface = */ &llama_sampler_typical_i,
/* .ctx = */ new llama_sampler_typical {
/* .p = */ p,
/* .min_keep = */ min_keep,
- },
- };
+ }
+ );
}
// temp
};
struct llama_sampler * llama_sampler_init_temp(float temp) {
- return new llama_sampler {
+ return llama_sampler_init(
/* .iface = */ &llama_sampler_temp_i,
/* .ctx = */ new llama_sampler_temp {
/*.temp = */ temp,
- },
- };
+ }
+ );
}
// temp-ext
};
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
- return new llama_sampler {
+ return llama_sampler_init(
/* .iface = */ &llama_sampler_temp_ext_i,
/* .ctx = */ new llama_sampler_temp_ext {
/* .temp = */ temp,
/* .delta = */ delta,
/* .exponent = */ exponent,
- },
- };
+ }
+ );
}
// xtc
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
auto seed_cur = get_rng_seed(seed);
- return new llama_sampler {
+ return llama_sampler_init(
/* .iface = */ &llama_sampler_xtc_i,
/* .ctx = */ new llama_sampler_xtc {
/* .probability = */ p,
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur),
- },
- };
+ }
+ );
}
// mirostat
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
auto seed_cur = get_rng_seed(seed);
- return new llama_sampler {
+ return llama_sampler_init(
/* .iface = */ &llama_sampler_mirostat_i,
/* .ctx = */ new llama_sampler_mirostat {
/* .n_vocab = */ n_vocab,
/* .m = */ m,
/* .mu = */ 2.0f*tau,
/* .rng = */ std::mt19937(seed_cur),
- },
- };
+ }
+ );
}
// mirostat v2
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
auto seed_cur = get_rng_seed(seed);
- return new llama_sampler {
+ return llama_sampler_init(
/* .iface = */ &llama_sampler_mirostat_v2_i,
/* .ctx = */ new llama_sampler_mirostat_v2 {
/* .seed = */ seed,
/* .eta = */ eta,
/* .mu = */ 2.0f*tau,
/* .rng = */ std::mt19937(seed_cur),
- },
- };
+ }
+ );
}
// grammar
};
}
- return new llama_sampler {
+ return llama_sampler_init(
/* .iface = */ &llama_sampler_grammar_i,
- /* .ctx = */ ctx,
- };
+ /* .ctx = */ ctx
+ );
}
struct llama_sampler * llama_sampler_init_grammar(
float penalty_present) {
penalty_last_n = std::max(penalty_last_n, 0);
- return new llama_sampler {
+ return llama_sampler_init(
/* .iface = */ &llama_sampler_penalties_i,
/* .ctx = */ new llama_sampler_penalties {
/* .penalty_last_n = */ penalty_last_n,
/* .penalty_present = */ penalty_present,
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
/* .token_count = */ {},
- },
- };
+ }
+ );
}
// DRY
}
}
- return new llama_sampler {
+ return llama_sampler_init(
/* .iface = */ &llama_sampler_dry_i,
/* .ctx = */ new llama_sampler_dry {
/* .total_context_size = */ context_size,
/* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
/* .dry_max_token_repeat = */ {},
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
- },
- };
+ }
+ );
}
// wrapper for test-sampling.cpp
int32_t n_vocab,
int32_t n_logit_bias,
const llama_logit_bias * logit_bias) {
- return new llama_sampler {
+ return llama_sampler_init(
/* .iface = */ &llama_sampler_logit_bias_i,
/* .ctx = */ new llama_sampler_logit_bias {
/* .n_vocab = */ n_vocab,
/* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
/* .to_search = */ {},
- },
- };
+ }
+ );
}
// infill
};
struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
- return new llama_sampler {
+ return llama_sampler_init(
/* .iface = */ &llama_sampler_infill_i,
/* .ctx = */ new llama_sampler_infill {
/* .vocab = */ vocab,
/* .buf0 = */ std::vector<char>(512),
/* .buf1 = */ std::vector<char>(512),
- },
- };
+ }
+ );
}
// utils