#include "llama-model.h"
#include <cinttypes>
+#include <cmath>
#include <cstring>
#include <limits>
#include <stdexcept>
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
}
+ if (cparams.yarn_ext_factor != 0) {
+ static auto get_mscale = [](float scale, float mscale) {
+ return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
+ };
+
+ const float factor = 1.0f / cparams.rope_freq_scale;
+
+ // ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348
+ if (hparams.rope_yarn_log_mul != 0.0f) {
+ // note: here we assume `mscale == 1.0f`
+ // TODO: start reading the actual value of mscale and handle the case where it is not 1.0f
+ float mscale = 1.0f;
+ const float mscale_all_dims = hparams.rope_yarn_log_mul;
+
+ // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
+ // special-case DEEPSEEK v2:
+ // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43
+ if (model.arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) {
+ mscale = mscale_all_dims;
+ }
+
+ cparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
+
+ LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n",
+ __func__, cparams.yarn_attn_factor, mscale, mscale_all_dims);
+ } else {
+ cparams.yarn_attn_factor = get_mscale(factor, 1.0f);
+ }
+
+ // when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor:
+ // https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544
+ //
+ // ref: https://github.com/ggml-org/llama.cpp/discussions/7416
+ // https://github.com/ggml-org/llama.cpp/pull/17945
+ cparams.yarn_attn_factor *= 1.0f / (1.0f + 0.1f * logf(factor));
+ }
+
cparams.yarn_attn_factor *= hparams.rope_attn_factor;
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
freq_base (cparams.rope_freq_base),
freq_scale (cparams.rope_freq_scale),
ext_factor (cparams.yarn_ext_factor),
- attn_factor (llama_hparams::yarn_attn_factor_adjust(cparams.yarn_attn_factor, cparams.rope_freq_scale, cparams.yarn_ext_factor)),
+ attn_factor (cparams.yarn_attn_factor),
beta_fast (cparams.yarn_beta_fast),
beta_slow (cparams.yarn_beta_slow),
norm_eps (hparams.f_norm_eps),
#include "ggml.h"
#include <cassert>
-#include <cmath>
void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) {
if (dense_first) {
return false;
}
-
-float llama_hparams::yarn_attn_factor_adjust(float attn_factor, float freq_scale, float ext_factor) {
- GGML_ASSERT(ext_factor >= 0.0f);
-
- if (ext_factor != 0.0f) {
- attn_factor *= 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
- }
-
- return attn_factor;
-}
// TODO: think of a better place for this function
// TODO: pack the SWA params in a struct?
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
-
- // when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor:
- // https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544
- //
- // ref: https://github.com/ggml-org/llama.cpp/discussions/7416
- // https://github.com/ggml-org/llama.cpp/pull/17945
- static float yarn_attn_factor_adjust(float attn_factor, float freq_scale, float ext_factor);
};
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
const auto & yarn_beta_fast = cparams.yarn_beta_fast;
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
- const auto & yarn_attn_factor = llama_hparams::yarn_attn_factor_adjust(cparams.yarn_attn_factor, cparams.rope_freq_scale, cparams.yarn_ext_factor);
+ const auto & yarn_attn_factor = cparams.yarn_attn_factor;
const auto & n_rot = hparams.n_rot;
const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE
default: throw std::runtime_error("unsupported model architecture");
}
- // ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348
- if (hparams.rope_yarn_log_mul != 0.0f) {
- const float factor = 1.0f / hparams.rope_freq_scale_train;
-
- // note: here we assume `mscale == 1.0f`
- // TODO: start reading the actual value of mscale and handle the case where it is not 1.0f
- float mscale = 1.0f;
- const float mscale_all_dims = hparams.rope_yarn_log_mul;
-
- // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
- // special-case DEEPSEEK v2:
- // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43
- if (arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) {
- mscale = mscale_all_dims;
- }
-
- static auto get_mscale = [](float scale, float mscale) {
- return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
- };
-
- hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
-
- LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n",
- __func__, hparams.yarn_attn_factor, mscale, mscale_all_dims);
- }
-
pimpl->n_bytes = ml.n_bytes;
pimpl->desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name();