]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : allow custom list of swa_layers (#13726)
authorXuan-Son Nguyen <redacted>
Fri, 23 May 2025 15:07:04 +0000 (17:07 +0200)
committerGitHub <redacted>
Fri, 23 May 2025 15:07:04 +0000 (17:07 +0200)
src/llama-hparams.cpp
src/llama-hparams.h
src/llama-model.cpp

index 4f84e56b3d3aac90cb1e7adbfcca1c51f18f4073..ac05db46e69ae44699b86c948f07ed5ca7cbabcc 100644 (file)
@@ -2,6 +2,26 @@
 
 #include "ggml.h"
 
+llama_hparams::llama_hparams() {
+    swa_layers.fill(false);
+}
+
+void llama_hparams::set_swa_pattern(uint32_t n_pattern) {
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
+    }
+}
+
+bool llama_hparams::is_swa_any() const {
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        if (swa_layers[il]) {
+            return true;
+        }
+    }
+
+    return false;
+}
+
 uint32_t llama_hparams::n_head(uint32_t il) const {
     if (il < n_layer) {
         return n_head_arr[il];
@@ -72,7 +92,7 @@ uint32_t llama_hparams::n_embd_v_s() const {
 
 bool llama_hparams::is_swa(uint32_t il) const {
     if (il < n_layer) {
-        return n_swa_pattern == 0 || (il % n_swa_pattern < (n_swa_pattern - 1));
+        return swa_layers[il];
     }
 
     GGML_ABORT("fatal error");
index 5222eedcfb0997509b510640d8b7dbd206eb6e77..fb638dce7fd259991f3cdd6a0a2796b74c6cdaaa 100644 (file)
@@ -102,20 +102,12 @@ struct llama_hparams {
 
     // Sliding Window Attention (SWA)
     llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
-
-    uint32_t n_swa = 0;         // the size of the sliding window (0 - no SWA)
-    uint32_t n_swa_pattern = 1; // this value n means that every nth layer is dense (i.e. non-SWA)
-                                // by default n == 1, all layers are dense
-                                // note that if n_swa_pattern == 0, all layers are SWA
-                                // example: n_swa_pattern = 3
-                                //   il == 0: swa
-                                //   il == 1: swa
-                                //   il == 2: dense
-                                //   il == 3: swa
-                                //   il == 4: swa
-                                //   il == 5: dense
-                                //   il == 6: swa
-                                //   etc ...
+    // the size of the sliding window (0 - no SWA)
+    uint32_t n_swa = 0;
+    // if swa_layers[il] == true, then layer il is SWA
+    // if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
+    // by default, all layers are dense
+    std::array<bool, LLAMA_MAX_LAYERS> swa_layers;
 
     // for State Space Models
     uint32_t ssm_d_conv  = 0;
@@ -153,6 +145,25 @@ struct llama_hparams {
     enum llama_rope_type         rope_type               = LLAMA_ROPE_TYPE_NONE;
     enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
 
+    llama_hparams();
+
+    // this value n_pattern means that every nth layer is dense (i.e. non-SWA)
+    // note that if n_pattern == 0, all layers are SWA
+    //           if n_pattern == 1, all layers are dense
+    // example: n_pattern = 3
+    //   il == 0: swa
+    //   il == 1: swa
+    //   il == 2: dense
+    //   il == 3: swa
+    //   il == 4: swa
+    //   il == 5: dense
+    //   il == 6: swa
+    //   etc ...
+    void set_swa_pattern(uint32_t n_pattern);
+
+    // return true if one of the layers is SWA
+    bool is_swa_any() const;
+
     uint32_t n_head(uint32_t il = 0) const;
 
     uint32_t n_head_kv(uint32_t il = 0) const;
index 3735e3c16f0d84bf309b52b78026493909fa8a99..227a1bcbaf0eb6355520e002fac3e42698b074ef 100644 (file)
@@ -574,7 +574,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
 
                 hparams.swa_type      = LLAMA_SWA_TYPE_CHUNKED;
                 hparams.n_swa         = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
-                hparams.n_swa_pattern = 4;    // pattern: 3 chunked - 1 full
+                hparams.set_swa_pattern(4);   // pattern: 3 chunked - 1 full
 
                 switch (hparams.n_expert) {
                     case 16:  type = LLM_TYPE_17B_16E; break;
@@ -863,7 +863,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     hparams.swa_type = LLAMA_SWA_TYPE_NONE;
 
                     hparams.n_swa         = 0;
-                    hparams.n_swa_pattern = 1;
+                    hparams.set_swa_pattern(1);
                 }
             } break;
         case LLM_ARCH_PHIMOE:
@@ -935,7 +935,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
             {
                 hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
                 hparams.n_swa = 4096; // default value of gemma 2
-                hparams.n_swa_pattern = 2;
+                hparams.set_swa_pattern(2);
                 hparams.attn_soft_cap = true;
 
                 ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW,    hparams.n_swa, false);
@@ -953,7 +953,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
         case LLM_ARCH_GEMMA3:
             {
                 hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
-                hparams.n_swa_pattern = 6;
+                hparams.set_swa_pattern(6);
 
                 hparams.rope_freq_base_train_swa  = 10000.0f;
                 hparams.rope_freq_scale_train_swa = 1.0f;
@@ -1038,7 +1038,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
         case LLM_ARCH_COHERE2:
             {
                 hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
-                hparams.n_swa_pattern = 4;
+                hparams.set_swa_pattern(4);
 
                 ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
                 ml.get_key(LLM_KV_LOGIT_SCALE,              hparams.f_logit_scale);
@@ -4320,7 +4320,7 @@ void llama_model::print_info() const {
         LLAMA_LOG_INFO("%s: n_head_kv        = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
         LLAMA_LOG_INFO("%s: n_rot            = %u\n",     __func__, hparams.n_rot);
         LLAMA_LOG_INFO("%s: n_swa            = %u\n",     __func__, hparams.n_swa);
-        LLAMA_LOG_INFO("%s: n_swa_pattern    = %u\n",     __func__, hparams.n_swa_pattern);
+        LLAMA_LOG_INFO("%s: is_swa_any       = %u\n",     __func__, hparams.is_swa_any());
         LLAMA_LOG_INFO("%s: n_embd_head_k    = %u\n",     __func__, hparams.n_embd_head_k);
         LLAMA_LOG_INFO("%s: n_embd_head_v    = %u\n",     __func__, hparams.n_embd_head_v);
         LLAMA_LOG_INFO("%s: n_gqa            = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il);        }, hparams.n_layer).c_str());
@@ -13216,7 +13216,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                 LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
 
                 if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
-                    GGML_ASSERT(hparams.n_swa_pattern != 1);
+                    GGML_ASSERT(hparams.is_swa_any());
 
                     res = new llama_kv_cache_unified_iswa(
                             *this,
@@ -13230,7 +13230,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                             cparams.n_batch,
                             padding);
                 } else {
-                    GGML_ASSERT(hparams.n_swa_pattern == 1);
+                    GGML_ASSERT(!hparams.is_swa_any());
 
                     res = new llama_kv_cache_unified(
                             *this,