]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add reranking support (#9510)
authorGeorgi Gerganov <redacted>
Sat, 28 Sep 2024 14:42:03 +0000 (17:42 +0300)
committerGitHub <redacted>
Sat, 28 Sep 2024 14:42:03 +0000 (17:42 +0300)
* py : add XLMRobertaForSequenceClassification [no ci]

* py : fix scalar-tensor conversion [no ci]

* py : fix position embeddings chop [no ci]

* llama : read new cls tensors [no ci]

* llama : add classigication head (wip) [no ci]

* llama : add "rank" pooling type

ggml-ci

* server : add rerank endpoint

ggml-ci

* llama : aboud ggml_repeat during classification

* rerank : cleanup + comments

* server : accept /rerank endpoint in addition to /v1/rerank [no ci]

* embedding : parse special tokens

* jina : support v1 reranker

* vocab : minor style

ggml-ci

* server : initiate tests for later

ggml-ci

* server : add docs

* llama : add comment [no ci]

* llama : fix uninitialized tensors

* ci : add rerank tests

ggml-ci

* add reranking test

* change test data

* Update examples/server/server.cpp

Co-authored-by: Xuan Son Nguyen <redacted>
* add `--reranking` argument

* update server docs

* llama : fix comment [no ci]

ggml-ci

---------

Co-authored-by: Xuan Son Nguyen <redacted>
Co-authored-by: Xuan Son Nguyen <redacted>
18 files changed:
ci/run.sh
common/arg.cpp
common/common.cpp
common/common.h
convert_hf_to_gguf.py
convert_hf_to_gguf_update.py
examples/embedding/embedding.cpp
examples/server/README.md
examples/server/server.cpp
examples/server/tests/features/embeddings.feature
examples/server/tests/features/rerank.feature [new file with mode: 0644]
examples/server/tests/features/steps/steps.py
examples/server/utils.hpp
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
include/llama.h
src/llama-vocab.cpp
src/llama.cpp

index 1ac08ee4e19a8601c2117fdce598fdf6faeaa57f..7d241ecc0ea063aeea8d9eb014f19ff8f66896e6 100755 (executable)
--- a/ci/run.sh
+++ b/ci/run.sh
@@ -712,6 +712,81 @@ function gg_run_embd_bge_small {
     set +e
 }
 
+function gg_sum_embd_bge_small {
+    gg_printf '### %s\n\n' "${ci}"
+
+    gg_printf 'BGE Small (BERT):\n'
+    gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)"
+    gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-f16.log)"
+    gg_printf '- q8_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q8_0.log)"
+}
+
+# rerank_tiny
+
+function gg_run_rerank_tiny {
+    cd ${SRC}
+
+    gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/config.json
+    gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/tokenizer.json
+    gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/tokenizer_config.json
+    gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/special_tokens_map.json
+    gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/resolve/main/pytorch_model.bin
+    gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/sentence_bert_config.json
+    gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/vocab.txt
+    gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/modules.json
+    gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/config.json
+
+    gg_wget models-mnt/rerank-tiny/1_Pooling https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/1_Pooling/config.json
+
+    path_models="../models-mnt/rerank-tiny"
+
+    rm -rf build-ci-release && mkdir build-ci-release && cd build-ci-release
+
+    set -e
+
+    (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log
+    (time make -j$(nproc)                                    ) 2>&1 | tee -a $OUT/${ci}-make.log
+
+    python3 ../convert_hf_to_gguf.py ${path_models} --outfile ${path_models}/ggml-model-f16.gguf
+
+    model_f16="${path_models}/ggml-model-f16.gguf"
+
+    (time ./bin/llama-embedding --model ${model_f16}  -p "what is panda?</s><s>hi\nwhat is panda?</s><s>it's a bear\nwhat is panda?</s><s>The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
+
+    # sample output
+    # rerank score 0:    0.029
+    # rerank score 1:    0.029
+    # rerank score 2:    0.135
+
+    # check that the score is in the range [$3, $4]
+    function check_score {
+        qnt="$1"
+        score=$(echo "$2" | grep -oE "[0-9]+\.[0-9]+" | tail -n 1)
+
+        if [ $(echo "$score < $3" | bc) -eq 1 ] || [ $(echo "$score > $4" | bc) -eq 1 ]; then
+            printf '  - %s @ %s (FAIL: score not in range [%s, %s])\n' "$qnt" "$score" "$3" "$4"
+            return 20
+        fi
+
+        printf '  - %s @ %s OK\n' "$qnt" "$score"
+        return 0
+    }
+
+    check_score "rerank score 0" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 0")" "0.00" "0.05" | tee -a $OUT/${ci}-rk-f16.log
+    check_score "rerank score 1" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 1")" "0.00" "0.05" | tee -a $OUT/${ci}-rk-f16.log
+    check_score "rerank score 2" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 2")" "0.10" "0.15" | tee -a $OUT/${ci}-rk-f16.log
+
+    set +e
+}
+
+function gg_sum_rerank_tiny {
+    gg_printf '### %s\n\n' "${ci}"
+
+    gg_printf 'Rerank Tiny (Jina):\n'
+    gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)"
+    gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-rk-f16.log)"
+}
+
 function gg_check_build_requirements {
     if ! command -v cmake &> /dev/null; then
         gg_printf 'cmake not found, please install'
@@ -726,15 +801,6 @@ function gg_check_build_requirements {
     fi
 }
 
-function gg_sum_embd_bge_small {
-    gg_printf '### %s\n\n' "${ci}"
-
-    gg_printf 'BGE Small (BERT):\n'
-    gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)"
-    gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-f16.log)"
-    gg_printf '- q8_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q8_0.log)"
-}
-
 ## main
 
 export LLAMA_LOG_PREFIX=1
@@ -762,6 +828,7 @@ test $ret -eq 0 && gg_run ctest_release
 
 if [ -z ${GG_BUILD_LOW_PERF} ]; then
     test $ret -eq 0 && gg_run embd_bge_small
+    test $ret -eq 0 && gg_run rerank_tiny
 
     if [ -z ${GG_BUILD_CLOUD} ] || [ ${GG_BUILD_EXTRA_TESTS_0} ]; then
         test $ret -eq 0 && gg_run test_scripts_debug
