]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
gpt : fix sampling to use the temperature (close #16)
authorGeorgi Gerganov <redacted>
Sun, 15 Jan 2023 13:53:08 +0000 (15:53 +0200)
committerGeorgi Gerganov <redacted>
Sun, 15 Jan 2023 13:53:08 +0000 (15:53 +0200)
examples/gpt-2/main.cpp
examples/gpt-j/main.cpp
examples/utils.cpp

index 333d93b8b5e62cc6da13379eb7ba1502bb2f429c..134a9304daa29c3380327d1fb298834954b27416 100644 (file)
@@ -347,7 +347,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
 //   - n_threads: number of threads to use
 //   - n_past:    the context size so far
 //   - embd_inp:  the embeddings of the tokens in the context
-//   - embd_w:    the predicted probabilities of the next token
+//   - embd_w:    the predicted logits for the next token
 //
 bool gpt2_eval(
         const gpt2_model & model,
@@ -627,7 +627,7 @@ bool gpt2_eval(
     inpL = ggml_mul_mat(ctx0, model.wte, inpL);
 
     // logits -> probs
-    inpL = ggml_soft_max(ctx0, inpL);
+    //inpL = ggml_soft_max(ctx0, inpL);
 
     // run the computation
     ggml_build_forward_expand(&gf, inpL);
@@ -641,7 +641,7 @@ bool gpt2_eval(
     //embd_w.resize(n_vocab*N);
     //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
 
-    // return result for just the last token
+    // return result just for the last token
     embd_w.resize(n_vocab);
     memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
 
@@ -698,7 +698,7 @@ int main(int argc, char ** argv) {
     int64_t t_sample_us  = 0;
     int64_t t_predict_us = 0;
 
-    std::vector<float> embd_w;
+    std::vector<float> logits;
 
     // tokenize the prompt
     std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
@@ -714,14 +714,14 @@ int main(int argc, char ** argv) {
 
     // determine the required inference memory per token:
     size_t mem_per_token = 0;
-    gpt2_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, embd_w, mem_per_token);
+    gpt2_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
 
     for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
         // predict
         if (embd.size() > 0) {
             const int64_t t_start_us = ggml_time_us();
 
-            if (!gpt2_eval(model, params.n_threads, n_past, embd, embd_w, mem_per_token)) {
+            if (!gpt2_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {
                 printf("Failed to predict\n");
                 return 1;
             }
@@ -745,7 +745,7 @@ int main(int argc, char ** argv) {
             {
                 const int64_t t_start_sample_us = ggml_time_us();
 
-                id = gpt_sample_top_k_top_p(vocab, embd_w.data() + (embd_w.size() - n_vocab), top_k, top_p, temp, rng);
+                id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);
 
                 t_sample_us += ggml_time_us() - t_start_sample_us;
             }
index c55157dc53fb7a785572ce03c2a11ead194acc92..63248d7d522f7d082bbfd46a7329363c6fa55032 100644 (file)
@@ -355,7 +355,7 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab &
 //   - n_threads: number of threads to use
 //   - n_past:    the context size so far
 //   - embd_inp:  the embeddings of the tokens in the context
-//   - embd_w:    the predicted probabilities of the next token
+//   - embd_w:    the predicted logits for the next token
 //
 // The GPT-J model requires about 16MB of memory per input token.
 //
@@ -559,7 +559,7 @@ bool gptj_eval(
     }
 
     // logits -> probs
-    inpL = ggml_soft_max(ctx0, inpL);
+    //inpL = ggml_soft_max(ctx0, inpL);
 
     // run the computation
     ggml_build_forward_expand(&gf, inpL);
@@ -630,7 +630,7 @@ int main(int argc, char ** argv) {
     int64_t t_sample_us  = 0;
     int64_t t_predict_us = 0;
 
-    std::vector<float> embd_w;
+    std::vector<float> logits;
 
     // tokenize the prompt
     std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
@@ -644,14 +644,14 @@ int main(int argc, char ** argv) {
 
     // determine the required inference memory per token:
     size_t mem_per_token = 0;
-    gptj_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, embd_w, mem_per_token);
+    gptj_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
 
     for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
         // predict
         if (embd.size() > 0) {
             const int64_t t_start_us = ggml_time_us();
 
-            if (!gptj_eval(model, params.n_threads, n_past, embd, embd_w, mem_per_token)) {
+            if (!gptj_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {
                 printf("Failed to predict\n");
                 return 1;
             }
@@ -675,7 +675,7 @@ int main(int argc, char ** argv) {
             {
                 const int64_t t_start_sample_us = ggml_time_us();
 
-                id = gpt_sample_top_k_top_p(vocab, embd_w.data() + (embd_w.size() - n_vocab), top_k, top_p, temp, rng);
+                id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng);
 
                 t_sample_us += ggml_time_us() - t_start_sample_us;
             }
index 1cd59d90353a7115bdebeec1c399ea9744dc7278..30057b7cadaeb1fcfc2dfa1e18541bae3fcea933 100644 (file)
@@ -261,8 +261,11 @@ gpt_vocab::id gpt_sample_top_k_top_p(
     std::vector<std::pair<double, gpt_vocab::id>> logits_id;
     logits_id.reserve(n_logits);
 
-    for (int i = 0; i < n_logits; i++) {
-        logits_id.push_back(std::make_pair(logits[i], i));
+    {
+        const double scale = 1.0/temp;
+        for (int i = 0; i < n_logits; ++i) {
+            logits_id.push_back(std::make_pair(logits[i]*scale, i));
+        }
     }
 
     // find the top K tokens
@@ -275,59 +278,51 @@ gpt_vocab::id gpt_sample_top_k_top_p(
 
     logits_id.resize(top_k);
 
-    // normalize
-    {
-        double sum = 0.0f;
-        for (int i = 0; i < (int)logits_id.size(); i++) {
-            sum += logits_id[i].first;
-        }
+    double maxl = -INFINITY;
+    for (const auto & kv : logits_id) {
+        maxl = std::max(maxl, kv.first);
+    }
 
-        sum = 1.0/sum;
-        for (int i = 0; i < (int)logits_id.size(); i++) {
-            logits_id[i].first *= sum;
-        }
+    // compute probs for the top K tokens
+    std::vector<double> probs;
+    probs.reserve(logits_id.size());
+
+    double sum = 0.0;
+    for (const auto & kv : logits_id) {
+        double p = exp(kv.first - maxl);
+        probs.push_back(p);
+        sum += p;
+    }
+
+    // normalize the probs
+    for (auto & p : probs) {
+        p /= sum;
     }
 
     if (top_p < 1.0f) {
-        {
-            double cumsum = 0.0f;
-            for (int i = 0; i < top_k; i++) {
-                cumsum += logits_id[i].first;
-                if (cumsum >= top_p) {
-                    logits_id.resize(i+1);
-                    break;
-                }
+        double cumsum = 0.0f;
+        for (int i = 0; i < top_k; i++) {
+            cumsum += probs[i];
+            if (cumsum >= top_p) {
+                top_k = i + 1;
+                probs.resize(top_k);
+                logits_id.resize(top_k);
+                break;
             }
         }
 
-        // normalize again
-        {
-            double sum = 0.0f;
-            for (int i = 0; i < (int)logits_id.size(); i++) {
-                sum += logits_id[i].first;
-            }
-
-            sum = 1.0/sum;
-            for (int i = 0; i < (int)logits_id.size(); i++) {
-                logits_id[i].first *= sum;
-            }
+        cumsum = 1.0/cumsum;
+        for (int i = 0; i < (int) probs.size(); i++) {
+            probs[i] *= cumsum;
         }
     }
 
     //printf("\n");
-    //for (int i = 0; i < (int)logits_id.size(); i++) {
-    //    printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), logits_id[i].first);
+    //for (int i = 0; i < (int) probs.size(); i++) {
+    //    printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
     //}
     //exit(0);
 
-    // sample from the obtained distribution
-    std::vector<double> probs;
-    probs.reserve(logits_id.size());
-
-    for (int i = 0; i < (int) logits_id.size(); i++) {
-        probs.push_back(logits_id[i].first);
-    }
-
     std::discrete_distribution<> dist(probs.begin(), probs.end());
     int idx = dist(rng);