]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : refactor k-shift implementation + KV defragmentation (#5691)
authorGeorgi Gerganov <redacted>
Sun, 25 Feb 2024 20:12:24 +0000 (22:12 +0200)
committerGitHub <redacted>
Sun, 25 Feb 2024 20:12:24 +0000 (22:12 +0200)
* llama : refactor k-shift implementation

ggml-ci

* llama : rename llama_kv_cache_seq_shift to llama_kv_cache_seq_add

* llama : cont k-shift refactoring + normalize type names

ggml-ci

* minor : fix MPI builds

* llama : reuse n_rot from the build context

ggml-ci

* llama : revert enum name changes from this PR

ggml-ci

* llama : update llama_rope_type

* llama : add comment about rope values

* llama : fix build

* passkey : apply kv cache updates explicitly

ggml-ci

* llama : change name to llama_kv_cache_update()

* llama : add llama_kv_cache_seq_pos_max()

* passkey : fix llama_kv_cache_seq_pos_max() usage

* llama : some llama_kv_cell simplifications

* llama : add llama_kv_cache_compress (EXPERIMENTAL)

* llama : add alternative KV cache merging (EXPERIMENTAL)

* llama : add llama_kv_cache_defrag

* llama : comments

* llama : remove llama_kv_cache_compress

will add in a separate PR

ggml-ci

* llama : defragment via non-overlapping moves

* llama : ggml_graph based defrag implementation

ggml-ci

* llama : switch the loop order in build_defrag

* llama : add comments

examples/infill/infill.cpp
examples/main/main.cpp
examples/passkey/passkey.cpp
examples/server/server.cpp
llama.cpp
llama.h

index 92c67b7cff5c895da6acd0b5fb1a458f0502466b..d4b8729dd0283c2e0ece17d9022c22bc967054ca 100644 (file)
@@ -447,8 +447,8 @@ int main(int argc, char ** argv) {
                 LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
                     n_past, n_left, n_ctx, params.n_keep, n_discard);
 
-                llama_kv_cache_seq_rm   (ctx, 0, params.n_keep + 1            , params.n_keep + n_discard + 1);
-                llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
+                llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1            , params.n_keep + n_discard + 1);
+                llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
 
                 n_past -= n_discard;
 
index 7555dffe441f0aea3847b904f0a70ef417542942..34e84d0d42f87870838338bdad10e6d1cbbd1ced 100644 (file)
@@ -548,8 +548,8 @@ int main(int argc, char ** argv) {
                     LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
                             n_past, n_left, n_ctx, params.n_keep, n_discard);
 
-                    llama_kv_cache_seq_rm   (ctx, 0, params.n_keep            , params.n_keep + n_discard);
-                    llama_kv_cache_seq_shift(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
+                    llama_kv_cache_seq_rm (ctx, 0, params.n_keep            , params.n_keep + n_discard);
+                    llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
 
                     n_past -= n_discard;
 
@@ -576,9 +576,9 @@ int main(int argc, char ** argv) {
                     LOG("div:   [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
                     LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
 
-                    llama_kv_cache_seq_shift(ctx, 0, ga_i,                n_past,              ib*bd);
-                    llama_kv_cache_seq_div  (ctx, 0, ga_i + ib*bd,        ga_i + ib*bd + ga_w, ga_n);
-                    llama_kv_cache_seq_shift(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd,      dd);
+                    llama_kv_cache_seq_add(ctx, 0, ga_i,                n_past,              ib*bd);
+                    llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd,        ga_i + ib*bd + ga_w, ga_n);
+                    llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd,      dd);
 
                     n_past -= bd;
 
index e12a1cdf19a793165f5a7a9c6f04ffc19747ad34..47de67a93047f6778778d56b845e4527ab70f94a 100644 (file)
@@ -126,7 +126,7 @@ int main(int argc, char ** argv) {
     const int n_batch     = ctx_params.n_batch;
     const int n_batch_grp = ctx_params.n_batch/n_grp;
 
-    LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch);
+    LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d, n_junk = %d, i_pos = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch, n_junk, i_pos);
 
     // print the prompt token-by-token
 
@@ -146,10 +146,11 @@ int main(int argc, char ** argv) {
             const int ib = i/n_batch - 1;
             const int bd = n_batch_grp*(n_grp - 1);
 
-            llama_kv_cache_seq_shift(ctx, 0, n_past - n_batch,         n_past,         ib*bd);
-            llama_kv_cache_seq_div  (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
+            llama_kv_cache_seq_add (ctx, 0, n_past - n_batch,         n_past,         ib*bd);
+            llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
+            llama_kv_cache_update  (ctx);
 
-            n_past -= bd;
+            n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
         }
 
         llama_batch_clear(batch);
@@ -179,10 +180,12 @@ int main(int argc, char ** argv) {
 
         LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard);
 
-        llama_kv_cache_seq_rm   (ctx, 0, n_keep            , n_keep + n_discard);
-        llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx,  -n_discard);
+        llama_kv_cache_seq_rm (ctx, 0, n_keep            , n_keep + n_discard);
+        llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx,  -n_discard);
+        llama_kv_cache_defrag (ctx);
+        llama_kv_cache_update (ctx);
 
-        n_past -= n_discard;
+        n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
 
         llama_batch_clear(batch);
 
@@ -208,10 +211,12 @@ int main(int argc, char ** argv) {
         if (n_discard > 0) {
             LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
 
-            llama_kv_cache_seq_rm   (ctx, 0, n_keep            , n_keep + n_discard);
-            llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx,  -n_discard);
+            llama_kv_cache_seq_rm (ctx, 0, n_keep            , n_keep + n_discard);
+            llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx,  -n_discard);
+            llama_kv_cache_defrag (ctx);
+            llama_kv_cache_update (ctx);
 
-            n_past -= n_discard;
+            n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
         }
     }
 
index c1eb61678c38a48c2be20cce08827a164712fbc3..8aadc95a9728f7ea24252f5d9e5397984d925ee2 100644 (file)
@@ -1636,8 +1636,8 @@ struct llama_server_context
                         {"n_system_tokens", system_tokens.size()},
                         {"n_cache_tokens",  slot.cache_tokens.size()}
                     });
-                    llama_kv_cache_seq_rm   (ctx, slot.id, n_keep            , n_keep + n_discard);
-                    llama_kv_cache_seq_shift(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
+                    llama_kv_cache_seq_rm (ctx, slot.id, n_keep            , n_keep + n_discard);
+                    llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
 
                     for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++)
                     {
@@ -1941,9 +1941,9 @@ struct llama_server_context
                         LOG_TEE("div:   [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
                         LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
 
-                        llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
+                        llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
                         llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
-                        llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);
+                        llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);
 
                         slot.n_past_se -= bd;
 
