LLM_KV_POOLING_TYPE,
LLM_KV_LOGIT_SCALE,
LLM_KV_DECODER_START_TOKEN_ID,
+ LLM_KV_ATTN_LOGIT_SOFTCAPPING,
+ LLM_KV_FINAL_LOGIT_SOFTCAPPING,
LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV,
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" },
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
+ { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
+ { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
float f_norm_eps;
float f_norm_rms_eps;
+ float f_attn_logit_softcapping = 50.0f;
+ float f_final_logit_softcapping = 30.0f;
+
float rope_attn_factor = 1.0f;
float rope_freq_base_train;
float rope_freq_scale_train;
float f_max_alibi_bias = 0.0f;
float f_logit_scale = 0.0f;
- bool causal_attn = true;
- bool use_alibi = false;
+ bool causal_attn = true;
+ bool use_alibi = false;
+ bool attn_soft_cap = false;
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
case LLM_ARCH_GEMMA2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+ ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
+ ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
+ hparams.attn_soft_cap = true;
switch (hparams.n_layer) {
case 42: model.type = e_model::MODEL_9B; break;
kq = ggml_scale(ctx, kq, 30);
}
+ if (hparams.attn_soft_cap) {
+ kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping);
+ kq = ggml_tanh(ctx, kq);
+ kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping);
+ }
+
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il);
- Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));
+ Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head)));
cb(Qcur, "Qcur_scaled", il);
Kcur = ggml_rope_ext(
// lm_head
cur = ggml_mul_mat(ctx0, model.output, cur);
+
+ // final logit soft-capping
+ cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
+ cur = ggml_tanh(ctx0, cur);
+ cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
+
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
params.flash_attn = false;
}
+ if (params.flash_attn && model->hparams.attn_soft_cap) {
+ LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__);
+ params.flash_attn = false;
+ }
+
+
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
params.flash_attn = false;