]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : support batched embeddings (#5466)
authorDouglas Hanley <redacted>
Tue, 13 Feb 2024 12:06:58 +0000 (06:06 -0600)
committerGitHub <redacted>
Tue, 13 Feb 2024 12:06:58 +0000 (14:06 +0200)
* batched embedding: pool outputs by sequence id. updated embedding example

* bring back non-causal attention

* embd : minor improvements

* llama : minor

---------

Co-authored-by: Georgi Gerganov <redacted>
convert-hf-to-gguf.py
examples/embedding/embedding.cpp
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
llama.cpp
llama.h

index cae1551a236b06798768224aab3de9da3c88a3b5..5adfdc143a41fd08e25ac8eecb72feb4a1202214 100755 (executable)
@@ -1648,6 +1648,7 @@ class BertModel(Model):
         self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
         self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
         self.gguf_writer.add_causal_attention(False)
+        self.gguf_writer.add_pooling_layer(True)
         self.gguf_writer.add_file_type(self.ftype)
 
     def set_vocab(self):
index 27376c8f09fdcc6624c34ce1e40a36c02437e12f..b4688cf519d151bcf0299c6fb6ec925157124e60 100644 (file)
@@ -7,6 +7,51 @@
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
 
+static std::vector<std::string> split_lines(const std::string & s) {
+    std::string line;
+    std::vector<std::string> lines;
+    std::stringstream ss(s);
+    while (std::getline(ss, line)) {
+        lines.push_back(line);
+    }
+    return lines;
+}
+
+static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
+    for (size_t i = 0; i < tokens.size(); i++) {
+        llama_batch_add(batch, tokens[i], i, { seq_id }, false);
+    }
+}
+
+static void normalize(float * vec, float * out, int n) {
+    float norm = 0;
+    for (int i = 0; i < n; i++) {
+        norm += vec[i] * vec[i];
+    }
+    norm = sqrt(norm);
+    for (int i = 0; i < n; i++) {
+        out[i] = vec[i] / norm;
+    }
+}
+
+static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
+    // clear previous kv_cache values (irrelevant for embeddings)
+    llama_kv_cache_clear(ctx);
+
+    // run model
+    fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
+    if (llama_decode(ctx, batch) < 0) {
+        fprintf(stderr, "%s : failed to decode\n", __func__);
+    }
+
+    // normalize on copy
+    for (int k = 0; k < n_seq; k++) {
+        float * emb = llama_get_embeddings_ith(ctx, k);
+        float * out = output + k * n_embd;
+        normalize(emb, out, n_embd);
+    }
+}
+
 int main(int argc, char ** argv) {
     gpt_params params;
 
@@ -55,59 +100,84 @@ int main(int argc, char ** argv) {
         fprintf(stderr, "%s\n", get_system_info(params).c_str());
     }
 
-    int n_past = 0;
+    // split the prompt into lines
+    std::vector<std::string> prompts = split_lines(params.prompt);
 
-    // tokenize the prompt
-    auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);
+    // max batch size
+    const uint64_t n_batch = params.n_batch;
+    GGML_ASSERT(params.n_batch == params.n_ctx);
 
-    if (params.verbose_prompt) {
-        fprintf(stderr, "\n");
-        fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
-        fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
-        for (int i = 0; i < (int) embd_inp.size(); i++) {
-            fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str());
+    // tokenize the prompts and trim
+    std::vector<std::vector<int32_t>> inputs;
+    for (const auto & prompt : prompts) {
+        auto inp = ::llama_tokenize(ctx, prompt, true);
+        if (inp.size() > n_batch) {
+            inp.resize(n_batch);
         }
-        fprintf(stderr, "\n");
+        inputs.push_back(inp);
     }
 
-    if (embd_inp.size() > (size_t)n_ctx) {
-        fprintf(stderr, "%s: error: prompt is longer than the context window (%zu tokens, n_ctx = %d)\n",
-                __func__, embd_inp.size(), n_ctx);
-        return 1;
-    }
-
-    while (!embd_inp.empty()) {
-        int n_tokens = std::min(params.n_batch, (int) embd_inp.size());
-        if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0))) {
-            fprintf(stderr, "%s : failed to eval\n", __func__);
-            return 1;
+    // tokenization stats
+    if (params.verbose_prompt) {
+        for (int i = 0; i < (int) inputs.size(); i++) {
+            fprintf(stderr, "%s: prompt %d: '%s'\n", __func__, i, prompts[i].c_str());
+            fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, inputs[i].size());
+            for (int j = 0; j < (int) inputs[i].size(); j++) {
+                fprintf(stderr, "%6d -> '%s'\n", inputs[i][j], llama_token_to_piece(ctx, inputs[i][j]).c_str());
+            }
+            fprintf(stderr, "\n\n");
         }
-        n_past += n_tokens;
-        embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_tokens);
     }
 
