]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : fix embeddings (#5796)
authorGeorgi Gerganov <redacted>
Mon, 4 Mar 2024 20:31:20 +0000 (22:31 +0200)
committerGitHub <redacted>
Mon, 4 Mar 2024 20:31:20 +0000 (22:31 +0200)
* llama : fix embeddings

ggml-ci

* llama : do not use KV cache for non-causal models

ggml-ci

* embeddings : fix llama_batch_init arg

* llama : add pooling switch

* llama : distinguish token vs sequence embeddings

ggml-ci

* llama : assert pooling tensor

* llama : simplify causal mask condition

ggml-ci

* llama : assert input batch with pooling enabled

* readme : update API changes list

README.md
common/common.cpp
examples/embedding/embedding.cpp
examples/server-embd.py [new file with mode: 0644]
examples/server/server.cpp
llama.cpp
llama.h

index 45c5d06f3e10ec4e378436b8ae7bd24f7590fd5d..f754022de894d83206c33535e5449302b1bc03ae 100644 (file)
--- a/README.md
+++ b/README.md
@@ -10,6 +10,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
 
 ### Recent API changes
 
+- [2024 Mar 4] Embeddings API updated https://github.com/ggerganov/llama.cpp/pull/5796
 - [2024 Mar 3] `struct llama_context_params` https://github.com/ggerganov/llama.cpp/pull/5849
 
 ### Hot topics
index 036a981349a690d201034830a612ee653248a21b..c244db6443eaa075f6a92976f0ffc119904fe4d5 100644 (file)
@@ -1292,7 +1292,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
     cparams.n_threads_batch   = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
     cparams.seed              = params.seed;
     cparams.logits_all        = params.logits_all;
-    cparams.embedding         = params.embedding;
+    cparams.embeddings        = params.embedding;
     cparams.rope_scaling_type = params.rope_scaling_type;
     cparams.rope_freq_base    = params.rope_freq_base;
     cparams.rope_freq_scale   = params.rope_freq_scale;
index acff715e99d05b25ff2ca9d6abba4bde40728cca..ff5883da6ba27a7b5bbd3326ee7b53f7b50665c1 100644 (file)
@@ -19,11 +19,11 @@ static std::vector<std::string> split_lines(const std::string & s) {
 
 static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
     for (size_t i = 0; i < tokens.size(); i++) {
-        llama_batch_add(batch, tokens[i], i, { seq_id }, false);
+        llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
     }
 }
 
-static void normalize(float * vec, float * out, int n) {
+static void normalize(const float * vec, float * out, int n) {
     float norm = 0;
     for (int i = 0; i < n; i++) {
         norm += vec[i] * vec[i];
@@ -45,10 +45,23 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
     }
 
     // normalize on copy
-    for (int k = 0; k < n_seq; k++) {
-        float * emb = llama_get_embeddings_ith(ctx, k);
-        float * out = output + k * n_embd;
-        normalize(emb, out, n_embd);
+    for (int i = 0; i < batch.n_tokens; i++) {
+        if (!batch.logits[i]) {
+            continue;
+        }
+
+        // try to get sequence embeddings - supported only when pooling_type is not NONE
+        const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
+        if (embd == NULL) {
+            embd = llama_get_embeddings_ith(ctx, i);
+            if (embd == NULL) {
+                fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i);
+                continue;
+            }
+        }
+
+        float * out = output + batch.seq_id[i][0] * n_embd;
+        normalize(embd, out, n_embd);
     }
 }
 
@@ -132,7 +145,7 @@ int main(int argc, char ** argv) {
 
     // initialize batch
     const int n_prompts = prompts.size();
-    struct llama_batch batch = llama_batch_init(n_batch, 0, n_prompts);
+    struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
 
     // allocate output
     const int n_embd = llama_n_embd(model);
@@ -145,6 +158,7 @@ int main(int argc, char ** argv) {
     for (int k = 0; k < n_prompts; k++) {
         // clamp to n_batch tokens
         auto & inp = inputs[k];
+
         const uint64_t n_toks = inp.size();
 
         // encode if at capacity
diff --git a/examples/server-embd.py b/examples/server-embd.py
new file mode 100644 (file)
index 0000000..c5c4ea8
--- /dev/null
@@ -0,0 +1,34 @@
+import asyncio
+import requests
+import numpy as np
+
+n = 8
+
+result = []
+
+async def requests_post_async(*args, **kwargs):
+    return await asyncio.to_thread(requests.post, *args, **kwargs)
+
+async def main():
+    model_url = "http://127.0.0.1:6900"
+    responses: list[requests.Response] = await asyncio.gather(*[requests_post_async(
+        url= f"{model_url}/embedding",
+        json= {"content": str(i)*1024}
+    ) for i in range(n)])
+
+    for response in responses:
+        embedding = response.json()["embedding"]
+        print(embedding[-8:])
+        result.append(embedding)
+
+asyncio.run(main())
+
+# compute cosine similarity
+
+for i in range(n-1):
+    for j in range(i+1, n):
+        embedding1 = np.array(result[i])
+        embedding2 = np.array(result[j])
+        similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
+        print(f"Similarity between {i} and {j}: {similarity:.2f}")
+
index 208edd571cb0ed5c459e49b75a794aa50be930ad..8fe5e0b19668f30e827712c0a72889af3c791ff8 100644 (file)
@@ -1210,7 +1210,7 @@ struct llama_server_context
         queue_results.send(res);
     }
 
-    void send_embedding(server_slot &slot)
+    void send_embedding(server_slot & slot, const llama_batch & batch)
     {
         task_result res;
         res.id = slot.task_id;
@@ -1219,6 +1219,7 @@ struct llama_server_context
         res.stop = true;
 
         const int n_embd = llama_n_embd(model);
+
         if (!params.embedding)
         {
             LOG_WARNING("embedding disabled", {{"params.embedding", params.embedding}});
@@ -1229,12 +1230,29 @@ struct llama_server_context
         }
         else
         {
-            const float *data = llama_get_embeddings(ctx);
-            std::vector<float> embedding(data, data + n_embd);
-            res.result_json = json
-            {
-                {"embedding", embedding},
-            };
+            for (int i = 0; i < batch.n_tokens; ++i) {
+                if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
+                    continue;
+                }
+
+                const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
+                if (embd == NULL) {
+                    embd = llama_get_embeddings_ith(ctx, i);
+                    if (embd == NULL) {
+                        LOG_ERROR("failed to get embeddings for token", {{"token", batch.token[i]}, {"seq_id", batch.seq_id[i][0]}});
+                        res.result_json = json
+                        {
+                            {"embedding", std::vector<float>(n_embd, 0.0f)},
+                        };
+                        continue;
+                    }
+                }
+
+                res.result_json = json
+                {
+                    {"embedding", std::vector<float>(embd, embd + n_embd)},
+                };
+            }
         }
         queue_results.send(res);
     }
@@ -1845,7 +1863,7 @@ struct llama_server_context
                                 ga_i += ga_w/ga_n;
                             }
                         }
-                        llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false);
+                        llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false);
                         slot_npast++;
                     }
 
