]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : use n_swa + n_ubatch cells for SWA cache (#13833)
authorGeorgi Gerganov <redacted>
Sat, 31 May 2025 12:57:44 +0000 (15:57 +0300)
committerGitHub <redacted>
Sat, 31 May 2025 12:57:44 +0000 (15:57 +0300)
* llama : use n_swa + n_ubatch cells for SWA cache

ggml-ci

* llama : add warning about multi-sqeuence SWA contexts

include/llama.h
src/llama-context.cpp
src/llama-kv-cache.cpp
src/llama-kv-cache.h
src/llama-model.cpp
tools/server/server.cpp

index adc4c69288a3d1e134b75e092f59077e43aa4713..6e13358bbbd963554fd32d8a542d737c83f96b04 100644 (file)
@@ -366,6 +366,8 @@ extern "C" {
         bool no_perf;     // measure performance timings
         bool op_offload;  // offload host tensor operations to device
         bool swa_full;    // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
+                          // NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
+                          //       ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
     };
 
     // model quantization parameters
@@ -502,6 +504,7 @@ extern "C" {
     LLAMA_API int32_t llama_model_n_layer    (const struct llama_model * model);
     LLAMA_API int32_t llama_model_n_head     (const struct llama_model * model);
     LLAMA_API int32_t llama_model_n_head_kv  (const struct llama_model * model);
+    LLAMA_API int32_t llama_model_n_swa      (const struct llama_model * model);
 
     // Get the model's RoPE frequency scaling factor
     LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
index 57c7b42269798baf9966e3cfa2c44755cb9a4e92..d913497675d616695beb91ffee9e9e34142012ca 100644 (file)
@@ -123,6 +123,11 @@ llama_context::llama_context(
                 __func__, n_ctx_per_seq, hparams.n_ctx_train);
     }
 
+    if (!params.swa_full && cparams.n_seq_max > 1) {
+        LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
+                __func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
+    }
+
     if (!hparams.vocab_only) {
         // GPU backends
         for (auto * dev : model.devices) {
index 4726b700ff926cf5302cd118dd0be1ff3919b362..447c09c969baa06ea8653e275f3758538a6ad189 100644 (file)
@@ -1731,14 +1731,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
                      bool   swa_full,
                  uint32_t   kv_size,
                  uint32_t   n_seq_max,
-                 uint32_t   n_batch,
+                 uint32_t   n_ubatch,
                  uint32_t   n_pad) : hparams(model.hparams) {
     llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
     llama_kv_cache_unified::layer_filter_cb filter_swa  = [&](int32_t il) { return  model.hparams.is_swa(il); };
 
     const uint32_t size_base = kv_size;
 
-    uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, n_pad));
+    uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
 
     // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
     if (swa_full) {
index d2439e13603a0344b8cffa64810e1975060c42fd..75969878e318c81516cf3c0237ee138919a52010 100644 (file)
@@ -339,7 +339,7 @@ public:
                          bool   swa_full,
                      uint32_t   kv_size,
                      uint32_t   n_seq_max,
-                     uint32_t   n_batch,
+                     uint32_t   n_ubatch,
                      uint32_t   n_pad);
 
     ~llama_kv_cache_unified_iswa() = default;
index e85becbb8f6958118acf72ae6e995871da094d6b..44c07ff457dce1525ca742e75c48ce8b3ae5662e 100644 (file)
@@ -13230,7 +13230,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                             params.swa_full,
                             cparams.n_ctx,
                             cparams.n_seq_max,
-                            cparams.n_batch,
+                            cparams.n_ubatch,
                             padding);
                 } else {
                     GGML_ASSERT(!hparams.is_swa_any());
@@ -13593,6 +13593,10 @@ int32_t llama_model_n_head_kv(const llama_model * model) {
     return model->hparams.n_head_kv();
 }
 
+int32_t llama_model_n_swa(const llama_model * model) {
+    return model->hparams.n_swa;
+}
+
 // deprecated
 int32_t llama_n_ctx_train(const llama_model * model) {
     return llama_model_n_ctx_train(model);
index 46dbe5cc3951df7a61139919af21dec19ab1905a..4b92eeac9499b684633f6e46d699d01a73570234 100644 (file)
@@ -2016,11 +2016,6 @@ struct server_context {
                 params_base.n_cache_reuse = 0;
                 SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled");
             }
-
-            if (!params_base.speculative.model.path.empty()) {
-                SRV_ERR("%s\n", "err: speculative decode is not supported by this context");
-                return false;
-            }
         }
 
         return true;
@@ -3215,8 +3210,14 @@ struct server_context {
 
                             if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
                                 const auto pos_min = llama_kv_self_seq_pos_min(ctx, slot.id);
-                                if (pos_min > 0) {
-                                    SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min);
+                                if (pos_min == -1) {
+                                    SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min);
+                                    GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
+                                }
+
+                                const auto n_swa = llama_model_n_swa(model);
+                                if (pos_min > slot.n_past - n_swa) {
+                                    SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
                                     SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
                                             "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
                                     llama_kv_self_seq_rm(ctx, slot.id, 0, -1);