From: klosax Date: Wed, 24 May 2023 07:27:36 +0000 (+0200) Subject: mpt : utf-8 support, perplexity testing, repeat penalty sampling (#184) X-Git-Tag: upstream/0.0.1642~1447 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=9276285583ee322973064e8aaff0400eb7321604;p=pkg%2Fggml%2Fsources%2Fggml mpt : utf-8 support, perplexity testing, repeat penalty sampling (#184) * common: utf-8 decoder, reverted gpt_toeknize utf-8 convert * Update common.h * main: decode utf-8 tokens on load * mpt import: bug fix * common: style fixes * common: style fix * Update common.h * common: revert gpt_tokenize utf-8 convert * Update common.cpp * Update common.cpp * Update common.cpp * Add perplexity to mpt * Update CMakeLists: perplexity * mpt-perplexity: fixes * Update perplexity.cpp * common: add sampling with repeat penalty * mpt-main: add repeat penalty sampling, add commandline parameters * Update common.h * mpt-main: style fixes * Update perplexity.cpp * Delete perplexity.cpp * mpt: move perplexity to main * mpt: move perplexity to main * common.cpp: Use codecvt utf-8 converter * main.cpp: Use codecvt utf-8 converter * mpt : code style changes --------- Co-authored-by: Georgi Gerganov --- diff --git a/examples/common.cpp b/examples/common.cpp index eaaaa606..bb98d2c6 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -8,6 +8,8 @@ #include #include #include +#include +#include #ifndef M_PI #define M_PI 3.14159265358979323846 @@ -212,38 +214,22 @@ void gpt_vocab::add_special_token(const std::string & token) { special_tokens.push_back(token); } -static void append_utf8(char32_t ch, std::string & out) { - if (ch <= 0x7F) { - out.push_back(static_cast(ch)); - } else if (ch <= 0x7FF) { - out.push_back(static_cast(0xC0 | ((ch >> 6) & 0x1F))); - out.push_back(static_cast(0x80 | (ch & 0x3F))); - } else if (ch <= 0xFFFF) { - out.push_back(static_cast(0xE0 | ((ch >> 12) & 0x0F))); - out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); - out.push_back(static_cast(0x80 | (ch & 0x3F))); - } else if (ch <= 0x10FFFF) { - out.push_back(static_cast(0xF0 | ((ch >> 18) & 0x07))); - out.push_back(static_cast(0x80 | ((ch >> 12) & 0x3F))); - out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); - out.push_back(static_cast(0x80 | (ch & 0x3F))); - } else { - printf("Invalid Unicode code point\n"); - } +std::string convert_to_utf8(const std::wstring & input) { + std::wstring_convert> converter; + return converter.to_bytes(input); +} + +std::wstring convert_to_wstring(const std::string & input) { + std::wstring_convert> converter; + return converter.from_bytes(input); } std::vector gpt_tokenize(const gpt_vocab & vocab, const std::string & text) { std::vector words; - - // Convert input to utf-8 - std::string utf8conv; - for (int w = 0; w < text.size(); w++) { - append_utf8( uint8_t(text[w]), utf8conv); - } // first split the text into words { - std::string str = utf8conv; + std::string str = text; std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; // Generate the subpattern from the special_tokens vector if it's not empty @@ -407,6 +393,122 @@ gpt_vocab::id gpt_sample_top_k_top_p( return logits_id[idx].second; } +gpt_vocab::id gpt_sample_top_k_top_p_repeat( + const gpt_vocab & vocab, + const float * logits, + const int32_t * last_n_tokens_data, + size_t last_n_tokens_data_size, + int top_k, + double top_p, + double temp, + int repeat_last_n, + float repeat_penalty, + std::mt19937 & rng) { + + int n_logits = vocab.id_to_token.size(); + + const auto * plogits = logits; + + const auto last_n_tokens = std::vector(last_n_tokens_data, last_n_tokens_data + last_n_tokens_data_size); + + if (temp <= 0) { + // select the token with the highest logit directly + float max_logit = plogits[0]; + gpt_vocab::id max_id = 0; + + for (int i = 1; i < n_logits; ++i) { + if (plogits[i] > max_logit) { + max_logit = plogits[i]; + max_id = i; + } + } + return max_id; + } + + + std::vector> logits_id; + logits_id.reserve(n_logits); + + { + const float scale = 1.0f/temp; + for (int i = 0; i < n_logits; ++i) { + // repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858) + // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main + if (repeat_last_n > 0 && std::find(last_n_tokens.end()-repeat_last_n, last_n_tokens.end(), i) != last_n_tokens.end()) { + // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability + if (plogits[i] < 0.0f) { + logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i)); + } else { + logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i)); + } + } else { + logits_id.push_back(std::make_pair(plogits[i]*scale, i)); + } + } + } + + // find the top K tokens + std::partial_sort( + logits_id.begin(), + logits_id.begin() + top_k, logits_id.end(), + [](const std::pair & a, const std::pair & b) { + return a.first > b.first; + }); + + logits_id.resize(top_k); + + double maxl = -INFINITY; + for (const auto & kv : logits_id) { + maxl = std::max(maxl, kv.first); + } + + // compute probs for the top K tokens + std::vector probs; + probs.reserve(logits_id.size()); + + double sum = 0.0; + for (const auto & kv : logits_id) { + double p = exp(kv.first - maxl); + probs.push_back(p); + sum += p; + } + + // normalize the probs + for (auto & p : probs) { + p /= sum; + } + + if (top_p < 1.0f) { + double cumsum = 0.0f; + for (int i = 0; i < top_k; i++) { + cumsum += probs[i]; + if (cumsum >= top_p) { + top_k = i + 1; + probs.resize(top_k); + logits_id.resize(top_k); + break; + } + } + + cumsum = 1.0/cumsum; + for (int i = 0; i < (int) probs.size(); i++) { + probs[i] *= cumsum; + } + } + +// printf("\n"); +// for (int i = 0; i < (int) probs.size(); i++) { +// for (int i = 0; i < 10; i++) { +// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]); +// } + + std::discrete_distribution<> dist(probs.begin(), probs.end()); + int idx = dist(rng); + + return logits_id[idx].second; + +} + bool read_wav(const std::string & fname, std::vector& pcmf32, std::vector>& pcmf32s, bool stereo) { drwav wav; std::vector wav_data; // used for pipe input from stdin diff --git a/examples/common.h b/examples/common.h index 29d0792a..73b4a581 100644 --- a/examples/common.h +++ b/examples/common.h @@ -61,6 +61,9 @@ struct gpt_vocab { // poor-man's JSON parsing std::map json_parse(const std::string & fname); +// handle utf-8 coding +void utf8_to_string(std::string const & in, std::string & out); + // split text into tokens // // ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53 @@ -92,6 +95,18 @@ gpt_vocab::id gpt_sample_top_k_top_p( double temp, std::mt19937 & rng); +gpt_vocab::id gpt_sample_top_k_top_p_repeat( + const gpt_vocab & vocab, + const float * logits, + const int32_t * last_n_tokens_data, + size_t last_n_tokens_data_size, + int top_k, + double top_p, + double temp, + int repeat_last_n, + float repeat_penalty, + std::mt19937 & rng); + // // Audio utils // diff --git a/examples/mpt/convert-h5-to-ggml.py b/examples/mpt/convert-h5-to-ggml.py index b61ec874..0765011c 100644 --- a/examples/mpt/convert-h5-to-ggml.py +++ b/examples/mpt/convert-h5-to-ggml.py @@ -99,11 +99,11 @@ byte_decoder = {v:k for k, v in byte_encoder.items()} counter = 0 # sort by value for key in sorted(encoder, key=encoder.get): - # workaround for key error when c = whitespace + # workaround for key error when c not found text="" for c in key: - if c == " ": - text += " " + if c not in byte_decoder: + text += c else: text += chr(byte_decoder[c] ) text = bytearray( text, encoding="utf-8" ) diff --git a/examples/mpt/main.cpp b/examples/mpt/main.cpp index 94cb44dc..2890884c 100644 --- a/examples/mpt/main.cpp +++ b/examples/mpt/main.cpp @@ -18,8 +18,6 @@ #include #include -int n_ctx = 4096; - // no defaults for now struct mpt_hparams { int32_t d_model = 0; @@ -30,6 +28,8 @@ struct mpt_hparams { float alibi_bias_max = 0; float clip_qkv = 0; int32_t ftype = 0; + int32_t n_ctx = 0; + }; struct mpt_layer { @@ -64,6 +64,111 @@ struct mpt_model { std::map tensors; }; +struct mpt_params { + int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + + int32_t seed = -1; // RNG seed + int32_t n_predict = 200; // new tokens to predict + int32_t n_batch = 8; // batch size for prompt processing + int32_t n_ctx = 512; + + std::string model = ""; // model path + std::string prompt = ""; + + bool perplexity = false; + + // sampling parameters + int32_t top_k = 0; + float top_p = 1.0f; + float temp = 0.8f; + int32_t repeat_last_n = 64; + float repeat_penalty = 1.02f; + +}; + +void mpt_print_usage(int /*argc*/, char ** argv, const mpt_params & params) { + fprintf(stderr, "usage: %s [options]\n", argv[0]); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help show this help message and exit\n"); + fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n"); + fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); + fprintf(stderr, " -p PROMPT, --prompt PROMPT\n"); + fprintf(stderr, " prompt to start generation with (default: random)\n"); + fprintf(stderr, " -f FNAME, --file FNAME\n"); + fprintf(stderr, " load prompt from a file\n"); + fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict); + fprintf(stderr, " --top_k N top-k sampling (default: %d, 0 = n_vocab)\n", params.top_k); + fprintf(stderr, " --top_p N top-p sampling (default: %.2f)\n", params.top_p); + fprintf(stderr, " --temp N temperature (default: %.2f)\n", params.temp); + fprintf(stderr, " --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", params.repeat_last_n); + fprintf(stderr, " --repeat-penalty N penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)\n", (double)params.repeat_penalty); + fprintf(stderr, " --perplexity compute perplexity over the prompt\n"); + fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); + fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch); + fprintf(stderr, " -m FNAME, --model FNAME\n"); + fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); + fprintf(stderr, "\n"); +} + +bool mpt_params_parse(int argc, char ** argv, mpt_params & params) { + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + + if (arg == "-s" || arg == "--seed") { + params.seed = std::stoi(argv[++i]); + } else if (arg == "-t" || arg == "--threads") { + params.n_threads = std::stoi(argv[++i]); + } else if (arg == "-p" || arg == "--prompt") { + params.prompt = argv[++i]; + } else if (arg == "-n" || arg == "--n_predict") { + params.n_predict = std::stoi(argv[++i]); + } else if (arg == "--top_k") { + params.top_k = std::max(1, std::stoi(argv[++i])); + } else if (arg == "--top_p") { + params.top_p = std::stof(argv[++i]); + } else if (arg == "--temp") { + params.temp = std::stof(argv[++i]); + } else if (arg == "--repeat-last-n") { + params.repeat_last_n = std::stof(argv[++i]); + } else if (arg == "--repeat-penalty") { + params.repeat_penalty = std::stof(argv[++i]); + } else if (arg == "--perplexity") { + params.perplexity = true; + } else if (arg == "-c" || arg == "--ctx-size") { + params.n_ctx = std::stoi(argv[++i]); + } else if (arg == "-b" || arg == "--batch_size") { + params.n_batch = std::stoi(argv[++i]); + } else if (arg == "-m" || arg == "--model") { + params.model = argv[++i]; + } else if (arg == "-h" || arg == "--help") { + mpt_print_usage(argc, argv, params); + exit(0); + } else if (arg == "-f" || arg == "--file") { + if (++i > argc) { + fprintf(stderr, "Invalid file param"); + break; + } + std::ifstream file(argv[i]); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + break; + } + params.prompt.clear(); + std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.prompt)); + if (params.prompt.back() == '\n') { + params.prompt.pop_back(); + } + } else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + mpt_print_usage(argc, argv, params); + exit(0); + } + } + + return true; +} + // load the model's weights from a file bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vocab) { printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); @@ -127,6 +232,13 @@ bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vo fin.read((char *) buf.data(), len); word.assign(buf.data(), len); + // Convert token from utf-8 + std::wstring word_multibytes = convert_to_wstring(word); + word.resize(word_multibytes.size()); + for (int w = 0; w < word_multibytes.size(); w++) { + word[w] = uint8_t(word_multibytes[w]); + } + vocab.token_to_id[word] = i; vocab.id_to_token[i] = word; } @@ -146,9 +258,10 @@ bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vo size_t ctx_size = 0; - { - const auto & hparams = model.hparams; + const auto & hparams = model.hparams; + const size_t n_ctx = hparams.n_ctx; + { const size_t n_embd = hparams.d_model; const size_t n_layer = hparams.n_layers; const size_t n_vocab = hparams.n_vocab; @@ -203,7 +316,7 @@ bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vo model.tensors["transformer.wte.weight"] = model.wte_weight; model.tensors["transformer.norm_f.weight"] = model.norm_f_weight; - for (int i = 0; i < n_layer; ++i) { + for (int i = 0; i < (int) n_layer; ++i) { auto & layer = model.layers[i]; layer.norm_1_weight = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); @@ -239,7 +352,7 @@ bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vo const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v); - printf("%s: memory_size = %8.2f MB, n_mem = %lld\n", __func__, memory_size / 1024.0 / 1024.0, n_mem); + printf("%s: memory_size = %8.2f MB, n_mem = %ld\n", __func__, memory_size / 1024.0 / 1024.0, n_mem); } // load weights @@ -335,15 +448,16 @@ bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vo // - embd_w: the predicted logits for the next token // bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past, - const std::vector & embd_inp, std::vector & embd_w, size_t & mem_per_token) { + const std::vector & embd_inp, std::vector & embd_w, bool logits_all, size_t & mem_per_token) { const int N = embd_inp.size(); const auto & hparams = model.hparams; - const int n_embd = hparams.d_model; + const int n_embd = hparams.d_model; const int n_layer = hparams.n_layers; - const int n_head = hparams.n_heads; + const int n_head = hparams.n_heads; const int n_vocab = hparams.n_vocab; + const int n_ctx = hparams.n_ctx; static size_t buf_size = 256u * 1024 * 1024; static void * buf = malloc(buf_size); @@ -539,9 +653,15 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past, // ggml_graph_dump_dot(&gf, NULL, "mpt-model.dot"); // } - // return result for just the last token - embd_w.resize(n_vocab); - memcpy(embd_w.data(), (float *)ggml_get_data(inpL) + (n_vocab * (N - 1)), sizeof(float) * n_vocab); + if (logits_all) { + // return result for all tokens + embd_w.resize(n_vocab *N); + memcpy(embd_w.data(), (float *)ggml_get_data(inpL) , sizeof(float) * n_vocab * N); + } else { + // return result for just the last token + embd_w.resize(n_vocab); + memcpy(embd_w.data(), (float *)ggml_get_data(inpL) + (n_vocab * (N - 1)), sizeof(float) * n_vocab); + } if (mem_per_token == 0) { mem_per_token = ggml_used_mem(ctx0) / N; @@ -553,23 +673,208 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past, return true; } -int main(int argc, char ** argv) { +std::vector softmax(const std::vector & logits) { + std::vector probs(logits.size()); + float max_logit = logits[0]; + for (float v : logits) max_logit = std::max(max_logit, v); + double sum_exp = 0.0; + for (size_t i = 0; i < logits.size(); i++) { + // Subtract the maximum logit value from the current logit value for numerical stability + const float logit = logits[i] - max_logit; + const float exp_logit = expf(logit); + sum_exp += exp_logit; + probs[i] = exp_logit; + } + for (size_t i = 0; i < probs.size(); i++) probs[i] /= sum_exp; + return probs; +} + +int perplexity(mpt_params params) { ggml_time_init(); const int64_t t_main_start_us = ggml_time_us(); - gpt_params params; - params.model = ""; + printf("%s: n_threads = %d\n", __func__, params.n_threads); + printf("%s: n_batch = %d\n", __func__, params.n_batch); + printf("%s: n_ctx = %d\n", __func__, params.n_ctx); + printf("\n"); + + int64_t t_load_us = 0; + + gpt_vocab vocab; + mpt_model model; + + model.hparams.n_ctx = params.n_ctx; + + // load the model + { + const int64_t t_start_us = ggml_time_us(); + + if (!mpt_model_load(params.model, model, vocab)) { + fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); + return 1; + } + + t_load_us = ggml_time_us() - t_start_us; + } + + int64_t t_predict_us = 0; + + std::vector logits; + + // tokenize the prompt + std::vector embd_inp = ::gpt_tokenize(vocab, params.prompt); + + printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); + + // determine the required inference memory per token: + size_t mem_per_token = 0; + mpt_eval(model, params.n_threads, 0, {0, 1, 2, 3}, logits, false, mem_per_token); + + int count = 0; + + const int n_chunk = embd_inp.size() / params.n_ctx; + + const int n_vocab = model.hparams.n_vocab; + const int n_batch = params.n_batch; + + double nll = 0.0; + fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch); + + for (int i = 0; i < n_chunk; ++i) { + + const int start = i * params.n_ctx; + const int end = start + params.n_ctx; + + const int num_batches = (params.n_ctx + n_batch - 1) / n_batch; + + std::vector logits; + + const auto t_start = std::chrono::high_resolution_clock::now(); + + for (int j = 0; j < num_batches; ++j) { + + const int batch_start = start + j * n_batch; + const int batch_size = std::min(end - batch_start, n_batch); - if (gpt_params_parse(argc, argv, params) == false) { + std::vector embd; + + for(int p=0;p batch_logits;// = llama_get_logits(ctx); + + const int64_t t_start_us = ggml_time_us(); + + if (!mpt_eval(model, params.n_threads, j * batch_size, embd, batch_logits, true, mem_per_token)) { + printf("Failed to predict\n"); + return 1; + } + + t_predict_us += ggml_time_us() - t_start_us; + + logits.insert(logits.end(), batch_logits.data(), batch_logits.data() + batch_size * n_vocab); + + } + + const auto t_end = std::chrono::high_resolution_clock::now(); + + if (i == 0) { + const float t_total = std::chrono::duration(t_end - t_start).count(); + fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total); + int total_seconds = (int)(t_total * n_chunk); + if (total_seconds >= 60*60) { + fprintf(stderr, "%d hours ", total_seconds / (60*60)); + total_seconds = total_seconds % (60*60); + } + fprintf(stderr, "%d minutes\n", total_seconds / 60); + + printf("\nChunk\tPPL cumulative\tPPL chunk\n"); + } + + // We get the logits for all the tokens in the context window (params.n_ctx) + // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity, + // calculate the perplexity over the last half of the window (so the model always has + // some context to predict the token). + // + // We rely on the fact that attention in the forward pass only looks at previous + // tokens here, so the logits returned for each token are an accurate representation + // of what the model would have predicted at that point. + // + // Example, we have a context window of 512, we will compute perplexity for each of the + // last 256 tokens. Then, we split the input up into context window size chunks to + // process the entire prompt. + + double nllchunk = 0.0; + int countchunk = 0; + + for (int j = std::min(512, params.n_ctx / 2); j < params.n_ctx - 1; ++j) { + // Calculate probability of next token, given the previous ones. + const std::vector tok_logits( + logits.begin() + (j + 0) * n_vocab, + logits.begin() + (j + 1) * n_vocab); + + const float prob = softmax(tok_logits)[embd_inp[ start+ j + 1]]; + + nllchunk += -std::log(prob); + ++countchunk; + } + + nll += nllchunk; + count += countchunk; + + // perplexity is e^(average negative log-likelihood) + printf("%d\t%.8lf\t%.8lf\n", i + 1, std::exp(nll / count), std::exp(nllchunk/countchunk) ); + fflush(stdout); + } + + // report timing + { + const int64_t t_main_end_us = ggml_time_us(); + + printf("\n\n"); + printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token); + printf("%s: load time = %8.2f ms\n", __func__, t_load_us / 1000.0f); + printf("%s: eval time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us / 1000.0f, + t_predict_us / 1000.0f / (n_chunk * params.n_ctx) ); + printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us) / 1000.0f); + } + + ggml_free(model.ctx); + + return 0; +} + +int main(int argc, char ** argv) { + mpt_params params; + + if (mpt_params_parse(argc, argv, params) == false) { return 1; } + if (params.perplexity) { + return perplexity(params); + } + + ggml_time_init(); + + const int64_t t_main_start_us = ggml_time_us(); + if (params.seed < 0) { params.seed = time(NULL); } - printf("%s: seed = %d\n", __func__, params.seed); + if (params.n_predict < 0) { + params.n_predict = 0; + } + + printf("%s: seed = %d\n", __func__, params.seed); + printf("%s: n_threads = %d\n", __func__, params.n_threads); + printf("%s: n_batch = %d\n", __func__, params.n_batch); + printf("%s: n_ctx = %d\n", __func__, params.n_ctx); + printf("%s: n_predict = %d\n\n", __func__, params.n_predict); + printf("\n"); std::mt19937 rng(params.seed); if (params.prompt.empty()) { @@ -588,6 +893,8 @@ int main(int argc, char ** argv) { gpt_vocab vocab; mpt_model model; + model.hparams.n_ctx = params.n_ctx; + // load the model { const int64_t t_start_us = ggml_time_us(); @@ -600,82 +907,111 @@ int main(int argc, char ** argv) { t_load_us = ggml_time_us() - t_start_us; } - int n_past = 0; + if (params.top_k == 0) { + params.top_k = model.hparams.n_vocab; + } + + if (params.repeat_last_n == -1) { + params.repeat_last_n = params.n_ctx; + } + + printf("\n"); + printf("%s: temp = %.3f\n", __func__, params.temp); + printf("%s: top_k = %d\n", __func__, params.top_k); + printf("%s: top_p = %.3f\n", __func__, params.top_p); + printf("%s: repeat_last_n = %d\n", __func__, params.repeat_last_n); + printf("%s: repeat_penalty = %.3f\n", __func__, params.repeat_penalty); int64_t t_sample_us = 0; int64_t t_predict_us = 0; - std::vector logits; + std::vector last_n_tokens(params.n_ctx); + std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); // tokenize the prompt std::vector embd_inp = ::gpt_tokenize(vocab, params.prompt); + printf("\n"); printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); - for (int i = 0; i < embd_inp.size(); i++) { - printf("%s: token[%d] = %6d\n", __func__, i, embd_inp[i]); + for (size_t i = 0; i < embd_inp.size(); i++) { + printf("%s: token[%lu] = %6d\n", __func__, i, embd_inp[i]); // vocab.id_to_token.at(embd_inp[i]).c_str() } printf("\n"); - params.n_predict = std::min(params.n_predict, n_ctx - (int)embd_inp.size()); - std::vector embd; + std::vector logits; // determine the required inference memory per token: size_t mem_per_token = 0; - mpt_eval(model, params.n_threads, 0, {0, 1, 2, 3}, logits, mem_per_token); + mpt_eval(model, params.n_threads, 0, {0, 1, 2, 3}, logits, false, mem_per_token); - for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) { + int n_past = 0; + int n_consumed = 0; + int n_sampled = 0; + + while (n_sampled < params.n_predict) { // predict if (embd.size() > 0) { const int64_t t_start_us = ggml_time_us(); - if (!mpt_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) { + if (!mpt_eval(model, params.n_threads, n_past, embd, logits, false, mem_per_token)) { printf("Failed to predict\n"); return 1; } t_predict_us += ggml_time_us() - t_start_us; - } - n_past += embd.size(); - embd.clear(); + n_past += embd.size(); + embd.clear(); + } - if (i >= embd_inp.size()) { + if ((int)embd_inp.size() <= n_consumed) { // sample next token + const int top_k = params.top_k; const float top_p = params.top_p; const float temp = params.temp; - - const int n_vocab = model.hparams.n_vocab; + const int repeat_last_n = params.repeat_last_n; + const float repeat_penalty = params.repeat_penalty; gpt_vocab::id id = 0; { const int64_t t_start_sample_us = ggml_time_us(); - id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng); + id = gpt_sample_top_k_top_p_repeat(vocab, logits.data() + (logits.size() - model.hparams.n_vocab), last_n_tokens.data(), last_n_tokens.size(), top_k, top_p, temp, repeat_last_n, repeat_penalty, rng); + + last_n_tokens.erase(last_n_tokens.begin()); + last_n_tokens.push_back(id); t_sample_us += ggml_time_us() - t_start_sample_us; } // add it to the context embd.push_back(id); + ++n_sampled; + } else { // if here, it means we are still processing the input prompt - for (int k = i; k < embd_inp.size(); k++) { - embd.push_back(embd_inp[k]); - if (embd.size() > params.n_batch) { + while ((int) embd_inp.size() > n_consumed) { + embd.push_back(embd_inp[n_consumed]); + + last_n_tokens.erase(last_n_tokens.begin()); + last_n_tokens.push_back(embd_inp[n_consumed]); + + ++n_consumed; + if ((int) embd.size() >= params.n_batch) { break; } } - i += embd.size() - 1; } // display text for (auto id : embd) { - printf("%s", vocab.id_to_token[id].c_str()); + printf("%s", vocab.id_to_token[id].c_str()); +// printf("[%i]%s", id, vocab.id_to_token[id].c_str()); } fflush(stdout); @@ -689,13 +1025,15 @@ int main(int argc, char ** argv) { { const int64_t t_main_end_us = ggml_time_us(); - printf("\n\n"); - printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token); - printf("%s: load time = %8.2f ms\n", __func__, t_load_us / 1000.0f); - printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us / 1000.0f); - printf("%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us / 1000.0f, + printf("\n\n\n"); + printf("%s: sampled tokens = %8d\n", __func__, n_sampled); + printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token); + printf("%s: load time = %8.2f ms\n", __func__, t_load_us / 1000.0f); + printf("%s: sample time = %8.2f ms / %.2f ms per token\n", __func__, t_sample_us / 1000.0f, + t_sample_us / 1000.0f / n_sampled); + printf("%s: eval time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us / 1000.0f, t_predict_us / 1000.0f / n_past); - printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us) / 1000.0f); + printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us) / 1000.0f); } ggml_free(model.ctx);