@@ -1881,7 +1899,7 @@ struct llama_server_context
 
         for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch)
         {
-            const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
+            const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
 
             for (auto & slot : slots)
             {
@@ -1954,7 +1972,7 @@ struct llama_server_context
                 // prompt evaluated for embedding
                 if (slot.embedding)
                 {
-                    send_embedding(slot);
+                    send_embedding(slot, batch_view);
                     slot.release();
                     slot.i_batch = -1;
                     continue;
@@ -2036,6 +2054,8 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
     printf("  --yarn-attn-factor N      YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
     printf("  --yarn-beta-slow N        YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
     printf("  --yarn-beta-fast N        YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
+    printf("  --pooling {none,mean,cls}\n");
+    printf("                        pooling type for embeddings, use model default if unspecified\n");
     printf("  -b N, --batch-size N      batch size for prompt processing (default: %d)\n", params.n_batch);
     printf("  --memory-f32              use f32 instead of f16 for memory key+value (default: disabled)\n");
     printf("                            not recommended: doubles context memory required and no measurable increase in quality\n");
@@ -2276,6 +2296,18 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
             }
             params.yarn_beta_slow = std::stof(argv[i]);
         }
+        else if (arg == "--pooling")
+        {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            std::string value(argv[i]);
+            /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
+            else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
+            else if (value == "cls")  { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
+            else { invalid_param = true; break; }
+        }
         else if (arg == "--threads" || arg == "-t")
         {
             if (++i >= argc)
@@ -2330,7 +2362,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
                 break;
             }
             params.n_batch = std::stoi(argv[i]);
-            params.n_batch = std::min(512, params.n_batch);
         }
         else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers")
         {
index de579d9e372b4e1e5d180e48c208d02e58370013..76afcbc135f4cdcb4a3f4204dfa52cbf8c1e2f8c 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1665,7 +1665,7 @@ struct llama_hparams {
 };
 
 struct llama_cparams {
-    uint32_t n_ctx;       // context size used during inference
+    uint32_t n_ctx;           // context size used during inference
     uint32_t n_batch;
     uint32_t n_threads;       // number of threads to use for generation
     uint32_t n_threads_batch; // number of threads to use for batch processing
@@ -1682,7 +1682,9 @@ struct llama_cparams {
     float yarn_beta_slow;
     float defrag_thold;
 
+    bool embeddings;
     bool offload_kqv;
+
     enum llama_pooling_type pooling_type;
 
     ggml_backend_sched_eval_callback cb_eval;
@@ -1972,7 +1974,7 @@ struct llama_context {
     int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
     int32_t n_eval   = 0; // number of eval calls
 
-    // decode output (2-dimensional array: [n_tokens][n_vocab])
+    // logits output (2-dimensional array: [n_tokens][n_vocab])
     std::vector<float> logits;
 #ifndef NDEBUG
     // guard against access to unset logits
@@ -1980,8 +1982,13 @@ struct llama_context {
 #endif
     bool logits_all = false;
 
-    // input embedding (1-dimensional array: [n_embd])
-    std::vector<float> embedding;
+    // embeddings output (2-dimensional array: [n_tokens][n_embd])
+    // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
+    std::vector<float> embd;
+
+    // sequence embeddings output (map of [n_embd] vectors)
+    // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
+    std::map<llama_seq_id, std::vector<float>> embd_seq;
 
     // memory buffers used to evaluate the model
     std::vector<uint8_t> buf_compute_meta;
@@ -5092,6 +5099,7 @@ static struct ggml_tensor * llm_build_kv(
     llm_build_kv_store(ctx, hparams, kv, graph, k_cur, v_cur, n_ctx, n_tokens, kv_head, cb, il);
 
     struct ggml_tensor * cur;
+
     cur  = llm_build_kqv(ctx, model, hparams, kv, graph, wo, wo_b,
             q_cur, kq_mask, kq_pos, n_ctx, n_tokens, n_kv, kq_scale, cb, il);
     cb(cur, "kqv_out", il);
@@ -6085,6 +6093,7 @@ struct llm_build_context {
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
 
         struct ggml_tensor * cur;
@@ -6092,9 +6101,10 @@ struct llm_build_context {
 
         // get input vectors with right size
         const size_t stride1 = n_tokens * ggml_type_size(lctx.inp_tokens->type);
-        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
+
+        struct ggml_tensor * inp_pos  = ggml_view_1d(ctx0, lctx.inp_pos,  n_tokens, 0);
         struct ggml_tensor * inp_mean = ggml_view_2d(ctx0, lctx.inp_mean, n_tokens, n_tokens, stride1, 0);
-        struct ggml_tensor * inp_cls = ggml_view_1d(ctx0, lctx.inp_cls, n_tokens, 0);
+        struct ggml_tensor * inp_cls  = ggml_view_1d(ctx0, lctx.inp_cls,  n_tokens, 0);
 
         // construct input embeddings (token, type, position)
         inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
@@ -6112,39 +6122,38 @@ struct llm_build_context {
         cb(inpL, "inp_norm", -1);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
-        cb(KQ_mask, "KQ_mask", -1); // [n_kv, n_tokens]
+        struct ggml_tensor * KQ_mask = ggml_cont(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_tokens, n_tokens, n_tokens*ggml_type_size(lctx.inp_KQ_mask->type), 0));
+        cb(KQ_mask, "KQ_mask", -1); // [n_tokens, n_tokens]
 
         // iterate layers
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * cur = inpL;
 
+            struct ggml_tensor * Qcur;
+            struct ggml_tensor * Kcur;
+            struct ggml_tensor * Vcur;
+
             // self-attention
             if (model.arch == LLM_ARCH_BERT) {
-                struct ggml_tensor * Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq);
+                Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq);
                 cb(Qcur, "Qcur", il);
 
-                struct ggml_tensor * Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), model.layers[il].bk);
+                Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), model.layers[il].bk);
                 cb(Kcur, "Kcur", il);
 
-                struct ggml_tensor * Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), model.layers[il].bv);
+                Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), model.layers[il].bv);
                 cb(Vcur, "Vcur", il);
 
-                // seems like we just need to do this for Q?
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-
-                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-                cb(cur, "kqv_out", il);
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
             } else {
                 // compute Q and K and RoPE them
                 cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
-                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+                Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
 
                 cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
@@ -6163,12 +6172,40 @@ struct llm_build_context {
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
+            }
 
-                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-                cb(cur, "kqv_out", il);
+            struct ggml_tensor * q =                 ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
+            struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
+
+            struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+            cb(kq, "kq", il);
+
+            kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, nullptr, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
+            cb(kq, "kq_soft_max_ext", il);
+
+            struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)));
+            cb(v, "v", il);
+
+            struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq);
+            cb(kqv, "kqv", il);
+
+            struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+            cb(kqv_merged, "kqv_merged", il);
+
+            cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
+            cb(cur, "kqv_merged_cont", il);
+
+            ggml_build_forward_expand(gf, cur);
+
+            cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
+            if (model.layers[il].bo) {
+                cb(cur, "kqv_wo", il);
+            }
+
+            if (model.layers[il].bo) {
+                cur = ggml_add(ctx0, cur, model.layers[il].bo);
             }
