]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : allow pooled embeddings on any model (#7477)
authorDouglas Hanley <redacted>
Fri, 21 Jun 2024 05:38:22 +0000 (00:38 -0500)
committerGitHub <redacted>
Fri, 21 Jun 2024 05:38:22 +0000 (08:38 +0300)
* create append_pooling operation; allow to specify attention_type; add last token pooling; update examples

* find result_norm/result_embd tensors properly; update output allocation logic

* only use embd output for pooling_type NONE

* get rid of old causal_attn accessor

* take out attention_type; add in llama_set_embeddings

* bypass logits when doing non-NONE pooling

common/common.cpp
examples/embedding/embedding.cpp
examples/gritlm/gritlm.cpp
examples/retrieval/retrieval.cpp
llama.cpp
llama.h

index 9c23d001bfba97040fb677614b9c5c2049ac5eca..64f160af1c18c2fefc79777e21b4765cd6440ef4 100644 (file)
@@ -541,6 +541,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
         else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
         else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
+        else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; }
         else { invalid_param = true; }
         return true;
     }
@@ -1869,6 +1870,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
 
     options.push_back({ "backend" });
     options.push_back({ "*",           "       --rpc SERVERS",          "comma separated list of RPC servers" });
+
     if (llama_supports_mlock()) {
         options.push_back({ "*",           "       --mlock",                "force system to keep model in RAM rather than swapping or compressing" });
     }
index 244751e003d9e1ad15ba05d345c3a8f3a57d272e..b4b73c0175cda0beea53a24375ee080f67da6d2a 100644 (file)
@@ -17,9 +17,10 @@ static std::vector<std::string> split_lines(const std::string & s) {
     return lines;
 }
 
-static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
-    for (size_t i = 0; i < tokens.size(); i++) {
-        llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
+static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
+    size_t n_tokens = tokens.size();
+    for (size_t i = 0; i < n_tokens; i++) {
+        llama_batch_add(batch, tokens[i], i, { seq_id }, true);
     }
 }
 
@@ -40,13 +41,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
 
         // try to get sequence embeddings - supported only when pooling_type is not NONE
         const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
-        if (embd == NULL) {
-            embd = llama_get_embeddings_ith(ctx, i);
-            if (embd == NULL) {
-                fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i);
-                continue;
-            }
-        }
+        GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
 
         float * out = output + batch.seq_id[i][0] * n_embd;
         //TODO: I would also add a parameter here to enable normalization or not.
@@ -97,6 +92,12 @@ int main(int argc, char ** argv) {
     const int n_ctx_train = llama_n_ctx_train(model);
     const int n_ctx = llama_n_ctx(ctx);
 
+    const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
+    if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
+        fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__);
+        return 1;
+    }
+
     if (n_ctx > n_ctx_train) {
         fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
                 __func__, n_ctx_train, n_ctx);
index 2135157916c977a779d28c4d047ce4d961d2faab..2c61c2e1eb3bc80002ead3172fa0052f11f133c2 100644 (file)
@@ -44,6 +44,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
 
         // clear previous kv_cache values (irrelevant for embeddings)
         llama_kv_cache_clear(ctx);
+        llama_set_embeddings(ctx, true);
         llama_set_causal_attn(ctx, false);
 
         // run model
@@ -98,7 +99,9 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
     llama_token eos_token = llama_token_eos(mdl);
 
     llama_kv_cache_clear(ctx);
+    llama_set_embeddings(ctx, false);
     llama_set_causal_attn(ctx, true);
+
     llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
 
     std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
@@ -166,8 +169,7 @@ int main(int argc, char * argv[]) {
 
     llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);
 
-    // create new context - set to embedding mode
-    cparams.embeddings = true;
+    // create generation context
     llama_context * ctx = llama_new_context_with_model(mdl, cparams);
 
     // ### Embedding/Representation ###
index 55b7b2f70ae2a1de654102d87085b5cdda980690..eb89d16daf18d7c750ab854401fc69e92cc101f0 100644 (file)
@@ -73,9 +73,10 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz
     return chunks;
 }
 
-static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
-    for (size_t i = 0; i < tokens.size(); i++) {
-        llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
+static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
+    size_t n_tokens = tokens.size();
+    for (size_t i = 0; i < n_tokens; i++) {
+        llama_batch_add(batch, tokens[i], i, { seq_id }, true);
     }
 }
 
@@ -160,6 +161,12 @@ int main(int argc, char ** argv) {
     const int n_ctx_train = llama_n_ctx_train(model);
     const int n_ctx = llama_n_ctx(ctx);
 
+    const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
+    if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
+        fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__);
+        return 1;
+    }
+
     if (n_ctx > n_ctx_train) {
         fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
                 __func__, n_ctx_train, n_ctx);
