]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
mpt : utf-8 support, perplexity testing, repeat penalty sampling (#184)
authorklosax <redacted>
Wed, 24 May 2023 07:27:36 +0000 (09:27 +0200)
committerGitHub <redacted>
Wed, 24 May 2023 07:27:36 +0000 (10:27 +0300)
* common: utf-8 decoder, reverted gpt_toeknize utf-8 convert

* Update common.h

* main: decode utf-8 tokens on load

* mpt import: bug fix

* common: style fixes

* common: style fix

* Update common.h

* common: revert gpt_tokenize utf-8 convert

* Update common.cpp

* Update common.cpp

* Update common.cpp

* Add perplexity to mpt

* Update CMakeLists: perplexity

* mpt-perplexity: fixes

* Update perplexity.cpp

* common: add sampling with repeat penalty

* mpt-main: add repeat penalty sampling, add commandline parameters

* Update common.h

* mpt-main: style fixes

* Update perplexity.cpp

* Delete perplexity.cpp

* mpt: move perplexity to main

* mpt: move perplexity to main

* common.cpp: Use codecvt utf-8 converter

* main.cpp: Use codecvt utf-8 converter

* mpt : code style changes

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/common.cpp
examples/common.h
examples/mpt/convert-h5-to-ggml.py
examples/mpt/main.cpp

index eaaaa606f224fd2bdeff86077827b50fa0923d64..bb98d2c66dd2cb5a3bd6ceb6324cdcea932d4c09 100644 (file)
@@ -8,6 +8,8 @@
 #include <cmath>
 #include <fstream>
 #include <regex>
+#include <locale>
+#include <codecvt>
 
 #ifndef M_PI
 #define M_PI 3.14159265358979323846
@@ -212,38 +214,22 @@ void gpt_vocab::add_special_token(const std::string & token) {
     special_tokens.push_back(token);
 }
 
-static void append_utf8(char32_t ch, std::string & out) {
-    if (ch <= 0x7F) {
-        out.push_back(static_cast<unsigned char>(ch));
-    } else if (ch <= 0x7FF) {
-        out.push_back(static_cast<unsigned char>(0xC0 | ((ch >> 6) & 0x1F)));
-        out.push_back(static_cast<unsigned char>(0x80 | (ch & 0x3F)));
-    } else if (ch <= 0xFFFF) {
-        out.push_back(static_cast<unsigned char>(0xE0 | ((ch >> 12) & 0x0F)));
-        out.push_back(static_cast<unsigned char>(0x80 | ((ch >> 6) & 0x3F)));
-        out.push_back(static_cast<unsigned char>(0x80 | (ch & 0x3F)));
-    } else if (ch <= 0x10FFFF) {
-        out.push_back(static_cast<unsigned char>(0xF0 | ((ch >> 18) & 0x07)));
-        out.push_back(static_cast<unsigned char>(0x80 | ((ch >> 12) & 0x3F)));
-        out.push_back(static_cast<unsigned char>(0x80 | ((ch >> 6) & 0x3F)));
-        out.push_back(static_cast<unsigned char>(0x80 | (ch & 0x3F)));
-    } else {
-        printf("Invalid Unicode code point\n");
-    }
+std::string convert_to_utf8(const std::wstring & input) {
+    std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
+    return converter.to_bytes(input);
+}
+
+std::wstring convert_to_wstring(const std::string & input) {
+    std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
+    return converter.from_bytes(input);
 }
 
 std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
     std::vector<std::string> words;
-
-    // Convert input to utf-8
-    std::string utf8conv;
-    for (int w = 0; w < text.size(); w++) {
-        append_utf8( uint8_t(text[w]), utf8conv);
-    }
     
     // first split the text into words
     {
-        std::string str = utf8conv;
+        std::string str = text;
         std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
 
         // Generate the subpattern from the special_tokens vector if it's not empty
@@ -407,6 +393,122 @@ gpt_vocab::id gpt_sample_top_k_top_p(
     return logits_id[idx].second;
 }
 
+gpt_vocab::id gpt_sample_top_k_top_p_repeat(
+        const gpt_vocab & vocab,
+        const float * logits,
+        const int32_t * last_n_tokens_data,
+        size_t last_n_tokens_data_size,
+        int    top_k,
+        double top_p,
+        double temp,
+        int repeat_last_n,
+        float repeat_penalty,
+        std::mt19937 & rng) {
+
+    int n_logits = vocab.id_to_token.size();
+
+    const auto * plogits = logits;
+
+    const auto last_n_tokens = std::vector<int32_t>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_data_size);
+
+    if (temp <= 0) {
+        // select the token with the highest logit directly
+        float max_logit = plogits[0];
+        gpt_vocab::id max_id = 0;
+
+        for (int i = 1; i < n_logits; ++i) {
+            if (plogits[i] > max_logit) {
+                max_logit = plogits[i];
+                max_id = i;
+            }
+        }
+        return max_id;
+    }
+
+
+    std::vector<std::pair<double, gpt_vocab::id>> logits_id;
+    logits_id.reserve(n_logits);
+
+    {
+        const float scale = 1.0f/temp;
+        for (int i = 0; i < n_logits; ++i) {
+            // repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
+            // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
+            if (repeat_last_n > 0 && std::find(last_n_tokens.end()-repeat_last_n, last_n_tokens.end(), i) != last_n_tokens.end()) {
+                // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
+                if (plogits[i] < 0.0f) {
+                    logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
+                } else {
+                    logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
+                }
+            } else {
+                logits_id.push_back(std::make_pair(plogits[i]*scale, i));
+            }
+        }
+    }
+
+    // find the top K tokens
+    std::partial_sort(
+            logits_id.begin(),
+            logits_id.begin() + top_k, logits_id.end(),
+            [](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
+        return a.first > b.first;
+    });
+
+    logits_id.resize(top_k);
+
+    double maxl = -INFINITY;
+    for (const auto & kv : logits_id) {
+        maxl = std::max(maxl, kv.first);
+    }
+
+    // compute probs for the top K tokens
+    std::vector<double> probs;
+    probs.reserve(logits_id.size());
+
+    double sum = 0.0;
+    for (const auto & kv : logits_id) {
+        double p = exp(kv.first - maxl);
+        probs.push_back(p);
+        sum += p;
+    }
+
+    // normalize the probs
+    for (auto & p : probs) {
+        p /= sum;
+    }
+
+    if (top_p < 1.0f) {
+        double cumsum = 0.0f;
+        for (int i = 0; i < top_k; i++) {
+            cumsum += probs[i];
+            if (cumsum >= top_p) {
+                top_k = i + 1;
+                probs.resize(top_k);
+                logits_id.resize(top_k);
+                break;
+            }
+        }
+
+        cumsum = 1.0/cumsum;
+        for (int i = 0; i < (int) probs.size(); i++) {
+            probs[i] *= cumsum;
+        }
+    }
+
+//    printf("\n");
+//    for (int i = 0; i < (int) probs.size(); i++) {
+//    for (int i = 0; i < 10; i++) {
+//        printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
+//    }
+
+    std::discrete_distribution<> dist(probs.begin(), probs.end());
+    int idx = dist(rng);
+
+    return logits_id[idx].second;
+
+}
+
 bool read_wav(const std::string & fname, std::vector<float>& pcmf32, std::vector<std::vector<float>>& pcmf32s, bool stereo) {
     drwav wav;
     std::vector<uint8_t> wav_data; // used for pipe input from stdin
index 29d0792af2d356945fb19b70709473963de3fccb..73b4a5818091fbeaa7f8b50b5e2a7ad3fe071386 100644 (file)
@@ -61,6 +61,9 @@ struct gpt_vocab {
 // poor-man's JSON parsing
 std::map<std::string, int32_t> json_parse(const std::string & fname);
 
+// handle utf-8 coding
+void utf8_to_string(std::string const & in, std::string & out);
+
 // split text into tokens
 //
 // ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
@@ -92,6 +95,18 @@ gpt_vocab::id gpt_sample_top_k_top_p(
         double temp,
         std::mt19937 & rng);
 
+gpt_vocab::id gpt_sample_top_k_top_p_repeat(
+        const gpt_vocab & vocab,
+        const float * logits,
+        const int32_t * last_n_tokens_data,
+        size_t last_n_tokens_data_size,
+        int    top_k,
+        double top_p,
+        double temp,
+        int repeat_last_n,
+        float repeat_penalty,
+        std::mt19937 & rng);
+
 //
 // Audio utils
 //
index b61ec8749255f74d4f4bb8b555078e51b72237df..0765011ccc9fad14a69d13660eefe89dc2366a2c 100644 (file)
@@ -99,11 +99,11 @@ byte_decoder = {v:k for k, v in byte_encoder.items()}
 counter = 0
 # sort by value
 for key in sorted(encoder, key=encoder.get):
-    # workaround for key error when c = whitespace
+    # workaround for key error when c not found
     text=""
     for c in key:
-        if c == " ":
-            text += " "
+        if c not in byte_decoder:
+            text += c
         else:
             text += chr(byte_decoder[c] )
     text = bytearray( text, encoding="utf-8" )
index 94cb44dcbab9efc569fa418f3335d56bca868526..2890884cd702e1e6cfe483b5019a12025699e3c4 100644 (file)
@@ -18,8 +18,6 @@
 #include <utility>
 #include <vector>
 
-int n_ctx = 4096;
-
 // no defaults for now
 struct mpt_hparams {
     int32_t d_model      = 0;
@@ -30,6 +28,8 @@ struct mpt_hparams {
     float alibi_bias_max = 0;
     float clip_qkv       = 0;
     int32_t ftype        = 0;
+    int32_t n_ctx        = 0;
+
 };
 
 struct mpt_layer {
@@ -64,6 +64,111 @@ struct mpt_model {
     std::map<std::string, struct ggml_tensor *> tensors;
 };
 
+struct mpt_params {
+    int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
+
+    int32_t seed           = -1; // RNG seed
+    int32_t n_predict      = 200; // new tokens to predict
+    int32_t n_batch        = 8; // batch size for prompt processing
+    int32_t n_ctx          = 512;
+
+    std::string model      = ""; // model path
+    std::string prompt     = "";
+
+    bool    perplexity     = false;
+
+    // sampling parameters
+    int32_t top_k          = 0;
+    float   top_p          = 1.0f;
+    float   temp           = 0.8f;
+    int32_t repeat_last_n  = 64;
+    float   repeat_penalty = 1.02f;
+
+};
+
+void mpt_print_usage(int /*argc*/, char ** argv, const mpt_params & params) {
+    fprintf(stderr, "usage: %s [options]\n", argv[0]);
+    fprintf(stderr, "\n");
+    fprintf(stderr, "options:\n");
+    fprintf(stderr, "  -h, --help            show this help message and exit\n");
+    fprintf(stderr, "  -s SEED, --seed SEED  RNG seed (default: -1)\n");
+    fprintf(stderr, "  -t N, --threads N     number of threads to use during computation (default: %d)\n", params.n_threads);
+    fprintf(stderr, "  -p PROMPT, --prompt PROMPT\n");
+    fprintf(stderr, "                        prompt to start generation with (default: random)\n");
+    fprintf(stderr, "  -f FNAME, --file FNAME\n");
+    fprintf(stderr, "                        load prompt from a file\n");
+    fprintf(stderr, "  -n N, --n_predict N   number of tokens to predict (default: %d)\n", params.n_predict);
+    fprintf(stderr, "  --top_k N             top-k sampling (default: %d, 0 = n_vocab)\n", params.top_k);
+    fprintf(stderr, "  --top_p N             top-p sampling (default: %.2f)\n", params.top_p);
+    fprintf(stderr, "  --temp N              temperature (default: %.2f)\n", params.temp);
+    fprintf(stderr, "  --repeat-last-n N     last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", params.repeat_last_n);
+    fprintf(stderr, "  --repeat-penalty N    penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)\n", (double)params.repeat_penalty);
+    fprintf(stderr, "  --perplexity          compute perplexity over the prompt\n");
+    fprintf(stderr, "  -c N, --ctx-size N    size of the prompt context (default: %d)\n", params.n_ctx);
+    fprintf(stderr, "  -b N, --batch_size N  batch size for prompt processing (default: %d)\n", params.n_batch);
+    fprintf(stderr, "  -m FNAME, --model FNAME\n");
+    fprintf(stderr, "                        model path (default: %s)\n", params.model.c_str());
+    fprintf(stderr, "\n");
+}
+
+bool mpt_params_parse(int argc, char ** argv, mpt_params & params) {
+    for (int i = 1; i < argc; i++) {
+        std::string arg = argv[i];
+
+        if (arg == "-s" || arg == "--seed") {
+            params.seed = std::stoi(argv[++i]);
+        } else if (arg == "-t" || arg == "--threads") {
+            params.n_threads = std::stoi(argv[++i]);
+        } else if (arg == "-p" || arg == "--prompt") {
+            params.prompt = argv[++i];
+        } else if (arg == "-n" || arg == "--n_predict") {
+            params.n_predict = std::stoi(argv[++i]);
+        } else if (arg == "--top_k") {
+            params.top_k = std::max(1, std::stoi(argv[++i]));
+        } else if (arg == "--top_p") {
+            params.top_p = std::stof(argv[++i]);
+        } else if (arg == "--temp") {
+            params.temp = std::stof(argv[++i]);
+        } else if (arg == "--repeat-last-n") {
+            params.repeat_last_n = std::stof(argv[++i]);
+        } else if (arg == "--repeat-penalty") {
+            params.repeat_penalty = std::stof(argv[++i]);
+        } else if (arg == "--perplexity") {
+            params.perplexity = true;
+        } else if (arg == "-c" || arg == "--ctx-size") {
+            params.n_ctx = std::stoi(argv[++i]);
+        } else if (arg == "-b" || arg == "--batch_size") {
+            params.n_batch = std::stoi(argv[++i]);
+        } else if (arg == "-m" || arg == "--model") {
+            params.model = argv[++i];
+        } else if (arg == "-h" || arg == "--help") {
+            mpt_print_usage(argc, argv, params);
+            exit(0);
+        } else if (arg == "-f" || arg == "--file") {
+            if (++i > argc) {
+                fprintf(stderr, "Invalid file param");
+                break;
+            }
+            std::ifstream file(argv[i]);
+            if (!file) {
+                fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
+                break;
+            }
+            params.prompt.clear();
+            std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
+            if (params.prompt.back() == '\n') {
+                params.prompt.pop_back();
+            }
+        } else {
+            fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
+            mpt_print_usage(argc, argv, params);
+            exit(0);
+        }
+    }
+
+    return true;
+}
+
 // load the model's weights from a file
 bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vocab) {
     printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
@@ -127,6 +232,13 @@ bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vo
             fin.read((char *) buf.data(), len);
             word.assign(buf.data(), len);
 
+            // Convert token from utf-8
+            std::wstring word_multibytes = convert_to_wstring(word);
+            word.resize(word_multibytes.size());
+            for (int w = 0; w < word_multibytes.size(); w++) {
+                word[w] = uint8_t(word_multibytes[w]);
+            }
+
             vocab.token_to_id[word] = i;
             vocab.id_to_token[i] = word;
         }
@@ -146,9 +258,10 @@ bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vo
 
     size_t ctx_size = 0;
 
-    {
-        const auto & hparams = model.hparams;
+    const auto & hparams = model.hparams;
+    const size_t n_ctx = hparams.n_ctx;
 
+    {
         const size_t n_embd = hparams.d_model;
         const size_t n_layer = hparams.n_layers;
         const size_t n_vocab = hparams.n_vocab;
@@ -203,7 +316,7 @@ bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vo
         model.tensors["transformer.wte.weight"] = model.wte_weight;
         model.tensors["transformer.norm_f.weight"] = model.norm_f_weight;
 
-        for (int i = 0; i < n_layer; ++i) {
+        for (int i = 0; i < (int) n_layer; ++i) {
             auto & layer = model.layers[i];
 
             layer.norm_1_weight = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
@@ -239,7 +352,7 @@ bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vo
 
         const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
 
-        printf("%s: memory_size = %8.2f MB, n_mem = %lld\n", __func__, memory_size / 1024.0 / 1024.0, n_mem);
+        printf("%s: memory_size = %8.2f MB, n_mem = %ld\n", __func__, memory_size / 1024.0 / 1024.0, n_mem);
     }
 
     // load weights
@@ -335,15 +448,16 @@ bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vo
 //   - embd_w:    the predicted logits for the next token
 //
 bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
-              const std::vector<gpt_vocab::id> & embd_inp, std::vector<float> & embd_w, size_t & mem_per_token) {
+              const std::vector<gpt_vocab::id> & embd_inp, std::vector<float> & embd_w, bool logits_all, size_t & mem_per_token) {
     const int N = embd_inp.size();
 
     const auto & hparams = model.hparams;
 
-    const int n_embd = hparams.d_model;
+    const int n_embd  = hparams.d_model;
     const int n_layer = hparams.n_layers;
-    const int n_head = hparams.n_heads;
+    const int n_head  = hparams.n_heads;
     const int n_vocab = hparams.n_vocab;
+    const int n_ctx   = hparams.n_ctx;
 
     static size_t buf_size = 256u * 1024 * 1024;
     static void * buf = malloc(buf_size);
@@ -539,9 +653,15 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
     // ggml_graph_dump_dot(&gf, NULL, "mpt-model.dot");
     // }
 
-    // return result for just the last token
-    embd_w.resize(n_vocab);
-    memcpy(embd_w.data(), (float *)ggml_get_data(inpL) + (n_vocab * (N - 1)), sizeof(float) * n_vocab);
+    if (logits_all) {
+        // return result for all tokens
+        embd_w.resize(n_vocab *N);
+        memcpy(embd_w.data(), (float *)ggml_get_data(inpL) , sizeof(float) * n_vocab * N);
+    } else {
+        // return result for just the last token
+        embd_w.resize(n_vocab);
+        memcpy(embd_w.data(), (float *)ggml_get_data(inpL) + (n_vocab * (N - 1)), sizeof(float) * n_vocab);
+    }
 
     if (mem_per_token == 0) {
         mem_per_token = ggml_used_mem(ctx0) / N;
@@ -553,23 +673,208 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
     return true;
 }
 
-int main(int argc, char ** argv) {
+std::vector<float> softmax(const std::vector<float> & logits) {
+    std::vector<float> probs(logits.size());
+    float max_logit = logits[0];
+    for (float v : logits) max_logit = std::max(max_logit, v);
+    double sum_exp = 0.0;
+    for (size_t i = 0; i < logits.size(); i++) {
+        // Subtract the maximum logit value from the current logit value for numerical stability
+        const float logit = logits[i] - max_logit;
+        const float exp_logit = expf(logit);
+        sum_exp += exp_logit;
+        probs[i] = exp_logit;
+    }
+    for (size_t i = 0; i < probs.size(); i++) probs[i] /= sum_exp;
+    return probs;
+}
+
+int perplexity(mpt_params params) {
     ggml_time_init();
 
     const int64_t t_main_start_us = ggml_time_us();
 
-    gpt_params params;
-    params.model = "";
+    printf("%s: n_threads = %d\n", __func__, params.n_threads);
+    printf("%s: n_batch = %d\n", __func__, params.n_batch);
+    printf("%s: n_ctx = %d\n", __func__, params.n_ctx);
+    printf("\n");
+
+    int64_t t_load_us = 0;
+
+    gpt_vocab vocab;
+    mpt_model model;
+
+    model.hparams.n_ctx = params.n_ctx;
+
+    // load the model
+    {
+        const int64_t t_start_us = ggml_time_us();
+
+        if (!mpt_model_load(params.model, model, vocab)) {
+            fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
+            return 1;
+        }
+
+        t_load_us = ggml_time_us() - t_start_us;
+    }
+
+    int64_t t_predict_us = 0;
+
+    std::vector<float> logits;
+
+    // tokenize the prompt
+    std::vector<int> embd_inp = ::gpt_tokenize(vocab, params.prompt);
+
+    printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
+
+    // determine the required inference memory per token:
+    size_t mem_per_token = 0;
+    mpt_eval(model, params.n_threads, 0, {0, 1, 2, 3}, logits, false, mem_per_token);
+
+    int count   = 0;
+
+    const int n_chunk = embd_inp.size() / params.n_ctx;
+
+    const int n_vocab = model.hparams.n_vocab;
+    const int n_batch = params.n_batch;
+
+    double nll = 0.0;
+    fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
+
+    for (int i = 0; i < n_chunk; ++i) {
+
+        const int start =     i * params.n_ctx;
+        const int end   = start + params.n_ctx;
+
+        const int num_batches = (params.n_ctx + n_batch - 1) / n_batch;
+
+        std::vector<float> logits;
+
+        const auto t_start = std::chrono::high_resolution_clock::now();
+
+        for (int j = 0; j < num_batches; ++j) {
+
+            const int batch_start = start + j * n_batch;
+            const int batch_size  = std::min(end - batch_start, n_batch);
 
-    if (gpt_params_parse(argc, argv, params) == false) {
+            std::vector<gpt_vocab::id> embd;
+
+            for(int p=0;p<batch_size;p++) {
+                embd.push_back( embd_inp[batch_start+p]  );
+            }
+
+            std::vector<float> batch_logits;// = llama_get_logits(ctx);
+
+            const int64_t t_start_us = ggml_time_us();
+
+            if (!mpt_eval(model, params.n_threads, j * batch_size, embd, batch_logits, true, mem_per_token)) {
+                printf("Failed to predict\n");
+                return 1;
+            }
+
+            t_predict_us += ggml_time_us() - t_start_us;
+
+            logits.insert(logits.end(), batch_logits.data(), batch_logits.data() + batch_size * n_vocab);
+
+        }
+
+        const auto t_end = std::chrono::high_resolution_clock::now();
+
+        if (i == 0) {
+            const float t_total = std::chrono::duration<float>(t_end - t_start).count();
+            fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
+            int total_seconds = (int)(t_total * n_chunk);
+            if (total_seconds >= 60*60) {
+                fprintf(stderr, "%d hours ", total_seconds / (60*60));
+                total_seconds = total_seconds % (60*60);
+            }
+            fprintf(stderr, "%d minutes\n", total_seconds / 60);
+
+            printf("\nChunk\tPPL cumulative\tPPL chunk\n");
+        }
+
+        // We get the logits for all the tokens in the context window (params.n_ctx)
+        // from llama_eval above.  Now, based on https://huggingface.co/docs/transformers/perplexity,
+        // calculate the perplexity over the last half of the window (so the model always has
+        // some context to predict the token).
+        //
+        // We rely on the fact that attention in the forward pass only looks at previous
+        // tokens here, so the logits returned for each token are an accurate representation
+        // of what the model would have predicted at that point.
+        //
+        // Example, we have a context window of 512, we will compute perplexity for each of the
+        // last 256 tokens.  Then, we split the input up into context window size chunks to
+        // process the entire prompt.
+
+        double nllchunk = 0.0;
+        int countchunk = 0;
+
+        for (int j = std::min(512, params.n_ctx / 2); j < params.n_ctx - 1; ++j) {
+            // Calculate probability of next token, given the previous ones.
+            const std::vector<float> tok_logits(
+                logits.begin() + (j + 0) * n_vocab,
+                logits.begin() + (j + 1) * n_vocab);
+
+            const float prob = softmax(tok_logits)[embd_inp[ start+ j + 1]];
+
+            nllchunk += -std::log(prob);
+            ++countchunk;
+        }
+
+               nll += nllchunk;
+               count += countchunk;
+
+        // perplexity is e^(average negative log-likelihood)
+        printf("%d\t%.8lf\t%.8lf\n", i + 1, std::exp(nll / count), std::exp(nllchunk/countchunk) );
+        fflush(stdout);
+    }
+
+    // report timing
+    {
+        const int64_t t_main_end_us = ggml_time_us();
+
+        printf("\n\n");
+        printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token);
+        printf("%s:     load time = %8.2f ms\n", __func__, t_load_us / 1000.0f);
+        printf("%s:  eval time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us / 1000.0f,
+               t_predict_us / 1000.0f / (n_chunk * params.n_ctx) );
+        printf("%s:    total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us) / 1000.0f);
+    }
+
+    ggml_free(model.ctx);
+
+    return 0;
+}
+
+int main(int argc, char ** argv) {
+    mpt_params params;
+
+    if (mpt_params_parse(argc, argv, params) == false) {
         return 1;
     }
 
+    if (params.perplexity) {
+        return perplexity(params);
+    }
+
+    ggml_time_init();
+
+    const int64_t t_main_start_us = ggml_time_us();
+
     if (params.seed < 0) {
         params.seed = time(NULL);
     }
 
-    printf("%s: seed = %d\n", __func__, params.seed);
+    if (params.n_predict < 0) {
+        params.n_predict = 0;
+    }
+
+    printf("%s: seed      = %d\n",   __func__, params.seed);
+    printf("%s: n_threads = %d\n",   __func__, params.n_threads);
+    printf("%s: n_batch   = %d\n",   __func__, params.n_batch);
+    printf("%s: n_ctx     = %d\n",   __func__, params.n_ctx);
+    printf("%s: n_predict = %d\n\n", __func__, params.n_predict);
+    printf("\n");
 
     std::mt19937 rng(params.seed);
     if (params.prompt.empty()) {
@@ -588,6 +893,8 @@ int main(int argc, char ** argv) {
     gpt_vocab vocab;
     mpt_model model;
 
+    model.hparams.n_ctx = params.n_ctx;
+
     // load the model
     {
         const int64_t t_start_us = ggml_time_us();
@@ -600,82 +907,111 @@ int main(int argc, char ** argv) {
         t_load_us = ggml_time_us() - t_start_us;
     }
 
-    int n_past = 0;
+    if (params.top_k == 0) {
+        params.top_k = model.hparams.n_vocab;
+    }
+
+    if (params.repeat_last_n == -1) {
+        params.repeat_last_n = params.n_ctx;
+    }
+
+    printf("\n");
+    printf("%s: temp           = %.3f\n", __func__, params.temp);
+    printf("%s: top_k          = %d\n",   __func__, params.top_k);
+    printf("%s: top_p          = %.3f\n", __func__, params.top_p);
+    printf("%s: repeat_last_n  = %d\n",   __func__, params.repeat_last_n);
+    printf("%s: repeat_penalty = %.3f\n", __func__, params.repeat_penalty);
 
     int64_t t_sample_us = 0;
     int64_t t_predict_us = 0;
 
-    std::vector<float> logits;
+    std::vector<int32_t> last_n_tokens(params.n_ctx);
+    std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
 
     // tokenize the prompt
     std::vector<int> embd_inp = ::gpt_tokenize(vocab, params.prompt);
 
+    printf("\n");
     printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
 
-    for (int i = 0; i < embd_inp.size(); i++) {
-        printf("%s: token[%d] = %6d\n", __func__, i, embd_inp[i]);
+    for (size_t i = 0; i < embd_inp.size(); i++) {
+        printf("%s: token[%lu] = %6d\n", __func__, i, embd_inp[i]);
         // vocab.id_to_token.at(embd_inp[i]).c_str()
     }
     printf("\n");
 
-    params.n_predict = std::min(params.n_predict, n_ctx - (int)embd_inp.size());
-
     std::vector<gpt_vocab::id> embd;
+    std::vector<float> logits;
 
     // determine the required inference memory per token:
     size_t mem_per_token = 0;
-    mpt_eval(model, params.n_threads, 0, {0, 1, 2, 3}, logits, mem_per_token);
+    mpt_eval(model, params.n_threads, 0, {0, 1, 2, 3}, logits, false, mem_per_token);
 
-    for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
+    int n_past     = 0;
+    int n_consumed = 0;
+    int n_sampled  = 0;
+
+    while (n_sampled < params.n_predict) {
         // predict
         if (embd.size() > 0) {
             const int64_t t_start_us = ggml_time_us();
 
-            if (!mpt_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {
+            if (!mpt_eval(model, params.n_threads, n_past, embd, logits, false, mem_per_token)) {
                 printf("Failed to predict\n");
                 return 1;
             }
 
             t_predict_us += ggml_time_us() - t_start_us;
-        }
 
-        n_past += embd.size();
-        embd.clear();
+            n_past += embd.size();
+            embd.clear();
+        }
 
-        if (i >= embd_inp.size()) {
+        if ((int)embd_inp.size() <= n_consumed) {
             // sample next token
+
             const int top_k = params.top_k;
             const float top_p = params.top_p;
             const float temp = params.temp;
-
-            const int n_vocab = model.hparams.n_vocab;
+            const int repeat_last_n = params.repeat_last_n;
+            const float repeat_penalty = params.repeat_penalty;
 
             gpt_vocab::id id = 0;
 
             {
                 const int64_t t_start_sample_us = ggml_time_us();
 
-                id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);
+                id = gpt_sample_top_k_top_p_repeat(vocab, logits.data() + (logits.size() - model.hparams.n_vocab), last_n_tokens.data(), last_n_tokens.size(), top_k, top_p, temp, repeat_last_n, repeat_penalty, rng);
+
+                last_n_tokens.erase(last_n_tokens.begin());
+                last_n_tokens.push_back(id);
 
                 t_sample_us += ggml_time_us() - t_start_sample_us;
             }
 
             // add it to the context
             embd.push_back(id);
+            ++n_sampled;
+
         } else {
             // if here, it means we are still processing the input prompt
-            for (int k = i; k < embd_inp.size(); k++) {
-                embd.push_back(embd_inp[k]);
-                if (embd.size() > params.n_batch) {
+            while ((int) embd_inp.size() > n_consumed) {
+                embd.push_back(embd_inp[n_consumed]);
+
+                last_n_tokens.erase(last_n_tokens.begin());
+                last_n_tokens.push_back(embd_inp[n_consumed]);
+
+                ++n_consumed;
+                if ((int) embd.size() >= params.n_batch) {
                     break;
                 }
             }
-            i += embd.size() - 1;
         }
 
         // display text
         for (auto id : embd) {
-            printf("%s", vocab.id_to_token[id].c_str());
+           printf("%s", vocab.id_to_token[id].c_str());
+//            printf("[%i]%s", id, vocab.id_to_token[id].c_str());
         }
         fflush(stdout);
 
@@ -689,13 +1025,15 @@ int main(int argc, char ** argv) {
     {
         const int64_t t_main_end_us = ggml_time_us();
 
-        printf("\n\n");
-        printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token);
-        printf("%s:     load time = %8.2f ms\n", __func__, t_load_us / 1000.0f);
-        printf("%s:   sample time = %8.2f ms\n", __func__, t_sample_us / 1000.0f);
-        printf("%s:  predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us / 1000.0f,
+        printf("\n\n\n");
+        printf("%s: sampled tokens = %8d\n", __func__, n_sampled);
+        printf("%s:  mem per token = %8zu bytes\n", __func__, mem_per_token);
+        printf("%s:      load time = %8.2f ms\n", __func__, t_load_us / 1000.0f);
+        printf("%s:    sample time = %8.2f ms / %.2f ms per token\n", __func__, t_sample_us / 1000.0f,
+               t_sample_us / 1000.0f / n_sampled);
+        printf("%s:   eval time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us / 1000.0f,
                t_predict_us / 1000.0f / n_past);
-        printf("%s:    total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us) / 1000.0f);
+        printf("%s:     total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us) / 1000.0f);
     }
 
     ggml_free(model.ctx);