+            cb(cur, "kqv_out", il);
 
             // re-add the layer input
             cur = ggml_add(ctx0, cur, inpL);
@@ -6209,16 +6246,29 @@ struct llm_build_context {
 
         // final output
         cur = inpL;
+        cb(cur, "result_embd", -1);
 
         // pooling layer
-        if (pooling_type == LLAMA_POOLING_TYPE_MEAN) {
-            cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
-        } else if (pooling_type == LLAMA_POOLING_TYPE_CLS) {
-            cur = ggml_get_rows(ctx0, cur, inp_cls);
-        } else {
-            GGML_ASSERT(pooling_type == LLAMA_POOLING_TYPE_NONE && "Invalid pooling type");
+        switch (pooling_type) {
+            case LLAMA_POOLING_TYPE_NONE:
+                {
+                    // nop
+                } break;
+            case LLAMA_POOLING_TYPE_MEAN:
+                {
+                    cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
+                    cb(cur, "result_embd_pooled", -1);
+                } break;
+            case LLAMA_POOLING_TYPE_CLS:
+                {
+                    cur = ggml_get_rows(ctx0, cur, inp_cls);
+                    cb(cur, "result_embd_pooled", -1);
+                } break;
+            case LLAMA_POOLING_TYPE_UNSPECIFIED:
+                {
+                    GGML_ASSERT(false && "Invalid pooling type");
+                } break;
         }
-        cb(cur, "result_embd", -1);
 
         ggml_build_forward_expand(gf, cur);
 
@@ -7980,7 +8030,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
     }
 
-    {
+    if (hparams.causal_attn) {
         const int64_t n_kv     = kv_self.n;
         const int64_t n_tokens = batch.n_tokens;
 
@@ -7995,16 +8045,40 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
 
                 for (int i = 0; i < n_kv; ++i) {
                     float f;
-                    if (!lctx.kv_self.cells[i].has_seq_id(seq_id) ||
-                        (hparams.causal_attn && lctx.kv_self.cells[i].pos > pos)) {
+                    if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
                         f = -INFINITY;
                     } else {
-                        f = 0;
+                        f = 0.0f;
                     }
                     data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
                 }
             }
         }
+    } else {
+        // non-causal attention attends only the tokens within the batch (i.e. the KV cache is not used)
+        const int64_t n_tokens = batch.n_tokens;
+
+        assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
+
+        float * data = (float *) lctx.inp_KQ_mask->data;
+
+        for (int h = 0; h < 1; ++h) {
+            for (int j = 0; j < n_tokens; ++j) {
+                const llama_seq_id seq_id = batch.seq_id[j][0];
+
+                for (int i = 0; i < n_tokens; ++i) {
+                    float f = -INFINITY;
+                    for (int s = 0; s < batch.n_seq_id[i]; ++s) {
+                        if (batch.seq_id[i][s] == seq_id) {
+                            f = 0.0f;
+                            break;
+                        }
+                    }
+
+                    data[h*(n_tokens*n_tokens) + j*n_tokens + i] = f;
+                }
+            }
+        }
     }
 
     if (hparams.need_kq_pos) {
@@ -8023,13 +8097,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         const int64_t n_tokens = batch.n_tokens;
 
         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
-        float * data = (float *) lctx.inp_mean->data;
 
+        float * data = (float *) lctx.inp_mean->data;
         memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean));
 
         std::vector<uint64_t> sum(n_tokens, 0);
         for (int i = 0; i < n_tokens; ++i) {
             const llama_seq_id seq_id = batch.seq_id[i][0];
+
+            GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
+
             sum[seq_id] += 1;
         }
 
