]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add embedding mode with arg flag. Currently working (#282)
authorLuciano <redacted>
Fri, 24 Mar 2023 15:05:13 +0000 (08:05 -0700)
committerGitHub <redacted>
Fri, 24 Mar 2023 15:05:13 +0000 (17:05 +0200)
* working but ugly

* add arg flag, not working on embedding mode

* typo

* Working! Thanks to @nullhook

* make params argument instead of hardcoded boolean. remove useless time check

* start doing the instructions but not finished. This probably doesnt compile

* Embeddings extraction support

---------

Co-authored-by: Georgi Gerganov <redacted>
llama.cpp
llama.h
main.cpp
utils.cpp
utils.h

index d55219256932a49c4f5f7a720719be39df3050c4..d8c771529ad310ea4bcea4b80ef399230ce043db 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -102,6 +102,9 @@ struct llama_context {
     // decode output (2-dimensional array: [n_tokens][n_vocab])
     std::vector<float> logits;
     bool logits_all = false;
+
+    // input embedding (1-dimensional array: [n_embd])
+    std::vector<float> embedding;
 };
 
 struct llama_context_params llama_context_default_params() {
@@ -112,6 +115,7 @@ struct llama_context_params llama_context_default_params() {
         /*.f16_kv     =*/ false,
         /*.logits_all =*/ false,
         /*.vocab_only =*/ false,
+        /*.embedding  =*/ false,
     };
 
     return result;
@@ -592,8 +596,6 @@ static bool llama_model_load(
         fin.close();
     }
 
-    lctx.logits.reserve(lctx.model.hparams.n_ctx);
-
     lctx.t_load_us = ggml_time_us() - t_start_us;
 
     return true;
@@ -791,6 +793,9 @@ static bool llama_eval_internal(
         inpL = cur;
     }
 
+    // used at the end to optionally extract the embeddings
+    struct ggml_tensor * embeddings = NULL;
+
     // norm
     {
         inpL = ggml_rms_norm(ctx0, inpL);
@@ -799,6 +804,8 @@ static bool llama_eval_internal(
         inpL = ggml_mul(ctx0,
                     ggml_repeat(ctx0, model.norm, inpL),
                     inpL);
+
+        embeddings = inpL;
     }
 
     // lm_head
@@ -821,15 +828,26 @@ static bool llama_eval_internal(
     //embd_w.resize(n_vocab*N);
     //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
 
-    auto & logits_out = lctx.logits;
+    // extract logits
+    {
+        auto & logits_out = lctx.logits;
+
+        if (lctx.logits_all) {
+            logits_out.resize(n_vocab * N);
+            memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
+        } else {
+            // return result for just the last token
+            logits_out.resize(n_vocab);
+            memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
+        }
+    }
+
+    // extract embeddings
+    if (lctx.embedding.size()) {
+        auto & embedding_out = lctx.embedding;
 
-    if (lctx.logits_all) {
-        logits_out.resize(n_vocab * N);
-        memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
-    } else {
-        // return result for just the last token
-        logits_out.resize(n_vocab);
-        memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
+        embedding_out.resize(n_embd);
+        memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
     }
 
     if (mem_per_token == 0) {
@@ -1416,6 +1434,20 @@ struct llama_context * llama_init_from_file(
         return nullptr;
     }
 
+    // reserve memory for context buffers
+    {
+        const auto & hparams = ctx->model.hparams;
+        if (params.logits_all) {
+            ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
+        } else {
+            ctx->logits.reserve(hparams.n_ctx);
+        }
+
+        if (params.embedding){
+            ctx->embedding.reserve(hparams.n_embd);
+        }
+    }
+
     return ctx;
 }
 
@@ -1484,6 +1516,10 @@ float * llama_get_logits(struct llama_context * ctx) {
     return ctx->logits.data();
 }
 
+float * llama_get_embeddings(struct llama_context * ctx) {
+    return ctx->embedding.data();
+}
+
 const char * llama_token_to_str(struct llama_context * ctx, llama_token token) {
     if (token >= llama_n_vocab(ctx)) {
         return nullptr;
diff --git a/llama.h b/llama.h
index 3df9ed1fdd82cdbcc59e8b4400dfa46257758089..209b4dbe81d6a9464e1d4ba84cd6ce76bd54d53a 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -53,6 +53,7 @@ extern "C" {
         bool f16_kv;     // use fp16 for KV cache
         bool logits_all; // the llama_eval() call computes all logits, not just the last one
         bool vocab_only; // only load the vocabulary, no weights
+        bool embedding;  // embedding mode only
     };
 
     LLAMA_API struct llama_context_params llama_context_default_params();
@@ -108,6 +109,10 @@ extern "C" {
     // Cols: n_vocab
     LLAMA_API float * llama_get_logits(struct llama_context * ctx);
 
+    // Get the embeddings for the input
+    // shape: [n_embd] (1-dimensional)
+    LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
+
     // Token Id -> String. Uses the vocabulary in the provided context
     LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);
 
index 5ba6d5a7561dcff65732b8737b94605961a38994..46a80ff876eafda8b58ec8817abd6f7364045db6 100644 (file)
--- a/main.cpp
+++ b/main.cpp
@@ -199,6 +199,7 @@ int main(int argc, char ** argv) {
         lparams.seed       = params.seed;
         lparams.f16_kv     = params.memory_f16;
         lparams.logits_all = params.perplexity;
+        lparams.embedding  = params.embedding;
 
         ctx = llama_init_from_file(params.model.c_str(), lparams);
 
@@ -292,6 +293,7 @@ int main(int argc, char ** argv) {
 
     std::vector<llama_token> embd;
 
+
     int last_n_size = params.repeat_last_n;
     std::vector<llama_token> last_n_tokens(last_n_size);
     std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
@@ -324,6 +326,27 @@ int main(int argc, char ** argv) {
     // the first thing we will do is to output the prompt, so set color accordingly
     set_console_state(CONSOLE_STATE_PROMPT);
 
+    if (params.embedding){
+        embd = embd_inp;
+
+        if (embd.size() > 0) {
+            if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
+                fprintf(stderr, "%s : failed to eval\n", __func__);
+                return 1;
+            }
+        }
+
+        const auto embeddings = llama_get_embeddings(ctx);
+
+        // TODO: print / use the embeddings
+
+        if (params.use_color) {
+            printf(ANSI_COLOR_RESET);
+        }
+
+        return 0;
+    }
+
     while (remaining_tokens > 0 || params.interactive) {
         // predict
         if (embd.size() > 0) {
index 45c9cabb1b168931098cce639d7fe3b3810e5f84..0df89af0bacdbbc5eecdbfff1143a558e861ac2d 100644 (file)
--- a/utils.cpp
+++ b/utils.cpp
@@ -117,6 +117,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
             params.model = argv[i];
         } else if (arg == "-i" || arg == "--interactive") {
             params.interactive = true;
+        } else if (arg == "--embedding") {
+            params.embedding = true;
+        } else if (arg == "--interactive-start") {
+            params.interactive = true;
         } else if (arg == "--interactive-first") {
             params.interactive_start = true;
         } else if (arg == "-ins" || arg == "--instruct") {
diff --git a/utils.h b/utils.h
index b0de556c953708c284c2e43bd55a292c937dac78..8120c123bd5df3851443c922b1f4a4af025ebc86 100644 (file)
--- a/utils.h
+++ b/utils.h
@@ -32,13 +32,17 @@ struct gpt_params {
     std::string model  = "models/lamma-7B/ggml-model.bin"; // model path
     std::string prompt = "";
 
+
     std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
 
     bool memory_f16        = false; // use f16 instead of f32 for memory kv
     bool random_prompt     = false; // do not randomize prompt if none provided
     bool use_color         = false; // use color to distinguish generations and inputs
     bool interactive       = false; // interactive mode
+
+    bool embedding         = false; // get only sentence embedding
     bool interactive_start = false; // wait for user input immediately
+
     bool instruct          = false; // instruction mode (used for Alpaca models)
     bool ignore_eos        = false; // do not stop generating after eos
     bool perplexity        = false; // compute perplexity over the prompt