]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add support for BertForSequenceClassification reranker (#13858)
authorĐinh Trọng Huy <redacted>
Wed, 28 May 2025 17:01:58 +0000 (02:01 +0900)
committerGitHub <redacted>
Wed, 28 May 2025 17:01:58 +0000 (19:01 +0200)
* convert: add support for BertForSequenceClassification

* add support for reranking using BertForSequenceClassification

* merge checks of eos and sep

* fix lint

---------

Co-authored-by: dinhhuy <redacted>
common/common.cpp
convert_hf_to_gguf.py
src/llama-graph.cpp
tools/server/utils.hpp

index 2afa9b2d641d4e0f0f46a9c6a1ba260321cb599c..d80c47bc393faa074a43c6d1f21ce8d15aa4b046 100644 (file)
@@ -903,13 +903,16 @@ struct common_init_result common_init_from_params(common_params & params) {
             ok = false;
         }
 
-        if (llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
-            LOG_WRN("%s: warning: vocab does not have an EOS token, reranking will not work\n", __func__);
-            ok = false;
-        }
+        bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
+        bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
 
-        if (llama_vocab_sep(vocab) == LLAMA_TOKEN_NULL) {
-            LOG_WRN("%s: warning: vocab does not have a  SEP token, reranking will not work\n", __func__);
+        if (!has_eos && !has_sep) {
+            LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__);
+            ok = false;
+        } else if (!has_eos) {
+            LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
+        } else if (!has_sep) {
+            LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
             ok = false;
         }
 
index 8ba82f7174f378581e088beea9a908b0eb1d52a8..227ae7bc235ee25a06e06d2424462f278d7e0cc7 100755 (executable)
@@ -3682,7 +3682,7 @@ class InternLM3Model(TextModel):
         return [(self.map_tensor_name(name), data_torch)]
 
 
-@ModelBase.register("BertModel", "BertForMaskedLM", "CamembertModel")
+@ModelBase.register("BertModel", "BertForMaskedLM", "CamembertModel", "BertForSequenceClassification")
 class BertModel(TextModel):
     model_arch = gguf.MODEL_ARCH.BERT
 
@@ -3745,6 +3745,13 @@ class BertModel(TextModel):
         if name.startswith("cls.seq_relationship"):
             return []
 
+        # For BertForSequenceClassification (direct projection layer)
+        if name == "classifier.weight":
+            name = "classifier.out_proj.weight"
+
+        if name == "classifier.bias":
+            name = "classifier.out_proj.bias"
+
         return [(self.map_tensor_name(name), data_torch)]
 
     def _xlmroberta_tokenizer_init(self) -> None:
index cdd5887de961c6f622e6c507efcf2775ad75709a..7a91ff3df8885ccc393c809b0b39cb2d9d624789 100644 (file)
@@ -1562,20 +1562,25 @@ void llm_graph_context::build_pooling(
                 ggml_tensor * inp_cls = build_inp_cls();
                 inp = ggml_get_rows(ctx0, inp, inp_cls);
 
-                // classification head
-                // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
-                GGML_ASSERT(cls   != nullptr);
-                GGML_ASSERT(cls_b != nullptr);
-
-                cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
-                cur = ggml_tanh(ctx0, cur);
-
-                // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
-                // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
-                if (cls_out) {
+                if (cls != nullptr && cls_b != nullptr) {
+                    // classification head
+                    // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
+                    cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
+                    cur = ggml_tanh(ctx0, cur);
+
+                    // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
+                    // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
+                    if (cls_out) {
+                        GGML_ASSERT(cls_out_b != nullptr);
+                        cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
+                    }
+                } else if (cls_out) {
+                    // Single layer classification head (direct projection)
+                    // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
                     GGML_ASSERT(cls_out_b != nullptr);
-
-                    cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
+                    cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, inp), cls_out_b);
+                } else {
+                    GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
                 }
             } break;
         default:
index 14e09f95087fb2680bb8f673ae8baff682ac4340..b64620582a402273f916b4efff5670fd1845260d 100644 (file)
@@ -264,13 +264,19 @@ static size_t validate_utf8(const std::string& text) {
 static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_tokens & query, const llama_tokens & doc) {
     llama_tokens result;
 
+    // Get EOS token - use SEP token as fallback if EOS is not available
+    llama_token eos_token = llama_vocab_eos(vocab);
+    if (eos_token == LLAMA_TOKEN_NULL) {
+        eos_token = llama_vocab_sep(vocab);
+    }
+
     result.reserve(doc.size() + query.size() + 4);
     result.push_back(llama_vocab_bos(vocab));
     result.insert(result.end(), query.begin(), query.end());
-    result.push_back(llama_vocab_eos(vocab));
+    result.push_back(eos_token);
     result.push_back(llama_vocab_sep(vocab));
     result.insert(result.end(), doc.begin(), doc.end());
-    result.push_back(llama_vocab_eos(vocab));
+    result.push_back(eos_token);
 
     return result;
 }