@@ -8051,11 +8128,16 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         const int64_t n_tokens = batch.n_tokens;
 
         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
+
         uint32_t * data = (uint32_t *) lctx.inp_cls->data;
+        memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
 
         for (int i = 0; i < n_tokens; ++i) {
             const llama_seq_id seq_id = batch.seq_id[i][0];
-            const llama_pos pos = batch.pos[i];
+            const llama_pos    pos    = batch.pos[i];
+
+            GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS");
+
             if (pos == 0) {
                 data[seq_id] = i;
             }
@@ -8169,23 +8251,26 @@ static int llama_decode_internal(
         batch.seq_id = seq_id_arr.data();
     }
 
-    llama_kv_cache_update(&lctx);
+    // non-causal masks do not use the KV cache
+    if (hparams.causal_attn) {
+        llama_kv_cache_update(&lctx);
 
-    // if we have enough unused cells before the current head ->
-    //   better to start searching from the beginning of the cache, hoping to fill it
-    if (kv_self.head > kv_self.used + 2*n_tokens) {
-        kv_self.head = 0;
-    }
+        // if we have enough unused cells before the current head ->
+        //   better to start searching from the beginning of the cache, hoping to fill it
+        if (kv_self.head > kv_self.used + 2*n_tokens) {
+            kv_self.head = 0;
+        }
 
-    if (!llama_kv_cache_find_slot(kv_self, batch)) {
-        return 1;
-    }
+        if (!llama_kv_cache_find_slot(kv_self, batch)) {
+            return 1;
+        }
 
-    // a heuristic, to avoid attending the full cache if it is not yet utilized
-    // after enough generations, the benefit from this heuristic disappears
-    // if we start defragmenting the cache, the benefit from this will be more important
-    kv_self.n = std::min(cparams.n_ctx, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
-    //kv_self.n = llama_kv_cache_cell_max(kv_self);
+        // a heuristic, to avoid attending the full cache if it is not yet utilized
+        // after enough generations, the benefit from this heuristic disappears
+        // if we start defragmenting the cache, the benefit from this will be more important
+        kv_self.n = std::min(cparams.n_ctx, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
+        //kv_self.n = llama_kv_cache_cell_max(kv_self);
+    }
 
     //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
 
@@ -8195,20 +8280,26 @@ static int llama_decode_internal(
     ggml_cgraph * gf = llama_build_graph(lctx, batch, false);
 
     // the output is always the last tensor in the graph
-    struct ggml_tensor * res        = gf->nodes[gf->n_nodes - 1];
-    struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
-
-    if (strcmp(res->name, "result_output") == 0) {
-        // the embeddings could be the second to last tensor, or the third to last tensor
-        if (strcmp(embeddings->name, "result_norm") != 0) {
-            embeddings = gf->nodes[gf->n_nodes - 3];
-            GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
-        }
-    } else if (strcmp(res->name, "result_embd") == 0) {
-        embeddings = res;
-        res = nullptr;
+    struct ggml_tensor * res  = gf->nodes[gf->n_nodes - 1];
+    struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
+
+    if (!hparams.causal_attn) {
+        res = nullptr; // do not extract logits for embedding models such as BERT
+
+        // token or sequence embeddings
+        embd = gf->nodes[gf->n_nodes - 1];
+
+        GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
     } else {
-        GGML_ASSERT(false);
+        if (strcmp(res->name, "result_output") == 0) {
+            // the token embeddings could be the second to last tensor, or the third to last tensor
+            if (strcmp(embd->name, "result_norm") != 0) {
+                embd = gf->nodes[gf->n_nodes - 3];
+                GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
+            }
+        } else {
+            GGML_ASSERT(false && "missing result_output tensor");
+        }
     }
 
     // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
@@ -8275,46 +8366,82 @@ static int llama_decode_internal(
         logits_out.clear();
 #endif
 
-        ggml_backend_t res_backend = ggml_backend_sched_get_node_backend(lctx.sched, res);
-        GGML_ASSERT(res_backend != nullptr);
+        ggml_backend_t backend_res = ggml_backend_sched_get_node_backend(lctx.sched, res);
+        GGML_ASSERT(backend_res != nullptr);
+
         if (batch.logits) {
             logits_out.resize(n_vocab * n_tokens);
             for (uint32_t i = 0; i < n_tokens; i++) {
                 if (batch.logits[i] == 0) {
                     continue;
                 }
-                ggml_backend_tensor_get_async(res_backend, res, logits_out.data() + (n_vocab*i), (n_vocab*i)*sizeof(float), n_vocab*sizeof(float));
+                ggml_backend_tensor_get_async(backend_res, res, logits_out.data() + (n_vocab*i), (n_vocab*i)*sizeof(float), n_vocab*sizeof(float));
 #ifndef NDEBUG
                 logits_valid[i] = true;
 #endif
             }
         } else if (lctx.logits_all) {
             logits_out.resize(n_vocab * n_tokens);
-            ggml_backend_tensor_get_async(res_backend, res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float));
+            ggml_backend_tensor_get_async(backend_res, res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float));
 #ifndef NDEBUG
             std::fill(logits_valid.begin(), logits_valid.end(), true);
 #endif
         } else {
             logits_out.resize(n_vocab);
-            ggml_backend_tensor_get_async(res_backend, res, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), n_vocab*sizeof(float));
+            ggml_backend_tensor_get_async(backend_res, res, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), n_vocab*sizeof(float));
 #ifndef NDEBUG
             logits_valid[0] = true;
 #endif
         }
-        ggml_backend_synchronize(res_backend);
+        ggml_backend_synchronize(backend_res);
     }
 
     // extract embeddings
-    if (!lctx.embedding.empty()) {
-        auto & embedding_out = lctx.embedding;
+    if (cparams.embeddings && embd) {
+        ggml_backend_t backend_embd = ggml_backend_sched_get_node_backend(lctx.sched, embd);
+        GGML_ASSERT(backend_embd != nullptr);
 
-        const int64_t embd_pos  = res ? n_embd * (n_tokens-1) : 0;
-        const int64_t embd_size = res ? n_embd : n_embd * n_tokens;
+        switch (cparams.pooling_type) {
+            case LLAMA_POOLING_TYPE_NONE:
+                {
+                    // extract token embeddings
+                    auto & embd_out = lctx.embd;
+
+                    if (batch.logits) {
+                        embd_out.resize(n_embd * n_tokens);
+                        for (uint32_t i = 0; i < n_tokens; i++) {
+                            if (batch.logits[i] == 0) {
+                                continue;
+                            }
 
-        embedding_out.resize(embd_size);
-        ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings);
-        ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embd_pos*sizeof(float), embd_size*sizeof(float));
-        ggml_backend_synchronize(embeddings_backend);
+                            ggml_backend_tensor_get_async(backend_embd, embd, embd_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
+                        }
+                    }
+                } break;
+            case LLAMA_POOLING_TYPE_CLS:
+            case LLAMA_POOLING_TYPE_MEAN:
+                {
+                    GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);
+
+                    // extract sequence embeddings
+                    auto & embd_seq_out = lctx.embd_seq;
+                    embd_seq_out.clear();
+
+                    for (uint32_t i = 0; i < n_tokens; i++) {
+                        const llama_seq_id seq_id = batch.seq_id[i][0];
+                        if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
+                            continue;
+                        }
+                        embd_seq_out[seq_id].resize(n_embd);
+                        ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
+                    }
+                } break;
+            case LLAMA_POOLING_TYPE_UNSPECIFIED:
+                {
+                    GGML_ASSERT(false && "unknown pooling type");
+                } break;
+        }
+        ggml_backend_synchronize(backend_embd);
     }
 
     // measure the performance only for the single-token evals
