ok = false;
}
- if (llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
- LOG_WRN("%s: warning: vocab does not have an EOS token, reranking will not work\n", __func__);
- ok = false;
- }
+ bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
+ bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
- if (llama_vocab_sep(vocab) == LLAMA_TOKEN_NULL) {
- LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
+ if (!has_eos && !has_sep) {
+ LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__);
+ ok = false;
+ } else if (!has_eos) {
+ LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
+ } else if (!has_sep) {
+ LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
ok = false;
}
return [(self.map_tensor_name(name), data_torch)]
-@ModelBase.register("BertModel", "BertForMaskedLM", "CamembertModel")
+@ModelBase.register("BertModel", "BertForMaskedLM", "CamembertModel", "BertForSequenceClassification")
class BertModel(TextModel):
model_arch = gguf.MODEL_ARCH.BERT
if name.startswith("cls.seq_relationship"):
return []
+ # For BertForSequenceClassification (direct projection layer)
+ if name == "classifier.weight":
+ name = "classifier.out_proj.weight"
+
+ if name == "classifier.bias":
+ name = "classifier.out_proj.bias"
+
return [(self.map_tensor_name(name), data_torch)]
def _xlmroberta_tokenizer_init(self) -> None:
ggml_tensor * inp_cls = build_inp_cls();
inp = ggml_get_rows(ctx0, inp, inp_cls);
- // classification head
- // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
- GGML_ASSERT(cls != nullptr);
- GGML_ASSERT(cls_b != nullptr);
-
- cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
- cur = ggml_tanh(ctx0, cur);
-
- // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
- // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
- if (cls_out) {
+ if (cls != nullptr && cls_b != nullptr) {
+ // classification head
+ // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
+ cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
+ cur = ggml_tanh(ctx0, cur);
+
+ // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
+ // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
+ if (cls_out) {
+ GGML_ASSERT(cls_out_b != nullptr);
+ cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
+ }
+ } else if (cls_out) {
+ // Single layer classification head (direct projection)
+ // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
GGML_ASSERT(cls_out_b != nullptr);
-
- cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
+ cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, inp), cls_out_b);
+ } else {
+ GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
}
} break;
default:
static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_tokens & query, const llama_tokens & doc) {
llama_tokens result;
+ // Get EOS token - use SEP token as fallback if EOS is not available
+ llama_token eos_token = llama_vocab_eos(vocab);
+ if (eos_token == LLAMA_TOKEN_NULL) {
+ eos_token = llama_vocab_sep(vocab);
+ }
+
result.reserve(doc.size() + query.size() + 4);
result.push_back(llama_vocab_bos(vocab));
result.insert(result.end(), query.begin(), query.end());
- result.push_back(llama_vocab_eos(vocab));
+ result.push_back(eos_token);
result.push_back(llama_vocab_sep(vocab));
result.insert(result.end(), doc.begin(), doc.end());
- result.push_back(llama_vocab_eos(vocab));
+ result.push_back(eos_token);
return result;
}