# Mistral format specifics
is_mistral_format: bool = False
disable_mistral_community_chat_template: bool = False
+ sentence_transformers_dense_modules: bool = False
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
use_temp_file: bool = False, eager: bool = False,
metadata_override: Path | None = None, model_name: str | None = None,
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
- disable_mistral_community_chat_template: bool = False):
+ disable_mistral_community_chat_template: bool = False,
+ sentence_transformers_dense_modules: bool = False):
if type(self) is ModelBase or \
type(self) is TextModel or \
type(self) is MmprojModel:
self.lazy = not eager or (remote_hf_model_id is not None)
self.dry_run = dry_run
self.remote_hf_model_id = remote_hf_model_id
+ self.sentence_transformers_dense_modules = sentence_transformers_dense_modules
if remote_hf_model_id is not None:
self.is_safetensors = True
@ModelBase.register("Gemma3TextModel")
class EmbeddingGemma(Gemma3Model):
model_arch = gguf.MODEL_ARCH.GEMMA_EMBEDDING
+ module_paths = []
+ dense_features_dims = {}
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ if self.sentence_transformers_dense_modules:
+ # read modules.json to determine if model has Dense layers
+ modules_file = self.dir_model / "modules.json"
+ if modules_file.is_file():
+ with open(modules_file, encoding="utf-8") as modules_json_file:
+ mods = json.load(modules_json_file)
+ for mod in mods:
+ if mod["type"] == "sentence_transformers.models.Dense":
+ mod_path = mod["path"]
+ # check if model.safetensors file for Dense layer exists
+ model_tensors_file = self.dir_model / mod_path / "model.safetensors"
+ if model_tensors_file.is_file():
+ self.module_paths.append(mod_path)
+ # read config.json of the Dense layer to get in/out features
+ mod_conf_file = self.dir_model / mod_path / "config.json"
+ if mod_conf_file.is_file():
+ with open(mod_conf_file, encoding="utf-8") as mod_conf_json_file:
+ mod_conf = json.load(mod_conf_json_file)
+ # hparams dense_2_feat_out and dense_3_feat_in are required when loading model's dense weights
+ prefix = self._get_dense_prefix(mod_path)
+ if mod_conf["in_features"] is not None and mod_conf["out_features"] is not None:
+ self.dense_features_dims[prefix] = (mod_conf["in_features"], mod_conf["out_features"])
+
+ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
+ from safetensors.torch import load_file
+ module_paths = list(self.module_paths)
+ for i, module_path in enumerate(module_paths):
+ tensors_file = self.dir_model / module_path / "model.safetensors"
+ local_tensors = load_file(tensors_file)
+ tensor_name = self._get_dense_prefix(module_path)
+ for name, local_tensor in local_tensors.items():
+ if not name.endswith(".weight"):
+ continue
+ orig_name = name.replace("linear", tensor_name)
+ name = self.map_tensor_name(orig_name)
+ yield name, local_tensor.clone()
+
+ @staticmethod
+ def _get_dense_prefix(module_path) -> str:
+ """Get the tensor name prefix for the Dense layer from module path."""
+ tensor_name = "dense_2" if module_path == "2_Dense" else "dense_3"
+ return tensor_name
def set_gguf_parameters(self):
super().set_gguf_parameters()
logger.info(f"Using original sliding_window from config: {orig_sliding_window} "
f"instead of {self.hparams['sliding_window']}")
self.gguf_writer.add_sliding_window(orig_sliding_window)
+ if self.sentence_transformers_dense_modules:
+ for dense, dims in self.dense_features_dims.items():
+ logger.info(f"Setting dense layer {dense} in/out features to {dims}")
+ self.gguf_writer.add_dense_features_dims(dense, dims[0], dims[1])
self._try_set_pooling_type()
)
)
+ parser.add_argument(
+ "--sentence-transformers-dense-modules", action="store_true",
+ help=("Whether to include sentence-transformers dense modules."
+ "It can be used for sentence-transformers models, like google/embeddinggemma-300m"
+ "Default these modules are not included.")
+ )
+
args = parser.parse_args()
if not args.print_supported_models and args.model is None:
parser.error("the following arguments are required: model")
if args.remote:
hf_repo_id = args.model
from huggingface_hub import snapshot_download
+ allowed_patterns = ["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"]
+ if args.sentence_transformers_dense_modules:
+ # include sentence-transformers dense modules safetensors files
+ allowed_patterns.append("*.safetensors")
local_dir = snapshot_download(
repo_id=hf_repo_id,
- allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"])
+ allow_patterns=allowed_patterns)
dir_model = Path(local_dir)
logger.info(f"Downloaded config and tokenizer to {local_dir}")
else:
split_max_tensors=args.split_max_tensors,
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split,
- remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template
+ remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template,
+ sentence_transformers_dense_modules=args.sentence_transformers_dense_modules
)
if args.vocab_only:
ALTUP_ACTIVE_IDX = "{arch}.altup.active_idx"
ALTUP_NUM_INPUTS = "{arch}.altup.num_inputs"
EMBD_LENGTH_PER_LAYER_INP = "{arch}.embedding_length_per_layer_input"
+ DENSE_FEAT_IN_SIZE = "{arch}.{dense}_feat_in"
+ DENSE_FEAT_OUT_SIZE = "{arch}.{dense}_feat_out"
class Attention:
HEAD_COUNT = "{arch}.attention.head_count"
TOKEN_TYPES = auto()
POS_EMBD = auto()
OUTPUT = auto()
+ DENSE_2_OUT = auto() # embeddinggemma 2_Dense
+ DENSE_3_OUT = auto() # embeddinggemma 3_Dense
OUTPUT_NORM = auto()
ROPE_FREQS = auto()
ROPE_FACTORS_LONG = auto()
MODEL_TENSOR.POS_EMBD: "position_embd",
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
MODEL_TENSOR.OUTPUT: "output",
+ MODEL_TENSOR.DENSE_2_OUT: "dense_2", # embeddinggemma 2_Dense
+ MODEL_TENSOR.DENSE_3_OUT: "dense_3", # embeddinggemma 2_Dense
MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long",
MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short",
MODEL_ARCH.GEMMA_EMBEDDING: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.DENSE_2_OUT,
+ MODEL_TENSOR.DENSE_3_OUT,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
def add_sliding_window_pattern(self, value: Sequence[bool]) -> None:
self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value)
+ def add_dense_features_dims(self, dense:str, in_f:int, out_f:int) -> None:
+ self.add_uint32(Keys.LLM.DENSE_FEAT_IN_SIZE.format(arch=self.arch, dense=dense), in_f)
+ self.add_uint32(Keys.LLM.DENSE_FEAT_OUT_SIZE.format(arch=self.arch, dense=dense), out_f)
+
def add_logit_scale(self, value: float) -> None:
self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)
"lm_head", # llama4
"model.transformer.ff_out", # llada
),
-
+ MODEL_TENSOR.DENSE_2_OUT: (
+ "dense_2_out", # embeddinggemma
+ ),
+ MODEL_TENSOR.DENSE_3_OUT: (
+ "dense_3_out", # embeddinggemma
+ ),
# Output norm
MODEL_TENSOR.OUTPUT_NORM: (
"gpt_neox.final_layer_norm", # gptneox
{ LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" },
{ LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" },
+ // sentence-transformers dense modules feature dims
+ { LLM_KV_DENSE_2_FEAT_IN, "%s.dense_2_feat_in" },
+ { LLM_KV_DENSE_2_FEAT_OUT, "%s.dense_2_feat_out" },
+ { LLM_KV_DENSE_3_FEAT_IN, "%s.dense_3_feat_in" },
+ { LLM_KV_DENSE_3_FEAT_OUT, "%s.dense_3_feat_out" },
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_DENSE_2_OUT, "dense_2" },
+ { LLM_TENSOR_DENSE_3_OUT, "dense_3" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
+ {LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
+ {LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
{LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
LLM_KV_TOKENIZER_PREFIX_ID,
LLM_KV_TOKENIZER_SUFFIX_ID,
LLM_KV_TOKENIZER_MIDDLE_ID,
+
+ // sentence-transformers dense layers in and out features
+ LLM_KV_DENSE_2_FEAT_IN,
+ LLM_KV_DENSE_2_FEAT_OUT,
+ LLM_KV_DENSE_3_FEAT_IN,
+ LLM_KV_DENSE_3_FEAT_OUT,
};
enum llm_tensor {
LLM_TENSOR_TOKEN_EMBD_NORM,
LLM_TENSOR_TOKEN_TYPES,
LLM_TENSOR_POS_EMBD,
+ LLM_TENSOR_DENSE_2_OUT,
+ LLM_TENSOR_DENSE_3_OUT,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_ROPE_FREQS,
return nullptr;
}
+ if (params.pooling_type != model->hparams.pooling_type) {
+ //user-specified pooling-type is different from the model default
+ LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__,
+ model->hparams.pooling_type, params.pooling_type);
+ }
+
try {
auto * ctx = new llama_context(*model, params);
return ctx;
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
}
+void llm_graph_context::build_dense_out(
+ ggml_tensor * dense_2,
+ ggml_tensor * dense_3) const {
+ if (!cparams.embeddings || dense_2 == nullptr || dense_3 == nullptr) {
+ return;
+ }
+ ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd;
+ GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd");
+
+ cur = ggml_mul_mat(ctx0, dense_2, cur);
+ cur = ggml_mul_mat(ctx0, dense_3, cur);
+ cb(cur, "result_embd_pooled", -1);
+ res->t_embd_pooled = cur;
+ ggml_build_forward_expand(gf, cur);
+}
+
+
void llm_graph_context::build_pooling(
ggml_tensor * cls,
ggml_tensor * cls_b,
ggml_tensor * cls_b,
ggml_tensor * cls_out,
ggml_tensor * cls_out_b) const;
+
+ //
+ // dense (out)
+ //
+
+ void build_dense_out(
+ ggml_tensor * dense_2,
+ ggml_tensor * dense_3) const;
};
// TODO: better name
uint32_t laurel_rank = 64;
uint32_t n_embd_altup = 256;
+ // needed for sentence-transformers dense layers
+ uint32_t dense_2_feat_in = 0; // in_features of the 2_Dense
+ uint32_t dense_2_feat_out = 0; // out_features of the 2_Dense
+ uint32_t dense_3_feat_in = 0; // in_features of the 3_Dense
+ uint32_t dense_3_feat_out = 0; // out_features of the 3_Dense
+
// xIELU
std::array<float, LLAMA_MAX_LAYERS> xielu_alpha_n;
std::array<float, LLAMA_MAX_LAYERS> xielu_alpha_p;
hparams.set_swa_pattern(6);
hparams.causal_attn = false; // embeddings do not use causal attention
- hparams.rope_freq_base_train_swa = 10000.0f;
+ hparams.rope_freq_base_train_swa = 10000.0f;
hparams.rope_freq_scale_train_swa = 1.0f;
- ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
+ ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
- ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
+ ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
+
+ //applied only if model converted with --sentence-transformers-dense-modules
+ ml.get_key(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in, false);
+ ml.get_key(LLM_KV_DENSE_2_FEAT_OUT, hparams.dense_2_feat_out, false);
+ ml.get_key(LLM_KV_DENSE_3_FEAT_IN, hparams.dense_3_feat_in, false);
+ ml.get_key(LLM_KV_DENSE_3_FEAT_OUT, hparams.dense_3_feat_out, false);
+
+ GGML_ASSERT((hparams.dense_2_feat_in == 0 || hparams.dense_2_feat_in == hparams.n_embd) && "dense_2_feat_in must be equal to n_embd");
+ GGML_ASSERT((hparams.dense_3_feat_out == 0 || hparams.dense_3_feat_out == hparams.n_embd) && "dense_3_feat_out must be equal to n_embd");
switch (hparams.n_layer) {
case 24: type = LLM_TYPE_0_3B; break;
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}
+ // Dense linear weights
+ dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.dense_2_feat_out}, TENSOR_NOT_REQUIRED);
+ dense_3_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_3_OUT, "weight"), {hparams.dense_3_feat_in, n_embd}, TENSOR_NOT_REQUIRED);
+
+
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
// add on pooling layer
llm->build_pooling(cls, cls_b, cls_out, cls_out_b);
+ // if the gguf model was converted with --sentence-transformers-dense-modules
+ // there will be two additional dense projection layers
+ // dense linear projections are applied after pooling
+ // TODO: move reranking logic here and generalize
+ llm->build_dense_out(dense_2_out_layers, dense_3_out_layers);
+
return llm->res->get_gf();
}
std::vector<llama_layer> layers;
+ //Dense linear projections for SentenceTransformers models like embeddinggemma
+ // For Sentence Transformers models structure see
+ // https://sbert.net/docs/sentence_transformer/usage/custom_models.html#structure-of-sentence-transformer-models
+ struct ggml_tensor * dense_2_out_layers = nullptr;
+ struct ggml_tensor * dense_3_out_layers = nullptr;
+
llama_model_params params;
// gguf metadata