@@ -8608,19 +8735,19 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
     GGML_ASSERT(llama_is_byte_token(vocab, id));
     const auto& token_data = vocab.id_to_token.at(id);
     switch (llama_vocab_get_type(vocab)) {
-    case LLAMA_VOCAB_TYPE_SPM: {
-        auto buf = token_data.text.substr(3, 2);
-        return strtol(buf.c_str(), NULL, 16);
-    }
-    case LLAMA_VOCAB_TYPE_BPE: {
-        GGML_ASSERT(false);
-        return unicode_to_bytes_bpe(token_data.text);
-    }
-    case LLAMA_VOCAB_TYPE_WPM: {
-        GGML_ASSERT(false);
-    }
-    default:
-        GGML_ASSERT(false);
+        case LLAMA_VOCAB_TYPE_SPM: {
+            auto buf = token_data.text.substr(3, 2);
+            return strtol(buf.c_str(), NULL, 16);
+        }
+        case LLAMA_VOCAB_TYPE_BPE: {
+            GGML_ASSERT(false);
+            return unicode_to_bytes_bpe(token_data.text);
+        }
+        case LLAMA_VOCAB_TYPE_WPM: {
+            GGML_ASSERT(false);
+        }
+        default:
+            GGML_ASSERT(false);
     }
 }
 