index acd9be08a6e5e8342cbe26c83c185a413682b204..3424b1999ebdda7fd2a4d4a4e14c99c59f7a1bb7 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1550,8 +1550,9 @@ static const size_t MiB = 1024*kiB;
 static const size_t GiB = 1024*MiB;
 
 struct llama_hparams {
-    bool     vocab_only;
-    bool     rope_finetuned;
+    bool vocab_only;
+    bool rope_finetuned;
+
     uint32_t n_vocab;
     uint32_t n_ctx_train; // context size the model was trained on
     uint32_t n_embd;
@@ -1580,7 +1581,8 @@ struct llama_hparams {
     bool causal_attn = true;
     bool need_kq_pos = false;
 
-    uint32_t pooling_type = LLAMA_POOLING_TYPE_NONE;
+    enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
+    enum llama_rope_type    rope_type    = LLAMA_ROPE_TYPE_NONE;
 
     bool operator!=(const llama_hparams & other) const {
         if (this->vocab_only    != other.vocab_only)    return true;
@@ -1707,11 +1709,20 @@ struct llama_kv_cell {
     bool has_seq_id(const llama_seq_id & id) const {
         return seq_id.find(id) != seq_id.end();
     }
+
+    bool is_empty() const {
+        return seq_id.empty();
+    }
+
+    bool is_same_seq(const llama_kv_cell & other) const {
+        return seq_id == other.seq_id;
+    }
 };
 
 // ring-buffer of cached KV data
 struct llama_kv_cache {
     bool has_shift = false;
+    bool do_defrag = false;
 
     // Note: The value of head isn't only used to optimize searching
     // for a free KV slot. llama_decode_internal also uses it, so it
@@ -1723,6 +1734,9 @@ struct llama_kv_cache {
     // computed before each graph build
     uint32_t n = 0;
 
+    ggml_type type_k = GGML_TYPE_F16;
+    ggml_type type_v = GGML_TYPE_F16;
+
     std::vector<llama_kv_cell> cells;
 
     std::vector<struct ggml_tensor *> k_l; // per layer
@@ -1958,8 +1972,8 @@ struct llama_context {
 static bool llama_kv_cache_init(
              struct llama_kv_cache & cache,
                  const llama_model & model,
-                         ggml_type   ktype,
-                         ggml_type   vtype,
+                         ggml_type   type_k,
+                         ggml_type   type_v,
                           uint32_t   n_ctx,
                               bool   offload) {
     const struct llama_hparams & hparams = model.hparams;
@@ -1974,6 +1988,9 @@ static bool llama_kv_cache_init(
     cache.size = n_ctx;
     cache.used = 0;
 
+    cache.type_k = type_k;
+    cache.type_v = type_v;
+
     cache.cells.clear();
     cache.cells.resize(n_ctx);
 
@@ -2014,8 +2031,8 @@ static bool llama_kv_cache_init(
 
     for (int i = 0; i < (int) n_layer; i++) {
         struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
-        ggml_tensor * k = ggml_new_tensor_1d(ctx, ktype, n_embd_k_gqa*n_ctx);
-        ggml_tensor * v = ggml_new_tensor_1d(ctx, vtype, n_embd_v_gqa*n_ctx);
+        ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*n_ctx);
+        ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*n_ctx);
         ggml_format_name(k, "cache_k_l%d", i);
         ggml_format_name(v, "cache_v_l%d", i);
         cache.k_l.push_back(k);
@@ -2099,7 +2116,7 @@ static bool llama_kv_cache_find_slot(
 // find how many cells are currently in use
 static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
     for (uint32_t i = cache.size - 1; i > 0; --i) {
-        if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) {
+        if (cache.cells[i].pos >= 0 && !cache.cells[i].is_empty()) {
             return i + 1;
         }
     }
@@ -2135,7 +2152,7 @@ static void llama_kv_cache_seq_rm(
             } else {
                 continue;
             }
-            if (cache.cells[i].seq_id.empty()) {
+            if (cache.cells[i].is_empty()) {
                 // keep count of the number of used cells
                 if (cache.cells[i].pos >= 0) cache.used--;
 
@@ -2186,7 +2203,7 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
     if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
 }
 
-static void llama_kv_cache_seq_shift(
+static void llama_kv_cache_seq_add(
         struct llama_kv_cache & cache,
                  llama_seq_id   seq_id,
                     llama_pos   p0,
@@ -2204,10 +2221,14 @@ static void llama_kv_cache_seq_shift(
             cache.cells[i].delta += delta;
 
             if (cache.cells[i].pos < 0) {
-                if (!cache.cells[i].seq_id.empty()) cache.used--;
+                if (!cache.cells[i].is_empty()) {
+                    cache.used--;
+                }
                 cache.cells[i].pos = -1;
                 cache.cells[i].seq_id.clear();
-                if (new_head == cache.size) new_head = i;
+                if (new_head == cache.size) {
+                    new_head = i;
+                }
             }
         }
     }
@@ -2239,6 +2260,22 @@ static void llama_kv_cache_seq_div(
     }
 }
 
+static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama_seq_id seq_id) {
+    llama_pos result = 0;
+
+    for (uint32_t i = 0; i < cache.size; ++i) {
+        if (cache.cells[i].has_seq_id(seq_id)) {
+            result = std::max(result, cache.cells[i].pos);
+        }
+    }
+
+    return result;
+}
+
+static void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
+    cache.do_defrag = true;
+}
+
 //
 // model loading and saving
 //
@@ -2310,7 +2347,7 @@ namespace GGUFMeta {
         }
     };
 
-    struct ArrayInfo{
+    struct ArrayInfo {
         const gguf_type gt;
         const size_t length;
         const void * data;
@@ -2329,7 +2366,7 @@ namespace GGUFMeta {
     };
 
     template<typename T>
-    class GKV: public GKV_Base<T> {
+    class GKV : public GKV_Base<T> {
         GKV() = delete;
 
         public:
@@ -2352,39 +2389,39 @@ namespace GGUFMeta {
             return "unknown";
         }
 
-        static bool validate_override(const llama_model_kv_override_type expected_type, const struct llama_model_kv_override *override) {
-            if (!override) { return false; }
-            if (override->tag == expected_type) {
+        static bool validate_override(const llama_model_kv_override_type expected_type, const struct llama_model_kv_override * ovrd) {
+            if (!ovrd) { return false; }
+            if (ovrd->tag == expected_type) {
                 LLAMA_LOG_INFO("%s: Using metadata override (%5s) '%s' = ",
-                    __func__, override_type_to_str(override->tag), override->key);
-                switch (override->tag) {
+                    __func__, override_type_to_str(ovrd->tag), ovrd->key);
+                switch (ovrd->tag) {
                     case LLAMA_KV_OVERRIDE_TYPE_BOOL:  {
-                        LLAMA_LOG_INFO("%s\n", override->bool_value ? "true" : "false");
+                        LLAMA_LOG_INFO("%s\n", ovrd->bool_value ? "true" : "false");
                     } break;
                     case LLAMA_KV_OVERRIDE_TYPE_INT:   {
-                        LLAMA_LOG_INFO("%" PRId64 "\n", override->int_value);
+                        LLAMA_LOG_INFO("%" PRId64 "\n", ovrd->int_value);
                     } break;
                     case LLAMA_KV_OVERRIDE_TYPE_FLOAT: {
-                        LLAMA_LOG_INFO("%.6f\n", override->float_value);
+                        LLAMA_LOG_INFO("%.6f\n", ovrd->float_value);
                     } break;
                     default:
                         // Shouldn't be possible to end up here, but just in case...
                         throw std::runtime_error(
                             format("Unsupported attempt to override %s type for metadata key %s\n",
-                                override_type_to_str(override->tag), override->key));
+                                override_type_to_str(ovrd->tag), ovrd->key));
                 }
                 return true;
             }
             LLAMA_LOG_WARN("%s: Warning: Bad metadata override type for key '%s', expected %s but got %s\n",
-                __func__, override->key, override_type_to_str(expected_type), override_type_to_str(override->tag));
+                __func__, ovrd->key, override_type_to_str(expected_type), override_type_to_str(ovrd->tag));
             return false;
         }
 
         template<typename OT>
         static typename std::enable_if<std::is_same<OT, bool>::value, bool>::type
-        try_override(OT & target, const struct llama_model_kv_override *override) {
-            if (validate_override(LLAMA_KV_OVERRIDE_TYPE_BOOL, override)) {
-                target = override->bool_value;
+        try_override(OT & target, const struct llama_model_kv_override * ovrd) {
+            if (validate_override(LLAMA_KV_OVERRIDE_TYPE_BOOL, ovrd)) {
+                target = ovrd->bool_value;
                 return true;
             }
             return false;
@@ -2392,9 +2429,9 @@ namespace GGUFMeta {
 
         template<typename OT>
         static typename std::enable_if<!std::is_same<OT, bool>::value && std::is_integral<OT>::value, bool>::type
-        try_override(OT & target, const struct llama_model_kv_override *override) {
-            if (validate_override(LLAMA_KV_OVERRIDE_TYPE_INT, override)) {
-                target = override->int_value;
+        try_override(OT & target, const struct llama_model_kv_override * ovrd) {
+            if (validate_override(LLAMA_KV_OVERRIDE_TYPE_INT, ovrd)) {
+                target = ovrd->int_value;
                 return true;
             }
             return false;
@@ -2402,9 +2439,9 @@ namespace GGUFMeta {
 
         template<typename OT>
         static typename std::enable_if<std::is_floating_point<OT>::value, bool>::type
-        try_override(T & target, const struct llama_model_kv_override *override) {
-            if (validate_override(LLAMA_KV_OVERRIDE_TYPE_FLOAT, override)) {
-                target = override->float_value;
+        try_override(T & target, const struct llama_model_kv_override * ovrd) {
+            if (validate_override(LLAMA_KV_OVERRIDE_TYPE_FLOAT, ovrd)) {
+                target = ovrd->float_value;
                 return true;
             }
             return false;
@@ -2412,17 +2449,17 @@ namespace GGUFMeta {
 
         template<typename OT>
         static typename std::enable_if<std::is_same<OT, std::string>::value, bool>::type
-        try_override(T & target, const struct llama_model_kv_override *override) {
+        try_override(T & target, const struct llama_model_kv_override * ovrd) {
             (void)target;
-            (void)override;
-            if (!override) { return false; }
+            (void)ovrd;
+            if (!ovrd) { return false; }
             // Currently, we should never end up here so it would be a bug if we do.
             throw std::runtime_error(format("Unsupported attempt to override string type for metadata key %s\n",
-                override ? override->key : "NULL"));
+                ovrd ? ovrd->key : "NULL"));
         }
 
-        static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override *override = nullptr) {
-            if (try_override<T>(target, override)) {
+        static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override * ovrd = nullptr) {
+            if (try_override<T>(target, ovrd)) {
                 return true;
             }
             if (k < 0) { return false; }
@@ -2430,12 +2467,12 @@ namespace GGUFMeta {
             return true;
         }
 
-        static bool set(const gguf_context * ctx, const char * key, T & target, const struct llama_model_kv_override *override = nullptr) {
-            return set(ctx, gguf_find_key(ctx, key), target, override);
+        static bool set(const gguf_context * ctx, const char * key, T & target, const struct llama_model_kv_override * ovrd = nullptr) {
+            return set(ctx, gguf_find_key(ctx, key), target, ovrd);
         }
 
-        static bool set(const gguf_context * ctx, const std::string & key, T & target, const struct llama_model_kv_override *override = nullptr) {
-            return set(ctx, key.c_str(), target, override);
+        static bool set(const gguf_context * ctx, const std::string & key, T & target, const struct llama_model_kv_override * ovrd = nullptr) {
+            return set(ctx, key.c_str(), target, ovrd);
         }
     };
 }
@@ -2846,6 +2883,15 @@ struct llama_model_loader {
     }
 };
 
+template<>
+bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) {
+    uint32_t tmp;
+    const bool found = get_key(kid, tmp, required);
+    result = (enum llama_pooling_type) tmp;
+    return found;
+}
+
+
 //
 // load LLaMA models
 //
@@ -2926,16 +2972,16 @@ static const char * llama_model_type_name(e_model type) {
         default:           return "?B";
     }
 }
+
 static const char * llama_model_vocab_type_name(enum llama_vocab_type type){
     switch (type) {
-        case LLAMA_VOCAB_TYPE_SPM:         return "SPM";
-        case LLAMA_VOCAB_TYPE_BPE:         return "BPE";
-        case LLAMA_VOCAB_TYPE_WPM:         return "WPM";
-        default:                           return "unknown";
+        case LLAMA_VOCAB_TYPE_SPM: return "SPM";
+        case LLAMA_VOCAB_TYPE_BPE: return "BPE";
+        case LLAMA_VOCAB_TYPE_WPM: return "WPM";
+        default:                   return "unknown";
     }
 }
 
-
 static void llm_load_arch(llama_model_loader & ml, llama_model & model) {
     model.arch = ml.get_arch();
     if (model.arch == LLM_ARCH_UNKNOWN) {
@@ -3112,10 +3158,10 @@ static void llm_load_hparams(
             } break;
         case LLM_ARCH_BERT:
             {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-                ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps);
+                ml.get_key(LLM_KV_ATTENTION_CAUSAL,           hparams.causal_attn);
                 ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
-                ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
+                ml.get_key(LLM_KV_POOLING_TYPE,               hparams.pooling_type);
 
                 switch (hparams.n_layer) {
                     case 3:
@@ -3133,10 +3179,10 @@ static void llm_load_hparams(
             } break;
         case LLM_ARCH_NOMIC_BERT:
             {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-                ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps);
+                ml.get_key(LLM_KV_ATTENTION_CAUSAL,           hparams.causal_attn);
                 ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
-                ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
+                ml.get_key(LLM_KV_POOLING_TYPE,               hparams.pooling_type);
 
                 if (hparams.n_layer == 12 && hparams.n_embd == 768) {
                     model.type = e_model::MODEL_137M;
@@ -3275,6 +3321,8 @@ static void llm_load_hparams(
     if (hparams.f_max_alibi_bias > 0.0f) {
         hparams.need_kq_pos = true;
     }
+
+    hparams.rope_type = llama_rope_type(&model);
 }
 
 // TODO: This should probably be in llama.h
@@ -3577,6 +3625,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     LLAMA_LOG_INFO("%s: n_ff             = %u\n",     __func__, hparams.n_ff);
     LLAMA_LOG_INFO("%s: n_expert         = %u\n",     __func__, hparams.n_expert);
     LLAMA_LOG_INFO("%s: n_expert_used    = %u\n",     __func__, hparams.n_expert_used);
+    LLAMA_LOG_INFO("%s: pooling type     = %d\n",     __func__, hparams.pooling_type);
+    LLAMA_LOG_INFO("%s: rope type        = %d\n",     __func__, hparams.rope_type);
     LLAMA_LOG_INFO("%s: rope scaling     = %s\n",     __func__, rope_scaling_type);
     LLAMA_LOG_INFO("%s: freq_base_train  = %.1f\n",   __func__, hparams.rope_freq_base_train);
     LLAMA_LOG_INFO("%s: freq_scale_train = %g\n",     __func__, hparams.rope_freq_scale_train);
@@ -4598,12 +4648,6 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
 
 using llm_build_cb = std::function<void(struct ggml_tensor * cur, const char * name, int nl)>;
 
-enum llm_rope_type {
-    LLM_ROPE,
-    LLM_ROPE_NEOX,
-    LLM_ROPE_GLM,
-};
-
 enum llm_ffn_op_type {
     LLM_FFN_SILU,
     LLM_FFN_GELU,
@@ -4649,55 +4693,6 @@ static struct ggml_tensor * llm_build_inp_embd(
     return inpL;
 }
 
-// Persimmon: n_rot = n_embd_head_k/2
-// Other:     n_rot = n_embd_head_k
-static void llm_build_k_shift(
-      struct ggml_context * ctx,
-      const llama_hparams & hparams,
-      const llama_cparams & cparams,
-     const llama_kv_cache & kv,
-       struct ggml_cgraph * graph,
-       struct ggml_tensor * K_shift,
-            llm_rope_type   type,
-                  int64_t   n_ctx,
-                  float     freq_base,
-                  float     freq_scale,
-       const llm_build_cb & cb) {
-    const int64_t n_layer       = hparams.n_layer;
-    const int64_t n_head_kv     = hparams.n_head_kv;
-    const int64_t n_embd_head_k = hparams.n_embd_head_k;
-    const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa();
-    const int32_t n_rot         = hparams.n_rot;
-    const int32_t n_orig_ctx    = cparams.n_yarn_orig_ctx;
-    const float   ext_factor    = cparams.yarn_ext_factor;
-    const float   attn_factor   = cparams.yarn_attn_factor;
-    const float   beta_fast     = cparams.yarn_beta_fast;
-    const float   beta_slow     = cparams.yarn_beta_slow;
-
-    int rope_type = 0;
-
-    switch (type) {
-        case LLM_ROPE:      rope_type = 0; break;
-        case LLM_ROPE_NEOX: rope_type = 2; break;
-        case LLM_ROPE_GLM:  rope_type = 4; break;
-    }
-
-    for (int il = 0; il < n_layer; ++il) {
-        struct ggml_tensor * tmp =
-            // we rotate only the first n_rot dimensions
-            ggml_rope_custom_inplace(ctx,
-                    ggml_view_3d(ctx, kv.k_l[il],
-                        n_embd_head_k, n_head_kv, n_ctx,
-                        ggml_row_size(kv.k_l[il]->type, n_embd_head_k),
-                        ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa),
-                        0),
-                    K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow);
-        cb(tmp, "K_shifted", il);
-        ggml_build_forward_expand(graph, tmp);
-    }
-}
-
 static void llm_build_kv_store(
         struct ggml_context * ctx,
         const llama_hparams & hparams,
@@ -5001,6 +4996,7 @@ struct llm_build_context {
 
     const int64_t n_embd;
     const int64_t n_layer;
+    const int64_t n_rot;
     const int64_t n_ctx;       // user-specified context size (can be different from n_ctx_train)
     const int64_t n_head;
     const int64_t n_head_kv;
@@ -5025,8 +5021,8 @@ struct llm_build_context {
     const int32_t kv_head;  // index of where we store new KV data in the cache
     const int32_t n_orig_ctx;
 
-    const bool do_rope_shift;
-    const uint32_t pooling_type;
+    const enum llama_pooling_type pooling_type;
+    const enum llama_rope_type    rope_type;
 
     const llm_build_cb & cb;
 
@@ -5048,6 +5044,7 @@ struct llm_build_context {
         kv_self          (lctx.kv_self),
         n_embd           (hparams.n_embd),
         n_layer          (hparams.n_layer),
+        n_rot            (hparams.n_rot),
         n_ctx            (cparams.n_ctx),
         n_head           (hparams.n_head),
         n_head_kv        (hparams.n_head_kv),
@@ -5069,8 +5066,8 @@ struct llm_build_context {
         n_kv             (worst_case ? n_ctx            : kv_self.n),
         kv_head          (worst_case ? n_ctx - n_tokens : kv_self.head),
         n_orig_ctx       (cparams.n_yarn_orig_ctx),
-        do_rope_shift    (worst_case || kv_self.has_shift),
-        pooling_type     (cparams.do_pooling ? hparams.pooling_type : (uint32_t)LLAMA_POOLING_TYPE_NONE),
+        pooling_type     (cparams.do_pooling ? hparams.pooling_type : LLAMA_POOLING_TYPE_NONE),
+        rope_type        (hparams.rope_type),
         cb               (cb),
         buf_compute_meta (lctx.buf_compute_meta) {
             // all initializations should be done in init()
@@ -5093,6 +5090,74 @@ struct llm_build_context {
         }
     }
 
+    struct ggml_cgraph * build_k_shift() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * tmp =
+                // we rotate only the first n_rot dimensions
+                ggml_rope_custom_inplace(ctx0,
+                        ggml_view_3d(ctx0, kv_self.k_l[il],
+                            n_embd_head_k, n_head_kv, n_ctx,
+                            ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
+                            ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
+                            0),
+                        lctx.inp_K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+            cb(tmp, "K_shifted", il);
+            ggml_build_forward_expand(gf, tmp);
+        }
+
+        return gf;
+    }
+
+    struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+        for (int i = 0; i < n_kv; ++i) {
+            const int id = ids[i];
+
+            if (i == id || id == n_kv) {
+                continue;
+            }
+
+            int nm = 1;
+
+            while (i + nm < n_kv && (int) ids[i + nm] == id + nm) {
+                nm++;
+            }
+
+            for (int il = 0; il < n_layer; ++il) {
+                ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il],
+                        n_embd_k_gqa, nm,
+                        ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
+                        ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*i));
+
+                ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self.k_l[il],
+                        n_embd_k_gqa, nm,
+                        ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
+                        ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id));
+
+                ggml_tensor * view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
+                        nm, n_embd_v_gqa,
+                        ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
+                        ggml_row_size(kv_self.v_l[il]->type, i));
+
+                ggml_tensor * view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
+                        nm, n_embd_v_gqa,
+                        ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
+                        ggml_row_size(kv_self.v_l[il]->type, id));
+
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
+            }
+
+            i += nm - 1;
+        }
+
+        return gf;
+    }
+
     struct ggml_cgraph * build_llama() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
 
@@ -5114,11 +5179,6 @@ struct llm_build_context {
         struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
-        // shift the entire K-cache if needed
-        if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
-        }
-
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * inpSA = inpL;
 
@@ -5154,14 +5214,14 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_custom(
                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
-                    hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_custom(
                     ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
-                    hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -5302,11 +5362,6 @@ struct llm_build_context {
         struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0);
         cb(KQ_pos, "KQ_pos", -1);
 
-        // shift the entire K-cache if needed
-        if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
-        }
-
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * inpSA = inpL;
 
@@ -5330,12 +5385,12 @@ struct llm_build_context {
                     case MODEL_7B:
                         Qcur = ggml_rope_custom(
                             ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
-                            hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
+                            n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                             ext_factor, attn_factor, beta_fast, beta_slow
                         );
                         Kcur = ggml_rope_custom(
                             ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
-                            hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
+                            n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                             ext_factor, attn_factor, beta_fast, beta_slow
                         );
                         break;
@@ -5420,11 +5475,6 @@ struct llm_build_context {
         struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
-        // shift the entire K-cache if needed
-        if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
-        }
-
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * attn_norm;
 
@@ -5463,13 +5513,13 @@ struct llm_build_context {
 
                 // using mode = 2 for neox mode
                 Qcur = ggml_rope_custom(
-                    ctx0, Qcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
+                    ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_custom(
-                    ctx0, Kcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
+                    ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -5639,10 +5689,6 @@ struct llm_build_context {
         struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
-        if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
-        }
-
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * residual = inpL;
 
@@ -5700,7 +5746,7 @@ struct llm_build_context {
 
                 // RoPE the first n_rot of q/k, pass the other half, and concat.
                 struct ggml_tensor * qrot = ggml_view_3d(
-                        ctx0, tmpq, hparams.n_rot, n_head, n_tokens,
+                        ctx0, tmpq, n_rot, n_head, n_tokens,
                         ggml_element_size(tmpq) * n_embd_head,
                         ggml_element_size(tmpq) * n_embd_head * n_head,
                         0
@@ -5708,7 +5754,7 @@ struct llm_build_context {
                 cb(qrot, "qrot", il);
 
                 struct ggml_tensor * krot = ggml_view_3d(
-                        ctx0, tmpk, hparams.n_rot, n_head, n_tokens,
+                        ctx0, tmpk, n_rot, n_head, n_tokens,
                         ggml_element_size(tmpk) * n_embd_head,
                         ggml_element_size(tmpk) * n_embd_head * n_head,
                         0
@@ -5717,29 +5763,29 @@ struct llm_build_context {
 
                 // get the second half of tmpq, e.g tmpq[n_rot:, :, :]
                 struct ggml_tensor * qpass = ggml_view_3d(
-                        ctx0, tmpq, hparams.n_rot, n_head, n_tokens,
+                        ctx0, tmpq, n_rot, n_head, n_tokens,
                         ggml_element_size(tmpq) * n_embd_head,
                         ggml_element_size(tmpq) * n_embd_head * n_head,
-                        ggml_element_size(tmpq) * hparams.n_rot
+                        ggml_element_size(tmpq) * n_rot
                         );
                 cb(qpass, "qpass", il);
 
                 struct ggml_tensor * kpass = ggml_view_3d(
-                        ctx0, tmpk, hparams.n_rot, n_head, n_tokens,
+                        ctx0, tmpk, n_rot, n_head, n_tokens,
                         ggml_element_size(tmpk) * n_embd_head,
                         ggml_element_size(tmpk) * n_embd_head * n_head,
-                        ggml_element_size(tmpk) * hparams.n_rot
+                        ggml_element_size(tmpk) * n_rot
                         );
                 cb(kpass, "kpass", il);
 
                 struct ggml_tensor * qrotated = ggml_rope_custom(
-                    ctx0, qrot, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
+                    ctx0, qrot, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(qrotated, "qrotated", il);
 
                 struct ggml_tensor * krotated = ggml_rope_custom(
-                    ctx0, krot, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
+                    ctx0, krot, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(krotated, "krotated", il);
@@ -5991,14 +6037,14 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_custom(
                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos,
-                    hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_custom(
                     ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
-                    hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -6287,11 +6333,6 @@ struct llm_build_context {
         struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
-        // shift the entire K-cache if needed
-        if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
-        }
-
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * inpSA = inpL;
 
@@ -6328,14 +6369,14 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_custom(
                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos,
-                    hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_custom(
                     ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
-                    hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -6410,11 +6451,6 @@ struct llm_build_context {
         struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
-        // shift the entire K-cache if needed
-        if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
-        }
-
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * inpSA = inpL;
 
@@ -6444,13 +6480,13 @@ struct llm_build_context {
 
                 // using mode = 2 for neox mode
                 Qcur = ggml_rope_custom(
-                    ctx0, Qcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
+                    ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_custom(
-                    ctx0, Kcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
+                    ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -6524,11 +6560,6 @@ struct llm_build_context {
         struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
-        // shift the entire K-cache if needed
-        if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
-        }
-
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * inpSA = inpL;
 
@@ -6564,14 +6595,14 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_custom(
                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos,
-                    hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_custom(
                     ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
-                    hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -6645,11 +6676,6 @@ struct llm_build_context {
         struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
-        // shift the entire K-cache if needed
-        if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
-        }
-
         for (int il = 0; il < n_layer; ++il) {
             attn_norm_output = llm_build_norm(ctx0, inpL, hparams,
                     model.layers[il].attn_norm,
@@ -6687,7 +6713,7 @@ struct llm_build_context {
                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
 
                 Qcur = ggml_rope_custom(
-                    ctx0, Qcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
+                    ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
@@ -6698,7 +6724,7 @@ struct llm_build_context {
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_custom(
-                    ctx0, Kcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
+                    ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -6767,11 +6793,6 @@ struct llm_build_context {
         struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
-        // shift the entire K-cache if needed
-        if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
-        }
-
         for (int il = 0; il < n_layer; ++il) {
 
             // norm
@@ -6795,14 +6816,14 @@ struct llm_build_context {
                 cb(Vcur, "Vcur", il);
 
                 Qcur = ggml_rope_custom(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, hparams.n_rot, n_head,    n_tokens), inp_pos,
-                        n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale,
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head,    n_tokens), inp_pos,
+                        n_embd_head, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_custom(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, hparams.n_rot, n_head_kv, n_tokens), inp_pos,
-                        n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale,
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos,
+                        n_embd_head, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Kcur, "Kcur", il);
 
@@ -6972,11 +6993,6 @@ struct llm_build_context {
         struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
-        // shift the entire K-cache if needed
-        if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
-        }
-
         for (int il = 0; il < n_layer; ++il) {
             cur = llm_build_norm(ctx0, inpL, hparams,
                     model.layers[il].attn_norm,
@@ -7002,14 +7018,14 @@ struct llm_build_context {
 
                 struct ggml_tensor * Qcur = ggml_rope_custom(
                     ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head,    n_tokens), inp_pos,
-                    hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 struct ggml_tensor * Kcur = ggml_rope_custom(
                     ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos,
-                    hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -7080,11 +7096,6 @@ struct llm_build_context {
         struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
-        // shift the entire K-cache if needed
-        if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
-        }
-
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * inpSA = inpL;
 
@@ -7120,14 +7131,14 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_custom(
                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos,
-                    hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_custom(
                     ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
-                    hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -7199,11 +7210,6 @@ struct llm_build_context {
         struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
-        // shift the entire K-cache if needed
-        if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
-        }
-
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * inpSA = inpL;
 
@@ -7239,14 +7245,14 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_custom(
                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos,
-                    hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_custom(
                     ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
-                    hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -7331,11 +7337,6 @@ struct llm_build_context {
         struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
-        // shift the entire K-cache if needed
-        if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
-        }
-
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * inpSA = inpL;
 
@@ -7371,14 +7372,14 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_custom(
                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos,
-                    hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_custom(
                     ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
-                    hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -7467,11 +7468,6 @@ struct llm_build_context {
         struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
-        // shift the entire K-cache if needed
-        if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
-        }
-
         for (int il = 0; il < n_layer; ++il) {
 
             // norm
@@ -7494,7 +7490,7 @@ struct llm_build_context {
 
                 Qcur = ggml_rope_custom(
                         ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens), inp_pos,
-                        n_embd_head_k, 2, 0, n_orig_ctx, freq_base, freq_scale,
+                        n_embd_head_k, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Qcur, "Qcur", il);
 
@@ -7503,7 +7499,7 @@ struct llm_build_context {
 
                 Kcur = ggml_rope_custom(
                         ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos,
-                        n_embd_head_k, 2, 0, n_orig_ctx, freq_base, freq_scale,
+                        n_embd_head_k, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Kcur, "Kcur", il);
 
@@ -7556,6 +7552,40 @@ struct llm_build_context {
     }
 };
 
+static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
+    llama_batch dummy;
+    dummy.n_tokens = 0;
+
+    llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
+
+    struct llm_build_context llm(lctx, dummy, cb, false);
+
+    llm.init();
+
+    struct ggml_cgraph * result = llm.build_defrag(ids);
+
+    llm.free();
+
+    return result;
+}
+
+static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
+    llama_batch dummy;
+    dummy.n_tokens = 0;
+
+    llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
+
+    struct llm_build_context llm(lctx, dummy, cb, false);
+
+    llm.init();
+
+    struct ggml_cgraph * result = llm.build_k_shift();
+
+    llm.free();
+
+    return result;
+}
+
 static struct ggml_cgraph * llama_build_graph(
          llama_context & lctx,
      const llama_batch & batch,
@@ -7675,6 +7705,20 @@ static struct ggml_cgraph * llama_build_graph(
     return result;
 }
 
+static void llama_set_k_shift(llama_context & lctx) {
+    const auto & cparams = lctx.cparams;
+
+    const int64_t n_ctx = cparams.n_ctx;
+
+    assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
+
+    int32_t * data = (int32_t *) lctx.inp_K_shift->data;
+
+    for (int i = 0; i < n_ctx; ++i) {
+        data[i] = lctx.kv_self.cells[i].delta;
+    }
+}
+
 static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
     //
     // set input data
@@ -7742,18 +7786,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         }
     }
 
-    if (kv_self.has_shift) {
-        const int64_t n_ctx = cparams.n_ctx;
-
-        assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
-
-        int32_t * data = (int32_t *) lctx.inp_K_shift->data;
-
-        for (int i = 0; i < n_ctx; ++i) {
-            data[i] = lctx.kv_self.cells[i].delta;
-        }
-    }
-
     if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
         const int64_t n_tokens = batch.n_tokens;
 
@@ -7798,6 +7830,34 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
     }
 }
 
+static void llama_graph_compute(
+        llama_context & lctx,
+          ggml_cgraph * gf,
+                  int   n_threads) {
+#ifdef GGML_USE_MPI
+    const int64_t n_layer = lctx.model.hparams.n_layer;
+    ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer);
+#endif
+
+#ifdef GGML_USE_METAL
+    if (ggml_backend_is_metal(lctx.backend_metal)) {
+        ggml_backend_metal_set_n_cb(lctx.backend_metal, n_threads);
+    }
+#endif
+
+    if (lctx.backend_cpu != nullptr) {
+        ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads);
+    }
+
+    ggml_backend_sched_graph_compute(lctx.sched, gf);
+
+    // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
+
+#ifdef GGML_USE_MPI
+    ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer);
+#endif
+}
+
 // decode a batch of tokens by evaluating the transformer
 //
 //   - lctx:      llama context
@@ -7893,14 +7953,17 @@ static int llama_decode_internal(
 
     //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
 
+    llama_kv_cache_update(&lctx);
+
     ggml_backend_sched_reset(lctx.sched);
     ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
 
     ggml_cgraph * gf = llama_build_graph(lctx, batch, false);
 
     // the output is always the last tensor in the graph
-    struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
+    struct ggml_tensor * res        = gf->nodes[gf->n_nodes - 1];
     struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
+
     if (strcmp(res->name, "result_output") == 0) {
         // the embeddings could be the second to last tensor, or the third to last tensor
         if (strcmp(embeddings->name, "result_norm") != 0) {
@@ -7927,40 +7990,12 @@ static int llama_decode_internal(
         n_threads = std::min(4, n_threads);
     }
 
-#ifdef GGML_USE_MPI
-    const int64_t n_layer = hparams.n_layer;
-    ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer);
-#endif
-
-#ifdef GGML_USE_METAL
-    if (ggml_backend_is_metal(lctx.backend_metal)) {
-        ggml_backend_metal_set_n_cb(lctx.backend_metal, n_threads);
-    }
-#endif
-
-    if (lctx.backend_cpu != nullptr) {
-        ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads);
-    }
-
     llama_set_inputs(lctx, batch);
 
-    ggml_backend_sched_graph_compute(lctx.sched, gf);
-
-    // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
-
-#ifdef GGML_USE_MPI
-    ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer);
-#endif
+    llama_graph_compute(lctx, gf, n_threads);
 
     // update the kv ring buffer
     {
-        if (kv_self.has_shift) {
-            kv_self.has_shift = false;
-            for (uint32_t i = 0; i < kv_self.size; ++i) {
-                kv_self.cells[i].delta = 0;
-            }
-        }
-
         kv_self.head += n_tokens;
 
         // Ensure kv cache head points to a valid index.
@@ -8056,6 +8091,221 @@ static int llama_decode_internal(
     return 0;
 }
 
+// find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
+static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
+    auto & kv_self = lctx.kv_self;
+
+    const uint32_t n_kv   = llama_kv_cache_cell_max(kv_self);
+    const uint32_t n_used = kv_self.used;
+
+    assert(n_used <= n_kv);
+
+    const int64_t t_start = ggml_time_us();
+
+    // number of cells moved
+    uint32_t n_moves = 0;
+
+    // determine which KV cells to move where
+    //
+    //  cell i moves to ids[i]
+    //
+    //  if ids[i] == i || ids[i] == n_kv, then cell i is not moved
+    //
+    std::vector<uint32_t> ids(n_kv, n_kv);
+
+    for (uint32_t i0 = 0; i0 < n_used; ++i0) {
+        const auto & cell0 = kv_self.cells[i0];
+
+        if (!cell0.is_empty()) {
+            ids[i0] = i0;
+
+            continue;
+        }
+
+        // found a hole - fill it with data from the end of the cache
+
+        // determine the size of the hole
+        uint32_t nh = 1;
+        while (i0 + nh < n_used && kv_self.cells[i0 + nh].is_empty()) {
+            nh++;
+        }
+
+        // starting from the end, find nh non-empty cells
+        uint32_t nf = 0;
+        uint32_t is = n_kv - 1;
+        for (; is > i0; --is) {
+            const auto & cell1 = kv_self.cells[is];
+
+            if (cell1.is_empty() || ids[is] != n_kv) {
+                continue;
+            }
+
+            // non-empty cell which is not yet moved
+            nf++;
+
+            if (nf == nh) {
+                break;
+            }
+        }
+
+        // this can only happen if `n_used` is not accurate, which would be a bug
+        GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
+
+        nf = 0;
+
+        // go back and move the nf cells to the hole
+        for (uint32_t i1 = is; i1 < n_kv; ++i1) {
+            const auto & cell1 = kv_self.cells[i1];
+
+            if (cell1.is_empty() || ids[i1] != n_kv) {
+                continue;
+            }
+
+            // this cell goes to (i0 + nf)
+            ids[i1] = i0 + nf;
+
+            // move the cell meta data
+            kv_self.cells[i0 + nf] = cell1;
+
+            n_moves++;
+            nf++;
+        }
+
+        LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, n_kv, i0, i0 + nh);
+
+        i0 += nh - 1;
+    }
+
+    if (n_moves == 0) {
+        return;
+    }
+
+    LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
+
+    kv_self.head = n_used;
+    kv_self.used = n_used;
+
+    // zero the rest of the cells
+    for (uint32_t i = n_used; i < n_kv; ++i) {
+        kv_self.cells[i] = llama_kv_cell();
+    }
+
+#if 0
+    // CPU defrag
+    //
+    // TODO: optimizations are possible:
+    //       - multiple threads
+    //       - avoid copying to the host memory when already there
+    //
+    // likely not worth the effort, as we have ggml_graph based defrag
+    //
+
+    const auto & hparams = lctx.model.hparams;
+
+    const uint32_t n_layer      = hparams.n_layer;
+    const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
+    const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
+
+    const uint32_t kv_size = kv_self.size;
+
+    std::vector<uint8_t> buf_k;
+    std::vector<uint8_t> buf_v;
+
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
+        const size_t k_size     = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_size);
+
+        const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
+        const size_t v_size    = ggml_row_size (kv_self.v_l[il]->type, n_embd_v_gqa*kv_size);
+
+        buf_k.resize(k_size);
+        buf_v.resize(v_size);
+
+        ggml_backend_tensor_get(kv_self.k_l[il], buf_k.data(), 0, buf_k.size());
+        ggml_backend_tensor_get(kv_self.v_l[il], buf_v.data(), 0, buf_v.size());
+
+        // batch move [i, i+nm) to [id, id+nm)
+        // note: cells can move only to a lower index
+        for (uint32_t i = 0; i < n_kv; ++i) {
+            const uint32_t id = ids[i];
+
+            if (i == id || id == n_kv) {
+                continue;
+            }
+
+            uint32_t nm = 1;
+
+            while (i + nm < n_kv && ids[i + nm] == id + nm) {
+                nm++;
+            }
+
+            // move keys
+            {
+                const int64_t os =  i*k_size_row;
+                const int64_t od = id*k_size_row;
+
+                memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
+            }
+
+            // move values (note: they are transposed)
+            {
+                const int64_t os =  i;
+                const int64_t od = id;
+
+                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+                    memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
+                }
+            }
+
+            i += nm - 1;
+        }
+
+        ggml_backend_tensor_set(kv_self.k_l[il], buf_k.data(), 0, buf_k.size());
+        ggml_backend_tensor_set(kv_self.v_l[il], buf_v.data(), 0, buf_v.size());
+    }
+#else
+    // ggml_graph defrag
+
+    ggml_cgraph * gf = llama_build_graph_defrag(lctx, ids);
+
+    llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
+#endif
+
+    const int64_t t_end = ggml_time_us();
+
+    LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0);
+}
+
+static void llama_kv_cache_update_internal(struct llama_context & lctx) {
+    // apply K-shift if needed
+    if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
+        llama_set_k_shift(lctx);
+
+        {
+            ggml_cgraph * gf = llama_build_graph_k_shift(lctx);
+
+            llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
+        }
+
+        {
+            auto & kv_self = lctx.kv_self;
+
+            kv_self.has_shift = false;
+
+            for (uint32_t i = 0; i < kv_self.size; ++i) {
+                kv_self.cells[i].delta = 0;
+            }
+        }
+    }
+
+    // defragment the KV cache if needed
+    if (lctx.kv_self.do_defrag) {
+        llama_kv_cache_defrag_internal(lctx);
+
+        lctx.kv_self.do_defrag = false;
+    }
+}
+
 //
 // tokenizer
 //
@@ -11671,8 +11921,7 @@ struct llama_context * llama_new_context_with_model(
         }
         ctx->backends.push_back(ctx->backend_cpu);
 
-        if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v,
-                cparams.n_ctx, cparams.offload_kqv)) {
+        if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, cparams.n_ctx, cparams.offload_kqv)) {
             LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
             llama_free(ctx);
             return nullptr;
@@ -11820,6 +12069,49 @@ enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
     return model->vocab.type;
 }
 
+enum llama_rope_type llama_rope_type(const struct llama_model * model) {
+    switch (model->arch) {
+        // these models do not use RoPE
+        case LLM_ARCH_GPT2:
+        case LLM_ARCH_GPTJ:
+        case LLM_ARCH_GPTNEOX:
+        case LLM_ARCH_MPT:
+        case LLM_ARCH_REFACT:
+        case LLM_ARCH_BLOOM:
+            return LLAMA_ROPE_TYPE_NONE;
+
+        // use what we call a normal RoPE, operating on pairs of consecutive head values
+        case LLM_ARCH_LLAMA:
+        case LLM_ARCH_BAICHUAN:
+        case LLM_ARCH_STARCODER:
+        case LLM_ARCH_PLAMO:
+        case LLM_ARCH_CODESHELL:
+        case LLM_ARCH_ORION:
+        case LLM_ARCH_INTERNLM2:
+        case LLM_ARCH_MINICPM:
+        case LLM_ARCH_GEMMA:
+            return LLAMA_ROPE_TYPE_NORM;
+
+        // the pairs of head values are offset by n_rot/2
+        case LLM_ARCH_FALCON:
+        case LLM_ARCH_PERSIMMON:
+        case LLM_ARCH_BERT:
+        case LLM_ARCH_NOMIC_BERT:
+        case LLM_ARCH_STABLELM:
+        case LLM_ARCH_QWEN:
+        case LLM_ARCH_QWEN2:
+        case LLM_ARCH_PHI2:
+            return LLAMA_ROPE_TYPE_NEOX;
+
+        // all model arches should be listed explicitly here
+        case LLM_ARCH_UNKNOWN:
+            GGML_ASSERT(false && "unknown architecture");
+            break;
+    }
+
+    return LLAMA_ROPE_TYPE_NONE;
+}
+
 int32_t llama_n_vocab(const struct llama_model * model) {
     return model->vocab.id_to_token.size();
 }
@@ -12062,12 +12354,12 @@ void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) {
     llama_kv_cache_seq_keep(ctx->kv_self, seq_id);
 }
 
-void llama_kv_cache_seq_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
+void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
     if (delta == 0) {
         return;
     }
 
-    llama_kv_cache_seq_shift(ctx->kv_self, seq_id, p0, p1, delta);
+    llama_kv_cache_seq_add(ctx->kv_self, seq_id, p0, p1, delta);
 }
 
 void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
@@ -12078,6 +12370,19 @@ void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, lla
     llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d);
 }
 
+llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) {
+    return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id);
+}
+
+void llama_kv_cache_defrag(struct llama_context * ctx) {
+    llama_kv_cache_defrag(ctx->kv_self);
+}
+
+void llama_kv_cache_update(struct llama_context * ctx) {
+    llama_kv_cache_update_internal(*ctx);
+}
+
+
 // Returns the *maximum* size of the state
 size_t llama_get_state_size(const struct llama_context * ctx) {
     // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
@@ -12204,10 +12509,10 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
         const auto & hparams = ctx->model.hparams;
         const auto & cparams = ctx->cparams;
 
-        const auto   n_layer      = hparams.n_layer;
-        const auto   n_embd_k_gqa = hparams.n_embd_k_gqa();
-        const auto   n_embd_v_gqa = hparams.n_embd_v_gqa();
-        const auto   n_ctx        = cparams.n_ctx;
+        const uint32_t n_layer      = hparams.n_layer;
+        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
+        const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
+        const uint32_t n_ctx        = cparams.n_ctx;
 
         const size_t   kv_buf_size = kv_self.total_size();
         const uint32_t kv_head     = kv_self.head;
@@ -12222,14 +12527,16 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
         if (kv_buf_size) {
             std::vector<uint8_t> tmp_buf;
             for (int il = 0; il < (int) n_layer; ++il) {
-                size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
+                const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
+
                 tmp_buf.resize(k_size);
                 ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size());
                 data_ctx->write(tmp_buf.data(), tmp_buf.size());
 
                 // v is not contiguous, copy row by row
-                size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
-                size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, n_ctx);
+                const size_t v_row_size   = ggml_row_size(kv_self.v_l[il]->type, kv_head);
+                const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, n_ctx);
+
                 tmp_buf.resize(v_row_size);
                 for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
                     ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), ir*v_row_stride, tmp_buf.size());
@@ -12316,10 +12623,10 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
         const auto & hparams = ctx->model.hparams;
         const auto & cparams = ctx->cparams;
 
-        const int    n_layer      = hparams.n_layer;
-        const int    n_embd_k_gqa = hparams.n_embd_k_gqa();
-        const int    n_embd_v_gqa = hparams.n_embd_v_gqa();
-        const int    n_ctx        = cparams.n_ctx;
+        const uint32_t n_layer      = hparams.n_layer;
+        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
+        const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
+        const uint32_t n_ctx        = cparams.n_ctx;
 
         size_t   kv_buf_size;
         uint32_t kv_head;
@@ -12335,13 +12642,15 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
             GGML_ASSERT(kv_self.total_size() == kv_buf_size);
 
             for (int il = 0; il < (int) n_layer; ++il) {
-                size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
+                const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
+
                 ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size);
                 inp += k_size;
 
                 // v is not contiguous, copy row by row
-                size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
-                size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, n_ctx);
+                const size_t v_row_size   = ggml_row_size(kv_self.v_l[il]->type, kv_head);
+                const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, n_ctx);
+
                 for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
                     ggml_backend_tensor_set(kv_self.v_l[il], inp, ir*v_row_stride, v_row_size);
                     inp += v_row_size;
diff --git a/llama.h b/llama.h
index 947284ea2f5355b4a6726bf576d7374ba5a6eb17..ff131996d9a38e438f600fb69c503861090e2eff 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -64,6 +64,15 @@ extern "C" {
         LLAMA_VOCAB_TYPE_WPM = 2, // WordPiece
     };
 
+    // note: these values should be synchronized with ggml_rope
+    // TODO: maybe move this enum to ggml.h (ggml_rope_type)
+    enum llama_rope_type {
+        LLAMA_ROPE_TYPE_NONE = -1,
+        LLAMA_ROPE_TYPE_NORM =  0,
+        LLAMA_ROPE_TYPE_NEOX =  2,
+        LLAMA_ROPE_TYPE_GLM  =  4,
+    };
+
     enum llama_token_type {
         LLAMA_TOKEN_TYPE_UNDEFINED    = 0,
         LLAMA_TOKEN_TYPE_NORMAL       = 1,
@@ -360,6 +369,7 @@ extern "C" {
     LLAMA_API uint32_t llama_n_batch    (const struct llama_context * ctx);
 
     LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
+    LLAMA_API enum llama_rope_type  llama_rope_type (const struct llama_model * model);
 
     LLAMA_API int32_t llama_n_vocab    (const struct llama_model * model);
     LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
@@ -514,10 +524,12 @@ extern "C" {
                     llama_seq_id   seq_id);
 
     // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
-    // If the KV cache is RoPEd, the KV data is updated accordingly
+    // If the KV cache is RoPEd, the KV data is updated accordingly:
+    //   - lazily on next llama_decode()
+    //   - explicitly with llama_kv_cache_update()
     // p0 < 0 : [0,  p1]
     // p1 < 0 : [p0, inf)
-    LLAMA_API void llama_kv_cache_seq_shift(
+    LLAMA_API void llama_kv_cache_seq_add(
             struct llama_context * ctx,
                     llama_seq_id   seq_id,
                        llama_pos   p0,
@@ -525,7 +537,9 @@ extern "C" {
                        llama_pos   delta);
 
     // Integer division of the positions by factor of `d > 1`
-    // If the KV cache is RoPEd, the KV data is updated accordingly
+    // If the KV cache is RoPEd, the KV data is updated accordingly:
+    //   - lazily on next llama_decode()
+    //   - explicitly with llama_kv_cache_update()
     // p0 < 0 : [0,  p1]
     // p1 < 0 : [p0, inf)
     LLAMA_API void llama_kv_cache_seq_div(
@@ -535,6 +549,20 @@ extern "C" {
                        llama_pos   p1,
                              int   d);
 
+    // Returns the largest position present in the KV cache for the specified sequence
+    LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
+            struct llama_context * ctx,
+                    llama_seq_id   seq_id);
+
+    // Defragment the KV cache
+    // This will be applied:
+    //   - lazily on next llama_decode()
+    //   - explicitly with llama_kv_cache_update()
+    LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx);
+
+    // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
+    LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);
+
     //
     // State / sessions
     //