]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add RobertaForSequenceClassification reranker support (#13875)
authorSigbjørn Skjæret <redacted>
Thu, 29 May 2025 06:15:01 +0000 (08:15 +0200)
committerGitHub <redacted>
Thu, 29 May 2025 06:15:01 +0000 (08:15 +0200)
convert_hf_to_gguf.py
gguf-py/gguf/constants.py
src/llama-arch.cpp
src/llama-arch.h
src/llama-hparams.h
src/llama-model.cpp

index 227ae7bc235ee25a06e06d2424462f278d7e0cc7..797773193563b511791af68ee0e7206b93d3b805 100755 (executable)
@@ -3695,6 +3695,10 @@ class BertModel(TextModel):
         self.gguf_writer.add_causal_attention(False)
         self._try_set_pooling_type()
 
+        if cls_out_labels := self.hparams.get("id2label"):
+            key_name = gguf.Keys.Classifier.OUTPUT_LABELS.format(arch = gguf.MODEL_ARCH_NAMES[self.model_arch])
+            self.gguf_writer.add_array(key_name, [v for k, v in sorted(cls_out_labels.items())])
+
     def set_vocab(self):
         tokens, toktypes, tokpre = self.get_vocab_base()
         self.vocab_size = len(tokens)
@@ -3745,12 +3749,13 @@ class BertModel(TextModel):
         if name.startswith("cls.seq_relationship"):
             return []
 
-        # For BertForSequenceClassification (direct projection layer)
-        if name == "classifier.weight":
-            name = "classifier.out_proj.weight"
+        if self.hparams.get("id2label"):
+            # For BertForSequenceClassification (direct projection layer)
+            if name == "classifier.weight":
+                name = "classifier.out_proj.weight"
 
-        if name == "classifier.bias":
-            name = "classifier.out_proj.bias"
+            if name == "classifier.bias":
+                name = "classifier.out_proj.bias"
 
         return [(self.map_tensor_name(name), data_torch)]
 
@@ -3846,7 +3851,7 @@ class BertModel(TextModel):
         self.gguf_writer.add_add_eos_token(True)
 
 
-@ModelBase.register("RobertaModel")
+@ModelBase.register("RobertaModel", "RobertaForSequenceClassification")
 class RobertaModel(BertModel):
     model_arch = gguf.MODEL_ARCH.BERT
 
index 31163effad8f283cf92915d019c3f68783e398a8..635b61f224b1e96e7b06557d34c614d1ec068685 100644 (file)
@@ -177,6 +177,9 @@ class Keys:
         EMBEDDING_LENGTH = "{arch}.convnext.embedding_length"
         BLOCK_COUNT      = "{arch}.convnext.block_count"
 
+    class Classifier:
+        OUTPUT_LABELS = "{arch}.classifier.output_labels"
+
     class Tokenizer:
         MODEL                = "tokenizer.ggml.model"
         PRE                  = "tokenizer.ggml.pre"
index abf436adac41665824c97c062d2a0c8349785b2d..2bb18c85fce2810309eea8e8ac1c6953d9133293 100644 (file)
@@ -174,6 +174,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_CONVNEXT_EMBEDDING_LENGTH, "%s.convnext.embedding_length" },
     { LLM_KV_CONVNEXT_BLOCK_COUNT,      "%s.convnext.block_count"      },
 
+    { LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" },
+
     { LLM_KV_TOKENIZER_MODEL,                "tokenizer.ggml.model"                    },
     { LLM_KV_TOKENIZER_PRE,                  "tokenizer.ggml.pre"                      },
     { LLM_KV_TOKENIZER_LIST,                 "tokenizer.ggml.tokens"                   },
index 41a023da3da6ee31a4637f84cac401b4f5de23b4..930cb4eca33ab7dac049d7c45b94836ff310defc 100644 (file)
@@ -213,6 +213,8 @@ enum llm_kv {
     LLM_KV_CONVNEXT_EMBEDDING_LENGTH,
     LLM_KV_CONVNEXT_BLOCK_COUNT,
 
+    LLM_KV_CLASSIFIER_OUTPUT_LABELS,
+
     // deprecated:
     LLM_KV_TOKENIZER_PREFIX_ID,
     LLM_KV_TOKENIZER_SUFFIX_ID,
index 2d72eab180ad0c93cb797e194d73ff38e11547fd..b2bcb8b01a18b8e07476cb332e6fd356687db0f8 100644 (file)
@@ -131,6 +131,9 @@ struct llama_hparams {
     bool attn_soft_cap = false;
     bool use_kq_norm   = true;
 
+    // for Classifiers
+    uint32_t n_cls_out = 1;
+
     // llama4
     uint32_t n_moe_layer_step        = 0;
     uint32_t n_no_rope_layer_step    = 4;
index e99f5309f99044663d190cacef2cfb1b1864f71d..4a4618a2bcab7644e4e8cd79baeed4f4927a1297 100644 (file)
@@ -683,6 +683,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps);
                 ml.get_key(LLM_KV_ATTENTION_CAUSAL,           hparams.causal_attn);
                 ml.get_key(LLM_KV_POOLING_TYPE,               hparams.pooling_type, false);
+                ml.get_arr_n(LLM_KV_CLASSIFIER_OUTPUT_LABELS, hparams.n_cls_out, false);
 
                 switch (hparams.n_layer) {
                     case 3:
@@ -2121,8 +2122,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         cls   = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED);
                         cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"),   {n_embd},         TENSOR_NOT_REQUIRED);
 
-                        cls_out   = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, 1}, TENSOR_NOT_REQUIRED);
-                        cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"),   {1},         TENSOR_NOT_REQUIRED);
+                        cls_out   = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
+                        cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"),   {hparams.n_cls_out},         TENSOR_NOT_REQUIRED);
                     }
 
                     tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);