@@ -11864,7 +11991,7 @@ struct llama_context_params llama_context_default_params() {
         /*.type_k                      =*/ GGML_TYPE_F16,
         /*.type_v                      =*/ GGML_TYPE_F16,
         /*.logits_all                  =*/ false,
-        /*.embedding                   =*/ false,
+        /*.embeddings                  =*/ false,
         /*.offload_kqv                 =*/ true,
         /*.abort_callback              =*/ nullptr,
         /*.abort_callback_data         =*/ nullptr,
@@ -12015,6 +12142,7 @@ struct llama_context * llama_new_context_with_model(
     cparams.yarn_beta_fast   = params.yarn_beta_fast;
     cparams.yarn_beta_slow   = params.yarn_beta_slow;
     cparams.defrag_thold     = params.defrag_thold;
+    cparams.embeddings       = params.embeddings;
     cparams.offload_kqv      = params.offload_kqv;
     cparams.pooling_type     = params.pooling_type;
 
@@ -12192,8 +12320,8 @@ struct llama_context * llama_new_context_with_model(
         // resized during inference, reserve maximum
         ctx->logits.reserve(hparams.n_vocab*cparams.n_batch);
 
-        if (params.embedding) {
-            ctx->embedding.resize(hparams.n_embd);
+        if (params.embeddings) {
+            ctx->embd.reserve(hparams.n_embd*cparams.n_batch);
         }
 
         // graph inputs
@@ -12628,7 +12756,7 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
     // assume worst case for logits although only currently set ones are serialized
     const size_t s_logits          = ctx->logits.capacity() * sizeof(float);
     const size_t s_embedding_size  = sizeof(size_t);
-    const size_t s_embedding       = ctx->embedding.size() * sizeof(float);
+    const size_t s_embedding       = ctx->embd.capacity() * sizeof(float);
     const size_t s_kv_buf_size     = sizeof(size_t);
     const size_t s_kv_head         = sizeof(uint32_t);
     const size_t s_kv_size         = sizeof(uint32_t);
@@ -12737,12 +12865,12 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
 
     // copy embeddings
     {
-        const size_t embedding_size = ctx->embedding.size();
+        const size_t embeddings_size = ctx->embd.size();
 
-        data_ctx->write(&embedding_size, sizeof(embedding_size));
+        data_ctx->write(&embeddings_size, sizeof(embeddings_size));
 
-        if (embedding_size) {
-            data_ctx->write(ctx->embedding.data(), embedding_size * sizeof(float));
+        if (embeddings_size) {
+            data_ctx->write(ctx->embd.data(), embeddings_size * sizeof(float));
         }
     }
 
@@ -12846,15 +12974,17 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
 
     // set embeddings
     {
-        size_t embedding_size;
+        size_t embeddings_size;
+
+        memcpy(&embeddings_size, inp, sizeof(embeddings_size)); inp += sizeof(embeddings_size);
 
-        memcpy(&embedding_size, inp, sizeof(embedding_size)); inp += sizeof(embedding_size);
+        GGML_ASSERT(ctx->embd.capacity() == embeddings_size);
 
-        GGML_ASSERT(ctx->embedding.capacity() == embedding_size);
+        if (embeddings_size) {
+            ctx->embd.resize(embeddings_size);
 
-        if (embedding_size) {
-            memcpy(ctx->embedding.data(), inp, embedding_size * sizeof(float));
-            inp += embedding_size * sizeof(float);
+            memcpy(ctx->embd.data(), inp, embeddings_size * sizeof(float));
+            inp += embeddings_size * sizeof(float);
         }
     }
 
@@ -13104,11 +13234,20 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
 }
 
 float * llama_get_embeddings(struct llama_context * ctx) {
-    return ctx->embedding.data();
+    return ctx->embd.data();
 }
 
 float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
-    return ctx->embedding.data() + i*ctx->model.hparams.n_embd;
+    return ctx->embd.data() + i*ctx->model.hparams.n_embd;
+}
+
+float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
+    auto it = ctx->embd_seq.find(seq_id);
+    if (it == ctx->embd_seq.end()) {
+        return nullptr;
+    }
+
+    return it->second.data();
 }
 
 const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
diff --git a/llama.h b/llama.h
index 70da4cb3f0ff6685e8a1b3f9edb8cc2e477cc1f8..3dc162b078d30269a8b5d69b3a5ffe8051b7052a 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -163,7 +163,7 @@ extern "C" {
     // - embd   : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
     // - pos    : the positions of the respective token in the sequence
     // - seq_id : the sequence to which the respective token belongs
-    // - logits : if zero, the logits for the respective token will not be output
+    // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
     //
     typedef struct llama_batch {
         int32_t n_tokens;
@@ -173,7 +173,7 @@ extern "C" {
         llama_pos    *  pos;
         int32_t      *  n_seq_id;
         llama_seq_id ** seq_id;
-        int8_t       *  logits;
+        int8_t       *  logits; // TODO: rename this to "output"
 
         // NOTE: helpers for smooth API transition - can be deprecated in the future
         //       for future-proof code, use the above fields instead and ignore everything below
@@ -260,7 +260,7 @@ extern "C" {
 
         // Keep the booleans together to avoid misalignment during copy-by-value.
         bool logits_all;  // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
-        bool embedding;   // embedding mode only
+        bool embeddings;  // if true, extract embeddings (together with logits)
         bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
 
         // Abort callback
@@ -655,14 +655,20 @@ extern "C" {
     // llama_get_logits(ctx) + i*n_vocab
     LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
 
-    // Get the embeddings for the input
-    // shape: [n_embd] (1-dimensional)
+    // Get all output token embeddings
+    // shape: [n_tokens*n_embd] (1-dimensional)
     LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
 
-    // Get the embeddings for the ith sequence
+    // Get the embeddings for the ith token
     // llama_get_embeddings(ctx) + i*n_embd
+    // shape: [n_embd] (1-dimensional)
     LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
 
+    // Get the embeddings for a sequence id
+    // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
+    // shape: [n_embd] (1-dimensional)
+    LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
+
     //
     // Vocab
     //