]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : allow for user specified embedding pooling type (#5849)
authorDouglas Hanley <redacted>
Sun, 3 Mar 2024 10:40:27 +0000 (04:40 -0600)
committerGitHub <redacted>
Sun, 3 Mar 2024 10:40:27 +0000 (12:40 +0200)
* allow for user specified pooling type

* llama : use enum types over int

---------

Co-authored-by: Georgi Gerganov <redacted>
common/common.cpp
common/common.h
convert-hf-to-gguf.py
llama.cpp
llama.h

index 1c0b7c403b936bec58f9c491534c273be8632443..dbe7e9229b770ed30ff9305aa52ee18f346dd227 100644 (file)
@@ -335,6 +335,16 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.yarn_beta_slow = std::stof(argv[i]);
+        } else if (arg == "--pooling") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            std::string value(argv[i]);
+            /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
+            else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
+            else if (value == "cls")  { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
+            else { invalid_param = true; break; }
         } else if (arg == "--defrag-thold" || arg == "-dt") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -1014,6 +1024,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("  --yarn-attn-factor N  YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
     printf("  --yarn-beta-slow N    YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
     printf("  --yarn-beta-fast N    YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
+    printf("  --pooling {none,mean,cls}\n");
+    printf("                        pooling type for embeddings, use model default if unspecified\n");
     printf("  -dt N, --defrag-thold N\n");
     printf("                        KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold);
     printf("  --ignore-eos          ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
@@ -1296,6 +1308,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
     cparams.yarn_beta_fast    = params.yarn_beta_fast;
     cparams.yarn_beta_slow    = params.yarn_beta_slow;
     cparams.yarn_orig_ctx     = params.yarn_orig_ctx;
+    cparams.pooling_type      = params.pooling_type;
     cparams.defrag_thold      = params.defrag_thold;
     cparams.offload_kqv       = !params.no_kv_offload;
 
index ab62bdb822d71b4776b37357536925aacdf5fe00..d3682b7adae70f4f652423b19c23fbef862df739 100644 (file)
@@ -76,8 +76,11 @@ struct gpt_params {
     float   yarn_beta_slow        = 1.0f;  // YaRN high correction dim
     int32_t yarn_orig_ctx         = 0;     // YaRN original context length
     float   defrag_thold          = -1.0f; // KV cache defragmentation threshold
-    int32_t rope_scaling_type     = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
-    ggml_numa_strategy numa       = GGML_NUMA_STRATEGY_DISABLED;
+
+    ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
+
+    llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
+    llama_pooling_type      pooling_type      = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
 
     // // sampling parameters
     struct llama_sampling_params sparams;
index fa9d4f22fd33c71dde6f5dafdbc1d507fb9b458b..ffdba74441e196d4c0693c2fe109e92f4bc066a2 100755 (executable)
@@ -1644,16 +1644,17 @@ class BertModel(Model):
         self.gguf_writer.add_causal_attention(False)
 
         # get pooling path
-        with open(self.dir_model / "modules.json", encoding="utf-8") as f:
-            modules = json.load(f)
         pooling_path = None
-        for mod in modules:
-            if mod["type"] == "sentence_transformers.models.Pooling":
-                pooling_path = mod["path"]
-                break
+        module_path = self.dir_model / "modules.json"
+        if module_path.is_file():
+            with open(module_path, encoding="utf-8") as f:
+                modules = json.load(f)
+            for mod in modules:
+                if mod["type"] == "sentence_transformers.models.Pooling":
+                    pooling_path = mod["path"]
+                    break
 
         # get pooling type
-        pooling_type = gguf.PoolingType.NONE
         if pooling_path is not None:
             with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f:
                 pooling = json.load(f)
@@ -1663,8 +1664,7 @@ class BertModel(Model):
                 pooling_type = gguf.PoolingType.CLS
             else:
                 raise NotImplementedError("Only MEAN and CLS pooling types supported")
-
-        self.gguf_writer.add_pooling_type(pooling_type)
+            self.gguf_writer.add_pooling_type(pooling_type)
 
     def set_vocab(self):
         path = self.dir_model
index 41d0000da7f9eeb4d54e2f41b1d19d04715983f6..c1f015791e826e303bc8d39fd6a9b5f1b8a86125 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -873,16 +873,16 @@ struct LLM_TN {
 // gguf helpers
 //
 
-static const std::map<int32_t, const char *> LLAMA_ROPE_SCALING_TYPES = {
+static const std::map<llama_rope_scaling_type, const char *> LLAMA_ROPE_SCALING_TYPES = {
     { LLAMA_ROPE_SCALING_TYPE_NONE,   "none"   },
     { LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" },
     { LLAMA_ROPE_SCALING_TYPE_YARN,   "yarn"   },
 };
 
-static int32_t llama_rope_scaling_type_from_string(const std::string & name) {
+static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) {
     for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
         if (kv.second == name) {
-            return kv.first;
+            return (llama_rope_scaling_type) kv.first;
         }
     }
 
@@ -1612,7 +1612,6 @@ struct llama_hparams {
     float    rope_freq_base_train;
     float    rope_freq_scale_train;
     uint32_t n_yarn_orig_ctx;
-    int32_t  rope_scaling_type_train;
 
     float f_clamp_kqv      = 0.0f;
     float f_max_alibi_bias = 0.0f;
@@ -1620,8 +1619,9 @@ struct llama_hparams {
     bool causal_attn = true;
     bool need_kq_pos = false;
 
-    enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
-    enum llama_rope_type    rope_type    = LLAMA_ROPE_TYPE_NONE;
+    enum llama_pooling_type      pooling_type            = LLAMA_POOLING_TYPE_NONE;
+    enum llama_rope_type         rope_type               = LLAMA_ROPE_TYPE_NONE;
+    enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
 
     bool operator!=(const llama_hparams & other) const {
         if (this->vocab_only    != other.vocab_only)    return true;
@@ -1670,8 +1670,8 @@ struct llama_cparams {
     uint32_t n_threads;       // number of threads to use for generation
     uint32_t n_threads_batch; // number of threads to use for batch processing
 
-    float    rope_freq_base;
-    float    rope_freq_scale;
+    float rope_freq_base;
+    float rope_freq_scale;
 
     uint32_t n_yarn_orig_ctx;
     // These hyperparameters are not exposed in GGUF, because all
@@ -1683,7 +1683,7 @@ struct llama_cparams {
     float defrag_thold;
 
     bool offload_kqv;
-    bool do_pooling;
+    enum llama_pooling_type pooling_type;
 
     ggml_backend_sched_eval_callback cb_eval;
     void * cb_eval_user_data;
@@ -2933,7 +2933,11 @@ template<>
 bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) {
     uint32_t tmp;
     const bool found = get_key(kid, tmp, required);
-    result = (enum llama_pooling_type) tmp;
+    if (found) {
+        result = (enum llama_pooling_type) tmp;
+    } else {
+        result = LLAMA_POOLING_TYPE_UNSPECIFIED;
+    }
     return found;
 }
 
@@ -3210,7 +3214,7 @@ static void llm_load_hparams(
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps);
                 ml.get_key(LLM_KV_ATTENTION_CAUSAL,           hparams.causal_attn);
                 ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
-                ml.get_key(LLM_KV_POOLING_TYPE,               hparams.pooling_type);
+                ml.get_key(LLM_KV_POOLING_TYPE,               hparams.pooling_type, false);
 
                 switch (hparams.n_layer) {
                     case 3:
@@ -5175,7 +5179,7 @@ struct llm_build_context {
         n_kv             (worst_case ? n_ctx            : kv_self.n),
         kv_head          (worst_case ? n_ctx - n_tokens : kv_self.head),
         n_orig_ctx       (cparams.n_yarn_orig_ctx),
-        pooling_type     (cparams.do_pooling ? hparams.pooling_type : LLAMA_POOLING_TYPE_NONE),
+        pooling_type     (cparams.pooling_type),
         rope_type        (hparams.rope_type),
         cb               (cb),
         buf_compute_meta (lctx.buf_compute_meta) {
@@ -8015,7 +8019,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         }
     }
 
-    if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
+    if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
         const int64_t n_tokens = batch.n_tokens;
 
         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
@@ -8043,7 +8047,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         }
     }
 
-    if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
+    if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
         const int64_t n_tokens = batch.n_tokens;
 
         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
@@ -11846,6 +11850,7 @@ struct llama_context_params llama_context_default_params() {
         /*.n_threads                   =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
         /*.n_threads_batch             =*/ GGML_DEFAULT_N_THREADS,
         /*.rope_scaling_type           =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
+        /*.pooling_type                =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
         /*.rope_freq_base              =*/ 0.0f,
         /*.rope_freq_scale             =*/ 0.0f,
         /*.yarn_ext_factor             =*/ -1.0f,
@@ -11861,7 +11866,6 @@ struct llama_context_params llama_context_default_params() {
         /*.logits_all                  =*/ false,
         /*.embedding                   =*/ false,
         /*.offload_kqv                 =*/ true,
-        /*.do_pooling                  =*/ true,
         /*.abort_callback              =*/ nullptr,
         /*.abort_callback_data         =*/ nullptr,
     };
@@ -12012,7 +12016,7 @@ struct llama_context * llama_new_context_with_model(
     cparams.yarn_beta_slow   = params.yarn_beta_slow;
     cparams.defrag_thold     = params.defrag_thold;
     cparams.offload_kqv      = params.offload_kqv;
-    cparams.do_pooling       = params.do_pooling;
+    cparams.pooling_type     = params.pooling_type;
 
     cparams.n_ctx            = params.n_ctx           == 0    ? hparams.n_ctx_train           : params.n_ctx;
     cparams.rope_freq_base   = params.rope_freq_base  == 0.0f ? hparams.rope_freq_base_train  : params.rope_freq_base;
@@ -12038,6 +12042,14 @@ struct llama_context * llama_new_context_with_model(
         cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
     }
 
+    if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
+        if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
+            cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
+        } else {
+            cparams.pooling_type = hparams.pooling_type;
+        }
+    }
+
     if (params.seed == LLAMA_DEFAULT_SEED) {
         params.seed = time(NULL);
     }
diff --git a/llama.h b/llama.h
index 6406b52705e7d9d63861eaaca08c9a4b6ba70505..70da4cb3f0ff6685e8a1b3f9edb8cc2e477cc1f8 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -129,6 +129,7 @@ extern "C" {
     };
 
     enum llama_pooling_type {
+        LLAMA_POOLING_TYPE_UNSPECIFIED = -1,
         LLAMA_POOLING_TYPE_NONE = 0,
         LLAMA_POOLING_TYPE_MEAN = 1,
         LLAMA_POOLING_TYPE_CLS  = 2,
@@ -236,7 +237,10 @@ extern "C" {
         uint32_t n_batch;           // prompt processing maximum batch size
         uint32_t n_threads;         // number of threads to use for generation
         uint32_t n_threads_batch;   // number of threads to use for batch processing
-        int32_t  rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
+
+        enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
+        enum llama_pooling_type      pooling_type;      // whether to pool (sum) embedding results by sequence id
+                                                        // (ignored if no pooling layer)
 
         // ref: https://github.com/ggerganov/llama.cpp/pull/2054
         float    rope_freq_base;   // RoPE base frequency, 0 = from model
@@ -258,7 +262,6 @@ extern "C" {
         bool logits_all;  // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
         bool embedding;   // embedding mode only
         bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
-        bool do_pooling;  // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
 
         // Abort callback
         // if it returns true, execution of llama_decode() will be aborted