]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : streamline embeddings from "non-embedding" models (#8087)
authorDouglas Hanley <redacted>
Fri, 5 Jul 2024 07:05:56 +0000 (02:05 -0500)
committerGitHub <redacted>
Fri, 5 Jul 2024 07:05:56 +0000 (10:05 +0300)
common/common.cpp
common/common.h
include/llama.h
src/llama.cpp

index 1f7359980a94db84a51f3933b55c68150b07937f..42ae85f5f11c214ab56d090caa28e5026e9fc0ed 100644 (file)
@@ -472,6 +472,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         else { invalid_param = true; }
         return true;
     }
+    if (arg == "--attention") {
+        CHECK_ARG
+        std::string value(argv[i]);
+        /**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; }
+        else if (value == "non-causal") { params.attention_type = LLAMA_ATTENTION_TYPE_NON_CAUSAL; }
+        else { invalid_param = true; }
+        return true;
+    }
     if (arg == "--defrag-thold" || arg == "-dt") {
         CHECK_ARG
         params.defrag_thold = std::stof(argv[i]);
@@ -1468,8 +1476,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
                                                                         "For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead" });
 
     options.push_back({ "embedding" });
-    options.push_back({ "embedding",   "       --pooling {none,mean,cls}",
+    options.push_back({ "embedding",   "       --pooling {none,mean,cls,last}",
                                                                         "pooling type for embeddings, use model default if unspecified" });
+    options.push_back({ "embedding",   "       --attention {causal,non-causal}",
+                                                                        "attention type for embeddings, use model default if unspecified" });
 
     options.push_back({ "context hacking" });
     options.push_back({ "*",           "       --rope-scaling {none,linear,yarn}",
@@ -2175,6 +2185,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
     cparams.yarn_beta_slow    = params.yarn_beta_slow;
     cparams.yarn_orig_ctx     = params.yarn_orig_ctx;
     cparams.pooling_type      = params.pooling_type;
+    cparams.attention_type    = params.attention_type;
     cparams.defrag_thold      = params.defrag_thold;
     cparams.cb_eval           = params.cb_eval;
     cparams.cb_eval_user_data = params.cb_eval_user_data;
index 65c0ef81adf7cac183384781d402de56a9fab19a..ac86483ffa1c946d6e5b43bd9f64726df3accf67 100644 (file)
@@ -99,6 +99,7 @@ struct gpt_params {
     enum llama_split_mode        split_mode        = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
     enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
     enum llama_pooling_type      pooling_type      = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
+    enum llama_attention_type    attention_type    = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
 
     // // sampling parameters
     struct llama_sampling_params sparams;
index aca79d2b864b8b7323babda8b2bb65f8a7717737..7a9a25609c9c1949794f79bc772a533153da26c2 100644 (file)
@@ -180,6 +180,12 @@ extern "C" {
         LLAMA_POOLING_TYPE_LAST = 3,
     };
 
+    enum llama_attention_type {
+        LLAMA_ATTENTION_TYPE_UNSPECIFIED = -1,
+        LLAMA_ATTENTION_TYPE_CAUSAL      = 0,
+        LLAMA_ATTENTION_TYPE_NON_CAUSAL  = 1,
+    };
+
     enum llama_split_mode {
         LLAMA_SPLIT_MODE_NONE    = 0, // single GPU
         LLAMA_SPLIT_MODE_LAYER   = 1, // split layers and KV across GPUs
@@ -297,6 +303,7 @@ extern "C" {
 
         enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
         enum llama_pooling_type      pooling_type;      // whether to pool (sum) embedding results by sequence id
+        enum llama_attention_type    attention_type;    // attention type to use for embeddings
 
         // ref: https://github.com/ggerganov/llama.cpp/pull/2054
         float    rope_freq_base;   // RoPE base frequency, 0 = from model
index b7ef8297563833256484cc2a3fa1197c7d967018..132055888b9f16c9fbd8677c932a98ee8325dd2b 100644 (file)
@@ -13840,7 +13840,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         }
     }
 
-    if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
+    if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
         const int64_t n_tokens = batch.n_tokens;
 
         GGML_ASSERT(lctx.inp_mean);
@@ -13872,7 +13872,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         }
     }
 
-    if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
+    if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
         const int64_t n_tokens = batch.n_tokens;
 
         GGML_ASSERT(lctx.inp_cls);
@@ -13893,7 +13893,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         }
     }
 
-    if (cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
+    if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
         const int64_t n_tokens = batch.n_tokens;
 
         GGML_ASSERT(lctx.inp_cls);
@@ -14181,14 +14181,15 @@ static int llama_decode_internal(
     std::vector<llama_seq_id *>            seq_id_arr;
     std::vector<std::vector<llama_seq_id>> seq_id;
 
+    // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
+    const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
+
     // count outputs
-    if (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE) {
-        n_outputs = n_tokens_all;
-    } else if (batch_all.logits) {
+    if (batch_all.logits && !embd_pooled) {
         for (uint32_t i = 0; i < n_tokens_all; ++i) {
             n_outputs += batch_all.logits[i] != 0;
         }
-    } else if (lctx.logits_all) {
+    } else if (lctx.logits_all || embd_pooled) {
         n_outputs = n_tokens_all;
     } else {
         // keep last output only
@@ -14234,7 +14235,7 @@ static int llama_decode_internal(
         {
             int32_t n_outputs_new = 0;
 
-            if (u_batch.logits) {
+            if (u_batch.logits && !embd_pooled) {
                 for (uint32_t i = 0; i < n_tokens; i++) {
                     n_outputs_new += u_batch.logits[i] != 0;
                 }
@@ -18533,6 +18534,7 @@ struct llama_context_params llama_context_default_params() {
         /*.n_threads_batch             =*/ GGML_DEFAULT_N_THREADS,
         /*.rope_scaling_type           =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
         /*.pooling_type                =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
+        /*.attention_type              =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
         /*.rope_freq_base              =*/ 0.0f,
         /*.rope_freq_scale             =*/ 0.0f,
         /*.yarn_ext_factor             =*/ -1.0f,
@@ -18785,7 +18787,6 @@ struct llama_context * llama_new_context_with_model(
     }
 
     cparams.yarn_attn_factor *= hparams.rope_attn_factor;
-    cparams.causal_attn = hparams.causal_attn;
 
     if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
         if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
@@ -18795,6 +18796,12 @@ struct llama_context * llama_new_context_with_model(
         }
     }
 
+    if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
+        cparams.causal_attn = hparams.causal_attn;
+    } else {
+        cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
+    }
+
     if (params.seed == LLAMA_DEFAULT_SEED) {
         params.seed = time(NULL);
     }