index 8818c69280178744ddea8521228db4baae2f7041..9ca0b7479619d6b22a6b4764aefa2de09846ee74 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -7649,6 +7649,50 @@ struct llm_build_context {
         return lctx.inp_s_seq;
     }
 
+    struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) {
+        // find result_norm tensor for input
+        struct ggml_tensor * inp = nullptr;
+        for (int i = gf->n_nodes - 1; i >= 0; --i) {
+            inp = gf->nodes[i];
+            if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
+                break;
+            } else {
+                inp = nullptr;
+            }
+        }
+        GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
+
+        struct ggml_tensor * cur;
+
+        switch (pooling_type) {
+            case LLAMA_POOLING_TYPE_MEAN:
+                {
+                    struct ggml_tensor * inp_mean = build_inp_mean();
+                    cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
+                } break;
+            case LLAMA_POOLING_TYPE_CLS:
+            case LLAMA_POOLING_TYPE_LAST:
+                {
+                    struct ggml_tensor * inp_cls = build_inp_cls();
+                    cur = ggml_get_rows(ctx0, inp, inp_cls);
+                } break;
+            case LLAMA_POOLING_TYPE_NONE:
+                {
+                    cur = inp;
+                } break;
+            default:
+                {
+                    GGML_ASSERT(false && "unknown pooling type");
+                } break;
+        }
+
+        cb(cur, "result_embd_pooled", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
     struct ggml_cgraph * build_llama() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
 
@@ -8629,8 +8673,6 @@ struct llm_build_context {
         if (model.arch != LLM_ARCH_JINA_BERT_V2) {
             inp_pos = build_inp_pos();
         }
-        struct ggml_tensor * inp_mean = build_inp_mean();
-        struct ggml_tensor * inp_cls  = build_inp_cls();
 
         // construct input embeddings (token, type, position)
         inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
@@ -8805,28 +8847,6 @@ struct llm_build_context {
         cur = inpL;
         cb(cur, "result_embd", -1);
 
-        // pooling layer
-        switch (pooling_type) {
-            case LLAMA_POOLING_TYPE_NONE:
-                {
-                    // nop
-                } break;
-            case LLAMA_POOLING_TYPE_MEAN:
-                {
-                    cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
-                    cb(cur, "result_embd_pooled", -1);
-                } break;
-            case LLAMA_POOLING_TYPE_CLS:
-                {
-                    cur = ggml_get_rows(ctx0, cur, inp_cls);
-                    cb(cur, "result_embd_pooled", -1);
-                } break;
-            case LLAMA_POOLING_TYPE_UNSPECIFIED:
-                {
-                    GGML_ASSERT(false && "Invalid pooling type");
-                } break;
-        }
-
         ggml_build_forward_expand(gf, cur);
 
         return gf;
@@ -11911,6 +11931,11 @@ static struct ggml_cgraph * llama_build_graph(
             GGML_ASSERT(false);
     }
 
+    // add on pooling layer
+    if (lctx.cparams.embeddings) {
+        result = llm.append_pooling(result);
+    }
+
     llm.free();
 
     return result;
@@ -12000,7 +12025,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         // (!a || b) is a logical implication (a -> b)
         // !hparams.causal_attn -> !cparams.causal_attn
         (hparams.causal_attn || !cparams.causal_attn) &&
-        "causal attention with embedding models is not supported"
+        "causal attention is not supported by this model"
     );
 
     if (lctx.inp_KQ_mask) {
@@ -12132,6 +12157,37 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         }
     }
 
+    if (cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
+        const int64_t n_tokens = batch.n_tokens;
+
+        GGML_ASSERT(lctx.inp_cls);
+        GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
+
+        uint32_t * data = (uint32_t *) lctx.inp_cls->data;
+        memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
+
+        std::vector<int> last_pos(n_tokens, -1);
+        std::vector<int> last_row(n_tokens, -1);
+
+        for (int i = 0; i < n_tokens; ++i) {
+            const llama_seq_id seq_id = batch.seq_id[i][0];
+            const llama_pos    pos    = batch.pos[i];
+
+            GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
+
+            if (pos >= last_pos[seq_id]) {
+                last_pos[seq_id] = pos;
+                last_row[seq_id] = i;
+            }
+        }
+
+        for (int i = 0; i < n_tokens; ++i) {
+            if (last_row[i] >= 0) {
+                data[i] = last_row[i];
+            }
+        }
+    }
+
     if (kv_self.recurrent) {
         const int64_t n_kv = kv_self.n;
 
@@ -12193,8 +12249,8 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
     const auto n_embd  = hparams.n_embd;
 
     // TODO: use a per-batch flag for logits presence instead
-    const bool has_logits = cparams.causal_attn;
-    const bool has_embd   = cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
+    const bool has_logits = !cparams.embeddings;
+    const bool has_embd   =  cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
 
     const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
     const size_t embd_size   = has_embd   ?  n_embd*n_outputs_max : 0;
@@ -12324,11 +12380,13 @@ static int llama_decode_internal(
     std::vector<std::vector<llama_seq_id>> seq_id;
 
     // count outputs
-    if (batch_all.logits) {
+    if (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE) {
+        n_outputs = n_tokens_all;
+    } else if (batch_all.logits) {
         for (uint32_t i = 0; i < n_tokens_all; ++i) {
             n_outputs += batch_all.logits[i] != 0;
         }
-    } else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
+    } else if (lctx.logits_all) {
         n_outputs = n_tokens_all;
     } else {
         // keep last output only
@@ -12459,30 +12517,13 @@ static int llama_decode_internal(
             // no output
             res  = nullptr;
             embd = nullptr;
-        } else if (!hparams.causal_attn) {
-            res = nullptr; // do not extract logits for embedding models such as BERT
-
-            // token or sequence embeddings
-            embd = gf->nodes[gf->n_nodes - 1];
-
-            GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
         } else if (cparams.embeddings) {
-            // the embeddings could be in the second to last tensor, or any of the previous tensors
-            int i_embd = gf->n_nodes - 2;
-            for (int i = 3; strcmp(embd->name, "result_norm") != 0; ++i) {
-                i_embd = gf->n_nodes - i;
-                if (i_embd < 0) { break; }
-                embd = gf->nodes[i_embd];
-            }
-            GGML_ASSERT(i_embd >= 0 && "missing result_norm tensor");
-
-            // TODO: use a per-batch flag to know when to skip logits while keeping embeddings
-            if (!cparams.causal_attn) {
-                res = nullptr; // do not extract logits when not needed
-                // skip computing logits
-                // TODO: is this safe?
-                gf->n_nodes = i_embd + 1;
+            res = nullptr; // do not extract logits for embedding case
+            embd = gf->nodes[gf->n_nodes - 1];
+            if (strcmp(embd->name, "result_embd_pooled") != 0) {
+                embd = gf->nodes[gf->n_nodes - 2];
             }
+            GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
         } else {
             embd = nullptr; // do not extract embeddings when not needed
             GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
@@ -12551,11 +12592,10 @@ static int llama_decode_internal(
                             ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float));
                         }
                     } break;
-                case LLAMA_POOLING_TYPE_CLS:
                 case LLAMA_POOLING_TYPE_MEAN:
+                case LLAMA_POOLING_TYPE_CLS:
+                case LLAMA_POOLING_TYPE_LAST:
                     {
-                        GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);
-
                         // extract sequence embeddings
                         auto & embd_seq_out = lctx.embd_seq;
                         embd_seq_out.clear();
@@ -18112,6 +18152,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)
     ctx->abort_callback_data = abort_callback_data;
 }
 
+void llama_set_embeddings(struct llama_context * ctx, bool embeddings) {
+    ctx->cparams.embeddings = embeddings;
+}
+
 void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
     ctx->cparams.causal_attn = causal_attn;
 }
