self.remote_hf_model_id = remote_hf_model_id
self.sentence_transformers_dense_modules = sentence_transformers_dense_modules
self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams
+ self.rope_parameters = self.hparams.get("rope_parameters", self.hparams.get("rope_scaling")) or {}
self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id)
self.metadata_override = metadata_override
self.model_name = model_name
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
+ # Ensure "rope_theta" and "rope_type" is mirrored in rope_parameters
+ if "full_attention" not in self.rope_parameters and "sliding_attention" not in self.rope_parameters:
+ if "rope_theta" not in self.rope_parameters and (rope_theta := self.find_hparam(["rope_theta", "global_rope_theta", "rotary_emb_base"], optional=True)) is not None:
+ self.rope_parameters["rope_theta"] = rope_theta
+ if "rope_type" not in self.rope_parameters and (rope_type := self.rope_parameters.get("type")) is not None:
+ self.rope_parameters["rope_type"] = rope_type
+
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
if self.ftype == gguf.LlamaFileType.GUESSED:
# NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
def set_gguf_parameters(self):
self.gguf_writer.add_block_count(self.block_count)
- if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx", "n_positions", "max_length"], optional=True)) is not None:
+ if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx", "n_positions", "max_length", "max_sequence_length", "model_max_length"], optional=True)) is not None:
self.gguf_writer.add_context_length(n_ctx)
logger.info(f"gguf: context length = {n_ctx}")
self.gguf_writer.add_head_count_kv(n_head_kv)
logger.info(f"gguf: key-value head count = {n_head_kv}")
- if (rope_theta := self.hparams.get("rope_theta")) is not None:
+ rope_params = self.rope_parameters.get("full_attention", self.rope_parameters)
+ if (rope_type := rope_params.get("rope_type")) is not None:
+ rope_factor = rope_params.get("factor")
+ rope_gguf_type = gguf.RopeScalingType.NONE
+ if rope_type == "linear" and rope_factor is not None:
+ rope_gguf_type = gguf.RopeScalingType.LINEAR
+ self.gguf_writer.add_rope_scaling_type(rope_gguf_type)
+ self.gguf_writer.add_rope_scaling_factor(rope_factor)
+ elif rope_type == "yarn" and rope_factor is not None:
+ rope_gguf_type = gguf.RopeScalingType.YARN
+ self.gguf_writer.add_rope_scaling_type(rope_gguf_type)
+ self.gguf_writer.add_rope_scaling_factor(rope_factor)
+ self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_params["original_max_position_embeddings"])
+ if (yarn_ext_factor := rope_params.get("extrapolation_factor")) is not None:
+ self.gguf_writer.add_rope_scaling_yarn_ext_factor(yarn_ext_factor)
+ if (yarn_attn_factor := rope_params.get("attention_factor", rope_params.get("attn_factor"))) is not None:
+ self.gguf_writer.add_rope_scaling_yarn_attn_factor(yarn_attn_factor)
+ if (yarn_beta_fast := rope_params.get("beta_fast")) is not None:
+ self.gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_beta_fast)
+ if (yarn_beta_slow := rope_params.get("beta_slow")) is not None:
+ self.gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_beta_slow)
+ # self.gguf_writer.add_rope_scaling_yarn_log_mul(rope_params["mscale_all_dim"])
+ elif rope_type == "su" or rope_type == "longrope":
+ rope_gguf_type = gguf.RopeScalingType.LONGROPE
+ self.gguf_writer.add_rope_scaling_type(rope_gguf_type)
+ elif rope_type == "dynamic":
+ # HunYuan, handled in model class
+ pass
+ elif rope_type.lower() == "llama3":
+ # Handled in generate_extra_tensors
+ pass
+ else:
+ logger.warning(f"Unknown RoPE type: {rope_type}")
+ logger.info(f"gguf: rope scaling type = {rope_gguf_type.name}")
+
+ if (rope_theta := rope_params.get("rope_theta")) is not None:
self.gguf_writer.add_rope_freq_base(rope_theta)
logger.info(f"gguf: rope theta = {rope_theta}")
if (f_rms_eps := self.find_hparam(["rms_norm_eps", "norm_eps"], optional=True)) is not None:
self._set_vocab_sentencepiece()
def set_gguf_parameters(self):
- head_count = self.hparams["num_attention_heads"]
- head_count_kv = self.hparams.get("num_key_value_heads", head_count)
-
- ctx_length = 0
- if "max_sequence_length" in self.hparams:
- ctx_length = self.hparams["max_sequence_length"]
- elif "max_position_embeddings" in self.hparams:
- ctx_length = self.hparams["max_position_embeddings"]
- elif "model_max_length" in self.hparams:
- ctx_length = self.hparams["model_max_length"]
- else:
- raise ValueError("gguf: can not find ctx length parameter.")
+ super().set_gguf_parameters()
self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
- self.gguf_writer.add_context_length(ctx_length)
- self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
- self.gguf_writer.add_block_count(self.block_count)
- self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
- self.gguf_writer.add_head_count(head_count)
- self.gguf_writer.add_head_count_kv(head_count_kv)
- self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
- self.gguf_writer.add_file_type(self.ftype)
-
- rope_scaling = self.hparams.get("rope_scaling") or {}
- if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling:
- self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
- self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
head_count = self.hparams["num_attention_heads"]
special_vocab.add_to_gguf(self.gguf_writer)
def set_gguf_parameters(self):
- head_count = self.hparams["num_attention_heads"]
- head_count_kv = self.hparams.get("num_key_value_heads", head_count)
-
- ctx_length = 0
- if "max_sequence_length" in self.hparams:
- ctx_length = self.hparams["max_sequence_length"]
- elif "max_position_embeddings" in self.hparams:
- ctx_length = self.hparams["max_position_embeddings"]
- elif "model_max_length" in self.hparams:
- ctx_length = self.hparams["model_max_length"]
- else:
- raise ValueError("gguf: can not find ctx length parameter.")
+ super().set_gguf_parameters()
self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
- self.gguf_writer.add_context_length(ctx_length)
- self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
- self.gguf_writer.add_block_count(self.block_count)
- self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
- self.gguf_writer.add_head_count(head_count)
- self.gguf_writer.add_head_count_kv(head_count_kv)
- self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
- self.gguf_writer.add_file_type(self.ftype)
-
- rope_scaling = self.hparams.get("rope_scaling") or {}
- if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling:
- self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
- self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(rope_dim)
- rope_scaling = self.hparams.get("rope_scaling") or {}
- if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling:
- self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
- self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
-
@staticmethod
def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
if n_head_kv is not None and n_head != n_head_kv:
return [(self.map_tensor_name(name), data_torch)]
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
- if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
- if rope_scaling.get("rope_type", '').lower() == "llama3":
- base = self.hparams.get("rope_theta", 10000.0)
+ if rope_params := self.rope_parameters.get("full_attention", self.rope_parameters):
+ if rope_params.get("rope_type", '').lower() == "llama3":
+ base = rope_params.get("rope_theta", 10000.0)
if (dim := self.hparams.get("head_dim")) is None:
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
- factor = rope_scaling.get("factor", 8.0)
- low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
- high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
+ factor = rope_params.get("factor", 8.0)
+ low_freq_factor = rope_params.get("low_freq_factor", 1.0)
+ high_freq_factor = rope_params.get("high_freq_factor", 4.0)
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
low_freq_wavelen = old_context_len / low_freq_factor
def set_gguf_parameters(self):
super().set_gguf_parameters()
self._try_set_pooling_type()
- rope_scaling = self.hparams.get("rope_scaling") or {}
- if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
- self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
- self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
- self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
@ModelBase.register("AfmoeForCausalLM")
def set_gguf_parameters(self):
super().set_gguf_parameters()
- rope_params = self.hparams.get("rope_parameters")
+ rope_params = self.rope_parameters
if self.hparams.get("model_type") == "ministral3":
- assert rope_params is not None, "ministral3 must have 'rope_parameters' config"
+ assert rope_params, "ministral3 must have 'rope_parameters' config"
assert rope_params["rope_type"] == "yarn", "ministral3 rope_type must be 'yarn'"
- self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
- self.gguf_writer.add_rope_scaling_factor(rope_params["factor"])
- self.gguf_writer.add_rope_scaling_yarn_beta_fast(rope_params["beta_fast"])
- self.gguf_writer.add_rope_scaling_yarn_beta_slow(rope_params["beta_slow"])
self.gguf_writer.add_rope_scaling_yarn_log_mul(rope_params["mscale_all_dim"])
- self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_params["original_max_position_embeddings"])
- self.gguf_writer.add_rope_freq_base(rope_params["rope_theta"])
self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
assert self.block_count == len(self._num_kv_heads)
assert self.block_count == len(self._num_heads)
assert self.block_count == len(self._ffn_dims)
- if (rope_theta := self.hparams.get("rope_theta")) is not None:
+ if (rope_theta := self.rope_parameters.get("rope_theta")) is not None:
self.gguf_writer.add_rope_freq_base(rope_theta)
self.gguf_writer.add_head_count_kv(self._num_kv_heads)
self.gguf_writer.add_head_count(self._num_heads)
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(rope_dim)
- rope_scaling = self.hparams.get("rope_scaling") or {}
- if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling:
- self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
- self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
-
@staticmethod
def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
if n_head_kv is not None and n_head != n_head_kv:
return [(self.map_tensor_name(name), data_torch)]
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
- if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
- if rope_scaling.get("rope_type", '').lower() == "llama3":
- base = self.hparams.get("rope_theta", 10000.0)
+ if rope_params := self.rope_parameters.get("full_attention", self.rope_parameters):
+ if rope_params.get("rope_type", '').lower() == "llama3":
+ base = rope_params.get("rope_theta", 10000.0)
if (dim := self.hparams.get("head_dim")) is None:
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
- factor = rope_scaling.get("factor", 8.0)
- low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
- high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
+ factor = rope_params.get("factor", 8.0)
+ low_freq_factor = rope_params.get("low_freq_factor", 1.0)
+ high_freq_factor = rope_params.get("high_freq_factor", 4.0)
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
low_freq_wavelen = old_context_len / low_freq_factor
logit_scale = self.hparams["hidden_size"] / self.hparams["dim_model_base"]
self.gguf_writer.add_logit_scale(logit_scale)
logger.info(f"gguf: (minicpm) logit_scale = {logit_scale}")
- rope_scaling = self.hparams.get("rope_scaling") or {}
- if rope_scaling.get("rope_type", rope_scaling.get("type")) == "longrope":
- self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LONGROPE)
- logger.info(f"gguf: (minicpm) rope_scaling_type = {gguf.RopeScalingType.LONGROPE}")
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
rope_dims = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
def set_vocab(self):
self._set_vocab_qwen()
- def set_gguf_parameters(self):
- self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
- self.gguf_writer.add_block_count(self.block_count)
- self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
- self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
- self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
- self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
- self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
- self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
- self.gguf_writer.add_file_type(self.ftype)
-
@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM", "Qwen2AudioForConditionalGeneration")
class Qwen2Model(TextModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
self._try_set_pooling_type()
- rope_scaling = self.hparams.get("rope_scaling") or {}
- if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
- self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
- self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
- self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if self.hf_arch == "Qwen2Model":
# Dream models use non-causal attention for diffusion
self.gguf_writer.add_causal_attention(False)
- # Handle RoPE scaling similar to Qwen2
- rope_scaling = self.hparams.get("rope_scaling") or {}
- if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
- self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
- self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
- self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
# Add Dream-specific parameters
mask_token_id = self.hparams.get("mask_token_id")
if (shared_expert_intermediate_size := self.hparams.get('shared_expert_intermediate_size')) is not None:
self.gguf_writer.add_expert_shared_feed_forward_length(shared_expert_intermediate_size)
logger.info(f"gguf: expert shared feed forward length = {shared_expert_intermediate_size}")
- # YaRN is not enabled by default
- # To enable it, please refer to this guide: https://huggingface.co/Qwen/Qwen3-30B-A3B#processing-long-texts
- rope_scaling = self.hparams.get("rope_scaling") or {}
- if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
- self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
- self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
- self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
_experts: list[dict[str, Tensor]] | None = None
self.gguf_writer.add_head_count_kv(n_head_kv)
self.gguf_writer.add_layer_norm_rms_eps(rms_eps)
self.gguf_writer.add_rope_dimension_count(rope_dims)
- self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"]))
+ self.gguf_writer.add_rope_freq_base(self.rope_parameters.get("full_attention", self.rope_parameters)["rope_theta"])
self.gguf_writer.add_file_type(self.ftype)
sliding_window = self.hparams.get("sliding_window")
# use zero value of sliding_window to distinguish Phi-4 from other PHI3 models
self.gguf_writer.add_value_length(hparams.get("hidden_size_per_head", 128))
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06))
- self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10000))
+ self.gguf_writer.add_rope_freq_base(self.rope_parameters.get("rope_theta", 10000))
# Mamba parameters
self.gguf_writer.add_ssm_state_size(hparams.get("mamba_d_state", 64))
special_vocab.add_to_gguf(self.gguf_writer)
- def set_gguf_parameters(self):
- self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
- self.gguf_writer.add_block_count(self.block_count)
- self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
- self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
- self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
- self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
- self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
- self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
- self.gguf_writer.add_file_type(self.ftype)
- rope_scaling = self.hparams.get("rope_scaling") or {}
- if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling:
- self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
- self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
-
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
num_heads = self.hparams["num_attention_heads"]
num_kv_heads = self.hparams["num_key_value_heads"]
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(rope_dim)
- rope_scaling = self.hparams.get("rope_scaling") or {}
- if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling:
- self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
- self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
-
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")
def set_gguf_parameters(self):
super().set_gguf_parameters()
- self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
if self.is_moe:
self.gguf_writer.add_moe_every_n_layers(self.hparams["moe_every_n_layers"])
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
super().set_gguf_parameters()
# jina-embeddings-v3
- if rotary_emb_base := self.hparams.get("rotary_emb_base"):
- self.gguf_writer.add_rope_freq_base(rotary_emb_base)
lora_alpha = self.hparams.get("lora_alpha")
if lora_prompt_prefixes := self.hparams.get("task_instructions"):
assert self._lora_files and all(lora_name in lora_prompt_prefixes for lora_name in self._lora_files.keys())
self._set_vocab_gpt2()
def set_gguf_parameters(self):
+ super().set_gguf_parameters()
hparams = self.hparams
# some default values are not specified in the hparams
self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 131072))
- self.gguf_writer.add_embedding_length(hparams["hidden_size"])
- self.gguf_writer.add_block_count(self.block_count)
- self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 8))
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6))
self.gguf_writer.add_key_length(hparams.get("head_dim", 256))
self.gguf_writer.add_value_length(hparams.get("head_dim", 256))
- self.gguf_writer.add_file_type(self.ftype)
- self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers
+ self.gguf_writer.add_rope_freq_base(self.rope_parameters.get("full_attention", self.rope_parameters).get("rope_theta", 1_000_000.0)) # for global layers
# attn_logit_softcapping is removed in Gemma3
assert hparams.get("attn_logit_softcapping") is None
if (final_logit_softcap := hparams.get("final_logit_softcapping")):
if hparams.get("sliding_window_pattern") != 1:
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4))
- if hparams.get("rope_scaling") is not None:
- rope_scaling = hparams["rope_scaling"]
- if rope_scaling["rope_type"] == "linear":
- # important: this rope_scaling is only applied for global layers, and not used by 1B model
- self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
- self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
- elif rope_scaling["rope_type"] == "yarn":
- self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
- self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
- self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
- self.gguf_writer.add_rope_scaling_yarn_ext_factor(rope_scaling["extrapolation_factor"])
- self.gguf_writer.add_rope_scaling_yarn_beta_fast(rope_scaling["beta_fast"])
- self.gguf_writer.add_rope_scaling_yarn_beta_slow(rope_scaling["beta_slow"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
def set_gguf_parameters(self):
super().set_gguf_parameters()
- rope_scaling = self.hparams.get("rope_scaling") or {}
- if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
- self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
- self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
- self.gguf_writer.add_rope_scaling_attn_factors(rope_scaling["attention_factor"])
- self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
-
if "sliding_window" in self.hparams:
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
- rope_scaling = self.hparams.get("rope_scaling") or {}
- if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
- self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
- self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
- self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
-
+ if (rope_mscale_all := self.rope_parameters.get("mscale_all_dim")) is not None:
# [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
# note: for legacy reasons, this is not consistent with the other usages of self.gguf_writer.add_rope_scaling_yarn_log_mul
# ref https://github.com/ggml-org/llama.cpp/pull/17945
- self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_scaling["mscale_all_dim"])
+ self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_mscale_all)
_experts: list[dict[str, Tensor]] | None = None
if (rope_dim := self.hparams.get("head_dim")) is None:
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
- rope_scaling = self.hparams.get("rope_scaling") or {}
- if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
- self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
- self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
- self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.startswith("model.visual."): # ignore visual part of Glm4v
model_arch = gguf.MODEL_ARCH.EXAONE
def set_gguf_parameters(self):
+ super().set_gguf_parameters()
hparams = self.hparams
assert (hparams["activation_function"] == "silu")
- max_position_embeddings = hparams["max_position_embeddings"]
- embed_dim = hparams["hidden_size"]
- num_heads = hparams["num_attention_heads"]
- num_kv_heads = hparams.get("num_key_value_heads", num_heads)
- layer_norm_eps = hparams["layer_norm_epsilon"]
- intermediate_size = hparams["intermediate_size"] if "intermediate_size" in hparams else 4 * embed_dim
- # ignore for now as EXAONE-3.0-7.8B-Instruct attentino_dropout is 0.0
- # attention_dropout_rate = hparams["attention_dropout"]
- # ignore for now as EXAONE-3.0-7.8B-Instruct embed_dropout is 0.0
- # embed_dropout_rate = hparams["embed_dropout"]
- self.gguf_writer.add_embedding_length(embed_dim)
- self.gguf_writer.add_head_count(num_heads)
- self.gguf_writer.add_head_count_kv(num_kv_heads)
- self.gguf_writer.add_context_length(max_position_embeddings)
- self.gguf_writer.add_layer_norm_rms_eps(layer_norm_eps)
- self.gguf_writer.add_feed_forward_length(intermediate_size)
- self.gguf_writer.add_block_count(self.block_count)
- self.gguf_writer.add_file_type(self.ftype)
-
- if (rope_theta := self.hparams.get("rope_theta")) is not None:
- self.gguf_writer.add_rope_freq_base(rope_theta)
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"], optional=True)
rotary_factor = rotary_factor if rotary_factor is not None else 1.0
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
- rope_scaling = self.hparams.get("rope_scaling") or {}
- if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling:
- self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
- self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
- if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
- if rope_scaling.get("rope_type", '').lower() == "llama3":
- base = self.hparams.get("rope_theta", 10000.0)
+ if rope_params := self.rope_parameters.get("full_attention", self.rope_parameters):
+ if rope_params.get("rope_type", '').lower() == "llama3":
+ base = self.rope_parameters.get("rope_theta", 10000.0)
if (dim := self.hparams.get("head_dim")) is None:
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
- factor = rope_scaling.get("factor", 8.0)
- low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
- high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
+ factor = rope_params.get("factor", 8.0)
+ low_freq_factor = rope_params.get("low_freq_factor", 1.0)
+ high_freq_factor = rope_params.get("high_freq_factor", 4.0)
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
low_freq_wavelen = old_context_len / low_freq_factor
if len(sliding_window_pattern) == hparams["num_hidden_layers"]:
self.gguf_writer.add_sliding_window_pattern(sliding_window_pattern)
- rope_scaling = self.hparams.get("rope_scaling") or {}
- if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling:
- self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
- self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
-
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
- if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
- if rope_scaling.get("rope_type", '').lower() == "llama3":
- base = self.hparams.get("rope_theta", 10_000.0)
+ if rope_params := self.rope_parameters.get("full_attention", self.rope_parameters):
+ if rope_params.get("rope_type", '').lower() == "llama3":
+ base = rope_params.get("rope_theta", 10_000.0)
if (dim := self.hparams.get("head_dim")) is None:
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
- factor = rope_scaling.get("factor", 16.0)
- low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
- high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
+ factor = rope_params.get("factor", 16.0)
+ low_freq_factor = rope_params.get("low_freq_factor", 1.0)
+ high_freq_factor = rope_params.get("high_freq_factor", 4.0)
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
low_freq_wavelen = old_context_len / low_freq_factor
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(rope_dim)
- rope_scaling = self.hparams.get("rope_scaling") or {}
- if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
- self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
- self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
- self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
- else:
- self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
- rope_scaling = self.hparams.get("rope_scaling") or {}
- if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
- self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
- self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
- self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
- else:
- self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
self.gguf_writer.add_experts_per_group(2)
# FIXME?: Hardcoded https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L376
self.gguf_writer.add_expert_group_scale(0.05)
- # YaRN is not enabled by default
- # To enable it, please refer to this guide: https://huggingface.co/Qwen/Qwen3-30B-A3B#processing-long-texts
- rope_scaling = self.hparams.get("rope_scaling") or {}
- if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
- self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
- self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
- self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
_experts: list[dict[str, Tensor]] | None = None
_chunk_experts: list[dict[str, Tensor]] | None = None
assert self.d_inner % self.d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {self.d_head}"
# Add any other Falcon Mamba2 specific configuration
- self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"]))
+ self.gguf_writer.add_rope_freq_base(self.rope_parameters["rope_theta"])
@ModelBase.register("HunYuanMoEV1ForCausalLM")
self.gguf_writer.add_expert_shared_count(moe_shared_expert[0])
# Rope
- rope_scaling = hparams.get("rope_scaling", {})
- if rope_scaling.get("type") == "dynamic":
+ if self.rope_parameters.get("rope_type") == "dynamic":
# HunYuan uses NTK Aware Alpha based scaling. Original implementation: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
# 1000 corresponds to a usable context length of 256k (https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/report/Hunyuan_A13B_Technical_Report.pdf)
- alpha = rope_scaling.get("alpha", 1000)
- base = hparams.get("rope_theta", 10000.0)
+ alpha = self.rope_parameters.get("alpha", 1000)
+ base = self.rope_parameters.get("rope_theta", 10000.0)
dim = (hparams["hidden_size"] // hparams["num_attention_heads"]) # 128
scaled_base = base * (alpha ** (dim / (dim - 2))) # 10000 * (1000 ** (128 / 126)) = 11158839.9251
self.gguf_writer.add_rope_freq_base(scaled_base)
hparams = self.hparams
# Rope
- rope_scaling = hparams.get("rope_scaling", {})
- if rope_scaling.get("type") == "dynamic":
+ if self.rope_parameters.get("rope_type") == "dynamic":
# HunYuan uses NTK Aware Alpha based scaling. Original implementation: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
# 1000 corresponds to a usable context length of 256k (https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/report/Hunyuan_A13B_Technical_Report.pdf)
- alpha = rope_scaling.get("alpha", 50)
- base = hparams.get("rope_theta", 10000.0)
+ alpha = self.rope_parameters.get("alpha", 50)
+ base = self.rope_parameters.get("rope_theta", 10000.0)
dim = hparams["head_dim"]
scaled_base = base * (alpha ** (dim / (dim - 2)))
self.gguf_writer.add_rope_freq_base(scaled_base)
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size"])
- rope_scaling = self.hparams.get("rope_scaling") or {}
- rope_type = rope_scaling.get("rope_type", rope_scaling.get("type"))
- assert rope_type == "yarn", f"GPT-OSS only supports yarn rope scaling, got {rope_type}"
- self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
- self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
- self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling.get("original_max_position_embeddings", 4096))
-
@ModelBase.register("Lfm2ForCausalLM", "LFM2ForCausalLM")
class LFM2Model(TextModel):
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
else:
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
- # YaRN is not enabled by default
- # To enable it, please refer to this guide: https://huggingface.co/Qwen/Qwen3-30B-A3B#processing-long-texts
- rope_scaling = self.hparams.get("rope_scaling") or {}
- if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
- self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
- self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
- self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
sliding_window_layout = self.hparams.get("sliding_window_layout")
if sliding_window_layout: