params.yarn_beta_slow = std::stof(argv[i]);
} else if (arg == "--memory-f32") {
params.memory_f16 = false;
+ } else if (arg == "--samplers") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ sparams.samplers_sequence = parse_samplers_input(argv[i]);
+ } else if (arg == "--sampling-seq") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ sparams.samplers_sequence = argv[i];
} else if (arg == "--top-p") {
if (++i >= argc) {
invalid_param = true;
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
+ printf(" --samplers samplers that will be used for generation in the order, separated by \';\', for example: \"top_k;tfs;typical;top_p;min_p;temp\"\n");
+ printf(" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n", sparams.samplers_sequence.c_str());
printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k);
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p);
GGML_UNREACHABLE();
}
+//
+// String parsing
+//
+
+std::string parse_samplers_input(std::string input) {
+ std::string output = "";
+ // since samplers names are written multiple ways
+ // make it ready for both system names and input names
+ std::unordered_map<std::string, char> samplers_symbols {
+ {"top_k", 'k'},
+ {"top-k", 'k'},
+ {"top_p", 'p'},
+ {"top-p", 'p'},
+ {"nucleus", 'p'},
+ {"typical_p", 'y'},
+ {"typical-p", 'y'},
+ {"typical", 'y'},
+ {"min_p", 'm'},
+ {"min-p", 'm'},
+ {"tfs_z", 'f'},
+ {"tfs-z", 'f'},
+ {"tfs", 'f'},
+ {"temp", 't'},
+ {"temperature",'t'}
+ };
+ // expected format example: "temp;top_k;tfs_z;typical_p;top_p;min_p"
+ size_t separator = input.find(';');
+ while (separator != input.npos) {
+ std::string name = input.substr(0,separator);
+ input = input.substr(separator+1);
+ separator = input.find(';');
+
+ if (samplers_symbols.find(name) != samplers_symbols.end()) {
+ output += samplers_symbols[name];
+ }
+ }
+ if (samplers_symbols.find(input) != samplers_symbols.end()) {
+ output += samplers_symbols[input];
+ }
+ return output;
+}
+
//
// Model utils
//
void process_escapes(std::string& input);
+//
+// String parsing
+//
+
+std::string parse_samplers_input(std::string input);
+
//
// Model utils
//
return std::string(result);
}
+std::string llama_sampling_order_print(const llama_sampling_params & params) {
+ std::string result = "CFG -> Penalties ";
+ if (params.mirostat == 0) {
+ for (auto s : params.samplers_sequence) {
+ switch (s) {
+ case 'k': result += "-> top_k "; break;
+ case 'f': result += "-> tfs_z "; break;
+ case 'y': result += "-> typical_p "; break;
+ case 'p': result += "-> top_p "; break;
+ case 'm': result += "-> min_p "; break;
+ case 't': result += "-> temp "; break;
+ default : break;
+ }
+ }
+ } else result += "-> mirostat ";
+
+ return result;
+}
+
+// no reasons to expose this function in header
+void sampler_queue(
+ struct llama_context * ctx_main,
+ const llama_sampling_params & params,
+ llama_token_data_array & cur_p,
+ size_t & min_keep) {
+ const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
+
+ const float temp = params.temp;
+ const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
+ const float top_p = params.top_p;
+ const float min_p = params.min_p;
+ const float tfs_z = params.tfs_z;
+ const float typical_p = params.typical_p;
+ const std::string & samplers_sequence = params.samplers_sequence;
+
+ for (auto s : samplers_sequence) {
+ switch (s){
+ case 'k': llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
+ case 'f': llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
+ case 'y': llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
+ case 'p': llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
+ case 'm': llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
+ case 't': llama_sample_temp (ctx_main, &cur_p, temp); break;
+ default : break;
+ }
+ }
+}
+
llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
const float temp = params.temp;
- const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
- const float top_p = params.top_p;
- const float min_p = params.min_p;
- const float tfs_z = params.tfs_z;
- const float typical_p = params.typical_p;
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
const float penalty_repeat = params.penalty_repeat;
const float penalty_freq = params.penalty_freq;
// temperature sampling
size_t min_keep = std::max(1, params.n_probs);
- llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep);
- llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
- llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
- llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep);
- llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep);
- llama_sample_temp (ctx_main, &cur_p, temp);
+ sampler_queue(ctx_main, params, cur_p, min_keep);
id = llama_sample_token(ctx_main, &cur_p);
// sampling parameters
typedef struct llama_sampling_params {
- int32_t n_prev = 64; // number of previous tokens to remember
- int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
- int32_t top_k = 40; // <= 0 to use vocab size
- float top_p = 0.95f; // 1.0 = disabled
- float min_p = 0.05f; // 0.0 = disabled
- float tfs_z = 1.00f; // 1.0 = disabled
- float typical_p = 1.00f; // 1.0 = disabled
- float temp = 0.80f; // 1.0 = disabled
- int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
- float penalty_repeat = 1.10f; // 1.0 = disabled
- float penalty_freq = 0.00f; // 0.0 = disabled
- float penalty_present = 0.00f; // 0.0 = disabled
- int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
- float mirostat_tau = 5.00f; // target entropy
- float mirostat_eta = 0.10f; // learning rate
- bool penalize_nl = true; // consider newlines as a repeatable token
+ int32_t n_prev = 64; // number of previous tokens to remember
+ int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
+ int32_t top_k = 40; // <= 0 to use vocab size
+ float top_p = 0.95f; // 1.0 = disabled
+ float min_p = 0.05f; // 0.0 = disabled
+ float tfs_z = 1.00f; // 1.0 = disabled
+ float typical_p = 1.00f; // 1.0 = disabled
+ float temp = 0.80f; // 1.0 = disabled
+ int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
+ float penalty_repeat = 1.10f; // 1.0 = disabled
+ float penalty_freq = 0.00f; // 0.0 = disabled
+ float penalty_present = 0.00f; // 0.0 = disabled
+ int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
+ float mirostat_tau = 5.00f; // target entropy
+ float mirostat_eta = 0.10f; // learning rate
+ bool penalize_nl = true; // consider newlines as a repeatable token
+ std::string samplers_sequence = "kfypmt"; // top_k, tail_free, typical_p, top_p, min_p, temp
std::string grammar; // optional BNF-like grammar to constrain sampling
// Print sampling parameters into a string
std::string llama_sampling_print(const llama_sampling_params & params);
+// Print sampling order into a string
+std::string llama_sampling_order_print(const llama_sampling_params & params);
+
// this is a common sampling function used across the examples for convenience
// it can serve as a starting point for implementing your own sampling function
// Note: When using multiple sequences, it is the caller's responsibility to call
}
}
LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
+ LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str());
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
LOG_TEE("\n\n");