diff --git a/llama.h b/llama.h
index da310ffaf9ad9989afde157c4bf55698a0888254..05d8b092b42a4ec375e1fe28c44e642a89ad3b7f 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -174,6 +174,7 @@ extern "C" {
         LLAMA_POOLING_TYPE_NONE = 0,
         LLAMA_POOLING_TYPE_MEAN = 1,
         LLAMA_POOLING_TYPE_CLS  = 2,
+        LLAMA_POOLING_TYPE_LAST = 3,
     };
 
     enum llama_split_mode {
@@ -293,7 +294,6 @@ extern "C" {
 
         enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
         enum llama_pooling_type      pooling_type;      // whether to pool (sum) embedding results by sequence id
-                                                        // (ignored if no pooling layer)
 
         // ref: https://github.com/ggerganov/llama.cpp/pull/2054
         float    rope_freq_base;   // RoPE base frequency, 0 = from model
@@ -786,6 +786,10 @@ extern "C" {
     // Get the number of threads used for prompt and batch processing (multiple token).
     LLAMA_API uint32_t llama_n_threads_batch(struct llama_context * ctx);
 
+    // Set whether the model is in embeddings model or not
+    // If true, embeddings will be returned but logits will not
+    LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
+
     // Set whether to use causal attention or not
     // If set to true, the model will only attend to the past tokens
     LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);