]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Cleanup STL headers + fix embedding examples + minor stuff
authorGeorgi Gerganov <redacted>
Sat, 25 Mar 2023 18:51:14 +0000 (20:51 +0200)
committerGeorgi Gerganov <redacted>
Sat, 25 Mar 2023 18:51:14 +0000 (20:51 +0200)
examples/embedding/embedding.cpp
examples/perplexity/perplexity.cpp
llama.cpp
llama.h

index 3015293f7f8eabb240159da88ad88573d4c8702b..d397f35fd464c3b653fb039efeb01ff7e0d61658 100644 (file)
@@ -1,15 +1,6 @@
 #include "common.h"
 #include "llama.h"
 
-#include <cassert>
-#include <cinttypes>
-#include <cmath>
-#include <cstdio>
-#include <cstring>
-#include <fstream>
-#include <string>
-#include <vector>
-
 int main(int argc, char ** argv) {
     gpt_params params;
     params.model = "models/llama-7B/ggml-model.bin";
@@ -94,9 +85,13 @@ int main(int argc, char ** argv) {
             }
         }
 
+        const int n_embd = llama_n_embd(ctx);
         const auto embeddings = llama_get_embeddings(ctx);
 
-        // TODO: print / use the embeddings
+        for (int i = 0; i < n_embd; i++) {
+            printf("%f ", embeddings[i]);
+        }
+        printf("\n");
     }
 
     llama_print_timings(ctx);
index f0266a01f7bedbb6e2803dc520b2ee1050f05cc3..f617ba365dd05dec9c003cc73631677de09905b3 100644 (file)
@@ -1,14 +1,6 @@
 #include "common.h"
 #include "llama.h"
 
-#include <cassert>
-#include <cinttypes>
-#include <cmath>
-#include <cstdio>
-#include <cstring>
-#include <string>
-#include <vector>
-
 std::vector<double> softmax(const std::vector<float>& logits) {
     std::vector<double> probs(logits.size());
     float max_logit = logits[0];
index 0015edec18df753c5fb2ef1d472c699f06cf7e60..2bd520353efda83323bfd9d5375b36d93308c8ae 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1261,10 +1261,10 @@ static llama_vocab::id llama_sample_top_p_top_k(
         double repeat_penalty) {
     auto & rng = lctx.rng;
 
-    const auto & vocab = lctx.vocab;
-    const auto & logits = lctx.logits;
+    const int n_logits = lctx.model.hparams.n_vocab;
 
-    int n_logits = vocab.id_to_token.size();
+    const auto & logits = lctx.logits;
+    const auto * plogits = logits.data() + logits.size() - n_logits;
 
     std::vector<std::pair<double, llama_vocab::id>> logits_id;
     logits_id.reserve(n_logits);
@@ -1276,13 +1276,13 @@ static llama_vocab::id llama_sample_top_p_top_k(
             // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
             if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
                 // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
-                if (logits[i] < 0.0) {
-                    logits_id.push_back(std::make_pair(logits[i]*scale*repeat_penalty, i));
+                if (plogits[i] < 0.0) {
+                    logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
                 } else {
-                    logits_id.push_back(std::make_pair(logits[i]*scale/repeat_penalty, i));
+                    logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
                 }
             } else {
-                logits_id.push_back(std::make_pair(logits[i]*scale, i));
+                logits_id.push_back(std::make_pair(plogits[i]*scale, i));
             }
         }
     }
@@ -1677,6 +1677,8 @@ struct llama_context * llama_init_from_file(
         }
 
         const auto & hparams = ctx->model.hparams;
+
+        // resized during inference
         if (params.logits_all) {
             ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
         } else {
@@ -1684,7 +1686,7 @@ struct llama_context * llama_init_from_file(
         }
 
         if (params.embedding){
-            ctx->embedding.reserve(hparams.n_embd);
+            ctx->embedding.resize(hparams.n_embd);
         }
 
         ctx->buf_compute.resize(MEM_REQ_EVAL.at(ctx->model.type));
@@ -1761,6 +1763,10 @@ int llama_n_ctx(struct llama_context * ctx) {
     return ctx->model.hparams.n_ctx;
 }
 
+int llama_n_embd(struct llama_context * ctx) {
+    return ctx->model.hparams.n_embd;
+}
+
 float * llama_get_logits(struct llama_context * ctx) {
     return ctx->logits.data();
 }
diff --git a/llama.h b/llama.h
index 827abc1f26005baa2d2bcd27068ef60f326cfdda..ebf55f41c35ace46ad62a6fa4f1ee07b5b68c0ec 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -109,6 +109,7 @@ extern "C" {
 
     LLAMA_API int llama_n_vocab(struct llama_context * ctx);
     LLAMA_API int llama_n_ctx  (struct llama_context * ctx);
+    LLAMA_API int llama_n_embd (struct llama_context * ctx);
 
     // Token logits obtained from the last call to llama_eval()
     // The logits for the last token are stored in the last row