]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add support for qwen3 reranker (#15824)
authorDouglas Hanley <redacted>
Thu, 25 Sep 2025 08:53:09 +0000 (03:53 -0500)
committerGitHub <redacted>
Thu, 25 Sep 2025 08:53:09 +0000 (11:53 +0300)
common/common.cpp
convert_hf_to_gguf.py
examples/embedding/embedding.cpp
src/llama-arch.cpp
src/llama-graph.cpp
src/llama-graph.h
src/llama-model.cpp
tools/server/server.cpp
tools/server/utils.hpp

index 7dcc02443536a15a7e4498f406daf16685be0490..6ebd934154eb5b64b8bab586c98f8f12af181a14 100644 (file)
@@ -961,15 +961,13 @@ struct common_init_result common_init_from_params(common_params & params) {
 
         bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
         bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
+        bool has_rerank_prompt = llama_model_chat_template(model, "rerank") != NULL;
 
-        if (!has_eos && !has_sep) {
-            LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__);
+        if (!has_eos && !has_sep && !has_rerank_prompt) {
+            LOG_WRN("%s: warning: vocab does not have an EOS token, SEP token, or rerank prompt. Reranking will not work\n", __func__);
             ok = false;
         } else if (!has_eos) {
             LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
-        } else if (!has_sep) {
-            LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
-            ok = false;
         }
 
         if (!ok) {
index 9ebd8567ad23fc169aa6499239940f619e245b99..54578a295b6453a38c783a318afdb54c412ab71e 100755 (executable)
@@ -3717,11 +3717,29 @@ class Qwen2MoeModel(TextModel):
 class Qwen3Model(Qwen2Model):
     model_arch = gguf.MODEL_ARCH.QWEN3
 
+    # extra logic for rerank models
+    is_rerank: bool = False
+    is_tied_embeddings: bool = False
+    token_false_id: int | None = None
+    token_true_id: int | None = None
+
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
+
+        # track for intern-s1-mini
         hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
         self.origin_hf_arch = hparams.get('architectures', [None])[0]
 
+        # a bit hacky, but currently the only way to detect if this is a rerank model
+        # ref: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B
+        readme_path = self.dir_model / "README.md"
+        readme_text = ""
+        if readme_path.exists():
+            with readme_path.open("r", encoding="utf-8") as f:
+                readme_text = f.read()
+        if "# Qwen3-Reranker" in readme_text:
+            self._find_rerank_config()
+
     def set_vocab(self):
         # deal with intern-s1-mini
         if self.origin_hf_arch == 'InternS1ForConditionalGeneration':
@@ -3730,6 +3748,53 @@ class Qwen3Model(Qwen2Model):
 
         super().set_vocab()
 
+    def _find_rerank_config(self):
+        from transformers import AutoTokenizer
+        tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
+
+        self.is_rerank = True
+        self.is_tied_embeddings = self.hparams.get("tie_word_embeddings", False)
+        self.token_false_id = tokenizer.convert_tokens_to_ids("no")
+        self.token_true_id = tokenizer.convert_tokens_to_ids("yes")
+        self.sep_token_id = tokenizer.convert_tokens_to_ids("|")
+
+        assert self.token_false_id is not None and self.token_true_id is not None
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        if self.is_rerank:
+            self.gguf_writer.add_pooling_type(gguf.PoolingType.RANK)
+            self.gguf_writer.add_classifier_output_labels(["yes", "no"])
+            self.gguf_writer.add_chat_template([{
+                "name": "rerank",
+                "template": "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n"
+                            "<|im_start|>user\n<Instruct>: Given a web search query, retrieve relevant passages that answer the query\n<Query>: {query}\n<Document>: {document}<|im_end|>\n"
+                            "<|im_start|>assistant\n<think>\n\n</think>\n\n"
+            }])
+
+    def _get_cls_out_tensor(self, data_torch: Tensor) -> Tensor:
+        # extract "yes" and "no" tokens from the output lm_head tensor
+        false_row = data_torch[self.token_false_id]
+        true_row = data_torch[self.token_true_id]
+        return torch.stack([true_row, false_row], dim=0)
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        if self.is_rerank:
+            is_tied_head = self.is_tied_embeddings and "embed_tokens" in name
+            is_real_head = not self.is_tied_embeddings and "lm_head" in name
+            if is_tied_head or is_real_head:
+                cls_out_head = (
+                    gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.CLS_OUT] + ".weight",
+                    self._get_cls_out_tensor(data_torch),
+                )
+                if is_tied_head:
+                    embed = (self.map_tensor_name(name), data_torch)
+                    return [cls_out_head, embed]
+                if is_real_head:
+                    return [cls_out_head]
+
+        return super().modify_tensors(data_torch, name, bid)
+
 
 @ModelBase.register("Qwen3MoeForCausalLM")
 class Qwen3MoeModel(Qwen2MoeModel):
index 9ae7e4dbb05928ca6921f5a76c2cef548c02aed5..388908bc4d70a6adf608cc5fd24f866d555ecfd7 100644 (file)
@@ -95,8 +95,13 @@ int main(int argc, char ** argv) {
         params.n_batch = params.n_ctx;
     }
 
-    // For non-causal models, batch size must be equal to ubatch size
-    params.n_ubatch = params.n_batch;
+    // for non-causal models, batch size must be equal to ubatch size
+    if (params.attention_type != LLAMA_ATTENTION_TYPE_CAUSAL) {
+        params.n_ubatch = params.n_batch;
+    }
+
+    // get max number of sequences per batch
+    const int n_seq_max = llama_max_parallel_sequences();
 
     llama_backend_init();
     llama_numa_init(params.numa);
@@ -144,6 +149,7 @@ int main(int argc, char ** argv) {
     // get added sep and eos token, if any
     const std::string added_sep_token = llama_vocab_get_add_sep(vocab) ? llama_vocab_get_text(vocab, llama_vocab_sep(vocab)) : "";
     const std::string added_eos_token = llama_vocab_get_add_eos(vocab) ? llama_vocab_get_text(vocab, llama_vocab_eos(vocab)) : "";
+    const char * rerank_prompt = llama_model_chat_template(model, "rerank");
 
     // tokenize the prompts and trim
     std::vector<std::vector<int32_t>> inputs;
@@ -153,21 +159,28 @@ int main(int argc, char ** argv) {
         // split classification pairs and insert expected separator tokens
         if (pooling_type == LLAMA_POOLING_TYPE_RANK && prompt.find(params.cls_sep) != std::string::npos) {
             std::vector<std::string> pairs = split_lines(prompt, params.cls_sep);
-            std::string final_prompt;
-
-            for (size_t i = 0; i < pairs.size(); i++) {
-                final_prompt += pairs[i];
-                if (i != pairs.size() - 1) {
-                    if (!added_eos_token.empty()) {
-                        final_prompt += added_eos_token;
-                    }
-                    if (!added_sep_token.empty()) {
-                        final_prompt += added_sep_token;
+            if (rerank_prompt != nullptr) {
+                const std::string query = pairs[0];
+                const std::string doc = pairs[1];
+                std::string final_prompt = rerank_prompt;
+                string_replace_all(final_prompt, "{query}"   , query);
+                string_replace_all(final_prompt, "{document}", doc  );
+                inp = common_tokenize(vocab, final_prompt, true, true);
+            } else {
+                std::string final_prompt;
+                for (size_t i = 0; i < pairs.size(); i++) {
+                    final_prompt += pairs[i];
+                    if (i != pairs.size() - 1) {
+                        if (!added_eos_token.empty()) {
+                            final_prompt += added_eos_token;
+                        }
+                        if (!added_sep_token.empty()) {
+                            final_prompt += added_sep_token;
+                        }
                     }
                 }
+                inp = common_tokenize(ctx, final_prompt, true, true);
             }
-
-            inp = common_tokenize(ctx, final_prompt, true, true);
         } else {
             inp = common_tokenize(ctx, prompt, true, true);
         }
@@ -229,7 +242,7 @@ int main(int argc, char ** argv) {
         const uint64_t n_toks = inp.size();
 
         // encode if at capacity
-        if (batch.n_tokens + n_toks > n_batch) {
+        if (batch.n_tokens + n_toks > n_batch || s >= n_seq_max) {
             float * out = emb + e * n_embd;
             batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
             e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
index a4d2973ada5dc9a8289eb236b8a79373312b4f09..f992cdde86693e35f42d2377c25e0f9d9fe6a7b2 100644 (file)
@@ -721,6 +721,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
             { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
             { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_CLS_OUT,         "cls.output" },
             { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
             { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
             { LLM_TENSOR_ATTN_Q_NORM,     "blk.%d.attn_q_norm" },
index 9f2e417f1ff4b19be80e6371ff2048b7bad29c7c..d4faa2a63935045f1fb7b621222df33eba8c1a03 100644 (file)
@@ -204,7 +204,10 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
         std::vector<int> target_pos(n_seqs_unq, -1);
         std::vector<int> target_row(n_seqs_unq, -1);
 
-        bool last = cparams.pooling_type == LLAMA_POOLING_TYPE_LAST;
+        const bool last = (
+             cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
+            (cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token
+        );
 
         for (int i = 0; i < n_tokens; ++i) {
             const llama_pos pos = ubatch->pos[i];
@@ -1177,7 +1180,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
 }
 
 ggml_tensor * llm_graph_context::build_inp_cls() const {
-    auto inp = std::make_unique<llm_graph_input_cls>(cparams);
+    auto inp = std::make_unique<llm_graph_input_cls>(cparams, arch);
 
     auto & cur = inp->cls;
 
@@ -1877,34 +1880,32 @@ void llm_graph_context::build_pooling(
         case LLAMA_POOLING_TYPE_RANK:
             {
                 ggml_tensor * inp_cls = build_inp_cls();
-                inp = ggml_get_rows(ctx0, inp, inp_cls);
+                cur = ggml_get_rows(ctx0, inp, inp_cls);
 
+                // classification head
+                // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
                 if (cls) {
-                    // classification head
-                    // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
-                    cur = ggml_mul_mat(ctx0, cls, inp);
+                    cur = ggml_mul_mat(ctx0, cls, cur);
                     if (cls_b) {
                         cur = ggml_add(ctx0, cur, cls_b);
                     }
                     cur = ggml_tanh(ctx0, cur);
+                }
 
-                    // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
-                    // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
-                    if (cls_out) {
-                        cur = ggml_mul_mat(ctx0, cls_out, cur);
-                        if (cls_out_b) {
-                            cur = ggml_add(ctx0, cur, cls_out_b);
-                        }
-                    }
-                } else if (cls_out) {
-                    // Single layer classification head (direct projection)
-                    // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
-                    cur = ggml_mul_mat(ctx0, cls_out, inp);
+                // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
+                // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
+                // Single layer classification head (direct projection)
+                // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
+                if (cls_out) {
+                    cur = ggml_mul_mat(ctx0, cls_out, cur);
                     if (cls_out_b) {
                         cur = ggml_add(ctx0, cur, cls_out_b);
                     }
-                } else {
-                    GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
+                }
+
+                // softmax for qwen3 reranker
+                if (arch == LLM_ARCH_QWEN3) {
+                    cur = ggml_soft_max(ctx0, cur);
                 }
             } break;
         default:
index ca90fdf613f6de0074e465ca414e7b80c070e85a..34b984afeb04379e1e7b5aa4c792931fd3b03b9e 100644 (file)
@@ -206,7 +206,7 @@ public:
 
 class llm_graph_input_cls : public llm_graph_input_i {
 public:
-    llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {}
+    llm_graph_input_cls(const llama_cparams & cparams, const llm_arch arch) : cparams(cparams), arch(arch) {}
     virtual ~llm_graph_input_cls() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
@@ -214,6 +214,7 @@ public:
     ggml_tensor * cls; // I32 [n_batch]
 
     const llama_cparams cparams;
+    const llm_arch arch;
 };
 
 class llm_graph_input_rs : public llm_graph_input_i {
index 48d9859c7d0be448adc96b21040fa171b9a60b5f..ce372133a93ebb7401c04a1d7d95912f369b4fa0 100644 (file)
@@ -3167,6 +3167,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
                     }
 
+                    // output rerank head
+                    cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
+
                     for (int i = 0; i < n_layer; ++i) {
                         auto & layer = layers[i];
 
index d6072e5ece266a14bd49fa3b0241800324147513..129801fe06daacf0722fef81e8e618c71bc07a99 100644 (file)
@@ -5093,21 +5093,15 @@ int main(int argc, char ** argv) {
             return;
         }
 
-        std::vector<server_tokens> tokenized_queries = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, query, /* add_special */ false, true);
-        if (tokenized_queries.size() != 1) {
-            res_error(res, format_error_response("\"query\" must contain only a single prompt", ERROR_TYPE_INVALID_REQUEST));
-        }
-
         // create and queue the task
         json responses = json::array();
         bool error = false;
         std::unordered_set<int> task_ids;
         {
             std::vector<server_task> tasks;
-            auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, documents, /* add_special */ false, true);
-            tasks.reserve(tokenized_docs.size());
-            for (size_t i = 0; i < tokenized_docs.size(); i++) {
-                auto tmp = format_rerank(ctx_server.vocab, tokenized_queries[0], tokenized_docs[i]);
+            tasks.reserve(documents.size());
+            for (size_t i = 0; i < documents.size(); i++) {
+                auto tmp = format_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]);
                 server_task task   = server_task(SERVER_TASK_TYPE_RERANK);
                 task.id            = ctx_server.queue_tasks.get_new_id();
                 task.index         = i;
index 64d702930ce96a2cb18595e8d4d0cab057c37455..4ca1423aaf2d4bac5b4938cf08b3692e5ad14a6b 100644 (file)
@@ -1368,34 +1368,6 @@ static std::string fnv_hash(const uint8_t * data, size_t len) {
     return std::to_string(hash);
 }
 
-
-// format rerank task: [BOS]query[EOS][SEP]doc[EOS].
-static server_tokens format_rerank(const struct llama_vocab * vocab, server_tokens & query, server_tokens & doc) {
-    server_tokens result = {};
-
-    // Get EOS token - use SEP token as fallback if EOS is not available
-    llama_token eos_token = llama_vocab_eos(vocab);
-    if (eos_token == LLAMA_TOKEN_NULL) {
-        eos_token = llama_vocab_sep(vocab);
-    }
-    if (llama_vocab_get_add_bos(vocab)) {
-        result.push_back(llama_vocab_bos(vocab));
-    }
-    result.push_back(query);
-    if (llama_vocab_get_add_eos(vocab)) {
-        result.push_back(eos_token);
-    }
-    if (llama_vocab_get_add_sep(vocab)) {
-        result.push_back(llama_vocab_sep(vocab));
-    }
-    result.push_back(doc);
-    if (llama_vocab_get_add_eos(vocab)) {
-        result.push_back(eos_token);
-    }
-    return result;
-}
-
-
 static server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector<raw_buffer> files) {
     mtmd::bitmaps bitmaps;
     for (auto & file : files) {
@@ -1501,3 +1473,43 @@ static std::vector<server_tokens> tokenize_input_prompts(const llama_vocab * voc
     }
     return result;
 }
+
+// format rerank task: [BOS]query[EOS][SEP]doc[EOS].
+static server_tokens format_rerank(const struct llama_model * model, const struct llama_vocab * vocab, mtmd_context * mctx, const std::string & query, const std::string & doc) {
+    server_tokens result = {};
+
+    const char * rerank_prompt = llama_model_chat_template(model, "rerank");
+
+    if (rerank_prompt != nullptr) {
+        std::string prompt = rerank_prompt;
+        string_replace_all(prompt, "{query}"   , query);
+        string_replace_all(prompt, "{document}", doc  );
+        server_tokens tokens = tokenize_input_subprompt(vocab, mctx, prompt, false, true);
+        result.push_back(tokens);
+    } else {
+        // Get EOS token - use SEP token as fallback if EOS is not available
+        server_tokens query_tokens = tokenize_input_subprompt(vocab, mctx, query, false, false);
+        server_tokens doc_tokens   = tokenize_input_subprompt(vocab, mctx, doc,   false, false);
+        llama_token eos_token = llama_vocab_eos(vocab);
+        if (eos_token == LLAMA_TOKEN_NULL) {
+            eos_token = llama_vocab_sep(vocab);
+        }
+
+        if (llama_vocab_get_add_bos(vocab)) {
+            result.push_back(llama_vocab_bos(vocab));
+        }
+        result.push_back(query_tokens);
+        if (llama_vocab_get_add_eos(vocab)) {
+            result.push_back(eos_token);
+        }
+        if (llama_vocab_get_add_sep(vocab)) {
+            result.push_back(llama_vocab_sep(vocab));
+        }
+        result.push_back(doc_tokens);
+        if (llama_vocab_get_add_eos(vocab)) {
+            result.push_back(eos_token);
+        }
+    }
+
+    return result;
+}