]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama: Add attention and final logit soft-capping, update scaling factor to Gemma2...
authorAndrei <redacted>
Sun, 30 Jun 2024 03:44:08 +0000 (20:44 -0700)
committerGitHub <redacted>
Sun, 30 Jun 2024 03:44:08 +0000 (23:44 -0400)
* Add attention and final logit softcapping.

* fix

* Add custom add_ functions

* Disable flash attention for Gemma2

* Update src/llama.cpp

Co-authored-by: slaren <redacted>
* Add default value for attention and final logit softcap value

* Add custom kq scaling from Gemma2Attention

* Remove custom pre attention scaling and use computed value instead.

---------

Co-authored-by: slaren <redacted>
convert-hf-to-gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
src/llama.cpp

index 5bcc849db999d49ca962e86f018d354e985e3220..3ef2f69e7c0dfaa8aada0eac9dc39940b923b76c 100755 (executable)
@@ -2363,6 +2363,12 @@ class Gemma2Model(Model):
         self.gguf_writer.add_key_length(hparams["head_dim"])
         self.gguf_writer.add_value_length(hparams["head_dim"])
         self.gguf_writer.add_file_type(self.ftype)
+        self.gguf_writer.add_attn_logit_softcapping(
+            self.hparams["attn_logit_softcapping"]
+        )
+        self.gguf_writer.add_final_logit_softcapping(
+            self.hparams["final_logit_softcapping"]
+        )
 
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
         del bid  # unusem
index cf3d09e70d3e76477ac959d8b338b1d9e1ea3c6f..9bfa891d5dc52922c881af6899a9875c32870f9d 100644 (file)
@@ -50,6 +50,8 @@ class Keys:
         POOLING_TYPE                      = "{arch}.pooling_type"
         LOGIT_SCALE                       = "{arch}.logit_scale"
         DECODER_START_TOKEN_ID            = "{arch}.decoder_start_token_id"
+        ATTN_LOGIT_SOFTCAPPING            = "{arch}.attn_logit_softcapping"
+        FINAL_LOGIT_SOFTCAPPING           = "{arch}.final_logit_softcapping"
 
     class Attention:
         HEAD_COUNT        = "{arch}.attention.head_count"
index 9869f6fe3445ab15cba2849684666dc3688210b2..1aeb0d9b08685a1d21f7652666c73e7d07ff75ef 100644 (file)
@@ -516,6 +516,12 @@ class GGUFWriter:
     def add_logit_scale(self, value: float) -> None:
         self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)
 
+    def add_attn_logit_softcapping(self, value: float) -> None:
+        self.add_float32(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
+
+    def add_final_logit_softcapping(self, value: float) -> None:
+        self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
+
     def add_expert_count(self, count: int) -> None:
         self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count)
 
index 3edaa98e8d01b2255804fad52faed8eeb8996ed2..2a4d73856fcd93a72c95c49da045d707c7654f7b 100644 (file)
@@ -302,6 +302,8 @@ enum llm_kv {
     LLM_KV_POOLING_TYPE,
     LLM_KV_LOGIT_SCALE,
     LLM_KV_DECODER_START_TOKEN_ID,
+    LLM_KV_ATTN_LOGIT_SOFTCAPPING,
+    LLM_KV_FINAL_LOGIT_SOFTCAPPING,
 
     LLM_KV_ATTENTION_HEAD_COUNT,
     LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -392,6 +394,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_POOLING_TYPE ,                     "%s.pooling_type"                      },
     { LLM_KV_LOGIT_SCALE,                       "%s.logit_scale"                       },
     { LLM_KV_DECODER_START_TOKEN_ID,            "%s.decoder_start_token_id"            },
+    { LLM_KV_ATTN_LOGIT_SOFTCAPPING,            "%s.attn_logit_softcapping"            },
+    { LLM_KV_FINAL_LOGIT_SOFTCAPPING,           "%s.final_logit_softcapping"           },
 
     { LLM_KV_ATTENTION_HEAD_COUNT,             "%s.attention.head_count"             },
     { LLM_KV_ATTENTION_HEAD_COUNT_KV,          "%s.attention.head_count_kv"          },
@@ -2099,6 +2103,9 @@ struct llama_hparams {
     float f_norm_eps;
     float f_norm_rms_eps;
 
+    float f_attn_logit_softcapping = 50.0f;
+    float f_final_logit_softcapping = 30.0f;
+
     float    rope_attn_factor = 1.0f;
     float    rope_freq_base_train;
     float    rope_freq_scale_train;
@@ -2115,8 +2122,9 @@ struct llama_hparams {
     float f_max_alibi_bias = 0.0f;
     float f_logit_scale    = 0.0f;
 
-    bool causal_attn = true;
-    bool use_alibi   = false;
+    bool causal_attn   = true;
+    bool use_alibi     = false;
+    bool attn_soft_cap = false;
 
     enum llama_pooling_type      pooling_type            = LLAMA_POOLING_TYPE_NONE;
     enum llama_rope_type         rope_type               = LLAMA_ROPE_TYPE_NONE;
@@ -4702,6 +4710,9 @@ static void llm_load_hparams(
         case LLM_ARCH_GEMMA2:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
+                ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
+                hparams.attn_soft_cap = true;
 
                 switch (hparams.n_layer) {
                     case 42: model.type = e_model::MODEL_9B; break;
@@ -7579,6 +7590,12 @@ static struct ggml_tensor * llm_build_kqv(
             kq = ggml_scale(ctx, kq, 30);
         }
 
+        if (hparams.attn_soft_cap) {
+            kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping);
+            kq = ggml_tanh(ctx, kq);
+            kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping);
+        }
+
         kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
         cb(kq, "kq_soft_max_ext", il);
 
@@ -11039,7 +11056,7 @@ struct llm_build_context {
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Qcur, "Qcur", il);
 
-                Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));
+                Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head)));
                 cb(Qcur, "Qcur_scaled", il);
 
                 Kcur = ggml_rope_ext(
@@ -11106,6 +11123,12 @@ struct llm_build_context {
 
         // lm_head
         cur = ggml_mul_mat(ctx0, model.output, cur);
+
+        // final logit soft-capping
+        cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
+        cur = ggml_tanh(ctx0, cur);
+        cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
+
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -17379,6 +17402,12 @@ struct llama_context * llama_new_context_with_model(
         params.flash_attn = false;
     }
 
+    if (params.flash_attn && model->hparams.attn_soft_cap) {
+        LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__);
+        params.flash_attn = false;
+    }
+
+
     if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
         LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
         params.flash_attn = false;