]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
gemma2: add sliding window mask (#8227)
authorXuan Son Nguyen <redacted>
Mon, 1 Jul 2024 16:48:34 +0000 (18:48 +0200)
committerGitHub <redacted>
Mon, 1 Jul 2024 16:48:34 +0000 (18:48 +0200)
* gemma2: add sliding window mask

* fix data_swa uninitialized

* better naming

* add co-author

Co-authored-by: Arlo Phoenix <redacted>
* replace list with single tensor

* update

* llama : minor styling

* convert : add sanity check for query_pre_attn_scalar

* fix small typo in README

---------

Co-authored-by: Arlo Phoenix <redacted>
Co-authored-by: Georgi Gerganov <redacted>
README.md
convert-hf-to-gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
src/llama.cpp

index c136d4a5cb9c95bbf8a53cc6694824639b3fa1c7..daba70717312e2e00d2f002e87dc6442dcb54395 100644 (file)
--- a/README.md
+++ b/README.md
@@ -218,7 +218,7 @@ Unless otherwise noted these projects are open-source with permissive licensing:
 **Tools:**
 
 - [akx/ggify](https://github.com/akx/ggify) – download PyTorch models from HuggingFace Hub and convert them to GGML
-[crashr/gppm](https://github.com/crashr/gppm) – launch llama.cpp instances utilizing NVIDIA Tesla P40 or P100 GPUs with reduced idle power consumption
+[crashr/gppm](https://github.com/crashr/gppm) – launch llama.cpp instances utilizing NVIDIA Tesla P40 or P100 GPUs with reduced idle power consumption
 
 ---
 
index 3ef2f69e7c0dfaa8aada0eac9dc39940b923b76c..4a7f500ff7d5ca312b5ee660e9ecdfbe24398c36 100755 (executable)
@@ -2369,6 +2369,12 @@ class Gemma2Model(Model):
         self.gguf_writer.add_final_logit_softcapping(
             self.hparams["final_logit_softcapping"]
         )
+        self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
+
+        # sanity check
+        attn_scalar = self.hparams["query_pre_attn_scalar"]
+        if attn_scalar != hparams["hidden_size"] / hparams["num_attention_heads"]:
+            raise ValueError("query_pre_attn_scalar must be equal to n_embd / n_head")
 
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
         del bid  # unusem
index 9bfa891d5dc52922c881af6899a9875c32870f9d..e87c58266158a1a7c70eb4c051b934b90868ffa1 100644 (file)
@@ -66,6 +66,7 @@ class Keys:
         Q_LORA_RANK       = "{arch}.attention.q_lora_rank"
         KV_LORA_RANK      = "{arch}.attention.kv_lora_rank"
         REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
+        SLIDING_WINDOW    = "{arch}.attention.sliding_window"
 
     class Rope:
         DIMENSION_COUNT         = "{arch}.rope.dimension_count"
index 1aeb0d9b08685a1d21f7652666c73e7d07ff75ef..75a8b2636a6a216b79b5348e2ca030daabe34ddc 100644 (file)
@@ -552,6 +552,9 @@ class GGUFWriter:
     def add_relative_attn_buckets_count(self, value: int) -> None:
         self.add_uint32(Keys.Attention.REL_BUCKETS_COUNT.format(arch=self.arch), value)
 
+    def add_sliding_window(self, value: int) -> None:
+        self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value)
+
     def add_pooling_type(self, value: PoolingType) -> None:
         self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
 
index 2a4d73856fcd93a72c95c49da045d707c7654f7b..eea532f6ac2ff32115b6277cb2155d16d9e0f724 100644 (file)
@@ -317,6 +317,7 @@ enum llm_kv {
     LLM_KV_ATTENTION_Q_LORA_RANK,
     LLM_KV_ATTENTION_KV_LORA_RANK,
     LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
+    LLM_KV_ATTENTION_SLIDING_WINDOW,
 
     LLM_KV_ROPE_DIMENSION_COUNT,
     LLM_KV_ROPE_FREQ_BASE,
@@ -409,6 +410,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_ATTENTION_Q_LORA_RANK,            "%s.attention.q_lora_rank"            },
     { LLM_KV_ATTENTION_KV_LORA_RANK,           "%s.attention.kv_lora_rank"           },
     { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
+    { LLM_KV_ATTENTION_SLIDING_WINDOW,         "%s.attention.sliding_window"         },
 
     { LLM_KV_ROPE_DIMENSION_COUNT,          "%s.rope.dimension_count"                 },
     { LLM_KV_ROPE_FREQ_BASE,                "%s.rope.freq_base"                       },
@@ -2085,6 +2087,7 @@ struct llama_hparams {
     uint32_t n_head_kv;
     uint32_t n_layer;
     uint32_t n_rot;
+    uint32_t n_swa = 0; // sliding window attention (SWA)
     uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
     uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
     uint32_t n_ff;
@@ -2139,6 +2142,7 @@ struct llama_hparams {
         if (this->n_head_kv     != other.n_head_kv)     return true;
         if (this->n_layer       != other.n_layer)       return true;
         if (this->n_rot         != other.n_rot)         return true;
+        if (this->n_swa         != other.n_swa)         return true;
         if (this->n_embd_head_k != other.n_embd_head_k) return true;
         if (this->n_embd_head_v != other.n_embd_head_v) return true;
         if (this->n_ff          != other.n_ff)          return true;
@@ -2649,17 +2653,18 @@ struct llama_context {
     void *              abort_callback_data = nullptr;
 
     // input tensors
-    struct ggml_tensor * inp_tokens;    // I32 [n_batch]
-    struct ggml_tensor * inp_embd;      // F32 [n_embd, n_batch]
-    struct ggml_tensor * inp_pos;       // I32 [n_batch]
-    struct ggml_tensor * inp_out_ids;   // I32 [n_outputs]
-    struct ggml_tensor * inp_KQ_mask;   // F32 [kv_size, n_batch]
-    struct ggml_tensor * inp_K_shift;   // I32 [kv_size]
-    struct ggml_tensor * inp_mean;      // F32 [n_batch, n_batch]
-    struct ggml_tensor * inp_cls;       // I32 [n_batch]
-    struct ggml_tensor * inp_s_copy;    // I32 [kv_size]
-    struct ggml_tensor * inp_s_mask;    // F32 [1, n_kv]
-    struct ggml_tensor * inp_s_seq;     // I32 [n_kv, n_batch]
+    struct ggml_tensor * inp_tokens;      // I32 [n_batch]
+    struct ggml_tensor * inp_embd;        // F32 [n_embd, n_batch]
+    struct ggml_tensor * inp_pos;         // I32 [n_batch]
+    struct ggml_tensor * inp_out_ids;     // I32 [n_outputs]
+    struct ggml_tensor * inp_KQ_mask;     // F32 [kv_size, n_batch]
+    struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch]
+    struct ggml_tensor * inp_K_shift;     // I32 [kv_size]
+    struct ggml_tensor * inp_mean;        // F32 [n_batch, n_batch]
+    struct ggml_tensor * inp_cls;         // I32 [n_batch]
+    struct ggml_tensor * inp_s_copy;      // I32 [kv_size]
+    struct ggml_tensor * inp_s_mask;      // F32 [1, n_kv]
+    struct ggml_tensor * inp_s_seq;       // I32 [n_kv, n_batch]
 
     // control vectors
     struct llama_control_vector cvec;
@@ -4709,6 +4714,8 @@ static void llm_load_hparams(
             } break;
         case LLM_ARCH_GEMMA2:
             {
+                hparams.n_swa = 4096; // default value of gemma 2
+                ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
                 ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
                 ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
@@ -5419,6 +5426,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     LLAMA_LOG_INFO("%s: n_head_kv        = %u\n",     __func__, hparams.n_head_kv);
     LLAMA_LOG_INFO("%s: n_layer          = %u\n",     __func__, hparams.n_layer);
     LLAMA_LOG_INFO("%s: n_rot            = %u\n",     __func__, hparams.n_rot);
+    LLAMA_LOG_INFO("%s: n_swa            = %u\n",     __func__, hparams.n_swa);
     LLAMA_LOG_INFO("%s: n_embd_head_k    = %u\n",     __func__, hparams.n_embd_head_k);
     LLAMA_LOG_INFO("%s: n_embd_head_v    = %u\n",     __func__, hparams.n_embd_head_v);
     LLAMA_LOG_INFO("%s: n_gqa            = %u\n",     __func__, hparams.n_gqa());
@@ -7775,17 +7783,18 @@ struct llm_build_context {
 
         ctx0 = ggml_init(params);
 
-        lctx.inp_tokens  = nullptr;
-        lctx.inp_embd    = nullptr;
-        lctx.inp_pos     = nullptr;
-        lctx.inp_out_ids = nullptr;
-        lctx.inp_KQ_mask = nullptr;
-        lctx.inp_K_shift = nullptr;
-        lctx.inp_mean    = nullptr;
-        lctx.inp_cls     = nullptr;
-        lctx.inp_s_copy  = nullptr;
-        lctx.inp_s_mask  = nullptr;
-        lctx.inp_s_seq   = nullptr;
+        lctx.inp_tokens      = nullptr;
+        lctx.inp_embd        = nullptr;
+        lctx.inp_pos         = nullptr;
+        lctx.inp_out_ids     = nullptr;
+        lctx.inp_KQ_mask     = nullptr;
+        lctx.inp_KQ_mask_swa = nullptr;
+        lctx.inp_K_shift     = nullptr;
+        lctx.inp_mean        = nullptr;
+        lctx.inp_cls         = nullptr;
+        lctx.inp_s_copy      = nullptr;
+        lctx.inp_s_mask      = nullptr;
+        lctx.inp_s_seq       = nullptr;
     }
 
     void free() {
@@ -7804,7 +7813,6 @@ struct llm_build_context {
         cb(lctx.inp_K_shift, "K_shift", -1);
         ggml_set_input(lctx.inp_K_shift);
 
-
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * rope_factors = build_rope_factors(il);
             struct ggml_tensor * tmp =
@@ -7939,16 +7947,27 @@ struct llm_build_context {
     }
 
     struct ggml_tensor * build_inp_KQ_mask(bool causal = true) {
-        if (causal) {
-            lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv,     GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
-        } else {
-            lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
-        }
+        lctx.inp_KQ_mask = causal
+            ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv,     GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
+            : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
         cb(lctx.inp_KQ_mask, "KQ_mask", -1);
         ggml_set_input(lctx.inp_KQ_mask);
+
         return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
     }
 
+    struct ggml_tensor * build_inp_KQ_mask_swa(bool causal = true) {
+        GGML_ASSERT(hparams.n_swa > 0);
+
+        lctx.inp_KQ_mask_swa = causal
+            ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv,     GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
+            : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+        cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1);
+        ggml_set_input(lctx.inp_KQ_mask_swa);
+
+        return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask_swa, GGML_TYPE_F16) : lctx.inp_KQ_mask_swa;
+    }
+
     struct ggml_tensor * build_inp_mean() {
         lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
         cb(lctx.inp_mean, "inp_mean", -1);
@@ -11029,9 +11048,14 @@ struct llm_build_context {
         struct ggml_tensor * inp_pos = build_inp_pos();
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+        // gemma 2 requires different mask for layers using sliding window (SWA)
+        struct ggml_tensor * KQ_mask     = build_inp_KQ_mask(true);
+        struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(true);
 
         for (int il = 0; il < n_layer; ++il) {
+            // (il % 2) layers use SWA
+            struct ggml_tensor * KQ_mask_l = (il % 2 == 0) ? KQ_mask_swa : KQ_mask;
+
             // norm
             cur = llm_build_norm(ctx0, inpL, hparams,
                     model.layers[il].attn_norm, NULL,
@@ -11067,7 +11091,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il);
             }
 
             cur = llm_build_norm(ctx0, cur, hparams,
@@ -12670,7 +12694,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
 
             GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
 
-            float * data = (float *) lctx.inp_KQ_mask->data;
+            float * data     = (float *) lctx.inp_KQ_mask->data;
+            float * data_swa = nullptr;
+
+            if (lctx.inp_KQ_mask_swa) {
+                data_swa = (float *) lctx.inp_KQ_mask_swa->data;
+            }
 
             // For causal attention, use only the previous KV cells
             // of the correct sequence for each token of the batch.
@@ -12692,6 +12721,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
                             }
                         }
                         data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
+
+                        // may need to cut off old tokens for sliding window
+                        if (data_swa) {
+                            if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
+                                f = -INFINITY;
+                            }
+                            data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
+                        }
                     }
                 }