index 6880117ed800129d3524de8fe8ded072e6912210..8266a16c261c51a9ace18916978956eb49de14cd 100644 (file)
@@ -284,6 +284,10 @@ static bool gpt_params_parse_ex(int argc, char ** argv, gpt_params_context & ctx
         params.kv_overrides.back().key[0] = 0;
     }
 
+    if (params.reranking && params.embedding) {
+        throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
+    }
+
     return true;
 }
 
@@ -391,7 +395,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
         [](gpt_params & params) {
             params.verbose_prompt = true;
         }
-    ).set_examples({LLAMA_EXAMPLE_MAIN}));
+    ));
     add_opt(llama_arg(
         {"--no-display-prompt"},
         format("don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false"),
@@ -1093,13 +1097,14 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
         }
     ).set_sparam());
     add_opt(llama_arg(
-        {"--pooling"}, "{none,mean,cls,last}",
+        {"--pooling"}, "{none,mean,cls,last,rank}",
         "pooling type for embeddings, use model default if unspecified",
         [](gpt_params & params, const std::string & value) {
             /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
             else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
-            else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
+            else if (value == "cls")  { params.pooling_type = LLAMA_POOLING_TYPE_CLS;  }
             else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; }
+            else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; }
             else { throw std::invalid_argument("invalid value"); }
         }
     ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING"));
@@ -1749,6 +1754,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
             params.embedding = true;
         }
     ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
