self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
+
+ # [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
+ # note: for legacy reasons, this is not consistent with the other usages of self.gguf_writer.add_rope_scaling_yarn_log_mul
+ # ref https://github.com/ggml-org/llama.cpp/pull/17945
self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_scaling["mscale_all_dim"])
_experts: list[dict[str, Tensor]] | None = None
MistralModel.set_mistral_config(self.gguf_writer, self.hparams)
yarn_params = self.hparams["yarn"]
self.gguf_writer.add_attn_temperature_length(yarn_params["original_max_position_embeddings"])
+
+ # [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
+ # note: for legacy reasons, this is not consistent with the other usages of self.gguf_writer.add_rope_scaling_yarn_log_mul
+ # ref https://github.com/ggml-org/llama.cpp/pull/17945
self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1) # mscale_all_dim * 0.1
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
freq_base (cparams.rope_freq_base),
freq_scale (cparams.rope_freq_scale),
ext_factor (cparams.yarn_ext_factor),
- attn_factor (cparams.yarn_attn_factor),
+ attn_factor (llama_hparams::yarn_attn_factor_adjust(cparams.yarn_attn_factor, cparams.rope_freq_scale, cparams.yarn_ext_factor)),
beta_fast (cparams.yarn_beta_fast),
beta_slow (cparams.yarn_beta_slow),
norm_eps (hparams.f_norm_eps),
#include "llama-hparams.h"
#include "ggml.h"
+
#include <cassert>
+#include <cmath>
void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) {
if (dense_first) {
return false;
}
+
+float llama_hparams::yarn_attn_factor_adjust(float attn_factor, float freq_scale, float ext_factor) {
+ GGML_ASSERT(ext_factor >= 0.0f);
+
+ if (ext_factor != 0.0f) {
+ attn_factor *= 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
+ }
+
+ return attn_factor;
+}
float rope_freq_base_train_swa;
float rope_freq_scale_train;
float rope_freq_scale_train_swa;
+
uint32_t n_ctx_orig_yarn;
float rope_yarn_log_mul = 0.0f;
// TODO: think of a better place for this function
// TODO: pack the SWA params in a struct?
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
+
+ // when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor:
+ // https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544
+ //
+ // ref: https://github.com/ggml-org/llama.cpp/discussions/7416
+ // https://github.com/ggml-org/llama.cpp/pull/17945
+ static float yarn_attn_factor_adjust(float attn_factor, float freq_scale, float ext_factor);
};
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
-
float freq_scale) const {
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
- const auto & yarn_ext_factor = cparams.yarn_ext_factor;
- const auto & yarn_beta_fast = cparams.yarn_beta_fast;
- const auto & yarn_beta_slow = cparams.yarn_beta_slow;
+ const auto & yarn_ext_factor = cparams.yarn_ext_factor;
+ const auto & yarn_beta_fast = cparams.yarn_beta_fast;
+ const auto & yarn_beta_slow = cparams.yarn_beta_slow;
+ const auto & yarn_attn_factor = llama_hparams::yarn_attn_factor_adjust(cparams.yarn_attn_factor, cparams.rope_freq_scale, cparams.yarn_ext_factor);
const auto & n_rot = hparams.n_rot;
const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE
? LLAMA_ROPE_TYPE_NEOX
: hparams.rope_type;
- // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
- // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
- const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2
- ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
- : cparams.yarn_attn_factor;
-
ggml_tensor * tmp;
if (ggml_is_quantized(cur->type)) {
// that have no expert_gating_func model parameter set
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX;
}
- ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false);
+
+ if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f)) {
+ // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
+ // cancel the factor from the convert script
+ hparams.rope_yarn_log_mul /= 0.1f;
+ }
// (optional) temperature tuning - used by mistral-large
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false);
- ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false);
- ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false);
- ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false);
+ ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false);
+ ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false);
+ ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f);
// TODO: maybe add n_attn_temp_floor_scale as a separate KV?
if (hparams.f_attn_temp_scale != 0.0f) {
}
}
- // TODO: this seems to be correct with the case of mscale == mscale_all_dims == 1.0f
- // but may need further verification with other values
- if (hparams.rope_yarn_log_mul != 0.0f) {
- float factor = 1.0f / hparams.rope_freq_scale_train;
- float mscale = 1.0f;
- float mscale_all_dims = hparams.rope_yarn_log_mul;
- static auto get_mscale = [](float scale, float mscale) {
- return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
- };
- hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
- }
-
switch (hparams.n_layer) {
case 26: type = LLM_TYPE_3B; break;
case 34: type = LLM_TYPE_8B; break;
default: throw std::runtime_error("unsupported model architecture");
}
+ // ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348
+ if (hparams.rope_yarn_log_mul != 0.0f) {
+ const float factor = 1.0f / hparams.rope_freq_scale_train;
+
+ // note: here we assume `mscale == 1.0f`
+ // TODO: start reading the actual value of mscale and handle the case where it is not 1.0f
+ float mscale = 1.0f;
+ const float mscale_all_dims = hparams.rope_yarn_log_mul;
+
+ // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
+ // special-case DEEPSEEK v2:
+ // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43
+ if (arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) {
+ mscale = mscale_all_dims;
+ }
+
+ static auto get_mscale = [](float scale, float mscale) {
+ return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
+ };
+
+ hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
+
+ LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n",
+ __func__, hparams.yarn_attn_factor, mscale, mscale_all_dims);
+ }
+
pimpl->n_bytes = ml.n_bytes;
pimpl->desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name();
LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);
+ LLAMA_LOG_INFO("%s: rope_yarn_log_mul= %.4f\n", __func__, hparams.rope_yarn_log_mul);
LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
// MRoPE (Multi-axis Rotary Position Embedding) sections
if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) {
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
- LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul);
}
if (arch == LLM_ARCH_QWEN2MOE) {
#include "models.h"
-
-
llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
// lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B
// We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly.
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
- const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale));
- const float kq_scale = 1.0f * mscale * mscale / sqrtf(float(n_embd_head_k));
- const float attn_factor = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
+ // And also: https://github.com/ggml-org/llama.cpp/pull/17945 [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
+
+ // first cancel the adjustment from llama_hparams::yarn_attn_factor_adjust to get the original attn_factor
+ GGML_ASSERT(ext_factor >= 0.0f);
+ const float attn_factor_org = attn_factor * (1.0f + 0.1f * logf(1.0f / freq_scale));
+
+ // use the original attn_factor to pre-scale the kq_scale
+ const float mscale = attn_factor_org * (1.0f + 0.1f * hparams.rope_yarn_log_mul * logf(1.0f / freq_scale));
+ const float kq_scale = 1.0f * mscale * mscale / sqrtf(float(n_embd_head_k));
ggml_tensor * cur;
ggml_tensor * inpL;