+    // initialize batch
+    const int n_prompts = prompts.size();
+    struct llama_batch batch = llama_batch_init(n_batch, 0, n_prompts);
+
+    // allocate output
     const int n_embd = llama_n_embd(model);
-    auto * embeddings = llama_get_embeddings(ctx);
+    std::vector<float> embeddings(n_prompts * n_embd, 0);
+    float * emb = embeddings.data();
+
+    // break into batches
+    int p = 0; // number of prompts processed already
+    int s = 0; // number of prompts in current batch
+    for (int k = 0; k < n_prompts; k++) {
+        // clamp to n_batch tokens
+        auto & inp = inputs[k];
+        const uint64_t n_toks = inp.size();
+
+        // encode if at capacity
+        if (batch.n_tokens + n_toks > n_batch) {
+            float * out = emb + p * n_embd;
+            batch_decode(ctx, batch, out, s, n_embd);
+            llama_batch_clear(batch);
+            p += s;
+            s = 0;
+        }
 
-    // l2-normalize embeddings
-    float norm = 0;
-    for (int i = 0; i < n_embd; i++) {
-        norm += embeddings[i] * embeddings[i];
-    }
-    norm = sqrt(norm);
-    for (int i = 0; i < n_embd; i++) {
-        embeddings[i] /= norm;
+        // add to batch
+        batch_add_seq(batch, inp, s);
+        s += 1;
     }
 
-    for (int i = 0; i < n_embd; i++) {
-        printf("%f ", embeddings[i]);
+    // final batch
+    float * out = emb + p * n_embd;
+    batch_decode(ctx, batch, out, s, n_embd);
+
+    // print first 3 embeddings
+    for (int j = 0; j < std::min(3, n_prompts); j++) {
+        fprintf(stderr, "embedding %d: ", j);
+        for (int i = 0; i < n_embd; i++) {
+            fprintf(stderr, "%f ", emb[j * n_embd + i]);
+        }
+        fprintf(stderr, "\n\n");
     }
-    printf("\n");
+    fprintf(stderr, "\n");
 
+    // clean up
     llama_print_timings(ctx);
     llama_free(ctx);
     llama_free_model(model);
-
     llama_backend_free();
 
     return 0;
index a9c13dd3826b872adf7cf8d7fac11e01cd1fa965..644e1589c830d10635a672de07e6dc6b7c462099 100644 (file)
@@ -40,6 +40,7 @@ class Keys:
         TENSOR_DATA_LAYOUT    = "{arch}.tensor_data_layout"
         EXPERT_COUNT          = "{arch}.expert_count"
         EXPERT_USED_COUNT     = "{arch}.expert_used_count"
+        POOLING_LAYER         = "{arch}.pooling_layer"
 
     class Attention:
         HEAD_COUNT        = "{arch}.attention.head_count"
index 7af58a46c2cb73332d8301e0ecab1c642582a154..d87bd8e88696c973e792c5ac90278746b6d3a08f 100644 (file)
@@ -360,6 +360,9 @@ class GGUFWriter:
     def add_causal_attention(self, value: bool) -> None:
         self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)
 
+    def add_pooling_layer(self, value: bool) -> None:
+        self.add_bool(Keys.LLM.POOLING_LAYER.format(arch=self.arch), value)
+
     def add_rope_dimension_count(self, count: int) -> None:
         self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
 
index 6dce392df005347398a74fd73a3390da3726e9e9..eb6c46f3672f973f472ab66370c3810edb4aacbf 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -254,6 +254,7 @@ enum llm_kv {
     LLM_KV_TENSOR_DATA_LAYOUT,
     LLM_KV_EXPERT_COUNT,
     LLM_KV_EXPERT_USED_COUNT,
+    LLM_KV_POOLING_LAYER,
 
     LLM_KV_ATTENTION_HEAD_COUNT,
     LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -311,6 +312,7 @@ static std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_TENSOR_DATA_LAYOUT,            "%s.tensor_data_layout"    },
     { LLM_KV_EXPERT_COUNT,                  "%s.expert_count"          },
     { LLM_KV_EXPERT_USED_COUNT,             "%s.expert_used_count"     },
+    { LLM_KV_POOLING_LAYER,                 "%s.pooling_layer"         },
 
     { LLM_KV_ATTENTION_HEAD_COUNT,          "%s.attention.head_count"             },
     { LLM_KV_ATTENTION_HEAD_COUNT_KV,       "%s.attention.head_count_kv"          },
@@ -1539,6 +1541,7 @@ struct llama_hparams {
     float f_max_alibi_bias;
 
     bool causal_attn = true;
+    bool pooling_layer = false;
 
 
     bool operator!=(const llama_hparams & other) const {
@@ -1601,6 +1604,7 @@ struct llama_cparams {
 
     bool mul_mat_q;
     bool offload_kqv;
+    bool do_pooling;
 
     ggml_backend_sched_eval_callback cb_eval;
     void * cb_eval_user_data;
@@ -1896,7 +1900,7 @@ struct llama_context {
     struct ggml_tensor * inp_pos;       // I32 [n_batch]
     struct ggml_tensor * inp_KQ_mask;   // F32 [n_ctx, n_batch]
     struct ggml_tensor * inp_K_shift;   // I32 [n_ctx]
-    struct ggml_tensor * inp_sum;       // F32 [1, n_batch]
+    struct ggml_tensor * inp_sum;       // F32 [n_batch, n_batch]
 
 #ifdef GGML_USE_MPI
     ggml_mpi_context * ctx_mpi = NULL;
@@ -3053,6 +3057,7 @@ static void llm_load_hparams(
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
                 ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
                 ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
+                ml.get_key(LLM_KV_POOLING_LAYER, hparams.pooling_layer);
 
                 switch (hparams.n_layer) {
                     case 3:
@@ -4859,7 +4864,7 @@ struct llm_build_context {
     const int32_t n_orig_ctx;
 
     const bool do_rope_shift;
-    const bool causal_attn;
+    const bool do_pooling;
 
     const llm_build_cb & cb;
 
@@ -4903,7 +4908,7 @@ struct llm_build_context {
         kv_head          (worst_case ? n_ctx - n_tokens : kv_self.head),
         n_orig_ctx       (cparams.n_yarn_orig_ctx),
         do_rope_shift    (worst_case || kv_self.has_shift),
-        causal_attn      (hparams.causal_attn),
+        do_pooling       (hparams.pooling_layer && cparams.do_pooling),
         cb               (cb),
         buf_compute_meta (lctx.buf_compute_meta) {
             // all initializations should be done in init()
@@ -5752,17 +5757,18 @@ struct llm_build_context {
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
 
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
         // get input vectors with right size
+        const size_t stride1 = n_tokens * ggml_type_size(lctx.inp_tokens->type);
         struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
-        struct ggml_tensor * inp_sum = ggml_view_1d(ctx0, lctx.inp_sum, n_tokens, 0);
+        struct ggml_tensor * inp_sum = ggml_view_2d(ctx0, lctx.inp_sum, n_tokens, n_tokens, stride1, 0);
 
         // construct input embeddings (token, type, position)
         inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
+
         // token types are hardcoded to zero ("Sentence A")
         struct ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
         inpL = ggml_add(ctx0, inpL, type_row0);
@@ -5832,9 +5838,11 @@ struct llm_build_context {
         // final output
         cur = inpL;
 
-        // pooling
-        cur = ggml_mul_mat(ctx0, inp_sum, ggml_cont(ctx0, ggml_transpose(ctx0, cur)));
-        cb(cur, "result_embed", -1);
+        // pooling layer
+        if (do_pooling) {
+            cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_sum);
+        }
+        cb(cur, "result_embd", -1);
 
         ggml_build_forward_expand(gf, cur);
 
@@ -7367,7 +7375,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
 
                 for (int i = 0; i < n_kv; ++i) {
                     float f;
-                    if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
+                    if (!lctx.kv_self.cells[i].has_seq_id(seq_id) ||
+                        (hparams.causal_attn && lctx.kv_self.cells[i].pos > pos)) {
                         f = -INFINITY;
                     } else {
                         f = 0;
@@ -7378,7 +7387,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         }
     }
 
-
     {
         assert(ggml_backend_buffer_is_host(lctx.inp_sum->buffer));
         float * data = (float *) lctx.inp_sum->data;
@@ -7399,6 +7407,20 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
             data[i] = lctx.kv_self.cells[i].delta;
         }
     }
+
+    if (hparams.pooling_layer && cparams.do_pooling) {
+        const int64_t n_tokens = batch.n_tokens;
+
+        GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_sum->buffer));
+        float * data = (float *) lctx.inp_sum->data;
+
+        memset(lctx.inp_sum->data, 0, batch.n_tokens * batch.n_tokens * ggml_element_size(lctx.inp_sum));
+
+        for (int i = 0; i < n_tokens; ++i) {
+            const llama_seq_id seq_id = batch.seq_id[i][0];
+            data[seq_id*n_tokens + i] = 1.0f;
+        }
+    }
 }
 
 // decode a batch of tokens by evaluating the transformer
@@ -7510,7 +7532,7 @@ static int llama_decode_internal(
             embeddings = gf->nodes[gf->n_nodes - 3];
             GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
         }
-    } else if (strcmp(res->name, "result_embed") == 0) {
+    } else if (strcmp(res->name, "result_embd") == 0) {
         embeddings = res;
         res = nullptr;
     } else {
@@ -7630,11 +7652,12 @@ static int llama_decode_internal(
     if (!lctx.embedding.empty()) {
         auto & embedding_out = lctx.embedding;
 
-        const int64_t embed_pos = res ? n_embd * (n_tokens-1) : 0;
+        const int64_t embd_pos  = res ? n_embd * (n_tokens-1) : 0;
+        const int64_t embd_size = res ? n_embd : n_embd * n_tokens;
 
-        embedding_out.resize(n_embd);
+        embedding_out.resize(embd_size);
         ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings);
-        ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embed_pos*sizeof(float), n_embd*sizeof(float));
+        ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embd_pos*sizeof(float), embd_size*sizeof(float));
         ggml_backend_synchronize(embeddings_backend);
     }
 
@@ -10950,6 +10973,7 @@ struct llama_context_params llama_context_default_params() {
         /*.logits_all                  =*/ false,
         /*.embedding                   =*/ false,
         /*.offload_kqv                 =*/ true,
+        /*.do_pooling                  =*/ true,
     };
 
     return result;
@@ -11105,6 +11129,7 @@ struct llama_context * llama_new_context_with_model(
     cparams.yarn_beta_slow   = params.yarn_beta_slow;
     cparams.mul_mat_q        = params.mul_mat_q;
     cparams.offload_kqv      = params.offload_kqv;
+    cparams.do_pooling       = params.do_pooling;
 
     cparams.n_ctx            = params.n_ctx           == 0    ? hparams.n_ctx_train           : params.n_ctx;
     cparams.rope_freq_base   = params.rope_freq_base  == 0.0f ? hparams.rope_freq_base_train  : params.rope_freq_base;
@@ -11252,7 +11277,7 @@ struct llama_context * llama_new_context_with_model(
         // resized during inference, reserve maximum
         ctx->logits.reserve(hparams.n_vocab*cparams.n_batch);
 
-        if (params.embedding){
+        if (params.embedding) {
             ctx->embedding.resize(hparams.n_embd);
         }
 
@@ -11270,7 +11295,7 @@ struct llama_context * llama_new_context_with_model(
             ctx->inp_pos     = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
             ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch);
             ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx);
-            ctx->inp_sum     = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, 1, cparams.n_batch);
+            ctx->inp_sum     = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
 
             ggml_set_name(ctx->inp_tokens,  "inp_tokens");
             ggml_set_name(ctx->inp_embd,    "inp_embd");
@@ -12128,6 +12153,10 @@ float * llama_get_embeddings(struct llama_context * ctx) {
     return ctx->embedding.data();
 }
 
+float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
+    return ctx->embedding.data() + i*ctx->model.hparams.n_embd;
+}
+
 const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
     return model->vocab.id_to_token[token].text.c_str();
 }
diff --git a/llama.h b/llama.h
index 367e8f1a105a5f8cc9f7d8d046f29a467e6262a5..5ef78ec968b1c17b64546b4ac3d5e6e90fb1bb65 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -236,6 +236,7 @@ extern "C" {
         bool logits_all;  // the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
         bool embedding;   // embedding mode only
         bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
+        bool do_pooling;  // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
     };
 
     // model quantization parameters
@@ -628,6 +629,10 @@ extern "C" {
     // shape: [n_embd] (1-dimensional)
     LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
 
+    // Get the embeddings for the ith sequence
+    // llama_get_embeddings(ctx) + i*n_embd
+    LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
+
     //
     // Vocab
     //