+    add_opt(llama_arg(
+        {"--reranking", "--rerank"},
+        format("enable reranking endpoint on server (default: %s)", params.reranking ? "enabled" : "disabled"),
+        [](gpt_params & params) {
+            params.reranking = true;
+        }
+    ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING"));
     add_opt(llama_arg(
         {"--api-key"}, "KEY",
         "API key to use for authentication (default: none)",
index 8d0ed4f95a7374f4cada4f101cf6458592a41e03..e2b8574bf77d7ad9e2c8b2a4bbd615276d4e915c 100644 (file)
@@ -1023,6 +1023,11 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
     cparams.flash_attn        = params.flash_attn;
     cparams.no_perf           = params.no_perf;
 
+    if (params.reranking) {
+        cparams.embeddings    = true;
+        cparams.pooling_type  = LLAMA_POOLING_TYPE_RANK;
+    }
+
     cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
     cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
 
index cb87c4479ed0a029c8e29e062260c020d94d347d..8b84cf9ad45ee698eeeb31e9e2b959bb4bf83c5b 100644 (file)
@@ -271,6 +271,7 @@ struct gpt_params {
     int32_t embd_normalize = 2;     // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
     std::string embd_out   = "";    // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
     std::string embd_sep   = "\n";  // separator of embendings
+    bool reranking         = false; // enable reranking support on server
 
     // server params
     int32_t port           = 8080;         // server listens on this network port
index 2cd5a8c11bc18501a38e36b96e18daf5fa7d11df..96a8830e9e7a3efa4ecfa1a8a022bd49d822c77f 100755 (executable)
@@ -291,8 +291,13 @@ class Model:
                     bid = int(part)
                     break
 
-            for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)):
-                data: np.ndarray  # type hint
+            for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
+                data = data_torch.squeeze().numpy()
+
+                # if data ends up empty, it means data_torch was a scalar tensor -> restore
+                if len(data.shape) == 0:
+                    data = data_torch.numpy()
+
                 n_dims = len(data.shape)
                 data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims)
 
@@ -592,6 +597,9 @@ class Model:
         if chkhsh == "a8594e3edff7c29c003940395316294b2c623e09894deebbc65f33f1515df79e":
             # ref: https://huggingface.co/databricks/dbrx-base
             res = "dbrx"
+        if chkhsh == "c7699093ba4255a91e702aa38a596aa81669f3525dae06c2953267dde580f448":
+            # ref: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
+            res = "jina-v1-en"
         if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f":
             # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-en
             res = "jina-v2-en"
@@ -2601,7 +2609,7 @@ class NomicBertModel(BertModel):
         self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
 
 
-@Model.register("XLMRobertaModel")
+@Model.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
 class XLMRobertaModel(BertModel):
     model_arch = gguf.MODEL_ARCH.BERT
 
@@ -2699,6 +2707,11 @@ class XLMRobertaModel(BertModel):
         self.gguf_writer.add_add_eos_token(True)
 
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        # if name starts with "roberta.", remove the prefix
+        # e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main
+        if name.startswith("roberta."):
+            name = name[8:]
+
         # position embeddings start at pad_token_id + 1, so just chop down the weight tensor
         if name == "embeddings.position_embeddings.weight":
             if self._position_offset is not None:
@@ -3110,6 +3123,14 @@ class JinaBertV2Model(BertModel):
         self.gguf_writer.add_add_bos_token(True)
         self.gguf_writer.add_add_eos_token(True)
 
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        # if name starts with "bert.", remove the prefix
+        # e.g. https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
+        if name.startswith("bert."):
+            name = name[5:]
+
+        return super().modify_tensors(data_torch, name, bid)
+
 
 @Model.register("OpenELMForCausalLM")
 class OpenELMModel(Model):
index 4d11059f374d262af01ef7abc9d10f281284d244..022354a3b624ecdd1fd6da536ebb7f77ec7405d0 100755 (executable)
@@ -81,6 +81,7 @@ models = [
     {"name": "qwen2",          "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", },
     {"name": "olmo",           "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", },
     {"name": "dbrx",           "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", },
+    {"name": "jina-v1-en",     "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-reranker-v1-tiny-en", },
     {"name": "jina-v2-en",     "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM!
     {"name": "jina-v2-es",     "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", },
     {"name": "jina-v2-de",     "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", },
index a438dcb5adf34f71085e8c72306ae789e4e100c5..7349268223827727770506cd3435594ff06e1745 100644 (file)
@@ -135,7 +135,7 @@ int main(int argc, char ** argv) {
     // tokenize the prompts and trim
     std::vector<std::vector<int32_t>> inputs;
     for (const auto & prompt : prompts) {
-        auto inp = ::llama_tokenize(ctx, prompt, true, false);
+        auto inp = ::llama_tokenize(ctx, prompt, true, true);
         if (inp.size() > n_batch) {
             LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n",
                     __func__, (long long int) inp.size(), (long long int) n_batch);
@@ -234,6 +234,11 @@ int main(int argc, char ** argv) {
                 }
                 LOG("\n");
             }
+        } else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
+            for (int j = 0; j < n_embd_count; j++) {
+                // NOTE: if you change this log - update the tests in ci/run.sh
+                LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
+            }
         } else {
             // print the first part of the embeddings or for a single prompt, the full embedding
             for (int j = 0; j < n_prompts; j++) {
index dfca07f9888247822f3e5fecef247dcec42b781a..951c4a44c6058e781a8eccde1b25af3f3737dfbc 100644 (file)
@@ -7,6 +7,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp.
 **Features:**
  * LLM inference of F16 and quantized models on GPU and CPU
  * [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes
+ * Reranking endoint (WIP: https://github.com/ggerganov/llama.cpp/pull/9510)
  * Parallel decoding with multi-user support
  * Continuous batching
  * Multimodal (wip)
@@ -23,6 +24,7 @@ The project is under active development, and we are [looking for feedback and co
 | -------- | ----------- |
 | `-h, --help, --usage` | print usage and exit |
 | `--version` | show version and build info |
+| `--verbose-prompt` | print a verbose prompt before generation (default: false) |
 | `-t, --threads N` | number of threads to use during generation (default: -1)<br/>(env: LLAMA_ARG_THREADS) |
 | `-tb, --threads-batch N` | number of threads to use during batch and prompt processing (default: same as --threads) |
 | `-C, --cpu-mask M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: "") |
@@ -130,7 +132,7 @@ The project is under active development, and we are [looking for feedback and co
 | `--no-context-shift` | disables context shift on inifinite text generation (default: disabled)<br/>(env: LLAMA_ARG_NO_CONTEXT_SHIFT) |
 | `-sp, --special` | special tokens output enabled (default: false) |
 | `--spm-infill` | use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: disabled) |
-| `--pooling {none,mean,cls,last}` | pooling type for embeddings, use model default if unspecified<br/>(env: LLAMA_ARG_POOLING) |
+| `--pooling {none,mean,cls,last,rank}` | pooling type for embeddings, use model default if unspecified<br/>(env: LLAMA_ARG_POOLING) |
 | `-cb, --cont-batching` | enable continuous batching (a.k.a dynamic batching) (default: enabled)<br/>(env: LLAMA_ARG_CONT_BATCHING) |
 | `-nocb, --no-cont-batching` | disable continuous batching<br/>(env: LLAMA_ARG_NO_CONT_BATCHING) |
 | `-a, --alias STRING` | set alias for model name (to be used by REST API)<br/>(env: LLAMA_ARG_ALIAS) |
@@ -138,6 +140,7 @@ The project is under active development, and we are [looking for feedback and co
 | `--port PORT` | port to listen (default: 8080)<br/>(env: LLAMA_ARG_PORT) |
 | `--path PATH` | path to serve static files from (default: )<br/>(env: LLAMA_ARG_STATIC_PATH) |
 | `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)<br/>(env: LLAMA_ARG_EMBEDDINGS) |
+| `--reranking, --rerank` | enable reranking endpoint on server (default: disabled)<br/>(env: LLAMA_ARG_RERANKING) |
 | `--api-key KEY` | API key to use for authentication (default: none)<br/>(env: LLAMA_API_KEY) |
 | `--api-key-file FNAME` | path to file containing API keys (default: none) |
 | `--ssl-key-file FNAME` | path to file a PEM-encoded SSL private key<br/>(env: LLAMA_ARG_SSL_KEY_FILE) |
@@ -152,6 +155,7 @@ The project is under active development, and we are [looking for feedback and co
 | `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)<br/> |
 | `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) |
 
+
 Note: If both command line argument and environment variable are both set for the same param, the argument will take precedence over env var.
 
 Example usage of docker compose with environment variables:
@@ -478,6 +482,39 @@ The same as [the embedding example](../embedding) does.
 
     `image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `content`. You can determine the place of the image in the content as in the following: `Image: [img-21].\nCaption: This is a picture of a house`. In this case, `[img-21]` will be replaced by the embeddings of the image with id `21` in the following `image_data` array: `{..., "image_data": [{"data": "<BASE64_STRING>", "id": 21}]}`. Use `image_data` only with multimodal models, e.g., LLaVA.
 
+### POST `/reranking`: Rerank documents according to a given query
+
+Similar to https://jina.ai/reranker/ but might change in the future.
+Requires a reranker model (such as [bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3)) and the `--embedding --pooling rank` options.
+
+    *Options:*
+
+    `query`: The query against which the documents will be ranked.
+
+    `documents`: An array strings representing the documents to be ranked.
+
+    *Aliases:*
+      - `/rerank`
+      - `/v1/rerank`
+      - `/v1/reranking`
+
+    *Examples:*
+
+    ```shell
+    curl http://127.0.0.1:8012/v1/rerank \
+        -H "Content-Type: application/json" \
+        -d '{
+            "model": "some-model",
+                "query": "What is panda?",
+                "top_n": 3,
+                "documents": [
+                    "hi",
+                "it is a bear",
+                "The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China."
+                ]
+        }' | jq
+    ```
+
 ### POST `/infill`: For code infilling.
 
 Takes a prefix and a suffix and returns the predicted completion as stream.
index 61ff09bb2b40f90a4d354d7deda2c64a9da06d94..f343cc252f89a0d7e141dee15a1c21798a6246cb 100644 (file)
@@ -92,6 +92,7 @@ enum server_task_type {
 enum server_task_cmpl_type {
     SERVER_TASK_CMPL_TYPE_NORMAL,
     SERVER_TASK_CMPL_TYPE_EMBEDDING,
+    SERVER_TASK_CMPL_TYPE_RERANK,
     SERVER_TASK_CMPL_TYPE_INFILL,
 };
 
@@ -172,6 +173,7 @@ struct server_slot {
     std::vector<completion_token_output> generated_token_probs;
 
     server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
+
     bool has_next_token = true;
     bool truncated      = false;
     bool stopped_eos    = false;
@@ -954,8 +956,17 @@ struct server_context {
                 slot.prompt = *prompt;
             } else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
                 slot.prompt = prompt->at(0);
+            } else if (prompt->is_array() && prompt->size() > 1) {
+                // array of strings
+                for (const auto & el : *prompt) {
+                    if (!el.is_string()) {
+                        send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
+                        return false;
+                    }
+                }
+                slot.prompt = *prompt;
             } else {
-                send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST);
+                send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
                 return false;
             }
         }
@@ -1389,6 +1400,7 @@ struct server_context {
 
                 res.data = json {
                     {"embedding", std::vector<float>(n_embd, 0.0f)},
+                    {"index",     slot.index},
                 };
 
                 continue;
@@ -1407,6 +1419,44 @@ struct server_context {
         queue_results.send(res);
     }
 
+    void send_rerank(const server_slot & slot, const llama_batch & batch) {
+        server_task_result res;
+        res.id       = slot.id_task;
+        res.error    = false;
+        res.stop     = true;
+
+        for (int i = 0; i < batch.n_tokens; ++i) {
+            if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
+                continue;
+            }
+
+            const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
+            if (embd == NULL) {
+                embd = llama_get_embeddings_ith(ctx, i);
+            }
+
+            if (embd == NULL) {
+                SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
+
+                res.data = json {
+                    {"index", slot.index},
+                    {"score", -1e6},
+                };
+
+                continue;
+            }
+
+            res.data = json {
+                {"index", slot.index},
+                {"score", embd[0]},
+            };
+        }
+
+        SLT_DBG(slot, "sending rerank result, res = '%s'\n", res.data.dump().c_str());
+
+        queue_results.send(res);
+    }
+
     //
     // Functions to create new task(s) and receive result(s)
     //
@@ -1442,13 +1492,27 @@ struct server_context {
         // otherwise, it's a multiple-prompt task, we break it into smaller tasks
         else if (prompt.is_array()) {
             std::vector<json> prompts = prompt;
-            for (size_t i = 0; i < prompts.size(); i++) {
-                const auto & e = prompts[i];
-                if (e.is_string() || json_is_array_of_numbers(e)) {
-                    data["index"] = i;
-                    create_task(data, true, e);
-                } else {
-                    throw std::runtime_error(error_msg);
+            if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
+                // prompts[0] is the question
+                // the rest are the answers/documents
+                SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1);
+                for (size_t i = 1; i < prompts.size(); i++) {
+                    json qd;
+                    qd.push_back(prompts[0]);
+                    qd.push_back(prompts[i]);
+                    data["index"] = i - 1;
+                    create_task(data, true, qd);
+                }
+            } else {
+                SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size());
+                for (size_t i = 0; i < prompts.size(); i++) {
+                    const auto & e = prompts[i];
+                    if (e.is_string() || json_is_array_of_numbers(e)) {
+                        data["index"] = i;
+                        create_task(data, true, e);
+                    } else {
+                        throw std::runtime_error(error_msg);
+                    }
                 }
             }
         }
@@ -1492,7 +1556,9 @@ struct server_context {
                 return;
             }
 
-            size_t idx = result.data["index"];
+            const size_t idx = result.data["index"];
+            GGML_ASSERT(idx < results.size() && "index out of range");
+
             results[idx] = result;
         }
         result_handler(results);
@@ -1903,6 +1969,7 @@ struct server_context {
         // track if this is an embedding or non-embedding batch
         // if we've added sampled tokens above, we are in non-embedding mode
         // -1: none, 0: non-embedding, 1: embedding
+        // TODO: make enum
         int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
 
         // next, batch any pending prompts without exceeding n_batch
@@ -1951,6 +2018,29 @@ struct server_context {
                             }
 
                             prompt_tokens = embd_inp;
+                        } else if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
+                            // require slot.prompt to be array of 2 strings
+                            if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
+                                SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
+                                slot.release();
+                                send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
+                                continue;
+                            }
+
+                            // prompt: <s>query</s><s>doc</s>
+                            prompt_tokens.clear();
+                            prompt_tokens.push_back(llama_token_bos(model));
+                            {
+                                const auto part = tokenize(slot.prompt[0], false);
+                                prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
+                            }
+                            prompt_tokens.push_back(llama_token_eos(model));
+                            prompt_tokens.push_back(llama_token_bos(model));
+                            {
+                                const auto part = tokenize(slot.prompt[1], false);
+                                prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
+                            }
+                            prompt_tokens.push_back(llama_token_eos(model));
                         } else {
                             prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
                         }
@@ -1970,7 +2060,7 @@ struct server_context {
                             continue;
                         }
 
-                        if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
+                        if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
                             // this prompt is too large to process - discard it
                             if (slot.n_prompt_tokens > n_ubatch) {
                                 slot.release();
@@ -2048,7 +2138,8 @@ struct server_context {
                         slot.n_prompt_tokens_processed = 0;
                     }
 
-                    if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
+                    // non-causal tasks require to fit the entire prompt in the physical batch
+                    if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
                         // cannot fit the prompt in the current batch - will try next iter
                         if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
                             continue;
@@ -2056,7 +2147,10 @@ struct server_context {
                     }
 
                     // check that we are in the right batch_type, if not defer the slot
-                    bool slot_type = slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ? 1 : 0;
+                    const bool slot_type =
+                        slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
+                        slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK     ? 1 : 0;
+
                     if (batch_type == -1) {
                         batch_type = slot_type;
                     } else if (batch_type != slot_type) {
@@ -2229,6 +2323,13 @@ struct server_context {
                         continue; // continue loop of slots
                     }
 
+                    if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
+                        send_rerank(slot, batch_view);
+                        slot.release();
+                        slot.i_batch = -1;
+                        continue; // continue loop of slots
+                    }
+
                     // prompt evaluated for next-token prediction
                     slot.state = SLOT_STATE_GENERATING;
                 } else if (slot.state != SLOT_STATE_GENERATING) {
@@ -2787,8 +2888,8 @@ int main(int argc, char ** argv) {
     };
 
     const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
-        if (ctx_server.params.embedding) {
-            res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
+        if (ctx_server.params.embedding || ctx_server.params.reranking) {
+            res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
             return;
         }
 
@@ -2848,8 +2949,8 @@ int main(int argc, char ** argv) {
 
     // TODO: maybe merge this function with "handle_completions_generic"
     const auto handle_chat_completions = [&ctx_server, &params, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
-        if (ctx_server.params.embedding) {
-            res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
+        if (ctx_server.params.embedding || ctx_server.params.reranking) {
+            res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
             return;
         }
 
@@ -2973,6 +3074,11 @@ int main(int argc, char ** argv) {
     };
 
     const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
+        // TODO: somehow clean up this checks in the future
+        if (!ctx_server.params.embedding || ctx_server.params.reranking) {
+            res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
+            return;
+        }
         const json body = json::parse(req.body);
         bool is_openai = false;
 
@@ -3023,6 +3129,79 @@ int main(int argc, char ** argv) {
         res_ok(res, root);
     };
 
+    const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
+        if (!ctx_server.params.reranking) {
+            res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
+            return;
+        }
+        const json body = json::parse(req.body);
+
+        // TODO: implement
+        //int top_n = 1;
+        //if (body.count("top_n") != 1) {
+        //    top_n = body.at("top_n");
+        //} else {
+        //    res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST));
+        //    return;
+        //}
+
+        json query;
+        if (body.count("query") == 1) {
+            query = body.at("query");
+            if (!query.is_string()) {
+                res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
+                return;
+            }
+        } else {
+            res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
+            return;
+        }
+
+        std::vector<std::string> documents = json_value(body, "documents", std::vector<std::string>());
+        if (documents.empty()) {
+            res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
+            return;
+        }
+
+        // construct prompt object: array of ["query", "doc0", "doc1", ...]
+        json prompt;
+        prompt.push_back(query);
+        for (const auto & doc : documents) {
+            prompt.push_back(doc);
+        }
+
+        LOG_DBG("rerank prompt: %s\n", prompt.dump().c_str());
+
+        // create and queue the task
+        json responses = json::array();
+        bool error = false;
+        {
+            std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK);
+            ctx_server.queue_results.add_waiting_tasks(tasks);
+            ctx_server.queue_tasks.post(tasks);
+
+            // get the result
+            std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
+
+            ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
+                for (const auto & res : results) {
+                    responses.push_back(res.data);
+                }
+            }, [&](const json & error_data) {
+                res_error(res, error_data);
+                error = true;
+            });
+        }
+
+        if (error) {
+            return;
+        }
+
+        // write JSON response
+        json root = format_response_rerank(body, responses);
+        res_ok(res, root);
+    };
+
     const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
         json result = json::array();
         for (size_t i = 0; i < ctx_server.loras.size(); ++i) {
@@ -3119,6 +3298,10 @@ int main(int argc, char ** argv) {
     svr->Post("/embedding",           handle_embeddings); // legacy
     svr->Post("/embeddings",          handle_embeddings);
     svr->Post("/v1/embeddings",       handle_embeddings);
+    svr->Post("/rerank",              handle_rerank);
+    svr->Post("/reranking",           handle_rerank);
+    svr->Post("/v1/rerank",           handle_rerank);
+    svr->Post("/v1/reranking",        handle_rerank);
     svr->Post("/tokenize",            handle_tokenize);
     svr->Post("/detokenize",          handle_detokenize);
     // LoRA adapters hotswap
index 818ea3beb90cd3131f54270e499de9dd8b37b27d..f4fe2ee4335ff8993ee08dec88606addd131be21 100644 (file)
@@ -15,7 +15,7 @@ Feature: llama.cpp server
     And   128 as batch size
     And   128 as ubatch size
     And   512 KV cache size
-    And   embeddings extraction
+    And   enable embeddings endpoint
     Then  the server is starting
     Then  the server is healthy
 
diff --git a/examples/server/tests/features/rerank.feature b/examples/server/tests/features/rerank.feature
new file mode 100644 (file)
index 0000000..c36cc8e
--- /dev/null
@@ -0,0 +1,42 @@
+@llama.cpp
+@rerank
+Feature: llama.cpp server
+
+  Background: Server startup
+    Given a server listening on localhost:8080
+    And   a model url https://huggingface.co/ggml-org/models/resolve/main/jina-reranker-v1-tiny-en/ggml-model-f16.gguf
+    And   a model file jina-reranker-v1-tiny-en.gguf
+    And   a model alias jina-reranker-v1-tiny-en
+    And   42 as server seed
+    And   2 slots
+    And   512 as batch size
+    And   512 as ubatch size
+    And   512 KV cache size
+    And   enable reranking endpoint
+    Then  the server is starting
+    Then  the server is healthy
+
+  Scenario: Rerank
+    Given a rerank query:
+      """
+      Machine learning is
+      """
+    And   a rerank document:
+      """
+      A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.
+      """
+    And   a rerank document:
+      """
+      Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.
+      """
+    And   a rerank document:
+      """
+      Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.
+      """
+    And   a rerank document:
+      """
+      Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine.
+      """
+    When  reranking request
+    Then  reranking results are returned
+    Then  reranking highest score is index 2 and lowest score is index 3
index 0fea0fe87b79950bee4f7cac5cbf7fea31955247..2611614ba36333b37af0e343f2e9254e838ead8e 100644 (file)
@@ -68,6 +68,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
     context.server_api_key = None
     context.server_continuous_batching = False
     context.server_embeddings = False
+    context.server_reranking = False
     context.server_metrics = False
     context.server_process = None
     context.seed = None
@@ -83,6 +84,10 @@ def step_server_config(context, server_fqdn: str, server_port: str):
     context.concurrent_tasks = []
     context.prompts = []
 
+    context.reranking_query = None
+    context.reranking_documents = []
+    context.reranking_results = None
+
 
 @step('a model file {hf_file} from HF repo {hf_repo}')
 def step_download_hf_model(context, hf_file: str, hf_repo: str):
@@ -172,10 +177,13 @@ def step_server_continuous_batching(context):
     context.server_continuous_batching = True
 
 
-@step('embeddings extraction')
+@step('enable embeddings endpoint')
 def step_server_embeddings(context):
     context.server_embeddings = True
 
+@step('enable reranking endpoint')
+def step_server_reranking(context):
+    context.server_reranking = True
 
 @step('prometheus compatible metrics exposed')
 def step_server_metrics(context):
@@ -452,6 +460,14 @@ def step_impl(context, n_ga_w):
 def step_prompt_passkey(context):
     context.prompt_passkey = context_text(context)
 
+@step('a rerank query')
+def step_set_rerank_query(context):
+    context.reranking_query = context_text(context)
+    context.reranking_documents = []
+
+@step('a rerank document')
+def step_set_rerank_document(context):
+    context.reranking_documents.append(context_text(context))
 
 @step('{n_prompts:d} fixed prompts')
 def step_fixed_prompts(context, n_prompts):
@@ -619,6 +635,22 @@ async def step_compute_embedding(context):
     context.embeddings = await request_embedding(context_text(context), None, base_url=context.base_url)
 
 
+@step('reranking request')
+@async_run_until_complete
+async def step_compute_reranking(context):
+    async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
+        async with session.post(f'{context.base_url}/reranking',
+                                json={
+                                    "query": context.reranking_query,
+                                    "documents": context.reranking_documents,
+                                }) as response:
+            if response.status == 200:
+                response_json = await response.json()
+                context.reranking_results = response_json['results']
+            else:
+                context.reranking_results = response.status
+
+
 @step('all embeddings are the same')
 @async_run_until_complete
 async def step_all_embeddings_are_the_same(context):
@@ -704,6 +736,24 @@ async def all_embeddings_are_generated(context):
     for i in range(n_embedding_requests):
         assert_embeddings(context.tasks_result.pop().pop())
 
+@step('reranking results are returned')
+def reranking_results_are_returned(context):
+    assert len(context.reranking_results) == len(context.reranking_documents)
+
+@step('reranking highest score is index {idx_high:d} and lowest score is index {idx_low:d}')
+def reranking_results_are_returned(context, idx_high: int, idx_low: int):
+    max_score, max_idx = 0, 0
+    min_score, min_idx = 0, 0
+    for res in context.reranking_results:
+        if max_score < res['relevance_score']:
+            max_score = res['relevance_score']
+            max_idx   = res['index']
+        if min_score > res['relevance_score']:
+            min_score = res['relevance_score']
+            min_idx   = res['index']
+    print(context.reranking_results)
+    assert max_idx == idx_high
+    assert min_idx == idx_low
 
 @step('adding special tokens')
 def step_tokenize_set_add_special(context):
@@ -1362,6 +1412,8 @@ def start_server_background(context):
         server_args.append('--cont-batching')
     if context.server_embeddings:
         server_args.append('--embedding')
+    if context.server_reranking:
+        server_args.append('--reranking')
     if context.server_metrics:
         server_args.append('--metrics')
     if context.model_alias:
index f093f547ff2c1324c304523187cd662f61353c0d..47dfdfde512dc45e26110e9e9f2b58cf896e9fc4 100644 (file)
@@ -537,7 +537,7 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
     json res = json {
         {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
         {"object", "list"},
-        {"usage", json {
+        {"usage", json { // TODO: fill
             {"prompt_tokens", 0},
             {"total_tokens", 0}
         }},
@@ -547,6 +547,29 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
     return res;
 }
 
+static json format_response_rerank(const json & request, const json & ranks) {
+    json data = json::array();
+    int i = 0;
+    for (const auto & rank : ranks) {
+        data.push_back(json{
+            {"index",    i++},
+            {"relevance_score", json_value(rank, "score", 0.0)},
+        });
+    }
+
+    json res = json {
+        {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
+        {"object", "list"},
+        {"usage", json { // TODO: fill
+            {"prompt_tokens", 0},
+            {"total_tokens", 0}
+        }},
+        {"results", data}
+    };
+
+    return res;
+}
+
 static bool is_valid_utf8(const std::string & str) {
     const unsigned char* bytes = reinterpret_cast<const unsigned char*>(str.data());
     const unsigned char* end = bytes + str.length();
index 2fd2e9d2be82828fc51085fe8b9102aed034a9b1..ebe66a4a39f5f2b9d29e32c46509e62bb14e2c69 100644 (file)
@@ -345,6 +345,8 @@ class MODEL_TENSOR(IntEnum):
     ENC_FFN_DOWN         = auto()
     ENC_FFN_UP           = auto()
     ENC_OUTPUT_NORM      = auto()
+    CLS                  = auto() # classifier
+    CLS_OUT              = auto() # classifier output projection
 
 
 MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@@ -504,6 +506,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.ENC_FFN_DOWN:              "enc.blk.{bid}.ffn_down",
     MODEL_TENSOR.ENC_FFN_UP:                "enc.blk.{bid}.ffn_up",
     MODEL_TENSOR.ENC_OUTPUT_NORM:           "enc.output_norm",
+    MODEL_TENSOR.CLS:                       "cls",
+    MODEL_TENSOR.CLS_OUT:                   "cls.output",
 }
 
 MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@@ -613,6 +617,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN,
         MODEL_TENSOR.FFN_UP,
         MODEL_TENSOR.LAYER_OUT_NORM,
+        MODEL_TENSOR.CLS,
+        MODEL_TENSOR.CLS_OUT,
     ],
     MODEL_ARCH.NOMIC_BERT: [
         MODEL_TENSOR.TOKEN_EMBD,
@@ -644,6 +650,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_GATE,
         MODEL_TENSOR.FFN_DOWN,
         MODEL_TENSOR.LAYER_OUT_NORM,
+        MODEL_TENSOR.CLS,
     ],
     MODEL_ARCH.MPT: [
         MODEL_TENSOR.TOKEN_EMBD,
index 5ef91f11d312f8b224b6f7b1cc378d2c7d34cf20..e7e9b6fd5efbceb034c15ad4daa6abf45630a4cf 100644 (file)
@@ -679,6 +679,15 @@ class TensorNameMap:
         MODEL_TENSOR.ENC_OUTPUT_NORM: (
             "encoder.final_layer_norm", # t5
         ),
+
+        MODEL_TENSOR.CLS: (
+            "classifier",       # jina
+            "classifier.dense", # roberta
+        ),
+
+        MODEL_TENSOR.CLS_OUT: (
+            "classifier.out_proj", # roberta
+        ),
     }
 
     # architecture-specific block mappings
index 4ea8a2c2b664b552f320f29706bb014611529856..7cae1bbe2e5b84b2dabb0f3dc6d46786d6555a2e 100644 (file)
@@ -193,6 +193,7 @@ extern "C" {
         LLAMA_POOLING_TYPE_MEAN = 1,
         LLAMA_POOLING_TYPE_CLS  = 2,
         LLAMA_POOLING_TYPE_LAST = 3,
+        LLAMA_POOLING_TYPE_RANK = 4, // used by reranking models to attach the classification head to the graph
     };
 
     enum llama_attention_type {
@@ -202,9 +203,9 @@ extern "C" {
     };
 
     enum llama_split_mode {
-        LLAMA_SPLIT_MODE_NONE    = 0, // single GPU
-        LLAMA_SPLIT_MODE_LAYER   = 1, // split layers and KV across GPUs
-        LLAMA_SPLIT_MODE_ROW     = 2, // split rows across GPUs
+        LLAMA_SPLIT_MODE_NONE  = 0, // single GPU
+        LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
+        LLAMA_SPLIT_MODE_ROW   = 2, // split rows across GPUs
     };
 
     // TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979)
@@ -872,7 +873,8 @@ extern "C" {
 
     // Get the embeddings for a sequence id
     // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
-    // shape: [n_embd] (1-dimensional)
+    // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence
+    // otherwise: float[n_embd] (1-dimensional)
     LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
 
     //
index e4d844a73c216c67db9ebb6b30e5b96f68c1440d..d2f34ddd6b339c2ccdd78c7f8a5ba3501ba3d40f 100644 (file)
@@ -1554,7 +1554,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(
             } break;
         case LLAMA_VOCAB_TYPE_UGM:
             {
-                if (add_special && vocab.tokenizer_add_bos != 0) {
+                if (add_special && vocab.tokenizer_add_bos) {
                     GGML_ASSERT(vocab.special_bos_id != -1);
                     output.push_back(vocab.special_bos_id);
                 }
@@ -1572,14 +1572,14 @@ std::vector<llama_vocab::id> llama_tokenize_internal(
                     }
                 }
 
-                if (add_special && vocab.tokenizer_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) {
+                if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
                     LLAMA_LOG_WARN(
                         "%s: Added a BOS token to the prompt as specified by the model but the prompt "
                         "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
                         "Are you sure this is what you want?\n", __FUNCTION__);
                 }
 
-                if (add_special && vocab.tokenizer_add_eos == 1) {
+                if (add_special && vocab.tokenizer_add_eos) {
                     GGML_ASSERT(vocab.special_eos_id != -1);
                     output.push_back(vocab.special_eos_id);
                 }
@@ -1791,11 +1791,13 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
                 // suppressing them like CONTROL tokens.
                 if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
                     return _try_copy(token_text.data(), token_text.size());
-                } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
+                }
+                if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
                     std::string result = token_text;
                     llama_unescape_whitespace(result);
                     return _try_copy(result.data(), result.size());
-                } else if (attr & LLAMA_TOKEN_ATTR_BYTE) {
+                }
+                if (attr & LLAMA_TOKEN_ATTR_BYTE) {
                     char byte = (char) llama_token_to_byte(vocab, token);
                     return _try_copy((char*) &byte, 1);
                 }
@@ -1806,7 +1808,8 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
                 // suppressing them like CONTROL tokens.
                 if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
                     return _try_copy(token_text.data(), token_text.size());
-                } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
+                }
+                if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
                     std::string result = llama_decode_text(token_text);
                     return _try_copy(result.data(), result.size());
                 }
index 44afb31d74e531e54621bd1dd53a0c4f722f6632..c466cd88b7c1428da7ad0df6e819209634b338f7 100644 (file)
@@ -606,6 +606,8 @@ enum llm_tensor {
     LLM_TENSOR_ENC_FFN_DOWN,
     LLM_TENSOR_ENC_FFN_UP,
     LLM_TENSOR_ENC_OUTPUT_NORM,
+    LLM_TENSOR_CLS,
+    LLM_TENSOR_CLS_OUT,
 };
 
 static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
@@ -793,6 +795,8 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
             { LLM_TENSOR_LAYER_OUT_NORM,  "blk.%d.layer_output_norm" },
             { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+            { LLM_TENSOR_CLS,             "cls" },
+            { LLM_TENSOR_CLS_OUT,         "cls.output" },
         },
     },
     {
@@ -828,6 +832,7 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
             { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
             { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+            { LLM_TENSOR_CLS,             "cls" },
         },
     },
     {
@@ -2894,6 +2899,7 @@ struct llama_model {
     llama_hparams hparams = {};
     llama_vocab   vocab;
 
+    // TODO: should init all tensors to nullptr
     struct ggml_tensor * tok_embd;
     struct ggml_tensor * type_embd;
     struct ggml_tensor * pos_embd;
@@ -2906,6 +2912,12 @@ struct llama_model {
     struct ggml_tensor * output_b;
     struct ggml_tensor * output_norm_enc;
 
+    // classifier
+    struct ggml_tensor * cls;
+    struct ggml_tensor * cls_b;
+    struct ggml_tensor * cls_out   = nullptr;
+    struct ggml_tensor * cls_out_b = nullptr;
+
     std::vector<llama_layer> layers;
 
     llama_split_mode split_mode;
@@ -5604,11 +5616,11 @@ static void llm_load_hparams(
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps);
                 ml.get_key(LLM_KV_ATTENTION_CAUSAL,           hparams.causal_attn);
                 ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
-                ml.get_key(LLM_KV_POOLING_TYPE,               hparams.pooling_type);
+                ml.get_key(LLM_KV_POOLING_TYPE,               hparams.pooling_type, false);
                 hparams.f_max_alibi_bias = 8.0f;
 
                 switch (hparams.n_layer) {
-                    case 4: model.type = e_model::MODEL_33M; break; // jina-embeddings-small
+                    case 4:  model.type = e_model::MODEL_33M;  break; // jina-embeddings-small
                     case 12: model.type = e_model::MODEL_137M; break; // jina-embeddings-base
                 }
             } break;
@@ -6313,6 +6325,7 @@ static void llm_load_vocab(
                     tokenizer_pre == "phi-2"   ||
                     tokenizer_pre == "jina-es" ||
                     tokenizer_pre == "jina-de" ||
+                    tokenizer_pre == "jina-v1-en" ||
                     tokenizer_pre == "jina-v2-es" ||
                     tokenizer_pre == "jina-v2-de" ||
                     tokenizer_pre == "jina-v2-code") {
@@ -6439,7 +6452,12 @@ static void llm_load_vocab(
 
     for (uint32_t i = 0; i < n_vocab; i++) {
         std::string word = gguf_get_arr_str(ctx, token_idx, i);
-        GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
+
+        //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
+        if (word.empty()) {
+            LLAMA_LOG_WARN("%s: empty token at index %u\n", __func__, i);
+            word = "[EMPTY_" + std::to_string(i) + "]";
+        }
 
         vocab.token_to_id[word] = i;
         vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size());
@@ -6520,8 +6538,14 @@ static void llm_load_vocab(
         vocab.linefeed_id = ids[0];
     } else {
         const std::vector<int> ids = llama_tokenize_internal(vocab, "\xC4\x8A", false); // U+010A
-        GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
-        vocab.linefeed_id = ids[0];
+
+        //GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
+        if (ids.empty()) {
+            LLAMA_LOG_WARN("%s: model vocab missing newline token, using special_pad_id instead\n", __func__);
+            vocab.linefeed_id = vocab.special_pad_id;
+        } else {
+            vocab.linefeed_id = ids[0];
+        }
     }
 
     // special tokens
@@ -7394,6 +7418,12 @@ static bool llm_load_tensors(
 
                     if (model.arch == LLM_ARCH_BERT) {
                         model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD,    "weight"), {n_embd, n_ctx_train});
+
+                        model.cls   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        model.cls_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "bias"),   {n_embd},         llama_model_loader::TENSOR_NOT_REQUIRED);
+
+                        model.cls_out   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        model.cls_out_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS_OUT, "bias"),   {1},         llama_model_loader::TENSOR_NOT_REQUIRED);
                     }
 
                     model.tok_norm   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
@@ -7446,6 +7476,8 @@ static bool llm_load_tensors(
                     model.tok_norm   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); // LayerNorm
                     model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd}); //LayerNorm bias
 
+                    model.cls   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                    model.cls_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_CLS, "bias"),   {1},         llama_model_loader::TENSOR_NOT_REQUIRED);
                     for (int i = 0; i < n_layer; ++i) {
                         ggml_context * ctx_layer = ctx_for_layer(i);
                         ggml_context * ctx_split = ctx_for_layer_split(i);
@@ -10279,6 +10311,10 @@ struct llm_build_context {
         struct ggml_tensor * cur;
 
         switch (pooling_type) {
+            case LLAMA_POOLING_TYPE_NONE:
+                {
+                    cur = inp;
+                } break;
             case LLAMA_POOLING_TYPE_MEAN:
                 {
                     struct ggml_tensor * inp_mean = build_inp_mean();
@@ -10290,9 +10326,26 @@ struct llm_build_context {
                     struct ggml_tensor * inp_cls = build_inp_cls();
                     cur = ggml_get_rows(ctx0, inp, inp_cls);
                 } break;
-            case LLAMA_POOLING_TYPE_NONE:
+            case LLAMA_POOLING_TYPE_RANK:
                 {
-                    cur = inp;
+                    struct ggml_tensor * inp_cls = build_inp_cls();
+                    inp = ggml_get_rows(ctx0, inp, inp_cls);
+
+                    // classification head
+                    // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
+                    GGML_ASSERT(model.cls       != nullptr);
+                    GGML_ASSERT(model.cls_b     != nullptr);
+
+                    cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls, inp), model.cls_b);
+                    cur = ggml_tanh(ctx0, cur);
+
+                    // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
+                    // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
+                    if (model.cls_out) {
+                        GGML_ASSERT(model.cls_out_b != nullptr);
+
+                        cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b);
+                    }
                 } break;
             default:
                 {
@@ -11521,8 +11574,8 @@ struct llm_build_context {
             inpL = cur;
         }
 
-        // final output
         cur = inpL;
+
         cb(cur, "result_embd", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -16682,7 +16735,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
         }
     }
 
-    if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
+    if (cparams.embeddings && (
+                cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
+                cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
         const int64_t n_tokens     = batch.n_tokens;
         const int64_t n_seq_tokens = batch.n_seq_tokens;
         const int64_t n_seqs       = batch.n_seqs;
@@ -16697,7 +16752,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
             const llama_seq_id seq_id = batch.seq_id[s][0];
 
             // TODO: adapt limits to n_seqs when batch.equal_seqs is true
-            GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS");
+            GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
 
             for (int i = 0; i < n_seq_tokens; ++i) {
                 const llama_pos pos = batch.pos[s*n_seq_tokens + i];
@@ -17237,6 +17292,20 @@ static int llama_decode_internal(
                             ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
                         }
                     } break;
+                case LLAMA_POOLING_TYPE_RANK:
+                    {
+                        // extract the rerank score - a single float per sequence
+                        auto & embd_seq_out = lctx.embd_seq;
+
+                        for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
+                            const llama_seq_id seq_id = ubatch.seq_id[s][0];
+                            if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
+                                continue;
+                            }
+                            embd_seq_out[seq_id].resize(1);
+                            ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
+                        }
+                    } break;
                 case LLAMA_POOLING_TYPE_UNSPECIFIED:
                     {
                         GGML_ABORT("unknown pooling type");
@@ -17443,6 +17512,13 @@ static int llama_encode_internal(
                             ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
                         }
                     } break;
+                case LLAMA_POOLING_TYPE_RANK:
+                    {
+                        // TODO: this likely should be the same logic as in llama_decoder_internal, but better to
+                        //       wait for an encoder model that requires this pooling type in order to test it
+                        //       https://github.com/ggerganov/llama.cpp/pull/9510
+                        GGML_ABORT("RANK pooling not implemented yet");
+                    }
                 case LLAMA_POOLING_TYPE_UNSPECIFIED:
                     {
                         GGML_ABORT("unknown pooling type");