]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : refactor llama_context, llama_kv_cache, llm_build_context (#12181)
authorGeorgi Gerganov <redacted>
Thu, 13 Mar 2025 10:35:44 +0000 (12:35 +0200)
committerGitHub <redacted>
Thu, 13 Mar 2025 10:35:44 +0000 (12:35 +0200)
* llama : refactor llama_context, llama_kv_cache, llm_build_context

ggml-ci

* graph : don't mutate the KV cache during defrag

ggml-ci

* context : reduce virtuals + remove test function

ggml-ci

* context : move interface implementation to source file + factory

ggml-ci

* graph : move KV cache build functions to llama_context impl

ggml-ci

* graph : remove model reference from build_pooling

ggml-ci

* graph : remove llama_model reference

ggml-ci

* kv_cache : provide rope factors

ggml-ci

* graph : rework inputs to use only unique_ptr, remove attn input abstraction

ggml-ci

* context : remove llama_context_i abstraction

ggml-ci

* context : clean-up

ggml-ci

* graph : clean-up

ggml-ci

* llama : remove redundant keywords (struct, enum)

ggml-ci

* model : adapt gemma3

ggml-ci

* graph : restore same attention ops as on master

ggml-ci

* llama : remove TODO + fix indent

ggml-ci

46 files changed:
common/common.cpp
common/speculative.cpp
examples/batched-bench/batched-bench.cpp
examples/batched.swift/Sources/main.swift
examples/cvector-generator/cvector-generator.cpp
examples/embedding/embedding.cpp
examples/gritlm/gritlm.cpp
examples/imatrix/imatrix.cpp
examples/infill/infill.cpp
examples/llama-bench/llama-bench.cpp
examples/llama.android/llama/src/main/cpp/llama-android.cpp
examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
examples/llava/gemma3-cli.cpp
examples/lookahead/lookahead.cpp
examples/lookup/lookup.cpp
examples/main/main.cpp
examples/parallel/parallel.cpp
examples/passkey/passkey.cpp
examples/perplexity/perplexity.cpp
examples/quantize-stats/quantize-stats.cpp
examples/retrieval/retrieval.cpp
examples/run/run.cpp
examples/save-load-state/save-load-state.cpp
examples/server/server.cpp
examples/server/tests/utils.py
examples/simple-chat/simple-chat.cpp
examples/speculative-simple/speculative-simple.cpp
examples/speculative/speculative.cpp
include/llama.h
src/CMakeLists.txt
src/llama-adapter.cpp
src/llama-adapter.h
src/llama-batch.h
src/llama-context.cpp
src/llama-context.h
src/llama-graph.cpp [new file with mode: 0644]
src/llama-graph.h [new file with mode: 0644]
src/llama-io.cpp [new file with mode: 0644]
src/llama-io.h [new file with mode: 0644]
src/llama-kv-cache.cpp
src/llama-kv-cache.h
src/llama-memory.cpp [new file with mode: 0644]
src/llama-memory.h [new file with mode: 0644]
src/llama-model.cpp
src/llama-model.h
src/llama.cpp

index 6448b7b03d6d2614b3f7da8b5e7b0cf1a6ab0879..8487e3834bccb474c2961da607ed35101a6c245e 100644 (file)
@@ -955,8 +955,8 @@ struct common_init_result common_init_from_params(common_params & params) {
         return iparams;
     }
 
-    if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) {
-        LOG_WRN("%s: KV cache shifting is not supported for this model, disabling KV cache shifting\n", __func__);
+    if (params.ctx_shift && !llama_kv_self_can_shift(lctx)) {
+        LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
         params.ctx_shift = false;
     }
 
@@ -1060,7 +1060,7 @@ struct common_init_result common_init_from_params(common_params & params) {
         if (llama_model_has_decoder(model)) {
             llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
         }
-        llama_kv_cache_clear(lctx);
+        llama_kv_self_clear(lctx);
         llama_synchronize(lctx);
         llama_perf_context_reset(lctx);
     }
index 1bac3a1ce101e452d8fefca3f890eee13ed54398..ccad70fa9ed85728ec6a7633c41f04e5e5999807 100644 (file)
@@ -173,7 +173,7 @@ llama_tokens common_speculative_gen_draft(
     result.reserve(params.n_draft);
 
     if (reuse_n == 0) {
-        llama_kv_cache_clear(ctx);
+        llama_kv_self_clear(ctx);
 
         prompt.clear();
     } else {
@@ -192,14 +192,14 @@ llama_tokens common_speculative_gen_draft(
         }
 
         if (reuse_i > 0) {
-            llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i);
-            llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
+            llama_kv_self_seq_rm (ctx, 0, 0, reuse_i);
+            llama_kv_self_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
 
             prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
         }
 
         if (reuse_n < (int) prompt.size()) {
-            llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1);
+            llama_kv_self_seq_rm (ctx, 0, reuse_n, -1);
 
             prompt.erase(prompt.begin() + reuse_n, prompt.end());
         }
index 0659ab6f119a70921b4c77d2073b94b641e04fa9..430e8be51265304717ce701e60dbb42193cac3cc 100644 (file)
@@ -132,7 +132,7 @@ int main(int argc, char ** argv) {
 
                 const auto t_pp_start = ggml_time_us();
 
-                llama_kv_cache_clear(ctx);
+                llama_kv_self_clear(ctx);
 
                 if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
                     LOG_ERR("%s: llama_decode() failed\n", __func__);
@@ -141,7 +141,7 @@ int main(int argc, char ** argv) {
 
                 if (is_pp_shared) {
                     for (int32_t i = 1; i < pl; ++i) {
-                        llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
+                        llama_kv_self_seq_cp(ctx, 0, i, -1, -1);
                     }
                 }
 
index 55c31166ca278f1e049f74db071259c1e7c12c85..514989e340e2c11905668241d117a6d411563c67 100644 (file)
@@ -116,7 +116,7 @@ if llama_decode(context, batch) != 0 {
 }
 
 for i in 1 ..< n_parallel {
-    llama_kv_cache_seq_cp(context, 0, Int32(i), 0, batch.n_tokens)
+    llama_kv_self_seq_cp(context, 0, Int32(i), 0, batch.n_tokens)
 }
 
 if n_parallel > 1 {
index c72528dac3ff009c50c2dedbf39140f3fe87b28c..2a907155010cbd545be8a836b53177fa6fcc0786 100644 (file)
@@ -342,7 +342,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
 }
 
 static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
-    llama_kv_cache_clear(ctx);
+    llama_kv_self_clear(ctx);
     if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
         fprintf(stderr, "%s : failed to eval\n", __func__);
         return false;
index 3dd9f2b07d177e44afe6a5afbb607c5c2af7f0bc..6f08904159fd50a1045eadb78f9048624119096a 100644 (file)
@@ -38,7 +38,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
     const struct llama_model * model = llama_get_model(ctx);
 
     // clear previous kv_cache values (irrelevant for embeddings)
-    llama_kv_cache_clear(ctx);
+    llama_kv_self_clear(ctx);
 
     // run model
     LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
index 72eb46257429e63a8cfdf4e1be5bf7b4285d0ded..f7db7861c1ad58f7c077f2a808b230a0cf78e437 100644 (file)
@@ -45,7 +45,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
         }
 
         // clear previous kv_cache values (irrelevant for embeddings)
-        llama_kv_cache_clear(ctx);
+        llama_kv_self_clear(ctx);
         llama_set_embeddings(ctx, true);
         llama_set_causal_attn(ctx, false);
 
@@ -102,7 +102,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
 
     llama_token eos_token = llama_vocab_eos(vocab);
 
-    llama_kv_cache_clear(ctx);
+    llama_kv_self_clear(ctx);
     llama_set_embeddings(ctx, false);
     llama_set_causal_attn(ctx, true);
 
index 91649c45065f470755ba2833becd240dda7b41f5..31b675e8f90b9851d00f1a20e0d479cb3a18fe9c 100644 (file)
@@ -495,7 +495,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
         const auto t_start = std::chrono::high_resolution_clock::now();
 
         // clear the KV cache
-        llama_kv_cache_clear(ctx);
+        llama_kv_self_clear(ctx);
 
         llama_batch batch = llama_batch_init(n_batch, 0, 1);
 
index 489a208b66b344c13e449d14b8e34c8dfa038c9e..4e2f7b727000331dc51fc3121770b79ba8454c8a 100644 (file)
@@ -332,8 +332,8 @@ int main(int argc, char ** argv) {
                 LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
                     n_past, n_left, n_ctx, params.n_keep, n_discard);
 
-                llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1            , params.n_keep + n_discard + 1);
-                llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
+                llama_kv_self_seq_rm (ctx, 0, params.n_keep + 1            , params.n_keep + n_discard + 1);
+                llama_kv_self_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
 
                 n_past -= n_discard;
 
index f518d02d3868969d7c4422f43770c26659a35f37..cbcbfcee861ee924df440a19bcd79dcf48ecdb3e 100644 (file)
@@ -1578,7 +1578,7 @@ int main(int argc, char ** argv) {
 
         test t(inst, lmodel, ctx);
 
-        llama_kv_cache_clear(ctx);
+        llama_kv_self_clear(ctx);
 
         // cool off before the test
         if (params.delay) {
@@ -1618,7 +1618,7 @@ int main(int argc, char ** argv) {
         }
 
         for (int i = 0; i < params.reps; i++) {
-            llama_kv_cache_clear(ctx);
+            llama_kv_self_clear(ctx);
 
             uint64_t t_start = get_time_ns();
 
index 0de61ce77c4fa56d6178fef0d35cb149c96928e9..9654cd53cf8d5cecb7fef4774a1f023910dd0e98 100644 (file)
@@ -194,7 +194,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
         }
 
         batch->logits[batch->n_tokens - 1] = true;
-        llama_kv_cache_clear(context);
+        llama_kv_self_clear(context);
 
         const auto t_pp_start = ggml_time_us();
         if (llama_decode(context, *batch) != 0) {
@@ -206,7 +206,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
 
         LOGi("Benchmark text generation (tg)");
 
-        llama_kv_cache_clear(context);
+        llama_kv_self_clear(context);
         const auto t_tg_start = ggml_time_us();
         for (i = 0; i < tg; i++) {
 
@@ -223,7 +223,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
 
         const auto t_tg_end = ggml_time_us();
 
-        llama_kv_cache_clear(context);
+        llama_kv_self_clear(context);
 
         const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
         const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
@@ -448,5 +448,5 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
 extern "C"
 JNIEXPORT void JNICALL
 Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
-    llama_kv_cache_clear(reinterpret_cast<llama_context *>(context));
+    llama_kv_self_clear(reinterpret_cast<llama_context *>(context));
 }
index ee7141a66322437a0ea95d2fa6273ab48f781cab..f6e31abc93c0925b56d32b0bb9fd51385f55e323 100644 (file)
@@ -210,7 +210,7 @@ actor LlamaContext {
             }
             batch.logits[Int(batch.n_tokens) - 1] = 1 // true
 
-            llama_kv_cache_clear(context)
+            llama_kv_self_clear(context)
 
             let t_pp_start = DispatchTime.now().uptimeNanoseconds / 1000;
 
@@ -223,7 +223,7 @@ actor LlamaContext {
 
             // bench text generation
 
-            llama_kv_cache_clear(context)
+            llama_kv_self_clear(context)
 
             let t_tg_start = DispatchTime.now().uptimeNanoseconds / 1000;
 
@@ -242,7 +242,7 @@ actor LlamaContext {
 
             let t_tg_end = DispatchTime.now().uptimeNanoseconds / 1000;
 
-            llama_kv_cache_clear(context)
+            llama_kv_self_clear(context)
 
             let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0
             let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0
@@ -292,7 +292,7 @@ actor LlamaContext {
     func clear() {
         tokens_list.removeAll()
         temporary_invalid_cchars.removeAll()
-        llama_kv_cache_clear(context)
+        llama_kv_self_clear(context)
     }
 
     private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
index a07864d4e59f6e1fa99ccd31340de90d9be3f5f1..c36bb2eda0c700732e117150d9831f32b9f2f560 100644 (file)
@@ -309,7 +309,7 @@ int main(int argc, char ** argv) {
             }
             if (line == "/clear") {
                 ctx.n_past = 0;
-                llama_kv_cache_seq_rm(ctx.lctx, 0, 1, -1); // keep BOS
+                llama_kv_self_seq_rm(ctx.lctx, 0, 1, -1); // keep BOS
                 LOG("Chat history cleared\n\n");
                 continue;
             }
index b9e8de694c0e8df8901ce7058d74b16a95bcac78..7df20aee170466d4d474846b1a0ed20580a42564 100644 (file)
@@ -96,7 +96,7 @@ int main(int argc, char ** argv) {
     llama_decode(ctx, llama_batch_get_one(&inp.back(),           1));
 
     for (int s = 1; s < W + G + 1; ++s) {
-        llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
+        llama_kv_self_seq_cp(ctx, 0, s, -1, -1);
     }
 
     const auto t_enc_end = ggml_time_us();
@@ -438,17 +438,17 @@ int main(int argc, char ** argv) {
 
         // KV cache management
         // if no verification token matched, we simply remove all cells from this batch -> no fragmentation
-        llama_kv_cache_seq_rm(ctx, -1, n_past, -1);
+        llama_kv_self_seq_rm(ctx, -1, n_past, -1);
 
         if (seq_id_best != 0) {
             // if a verification token matched, we keep the best sequence and remove the rest
             // this leads to some KV cache fragmentation
-            llama_kv_cache_seq_keep(ctx, seq_id_best);
-            llama_kv_cache_seq_cp  (ctx, seq_id_best, 0, -1, -1);
-            llama_kv_cache_seq_rm  (ctx, seq_id_best,    -1, -1);
+            llama_kv_self_seq_keep(ctx, seq_id_best);
+            llama_kv_self_seq_cp  (ctx, seq_id_best, 0, -1, -1);
+            llama_kv_self_seq_rm  (ctx, seq_id_best,    -1, -1);
 
             for (int s = 1; s < W + G + 1; ++s) {
-                llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
+                llama_kv_self_seq_cp(ctx, 0, s, -1, -1);
             }
         }
     }
index dbd0444ec874292516dbbd932dc0ddee435e2230..4ae93b2a5ed15c08625f66626b38b71113fce383 100644 (file)
@@ -192,7 +192,7 @@ int main(int argc, char ** argv){
 
         // KV cache management
         // clean the cache of draft tokens that weren't accepted
-        llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
+        llama_kv_self_seq_rm(ctx, 0, n_past, -1);
 
         common_batch_clear(batch_tgt);
         common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);
index 4e0c69473badb5cd72dbc4df769e3f722176b42e..fd7410a646c69730a8f1bdb963bd47803b869e13 100644 (file)
@@ -354,7 +354,7 @@ int main(int argc, char ** argv) {
         }
 
         // remove any "future" tokens that we might have inherited from the previous session
-        llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1);
+        llama_kv_self_seq_rm(ctx, -1, n_matching_session_tokens, -1);
     }
 
     LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n",
@@ -602,8 +602,8 @@ int main(int argc, char ** argv) {
                     LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
                             n_past, n_left, n_ctx, params.n_keep, n_discard);
 
-                    llama_kv_cache_seq_rm (ctx, 0, params.n_keep            , params.n_keep + n_discard);
-                    llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
+                    llama_kv_self_seq_rm (ctx, 0, params.n_keep            , params.n_keep + n_discard);
+                    llama_kv_self_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
 
                     n_past -= n_discard;
 
@@ -626,9 +626,9 @@ int main(int argc, char ** argv) {
                     LOG_DBG("div:   [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
                     LOG_DBG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
 
-                    llama_kv_cache_seq_add(ctx, 0, ga_i,                n_past,              ib*bd);
-                    llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd,        ga_i + ib*bd + ga_w, ga_n);
-                    llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd,      dd);
+                    llama_kv_self_seq_add(ctx, 0, ga_i,                n_past,              ib*bd);
+                    llama_kv_self_seq_div(ctx, 0, ga_i + ib*bd,        ga_i + ib*bd + ga_w, ga_n);
+                    llama_kv_self_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd,      dd);
 
                     n_past -= bd;
 
index be18909ed4f97e9529143a35947c9bd091ef9a6a..588632f0432b29a0ff49d379e5c59ed33544929f 100644 (file)
@@ -202,7 +202,7 @@ int main(int argc, char ** argv) {
 
         // assign the system KV cache to all parallel sequences
         for (int32_t i = 1; i <= n_clients; ++i) {
-            llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
+            llama_kv_self_seq_cp(ctx, 0, i, -1, -1);
         }
 
         LOG_INF("\n");
@@ -234,9 +234,9 @@ int main(int argc, char ** argv) {
         if (batch.n_tokens == 0) {
             // all sequences have ended - clear the entire KV cache
             for (int i = 1; i <= n_clients; ++i) {
-                llama_kv_cache_seq_rm(ctx, i, -1, -1);
+                llama_kv_self_seq_rm(ctx, i, -1, -1);
                 // but keep the system prompt
-                llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
+                llama_kv_self_seq_cp(ctx, 0, i, -1, -1);
             }
 
             LOG_INF("%s: clearing the KV cache\n", __func__);
@@ -372,8 +372,8 @@ int main(int argc, char ** argv) {
                     }
 
                     // delete only the generated part of the sequence, i.e. keep the system prompt in the cache
-                    llama_kv_cache_seq_rm(ctx,    client.id + 1, -1, -1);
-                    llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1);
+                    llama_kv_self_seq_rm(ctx,    client.id + 1, -1, -1);
+                    llama_kv_self_seq_cp(ctx, 0, client.id + 1, -1, -1);
 
                     const auto t_main_end = ggml_time_us();
 
index fa85190518ef573ee28bd2df0935134aeaae3876..ea3a6c1fca3eec46c2abd2d95aaa98a8f12d0ae2 100644 (file)
@@ -133,11 +133,11 @@ int main(int argc, char ** argv) {
             const int ib = i/n_batch - 1;
             const int bd = n_batch_grp*(n_grp - 1);
 
-            llama_kv_cache_seq_add (ctx, 0, n_past - n_batch,         n_past,         ib*bd);
-            llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
-            llama_kv_cache_update  (ctx);
+            llama_kv_self_seq_add (ctx, 0, n_past - n_batch,         n_past,         ib*bd);
+            llama_kv_self_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
+            llama_kv_self_update  (ctx);
 
-            n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
+            n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
         }
 
         common_batch_clear(batch);
@@ -167,12 +167,12 @@ int main(int argc, char ** argv) {
 
         LOG_INF("%s: shifting KV cache with %d\n", __func__, n_discard);
 
-        llama_kv_cache_seq_rm (ctx, 0, n_keep            , n_keep + n_discard);
-        llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx,  -n_discard);
-      //llama_kv_cache_defrag (ctx);
-        llama_kv_cache_update (ctx);
+        llama_kv_self_seq_rm (ctx, 0, n_keep            , n_keep + n_discard);
+        llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx,  -n_discard);
+      //llama_kv_self_defrag (ctx);
+        llama_kv_self_update (ctx);
 
-        n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
+        n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
 
         common_batch_clear(batch);
 
@@ -198,12 +198,12 @@ int main(int argc, char ** argv) {
         if (n_discard > 0) {
             LOG_INF("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
 
-            llama_kv_cache_seq_rm (ctx, 0, n_keep            , n_keep + n_discard);
-            llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx,  -n_discard);
-          //llama_kv_cache_defrag (ctx);
-            llama_kv_cache_update (ctx);
+            llama_kv_self_seq_rm (ctx, 0, n_keep            , n_keep + n_discard);
+            llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx,  -n_discard);
+          //llama_kv_self_defrag (ctx);
+            llama_kv_self_update (ctx);
 
-            n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
+            n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
         }
     }
 
index 5d07421e827d18e95805a00b9328786757ad8097..8c413f7d66e6d1c3d7c224fb1c17ecc88646aebb 100644 (file)
@@ -361,7 +361,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
         const auto t_start = std::chrono::high_resolution_clock::now();
 
         // clear the KV cache
-        llama_kv_cache_clear(ctx);
+        llama_kv_self_clear(ctx);
 
         llama_batch batch = llama_batch_init(n_batch, 0, 1);
 
@@ -547,7 +547,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
         const auto t_start = std::chrono::high_resolution_clock::now();
 
         // clear the KV cache
-        llama_kv_cache_clear(ctx);
+        llama_kv_self_clear(ctx);
 
         for (int j = 0; j < num_batches; ++j) {
             const int batch_start = start + j * n_batch;
@@ -924,7 +924,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
             return;
         }
 
-        llama_kv_cache_clear(ctx);
+        llama_kv_self_clear(ctx);
 
         // decode all tasks [i0, i1)
         if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@@ -1203,7 +1203,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
             return;
         }
 
-        llama_kv_cache_clear(ctx);
+        llama_kv_self_clear(ctx);
 
         // decode all tasks [i0, i1)
         if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@@ -1575,7 +1575,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
             return;
         }
 
-        llama_kv_cache_clear(ctx);
+        llama_kv_self_clear(ctx);
 
         // decode all tasks [i0, i1)
         if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@@ -1765,7 +1765,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
         }
 
         // clear the KV cache
-        llama_kv_cache_clear(ctx);
+        llama_kv_self_clear(ctx);
 
         llama_batch batch = llama_batch_init(n_batch, 0, 1);
 
index bd2f734670de8cc9338c4f8b08e946610ba2e839..dd07ab9b374560107cbf3dbd88228b4e65c61237 100644 (file)
@@ -1,6 +1,6 @@
 #include "ggml.h"
 #include "llama.h"
-#include "llama-context.h"
+#include "llama-model.h"
 #include "common.h"
 
 #include <algorithm>
@@ -328,7 +328,7 @@ int main(int argc, char ** argv) {
         }
     }
 
-    const auto & tensors = llama_internal_get_tensor_map(ctx);
+    const auto & tensors = llama_internal_get_tensor_map(model);
 
     // check layer tensors
     int included_layers = 0;
index 2439022a229b7d3dd86c8fd22c5f50633f9fff78..0efe20d4b3f5d0bd2e63a005984c25c60e4cbd74 100644 (file)
@@ -83,7 +83,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
 
 static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
     // clear previous kv_cache values (irrelevant for embeddings)
-    llama_kv_cache_clear(ctx);
+    llama_kv_self_clear(ctx);
 
     // run model
     LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
index 38407d5190923c193e4657b857ccae4238da4dac..437f2533e5777116029a6cf630d87006c4cfa298 100644 (file)
@@ -891,7 +891,7 @@ static int apply_chat_template(const struct common_chat_templates * tmpls, Llama
 // Function to tokenize the prompt
 static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
                            std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data) {
-    const bool is_first = llama_get_kv_cache_used_cells(llama_data.context.get()) == 0;
+    const bool is_first = llama_kv_self_used_cells(llama_data.context.get()) == 0;
 
     const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
     prompt_tokens.resize(n_prompt_tokens);
@@ -907,7 +907,7 @@ static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt
 // Check if we have enough space in the context to evaluate this batch
 static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) {
     const int n_ctx      = llama_n_ctx(ctx.get());
-    const int n_ctx_used = llama_get_kv_cache_used_cells(ctx.get());
+    const int n_ctx_used = llama_kv_self_used_cells(ctx.get());
     if (n_ctx_used + batch.n_tokens > n_ctx) {
         printf(LOG_COL_DEFAULT "\n");
         printe("context size exceeded\n");
index cf7cbd8159cf85d3117792bfd8c6c737a5ccf521..760ebbbf08788d350f547a9a185188bb487393f1 100644 (file)
@@ -15,7 +15,7 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
-    print_build_info();
+    common_init();
 
     if (params.n_predict < 0) {
         params.n_predict = 16;
@@ -196,7 +196,7 @@ int main(int argc, char ** argv) {
         fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);
 
         // erase whole kv
-        llama_kv_cache_clear(ctx3);
+        llama_kv_self_clear(ctx3);
         fprintf(stderr, "%s : kv cache cleared\n", __func__);
 
         // restore kv into seq 1
index ce0195475d65865f4e6d76786201d353cbfc1dd7..71e053b202cd2fcdd8af5d77a526aaf65a9852e1 100644 (file)
@@ -2113,7 +2113,7 @@ struct server_context {
         SRV_DBG("%s", "clearing KV cache\n");
 
         // clear the entire KV cache
-        llama_kv_cache_clear(ctx);
+        llama_kv_self_clear(ctx);
         clean_kv_cache = false;
     }
 
@@ -2655,8 +2655,8 @@ struct server_context {
                     res->n_tasks_deferred    = queue_tasks.queue_tasks_deferred.size();
                     res->t_start             = metrics.t_start;
 
-                    res->kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx);
-                    res->kv_cache_used_cells   = llama_get_kv_cache_used_cells(ctx);
+                    res->kv_cache_tokens_count = llama_kv_self_n_tokens(ctx);
+                    res->kv_cache_used_cells   = llama_kv_self_used_cells(ctx);
 
                     res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total;
                     res->t_prompt_processing_total       = metrics.t_prompt_processing_total;
@@ -2772,7 +2772,7 @@ struct server_context {
 
                     // Erase token cache
                     const size_t n_erased = slot->cache_tokens.size();
-                    llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
+                    llama_kv_self_seq_rm(ctx, slot->id, -1, -1);
                     slot->cache_tokens.clear();
 
                     auto res = std::make_unique<server_task_result_slot_erase>();
@@ -2840,8 +2840,8 @@ struct server_context {
 
                 SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
 
-                llama_kv_cache_seq_rm (ctx, slot.id, n_keep            , n_keep + n_discard);
-                llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past,        -n_discard);
+                llama_kv_self_seq_rm (ctx, slot.id, n_keep            , n_keep + n_discard);
+                llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past,        -n_discard);
 
                 if (slot.params.cache_prompt) {
                     for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
@@ -3032,8 +3032,8 @@ struct server_context {
 
                                             const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
 
-                                            llama_kv_cache_seq_rm (ctx, slot.id, head_p, head_c);
-                                            llama_kv_cache_seq_add(ctx, slot.id, head_c, head_c + n_match, kv_shift);
+                                            llama_kv_self_seq_rm (ctx, slot.id, head_p, head_c);
+                                            llama_kv_self_seq_add(ctx, slot.id, head_c, head_c + n_match, kv_shift);
 
                                             for (size_t i = 0; i < n_match; i++) {
                                                 slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
@@ -3071,9 +3071,9 @@ struct server_context {
                     }
 
                     // keep only the common part
-                    if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
+                    if (!llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1)) {
                         // could not partially delete (likely using a non-Transformer model)
-                        llama_kv_cache_seq_rm(ctx, slot.id, -1, -1);
+                        llama_kv_self_seq_rm(ctx, slot.id, -1, -1);
 
                         // there is no common part left
                         slot.n_past = 0;
@@ -3313,7 +3313,7 @@ struct server_context {
                 slot.cache_tokens.push_back(id);
                 slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
 
-                llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
+                llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1);
 
                 for (size_t i = 0; i < ids.size(); ++i) {
                     completion_token_output result;
index ec2d8ec55853c36526d9746e8aa41c92934420d8..30aa8660950a14c70ef9bfa7e698abb7d0588d36 100644 (file)
@@ -302,7 +302,7 @@ class ServerPreset:
         server.model_hf_repo = "ggml-org/models"
         server.model_hf_file = "tinyllamas/stories260K.gguf"
         server.model_alias = "tinyllama-2"
-        server.n_ctx = 256
+        server.n_ctx = 512
         server.n_batch = 32
         server.n_slots = 2
         server.n_predict = 64
index c5534cc13e4b462c329f435d602948966c4e7bd4..84f41597372603957295b08093bdd997528478f4 100644 (file)
@@ -98,7 +98,7 @@ int main(int argc, char ** argv) {
     auto generate = [&](const std::string & prompt) {
         std::string response;
 
-        const bool is_first = llama_get_kv_cache_used_cells(ctx) == 0;
+        const bool is_first = llama_kv_self_used_cells(ctx) == 0;
 
         // tokenize the prompt
         const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
@@ -113,7 +113,7 @@ int main(int argc, char ** argv) {
         while (true) {
             // check if we have enough space in the context to evaluate this batch
             int n_ctx = llama_n_ctx(ctx);
-            int n_ctx_used = llama_get_kv_cache_used_cells(ctx);
+            int n_ctx_used = llama_kv_self_used_cells(ctx);
             if (n_ctx_used + batch.n_tokens > n_ctx) {
                 printf("\033[0m\n");
                 fprintf(stderr, "context size exceeded\n");
index 403ba2dd21914951ad1a4a604273cd8256a97c59..a5d2bc9d09de717ef9b67d51ba2cb379c53a9312 100644 (file)
@@ -217,7 +217,7 @@ int main(int argc, char ** argv) {
         {
             LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
 
-            llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
+            llama_kv_self_seq_rm(ctx_tgt, 0, n_past, -1);
         }
 
         if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
index c7ccea50dbbd47c8bffcb824d35f975139048741..bfddc67e034fbe42d5cf6da83bf8056422e73de1 100644 (file)
@@ -420,14 +420,14 @@ int main(int argc, char ** argv) {
             {
                 LOG_DBG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
 
-                llama_kv_cache_seq_keep(ctx_dft, s_keep);
-                llama_kv_cache_seq_cp  (ctx_dft, s_keep, 0, -1, -1);
-                llama_kv_cache_seq_keep(ctx_dft, 0);
-
-                llama_kv_cache_seq_rm  (ctx_tgt, s_keep, n_past_tgt, -1);
-                llama_kv_cache_seq_keep(ctx_tgt, s_keep);
-                llama_kv_cache_seq_cp  (ctx_tgt, s_keep, 0, -1, -1);
-                llama_kv_cache_seq_keep(ctx_tgt, 0);
+                llama_kv_self_seq_keep(ctx_dft, s_keep);
+                llama_kv_self_seq_cp  (ctx_dft, s_keep, 0, -1, -1);
+                llama_kv_self_seq_keep(ctx_dft, 0);
+
+                llama_kv_self_seq_rm  (ctx_tgt, s_keep, n_past_tgt, -1);
+                llama_kv_self_seq_keep(ctx_tgt, s_keep);
+                llama_kv_self_seq_cp  (ctx_tgt, s_keep, 0, -1, -1);
+                llama_kv_self_seq_keep(ctx_tgt, 0);
             }
 
             for (int s = 0; s < n_seq_dft; ++s) {
@@ -444,7 +444,7 @@ int main(int argc, char ** argv) {
             common_batch_clear(batch_dft);
             common_batch_add  (batch_dft, token_id, n_past_dft, { 0 }, true);
 
-            llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
+            llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1);
             // LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
             llama_decode(ctx_dft, batch_dft);
 
@@ -503,8 +503,8 @@ int main(int argc, char ** argv) {
                     if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_draft_split) {
                         LOG_DBG("splitting seq %3d into %3d\n", s, n_seq_cur);
 
-                        llama_kv_cache_seq_rm(ctx_dft,    n_seq_cur, -1, -1);
-                        llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
+                        llama_kv_self_seq_rm(ctx_dft,    n_seq_cur, -1, -1);
+                        llama_kv_self_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
 
                         // all previous tokens from this branch are now also part of the new branch
                         for (int t = 0; t < batch_tgt.n_tokens; ++t) {
@@ -585,9 +585,9 @@ int main(int argc, char ** argv) {
 
         // evaluate the target model on the drafted tokens
         {
-            llama_kv_cache_seq_keep(ctx_tgt, 0);
+            llama_kv_self_seq_keep(ctx_tgt, 0);
             for (int s = 1; s < n_seq_dft; ++s) {
-                llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
+                llama_kv_self_seq_cp(ctx_tgt, 0, s, -1, -1);
             }
 
             // LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
index d62792c0a67609ec39e2f29e3e165411c3e9dd8a..e5286f06162ab5a90118b11a6a9ef783326bf632 100644 (file)
@@ -60,6 +60,7 @@ extern "C" {
     struct llama_model;
     struct llama_context;
     struct llama_sampler;
+    struct llama_kv_cache;
 
     typedef int32_t llama_pos;
     typedef int32_t llama_token;
@@ -469,7 +470,8 @@ extern "C" {
     DEPRECATED(LLAMA_API int32_t llama_n_vocab    (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
 
     LLAMA_API const struct llama_model * llama_get_model   (const struct llama_context * ctx);
-    LLAMA_API enum llama_pooling_type    llama_pooling_type(const struct llama_context * ctx);
+    LLAMA_API    struct llama_kv_cache * llama_get_kv_self (      struct llama_context * ctx);
+    LLAMA_API  enum llama_pooling_type   llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
 
     LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
     LLAMA_API enum llama_rope_type       llama_model_rope_type(const struct llama_model * model);
@@ -586,7 +588,7 @@ extern "C" {
     // KV cache
     //
 
-    // TODO: remove llama_kv_cache_view_* API
+    // TODO: start using struct llama_kv_cache
 
     // Information associated with an individual cell in the KV cache view.
     struct llama_kv_cache_view_cell {
@@ -641,13 +643,19 @@ extern "C" {
 
     // Returns the number of tokens in the KV cache (slow, use only for debug)
     // If a KV cell has multiple sequences assigned to it, it will be counted multiple times
-    LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx);
+    LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx);
+
+    DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx),
+            "use llama_kv_self_n_tokens instead");
 
     // Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
-    LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
+    LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx);
+
+    DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx),
+            "use llama_kv_self_used_cells instead");
 
     // Clear the KV cache - both cell info is erased and KV data is zeroed
-    LLAMA_API void llama_kv_cache_clear(
+    LLAMA_API void llama_kv_self_clear(
             struct llama_context * ctx);
 
     // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
@@ -655,7 +663,7 @@ extern "C" {
     // seq_id < 0 : match any sequence
     // p0 < 0     : [0,  p1]
     // p1 < 0     : [p0, inf)
-    LLAMA_API bool llama_kv_cache_seq_rm(
+    LLAMA_API bool llama_kv_self_seq_rm(
             struct llama_context * ctx,
                     llama_seq_id   seq_id,
                        llama_pos   p0,
@@ -665,7 +673,7 @@ extern "C" {
     // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
     // p0 < 0 : [0,  p1]
     // p1 < 0 : [p0, inf)
-    LLAMA_API void llama_kv_cache_seq_cp(
+    LLAMA_API void llama_kv_self_seq_cp(
             struct llama_context * ctx,
                     llama_seq_id   seq_id_src,
                     llama_seq_id   seq_id_dst,
@@ -673,17 +681,17 @@ extern "C" {
                        llama_pos   p1);
 
     // Removes all tokens that do not belong to the specified sequence
-    LLAMA_API void llama_kv_cache_seq_keep(
+    LLAMA_API void llama_kv_self_seq_keep(
             struct llama_context * ctx,
                     llama_seq_id   seq_id);
 
     // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
     // If the KV cache is RoPEd, the KV data is updated accordingly:
     //   - lazily on next llama_decode()
-    //   - explicitly with llama_kv_cache_update()
+    //   - explicitly with llama_kv_self_update()
     // p0 < 0 : [0,  p1]
     // p1 < 0 : [p0, inf)
-    LLAMA_API void llama_kv_cache_seq_add(
+    LLAMA_API void llama_kv_self_seq_add(
             struct llama_context * ctx,
                     llama_seq_id   seq_id,
                        llama_pos   p0,
@@ -693,10 +701,10 @@ extern "C" {
     // Integer division of the positions by factor of `d > 1`
     // If the KV cache is RoPEd, the KV data is updated accordingly:
     //   - lazily on next llama_decode()
-    //   - explicitly with llama_kv_cache_update()
+    //   - explicitly with llama_kv_self_update()
     // p0 < 0 : [0,  p1]
     // p1 < 0 : [p0, inf)
-    LLAMA_API void llama_kv_cache_seq_div(
+    LLAMA_API void llama_kv_self_seq_div(
             struct llama_context * ctx,
                     llama_seq_id   seq_id,
                        llama_pos   p0,
@@ -704,24 +712,76 @@ extern "C" {
                              int   d);
 
     // Returns the largest position present in the KV cache for the specified sequence
-    LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
+    LLAMA_API llama_pos llama_kv_self_seq_pos_max(
             struct llama_context * ctx,
-                    llama_seq_id   seq_id);
-
-    // TODO: the llama_kv_cache_defrag and llama_kv_cache_update API tightly couples llama_context with llama_kv_cache
-    //       how to avoid this?
+                     llama_seq_id   seq_id);
 
     // Defragment the KV cache
     // This will be applied:
     //   - lazily on next llama_decode()
-    //   - explicitly with llama_kv_cache_update()
-    LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx);
+    //   - explicitly with llama_kv_self_update()
+    LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx);
+
+    // Check if the context supports KV cache shifting
+    LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
 
     // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
-    LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);
+    LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
+
+    DEPRECATED(LLAMA_API void llama_kv_cache_clear(
+            struct llama_context * ctx),
+            "use llama_kv_self_clear instead");
+
+    DEPRECATED(LLAMA_API bool llama_kv_cache_seq_rm(
+            struct llama_context * ctx,
+                    llama_seq_id   seq_id,
+                       llama_pos   p0,
+                       llama_pos   p1),
+            "use llama_kv_self_seq_rm instead");
+
+    DEPRECATED(LLAMA_API void llama_kv_cache_seq_cp(
+            struct llama_context * ctx,
+                    llama_seq_id   seq_id_src,
+                    llama_seq_id   seq_id_dst,
+                       llama_pos   p0,
+                       llama_pos   p1),
+            "use llama_kv_self_seq_cp instead");
+
+    DEPRECATED(LLAMA_API void llama_kv_cache_seq_keep(
+            struct llama_context * ctx,
+                    llama_seq_id   seq_id),
+            "use llama_kv_self_seq_keep instead");
+
+    DEPRECATED(LLAMA_API void llama_kv_cache_seq_add(
+            struct llama_context * ctx,
+                    llama_seq_id   seq_id,
+                       llama_pos   p0,
+                       llama_pos   p1,
+                       llama_pos   delta),
+            "use llama_kv_self_seq_add instead");
+
+    DEPRECATED(LLAMA_API void llama_kv_cache_seq_div(
+            struct llama_context * ctx,
+                    llama_seq_id   seq_id,
+                       llama_pos   p0,
+                       llama_pos   p1,
+                             int   d),
+            "use llama_kv_self_seq_div instead");
+
+    DEPRECATED(LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
+            struct llama_context * ctx,
+                    llama_seq_id   seq_id),
+            "use llama_kv_self_seq_pos_max instead");
+
+    DEPRECATED(LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx),
+            "use llama_kv_self_defrag instead");
+
+    DEPRECATED(LLAMA_API bool llama_kv_cache_can_shift(const struct llama_context * ctx),
+            "use llama_kv_self_can_shift instead");
+
+    DEPRECATED(LLAMA_API void llama_kv_cache_update(struct llama_context * ctx),
+            "use llama_kv_self_update instead");
 
-    // Check if the context supports KV cache shifting
-    LLAMA_API bool llama_kv_cache_can_shift(struct llama_context * ctx);
 
     //
     // State / sessions
index e1b02e4c08f07d233b071ccfeffb382cca130ffd..b340dae5b28cdcafc3191a105c56aa3cb23f31d9 100644 (file)
@@ -15,18 +15,21 @@ add_library(llama
             llama-chat.cpp
             llama-context.cpp
             llama-grammar.cpp
+            llama-graph.cpp
             llama-hparams.cpp
             llama-impl.cpp
+            llama-io.cpp
             llama-kv-cache.cpp
+            llama-memory.cpp
             llama-mmap.cpp
             llama-model-loader.cpp
             llama-model.cpp
             llama-quant.cpp
             llama-sampling.cpp
             llama-vocab.cpp
-            unicode.h
-            unicode.cpp
             unicode-data.cpp
+            unicode.cpp
+            unicode.h
             )
 
 target_include_directories(llama PUBLIC . ../include ../common)
index 8a0800463137ee13d0c90c6a1c1e8918eab209f2..b448614e471d64df3c926d047c77a431aadd60ce 100644 (file)
@@ -4,14 +4,13 @@
 #include "llama-mmap.h"
 #include "llama-model.h"
 
-#include <algorithm>
 #include <map>
 #include <cassert>
 #include <stdexcept>
 
 // vec
 
-struct ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
+ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
     if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) {
         return nullptr;
     }
@@ -19,7 +18,7 @@ struct ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
     return tensors[il];
 }
 
-struct ggml_tensor * llama_adapter_cvec::apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int  il) const {
+ggml_tensor * llama_adapter_cvec::apply_to(ggml_context * ctx, ggml_tensor * cur, int  il) const {
     ggml_tensor * layer_dir = tensor_for(il);
     if (layer_dir != nullptr) {
         cur = ggml_add(ctx, cur, layer_dir);
@@ -40,7 +39,7 @@ bool llama_adapter_cvec::init(const llama_model & model) {
     auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
         auto it = ctx_map.find(buft);
         if (it == ctx_map.end()) {
-            struct ggml_init_params params = {
+            ggml_init_params params = {
                 /*.mem_size   =*/ hparams.n_layer*ggml_tensor_overhead(),
                 /*.mem_buffer =*/ NULL,
                 /*.no_alloc   =*/ true,
@@ -91,7 +90,7 @@ bool llama_adapter_cvec::init(const llama_model & model) {
     return true;
 }
 
-int32_t llama_adapter_cvec::apply(
+bool llama_adapter_cvec::apply(
         const llama_model & model,
         const float * data,
         size_t len,
@@ -104,17 +103,17 @@ int32_t llama_adapter_cvec::apply(
         // disable the current control vector (but leave allocated for later)
         layer_start = -1;
         layer_end   = -1;
-        return 0;
+        return true;
     }
 
     if (n_embd != (int) hparams.n_embd) {
         LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__);
-        return 1;
+        return false;
     }
 
     if (tensors.empty()) {
         if (!init(model)) {
-            return 1;
+            return false;
         }
     }
 
@@ -130,12 +129,12 @@ int32_t llama_adapter_cvec::apply(
         }
     }
 
-    return 0;
+    return true;
 }
 
 // lora
 
-llama_adapter_lora_weight * llama_adapter_lora::get_weight(struct ggml_tensor * w) {
+llama_adapter_lora_weight * llama_adapter_lora::get_weight(ggml_tensor * w) {
     const std::string name(w->name);
 
     const auto pos = ab_map.find(name);
@@ -146,11 +145,11 @@ llama_adapter_lora_weight * llama_adapter_lora::get_weight(struct ggml_tensor *
     return nullptr;
 }
 
-static void llama_adapter_lora_init_impl(struct llama_model & model, const char * path_lora, struct llama_adapter_lora & adapter) {
+static void llama_adapter_lora_init_impl(llama_model & model, const char * path_lora, llama_adapter_lora & adapter) {
     LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
 
     ggml_context * ctx_init;
-    struct gguf_init_params meta_gguf_params = {
+    gguf_init_params meta_gguf_params = {
         /* .no_alloc = */ true,
         /* .ctx      = */ &ctx_init,
     };
@@ -201,7 +200,7 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
         auto it = ctx_map.find(buft);
         if (it == ctx_map.end()) {
             // add a new context
-            struct ggml_init_params params = {
+            ggml_init_params params = {
                 /*.mem_size   =*/ n_tensors*ggml_tensor_overhead(),
                 /*.mem_buffer =*/ NULL,
                 /*.no_alloc   =*/ true,
@@ -264,7 +263,7 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
             throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)");
         }
 
-        struct ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer));
+        ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer));
         // validate tensor shape
         if (is_token_embd) {
             // expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd()
@@ -281,8 +280,8 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
         }
 
         // save tensor to adapter
-        struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
-        struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
+        ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
+        ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
         ggml_set_name(tensor_a, w.a->name);
         ggml_set_name(tensor_b, w.b->name);
         adapter.ab_map[name] = llama_adapter_lora_weight(tensor_a, tensor_b);
@@ -308,7 +307,7 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
     {
         llama_file gguf_file(path_lora, "rb");
         std::vector<uint8_t> read_buf;
-        auto set_tensor = [&](struct ggml_tensor * orig, struct ggml_tensor * dev) {
+        auto set_tensor = [&](ggml_tensor * orig, ggml_tensor * dev) {
             size_t offs = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), gguf_find_tensor(ctx_gguf.get(), orig->name));
             size_t size = ggml_nbytes(orig);
             read_buf.resize(size);
@@ -327,8 +326,8 @@ static void llama_adapter_lora_init_impl(struct llama_model & model, const char
     LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2);
 }
 
-struct llama_adapter_lora * llama_adapter_lora_init(struct llama_model * model, const char * path_lora) {
-    struct llama_adapter_lora * adapter = new llama_adapter_lora();
+llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * path_lora) {
+    llama_adapter_lora * adapter = new llama_adapter_lora();
 
     try {
         llama_adapter_lora_init_impl(*model, path_lora, *adapter);
@@ -342,6 +341,6 @@ struct llama_adapter_lora * llama_adapter_lora_init(struct llama_model * model,
     return nullptr;
 }
 
-void llama_adapter_lora_free(struct llama_adapter_lora * adapter) {
+void llama_adapter_lora_free(llama_adapter_lora * adapter) {
     delete adapter;
 }
index 603fa08f6d1867058ed6b788a98f6e2d8dd95aaa..65824e972765bd5b8b845e0c41ec46dc81657a4f 100644 (file)
 //
 
 struct llama_adapter_cvec {
-    struct ggml_tensor * tensor_for(int il) const;
+    ggml_tensor * tensor_for(int il) const;
 
-    struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int  il) const;
+    ggml_tensor * apply_to(ggml_context * ctx, ggml_tensor * cur, int  il) const;
 
-    int32_t apply(
+    bool apply(
             const llama_model & model,
             const float * data,
             size_t len,
@@ -36,7 +36,7 @@ private:
     std::vector<ggml_context_ptr> ctxs;
     std::vector<ggml_backend_buffer_ptr> bufs;
 
-    std::vector<struct ggml_tensor *> tensors; // per layer
+    std::vector<ggml_tensor *> tensors; // per layer
 };
 
 //
@@ -44,8 +44,8 @@ private:
 //
 
 struct llama_adapter_lora_weight {
-    struct ggml_tensor * a = nullptr;
-    struct ggml_tensor * b = nullptr;
+    ggml_tensor * a = nullptr;
+    ggml_tensor * b = nullptr;
 
     // get actual scale based on rank and alpha
     float get_scale(float alpha, float adapter_scale) const {
@@ -55,12 +55,12 @@ struct llama_adapter_lora_weight {
     }
 
     llama_adapter_lora_weight() = default;
-    llama_adapter_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {}
+    llama_adapter_lora_weight(ggml_tensor * a, ggml_tensor * b) : a(a), b(b) {}
 };
 
 struct llama_adapter_lora {
     // map tensor name to lora_a_b
-    std::unordered_map<std::string, struct llama_adapter_lora_weight> ab_map;
+    std::unordered_map<std::string, llama_adapter_lora_weight> ab_map;
 
     std::vector<ggml_context_ptr> ctxs;
     std::vector<ggml_backend_buffer_ptr> bufs;
@@ -70,5 +70,7 @@ struct llama_adapter_lora {
     llama_adapter_lora() = default;
     ~llama_adapter_lora() = default;
 
-    llama_adapter_lora_weight * get_weight(struct ggml_tensor * w);
+    llama_adapter_lora_weight * get_weight(ggml_tensor * w);
 };
+
+using llama_adapter_loras = std::unordered_map<llama_adapter_lora *, float>;
index 773c3808b770fed6d4d72e38b457deebacab07d7..f1df40d27086e7a47c56dd43bfabfdaf3fa0688b 100644 (file)
@@ -42,9 +42,9 @@ struct llama_sbatch {
     bool logits_all; // TODO: remove once lctx.logits_all is removed too
 
     // sorted indices into the batch
-    std::vector<size_t> ids;
+    std::vector<int64_t> ids;
     // batch indices of the output
-    std::vector<size_t> out_ids;
+    std::vector<int64_t> out_ids;
     std::vector<llama_sbatch_seq> seq;
 
     const llama_batch * batch = nullptr;
index 671d2a81adabf77ea2a46ebee009191959530c82..0a43a3af8e0035a73bee0972a99d5d24e4b7c314 100644 (file)
 #include "llama-context.h"
 
 #include "llama-impl.h"
+#include "llama-io.h"
 #include "llama-mmap.h"
+#include "llama-model.h"
+#include "llama-kv-cache.h"
 
 #include <cassert>
-#include <cmath>
 #include <cstring>
 #include <stdexcept>
+#include <cinttypes>
 
-void llama_set_k_shift(struct llama_context & lctx) {
-    const int64_t kv_size = lctx.kv_self.size;
+//
+// llama_context
+//
 
-    assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
+llama_context::llama_context(
+        const llama_model & model,
+              llama_context_params params) :
+    model(model) {
+    LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
 
-    int32_t * data = (int32_t *) lctx.inp_K_shift->data;
+    t_start_us = model.t_start_us;
+    t_load_us  = model.t_load_us;
 
-    for (int i = 0; i < kv_size; ++i) {
-        data[i] = lctx.kv_self.cells[i].delta;
-    }
-}
+    const auto & hparams = model.hparams;
 
-void llama_set_s_copy(struct llama_context & lctx) {
-    const int64_t kv_size = lctx.kv_self.size;
+    cparams.n_seq_max        = std::max(1u, params.n_seq_max);
+    cparams.n_threads        = params.n_threads;
+    cparams.n_threads_batch  = params.n_threads_batch;
+    cparams.yarn_ext_factor  = params.yarn_ext_factor;
+    cparams.yarn_attn_factor = params.yarn_attn_factor;
+    cparams.yarn_beta_fast   = params.yarn_beta_fast;
+    cparams.yarn_beta_slow   = params.yarn_beta_slow;
+    cparams.defrag_thold     = params.defrag_thold;
+    cparams.embeddings       = params.embeddings;
+    cparams.offload_kqv      = params.offload_kqv;
+    cparams.flash_attn       = params.flash_attn;
+    cparams.no_perf          = params.no_perf;
+    cparams.pooling_type     = params.pooling_type;
 
-    assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
+    cparams.n_ctx            = params.n_ctx           == 0    ? hparams.n_ctx_train           : params.n_ctx;
+    cparams.rope_freq_base   = params.rope_freq_base  == 0.0f ? hparams.rope_freq_base_train  : params.rope_freq_base;
+    cparams.rope_freq_scale  = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
 
-    int32_t * data = (int32_t *) lctx.inp_s_copy->data;
+    cparams.n_ctx_orig_yarn  = params.yarn_orig_ctx    != 0 ? params.yarn_orig_ctx    :
+                               hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn :
+                                                              hparams.n_ctx_train;
 
-    for (int i = 0; i < kv_size; ++i) {
-        data[i] = lctx.kv_self.cells[i].src;
-    }
-}
+    cparams.cb_eval           = params.cb_eval;
+    cparams.cb_eval_user_data = params.cb_eval_user_data;
 
-// llama input
+    auto rope_scaling_type = params.rope_scaling_type;
+    if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
+        rope_scaling_type = hparams.rope_scaling_type_train;
+    }
 
-static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
-    // TODO move to hparams if a T5 variant appears that uses a different value
-    const int64_t max_distance = 128;
+    if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) {
+        cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none
+    }
 
-    if (bidirectional) {
-        n_buckets >>= 1;
+    if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set'
+        cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
     }
 
-    const int64_t max_exact = n_buckets >> 1;
+    cparams.yarn_attn_factor *= hparams.rope_attn_factor;
 
-    int32_t relative_position = x - y;
-    int32_t relative_bucket = 0;
-    if (bidirectional) {
-        relative_bucket += (relative_position > 0) * n_buckets;
-        relative_position = abs(relative_position);
+    if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
+        if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
+            cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
+        } else {
+            cparams.pooling_type = hparams.pooling_type;
+        }
+    }
+
+    if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
+        cparams.causal_attn = hparams.causal_attn;
     } else {
-        relative_position = -std::min<int32_t>(relative_position, 0);
+        cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
     }
-    int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
-    relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
-    relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
-    return relative_bucket;
-}
 
-void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
-    //
-    // set input data
-    //
+    // with causal attention, the batch size is limited by the context size
+    cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
 
-    const auto & hparams = lctx.model.hparams;
-    const auto & cparams = lctx.cparams;
-    const auto & kv_self = lctx.kv_self;
+    // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
+    // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
+    // ref: https://github.com/ggerganov/llama.cpp/pull/5021
+    // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_context_kv_self
+    if (cparams.n_batch < GGML_KQ_MASK_PAD) {
+        LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
+        cparams.n_batch = GGML_KQ_MASK_PAD;
+    }
 
-    if (ubatch.token) {
-        const int64_t n_tokens = ubatch.n_tokens;
+    cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
 
-        ggml_backend_tensor_set(lctx.inp_tokens, ubatch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens));
-    }
+    const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
 
-    if (ubatch.embd) {
-        const int64_t n_embd   = hparams.n_embd;
-        const int64_t n_tokens = ubatch.n_tokens;
+    LLAMA_LOG_INFO("%s: n_seq_max     = %u\n",   __func__, cparams.n_seq_max);
+    LLAMA_LOG_INFO("%s: n_ctx         = %u\n",   __func__, cparams.n_ctx);
+    LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n",   __func__, n_ctx_per_seq);
+    LLAMA_LOG_INFO("%s: n_batch       = %u\n",   __func__, cparams.n_batch);
+    LLAMA_LOG_INFO("%s: n_ubatch      = %u\n",   __func__, cparams.n_ubatch);
+    LLAMA_LOG_INFO("%s: causal_attn   = %d\n",   __func__, cparams.causal_attn);
+    LLAMA_LOG_INFO("%s: flash_attn    = %d\n",   __func__, cparams.flash_attn);
+    LLAMA_LOG_INFO("%s: freq_base     = %.1f\n", __func__, cparams.rope_freq_base);
+    LLAMA_LOG_INFO("%s: freq_scale    = %g\n",   __func__, cparams.rope_freq_scale);
 
-        ggml_backend_tensor_set(lctx.inp_embd, ubatch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
+    if (n_ctx_per_seq < hparams.n_ctx_train) {
+        LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
+                __func__, n_ctx_per_seq, hparams.n_ctx_train);
     }
 
-    if (ubatch.pos && lctx.inp_pos) {
-        const int64_t n_tokens = ubatch.n_tokens;
-        auto n_pos = lctx.n_pos_per_token;
-        ggml_backend_tensor_set(lctx.inp_pos, ubatch.pos, 0, n_tokens*n_pos*ggml_element_size(lctx.inp_pos));
+    if (n_ctx_per_seq > hparams.n_ctx_train) {
+        LLAMA_LOG_WARN("%s: n_ctx_pre_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
+                __func__, n_ctx_per_seq, hparams.n_ctx_train);
     }
 
-    if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
-        //GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
-
-        if (!lctx.inp_out_ids) {
-            LLAMA_LOG_WARN("%s: 'lctx.inp_out_ids' is not created\n", __func__);
-        } else {
-            const int64_t n_tokens = ubatch.n_tokens;
+    logits_all = params.logits_all;
 
-            GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer));
-            int32_t * data = (int32_t *) lctx.inp_out_ids->data;
+    if (!hparams.vocab_only) {
+        // GPU backends
+        for (auto * dev : model.devices) {
+            ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
+            if (backend == nullptr) {
+                throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev)));
+            }
+            backends.emplace_back(backend);
+        }
 
-            if (lctx.n_outputs == n_tokens) {
-                for (int i = 0; i < n_tokens; ++i) {
-                    data[i] = i;
+        // add ACCEL backends (such as BLAS)
+        for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
+            ggml_backend_dev_t dev = ggml_backend_dev_get(i);
+            if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
+                ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
+                if (backend == nullptr) {
+                    throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev)));
                 }
-            } else if (ubatch.output) {
-                int32_t n_outputs = 0;
-                for (int i = 0; i < n_tokens; ++i) {
-                    if (ubatch.output[i]) {
-                        data[n_outputs++] = i;
-                    }
+                backends.emplace_back(backend);
+            }
+        }
+
+        // add CPU backend
+        backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
+        if (backend_cpu == nullptr) {
+            throw std::runtime_error("failed to initialize CPU backend");
+        }
+        backends.emplace_back(backend_cpu);
+
+        // create a list of the set_n_threads functions in the backends
+        for (auto & backend : backends) {
+            ggml_backend_dev_t dev = ggml_backend_get_device(backend.get());
+            ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
+            if (reg) {
+                auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
+                if (ggml_backend_set_n_threads_fn) {
+                    set_n_threads_fns.emplace_back(backend.get(), ggml_backend_set_n_threads_fn);
                 }
-                // the graph needs to have been passed the correct number of outputs
-                GGML_ASSERT(lctx.n_outputs == n_outputs);
-            } else if (lctx.n_outputs == 1) {
-                // only keep last output
-                data[0] = n_tokens - 1;
-            } else {
-                GGML_ASSERT(lctx.n_outputs == 0);
             }
         }
-    }
 
-    GGML_ASSERT(
-        // (!a || b) is a logical implication (a -> b)
-        // !hparams.causal_attn -> !cparams.causal_attn
-        (hparams.causal_attn || !cparams.causal_attn) &&
-        "causal attention is not supported by this model"
-    );
+        llama_set_abort_callback(this, params.abort_callback, params.abort_callback_data);
 
-    if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) {
-        // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
-        if (cparams.causal_attn && !lctx.is_encoding) {
-            const int64_t n_kv         = kv_self.n;
-            const int64_t n_tokens     = ubatch.n_tokens;
-            const int64_t n_seq_tokens = ubatch.n_seq_tokens;
-            const int64_t n_seqs       = ubatch.n_seqs;
+        // graph outputs buffer
+        {
+            // resized during inference when a batch uses more outputs
+            if ((uint32_t) output_reserve(params.n_seq_max) < params.n_seq_max) {
+                throw std::runtime_error("failed to reserve initial output buffer");
+            }
 
+            LLAMA_LOG_INFO("%s: %10s  output buffer size = %8.2f MiB\n", __func__,
+                    ggml_backend_buffer_name    (buf_output.get()),
+                    ggml_backend_buffer_get_size(buf_output.get()) / 1024.0 / 1024.0);
+        }
+    }
 
-            float * data     = nullptr;
-            float * data_swa = nullptr;
+    // init the memory module
+    // TODO: for now, always create a unified KV cache
+    if (!hparams.vocab_only) {
+        kv_self.reset(static_cast<llama_kv_cache_unified *>(model.create_memory()));
 
-            if (lctx.inp_KQ_mask) {
-                GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
-                data = (float *) lctx.inp_KQ_mask->data;
-            }
+        LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
 
-            if (lctx.inp_KQ_mask_swa) {
-                GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer));
-                data_swa = (float *) lctx.inp_KQ_mask_swa->data;
-            }
+        cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv_self->get_padding(cparams));
 
-            // For causal attention, use only the previous KV cells
-            // of the correct sequence for each token of the ubatch.
-            // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
-            for (int h = 0; h < 1; ++h) {
-                for (int s = 0; s < n_seqs; ++s) {
-                    const llama_seq_id seq_id = ubatch.seq_id[s][0];
-
-                    for (int j = 0; j < n_seq_tokens; ++j) {
-                        const llama_pos pos = ubatch.pos[s*n_seq_tokens + j];
-
-                        for (int i = 0; i < n_kv; ++i) {
-                            float f;
-                            if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
-                                f = -INFINITY;
-                            } else {
-                                if (hparams.use_alibi) {
-                                    f = -std::abs(kv_self.cells[i].pos - pos);
-                                } else {
-                                    f = 0.0f;
-                                }
-                            }
+        LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
 
-                            if (data) {
-                                data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
-                            }
+        uint32_t kv_size = cparams.n_ctx;
+        ggml_type type_k = params.type_k;
+        ggml_type type_v = params.type_v;
 
-                            // may need to cut off old tokens for sliding window
-                            if (data_swa) {
-                                if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
-                                    f = -INFINITY;
-                                }
-                                data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
-                            }
-                        }
-                    }
-                }
+        if (llama_model_is_recurrent(&model)) {
+            // Mamba needs at least as many KV cells as there are sequences kept at any time
+            kv_size = std::max((uint32_t) 1, params.n_seq_max);
+            // it's probably best to keep as much precision as possible for the states
+            type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
+            type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
+        }
 
-                if (data) {
-                    for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
-                        for (int j = 0; j < n_kv; ++j) {
-                            data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
-                        }
-                    }
-                }
+        GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
+        GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
 
-                if (data_swa) {
-                    for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
-                        for (int j = 0; j < n_kv; ++j) {
-                            data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
-                        }
-                    }
-                }
-            }
-        } else {
-            const int64_t n_tokens     = ubatch.n_tokens;
-            const int64_t n_seq_tokens = ubatch.n_seq_tokens;
-            const int64_t n_seqs       = ubatch.n_seqs;
-            // when using kv cache, the mask needs to match the kv cache size
-            const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens;
-
-            GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
-
-            float * data = (float *) lctx.inp_KQ_mask->data;
-
-            for (int h = 0; h < 1; ++h) {
-                for (int s1 = 0; s1 < n_seqs; ++s1) {
-                    const llama_seq_id seq_id = ubatch.seq_id[s1][0];
-
-                    for (int j = 0; j < n_seq_tokens; ++j) {
-                        const int32_t tj = s1*n_seq_tokens + j;
-
-                        for (int s0 = 0; s0 < n_seqs; ++s0) {
-                            for (int i = 0; i < n_seq_tokens; ++i) {
-                                const int32_t ti = s0*n_seq_tokens + i;
-                                float f = -INFINITY;
-
-                                for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) {
-                                    if (ubatch.seq_id[s0][s] == seq_id) {
-                                        if (hparams.use_alibi) {
-                                            f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]);
-                                        } else {
-                                            f = 0.0f;
-                                        }
-                                        break;
-                                    }
-                                }
-
-                                data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
-                            }
-                        }
+        if (!kv_self->init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
+            throw std::runtime_error("failed to initialize self-attention cache");
+        }
 
-                        for (int i = n_tokens; i < n_stride; ++i) {
-                            data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
-                        }
-                    }
-                }
-            }
+        {
+            const size_t memory_size_k = kv_self->size_k_bytes();
+            const size_t memory_size_v = kv_self->size_v_bytes();
+
+            LLAMA_LOG_INFO("%s: KV self size  = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
+                    (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
+                    ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
+                    ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
         }
     }
 
-    if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
-        const int64_t n_tokens     = ubatch.n_tokens;
-        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
-        const int64_t n_seqs       = ubatch.n_seqs;
+    // init backends
+    if (!hparams.vocab_only) {
+        LLAMA_LOG_DEBUG("%s: enumerating backends\n", __func__);
 
-        GGML_ASSERT(lctx.inp_mean);
-        GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
+        backend_buft.clear();
+        backend_ptrs.clear();
 
-        float * data = (float *) lctx.inp_mean->data;
-        memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean));
+        for (auto & backend : backends) {
+            auto * buft = ggml_backend_get_default_buffer_type(backend.get());
+            auto backend_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
 
-        std::vector<uint64_t> sum(n_tokens, 0);
+            if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model.devices.empty()) {
+                // use the host buffer of the first device CPU for faster transfer of the intermediate state
+                auto * dev = model.devices[0];
+                auto * host_buft = ggml_backend_dev_host_buffer_type(dev);
+                if (host_buft) {
+                    buft = host_buft;
+                }
+            }
 
-        for (int s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = ubatch.seq_id[s][0];
+            backend_buft.push_back(buft);
+            backend_ptrs.push_back(backend.get());
+        }
 
-            // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
-            GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
+        LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
 
-            sum[seq_id] += ubatch.n_seq_tokens;
-        }
+        const size_t max_nodes = this->graph_max_nodes();
 
-        std::vector<float> div(n_tokens, 0.0f);
-        for (int i = 0; i < n_tokens; ++i) {
-            const uint64_t s = sum[i];
-            if (s > 0) {
-                div[i] = 1.0f/float(s);
+        LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
+
+        // buffer used to store the computation graph and the tensor meta data
+        buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
+
+        // TODO: move these checks to ggml_backend_sched
+        // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
+        bool pipeline_parallel =
+            model.n_devices() > 1 &&
+            model.params.n_gpu_layers > (int) model.hparams.n_layer &&
+            model.params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
+            cparams.offload_kqv;
+
+        // pipeline parallelism requires support for async compute and events in all devices
+        if (pipeline_parallel) {
+            for (auto & backend : backends) {
+                auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
+                if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) {
+                    // ignore CPU backend
+                    continue;
+                }
+                auto * dev = ggml_backend_get_device(backend.get());
+                ggml_backend_dev_props props;
+                ggml_backend_dev_get_props(dev, &props);
+                if (!props.caps.async || !props.caps.events) {
+                    // device does not support async compute or events
+                    pipeline_parallel = false;
+                    break;
+                }
             }
         }
 
-        for (int s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = ubatch.seq_id[s][0];
+        sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));
 
-            for (int i = 0; i < n_seq_tokens; ++i) {
-                data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
-            }
+        if (pipeline_parallel) {
+            LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
         }
     }
 
-    if (cparams.embeddings && (
-                cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
-                cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
-        const int64_t n_tokens     = ubatch.n_tokens;
-        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
-        const int64_t n_seqs       = ubatch.n_seqs;
+    // reserve worst-case graph
+    if (!hparams.vocab_only) {
+        uint32_t n_seqs = 1; // TODO: worst-case number of sequences
+        uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
 
-        GGML_ASSERT(lctx.inp_cls);
-        GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
+        llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
 
-        uint32_t * data = (uint32_t *) lctx.inp_cls->data;
-        memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
+        // max number of outputs
+        n_outputs = n_tokens;
 
-        for (int s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = ubatch.seq_id[s][0];
+        LLAMA_LOG_DEBUG("%s: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
 
-            // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
-            GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
+        int n_splits_pp = -1;
+        int n_nodes_pp  = -1;
 
-            for (int i = 0; i < n_seq_tokens; ++i) {
-                const llama_pos pos = ubatch.pos[s*n_seq_tokens + i];
+        int n_splits_tg = -1;
+        int n_nodes_tg  = -1;
 
-                if (pos == 0) {
-                    data[seq_id] = s*n_seq_tokens + i;
-                }
+        // simulate full KV cache
+        kv_self->n = kv_self->size;
+
+        cross.v_embd.clear();
+
+        // reserve pp graph first so that buffers are only allocated once
+        {
+            llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
+            auto * gf = graph_init();
+            graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
+            if (!ggml_backend_sched_reserve(sched.get(), gf)) {
+                throw std::runtime_error("failed to allocate compute pp buffers");
             }
+
+            n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
+            n_nodes_pp  = ggml_graph_n_nodes(gf);
+        }
+
+        // reserve with tg graph to get the number of splits and nodes
+        {
+            llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
+            auto * gf = graph_init();
+            graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
+            if (!ggml_backend_sched_reserve(sched.get(), gf)) {
+                throw std::runtime_error("failed to allocate compute tg buffers");
+            }
+            n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
+            n_nodes_tg  = ggml_graph_n_nodes(gf);
         }
-    }
 
-    if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
-        const int64_t n_tokens     = ubatch.n_tokens;
-        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
-        const int64_t n_seqs       = ubatch.n_seqs;
+        // reserve again with pp graph to avoid ggml-alloc reallocations during inference
+        {
+            llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
+            auto * gf = graph_init();
+            graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
+            if (!ggml_backend_sched_reserve(sched.get(), gf)) {
+                throw std::runtime_error("failed to allocate compute pp buffers");
+            }
+        }
 
-        GGML_ASSERT(lctx.inp_cls);
-        GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
+        for (size_t i = 0; i < backend_ptrs.size(); ++i) {
+            ggml_backend_t             backend = backend_ptrs[i];
+            ggml_backend_buffer_type_t buft    = backend_buft[i];
+            size_t size = ggml_backend_sched_get_buffer_size(sched.get(), backend);
+            if (size > 1) {
+                LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
+                        ggml_backend_buft_name(buft),
+                        size / 1024.0 / 1024.0);
+            }
+        }
 
-        uint32_t * data = (uint32_t *) lctx.inp_cls->data;
-        memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
+        if (n_nodes_pp == n_nodes_tg) {
+            LLAMA_LOG_INFO("%s: graph nodes  = %d\n", __func__, n_nodes_pp);
+        } else {
+            LLAMA_LOG_INFO("%s: graph nodes  = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg);
+        }
 
-        std::vector<int> last_pos(n_tokens, -1);
-        std::vector<int> last_row(n_tokens, -1);
+        if (n_splits_pp == n_splits_tg) {
+            LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp);
+        } else {
+            LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
+        }
+    }
+}
 
-        for (int s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = ubatch.seq_id[s][0];
+llama_context::~llama_context() = default;
 
-            // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
-            GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
+void llama_context::synchronize() {
+    ggml_backend_sched_synchronize(sched.get());
 
-            for (int i = 0; i < n_seq_tokens; ++i) {
-                const llama_pos pos = ubatch.pos[s*n_seq_tokens + i];
+    // FIXME: if multiple single tokens are evaluated without a synchronization,
+    // the stats will be added to the prompt evaluation stats
+    // this should only happen when using batch size 1 to evaluate a batch
 
-                if (pos >= last_pos[seq_id]) {
-                    last_pos[seq_id] = pos;
-                    last_row[seq_id] = s*n_seq_tokens + i;
-                }
-            }
+    // add the evaluation to the stats
+    if (n_queued_tokens == 1) {
+        if (!cparams.no_perf) {
+            t_eval_us += ggml_time_us() - t_compute_start_us;
         }
-
-        for (int i = 0; i < n_tokens; ++i) {
-            if (last_row[i] >= 0) {
-                data[i] = last_row[i];
-            }
+        n_eval++;
+    } else if (n_queued_tokens > 1) {
+        if (!cparams.no_perf) {
+            t_p_eval_us += ggml_time_us() - t_compute_start_us;
         }
+        n_p_eval += n_queued_tokens;
     }
 
-    if (kv_self.recurrent) {
-        const int64_t n_kv = kv_self.n;
-
-        if (lctx.inp_s_mask) {
-            GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
-            float * data = (float *) lctx.inp_s_mask->data;
+    // get a more accurate load time, upon first eval
+    if (n_queued_tokens > 0 && !has_evaluated_once) {
+        t_load_us = ggml_time_us() - t_start_us;
+        has_evaluated_once = true;
+    }
 
-            // clear unused states
-            for (int i = 0; i < n_kv; ++i) {
-                const uint32_t  cell_id = i + kv_self.head;
-                llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
+    n_queued_tokens = 0;
+    t_compute_start_us = 0;
+}
 
-                data[i] = (float) (kv_cell.src >= 0);
+const llama_model & llama_context::get_model() const {
+    return model;
+}
 
-                // only clear once
-                if (kv_cell.src < 0) {
-                    kv_cell.src = cell_id;
-                }
-            }
-        }
+uint32_t llama_context::n_ctx() const {
+    return cparams.n_ctx;
+}
 
-        if (lctx.inp_s_copy) {
-            GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
-            int32_t * data = (int32_t *) lctx.inp_s_copy->data;
+uint32_t llama_context::n_ctx_per_seq() const {
+    return cparams.n_ctx / cparams.n_seq_max;
+}
 
-            // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
-            for (uint32_t i = 0; i < n_kv; ++i) {
-                const uint32_t  cell_id = i + kv_self.head;
-                llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
+uint32_t llama_context::n_batch() const {
+    return cparams.n_batch;
+}
 
-                // prevent out-of-bound sources
-                if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) {
-                    kv_cell.src = cell_id;
-                }
+uint32_t llama_context::n_ubatch() const {
+    return cparams.n_ubatch;
+}
 
-                data[i] = kv_cell.src;
+uint32_t llama_context::n_seq_max() const {
+    return cparams.n_seq_max;
+}
 
-                // ensure copy only happens once
-                if (kv_cell.src != (int32_t) cell_id) {
-                    kv_cell.src = cell_id;
-                }
-            }
-        }
-    }
+uint32_t llama_context::n_threads() const {
+    return cparams.n_threads;
+}
 
-    if (lctx.inp_pos_bucket) {
-        const int64_t n_tokens = ubatch.n_tokens;
+uint32_t llama_context::n_threads_batch() const {
+    return cparams.n_threads_batch;
+}
 
-        GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer));
-        GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing
+llama_kv_cache * llama_context::get_kv_self() {
+    return kv_self.get();
+}
 
-        int32_t * data = (int32_t *) lctx.inp_pos_bucket->data;
+const llama_kv_cache * llama_context::get_kv_self() const {
+    return kv_self.get();
+}
 
-        if (!lctx.is_encoding) {
-            const int64_t n_kv = kv_self.n;
-            for (int h = 0; h < 1; ++h) {
-                for (int j = 0; j < n_tokens; ++j) {
-                    for (int i = 0; i < n_kv; ++i) {
-                        data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(lctx.kv_self.cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
-                    }
-                }
-            }
-        } else {
-            for (int h = 0; h < 1; ++h) {
-                for (int j = 0; j < n_tokens; ++j) {
-                    for (int i = 0; i < n_tokens; ++i) {
-                        data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch.pos[i], ubatch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
-                    }
+ggml_tensor * llama_context::build_rope_shift(
+        ggml_context * ctx0,
+        ggml_tensor * cur,
+        ggml_tensor * shift,
+        ggml_tensor * factors,
+        ggml_backend_buffer * bbuf) const {
+    const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
+    const auto & freq_base  = cparams.rope_freq_base;
+    const auto & freq_scale = cparams.rope_freq_scale;
+
+    const auto & yarn_ext_factor  = cparams.yarn_ext_factor;
+    const auto & yarn_attn_factor = cparams.yarn_attn_factor;
+    const auto & yarn_beta_fast   = cparams.yarn_beta_fast;
+    const auto & yarn_beta_slow   = cparams.yarn_beta_slow;
+
+    const auto & hparams = model.hparams;
+
+    const auto & n_rot     = hparams.n_rot;
+    const auto & rope_type = hparams.rope_type;
+
+    ggml_tensor * tmp;
+
+    if (ggml_is_quantized(cur->type)) {
+        // dequantize to f32 -> RoPE -> quantize back
+        tmp = ggml_cast(ctx0, cur, GGML_TYPE_F32);
+
+        if (bbuf) {
+            for (const auto & backend : backends) {
+                // Figure out which backend KV cache belongs to
+                if (ggml_backend_supports_buft(backend.get(), ggml_backend_buffer_get_type(bbuf))) {
+                    ggml_backend_sched_set_tensor_backend(sched.get(), tmp, backend.get());
+                    break;
                 }
             }
         }
-    }
 
-    if (!lctx.is_encoding && lctx.inp_embd_enc) {
-        assert(lctx.inp_embd_enc->type == GGML_TYPE_F32);
-        assert((size_t) ggml_nelements(lctx.inp_embd_enc) == lctx.embd_enc.size());
+        tmp = ggml_rope_ext_inplace(ctx0, tmp,
+                shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
 
-        ggml_backend_tensor_set(lctx.inp_embd_enc, lctx.embd_enc.data(), 0, ggml_nbytes(lctx.inp_embd_enc));
+        tmp = ggml_cpy(ctx0, tmp, cur);
+    } else {
+        // we rotate only the first n_rot dimensions
+        tmp = ggml_rope_ext_inplace(ctx0, cur,
+                shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
     }
 
-    if (!lctx.is_encoding && lctx.inp_KQ_mask_cross) {
-        const int64_t n_output_enc = lctx.embd_enc.size() / hparams.n_embd;
-        const int64_t n_tokens = ubatch.n_tokens;
+    return tmp;
+}
 
-        GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_cross->buffer));
-        GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing
+class llm_graph_input_k_shift : public llm_graph_input_i {
+public:
+    llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
+    virtual ~llm_graph_input_k_shift() = default;
 
-        float * data = (float *) lctx.inp_KQ_mask_cross->data;
+    void set_input(const llama_ubatch * ubatch) override;
 
-        for (int h = 0; h < 1; ++h) {
-            for (int j = 0; j < n_tokens; ++j) {
-                for (int i = 0; i < n_output_enc; ++i) {
-                    float f = -INFINITY;
-                    for (int s = 0; s < ubatch.n_seq_id[j]; ++s) {
-                        const llama_seq_id seq_id = ubatch.seq_id[j][s];
-                        if (lctx.seq_ids_enc[i].find(seq_id) != lctx.seq_ids_enc[i].end()) {
-                            f = 0.0f;
-                        }
-                    }
-                    data[h*(n_output_enc*n_tokens) + j*n_output_enc + i] = f;
-                }
-            }
+    ggml_tensor * k_shift; // I32 [kv_size]
 
-            for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
-                for (int j = 0; j < n_output_enc; ++j) {
-                    data[h*(n_output_enc*n_tokens) + i*n_output_enc + j] = -INFINITY;
-                }
-            }
+    const llama_kv_cache_unified * kv_self;
+};
+
+void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
+    GGML_UNUSED(ubatch);
+
+    if (k_shift) {
+        assert(ggml_backend_buffer_is_host(k_shift->buffer));
+
+        int32_t * data = (int32_t *) k_shift->data;
+
+        for (uint32_t i = 0; i < kv_self->size; ++i) {
+            data[i] = kv_self->cells[i].delta;
         }
     }
 }
 
-// llama output
+llm_graph_result_ptr llama_context::build_kv_self_shift(
+        ggml_context * ctx0,
+        ggml_cgraph * gf) const {
+    auto res = std::make_unique<llm_graph_result>();
 
-size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
-    const auto & cparams = lctx.cparams;
-    const auto & hparams = lctx.model.hparams;
-    const auto & vocab   = lctx.model.vocab;
+    const auto & hparams = model.hparams;
 
-    const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max);
+    const auto & n_layer = hparams.n_layer;
 
-    const auto n_batch = cparams.n_batch;
-    const auto n_vocab = vocab.n_tokens();
-    const auto n_embd  = hparams.n_embd;
+    const auto & n_embd_head_k = hparams.n_embd_head_k;
+  //const auto & n_embd_head_v = hparams.n_embd_head_v;
 
-    // TODO: use a per-batch flag for logits presence instead
-    const bool has_logits = !cparams.embeddings;
-    const bool has_embd   =  cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
+    //GGML_ASSERT(kv_self->size == n_ctx);
 
-    const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
-    const size_t embd_size   = has_embd   ?  n_embd*n_outputs_max : 0;
+    auto inp = std::make_unique<llm_graph_input_k_shift>(kv_self.get());
 
-    if (lctx.output_ids.empty()) {
-        // init, never resized afterwards
-        lctx.output_ids.resize(n_batch);
-    }
+    inp->k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_ctx);
+    ggml_set_input(inp->k_shift);
 
-    const size_t prev_size = lctx.buf_output ? ggml_backend_buffer_get_size(lctx.buf_output.get()) : 0;
-    const size_t new_size  = (logits_size + embd_size) * sizeof(float);
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        const int64_t n_head_kv    = hparams.n_head_kv(il);
+        const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
 
-    // alloc only when more than the current capacity is required
-    // TODO: also consider shrinking the buffer
-    if (!lctx.buf_output || prev_size < new_size) {
-        if (lctx.buf_output) {
-#ifndef NDEBUG
-            // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
-            LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
-#endif
-            lctx.buf_output = nullptr;
-            lctx.logits = nullptr;
-            lctx.embd = nullptr;
-        }
+        ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
 
-        auto * buft = ggml_backend_cpu_buffer_type();
-        // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory
-        auto * output_dev = lctx.model.dev_output();
-        auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr;
-        if (output_dev_host_buft) {
-            buft = output_dev_host_buft;
-        }
-        lctx.buf_output.reset(ggml_backend_buft_alloc_buffer(buft, new_size));
-        if (lctx.buf_output == nullptr) {
-            LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
-            return 0;
-        }
+        ggml_tensor * k =
+            ggml_view_3d(ctx0, kv_self->k_l[il],
+                n_embd_head_k, n_head_kv, kv_self->size,
+                ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
+                ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
+                0);
+
+        ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, kv_self->k_l[il]->buffer);
+
+        ggml_build_forward_expand(gf, cur);
     }
 
-    float * output_base = (float *) ggml_backend_buffer_get_base(lctx.buf_output.get());
+    res->add_input(std::move(inp));
+
+    return res;
+}
+
+llm_graph_result_ptr llama_context::build_kv_self_defrag(
+        ggml_context * ctx0,
+        ggml_cgraph * gf) const {
+    auto res = std::make_unique<llm_graph_result>();
 
-    lctx.logits = has_logits ? output_base               : nullptr;
-    lctx.embd   = has_embd   ? output_base + logits_size : nullptr;
+    const auto & hparams = model.hparams;
 
-    lctx.output_size = n_outputs_max;
-    lctx.logits_size = logits_size;
-    lctx.embd_size   = embd_size;
+    const auto & ids = kv_self->defrag_info.ids;
 
-    // set all ids as invalid (negative)
-    std::fill(lctx.output_ids.begin(), lctx.output_ids.end(), -1);
+#if 0
+    // CPU defrag
+    //
+    // TODO: optimizations are possible:
+    //       - multiple threads
+    //       - avoid copying to the host memory when already there
+    //
+    // likely not worth the effort, as we have ggml_graph based defrag
+    //
 
-    ggml_backend_buffer_clear(lctx.buf_output.get(), 0);
+    const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
+    const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
 
-    lctx.n_outputs = 0;
+    const uint32_t kv_size = size;
 
-    return n_outputs_max;
-}
+    std::vector<uint8_t> buf_k;
+    std::vector<uint8_t> buf_v;
 
-void llama_output_reorder(struct llama_context & ctx) {
-    std::vector<size_t> & out_ids = ctx.sbatch.out_ids;
-    if (!out_ids.empty()) {
-        const uint32_t n_vocab = ctx.model.vocab.n_tokens();
-        const uint32_t n_embd  = ctx.model.hparams.n_embd;
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
+        const size_t k_size     = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
 
-        const int32_t n_outputs = ctx.n_outputs;
-        GGML_ASSERT((size_t) n_outputs == out_ids.size());
+        const size_t v_size_el = ggml_type_size(v_l[il]->type);
+        const size_t v_size    = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
 
-        // TODO: is there something more efficient which also minimizes swaps?
-        // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
-        for (int32_t i = 0; i < n_outputs - 1; ++i) {
-            int32_t j_min = i;
-            for (int32_t j = i + 1; j < n_outputs; ++j) {
-                if (out_ids[j] < out_ids[j_min]) {
-                    j_min = j;
-                }
+        buf_k.resize(k_size);
+        buf_v.resize(v_size);
+
+        ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
+        ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
+
+        // batch move [i, i+nm) to [id, id+nm)
+        // note: cells can move only to a lower index
+        for (uint32_t i = 0; i < n_kv; ++i) {
+            const uint32_t id = ids[i];
+
+            if (i == id || id == n_kv) {
+                continue;
             }
-            if (j_min == i) { continue; }
-            std::swap(out_ids[i], out_ids[j_min]);
-            if (ctx.logits_size > 0) {
-                for (uint32_t k = 0; k < n_vocab; k++) {
-                    std::swap(ctx.logits[i*n_vocab + k], ctx.logits[j_min*n_vocab + k]);
-                }
+
+            uint32_t nm = 1;
+
+            while (i + nm < n_kv && ids[i + nm] == id + nm) {
+                nm++;
             }
-            if (ctx.embd_size > 0) {
-                for (uint32_t k = 0; k < n_embd; k++) {
-                    std::swap(ctx.embd[i*n_embd + k], ctx.embd[j_min*n_embd + k]);
+
+            // move keys
+            {
+                const int64_t os =  i*k_size_row;
+                const int64_t od = id*k_size_row;
+
+                memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
+            }
+
+            // move values (note: they are transposed)
+            {
+                const int64_t os =  i;
+                const int64_t od = id;
+
+                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+                    memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
                 }
             }
+
+            i += nm - 1;
         }
-        std::fill(ctx.output_ids.begin(), ctx.output_ids.end(), -1);
-        for (int32_t i = 0; i < n_outputs; ++i) {
-            ctx.output_ids[out_ids[i]] = i;
-        }
-        out_ids.clear();
-    }
-}
 
-//
-// interface implementation
-//
+        ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
+        ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
+    }
+#else
+    for (uint32_t i = 0; i < ids.size(); ++i) {
+        const uint32_t id = ids[i];
 
-void llama_free(struct llama_context * ctx) {
-    delete ctx;
-}
+        if (i == id || id == ids.size()) {
+            continue;
+        }
 
-uint32_t llama_n_ctx(const struct llama_context * ctx) {
-    return ctx->cparams.n_ctx;
-}
+        uint32_t nm = 1;
 
-uint32_t llama_n_batch(const struct llama_context * ctx) {
-    return ctx->cparams.n_batch;
-}
+        while (i + nm < ids.size() && ids[i + nm] == id + nm) {
+            nm++;
+        }
 
-uint32_t llama_n_ubatch(const struct llama_context * ctx) {
-    return ctx->cparams.n_ubatch;
-}
+        for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
+            const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
+            const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
+
+            ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self->k_l[il],
+                    n_embd_k_gqa, nm,
+                    ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
+                    ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*i));
+
+            ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self->k_l[il],
+                    n_embd_k_gqa, nm,
+                    ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
+                    ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*id));
+
+            ggml_tensor * view_v_src;
+            ggml_tensor * view_v_dst;
+
+            if (cparams.flash_attn) {
+                // NOTE: the V cache is not transposed when using flash attention
+                view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
+                        n_embd_v_gqa, nm,
+                        ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
+                        ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*i));
+
+                view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
+                        n_embd_v_gqa, nm,
+                        ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
+                        ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*id));
+            } else {
+                view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
+                        nm, n_embd_v_gqa,
+                        ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
+                        ggml_row_size(kv_self->v_l[il]->type, i));
+
+                view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
+                        nm, n_embd_v_gqa,
+                        ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
+                        ggml_row_size(kv_self->v_l[il]->type, id));
+            }
 
-uint32_t llama_n_seq_max(const struct llama_context * ctx) {
-    return ctx->kv_self.size;
-}
+            ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
+            ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
+        }
 
-const struct llama_model * llama_get_model(const struct llama_context * ctx) {
-    return &ctx->model;
-}
+        i += nm - 1;
+    }
 
-enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
-    return ctx->cparams.pooling_type;
-}
+    //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
+#endif
 
-void llama_attach_threadpool(
-             struct llama_context * ctx,
-        ggml_threadpool_t   threadpool,
-        ggml_threadpool_t   threadpool_batch) {
-    ctx->threadpool       = threadpool;
-    ctx->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool;
+    return res;
 }
 
-void llama_detach_threadpool(struct llama_context * ctx) {
-    ctx->threadpool       = nullptr;
-    ctx->threadpool_batch = nullptr;
-}
+void llama_context::kv_self_update() {
+    auto & kv = kv_self;
 
-void llama_set_n_threads(struct llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) {
-    ctx->cparams.n_threads       = n_threads;
-    ctx->cparams.n_threads_batch = n_threads_batch;
-}
+    bool need_reserve = false;
 
-int32_t llama_n_threads(struct llama_context * ctx) {
-    return ctx->cparams.n_threads;
-}
+    if (kv->has_shift) {
+        if (!kv->get_can_shift()) {
+            GGML_ABORT("The current context does not support K-shift");
+        }
 
-int32_t llama_n_threads_batch(struct llama_context * ctx) {
-    return ctx->cparams.n_threads_batch;
-}
+        LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
 
-void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) {
-    ctx->abort_callback      = abort_callback;
-    ctx->abort_callback_data = abort_callback_data;
+        // apply K-shift if needed
+        if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
+            ggml_backend_sched_reset(sched.get());
 
-    for (auto & backend : ctx->backends) {
-        auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get()));
-        auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback");
-        if (set_abort_callback_fn) {
-            set_abort_callback_fn(backend.get(), ctx->abort_callback, ctx->abort_callback_data);
-        }
-    }
-}
+            auto * gf = graph_init();
 
-void llama_set_embeddings(struct llama_context * ctx, bool embeddings) {
-    ctx->cparams.embeddings = embeddings;
-}
+            auto res = build_kv_self_shift(ctx_compute.get(), gf);
 
-void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
-    ctx->cparams.causal_attn = causal_attn;
-}
+            ggml_backend_sched_alloc_graph(sched.get(), gf);
 
-void llama_synchronize(struct llama_context * ctx) {
-    ggml_backend_sched_synchronize(ctx->sched.get());
+            res->set_inputs(nullptr);
 
-    // FIXME: if multiple single tokens are evaluated without a synchronization,
-    // the stats will be added to the prompt evaluation stats
-    // this should only happen when using batch size 1 to evaluate a batch
+            graph_compute(gf, false);
 
-    // add the evaluation to the stats
-    if (ctx->n_queued_tokens == 1) {
-        if (!ctx->cparams.no_perf) {
-            ctx->t_eval_us += ggml_time_us() - ctx->t_compute_start_us;
+            need_reserve = true;
         }
-        ctx->n_eval++;
-    } else if (ctx->n_queued_tokens > 1) {
-        if (!ctx->cparams.no_perf) {
-            ctx->t_p_eval_us += ggml_time_us() - ctx->t_compute_start_us;
+
+        {
+            kv->has_shift = false;
+
+            for (uint32_t i = 0; i < kv->size; ++i) {
+                kv->cells[i].delta = 0;
+            }
         }
-        ctx->n_p_eval += ctx->n_queued_tokens;
     }
 
-    // get a more accurate load time, upon first eval
-    if (ctx->n_queued_tokens > 0 && !ctx->has_evaluated_once) {
-        ctx->t_load_us = ggml_time_us() - ctx->t_start_us;
-        ctx->has_evaluated_once = true;
-    }
+    // defragment the KV cache if needed
+    if (kv->do_defrag) {
+        LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
 
-    ctx->n_queued_tokens = 0;
-    ctx->t_compute_start_us = 0;
-}
+        if (kv->defrag_prepare(graph_max_nodes())) {
+            ggml_backend_sched_reset(sched.get());
 
-float * llama_get_logits(struct llama_context * ctx) {
-    llama_synchronize(ctx);
+            auto * gf = graph_init();
 
-    // reorder logits for backward compatibility
-    // TODO: maybe deprecate this
-    llama_output_reorder(*ctx);
+            auto res = build_kv_self_defrag(ctx_compute.get(), gf);
 
-    return ctx->logits;
-}
+            ggml_backend_sched_alloc_graph(sched.get(), gf);
 
-float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
-    int32_t j = -1;
+            res->set_inputs(nullptr);
 
-    llama_synchronize(ctx);
+            graph_compute(gf, false);
 
-    try {
-        if (ctx->logits == nullptr) {
-            throw std::runtime_error("no logits");
+            need_reserve = true;
         }
 
-        if (i < 0) {
-            j = ctx->n_outputs + i;
-            if (j < 0) {
-                throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
-            }
-        } else if ((size_t) i >= ctx->output_ids.size()) {
-            throw std::runtime_error(format("out of range [0, %zu)", ctx->output_ids.size()));
-        } else {
-            j = ctx->output_ids[i];
+        kv->do_defrag = false;
+    }
+
+    // reserve a worst case graph if needed
+    if (need_reserve) {
+        LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
+
+        // build worst-case graph
+        uint32_t n_seqs = 1; // TODO: worst-case number of sequences
+        uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
+
+        // simulate full KV cache
+        kv_self->n = kv_self->size;
+
+        llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
+        llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
+
+        auto * gf = graph_init();
+        graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
+
+        // initialize scheduler with the worst-case graph
+        ggml_backend_sched_reset(sched.get());
+        if (!ggml_backend_sched_reserve(sched.get(), gf)) {
+            LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
+        }
+    }
+}
+
+enum llama_pooling_type llama_context::pooling_type() const {
+    return cparams.pooling_type;
+}
+
+float * llama_context::get_logits() {
+    // reorder logits for backward compatibility
+    output_reorder();
+
+    return logits;
+}
+
+float * llama_context::get_logits_ith(int32_t i) {
+    int32_t j = -1;
+
+    try {
+        if (logits == nullptr) {
+            throw std::runtime_error("no logits");
+        }
+
+        if (i < 0) {
+            j = n_outputs + i;
+            if (j < 0) {
+                throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
+            }
+        } else if ((size_t) i >= output_ids.size()) {
+            throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
+        } else {
+            j = output_ids[i];
         }
 
         if (j < 0) {
             throw std::runtime_error(format("batch.logits[%d] != true", i));
         }
-        if (j >= ctx->n_outputs) {
+        if (j >= n_outputs) {
             // This should not happen
-            throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
+            throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
         }
 
-        return ctx->logits + j*ctx->model.vocab.n_tokens();
+        return logits + j*model.vocab.n_tokens();
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
 #ifndef NDEBUG
@@ -737,46 +837,41 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
     }
 }
 
-float * llama_get_embeddings(struct llama_context * ctx) {
-    llama_synchronize(ctx);
-
+float * llama_context::get_embeddings() {
     // reorder embeddings for backward compatibility
-    // TODO: maybe deprecate this
-    llama_output_reorder(*ctx);
+    output_reorder();
 
-    return ctx->embd;
+    return embd;
 }
 
-float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
+float * llama_context::get_embeddings_ith(int32_t i) {
     int32_t j = -1;
 
-    llama_synchronize(ctx);
-
     try {
-        if (ctx->embd == nullptr) {
+        if (embd == nullptr) {
             throw std::runtime_error("no embeddings");
         }
 
         if (i < 0) {
-            j = ctx->n_outputs + i;
+            j = n_outputs + i;
             if (j < 0) {
-                throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
+                throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
             }
-        } else if ((size_t) i >= ctx->output_ids.size()) {
-            throw std::runtime_error(format("out of range [0, %zu)", ctx->output_ids.size()));
+        } else if ((size_t) i >= output_ids.size()) {
+            throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
         } else {
-            j = ctx->output_ids[i];
+            j = output_ids[i];
         }
 
         if (j < 0) {
             throw std::runtime_error(format("batch.logits[%d] != true", i));
         }
-        if (j >= ctx->n_outputs) {
+        if (j >= n_outputs) {
             // This should not happen
-            throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
+            throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
         }
 
-        return ctx->embd + j*ctx->model.hparams.n_embd;
+        return embd + j*model.hparams.n_embd;
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
 #ifndef NDEBUG
@@ -787,696 +882,925 @@ float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
     }
 }
 
-float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
-    llama_synchronize(ctx);
-
-    auto it = ctx->embd_seq.find(seq_id);
-    if (it == ctx->embd_seq.end()) {
+float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
+    auto it = embd_seq.find(seq_id);
+    if (it == embd_seq.end()) {
         return nullptr;
     }
 
     return it->second.data();
 }
 
-// llama state API
+void llama_context::attach_threadpool(
+           ggml_threadpool_t threadpool,
+           ggml_threadpool_t threadpool_batch) {
+    LLAMA_LOG_DEBUG("%s: call\n", __func__);
 
-// deprecated
-size_t llama_get_state_size(struct llama_context * ctx) {
-    return llama_state_get_size(ctx);
+    this->threadpool       = threadpool;
+    this->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool;
 }
 
-// deprecated
-size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
-    return llama_state_get_data(ctx, dst, -1);
-}
+void llama_context::detach_threadpool() {
+    LLAMA_LOG_DEBUG("%s: call\n", __func__);
 
-// deprecated
-size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
-    return llama_state_set_data(ctx, src, -1);
+    this->threadpool       = nullptr;
+    this->threadpool_batch = nullptr;
 }
 
-// deprecated
-bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
-    return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
-}
+void llama_context::set_n_threads(int32_t n_threads, int32_t n_threads_batch) {
+    LLAMA_LOG_DEBUG("%s: n_threads = %d, n_threads_batch = %d\n", __func__, n_threads, n_threads_batch);
 
-// deprecated
-bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
-    return llama_state_save_file(ctx, path_session, tokens, n_token_count);
+    cparams.n_threads       = n_threads;
+    cparams.n_threads_batch = n_threads_batch;
 }
 
-// TODO: replace all non-fatal assertions with returned errors or exceptions
-struct llama_data_write {
-    virtual void write(const void * src, size_t size) = 0;
-    virtual void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) = 0;
-    virtual size_t get_size_written() = 0;
-    virtual ~llama_data_write() = default;
+void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data) {
+    LLAMA_LOG_DEBUG("%s: call\n", __func__);
 
-    void write_string(const std::string & str) {
-        uint32_t str_size = str.size();
+    this->abort_callback      = abort_callback;
+    this->abort_callback_data = abort_callback_data;
 
-        write(&str_size,  sizeof(str_size));
-        write(str.data(), str_size);
+    for (auto & backend : backends) {
+        auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get()));
+        auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback");
+        if (set_abort_callback_fn) {
+            set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data);
+        }
     }
+}
 
-    void write_model_info(const struct llama_context * ctx) {
-        const std::string arch_str = llm_arch_name(ctx->model.arch);
-        write_string(arch_str);
-        // TODO: add more model-specific info which should prevent loading the session file if not identical
-    }
+void llama_context::set_embeddings(bool value) {
+    LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
 
-    //void write_rng(const std::mt19937 & rng) {
-    //    std::ostringstream rng_ss;
-    //    rng_ss << rng;
+    cparams.embeddings = value;
+}
 
-    //    const std::string & rng_str = rng_ss.str();
+void llama_context::set_causal_attn(bool value) {
+    LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
 
-    //    write_string(rng_str);
-    //}
+    cparams.causal_attn = value;
+}
 
-    void write_output_ids(struct llama_context * ctx) {
-        llama_output_reorder(*ctx);
+void llama_context::set_adapter_lora(
+            llama_adapter_lora * adapter,
+            float scale) {
+    LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale);
 
-        const uint32_t n_outputs = ctx->n_outputs;
+    loras[adapter] = scale;
+}
 
-        std::vector<int32_t> output_pos;
+bool llama_context::rm_adapter_lora(
+            llama_adapter_lora * adapter) {
+    LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter);
 
-        const size_t    n_batch = ctx->cparams.n_batch;
-        const auto & output_ids = ctx->output_ids;
+    auto pos = loras.find(adapter);
+    if (pos != loras.end()) {
+        loras.erase(pos);
+        return true;
+    }
 
-        GGML_ASSERT(n_outputs <= ctx->output_size);
+    return false;
+}
 
-        output_pos.resize(n_outputs);
+void llama_context::clear_adapter_lora() {
+    LLAMA_LOG_DEBUG("%s: call\n", __func__);
 
-        // build a more compact representation of the output ids
-        for (size_t i = 0; i < n_batch; ++i) {
-            // map an output id to a position in the batch
-            int32_t pos = output_ids[i];
-            if (pos >= 0) {
-                GGML_ASSERT((uint32_t) pos < n_outputs);
-                output_pos[pos] = i;
-            }
-        }
+    loras.clear();
+}
 
-        write(&n_outputs, sizeof(n_outputs));
+bool llama_context::apply_adapter_cvec(
+            const float * data,
+                 size_t   len,
+                int32_t   n_embd,
+                int32_t   il_start,
+                int32_t   il_end) {
+    LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end);
 
-        if (n_outputs) {
-            write(output_pos.data(), n_outputs * sizeof(int32_t));
-        }
+    return cvec.apply(model, data, len, n_embd, il_start, il_end);
+}
+
+int llama_context::encode(llama_batch & inp_batch) {
+    if (inp_batch.n_tokens == 0) {
+        LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
+        return -1;
     }
 
-    void write_logits(const struct llama_context * ctx) {
-        const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.vocab.n_tokens());
+    // temporary allocate memory for the input batch if needed
+    // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
+    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
 
-        write(&logits_size, sizeof(logits_size));
+    const llama_batch & batch = batch_allocr.batch;
+    const int32_t n_tokens = batch.n_tokens;
 
-        if (logits_size) {
-            write(ctx->logits, logits_size * sizeof(float));
+    const auto & hparams = model.hparams;
+
+    GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
+
+    if (batch.token) {
+        for (int32_t i = 0; i < n_tokens; ++i) {
+            if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
+                LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
+                return -1;
+            }
         }
     }
 
-    void write_embeddings(const struct llama_context * ctx) {
-        const uint64_t embeddings_size = std::min((uint64_t) ctx->embd_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_embd);
+    // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
+    GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
 
-        write(&embeddings_size, sizeof(embeddings_size));
-
-        if (embeddings_size) {
-            write(ctx->embd, embeddings_size * sizeof(float));
-        }
+    if (t_compute_start_us == 0) {
+        t_compute_start_us = ggml_time_us();
     }
 
-    void write_kv_cache_meta(const llama_kv_cache & kv_self, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) {
-        for (const auto & range : cell_ranges) {
-            for (uint32_t i = range.first; i < range.second; ++i) {
-                const auto & cell = kv_self.cells[i];
-                const llama_pos pos      = cell.pos;
-                const uint32_t  n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
+    n_queued_tokens += n_tokens;
 
-                write(&pos,      sizeof(pos));
-                write(&n_seq_id, sizeof(n_seq_id));
+    const int64_t n_embd = hparams.n_embd;
 
-                if (n_seq_id) {
-                    for (auto seq_id : cell.seq_id) {
-                        write(&seq_id, sizeof(seq_id));
-                    }
-                }
-            }
-        }
-    }
+    sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
 
-    void write_kv_cache_data(const struct llama_context * ctx, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) {
-        const struct llama_kv_cache & kv_self = ctx->kv_self;
-        const struct llama_hparams & hparams = ctx->model.hparams;
+    const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
 
-        const uint32_t v_trans = kv_self.v_trans ? 1 : 0;
-        const uint32_t n_layer = hparams.n_layer;
+    // reserve output buffer
+    if (output_reserve(n_tokens) < n_tokens) {
+        LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
+        return -2;
+    };
 
-        write(&v_trans, sizeof(v_trans));
-        write(&n_layer, sizeof(n_layer));
+    for (int32_t i = 0; i < n_tokens; ++i) {
+        output_ids[i] = i;
+    }
 
-        std::vector<uint8_t> tmp_buf;
+    n_outputs = n_tokens;
 
-        // Iterate and write all the keys first, each row is a cell
-        // Get whole range at a time
-        for (uint32_t il = 0; il < n_layer; ++il) {
-            const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+    //batch_manager->prepare(ubatch);
 
-            // Write key type
-            const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
-            write(&k_type_i, sizeof(k_type_i));
+    ggml_backend_sched_reset(sched.get());
+    ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
 
-            // Write row size of key
-            const uint64_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
-            write(&k_size_row, sizeof(k_size_row));
+    auto * gf = graph_init();
+    auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER);
 
-            // Read each range of cells of k_size length each into tmp_buf and write out
-            for (const auto & range : cell_ranges) {
-                const size_t range_size = range.second - range.first;
-                const size_t buf_size = range_size * k_size_row;
-                write_tensor_data(kv_self.k_l[il], range.first * k_size_row, buf_size);
-            }
-        }
+    ggml_backend_sched_alloc_graph(sched.get(), gf);
 
-        if (!kv_self.v_trans) {
-            for (uint32_t il = 0; il < n_layer; ++il) {
-                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+    res->set_inputs(&ubatch);
 
-                // Write value type
-                const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
-                write(&v_type_i, sizeof(v_type_i));
+    const auto compute_status = graph_compute(gf, n_tokens > 1);
+    switch (compute_status) {
+        case GGML_STATUS_SUCCESS:
+            break;
+        case GGML_STATUS_ABORTED:
+            return 2;
+        case GGML_STATUS_ALLOC_FAILED:
+            return -2;
+        case GGML_STATUS_FAILED:
+        default:
+            return -3;
+    }
 
-                // Write row size of value
-                const uint64_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
-                write(&v_size_row, sizeof(v_size_row));
+    auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
 
-                // Read each range of cells of v_size length each into tmp_buf and write out
-                for (const auto & range : cell_ranges) {
-                    const size_t range_size = range.second - range.first;
-                    const size_t buf_size = range_size * v_size_row;
-                    write_tensor_data(kv_self.v_l[il], range.first * v_size_row, buf_size);
-                }
-            }
-        } else {
-            // When v is transposed, we also need the element size and get the element ranges from each row
-            const uint32_t kv_size = kv_self.size;
-            for (uint32_t il = 0; il < n_layer; ++il) {
-                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+    // extract embeddings
+    if (t_embd) {
+        ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
+        GGML_ASSERT(backend_embd != nullptr);
 
-                // Write value type
-                const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
-                write(&v_type_i, sizeof(v_type_i));
+        GGML_ASSERT(embd != nullptr);
 
-                // Write element size
-                const uint32_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
-                write(&v_size_el, sizeof(v_size_el));
+        switch (cparams.pooling_type) {
+            case LLAMA_POOLING_TYPE_NONE:
+                {
+                    // extract token embeddings
+                    GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
+                    ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
+                } break;
+            case LLAMA_POOLING_TYPE_MEAN:
+            case LLAMA_POOLING_TYPE_CLS:
+            case LLAMA_POOLING_TYPE_LAST:
+                {
+                    // extract sequence embeddings
+                    auto & embd_seq_out = embd_seq;
+                    embd_seq_out.clear();
 
-                // Write GQA embedding size
-                write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
+                    GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
 
-                // For each row, we get the element values of each cell
-                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                    // Read each range of cells of v_size_el length each into tmp_buf and write out
-                    for (const auto & range : cell_ranges) {
-                        const size_t range_size = range.second - range.first;
-                        const size_t src_offset = (range.first + j * kv_size) * v_size_el;
-                        const size_t buf_size = range_size * v_size_el;
-                        write_tensor_data(kv_self.v_l[il], src_offset, buf_size);
+                    for (int32_t i = 0; i < n_tokens; i++) {
+                        const llama_seq_id seq_id = ubatch.seq_id[i][0];
+                        if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
+                            continue;
+                        }
+                        embd_seq_out[seq_id].resize(n_embd);
+                        ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
                     }
+                } break;
+            case LLAMA_POOLING_TYPE_RANK:
+                {
+                    // TODO: this likely should be the same logic as in llama_decoder_internal, but better to
+                    //       wait for an encoder model that requires this pooling type in order to test it
+                    //       https://github.com/ggerganov/llama.cpp/pull/9510
+                    GGML_ABORT("RANK pooling not implemented yet");
+                }
+            case LLAMA_POOLING_TYPE_UNSPECIFIED:
+                {
+                    GGML_ABORT("unknown pooling type");
                 }
-            }
         }
     }
 
-    void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) {
-        const struct llama_kv_cache & kv_self = ctx->kv_self;
-        std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
-        uint32_t cell_count = 0;
+    // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
+    // overlap with device computation.
+    ggml_backend_sched_reset(sched.get());
 
-        // Count the number of cells with the specified seq_id
-        // Find all the ranges of cells with this seq id (or all, when -1)
-        uint32_t cell_range_begin = kv_self.size;
-        for (uint32_t i = 0; i < kv_self.size; ++i) {
-            const auto & cell = kv_self.cells[i];
-            if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
-                ++cell_count;
-                if (cell_range_begin == kv_self.size) {
-                    cell_range_begin = i;
-                }
-            } else {
-                if (cell_range_begin != kv_self.size) {
-                    cell_ranges.emplace_back(cell_range_begin, i);
-                    cell_range_begin = kv_self.size;
-                }
-            }
-        }
-        if (cell_range_begin != kv_self.size) {
-            cell_ranges.emplace_back(cell_range_begin, kv_self.size);
-        }
+    // TODO: hacky solution
+    if (model.arch == LLM_ARCH_T5 && t_embd) {
+        //cross.t_embd = t_embd;
 
-        // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
-        uint32_t cell_count_check = 0;
-        for (const auto & range : cell_ranges) {
-            cell_count_check += range.second - range.first;
+        cross.n_embd = t_embd->ne[0];
+        cross.n_enc  = t_embd->ne[1];
+        cross.v_embd.resize(cross.n_embd*cross.n_enc);
+        memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
+
+        // remember the sequence ids used during the encoding - needed for cross attention later
+        cross.seq_ids_enc.resize(n_tokens);
+        for (int32_t i = 0; i < n_tokens; i++) {
+            for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
+                llama_seq_id seq_id = ubatch.seq_id[i][s];
+                cross.seq_ids_enc[i].insert(seq_id);
+            }
         }
-        GGML_ASSERT(cell_count == cell_count_check);
+    }
 
-        write(&cell_count, sizeof(cell_count));
+    return 0;
+}
 
-        write_kv_cache_meta(kv_self, cell_ranges, seq_id);
-        write_kv_cache_data(ctx, cell_ranges);
+int llama_context::decode(llama_batch & inp_batch) {
+    if (inp_batch.n_tokens == 0) {
+        LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
+        return -1;
     }
-};
 
-struct llama_data_read {
-    virtual const uint8_t * read(size_t size) = 0;
-    virtual void read_to(void * dst, size_t size) = 0;
-    virtual size_t get_size_read() = 0;
-    virtual ~llama_data_read() = default;
+    // temporary allocate memory for the input batch if needed
+    // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
+    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
 
-    void read_string(std::string & str) {
-        uint32_t str_size;
-        read_to(&str_size, sizeof(str_size));
+    const llama_batch & batch = batch_allocr.batch;
 
-        str.assign((const char *) read(str_size), str_size);
-    }
+    const auto & vocab   = model.vocab;
+    const auto & hparams = model.hparams;
 
-    // validate model information
-    void read_model_info(const struct llama_context * ctx) {
-        const std::string cur_arch_str = llm_arch_name(ctx->model.arch);
+    const int32_t n_vocab = vocab.n_tokens();
 
-        std::string arch_str;
-        read_string(arch_str);
-        if (cur_arch_str != arch_str) {
-            throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str()));
+    const int64_t n_tokens_all = batch.n_tokens;
+    const int64_t n_embd       = hparams.n_embd;
+
+    // TODO: remove this stuff
+    class batch_guard {
+    public:
+        batch_guard(llama_kv_cache_unified & kv_self) : kv_slot_restorer(kv_self) {
         }
-        // TODO: add more info which needs to be identical but which is not verified otherwise
-    }
 
-    //void read_rng(std::mt19937 & rng) {
-    //    std::string rng_str;
-    //    read_string(rng_str);
+        ~batch_guard() {
+            if (!is_done) {
+                kv_slot_restorer.restore();
+            }
+        }
 
-    //    std::istringstream rng_ss(rng_str);
-    //    rng_ss >> rng;
+        void done() {
+            is_done = true;
+        }
 
-    //    if (rng_ss.fail()) {
-    //        throw std::runtime_error("failed to load RNG state");
-    //    }
-    //}
+        void save(const llama_kv_cache_slot_info & slot_info) {
+            kv_slot_restorer.save(slot_info);
+        }
 
-    void read_output_ids(struct llama_context * ctx) {
-        std::vector<int32_t> output_pos;
+    private:
+        bool is_done = false;
 
-        uint32_t n_outputs;
-        read_to(&n_outputs, sizeof(n_outputs));
+        llama_kv_slot_restorer kv_slot_restorer;
+    };
 
-        if (n_outputs > llama_output_reserve(*ctx, n_outputs)) {
-            throw std::runtime_error("could not reserve outputs");
-        }
+    batch_guard bg(*kv_self);
 
-        if (n_outputs) {
-            output_pos.resize(n_outputs);
-            read_to(output_pos.data(), n_outputs * sizeof(int32_t));
+    GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
 
-            for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
-                int32_t id = output_pos[i];
-                if ((uint32_t) id >= ctx->cparams.n_batch) {
-                    throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, ctx->cparams.n_batch));
-                }
-                ctx->output_ids[id] = i;
+    if (batch.token) {
+        for (int64_t i = 0; i < n_tokens_all; ++i) {
+            if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
+                LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
+                throw std::runtime_error("invalid token");
             }
-
-            ctx->n_outputs = n_outputs;
         }
     }
 
-    void read_logits(struct llama_context * ctx) {
-        uint64_t logits_size;
-        read_to(&logits_size, sizeof(logits_size));
+    GGML_ASSERT(n_tokens_all <= cparams.n_batch);
 
-        if (ctx->logits_size < logits_size) {
-            throw std::runtime_error("logits buffer too small");
-        }
+    GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
 
-        if (logits_size) {
-            read_to(ctx->logits, logits_size * sizeof(float));
-        }
+    if (t_compute_start_us == 0) {
+        t_compute_start_us = ggml_time_us();
     }
+    n_queued_tokens += n_tokens_all;
 
-    void read_embeddings(struct llama_context * ctx) {
-        uint64_t embeddings_size;
-        read_to(&embeddings_size, sizeof(embeddings_size));
+    // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
+    const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
 
-        if (ctx->embd_size < embeddings_size) {
-            throw std::runtime_error("embeddings buffer too small");
-        }
+    embd_seq.clear();
 
-        if (embeddings_size) {
-            read_to(ctx->embd, embeddings_size * sizeof(float));
+    int64_t n_outputs_all = 0;
+
+    // count outputs
+    if (batch.logits && !embd_pooled) {
+        for (uint32_t i = 0; i < n_tokens_all; ++i) {
+            n_outputs_all += batch.logits[i] != 0;
         }
+    } else if (logits_all || embd_pooled) {
+        n_outputs_all = n_tokens_all;
+    } else {
+        // keep last output only
+        n_outputs_all = 1;
     }
 
-    bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id dest_seq_id = -1) {
-        struct llama_kv_cache & kv_self = ctx->kv_self;
+    const bool logits_all = n_outputs_all == n_tokens_all;
 
-        if (dest_seq_id != -1) {
-            // single sequence
+    sbatch.from_batch(batch, n_embd,
+            /* simple_split */ !kv_self->recurrent,
+            /* logits_all   */ logits_all);
 
-            llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
+    // reserve output buffer
+    if (output_reserve(n_outputs_all) < n_outputs_all) {
+        LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
+        return -2;
+    };
 
-            llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
-            batch.n_tokens = cell_count;
-            batch.n_seq_tokens = cell_count;
-            batch.n_seqs = 1;
+    int64_t n_outputs_prev = 0;
 
-            for (uint32_t i = 0; i < cell_count; ++i) {
-                llama_pos pos;
-                uint32_t n_seq_id;
+    while (sbatch.n_tokens > 0) {
+        llama_ubatch ubatch = llama_ubatch();
 
-                read_to(&pos, sizeof(pos));
-                read_to(&n_seq_id, sizeof(n_seq_id));
-
-                if (n_seq_id != 0) {
-                    LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
-                    return false;
-                }
+        const auto & n_ubatch = cparams.n_ubatch;
 
-                batch.pos[i] = pos;
-            }
-            batch.n_seq_id[0] = 1;
-            batch.seq_id[0] = &dest_seq_id;
-            if (!llama_kv_cache_find_slot(kv_self, batch)) {
-                LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
-                return false;
+        if (kv_self->recurrent) {
+            if (embd_pooled) {
+                // Pooled embeddings cannot be split across ubatches (yet)
+                ubatch = sbatch.split_seq(cparams.n_ubatch);
+            } else {
+                // recurrent model architectures are easier to implement
+                // with equal-length sequences
+                ubatch = sbatch.split_equal(cparams.n_ubatch);
             }
-
-            // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
-            // Assume that this is one contiguous block of cells
-            GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
-            GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]);
-            GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
-            GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
-            GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
         } else {
-            // whole KV cache restore
-
-            if (cell_count > kv_self.size) {
-                LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
-                return false;
-            }
-
-            llama_kv_cache_clear(kv_self);
-
-            for (uint32_t i = 0; i < cell_count; ++i) {
-                llama_kv_cell & cell = kv_self.cells[i];
-
-                llama_pos pos;
-                uint32_t  n_seq_id;
+            ubatch = sbatch.split_simple(n_ubatch);
+        }
 
-                read_to(&pos,      sizeof(pos));
-                read_to(&n_seq_id, sizeof(n_seq_id));
+        // count the outputs in this u_batch
+        {
+            int32_t n_outputs_new = 0;
 
-                cell.pos = pos;
+            if (n_outputs_all == n_tokens_all) {
+                n_outputs_new = ubatch.n_tokens;
+            } else {
+                GGML_ASSERT(ubatch.output);
+                for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
+                    n_outputs_new += (int32_t) (ubatch.output[i] != 0);
+                }
+            }
 
-                for (uint32_t j = 0; j < n_seq_id; ++j) {
-                    llama_seq_id seq_id;
-                    read_to(&seq_id, sizeof(seq_id));
+            // needs to happen before the graph is built
+            n_outputs = n_outputs_new;
+        }
 
-                    if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
-                        LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
-                        return false;
-                    }
+        // non-causal masks do not use the KV cache
+        if (hparams.causal_attn) {
+            kv_self_update();
 
-                    cell.seq_id.insert(seq_id);
+            // if we have enough unused cells before the current head ->
+            //   better to start searching from the beginning of the cache, hoping to fill it
+            if (kv_self->head > kv_self->used + 2*ubatch.n_tokens) {
+                kv_self->head = 0;
+            }
 
-                    if (kv_self.recurrent) {
-                        int32_t & tail = kv_self.cells[seq_id].tail;
-                        if (tail != -1) {
-                            LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
-                            return false;
-                        }
-                        tail = i;
-                    }
-                }
+            const auto slot_info = kv_self->find_slot(ubatch);
+            if (!slot_info) {
+                LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
+                return -3;
             }
 
-            kv_self.head = 0;
-            kv_self.used = cell_count;
-        }
+            bg.save(slot_info);
 
-        if (kv_self.recurrent) {
-            for (uint32_t i = 0; i < cell_count; ++i) {
-                uint32_t cell_id = kv_self.head + i;
-                // make sure the recurrent states will keep their restored state
-                kv_self.cells[cell_id].src = cell_id;
+            if (!kv_self->recurrent) {
+                // a heuristic, to avoid attending the full cache if it is not yet utilized
+                // after enough generations, the benefit from this heuristic disappears
+                // if we start defragmenting the cache, the benefit from this will be more important
+                const uint32_t pad = kv_self->get_padding(cparams);
+                kv_self->n = std::min(kv_self->size, std::max(pad, GGML_PAD(kv_self->cell_max(), pad)));
             }
         }
 
-        return true;
-    }
+        //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head);
 
-    bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) {
-        const struct llama_hparams & hparams = ctx->model.hparams;
-        struct llama_kv_cache & kv_self = ctx->kv_self;
-        uint32_t v_trans;
-        uint32_t n_layer;
-        read_to(&v_trans, sizeof(v_trans));
-        read_to(&n_layer, sizeof(n_layer));
+        ggml_backend_sched_reset(sched.get());
+        ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
 
-        if (n_layer != hparams.n_layer) {
-            LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
-            return false;
-        }
-        if (cell_count > kv_self.size) {
-            LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size);
-            return false;
-        }
-        if (kv_self.v_trans != (bool) v_trans) {
-            LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
-            return false;
-        }
+        auto * gf = graph_init();
+        auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DECODER);
 
-        // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
-        for (uint32_t il = 0; il < n_layer; ++il) {
-            const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+        // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
 
-            // Read type of key
-            int32_t k_type_i_ref;
-            read_to(&k_type_i_ref, sizeof(k_type_i_ref));
-            const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
-            if (k_type_i != k_type_i_ref) {
-                LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
-                return false;
-            }
+        ggml_backend_sched_alloc_graph(sched.get(), gf);
 
-            // Read row size of key
-            uint64_t k_size_row_ref;
-            read_to(&k_size_row_ref, sizeof(k_size_row_ref));
-            const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
-            if (k_size_row != k_size_row_ref) {
-                LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
-                return false;
+        res->set_inputs(&ubatch);
+
+        const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
+        if (compute_status != GGML_STATUS_SUCCESS) {
+            switch (compute_status) {
+                case GGML_STATUS_ABORTED:
+                    return 2;
+                case GGML_STATUS_ALLOC_FAILED:
+                    return -2;
+                case GGML_STATUS_FAILED:
+                default:
+                    return -3;
             }
+        }
 
-            if (cell_count) {
-                // Read and set the keys for the whole cell range
-                ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row);
+        // update the kv ring buffer
+        {
+            kv_self->head += ubatch.n_tokens;
+
+            // Ensure kv cache head points to a valid index.
+            if (kv_self->head >= kv_self->size) {
+                kv_self->head = 0;
             }
         }
 
-        if (!kv_self.v_trans) {
-            for (uint32_t il = 0; il < n_layer; ++il) {
-                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+        // plot the computation graph in dot format (for debugging purposes)
+        //if (n_past%100 == 0) {
+        //    ggml_graph_dump_dot(gf, NULL, "llama.dot");
+        //}
 
-                // Read type of value
-                int32_t v_type_i_ref;
-                read_to(&v_type_i_ref, sizeof(v_type_i_ref));
-                const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
-                if (v_type_i != v_type_i_ref) {
-                    LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
-                    return false;
-                }
+        auto * t_logits = cparams.embeddings ? nullptr         : res->get_logits();
+        auto * t_embd   = cparams.embeddings ? res->get_embd() : nullptr;
 
-                // Read row size of value
-                uint64_t v_size_row_ref;
-                read_to(&v_size_row_ref, sizeof(v_size_row_ref));
-                const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
-                if (v_size_row != v_size_row_ref) {
-                    LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
-                    return false;
-                }
+        if (t_embd && res->get_embd_pooled()) {
+            t_embd = res->get_embd_pooled();
+        }
 
-                if (cell_count) {
-                    // Read and set the values for the whole cell range
-                    ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row);
-                }
-            }
-        } else {
-            // For each layer, read the values for each cell (transposed)
-            for (uint32_t il = 0; il < n_layer; ++il) {
-                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
-
-                // Read type of value
-                int32_t v_type_i_ref;
-                read_to(&v_type_i_ref, sizeof(v_type_i_ref));
-                const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
-                if (v_type_i != v_type_i_ref) {
-                    LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
-                    return false;
-                }
+        // extract logits
+        if (t_logits && n_outputs > 0) {
+            ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
+            GGML_ASSERT(backend_res != nullptr);
+            GGML_ASSERT(logits != nullptr);
 
-                // Read element size of value
-                uint32_t v_size_el_ref;
-                read_to(&v_size_el_ref, sizeof(v_size_el_ref));
-                const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
-                if (v_size_el != v_size_el_ref) {
-                    LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
-                    return false;
-                }
+            float * logits_out = logits + n_outputs_prev*n_vocab;
 
-                // Read GQA embedding size
-                uint32_t n_embd_v_gqa_ref;
-                read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
-                if (n_embd_v_gqa != n_embd_v_gqa_ref) {
-                    LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
-                    return false;
-                }
+            if (n_outputs) {
+                GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
+                GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
+                ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
+            }
+        }
 
-                if (cell_count) {
-                    // For each row in the transposed matrix, read the values for the whole cell range
-                    for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                        const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el;
-                        ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
+        // extract embeddings
+        if (t_embd && n_outputs > 0) {
+            ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
+            GGML_ASSERT(backend_embd != nullptr);
+
+            switch (cparams.pooling_type) {
+                case LLAMA_POOLING_TYPE_NONE:
+                    {
+                        // extract token embeddings
+                        GGML_ASSERT(embd != nullptr);
+                        float * embd_out = embd + n_outputs_prev*n_embd;
+
+                        if (n_outputs) {
+                            GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
+                            GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
+                            ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
+                        }
+                    } break;
+                case LLAMA_POOLING_TYPE_MEAN:
+                case LLAMA_POOLING_TYPE_CLS:
+                case LLAMA_POOLING_TYPE_LAST:
+                    {
+                        // extract sequence embeddings (cleared before processing each batch)
+                        auto & embd_seq_out = embd_seq;
+
+                        for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
+                            const llama_seq_id seq_id = ubatch.seq_id[s][0];
+                            if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
+                                continue;
+                            }
+                            embd_seq_out[seq_id].resize(n_embd);
+                            ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
+                        }
+                    } break;
+                case LLAMA_POOLING_TYPE_RANK:
+                    {
+                        // extract the rerank score - a single float per sequence
+                        auto & embd_seq_out = embd_seq;
+
+                        for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
+                            const llama_seq_id seq_id = ubatch.seq_id[s][0];
+                            if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
+                                continue;
+                            }
+                            embd_seq_out[seq_id].resize(1);
+                            ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
+                        }
+                    } break;
+                case LLAMA_POOLING_TYPE_UNSPECIFIED:
+                    {
+                        GGML_ABORT("unknown pooling type");
                     }
-                }
             }
         }
-        return true;
+
+        n_outputs_prev += n_outputs;
     }
 
-    void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) {
-        uint32_t cell_count;
-        read_to(&cell_count, sizeof(cell_count));
+    // finalize the batch processing
+    bg.done();
 
-        bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count);
+    // set output mappings
+    {
+        bool sorted_output = true;
 
-        if (!res) {
-            if (seq_id == -1) {
-                llama_kv_cache_clear(ctx);
-            } else {
-                llama_kv_cache_seq_rm(ctx, seq_id, -1, -1);
+        GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all);
+
+        for (int64_t i = 0; i < n_outputs_all; ++i) {
+            int64_t out_id = sbatch.out_ids[i];
+            output_ids[out_id] = i;
+            if (out_id != i) {
+                sorted_output = false;
             }
-            throw std::runtime_error("failed to restore kv cache");
+        }
+
+        if (sorted_output) {
+            sbatch.out_ids.clear();
         }
     }
-};
 
-struct llama_data_write_dummy : llama_data_write {
-    size_t size_written = 0;
+    // set to total number of outputs in the batch, for use in llama_get_logits_ith
+    n_outputs = n_outputs_all;
 
-    llama_data_write_dummy() {}
+    // wait for the computation to finish (automatically done when obtaining the model output)
+    //synchronize();
 
-    void write(const void * /* src */, size_t size) override {
-        size_written += size;
-    }
+    // decide if we need to defrag the kv cache
+    if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
+        // - do not defrag small contexts (i.e. < 2048 tokens)
+        // - count the padding towards the number of used tokens
+        const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f;
 
-    void write_tensor_data(const struct ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override {
-        size_written += size;
-    }
+        // queue defragmentation for next llama_kv_cache_update
+        if (fragmentation > cparams.defrag_thold) {
+            LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
 
-    size_t get_size_written() override {
-        return size_written;
+            kv_self->defrag();
+        }
     }
-};
 
-struct llama_data_write_buffer : llama_data_write {
-    uint8_t * ptr;
-    size_t buf_size = 0;
-    size_t size_written = 0;
+    // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
+    // overlap with device computation.
+    ggml_backend_sched_reset(sched.get());
 
-    llama_data_write_buffer(uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
+    return 0;
+}
 
-    void write(const void * src, size_t size) override {
-        if (size > buf_size) {
-            throw std::runtime_error("unexpectedly reached end of buffer");
-        }
-        memcpy(ptr, src, size);
-        ptr += size;
-        size_written += size;
-        buf_size -= size;
-    }
+//
+// output
+//
 
-    void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override {
-        if (size > buf_size) {
-            throw std::runtime_error("unexpectedly reached end of buffer");
-        }
-        ggml_backend_tensor_get(tensor, ptr, offset, size);
-        ptr += size;
-        size_written += size;
-        buf_size -= size;
-    }
+int32_t llama_context::output_reserve(int32_t n_outputs) {
+    const auto & hparams = model.hparams;
+    const auto & vocab   = model.vocab;
 
-    size_t get_size_written() override {
-        return size_written;
-    }
-};
+    const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max());
 
-struct llama_data_read_buffer : llama_data_read {
-    const uint8_t * ptr;
-    size_t buf_size = 0;
-    size_t size_read = 0;
+    const auto n_batch = cparams.n_batch;
+    const auto n_vocab = vocab.n_tokens();
+    const auto n_embd  = hparams.n_embd;
 
-    llama_data_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
+    // TODO: use a per-batch flag for logits presence instead
+    bool has_logits = !cparams.embeddings;
+    bool has_embd   =  cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
 
-    const uint8_t * read(size_t size) override {
-        const uint8_t * base_ptr = ptr;
-        if (size > buf_size) {
-            throw std::runtime_error("unexpectedly reached end of buffer");
-        }
-        ptr += size;
-        size_read += size;
-        buf_size -= size;
-        return base_ptr;
+    // TODO: hacky enc-dec support
+    if (model.arch == LLM_ARCH_T5) {
+        has_logits = true;
+        has_embd   = true;
     }
 
-    void read_to(void * dst, size_t size) override {
-        memcpy(dst, read(size), size);
-    }
+    logits_size = has_logits ? n_vocab*n_outputs_max : 0;
+    embd_size   = has_embd   ?  n_embd*n_outputs_max : 0;
 
-    size_t get_size_read() override {
-        return size_read;
+    if (output_ids.empty()) {
+        // init, never resized afterwards
+        output_ids.resize(n_batch);
     }
-};
 
-struct llama_data_write_file : llama_data_write {
-    llama_file * file;
-    size_t size_written = 0;
-    std::vector<uint8_t> temp_buffer;
+    const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
+    const size_t new_size  = (logits_size + embd_size) * sizeof(float);
 
-    llama_data_write_file(llama_file * f) : file(f) {}
+    // alloc only when more than the current capacity is required
+    // TODO: also consider shrinking the buffer
+    if (!buf_output || prev_size < new_size) {
+        if (buf_output) {
+#ifndef NDEBUG
+            // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
+            LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
+#endif
+            buf_output = nullptr;
+            logits = nullptr;
+            embd = nullptr;
+        }
+
+        auto * buft = ggml_backend_cpu_buffer_type();
+        // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory
+        auto * output_dev = model.dev_output();
+        auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr;
+        if (output_dev_host_buft) {
+            buft = output_dev_host_buft;
+        }
+        buf_output.reset(ggml_backend_buft_alloc_buffer(buft, new_size));
+        if (buf_output == nullptr) {
+            LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
+            return 0;
+        }
+    }
+
+    float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get());
+
+    logits = has_logits ? output_base               : nullptr;
+    embd   = has_embd   ? output_base + logits_size : nullptr;
+
+    // set all ids as invalid (negative)
+    std::fill(output_ids.begin(), output_ids.end(), -1);
+
+    ggml_backend_buffer_clear(buf_output.get(), 0);
+
+    this->n_outputs     = 0;
+    this->n_outputs_max = n_outputs_max;
+
+    return n_outputs_max;
+}
+
+void llama_context::output_reorder() {
+    auto & out_ids = sbatch.out_ids;
+    if (!out_ids.empty()) {
+        const uint32_t n_vocab = model.vocab.n_tokens();
+        const uint32_t n_embd  = model.hparams.n_embd;
+
+        GGML_ASSERT((size_t) n_outputs == out_ids.size());
+
+        // TODO: is there something more efficient which also minimizes swaps?
+        // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
+        for (int32_t i = 0; i < n_outputs - 1; ++i) {
+            int32_t j_min = i;
+            for (int32_t j = i + 1; j < n_outputs; ++j) {
+                if (out_ids[j] < out_ids[j_min]) {
+                    j_min = j;
+                }
+            }
+            if (j_min == i) { continue; }
+            std::swap(out_ids[i], out_ids[j_min]);
+            if (logits_size > 0) {
+                for (uint32_t k = 0; k < n_vocab; k++) {
+                    std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
+                }
+            }
+            if (embd_size > 0) {
+                for (uint32_t k = 0; k < n_embd; k++) {
+                    std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
+                }
+            }
+        }
+        std::fill(output_ids.begin(), output_ids.end(), -1);
+        for (int32_t i = 0; i < n_outputs; ++i) {
+            output_ids[out_ids[i]] = i;
+        }
+        out_ids.clear();
+    }
+}
+
+//
+// graph
+//
+
+int32_t llama_context::graph_max_nodes() const {
+    return std::max<int32_t>(8192, 5*model.n_tensors());
+}
+
+ggml_cgraph * llama_context::graph_init() {
+    ggml_init_params params = {
+        /*.mem_size   =*/ buf_compute_meta.size(),
+        /*.mem_buffer =*/ buf_compute_meta.data(),
+        /*.no_alloc   =*/ true,
+    };
+
+    ctx_compute.reset(ggml_init(params));
+
+    return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
+}
+
+llm_graph_result_ptr llama_context::graph_build(
+            ggml_context * ctx,
+             ggml_cgraph * gf,
+      const llama_ubatch & ubatch,
+            llm_graph_type gtype) {
+    return model.build_graph(
+            {
+                /*.ctx         =*/ ctx,
+                /*.arch        =*/ model.arch,
+                /*.hparams     =*/ model.hparams,
+                /*.cparams     =*/ cparams,
+                /*.ubatch      =*/ ubatch,
+                /*.sched       =*/ sched.get(),
+                /*.backend_cpu =*/ backend_cpu,
+                /*.cvec        =*/ &cvec,
+                /*.loras       =*/ &loras,
+                /*.memory      =*/ kv_self.get(),
+                /*.cross       =*/ &cross,
+                /*.n_outputs   =*/ n_outputs,
+                /*.cb          =*/ graph_get_cb(),
+            }, gf, gtype);
+}
+
+ggml_status llama_context::graph_compute(
+            ggml_cgraph * gf,
+                   bool   batched) {
+    int n_threads        = batched ? cparams.n_threads_batch : cparams.n_threads;
+    ggml_threadpool_t tp = batched ? threadpool_batch        : threadpool;
+
+    if (backend_cpu != nullptr) {
+        auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu));
+        auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool");
+        set_threadpool_fn(backend_cpu, tp);
+    }
+
+    // set the number of threads for all the backends
+    for (const auto & set_n_threads_fn : set_n_threads_fns) {
+        set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
+    }
+
+    auto status = ggml_backend_sched_graph_compute_async(sched.get(), gf);
+    if (status != GGML_STATUS_SUCCESS) {
+        LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status);
+    }
+
+    // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched));
+
+    return status;
+}
+
+llm_graph_cb llama_context::graph_get_cb() const {
+    return [&](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) {
+        if (il >= 0) {
+            ggml_format_name(cur, "%s-%d", name, il);
+        } else {
+            ggml_set_name(cur, name);
+        }
+
+        if (!cparams.offload_kqv) {
+            if (strcmp(name, "kqv_merged_cont") == 0) {
+                // all nodes between the KV store and the attention output are run on the CPU
+                ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu);
+            }
+        }
+
+        // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
+        // FIXME: fix in ggml_backend_sched
+        const bool full_offload = model.params.n_gpu_layers > (int) model.hparams.n_layer;
+        if (ubatch.n_tokens < 32 || full_offload) {
+            if (il != -1 && strcmp(name, "norm") == 0) {
+                const auto & dev_layer = model.dev_layer(il);
+                for (const auto & backend : backends) {
+                    if (ggml_backend_get_device(backend.get()) == dev_layer) {
+                        if (ggml_backend_supports_op(backend.get(), cur)) {
+                            ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend.get());
+                        }
+                    }
+                }
+            }
+        }
+    };
+}
+
+//
+// state save/load
+//
+
+class llama_io_write_dummy : public llama_io_write_i {
+public:
+    llama_io_write_dummy() = default;
+
+    void write(const void * /* src */, size_t size) override {
+        size_written += size;
+    }
+
+    void write_tensor(const ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override {
+        size_written += size;
+    }
+
+    size_t n_bytes() override {
+        return size_written;
+    }
+
+private:
+    size_t size_written = 0;
+};
+
+class llama_io_write_buffer : public llama_io_write_i {
+public:
+    llama_io_write_buffer(
+            uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
+
+    void write(const void * src, size_t size) override {
+        if (size > buf_size) {
+            throw std::runtime_error("unexpectedly reached end of buffer");
+        }
+        memcpy(ptr, src, size);
+        ptr += size;
+        size_written += size;
+        buf_size -= size;
+    }
+
+    void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override {
+        if (size > buf_size) {
+            throw std::runtime_error("unexpectedly reached end of buffer");
+        }
+        ggml_backend_tensor_get(tensor, ptr, offset, size);
+        ptr += size;
+        size_written += size;
+        buf_size -= size;
+    }
+
+    size_t n_bytes() override {
+        return size_written;
+    }
+
+private:
+    uint8_t * ptr;
+    size_t buf_size = 0;
+    size_t size_written = 0;
+};
+
+class llama_io_read_buffer : public llama_io_read_i {
+public:
+    llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
+
+    const uint8_t * read(size_t size) override {
+        const uint8_t * base_ptr = ptr;
+        if (size > buf_size) {
+            throw std::runtime_error("unexpectedly reached end of buffer");
+        }
+        ptr += size;
+        size_read += size;
+        buf_size -= size;
+        return base_ptr;
+    }
+
+    void read_to(void * dst, size_t size) override {
+        memcpy(dst, read(size), size);
+    }
+
+    size_t n_bytes() override {
+        return size_read;
+    }
+
+private:
+    const uint8_t * ptr;
+    size_t buf_size = 0;
+    size_t size_read = 0;
+};
+
+class llama_io_write_file : public llama_io_write_i {
+public:
+    llama_io_write_file(llama_file * f) : file(f) {}
 
     void write(const void * src, size_t size) override {
         file->write_raw(src, size);
         size_written += size;
     }
 
-    void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override {
+    void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override {
         temp_buffer.resize(size);
         ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size);
         write(temp_buffer.data(), temp_buffer.size());
     }
 
-    size_t get_size_written() override {
+    size_t n_bytes() override {
         return size_written;
     }
-};
 
-struct llama_data_read_file : llama_data_read {
+private:
     llama_file * file;
-    size_t size_read = 0;
+    size_t size_written = 0;
     std::vector<uint8_t> temp_buffer;
+};
 
-    llama_data_read_file(llama_file * f) : file(f) {}
+class llama_io_read_file : public llama_io_read_i {
+public:
+    llama_io_read_file(llama_file * f) : file(f) {}
 
     void read_to(void * dst, size_t size) override {
         file->read_raw(dst, size);
@@ -1489,89 +1813,78 @@ struct llama_data_read_file : llama_data_read {
         return temp_buffer.data();
     }
 
-    size_t get_size_read() override {
+    size_t n_bytes() override {
         return size_read;
     }
-};
-
-/** copy state data into either a buffer or file depending on the passed in context
- *
- * file context:
- * llama_file file("/path", "wb");
- * llama_data_write_file data_ctx(&file);
- * llama_state_get_data_internal(ctx, data_ctx);
- *
- * buffer context:
- * std::vector<uint8_t> buf(max_size, 0);
- * llama_data_write_buffer data_ctx(buf.data(), max_size);
- * llama_state_get_data_internal(ctx, data_ctx);
- *
-*/
-static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx) {
-    llama_synchronize(ctx);
-
-    data_ctx.write_model_info(ctx);
-
-    // copy outputs
-    data_ctx.write_output_ids(ctx);
-    data_ctx.write_logits(ctx);
-    data_ctx.write_embeddings(ctx);
 
-    data_ctx.write_kv_cache(ctx);
+private:
+    llama_file * file;
+    size_t size_read = 0;
+    std::vector<uint8_t> temp_buffer;
+};
 
-    return data_ctx.get_size_written();
+size_t llama_context::state_get_size() {
+    llama_io_write_dummy io;
+    try {
+        return state_write_data(io);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
+        return 0;
+    }
 }
 
-size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst, size_t size) {
-    llama_data_write_buffer data_ctx(dst, size);
+size_t llama_context::state_get_data(uint8_t * dst, size_t size) {
+    llama_io_write_buffer io(dst, size);
     try {
-        return llama_state_get_data_internal(ctx, data_ctx);
+        return state_write_data(io);
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
         return 0;
     }
 }
 
-// Returns the *actual* size of the state.
-// Intended to be used when saving to state to a buffer.
-size_t llama_state_get_size(struct llama_context * ctx) {
-    llama_data_write_dummy data_ctx;
+size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
+    llama_io_read_buffer io(src, size);
     try {
-        return llama_state_get_data_internal(ctx, data_ctx);
+        return state_read_data(io);
     } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
+        LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
         return 0;
     }
 }
 
-static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx) {
-    llama_synchronize(ctx);
-
-    data_ctx.read_model_info(ctx);
-
-    // set outputs
-    data_ctx.read_output_ids(ctx);
-    data_ctx.read_logits(ctx);
-    data_ctx.read_embeddings(ctx);
-
-    data_ctx.read_kv_cache(ctx);
+size_t llama_context::state_seq_get_size(llama_seq_id seq_id) {
+    llama_io_write_dummy io;
+    try {
+        return state_seq_write_data(io, seq_id);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
+        return 0;
+    }
+}
 
-    return data_ctx.get_size_read();
+size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
+    llama_io_write_buffer io(dst, size);
+    try {
+        return state_seq_write_data(io, seq_id);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
+        return 0;
+    }
 }
 
-// Sets the state reading from the specified source address
-size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src, size_t size) {
-    llama_data_read_buffer data_ctx(src, size);
+size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
+    llama_io_read_buffer io(src, size);
     try {
-        return llama_state_set_data_internal(ctx, data_ctx);
+        return state_seq_read_data(io, seq_id);
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
         return 0;
     }
 }
 
-static bool llama_state_load_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
-    llama_file file(path_session, "rb");
+bool llama_context::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
+    llama_file file(filepath, "rb");
 
     // sanity checks
     {
@@ -1601,28 +1914,20 @@ static bool llama_state_load_file_internal(struct llama_context * ctx, const cha
     {
         const size_t n_state_size_cur = file.size() - file.tell();
 
-        llama_data_read_file data_ctx(&file);
-        const size_t n_read = llama_state_set_data_internal(ctx, data_ctx);
+        llama_io_read_file io( &file);
+        const size_t n_read = state_read_data(io);
 
         if (n_read != n_state_size_cur) {
             LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read);
             return false;
         }
     }
-    return true;
-}
 
-bool llama_state_load_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
-    try {
-        return llama_state_load_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what());
-        return false;
-    }
+    return true;
 }
 
-static bool llama_state_save_file_internal(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
-    llama_file file(path_session, "wb");
+bool llama_context::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) {
+    llama_file file(filepath, "wb");
 
     file.write_u32(LLAMA_SESSION_MAGIC);
     file.write_u32(LLAMA_SESSION_VERSION);
@@ -1632,82 +1937,13 @@ static bool llama_state_save_file_internal(struct llama_context * ctx, const cha
     file.write_raw(tokens, sizeof(llama_token) * n_token_count);
 
     // save the context state using stream saving
-    llama_data_write_file data_ctx(&file);
-    llama_state_get_data_internal(ctx, data_ctx);
+    llama_io_write_file io(&file);
+    state_write_data(io);
 
     return true;
 }
 
-bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
-    try {
-        return llama_state_save_file_internal(ctx, path_session, tokens, n_token_count);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what());
-        return false;
-    }
-}
-
-static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) {
-    llama_synchronize(ctx);
-
-    data_ctx.write_kv_cache(ctx, seq_id);
-
-    return data_ctx.get_size_written();
-}
-
-size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id) {
-    llama_data_write_dummy data_ctx;
-    return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
-}
-
-size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
-    llama_data_write_buffer data_ctx(dst, size);
-    try {
-        return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: error saving sequence state: %s\n", __func__, err.what());
-        return 0;
-    }
-}
-
-static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) {
-    llama_synchronize(ctx);
-
-    data_ctx.read_kv_cache(ctx, dest_seq_id);
-
-    return data_ctx.get_size_read();
-}
-
-size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id) {
-    llama_data_read_buffer data_ctx(src, size);
-    try {
-        return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: error loading sequence state: %s\n", __func__, err.what());
-        return 0;
-    }
-}
-
-static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
-    llama_file file(filepath, "wb");
-
-    file.write_u32(LLAMA_STATE_SEQ_MAGIC);
-    file.write_u32(LLAMA_STATE_SEQ_VERSION);
-
-    // save the prompt
-    file.write_u32((uint32_t) n_token_count);
-    file.write_raw(tokens, sizeof(llama_token) * n_token_count);
-
-    // save the context state using stream saving
-    llama_data_write_file data_ctx(&file);
-    llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
-
-    const size_t res = file.tell();
-    GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written());
-    return res;
-}
-
-static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
+size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
     llama_file file(filepath, "rb");
 
     // version checks
@@ -1737,8 +1973,8 @@ static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, con
     // restore the context state
     {
         const size_t state_size = file.size() - file.tell();
-        llama_data_read_file data_ctx(&file);
-        const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
+        llama_io_read_file io(&file);
+        const size_t nread = state_seq_read_data(io, seq_id);
         if (!nread) {
             LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
             return 0;
@@ -1750,26 +1986,785 @@ static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, con
     return file.tell();
 }
 
-size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
-    try {
-        return llama_state_seq_save_file_internal(ctx, filepath, seq_id, tokens, n_token_count);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what());
-        return 0;
-    }
-}
+size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) {
+    llama_file file(filepath, "wb");
 
-size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
-    try {
-        return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what());
-        return 0;
-    }
+    file.write_u32(LLAMA_STATE_SEQ_MAGIC);
+    file.write_u32(LLAMA_STATE_SEQ_VERSION);
+
+    // save the prompt
+    file.write_u32((uint32_t) n_token_count);
+    file.write_raw(tokens, sizeof(llama_token) * n_token_count);
+
+    // save the context state using stream saving
+    llama_io_write_file io(&file);
+    state_seq_write_data(io, seq_id);
+
+    const size_t res = file.tell();
+    GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
+
+    return res;
+}
+
+size_t llama_context::state_write_data(llama_io_write_i & io) {
+    LLAMA_LOG_DEBUG("%s: writing state\n", __func__);
+
+    // write model info
+    {
+        LLAMA_LOG_DEBUG("%s: - writing model info\n", __func__);
+
+        const std::string arch_str = llm_arch_name(model.arch);
+        io.write_string(arch_str);
+        // TODO: add more model-specific info which should prevent loading the session file if not identical
+    }
+
+    // write output ids
+    {
+        LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
+
+        output_reorder();
+
+        const auto n_outputs    = this->n_outputs;
+        const auto & output_ids = this->output_ids;
+
+        std::vector<int32_t> w_output_pos;
+
+        GGML_ASSERT(n_outputs <= n_outputs_max);
+
+        w_output_pos.resize(n_outputs);
+
+        // build a more compact representation of the output ids
+        for (size_t i = 0; i < n_batch(); ++i) {
+            // map an output id to a position in the batch
+            int32_t pos = output_ids[i];
+            if (pos >= 0) {
+                GGML_ASSERT(pos < n_outputs);
+                w_output_pos[pos] = i;
+            }
+        }
+
+        io.write(&n_outputs, sizeof(n_outputs));
+
+        if (n_outputs) {
+            io.write(w_output_pos.data(), n_outputs * sizeof(int32_t));
+        }
+    }
+
+    // write logits
+    {
+        LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
+
+        const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens());
+
+        io.write(&logits_size, sizeof(logits_size));
+
+        if (logits_size) {
+            io.write(logits, logits_size * sizeof(float));
+        }
+    }
+
+    // write embeddings
+    {
+        LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__);
+
+        const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd);
+
+        io.write(&embd_size, sizeof(embd_size));
+
+        if (embd_size) {
+            io.write(embd, embd_size * sizeof(float));
+        }
+    }
+
+    LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
+    kv_self->state_write(io);
+
+    return io.n_bytes();
+}
+
+size_t llama_context::state_read_data(llama_io_read_i & io) {
+    LLAMA_LOG_DEBUG("%s: reading state\n", __func__);
+
+    // read model info
+    {
+        LLAMA_LOG_DEBUG("%s: - reading model info\n", __func__);
+
+        const std::string cur_arch_str = llm_arch_name(model.arch);
+
+        std::string arch_str;
+        io.read_string(arch_str);
+        if (cur_arch_str != arch_str) {
+            throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str()));
+        }
+        // TODO: add more info which needs to be identical but which is not verified otherwise
+    }
+
+    // read output ids
+    {
+        LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__);
+
+        auto n_outputs = this->n_outputs;
+        io.read_to(&n_outputs, sizeof(n_outputs));
+
+        if (n_outputs > output_reserve(n_outputs)) {
+            throw std::runtime_error("could not reserve outputs");
+        }
+
+        std::vector<int32_t> output_pos;
+
+        if (n_outputs) {
+            output_pos.resize(n_outputs);
+            io.read_to(output_pos.data(), n_outputs * sizeof(int32_t));
+
+            for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
+                int32_t id = output_pos[i];
+                if ((uint32_t) id >= n_batch()) {
+                    throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch()));
+                }
+                this->output_ids[id] = i;
+            }
+
+            this->n_outputs = n_outputs;
+        }
+    }
+
+    // read logits
+    {
+        LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__);
+
+        uint64_t logits_size;
+        io.read_to(&logits_size, sizeof(logits_size));
+
+        if (this->logits_size < logits_size) {
+            throw std::runtime_error("logits buffer too small");
+        }
+
+        if (logits_size) {
+            io.read_to(this->logits, logits_size * sizeof(float));
+        }
+    }
+
+    // read embeddings
+    {
+        LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__);
+
+        uint64_t embd_size;
+        io.read_to(&embd_size, sizeof(embd_size));
+
+        if (this->embd_size < embd_size) {
+            throw std::runtime_error("embeddings buffer too small");
+        }
+
+        if (embd_size) {
+            io.read_to(this->embd, embd_size * sizeof(float));
+        }
+    }
+
+    LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
+    kv_self->state_read(io);
+
+    return io.n_bytes();
+}
+
+size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
+    GGML_UNUSED(seq_id);
+
+    kv_self->state_write(io, seq_id);
+
+    return io.n_bytes();
+}
+
+size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
+    GGML_UNUSED(seq_id);
+
+    kv_self->state_read(io, seq_id);
+
+    return io.n_bytes();
+}
+
+//
+// perf
+//
+
+llama_perf_context_data llama_context::perf_get_data() const {
+    llama_perf_context_data data = {};
+
+    data.t_start_ms  = 1e-3 * t_start_us;
+    data.t_load_ms   = 1e-3 * t_load_us;
+    data.t_p_eval_ms = 1e-3 * t_p_eval_us;
+    data.t_eval_ms   = 1e-3 * t_eval_us;
+    data.n_p_eval    = std::max(1, n_p_eval);
+    data.n_eval      = std::max(1, n_eval);
+
+    return data;
+}
+
+void llama_context::perf_reset() {
+    t_start_us  = ggml_time_us();
+    t_eval_us   = n_eval = 0;
+    t_p_eval_us = n_p_eval = 0;
+}
+
+//
+// interface implementation
+//
+
+llama_context_params llama_context_default_params() {
+    llama_context_params result = {
+        /*.n_ctx                       =*/ 512,
+        /*.n_batch                     =*/ 2048,
+        /*.n_ubatch                    =*/ 512,
+        /*.n_seq_max                   =*/ 1,
+        /*.n_threads                   =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
+        /*.n_threads_batch             =*/ GGML_DEFAULT_N_THREADS,
+        /*.rope_scaling_type           =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
+        /*.pooling_type                =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
+        /*.attention_type              =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
+        /*.rope_freq_base              =*/ 0.0f,
+        /*.rope_freq_scale             =*/ 0.0f,
+        /*.yarn_ext_factor             =*/ -1.0f,
+        /*.yarn_attn_factor            =*/ 1.0f,
+        /*.yarn_beta_fast              =*/ 32.0f,
+        /*.yarn_beta_slow              =*/ 1.0f,
+        /*.yarn_orig_ctx               =*/ 0,
+        /*.defrag_thold                =*/ -1.0f,
+        /*.cb_eval                     =*/ nullptr,
+        /*.cb_eval_user_data           =*/ nullptr,
+        /*.type_k                      =*/ GGML_TYPE_F16,
+        /*.type_v                      =*/ GGML_TYPE_F16,
+        /*.logits_all                  =*/ false,
+        /*.embeddings                  =*/ false,
+        /*.offload_kqv                 =*/ true,
+        /*.flash_attn                  =*/ false,
+        /*.no_perf                     =*/ true,
+        /*.abort_callback              =*/ nullptr,
+        /*.abort_callback_data         =*/ nullptr,
+    };
+
+    return result;
+}
+
+llama_context * llama_init_from_model(
+                 llama_model * model,
+        llama_context_params   params) {
+    if (!model) {
+        LLAMA_LOG_ERROR("%s: model cannot be NULL\n", __func__);
+        return nullptr;
+    }
+
+    if (params.n_batch == 0 && params.n_ubatch == 0) {
+        LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__);
+        return nullptr;
+    }
+
+    if (params.n_ctx == 0 && model->hparams.n_ctx_train == 0) {
+        LLAMA_LOG_ERROR("%s: n_ctx and model->hparams.n_ctx_train cannot both be zero\n", __func__);
+        return nullptr;
+    }
+
+    if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
+        LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
+        params.flash_attn = false;
+    }
+
+    if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
+        LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
+        params.flash_attn = false;
+    }
+
+    if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
+        LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
+        return nullptr;
+    }
+
+    try {
+        auto * ctx = new llama_context(*model, params);
+        return ctx;
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: failed to initialize the context: %s\n", __func__, err.what());
+    }
+
+    return nullptr;
+}
+
+// deprecated
+llama_context * llama_new_context_with_model(
+                 llama_model * model,
+        llama_context_params   params) {
+    return llama_init_from_model(model, params);
+}
+
+void llama_free(llama_context * ctx) {
+    delete ctx;
+}
+
+uint32_t llama_n_ctx(const llama_context * ctx) {
+    return ctx->n_ctx();
+}
+
+uint32_t llama_n_batch(const llama_context * ctx) {
+    return ctx->n_batch();
+}
+
+uint32_t llama_n_ubatch(const llama_context * ctx) {
+    return ctx->n_ubatch();
+}
+
+uint32_t llama_n_seq_max(const llama_context * ctx) {
+    return ctx->n_seq_max();
+}
+
+const llama_model * llama_get_model(const llama_context * ctx) {
+    return &ctx->get_model();
+}
+
+llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
+    return ctx->get_kv_self();
+}
+
+void llama_kv_self_update(llama_context * ctx) {
+    ctx->kv_self_update();
+}
+
+enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
+    return ctx->pooling_type();
+}
+
+void llama_attach_threadpool(
+            llama_context * ctx,
+        ggml_threadpool_t   threadpool,
+        ggml_threadpool_t   threadpool_batch) {
+    ctx->attach_threadpool(threadpool, threadpool_batch);
+}
+
+void llama_detach_threadpool(llama_context * ctx) {
+    ctx->detach_threadpool();
+}
+
+void llama_set_n_threads(llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) {
+    ctx->set_n_threads(n_threads, n_threads_batch);
+}
+
+int32_t llama_n_threads(llama_context * ctx) {
+    return ctx->n_threads();
+}
+
+int32_t llama_n_threads_batch(llama_context * ctx) {
+    return ctx->n_threads_batch();
+}
+
+void llama_set_abort_callback(llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) {
+    ctx->set_abort_callback(abort_callback, abort_callback_data);
+}
+
+void llama_set_embeddings(llama_context * ctx, bool embeddings) {
+    ctx->set_embeddings(embeddings);
+}
+
+void llama_set_causal_attn(llama_context * ctx, bool causal_attn) {
+    ctx->set_causal_attn(causal_attn);
+}
+
+void llama_synchronize(llama_context * ctx) {
+    ctx->synchronize();
+}
+
+float * llama_get_logits(llama_context * ctx) {
+    ctx->synchronize();
+
+    return ctx->get_logits();
+}
+
+float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
+    ctx->synchronize();
+
+    return ctx->get_logits_ith(i);
+}
+
+float * llama_get_embeddings(llama_context * ctx) {
+    ctx->synchronize();
+
+    return ctx->get_embeddings();
+}
+
+float * llama_get_embeddings_ith(llama_context * ctx, int32_t i) {
+    ctx->synchronize();
+
+    return ctx->get_embeddings_ith(i);
+}
+
+float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
+    ctx->synchronize();
+
+    return ctx->get_embeddings_seq(seq_id);
+}
+
+// llama adapter API
+
+int32_t llama_set_adapter_lora(
+            llama_context * ctx,
+            llama_adapter_lora * adapter,
+            float scale) {
+    ctx->set_adapter_lora(adapter, scale);
+
+    return 0;
+}
+
+int32_t llama_rm_adapter_lora(
+            llama_context * ctx,
+            llama_adapter_lora * adapter) {
+    bool res = ctx->rm_adapter_lora(adapter);
+
+    return res ? 0 : -1;
+}
+
+void llama_clear_adapter_lora(llama_context * ctx) {
+    ctx->clear_adapter_lora();
+}
+
+int32_t llama_apply_adapter_cvec(
+        llama_context * ctx,
+                 const float * data,
+                      size_t   len,
+                     int32_t   n_embd,
+                     int32_t   il_start,
+                     int32_t   il_end) {
+    bool res = ctx->apply_adapter_cvec(data, len, n_embd, il_start, il_end);
+
+    return res ? 0 : -1;
+}
+
+//
+// kv cache view
+//
+
+llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) {
+    const auto * kv = ctx->get_kv_self();
+    if (kv == nullptr) {
+        LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
+        return {};
+    }
+
+    return llama_kv_cache_view_init(*kv, n_seq_max);
+}
+
+void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) {
+    const auto * kv = ctx->get_kv_self();
+    if (kv == nullptr) {
+        LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__);
+        return;
+    }
+
+    llama_kv_cache_view_update(view, kv);
+}
+
+//
+// kv cache
+//
+
+// deprecated
+int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
+    return llama_kv_self_n_tokens(ctx);
+}
+
+int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
+    return llama_kv_cache_n_tokens(ctx->get_kv_self());
+}
+
+// deprecated
+int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
+    return llama_kv_self_used_cells(ctx);
+}
+
+int32_t llama_kv_self_used_cells(const llama_context * ctx) {
+    return llama_kv_cache_used_cells(ctx->get_kv_self());
+}
+
+// deprecated
+void llama_kv_cache_clear(llama_context * ctx) {
+    llama_kv_self_clear(ctx);
+}
+
+void llama_kv_self_clear(llama_context * ctx) {
+    llama_kv_cache_clear(ctx->get_kv_self());
+}
+
+// deprecated
+bool llama_kv_cache_seq_rm(
+        llama_context * ctx,
+         llama_seq_id   seq_id,
+            llama_pos   p0,
+            llama_pos   p1) {
+    return llama_kv_self_seq_rm(ctx, seq_id, p0, p1);
+}
+
+bool llama_kv_self_seq_rm(
+        llama_context * ctx,
+         llama_seq_id   seq_id,
+            llama_pos   p0,
+            llama_pos   p1) {
+    return llama_kv_cache_seq_rm(ctx->get_kv_self(), seq_id, p0, p1);
+}
+
+// deprecated
+void llama_kv_cache_seq_cp(
+        llama_context * ctx,
+         llama_seq_id   seq_id_src,
+         llama_seq_id   seq_id_dst,
+            llama_pos   p0,
+            llama_pos   p1) {
+    return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
+}
+
+void llama_kv_self_seq_cp(
+        llama_context * ctx,
+         llama_seq_id   seq_id_src,
+         llama_seq_id   seq_id_dst,
+            llama_pos   p0,
+            llama_pos   p1) {
+    return llama_kv_cache_seq_cp(ctx->get_kv_self(), seq_id_src, seq_id_dst, p0, p1);
+}
+
+// deprecated
+void llama_kv_cache_seq_keep(
+        llama_context * ctx,
+         llama_seq_id   seq_id) {
+    return llama_kv_self_seq_keep(ctx, seq_id);
+}
+
+void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
+    return llama_kv_cache_seq_keep(ctx->get_kv_self(), seq_id);
+}
+
+// deprecated
+void llama_kv_cache_seq_add(
+        llama_context * ctx,
+         llama_seq_id   seq_id,
+            llama_pos   p0,
+            llama_pos   p1,
+            llama_pos   delta) {
+    return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
+}
+
+void llama_kv_self_seq_add(
+        llama_context * ctx,
+         llama_seq_id   seq_id,
+            llama_pos   p0,
+            llama_pos   p1,
+            llama_pos   delta) {
+    return llama_kv_cache_seq_add(ctx->get_kv_self(), seq_id, p0, p1, delta);
+}
+
+// deprecated
+void llama_kv_cache_seq_div(
+        llama_context * ctx,
+         llama_seq_id   seq_id,
+            llama_pos   p0,
+            llama_pos   p1,
+                  int   d) {
+    return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
+}
+
+void llama_kv_self_seq_div(
+        llama_context * ctx,
+         llama_seq_id   seq_id,
+            llama_pos   p0,
+            llama_pos   p1,
+                  int   d) {
+    return llama_kv_cache_seq_div(ctx->get_kv_self(), seq_id, p0, p1, d);
+}
+
+// deprecated
+llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
+    return llama_kv_self_seq_pos_max(ctx, seq_id);
+}
+
+llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
+    return llama_kv_cache_seq_pos_max(ctx->get_kv_self(), seq_id);
+}
+
+// deprecated
+void llama_kv_cache_defrag(llama_context * ctx) {
+    return llama_kv_self_defrag(ctx);
+}
+
+void llama_kv_self_defrag(llama_context * ctx) {
+    llama_kv_cache_defrag(ctx->get_kv_self());
+}
+
+// deprecated
+bool llama_kv_cache_can_shift(const llama_context * ctx) {
+    return llama_kv_self_can_shift(ctx);
+}
+
+bool llama_kv_self_can_shift(const llama_context * ctx) {
+    return llama_kv_cache_can_shift(ctx->get_kv_self());
+}
+
+// deprecated
+void llama_kv_cache_update(llama_context * ctx) {
+    llama_kv_self_update(ctx);
+}
+
+// llama state API
+
+// deprecated
+size_t llama_get_state_size(llama_context * ctx) {
+    return llama_state_get_size(ctx);
+}
+
+// deprecated
+size_t llama_copy_state_data(llama_context * ctx, uint8_t * dst) {
+    return llama_state_get_data(ctx, dst, -1);
+}
+
+// deprecated
+size_t llama_set_state_data(llama_context * ctx, const uint8_t * src) {
+    return llama_state_set_data(ctx, src, -1);
+}
+
+// deprecated
+bool llama_load_session_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
+    return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
+}
+
+// deprecated
+bool llama_save_session_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
+    return llama_state_save_file(ctx, path_session, tokens, n_token_count);
+}
+
+// Returns the *actual* size of the state.
+// Intended to be used when saving to state to a buffer.
+size_t llama_state_get_size(llama_context * ctx) {
+    return ctx->state_get_size();
+}
+
+size_t llama_state_get_data(llama_context * ctx, uint8_t * dst, size_t size) {
+    ctx->synchronize();
+
+    return ctx->state_get_data(dst, size);
+}
+
+// Sets the state reading from the specified source address
+size_t llama_state_set_data(llama_context * ctx, const uint8_t * src, size_t size) {
+    ctx->synchronize();
+
+    return ctx->state_set_data(src, size);
+}
+
+bool llama_state_load_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
+    ctx->synchronize();
+
+    try {
+        return ctx->state_load_file(path_session, tokens_out, n_token_capacity, n_token_count_out);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what());
+        return false;
+    }
+}
+
+bool llama_state_save_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
+    ctx->synchronize();
+
+    try {
+        return ctx->state_save_file(path_session, tokens, n_token_count);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what());
+        return false;
+    }
+}
+
+size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) {
+    return ctx->state_seq_get_size(seq_id);
+}
+
+size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
+    ctx->synchronize();
+
+    return ctx->state_seq_get_data(seq_id, dst, size);
+}
+
+size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
+    ctx->synchronize();
+
+    return ctx->state_seq_set_data(seq_id, src, size);
+}
+
+size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
+    ctx->synchronize();
+
+    try {
+        return ctx->state_seq_save_file(seq_id, filepath, tokens, n_token_count);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what());
+        return 0;
+    }
+}
+
+size_t llama_state_seq_load_file(llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
+    ctx->synchronize();
+
+    try {
+        return ctx->state_seq_load_file(dest_seq_id, filepath, tokens_out, n_token_capacity, n_token_count_out);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what());
+        return 0;
+    }
+}
+
+///
+
+int32_t llama_encode(
+        llama_context * ctx,
+          llama_batch   batch) {
+    const int ret = ctx->encode(batch);
+    if (ret != 0) {
+        LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
+    }
+
+    return ret;
+}
+
+int32_t llama_decode(
+        llama_context * ctx,
+          llama_batch   batch) {
+    const int ret = ctx->decode(batch);
+    if (ret != 0) {
+        LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
+    }
+
+    return ret;
+}
+
+//
+// perf
+//
+
+llama_perf_context_data llama_perf_context(const llama_context * ctx) {
+    llama_perf_context_data data = {};
+
+    if (ctx == nullptr) {
+        return data;
+    }
+
+    data = ctx->perf_get_data();
+
+    return data;
+}
+
+void llama_perf_context_print(const llama_context * ctx) {
+    const auto data = llama_perf_context(ctx);
+
+    const double t_end_ms = 1e-3 * ggml_time_us();
+
+    LLAMA_LOG_INFO("%s:        load time = %10.2f ms\n", __func__, data.t_load_ms);
+    LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
+            __func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
+    LLAMA_LOG_INFO("%s:        eval time = %10.2f ms / %5d runs   (%8.2f ms per token, %8.2f tokens per second)\n",
+            __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
+    LLAMA_LOG_INFO("%s:       total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
 }
 
-const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
-    struct llama_context * ctx
-) {
-    return ctx->model.tensors_by_name;
+void llama_perf_context_reset(llama_context * ctx) {
+    ctx->perf_reset();
 }
index a9268b292090893a01875b06fc43a8b1907fb783..71d702e8baeeb0b14584a1f61c0dcd9ab4ff1a2e 100644 (file)
 #include "llama.h"
 #include "llama-batch.h"
 #include "llama-cparams.h"
-#include "llama-model.h"
-#include "llama-kv-cache.h"
+#include "llama-graph.h"
 #include "llama-adapter.h"
 
 #include "ggml-cpp.h"
 
 #include <map>
-#include <unordered_map>
 #include <vector>
-#include <set>
+
+struct llama_model;
+struct llama_kv_cache;
+
+class llama_io_read_i;
+class llama_io_write_i;
 
 struct llama_context {
-    llama_context(const llama_model & model)
-        : model(model)
-        , t_start_us(model.t_start_us)
-        , t_load_us(model.t_load_us) {}
+    // init scheduler and compute buffers, reserve worst-case graphs
+    llama_context(
+            const llama_model & model,
+                  llama_context_params params);
 
-    const struct llama_model & model;
+    ~llama_context();
 
-    struct llama_cparams      cparams;
-    struct llama_sbatch       sbatch;  // TODO: revisit if needed
-    struct llama_kv_cache     kv_self;
-    struct llama_adapter_cvec cvec;
+    void synchronize();
 
-    std::unordered_map<struct llama_adapter_lora *, float> lora;
+    const llama_model & get_model() const;
 
-    std::vector<ggml_backend_ptr> backends;
-    std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
+    uint32_t n_ctx()         const;
+    uint32_t n_ctx_per_seq() const;
+    uint32_t n_batch()       const;
+    uint32_t n_ubatch()      const;
+    uint32_t n_seq_max()     const;
 
-    ggml_backend_t backend_cpu = nullptr;
+    uint32_t n_threads()       const;
+    uint32_t n_threads_batch() const;
 
-    ggml_threadpool_t threadpool       = nullptr;
-    ggml_threadpool_t threadpool_batch = nullptr;
+          llama_kv_cache * get_kv_self();
+    const llama_kv_cache * get_kv_self() const;
 
-    bool has_evaluated_once = false;
+    void kv_self_update();
 
-    mutable int64_t t_start_us;
-    mutable int64_t t_load_us;
-    mutable int64_t t_p_eval_us = 0;
-    mutable int64_t t_eval_us   = 0;
+    enum llama_pooling_type pooling_type() const;
 
-    mutable int64_t t_compute_start_us = 0;
-    mutable int64_t n_queued_tokens = 0;
+    float * get_logits();
+    float * get_logits_ith(int32_t i);
 
-    mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
-    mutable int32_t n_eval   = 0; // number of eval calls
+    float * get_embeddings();
+    float * get_embeddings_ith(int32_t i);
+    float * get_embeddings_seq(llama_seq_id seq_id);
 
-    // host buffer for the model output (logits and embeddings)
-    ggml_backend_buffer_ptr buf_output;
+    void attach_threadpool(
+            ggml_threadpool_t threadpool,
+            ggml_threadpool_t threadpool_batch);
 
-    // decode output (2-dimensional array: [n_outputs][n_vocab])
-    size_t  logits_size = 0; // capacity (of floats) for logits
-    float * logits      = nullptr;
+    void detach_threadpool();
 
-    std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
-    size_t  output_size = 0; // capacity (of tokens positions) for the output buffers
-    int32_t n_outputs   = 0; // number of actually-used outputs in the current ubatch or last logical batch
+    void set_n_threads(int32_t n_threads, int32_t n_threads_batch);
+
+    void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data);
+
+    void set_embeddings (bool value);
+    void set_causal_attn(bool value);
+
+    void set_adapter_lora(
+            llama_adapter_lora * adapter,
+            float scale);
+
+    bool rm_adapter_lora(
+            llama_adapter_lora * adapter);
+
+    void clear_adapter_lora();
+
+    bool apply_adapter_cvec(
+            const float * data,
+                 size_t   len,
+                int32_t   n_embd,
+                int32_t   il_start,
+                int32_t   il_end);
+
+    int encode(llama_batch & inp_batch);
+    int decode(llama_batch & inp_batch);
+
+    //
+    // state save/load
+    //
+
+    size_t state_get_size();
+    size_t state_get_data(      uint8_t * dst, size_t size);
+    size_t state_set_data(const uint8_t * src, size_t size);
+
+    size_t state_seq_get_size(llama_seq_id seq_id);
+    size_t state_seq_get_data(llama_seq_id seq_id,       uint8_t * dst, size_t size);
+    size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size);
+
+    bool state_load_file(
+            const char * filepath,
+           llama_token * tokens_out,
+                size_t   n_token_capacity,
+                size_t * n_token_count_out);
+
+    bool state_save_file(
+            const char * filepath,
+     const llama_token * tokens,
+                size_t   n_token_count);
+
+    size_t state_seq_load_file(
+          llama_seq_id   seq_id,
+            const char * filepath,
+           llama_token * tokens_out,
+                size_t   n_token_capacity,
+                size_t * n_token_count_out);
+
+    size_t state_seq_save_file(
+          llama_seq_id   seq_id,
+            const char * filepath,
+     const llama_token * tokens,
+                size_t   n_token_count);
+
+    //
+    // perf
+    //
+
+    llama_perf_context_data perf_get_data() const;
+    void perf_reset();
+
+private:
+    //
+    // output
+    //
+
+    // Make sure enough space is available for outputs.
+    // Returns max number of outputs for which space was reserved.
+    int32_t output_reserve(int32_t n_outputs);
+
+    // make the outputs have the same order they had in the user-provided batch
+    // TODO: maybe remove this
+    void output_reorder();
 
+    //
+    // graph
+    //
+
+    int32_t graph_max_nodes() const;
+
+    // zero-out inputs and create the ctx_compute for the compute graph
+    ggml_cgraph * graph_init();
+
+    llm_graph_result_ptr graph_build(
+            ggml_context * ctx,
+             ggml_cgraph * gf,
+      const llama_ubatch & ubatch,
+          llm_graph_type   gtype);
+
+    // returns the result of ggml_backend_sched_graph_compute_async execution
+    ggml_status graph_compute(
+            ggml_cgraph * gf,
+                   bool   batched);
+
+    llm_graph_cb graph_get_cb() const;
+
+    // used by kv_self_update()
+    ggml_tensor * build_rope_shift(
+        ggml_context * ctx0,
+        ggml_tensor * cur,
+        ggml_tensor * shift,
+        ggml_tensor * factors,
+        ggml_backend_buffer * bbuf) const;
+
+    llm_graph_result_ptr build_kv_self_shift(
+            ggml_context * ctx0,
+            ggml_cgraph * gf) const;
+
+    llm_graph_result_ptr build_kv_self_defrag(
+            ggml_context * ctx0,
+            ggml_cgraph * gf) const;
+
+    // TODO: read/write lora adapters and cvec
+    size_t state_write_data(llama_io_write_i & io);
+    size_t state_read_data (llama_io_read_i  & io);
+
+    size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id);
+    size_t state_seq_read_data (llama_io_read_i  & io, llama_seq_id seq_id);
+
+    //
+    // members
+    //
+
+    const llama_model & model;
+
+    llama_cparams       cparams;
+    llama_adapter_cvec  cvec;
+    llama_adapter_loras loras;
+    llama_sbatch        sbatch;
+
+    llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
+
+    std::unique_ptr<llama_kv_cache_unified> kv_self;
+
+    // TODO: remove
     bool logits_all = false;
 
+    // decode output (2-dimensional array: [n_outputs][n_vocab])
+    size_t  logits_size = 0; // capacity (of floats) for logits
+    float * logits      = nullptr;
+
     // embeddings output (2-dimensional array: [n_outputs][n_embd])
     // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
     size_t  embd_size = 0; // capacity (of floats) for embeddings
@@ -72,57 +216,47 @@ struct llama_context {
     // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
     std::map<llama_seq_id, std::vector<float>> embd_seq;
 
-    // whether we are computing encoder output or decoder output
-    bool is_encoding = false;
+    int32_t n_outputs     = 0; // number of actually-used outputs in the current ubatch or last logical batch
+    int32_t n_outputs_max = 0; // capacity (of tokens positions) for the output buffers
 
-    // TODO: find a better way to accommodate mutli-dimension position encoding methods
-    // number of position id each token get, 1 for each token in most cases.
-    // when using m-rope, it will be 3 position ids per token to representing 3 dimension coordinate.
-    int n_pos_per_token = 1;
-
-    // output of the encoder part of the encoder-decoder models
-    std::vector<float> embd_enc;
-    std::vector<std::set<llama_seq_id>> seq_ids_enc;
+    std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
 
-    // memory buffers used to evaluate the model
-    std::vector<uint8_t> buf_compute_meta;
     ggml_backend_sched_ptr sched;
 
+    ggml_backend_t backend_cpu = nullptr;
+    std::vector<ggml_backend_ptr> backends;
+
+    ggml_context_ptr ctx_compute;
+
+    ggml_threadpool_t threadpool       = nullptr;
+    ggml_threadpool_t threadpool_batch = nullptr;
+
     ggml_abort_callback abort_callback      = nullptr;
     void *              abort_callback_data = nullptr;
 
-    // input tensors
-    struct ggml_tensor * inp_tokens;        // I32 [n_batch]
-    struct ggml_tensor * inp_embd;          // F32 [n_embd, n_batch]
-    struct ggml_tensor * inp_pos;           // I32 [n_batch]
-    struct ggml_tensor * inp_out_ids;       // I32 [n_outputs]
-    struct ggml_tensor * inp_KQ_mask;       // F32 [kv_size, n_batch]
-    struct ggml_tensor * inp_KQ_mask_swa;   // F32 [kv_size, n_batch]
-    struct ggml_tensor * inp_K_shift;       // I32 [kv_size]
-    struct ggml_tensor * inp_mean;          // F32 [n_batch, n_batch]
-    struct ggml_tensor * inp_cls;           // I32 [n_batch]
-    struct ggml_tensor * inp_s_copy;        // I32 [kv_size]
-    struct ggml_tensor * inp_s_mask;        // F32 [1, n_kv]
-    struct ggml_tensor * inp_s_seq;         // I32 [n_kv, n_batch]
-    struct ggml_tensor * inp_pos_bucket;    // I32 [n_batch|n_kv, n_batch]
-    struct ggml_tensor * inp_embd_enc;      // F32 [n_embd, n_outputs_enc]
-    struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
-};
+    std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
 
-// TODO: make these methods of llama_context
-void llama_set_k_shift(struct llama_context & lctx);
+    // buffer types used for the compute buffer of each backend
+    std::vector<ggml_backend_t>             backend_ptrs;
+    std::vector<ggml_backend_buffer_type_t> backend_buft;
 
-void llama_set_s_copy(struct llama_context & lctx);
+    // memory buffers used to evaluate the model
+    std::vector<uint8_t> buf_compute_meta;
 
-void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch);
+    // host buffer for the model output (logits and embeddings)
+    ggml_backend_buffer_ptr buf_output;
 
-// Make sure enough space is available for outputs.
-// Returns max number of outputs for which space was reserved.
-size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs);
+    bool has_evaluated_once = false;
 
-// make the outputs have the same order they had in the user-provided batch
-void llama_output_reorder(struct llama_context & ctx);
+    // perf
+    mutable int64_t t_start_us  = 0;
+    mutable int64_t t_load_us   = 0;
+    mutable int64_t t_p_eval_us = 0;
+    mutable int64_t t_eval_us   = 0;
 
-// For internal test use
-// TODO: remove
-const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(struct llama_context * ctx);
+    mutable int64_t t_compute_start_us = 0;
+    mutable int64_t n_queued_tokens    = 0;
+
+    mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
+    mutable int32_t n_eval   = 0; // number of eval calls
+};
diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
new file mode 100644 (file)
index 0000000..1e3f2ef
--- /dev/null
@@ -0,0 +1,1695 @@
+#include "llama-graph.h"
+
+#include "llama-impl.h"
+#include "llama-batch.h"
+#include "llama-cparams.h"
+#include "llama-kv-cache.h"
+
+#include <cassert>
+#include <cmath>
+#include <cstring>
+
+static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
+    // TODO move to hparams if a T5 variant appears that uses a different value
+    const int64_t max_distance = 128;
+
+    if (bidirectional) {
+        n_buckets >>= 1;
+    }
+
+    const int64_t max_exact = n_buckets >> 1;
+
+    int32_t relative_position = x - y;
+    int32_t relative_bucket = 0;
+
+    if (bidirectional) {
+        relative_bucket += (relative_position > 0) * n_buckets;
+        relative_position = abs(relative_position);
+    } else {
+        relative_position = -std::min<int32_t>(relative_position, 0);
+    }
+
+    int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
+    relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
+    relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
+
+    return relative_bucket;
+}
+
+void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
+    if (ubatch->token) {
+        const int64_t n_tokens = ubatch->n_tokens;
+
+        ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens));
+    }
+
+    if (ubatch->embd) {
+        const int64_t n_embd   = embd->ne[0];
+        const int64_t n_tokens = ubatch->n_tokens;
+
+        ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd));
+    }
+}
+
+void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
+    if (ubatch->pos && pos) {
+        const int64_t n_tokens = ubatch->n_tokens;
+
+        ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_token*ggml_element_size(pos));
+    }
+}
+
+void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
+    if (pos_bucket) {
+        const int64_t n_tokens = ubatch->n_tokens;
+
+        GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
+        GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
+
+        int32_t * data = (int32_t *) pos_bucket->data;
+
+        for (int h = 0; h < 1; ++h) {
+            for (int j = 0; j < n_tokens; ++j) {
+                for (int i = 0; i < n_tokens; ++i) {
+                    data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
+                }
+            }
+        }
+    }
+}
+
+void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
+    if (pos_bucket) {
+        const int64_t n_tokens = ubatch->n_tokens;
+
+        GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
+        GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
+
+        int32_t * data = (int32_t *) pos_bucket->data;
+
+        const int64_t n_kv = kv_self->n;
+
+        for (int h = 0; h < 1; ++h) {
+            for (int j = 0; j < n_tokens; ++j) {
+                for (int i = 0; i < n_kv; ++i) {
+                    data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self->cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
+                }
+            }
+        }
+    }
+}
+
+void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
+    if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
+        //GGML_ASSERT(out_ids && "every model that can must skip unused outputs");
+
+        if (!out_ids) {
+            LLAMA_LOG_WARN("%s: 'out_ids' is not created\n", __func__);
+        } else {
+            const int64_t n_tokens = ubatch->n_tokens;
+
+            GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
+            int32_t * data = (int32_t *) out_ids->data;
+
+            if (n_outputs == n_tokens) {
+                for (int i = 0; i < n_tokens; ++i) {
+                    data[i] = i;
+                }
+            } else if (ubatch->output) {
+                int32_t n_outputs = 0;
+                for (int i = 0; i < n_tokens; ++i) {
+                    if (ubatch->output[i]) {
+                        data[n_outputs++] = i;
+                    }
+                }
+                // the graph needs to have been passed the correct number of outputs
+                GGML_ASSERT(n_outputs == n_outputs);
+            } else if (n_outputs == 1) {
+                // only keep last output
+                data[0] = n_tokens - 1;
+            } else {
+                GGML_ASSERT(n_outputs == 0);
+            }
+        }
+    }
+}
+
+void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
+    if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
+        const int64_t n_tokens     = ubatch->n_tokens;
+        const int64_t n_seq_tokens = ubatch->n_seq_tokens;
+        const int64_t n_seqs       = ubatch->n_seqs;
+
+        GGML_ASSERT(mean);
+        GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
+
+        float * data = (float *) mean->data;
+        memset(mean->data, 0, n_tokens * n_tokens * ggml_element_size(mean));
+
+        std::vector<uint64_t> sum(n_tokens, 0);
+
+        for (int s = 0; s < n_seqs; ++s) {
+            const llama_seq_id seq_id = ubatch->seq_id[s][0];
+
+            // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
+            GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
+
+            sum[seq_id] += ubatch->n_seq_tokens;
+        }
+
+        std::vector<float> div(n_tokens, 0.0f);
+        for (int i = 0; i < n_tokens; ++i) {
+            const uint64_t s = sum[i];
+            if (s > 0) {
+                div[i] = 1.0f/float(s);
+            }
+        }
+
+        for (int s = 0; s < n_seqs; ++s) {
+            const llama_seq_id seq_id = ubatch->seq_id[s][0];
+
+            for (int i = 0; i < n_seq_tokens; ++i) {
+                data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
+            }
+        }
+    }
+}
+
+void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
+    if (cparams.embeddings && (
+                cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
+                cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
+        const int64_t n_tokens     = ubatch->n_tokens;
+        const int64_t n_seq_tokens = ubatch->n_seq_tokens;
+        const int64_t n_seqs       = ubatch->n_seqs;
+
+        GGML_ASSERT(cls);
+        GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
+
+        uint32_t * data = (uint32_t *) cls->data;
+        memset(cls->data, 0, n_tokens * ggml_element_size(cls));
+
+        for (int s = 0; s < n_seqs; ++s) {
+            const llama_seq_id seq_id = ubatch->seq_id[s][0];
+
+            // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
+            GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
+
+            for (int i = 0; i < n_seq_tokens; ++i) {
+                const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
+
+                if (pos == 0) {
+                    data[seq_id] = s*n_seq_tokens + i;
+                }
+            }
+        }
+    }
+
+    if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
+        const int64_t n_tokens     = ubatch->n_tokens;
+        const int64_t n_seq_tokens = ubatch->n_seq_tokens;
+        const int64_t n_seqs       = ubatch->n_seqs;
+
+        GGML_ASSERT(cls);
+        GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
+
+        uint32_t * data = (uint32_t *) cls->data;
+        memset(cls->data, 0, n_tokens * ggml_element_size(cls));
+
+        std::vector<int> last_pos(n_tokens, -1);
+        std::vector<int> last_row(n_tokens, -1);
+
+        for (int s = 0; s < n_seqs; ++s) {
+            const llama_seq_id seq_id = ubatch->seq_id[s][0];
+
+            // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
+            GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
+
+            for (int i = 0; i < n_seq_tokens; ++i) {
+                const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
+
+                if (pos >= last_pos[seq_id]) {
+                    last_pos[seq_id] = pos;
+                    last_row[seq_id] = s*n_seq_tokens + i;
+                }
+            }
+        }
+
+        for (int i = 0; i < n_tokens; ++i) {
+            if (last_row[i] >= 0) {
+                data[i] = last_row[i];
+            }
+        }
+    }
+}
+
+void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
+    GGML_UNUSED(ubatch);
+
+    const int64_t n_kv = kv_self->n;
+
+    if (s_copy) {
+        GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
+        int32_t * data = (int32_t *) s_copy->data;
+
+        // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
+        for (uint32_t i = 0; i < n_kv; ++i) {
+            const uint32_t  cell_id = i + kv_self->head;
+
+            //////////////////////////////////////////////
+            // TODO: this should not mutate the KV cache !
+            llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
+
+            // prevent out-of-bound sources
+            if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
+                kv_cell.src = cell_id;
+            }
+
+            data[i] = kv_cell.src;
+
+            // TODO: do not mutate the KV cache
+            // ensure copy only happens once
+            if (kv_cell.src != (int32_t) cell_id) {
+                kv_cell.src = cell_id;
+            }
+        }
+    }
+}
+
+void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
+    GGML_UNUSED(ubatch);
+
+    const int64_t n_kv = kv_self->n;
+
+    if (s_mask) {
+        GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
+        float * data = (float *) s_mask->data;
+
+        // clear unused states
+        for (int i = 0; i < n_kv; ++i) {
+            const uint32_t  cell_id = i + kv_self->head;
+
+            //////////////////////////////////////////////
+            // TODO: this should not mutate the KV cache !
+            llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
+
+            data[i] = (float) (kv_cell.src >= 0);
+
+            // only clear once
+            if (kv_cell.src < 0) {
+                kv_cell.src = cell_id;
+            }
+        }
+    }
+}
+
+void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
+    GGML_UNUSED(ubatch);
+
+    if (cross_embd && !cross->v_embd.empty()) {
+        assert(cross_embd->type == GGML_TYPE_F32);
+
+        ggml_backend_tensor_set(cross_embd, cross->v_embd.data(), 0, ggml_nbytes(cross_embd));
+    }
+}
+
+void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
+    if (kq_mask) {
+        if (cparams.causal_attn) {
+            const int64_t n_kv         = ubatch->n_tokens;
+            const int64_t n_tokens     = ubatch->n_tokens;
+            const int64_t n_seq_tokens = ubatch->n_seq_tokens;
+            const int64_t n_seqs       = ubatch->n_seqs;
+
+            GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
+            float * data = (float *) kq_mask->data;
+
+            for (int h = 0; h < 1; ++h) {
+                for (int s1 = 0; s1 < n_seqs; ++s1) {
+                    const llama_seq_id seq_id = ubatch->seq_id[s1][0];
+
+                    for (int j = 0; j < n_seq_tokens; ++j) {
+                        const int32_t tj = s1*n_seq_tokens + j;
+
+                        for (int s0 = 0; s0 < n_seqs; ++s0) {
+                            for (int i = 0; i < n_seq_tokens; ++i) {
+                                const int32_t ti = s0*n_seq_tokens + i;
+                                float f = -INFINITY;
+
+                                for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
+                                    if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
+                                        if (hparams.use_alibi) {
+                                            f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
+                                        } else {
+                                            f = 0.0f;
+                                        }
+                                        break;
+                                    }
+                                }
+
+                                data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
+                            }
+                        }
+                    }
+                }
+            }
+        } else {
+            const int64_t n_tokens     = ubatch->n_tokens;
+            const int64_t n_seq_tokens = ubatch->n_seq_tokens;
+            const int64_t n_seqs       = ubatch->n_seqs;
+            const int64_t n_stride     = ubatch->n_tokens;
+
+            GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
+
+            float * data = (float *) kq_mask->data;
+
+            for (int h = 0; h < 1; ++h) {
+                for (int s1 = 0; s1 < n_seqs; ++s1) {
+                    const llama_seq_id seq_id = ubatch->seq_id[s1][0];
+
+                    for (int j = 0; j < n_seq_tokens; ++j) {
+                        const int32_t tj = s1*n_seq_tokens + j;
+
+                        for (int s0 = 0; s0 < n_seqs; ++s0) {
+                            for (int i = 0; i < n_seq_tokens; ++i) {
+                                const int32_t ti = s0*n_seq_tokens + i;
+                                float f = -INFINITY;
+
+                                for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
+                                    if (ubatch->seq_id[s0][s] == seq_id) {
+                                        if (hparams.use_alibi) {
+                                            f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
+                                        } else {
+                                            f = 0.0f;
+                                        }
+                                        break;
+                                    }
+                                }
+
+                                data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
+                            }
+                        }
+
+                        for (int i = n_tokens; i < n_stride; ++i) {
+                            data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
+    if (self_kq_mask || self_kq_mask_swa) {
+        // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
+        if (cparams.causal_attn) {
+            const int64_t n_kv         = kv_self->n;
+            const int64_t n_tokens     = ubatch->n_tokens;
+            const int64_t n_seq_tokens = ubatch->n_seq_tokens;
+            const int64_t n_seqs       = ubatch->n_seqs;
+
+            float * data     = nullptr;
+            float * data_swa = nullptr;
+
+            if (self_kq_mask) {
+                GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
+                data = (float *) self_kq_mask->data;
+            }
+
+            if (self_kq_mask_swa) {
+                GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
+                data_swa = (float *) self_kq_mask_swa->data;
+            }
+
+            // For causal attention, use only the previous KV cells
+            // of the correct sequence for each token of the ubatch.
+            // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
+            for (int h = 0; h < 1; ++h) {
+                for (int s = 0; s < n_seqs; ++s) {
+                    const llama_seq_id seq_id = ubatch->seq_id[s][0];
+
+                    for (int j = 0; j < n_seq_tokens; ++j) {
+                        const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
+
+                        for (int i = 0; i < n_kv; ++i) {
+                            float f;
+                            if (!kv_self->cells[i].has_seq_id(seq_id) || kv_self->cells[i].pos > pos) {
+                                f = -INFINITY;
+                            } else {
+                                if (hparams.use_alibi) {
+                                    f = -std::abs(kv_self->cells[i].pos - pos);
+                                } else {
+                                    f = 0.0f;
+                                }
+                            }
+
+                            if (data) {
+                                data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
+                            }
+
+                            // may need to cut off old tokens for sliding window
+                            if (data_swa) {
+                                if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
+                                    f = -INFINITY;
+                                }
+                                data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
+                            }
+                        }
+                    }
+                }
+
+                if (data) {
+                    for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
+                        for (int j = 0; j < n_kv; ++j) {
+                            data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
+                        }
+                    }
+                }
+
+                if (data_swa) {
+                    for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
+                        for (int j = 0; j < n_kv; ++j) {
+                            data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
+                        }
+                    }
+                }
+            }
+        } else {
+            const int64_t n_tokens     = ubatch->n_tokens;
+            const int64_t n_seq_tokens = ubatch->n_seq_tokens;
+            const int64_t n_seqs       = ubatch->n_seqs;
+            // when using kv cache, the mask needs to match the kv cache size
+            const int64_t n_stride     = n_tokens;
+
+            GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
+
+            float * data = (float *) self_kq_mask->data;
+
+            for (int h = 0; h < 1; ++h) {
+                for (int s1 = 0; s1 < n_seqs; ++s1) {
+                    const llama_seq_id seq_id = ubatch->seq_id[s1][0];
+
+                    for (int j = 0; j < n_seq_tokens; ++j) {
+                        const int32_t tj = s1*n_seq_tokens + j;
+
+                        for (int s0 = 0; s0 < n_seqs; ++s0) {
+                            for (int i = 0; i < n_seq_tokens; ++i) {
+                                const int32_t ti = s0*n_seq_tokens + i;
+                                float f = -INFINITY;
+
+                                for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
+                                    if (ubatch->seq_id[s0][s] == seq_id) {
+                                        if (hparams.use_alibi) {
+                                            f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
+                                        } else {
+                                            f = 0.0f;
+                                        }
+                                        break;
+                                    }
+                                }
+
+                                data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
+                            }
+                        }
+
+                        for (int i = n_tokens; i < n_stride; ++i) {
+                            data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
+    if (cross_kq_mask) {
+        const int64_t n_enc    = cross_kq_mask->ne[0];
+        const int64_t n_tokens = ubatch->n_tokens;
+
+        GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
+        GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
+
+        float * data = (float *) cross_kq_mask->data;
+
+        for (int h = 0; h < 1; ++h) {
+            for (int j = 0; j < n_tokens; ++j) {
+                for (int i = 0; i < n_enc; ++i) {
+                    float f = -INFINITY;
+                    for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
+                        const llama_seq_id seq_id = ubatch->seq_id[j][s];
+                        if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
+                            f = 0.0f;
+                        }
+                    }
+                    data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
+                }
+            }
+
+            for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
+                for (int j = 0; j < n_enc; ++j) {
+                    data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
+                }
+            }
+        }
+    }
+}
+
+//
+// llm_graph_context
+//
+
+llm_graph_context::llm_graph_context(const llm_graph_params & params) :
+    arch             (params.arch),
+    hparams          (params.hparams),
+    cparams          (params.cparams),
+    ubatch           (params.ubatch),
+    n_embd           (hparams.n_embd),
+    n_layer          (hparams.n_layer),
+    n_rot            (hparams.n_rot),
+    n_ctx            (cparams.n_ctx),
+    n_ctx_per_seq    (cparams.n_ctx / cparams.n_seq_max),
+    n_head           (hparams.n_head()),
+    n_head_kv        (hparams.n_head_kv()),
+    n_embd_head_k    (hparams.n_embd_head_k),
+    n_embd_k_gqa     (hparams.n_embd_k_gqa()),
+    n_embd_head_v    (hparams.n_embd_head_v),
+    n_embd_v_gqa     (hparams.n_embd_v_gqa()),
+    n_expert         (hparams.n_expert),
+    n_expert_used    (hparams.n_expert_used),
+    freq_base        (cparams.rope_freq_base),
+    freq_scale       (cparams.rope_freq_scale),
+    ext_factor       (cparams.yarn_ext_factor),
+    attn_factor      (cparams.yarn_attn_factor),
+    beta_fast        (cparams.yarn_beta_fast),
+    beta_slow        (cparams.yarn_beta_slow),
+    norm_eps         (hparams.f_norm_eps),
+    norm_rms_eps     (hparams.f_norm_rms_eps),
+    n_tokens         (ubatch.n_tokens),
+    n_outputs        (params.n_outputs),
+    n_ctx_orig       (cparams.n_ctx_orig_yarn),
+    pooling_type     (cparams.pooling_type),
+    rope_type        (hparams.rope_type),
+    ctx0             (params.ctx),
+    sched            (params.sched),
+    backend_cpu      (params.backend_cpu),
+    cvec             (params.cvec),
+    loras            (params.loras),
+    memory           (params.memory),
+    cross            (params.cross),
+    cb_func          (params.cb),
+    res              (std::make_unique<llm_graph_result>()) {
+    }
+
+int64_t llm_graph_context::n_pos_per_token() const {
+    return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
+}
+
+void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
+    if (cb_func) {
+        cb_func(ubatch, cur, name, il);
+    }
+}
+
+ggml_tensor * llm_graph_context::build_cvec(
+         ggml_tensor * cur,
+                 int   il) const {
+    return cvec->apply_to(ctx0, cur, il);
+}
+
+ggml_tensor * llm_graph_context::build_lora_mm(
+          ggml_tensor * w,
+          ggml_tensor * cur) const {
+    ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
+
+    for (const auto & lora : *loras) {
+        llama_adapter_lora_weight * lw = lora.first->get_weight(w);
+        if (lw == nullptr) {
+            continue;
+        }
+
+        const float adapter_scale = lora.second;
+        const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
+
+        ggml_tensor * ab_cur = ggml_mul_mat(
+                ctx0, lw->b,
+                ggml_mul_mat(ctx0, lw->a, cur)
+                );
+
+        ab_cur = ggml_scale(ctx0, ab_cur, scale);
+        res = ggml_add(ctx0, res, ab_cur);
+    }
+
+    return res;
+}
+
+ggml_tensor * llm_graph_context::build_lora_mm_id(
+          ggml_tensor * w,   // ggml_tensor * as
+          ggml_tensor * cur, // ggml_tensor * b
+          ggml_tensor * ids) const {
+    ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
+    for (const auto & lora : *loras) {
+        llama_adapter_lora_weight * lw = lora.first->get_weight(w);
+        if (lw == nullptr) {
+            continue;
+        }
+
+        const float alpha = lora.first->alpha;
+        const float rank  = (float) lw->b->ne[0];
+        const float scale = alpha ? lora.second * alpha / rank : lora.second;
+
+        ggml_tensor * ab_cur = ggml_mul_mat_id(
+                ctx0, lw->b,
+                ggml_mul_mat_id(ctx0, lw->a, cur, ids),
+                ids
+                );
+
+        ab_cur = ggml_scale(ctx0, ab_cur, scale);
+        res = ggml_add(ctx0, res, ab_cur);
+    }
+
+    return res;
+}
+
+ggml_tensor * llm_graph_context::build_norm(
+         ggml_tensor * cur,
+         ggml_tensor * mw,
+         ggml_tensor * mb,
+       llm_norm_type   type,
+                 int   il) const {
+    switch (type) {
+        case LLM_NORM:       cur = ggml_norm    (ctx0, cur, hparams.f_norm_eps);     break;
+        case LLM_NORM_RMS:   cur = ggml_rms_norm(ctx0, cur, hparams.f_norm_rms_eps); break;
+        case LLM_NORM_GROUP:
+            {
+                cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], 1, cur->ne[1]);
+                cur = ggml_group_norm(ctx0, cur, hparams.n_norm_groups, hparams.f_norm_group_eps);
+                cur = ggml_reshape_2d(ctx0, cur, cur->ne[0],    cur->ne[2]);
+            } break;
+    }
+
+    if (mw || mb) {
+        cb(cur, "norm", il);
+    }
+
+    if (mw) {
+        cur = ggml_mul(ctx0, cur, mw);
+        if (mb) {
+            cb(cur, "norm_w", il);
+        }
+    }
+
+    if (mb) {
+        cur = ggml_add(ctx0, cur, mb);
+    }
+
+    return cur;
+}
+
+ggml_tensor * llm_graph_context::build_ffn(
+         ggml_tensor * cur,
+         ggml_tensor * up,
+         ggml_tensor * up_b,
+         ggml_tensor * up_s,
+         ggml_tensor * gate,
+         ggml_tensor * gate_b,
+         ggml_tensor * gate_s,
+         ggml_tensor * down,
+         ggml_tensor * down_b,
+         ggml_tensor * down_s,
+         ggml_tensor * act_scales,
+     llm_ffn_op_type   type_op,
+   llm_ffn_gate_type   type_gate,
+                 int   il) const {
+    ggml_tensor * tmp = up ? build_lora_mm(up, cur) : cur;
+    cb(tmp, "ffn_up", il);
+
+    if (up_b) {
+        tmp = ggml_add(ctx0, tmp, up_b);
+        cb(tmp, "ffn_up_b", il);
+    }
+
+    if (up_s) {
+        tmp = ggml_mul(ctx0, tmp, up_s);
+        cb(tmp, "ffn_up_s", il);
+    }
+
+    if (gate) {
+        switch (type_gate) {
+            case LLM_FFN_SEQ:
+                {
+                    cur = build_lora_mm(gate, tmp);
+                    cb(cur, "ffn_gate", il);
+                } break;
+            case LLM_FFN_PAR:
+                {
+                    cur = build_lora_mm(gate, cur);
+                    cb(cur, "ffn_gate", il);
+                } break;
+        }
+
+        if (gate_b) {
+            cur = ggml_add(ctx0, cur, gate_b);
+            cb(cur, "ffn_gate_b", il);
+        }
+
+        if (gate_s) {
+            cur = ggml_mul(ctx0, cur, gate_s);
+            cb(cur, "ffn_gate_s", il);
+        }
+
+    } else {
+        cur = tmp;
+    }
+
+    switch (type_op) {
+        case LLM_FFN_SILU:
+            {
+                cur = ggml_silu(ctx0, cur);
+                cb(cur, "ffn_silu", il);
+            } break;
+        case LLM_FFN_GELU:
+            {
+                cur = ggml_gelu(ctx0, cur);
+                cb(cur, "ffn_gelu", il);
+                if (act_scales != NULL) {
+                    cur = ggml_div(ctx0, cur, act_scales);
+                    cb(cur, "ffn_act", il);
+                }
+            } break;
+        case LLM_FFN_RELU:
+            {
+                cur = ggml_relu(ctx0, cur);
+                cb(cur, "ffn_relu", il);
+            } break;
+        case LLM_FFN_RELU_SQR:
+            {
+                cur = ggml_relu(ctx0, cur);
+                cb(cur, "ffn_relu", il);
+
+                cur = ggml_sqr(ctx0, cur);
+                cb(cur, "ffn_sqr(relu)", il);
+            } break;
+        case LLM_FFN_SWIGLU:
+            {
+                // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
+                int64_t split_point = cur->ne[0] / 2;
+                ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
+                ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
+
+                x0 = ggml_silu(ctx0, x0);
+                cb(cur, "ffn_silu", il);
+
+                cur = ggml_mul(ctx0, x0, x1);
+                cb(cur, "ffn_mul", il);
+            } break;
+    }
+
+    if (type_gate == LLM_FFN_PAR) {
+        cur = ggml_mul(ctx0, cur, tmp);
+        cb(cur, "ffn_gate_par", il);
+    }
+
+    if (down) {
+        cur = build_lora_mm(down, cur);
+    }
+
+    if (down_b) {
+        cb(cur, "ffn_down", il);
+    }
+
+    if (down_b) {
+        cur = ggml_add(ctx0, cur, down_b);
+    }
+
+    if (down_s) {
+        cur = ggml_mul(ctx0, cur, down_s);
+        cb(cur, "ffn_down_s", il);
+    }
+
+    return cur;
+}
+
+ggml_tensor * llm_graph_context::build_moe_ffn(
+         ggml_tensor * cur,
+         ggml_tensor * gate_inp,
+         ggml_tensor * up_exps,
+         ggml_tensor * gate_exps,
+         ggml_tensor * down_exps,
+         ggml_tensor * exp_probs_b,
+             int64_t   n_expert,
+             int64_t   n_expert_used,
+     llm_ffn_op_type   type_op,
+                bool   norm_w,
+                bool   scale_w,
+               float   w_scale,
+         llama_expert_gating_func_type gating_op,
+                 int   il) const {
+    int64_t n_embd = cur->ne[0];
+    int64_t n_tokens = cur->ne[1];
+
+    ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
+    cb(logits, "ffn_moe_logits", il);
+
+    ggml_tensor * probs = nullptr;
+    switch (gating_op) {
+        case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
+            {
+                probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens]
+            } break;
+        case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID:
+            {
+                probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
+            } break;
+        default:
+            GGML_ABORT("fatal error");
+    }
+    cb(probs, "ffn_moe_probs", il);
+
+    // add experts selection bias - introduced in DeepSeek V3
+    // leave probs unbiased as it's later used to get expert weights
+    ggml_tensor * selection_probs = probs;
+    if (exp_probs_b != nullptr) {
+        selection_probs = ggml_add(ctx0, probs, exp_probs_b);
+        cb(selection_probs, "ffn_moe_probs_biased", il);
+    }
+
+    // select experts
+    ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
+    cb(selected_experts->src[0], "ffn_moe_argsort", il);
+    cb(selected_experts, "ffn_moe_topk", il);
+
+    ggml_tensor * weights = ggml_get_rows(ctx0,
+            ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
+    cb(weights, "ffn_moe_weights", il);
+
+    if (norm_w) {
+        weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
+
+        ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
+        cb(weights_sum, "ffn_moe_weights_sum", il);
+
+        weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
+        cb(weights, "ffn_moe_weights_norm", il);
+
+        weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
+    }
+    if (scale_w) {
+        weights = ggml_scale(ctx0, weights, w_scale);
+        cb(weights, "ffn_moe_weights_scaled", il);
+    }
+
+    cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
+    ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
+    cb(up, "ffn_moe_up", il);
+
+    ggml_tensor * gate = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
+    cb(gate, "ffn_moe_gate", il);
+
+    switch (type_op) {
+        case LLM_FFN_SILU:
+            {
+                gate = ggml_silu(ctx0, gate);
+                cb(gate, "ffn_moe_silu", il);
+            } break;
+        case LLM_FFN_GELU:
+            {
+                gate = ggml_gelu(ctx0, gate);
+                cb(gate, "ffn_moe_gelu", il);
+            } break;
+        default:
+            GGML_ABORT("fatal error");
+    }
+
+    ggml_tensor * par = ggml_mul(ctx0, up, gate); // [n_ff, n_expert_used, n_tokens]
+    cb(par, "ffn_moe_gate_par", il);
+
+    ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
+    cb(experts, "ffn_moe_down", il);
+
+    experts = ggml_mul(ctx0, experts, weights);
+
+    // aggregate experts
+    ggml_tensor * moe_out = nullptr;
+    for (int i = 0; i < n_expert_used; ++i) {
+        ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens,
+                experts->nb[2], i*experts->nb[1]);
+
+        if (i == 0) {
+            moe_out = cur_expert;
+        } else {
+            moe_out = ggml_add(ctx0, moe_out, cur_expert);
+        }
+    }
+
+    if (n_expert_used == 1) {
+        // avoid returning a non-contiguous tensor
+        moe_out = ggml_cont(ctx0, moe_out);
+    }
+
+    return moe_out;
+}
+
+// input embeddings with optional lora
+ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
+    const int64_t n_embd = hparams.n_embd;
+
+    auto inp = std::make_unique<llm_graph_input_embd>();
+
+    ggml_tensor * cur = nullptr;
+
+    if (ubatch.token) {
+        inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
+        //cb(inp->tokens, "inp_tokens", -1);
+        ggml_set_input(inp->tokens);
+
+        cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
+
+        // apply lora for embedding tokens if needed
+        for (const auto & lora : *loras) {
+            llama_adapter_lora_weight * lw = lora.first->get_weight(tok_embd);
+            if (lw == nullptr) {
+                continue;
+            }
+
+            const float adapter_scale = lora.second;
+            const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
+
+            ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat(
+                        ctx0, lw->b, // non-transposed lora_b
+                        ggml_get_rows(ctx0, lw->a, inp->tokens)
+                        ), scale);
+
+            cur = ggml_add(ctx0, cur, inpL_delta);
+        }
+    } else {
+        inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
+        ggml_set_input(inp->embd);
+
+        cur = inp->embd;
+    }
+
+    // For Granite architecture
+    if (hparams.f_embedding_scale != 0.0f) {
+        cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale);
+    }
+
+    cb(cur, "inp_embd", -1);
+
+    res->add_input(std::move(inp));
+
+    return cur;
+}
+
+ggml_tensor * llm_graph_context::build_inp_pos() const {
+    auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_token());
+
+    auto & cur = inp->pos;
+
+    cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_token());
+    ggml_set_input(cur);
+
+    res->add_input(std::move(inp));
+
+    return cur;
+}
+
+ggml_tensor * llm_graph_context::build_inp_out_ids() const {
+    auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
+
+    auto & cur = inp->out_ids;
+
+    cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs);
+    ggml_set_input(cur);
+
+    res->add_input(std::move(inp));
+
+    return cur;
+}
+
+ggml_tensor * llm_graph_context::build_inp_mean() const {
+    auto inp = std::make_unique<llm_graph_input_mean>(cparams);
+
+    auto & cur = inp->mean;
+
+    cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
+    ggml_set_input(cur);
+
+    res->add_input(std::move(inp));
+
+    return cur;
+}
+
+ggml_tensor * llm_graph_context::build_inp_cls() const {
+    auto inp = std::make_unique<llm_graph_input_cls>(cparams);
+
+    auto & cur = inp->cls;
+
+    cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+    ggml_set_input(cur);
+
+    res->add_input(std::move(inp));
+
+    return cur;
+}
+
+ggml_tensor * llm_graph_context::build_inp_s_copy() const {
+    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+
+    auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
+
+    const auto n_kv = kv_self->n;
+
+    auto & cur = inp->s_copy;
+
+    cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
+    ggml_set_input(cur);
+
+    res->add_input(std::move(inp));
+
+    return cur;
+}
+
+ggml_tensor * llm_graph_context::build_inp_s_mask() const {
+    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+
+    auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
+
+    const auto n_kv = kv_self->n;
+
+    auto & cur = inp->s_mask;
+
+    cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
+    ggml_set_input(cur);
+
+    res->add_input(std::move(inp));
+
+    return cur;
+}
+
+ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
+    auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
+
+    auto & cur = inp->cross_embd;
+
+    // if we have the output embeddings from the encoder, use them directly
+    // TODO: needs more work to be correct, for now just use the tensor shape
+    //if (cross->t_embd) {
+    //    cur = ggml_view_tensor(ctx0, cross->t_embd);
+
+    //    return cur;
+    //}
+
+    const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd;
+    const auto n_enc  = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
+
+    cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
+    ggml_set_input(cur);
+
+    res->add_input(std::move(inp));
+
+    return cur;
+}
+
+ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
+    auto inp = std::make_unique<llm_graph_input_pos_bucket>(hparams);
+
+    auto & cur = inp->pos_bucket;
+
+    cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens);
+    ggml_set_input(cur);
+
+    res->add_input(std::move(inp));
+
+    return cur;
+}
+
+ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
+    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+
+    auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
+
+    const auto n_kv = kv_self->n;
+
+    auto & cur = inp->pos_bucket;
+
+    cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
+    ggml_set_input(cur);
+
+    res->add_input(std::move(inp));
+
+    return cur;
+}
+
+ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const {
+    ggml_tensor * pos_bucket_1d = ggml_reshape_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1]);
+    cb(pos_bucket_1d, "pos_bucket_1d", -1);
+
+    ggml_tensor * pos_bias = ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d);
+
+    pos_bias = ggml_reshape_3d(ctx0, pos_bias, pos_bias->ne[0], pos_bucket->ne[0], pos_bucket->ne[1]);
+    pos_bias = ggml_permute   (ctx0, pos_bias, 2, 0, 1, 3);
+    pos_bias = ggml_cont      (ctx0, pos_bias);
+
+    cb(pos_bias, "pos_bias", -1);
+
+    return pos_bias;
+}
+
+ggml_tensor * llm_graph_context::build_attn_mha(
+         ggml_cgraph * gf,
+         ggml_tensor * q,
+         ggml_tensor * k,
+         ggml_tensor * v,
+         ggml_tensor * kq_b,
+         ggml_tensor * kq_mask,
+             bool      v_trans,
+             float     kq_scale) const {
+  //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
+  //const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
+
+  //const int64_t n_head    = hparams.n_head(il);
+  //const int64_t n_head_kv = hparams.n_head_kv(il);
+
+  //const auto & n_embd_head_k = hparams.n_embd_head_k;
+  //const auto & n_embd_head_v = hparams.n_embd_head_v;
+
+    const auto n_embd_head_v = v_trans ? v->ne[1] : v->ne[0];
+
+    const auto n_tokens = q->ne[1];
+    const auto n_head   = q->ne[2];
+    const auto n_kv     = k->ne[1];
+
+    ggml_tensor * cur;
+
+    // TODO: replace hardcoded padding with ggml-provided padding
+    if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
+        GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
+
+        if (v_trans) {
+            v = ggml_transpose(ctx0, v);
+        }
+
+        cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
+                                  hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
+
+        ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
+
+        cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
+    } else {
+        ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+
+        // note: this op tends to require high floating point range
+        //       while for some models F16 is enough, for others it is not, so we default to F32 here
+        ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
+
+        if (arch == LLM_ARCH_GROK) {
+            // need to do the following:
+            // multiply by attn_output_multiplyer of 0.08838834764831845
+            // and then :
+            // kq = 30 * tanh(kq / 30)
+            // before the softmax below
+
+            kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f));
+            kq = ggml_scale(ctx0, kq, 30);
+        }
+
+        if (hparams.attn_soft_cap) {
+            kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
+            kq = ggml_tanh (ctx0, kq);
+            kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
+        }
+
+        if (kq_b) {
+            kq = ggml_add(ctx0, kq, kq_b);
+        }
+
+        kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
+
+        if (!v_trans) {
+            // note: avoid this branch
+            v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
+        }
+
+        ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
+
+        ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+
+        cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
+
+        if (!cparams.offload_kqv) {
+            // all nodes between the KV store and the attention output are run on the CPU
+            ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu);
+        }
+    }
+
+    ggml_build_forward_expand(gf, cur);
+
+    return cur;
+}
+
+llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const {
+    auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
+
+    // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
+    inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+    //cb(inp_kq_mask, "KQ_mask", -1);
+    ggml_set_input(inp->kq_mask);
+
+    inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
+
+    return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
+}
+
+ggml_tensor * llm_graph_context::build_attn(
+        llm_graph_input_attn_no_cache * inp,
+        ggml_cgraph * gf,
+        ggml_tensor * wo,
+        ggml_tensor * wo_b,
+        ggml_tensor * q_cur,
+        ggml_tensor * k_cur,
+        ggml_tensor * v_cur,
+        ggml_tensor * kq_b,
+            float     kq_scale,
+            int       il) const {
+    GGML_UNUSED(n_tokens);
+
+    // these nodes are added to the graph together so that they are not reordered
+    // by doing so, the number of splits in the graph is reduced
+    ggml_build_forward_expand(gf, q_cur);
+    ggml_build_forward_expand(gf, k_cur);
+    ggml_build_forward_expand(gf, v_cur);
+
+    const auto & kq_mask = inp->get_kq_mask();
+
+    ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
+    //cb(q, "q", il);
+
+    ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
+    //cb(k, "k", il);
+
+    ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
+    //cb(k, "v", il);
+
+    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
+
+    cb(cur, "kqv_out", il);
+
+    if (wo) {
+        cur = build_lora_mm(wo, cur);
+    }
+
+    if (wo_b) {
+        //cb(cur, "kqv_wo", il);
+    }
+
+    if (wo_b) {
+        cur = ggml_add(ctx0, cur, wo_b);
+    }
+
+    return cur;
+}
+
+llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(
+                bool   causal,
+                bool   swa) const {
+    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+
+    auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
+
+    const auto n_kv = kv_self->n;
+
+    inp->self_kq_mask = causal
+        ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv,     GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
+        : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+    //cb(inp->self_kq_mask, "KQ_mask", -1);
+    ggml_set_input(inp->self_kq_mask);
+
+    inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
+
+    if (swa) {
+        GGML_ASSERT(hparams.n_swa > 0);
+
+        inp->self_kq_mask_swa = causal
+            ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv,     GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
+            : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+        //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
+        ggml_set_input(inp->self_kq_mask_swa);
+
+        inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
+    }
+
+    return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
+}
+
+ggml_tensor * llm_graph_context::build_attn(
+        llm_graph_input_attn_kv_unified * inp,
+        ggml_cgraph * gf,
+        ggml_tensor * wo,
+        ggml_tensor * wo_b,
+        ggml_tensor * q_cur,
+        ggml_tensor * k_cur,
+        ggml_tensor * v_cur,
+        ggml_tensor * kq_b,
+            float     kq_scale,
+            int       il) const {
+    // these nodes are added to the graph together so that they are not reordered
+    // by doing so, the number of splits in the graph is reduced
+    ggml_build_forward_expand(gf, q_cur);
+    ggml_build_forward_expand(gf, k_cur);
+    ggml_build_forward_expand(gf, v_cur);
+
+    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+    const auto & n_ctx = cparams.n_ctx;
+
+    const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
+    const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
+
+    const auto n_tokens = q_cur->ne[2];
+
+    const bool v_trans = !cparams.flash_attn;
+
+    // store to KV cache
+    {
+        GGML_ASSERT(!kv_self->recurrent);
+
+        const auto kv_head = kv_self->head;
+
+        GGML_ASSERT(kv_self->size == n_ctx);
+
+        ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa)*kv_head);
+        //cb(k_cache_view, "k_cache_view", il);
+
+        // note: storing RoPE-ed version of K in the KV cache
+        ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
+
+        assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
+
+        ggml_tensor * v_cache_view = nullptr;
+
+        if (!v_trans) {
+            v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head);
+        } else {
+            // note: the V cache is transposed when not using flash attention
+            v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
+                    (  n_ctx)*ggml_element_size(kv_self->v_l[il]),
+                    (kv_head)*ggml_element_size(kv_self->v_l[il]));
+
+            v_cur = ggml_transpose(ctx0, v_cur);
+        }
+        //cb(v_cache_view, "v_cache_view", il);
+
+        ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
+    }
+
+    // TODO: improve
+    bool is_sliding = false;
+
+    switch (arch) {
+        case LLM_ARCH_COHERE2:
+            {
+                const int32_t sliding_window_pattern = 4;
+                is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
+            } break;
+        case LLM_ARCH_GEMMA2:
+            {
+                const int32_t sliding_window_pattern = 2;
+                is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
+            } break;
+        case LLM_ARCH_GEMMA3:
+            {
+                const int32_t sliding_window_pattern = 6;
+                is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
+            } break;
+        case LLM_ARCH_PHI3:
+            {
+                is_sliding = hparams.n_swa > 0;
+            } break;
+        default:
+            {
+                is_sliding = false;
+            }
+    };
+
+    const auto & kq_mask = is_sliding ? inp->get_kq_mask_swa() : inp->get_kq_mask();
+
+    const auto n_kv = kv_self->n;
+
+    const int64_t n_head_kv = hparams.n_head_kv(il);
+
+    const auto & n_embd_head_k = hparams.n_embd_head_k;
+    const auto & n_embd_head_v = hparams.n_embd_head_v;
+
+    ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
+    //cb(q, "q", il);
+
+    ggml_tensor * k =
+        ggml_view_3d(ctx0, kv_self->k_l[il],
+                n_embd_head_k, n_kv, n_head_kv,
+                ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
+                ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
+                0);
+    //cb(k, "k", il);
+
+    ggml_tensor * v = !v_trans ?
+        ggml_view_3d(ctx0, kv_self->v_l[il],
+                n_embd_head_v, n_kv, n_head_kv,
+                ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
+                ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
+                0) :
+        ggml_view_3d(ctx0, kv_self->v_l[il],
+                n_kv, n_embd_head_v, n_head_kv,
+                ggml_element_size(kv_self->v_l[il])*n_ctx,
+                ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
+                0);
+
+    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale);
+    cb(cur, "kqv_out", il);
+
+    if (wo) {
+        cur = build_lora_mm(wo, cur);
+    }
+
+    if (wo_b) {
+        //cb(cur, "kqv_wo", il);
+    }
+
+    if (wo_b) {
+        cur = ggml_add(ctx0, cur, wo_b);
+    }
+
+    return cur;
+}
+
+llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
+    auto inp = std::make_unique<llm_graph_input_attn_cross>(cross);
+
+    const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
+
+    inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+    ggml_set_input(inp->cross_kq_mask);
+
+    inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
+
+    return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
+}
+
+ggml_tensor * llm_graph_context::build_attn(
+        llm_graph_input_attn_cross * inp,
+        ggml_cgraph * gf,
+        ggml_tensor * wo,
+        ggml_tensor * wo_b,
+        ggml_tensor * q_cur,
+        ggml_tensor * k_cur,
+        ggml_tensor * v_cur,
+        ggml_tensor * kq_b,
+            float     kq_scale,
+            int       il) const {
+    // these nodes are added to the graph together so that they are not reordered
+    // by doing so, the number of splits in the graph is reduced
+    ggml_build_forward_expand(gf, q_cur);
+    ggml_build_forward_expand(gf, k_cur);
+    ggml_build_forward_expand(gf, v_cur);
+
+    const auto & kq_mask = inp->get_kq_mask_cross();
+
+    ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
+    //cb(q, "q", il);
+
+    ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
+    //cb(k, "k", il);
+
+    ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
+    //cb(k, "v", il);
+
+    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
+
+    cb(cur, "kqv_out", il);
+
+    if (wo) {
+        cur = build_lora_mm(wo, cur);
+    }
+
+    if (wo_b) {
+        //cb(cur, "kqv_wo", il);
+    }
+
+    if (wo_b) {
+        cur = ggml_add(ctx0, cur, wo_b);
+    }
+
+    return cur;
+}
+
+ggml_tensor * llm_graph_context::build_copy_mask_state(
+         ggml_cgraph * gf,
+         ggml_tensor * s,
+         ggml_tensor * state_copy,
+         ggml_tensor * state_mask,
+             int32_t   n_state,
+             int32_t   n_seqs) const {
+    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+
+    const auto n_kv    = kv_self->n;
+    const auto kv_head = kv_self->head;
+
+    ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size);
+
+    // copy states
+    // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
+    // this shrinks the tensors's ne[1] to n_kv
+    states = ggml_get_rows(ctx0, states, state_copy);
+
+    // clear states of sequences which are starting at the beginning of this batch
+    // FIXME: zero-out NANs?
+    states = ggml_mul(ctx0, states, state_mask);
+
+    // copy states which won't be changed further (between n_seqs and n_kv)
+    ggml_build_forward_expand(gf,
+        ggml_cpy(ctx0,
+            ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs          )*n_state*ggml_element_size(states)),
+            ggml_view_1d(ctx0, s,      n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
+
+    // the part of the states that will be used and modified
+    return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
+}
+
+ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
+         ggml_cgraph * gf,
+         ggml_tensor * state_copy,
+         ggml_tensor * state_mask,
+  const llama_ubatch & ubatch,
+                 int   il) const {
+    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+
+    const auto token_shift_count = hparams.token_shift_count;
+
+    const int64_t n_seqs  = ubatch.n_seqs;
+
+    ggml_tensor * token_shift_all = kv_self->k_l[il];
+
+    ggml_tensor * token_shift = build_copy_mask_state(
+            gf, token_shift_all, state_copy, state_mask,
+            hparams.n_embd_k_s(), n_seqs);
+
+    token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
+
+    return token_shift;
+}
+
+ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
+         ggml_tensor * token_shift,
+  const llama_ubatch & ubatch,
+                 int   il) const {
+    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+
+    const auto token_shift_count = hparams.token_shift_count;
+    const auto n_embd = hparams.n_embd;
+
+    const int64_t n_seqs = ubatch.n_seqs;
+
+    const auto kv_head = kv_self->head;
+
+    return ggml_cpy(
+        ctx0,
+        ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
+        ggml_view_1d(ctx0, kv_self->k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self->k_l[il]))
+    );
+}
+
+void llm_graph_context::build_pooling(
+        ggml_cgraph * gf,
+        ggml_tensor * cls,
+        ggml_tensor * cls_b,
+        ggml_tensor * cls_out,
+        ggml_tensor * cls_out_b) const {
+    if (!cparams.embeddings) {
+        return;
+    }
+
+    ggml_tensor * inp = res->t_embd;
+
+    //// find result_norm tensor for input
+    //for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
+    //    inp = ggml_graph_node(gf, i);
+    //    if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
+    //        break;
+    //    }
+
+    //    inp = nullptr;
+    //}
+
+    GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
+
+    ggml_tensor * cur;
+
+    switch (pooling_type) {
+        case LLAMA_POOLING_TYPE_NONE:
+            {
+                cur = inp;
+            } break;
+        case LLAMA_POOLING_TYPE_MEAN:
+            {
+                ggml_tensor * inp_mean = build_inp_mean();
+                cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
+            } break;
+        case LLAMA_POOLING_TYPE_CLS:
+        case LLAMA_POOLING_TYPE_LAST:
+            {
+                ggml_tensor * inp_cls = build_inp_cls();
+                cur = ggml_get_rows(ctx0, inp, inp_cls);
+            } break;
+        case LLAMA_POOLING_TYPE_RANK:
+            {
+                ggml_tensor * inp_cls = build_inp_cls();
+                inp = ggml_get_rows(ctx0, inp, inp_cls);
+
+                // classification head
+                // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
+                GGML_ASSERT(cls   != nullptr);
+                GGML_ASSERT(cls_b != nullptr);
+
+                cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
+                cur = ggml_tanh(ctx0, cur);
+
+                // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
+                // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
+                if (cls_out) {
+                    GGML_ASSERT(cls_out_b != nullptr);
+
+                    cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
+                }
+            } break;
+        default:
+            {
+                GGML_ABORT("unknown pooling type");
+            }
+    }
+
+    cb(cur, "result_embd_pooled", -1);
+    res->t_embd_pooled = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
+
diff --git a/src/llama-graph.h b/src/llama-graph.h
new file mode 100644 (file)
index 0000000..b7a66d1
--- /dev/null
@@ -0,0 +1,576 @@
+#pragma once
+
+#include "llama-arch.h"
+#include "llama-hparams.h"
+#include "llama-adapter.h"
+
+#include <cstdint>
+#include <vector>
+#include <memory>
+#include <set>
+#include <functional>
+
+struct ggml_cgraph;
+struct ggml_context;
+struct ggml_tensor;
+
+struct llama_ubatch;
+struct llama_cparams;
+
+class llama_memory_i;
+class llama_kv_cache_unified;
+
+// certain models (typically multi-modal) can produce different types of graphs
+enum llm_graph_type {
+    LLM_GRAPH_TYPE_DEFAULT,
+    LLM_GRAPH_TYPE_ENCODER,
+    LLM_GRAPH_TYPE_DECODER,
+};
+
+enum llm_ffn_op_type {
+    LLM_FFN_SILU,
+    LLM_FFN_GELU,
+    LLM_FFN_RELU,
+    LLM_FFN_RELU_SQR,
+    LLM_FFN_SWIGLU,
+};
+
+enum llm_ffn_gate_type {
+    LLM_FFN_SEQ,
+    LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
+};
+
+enum llm_norm_type {
+    LLM_NORM,
+    LLM_NORM_RMS,
+    LLM_NORM_GROUP,
+};
+
+// TODO: tmp - need something better to pass the data from the encoder to the decoder
+struct llama_cross {
+    // the output embeddings from the encoder as a ggml tensor
+    // TODO: this needs more work to be correct, for now copy the embeddings data to host memory
+    //       ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
+    //ggml_tensor * t_embd = nullptr;
+
+    int64_t n_embd = 0;
+    int64_t n_enc  = 0;
+
+    // embeddings data copied to host memory (tmp)
+    std::vector<float> v_embd;
+
+    // needed to construct the cross-attention mask in the decoder
+    std::vector<std::set<llama_seq_id>> seq_ids_enc;
+};
+
+//
+// llm_graph_input
+//
+
+class llm_graph_input_i {
+public:
+    virtual ~llm_graph_input_i() = default;
+
+    virtual void set_input(const llama_ubatch * ubatch) = 0;
+};
+
+using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
+
+
+class llm_graph_input_embd : public llm_graph_input_i {
+public:
+    llm_graph_input_embd()          = default;
+    virtual ~llm_graph_input_embd() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * tokens = nullptr; // I32 [n_batch]
+    ggml_tensor * embd   = nullptr; // F32 [n_embd, n_batch]
+};
+
+class llm_graph_input_pos : public llm_graph_input_i {
+public:
+    llm_graph_input_pos(int64_t n_pos_per_token) : n_pos_per_token(n_pos_per_token) {}
+    virtual ~llm_graph_input_pos() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * pos = nullptr; // I32 [n_batch]
+
+    const int64_t n_pos_per_token = 1;
+};
+
+class llm_graph_input_pos_bucket : public llm_graph_input_i {
+public:
+    llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
+    virtual ~llm_graph_input_pos_bucket() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
+
+    const llama_hparams & hparams;
+};
+
+class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
+public:
+    llm_graph_input_pos_bucket_kv(
+            const llama_hparams & hparams,
+            const llama_kv_cache_unified * kv_self) : hparams(hparams), kv_self(kv_self) {}
+    virtual ~llm_graph_input_pos_bucket_kv() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
+
+    const llama_hparams & hparams;
+    const llama_kv_cache_unified * kv_self;
+};
+
+class llm_graph_input_out_ids : public llm_graph_input_i {
+public:
+    llm_graph_input_out_ids(
+            const llama_hparams & hparams,
+            const llama_cparams & cparams,
+            int32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
+    virtual ~llm_graph_input_out_ids() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * out_ids; // I32 [n_outputs]
+
+    const llama_hparams & hparams;
+    const llama_cparams & cparams;
+
+    const int32_t n_outputs;
+};
+
+class llm_graph_input_mean : public llm_graph_input_i {
+public:
+    llm_graph_input_mean(const llama_cparams & cparams) : cparams(cparams) {}
+    virtual ~llm_graph_input_mean() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * mean; // F32 [n_batch, n_batch]
+
+    const llama_cparams & cparams;
+};
+
+class llm_graph_input_cls : public llm_graph_input_i {
+public:
+    llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {}
+    virtual ~llm_graph_input_cls() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * cls; // I32 [n_batch]
+
+    const llama_cparams & cparams;
+};
+
+class llm_graph_input_s_copy : public llm_graph_input_i {
+public:
+    llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
+    virtual ~llm_graph_input_s_copy() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * s_copy; // I32 [kv_size]
+
+    const llama_kv_cache_unified * kv_self;
+};
+
+class llm_graph_input_s_mask : public llm_graph_input_i {
+public:
+    llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
+    virtual ~llm_graph_input_s_mask() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * s_mask; // F32 [1, n_kv]
+
+    const llama_kv_cache_unified * kv_self;
+};
+
+class llm_graph_input_cross_embd : public llm_graph_input_i {
+public:
+    llm_graph_input_cross_embd(
+            const llama_cross * cross) : cross(cross) {}
+    virtual ~llm_graph_input_cross_embd() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
+
+    const llama_cross * cross;
+};
+
+class llm_graph_input_attn_no_cache : public llm_graph_input_i {
+public:
+    llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
+        hparams(hparams),
+        cparams(cparams) {
+    }
+    ~llm_graph_input_attn_no_cache() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
+
+    ggml_tensor * kq_mask     = nullptr; // F32 [n_tokens, n_batch]
+    ggml_tensor * kq_mask_cnv = nullptr; //     [n_tokens, n_batch]
+
+    const llama_hparams & hparams;
+    const llama_cparams & cparams;
+};
+
+class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
+public:
+    llm_graph_input_attn_kv_unified(
+            const llama_hparams & hparams,
+            const llama_cparams & cparams,
+            const llama_kv_cache_unified * kv_self) :
+        hparams(hparams),
+        cparams(cparams),
+        kv_self(kv_self) {
+    }
+    ~llm_graph_input_attn_kv_unified() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * get_kq_mask()     const { return self_kq_mask_cnv; }
+    ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
+
+    ggml_tensor * self_kq_mask         = nullptr; // F32 [n_kv, n_batch]
+    ggml_tensor * self_kq_mask_cnv     = nullptr; //     [n_kv, n_batch]
+    ggml_tensor * self_kq_mask_swa     = nullptr; // F32 [n_kv, n_batch]
+    ggml_tensor * self_kq_mask_swa_cnv = nullptr; //     [n_kv, n_batch]
+
+    const llama_hparams & hparams;
+    const llama_cparams & cparams;
+
+    const llama_kv_cache_unified * kv_self;
+};
+
+class llm_graph_input_attn_cross : public llm_graph_input_i {
+public:
+    llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
+    ~llm_graph_input_attn_cross() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
+
+    ggml_tensor * cross_kq_mask     = nullptr; // F32 [n_outputs_enc, n_batch]
+    ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
+
+    const llama_cross * cross = nullptr;
+};
+
+//
+// llm_graph_result
+//
+
+// these objects deliver the result from the graph build process back to the llama_context
+// note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
+//   specific data, by calling the set_inputs() method
+// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
+//   these are used by the llama_context to extact the relevant data, based on the compute parameters
+
+class llm_graph_result_i {
+public:
+    virtual ~llm_graph_result_i() = default;
+
+    virtual ggml_tensor * get_logits()      = 0;
+    virtual ggml_tensor * get_embd()        = 0;
+    virtual ggml_tensor * get_embd_pooled() = 0;
+
+    virtual void set_inputs(const llama_ubatch * ubatch) = 0;
+};
+
+using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
+
+
+class llm_graph_result : public llm_graph_result_i {
+public:
+    virtual ~llm_graph_result() = default;
+
+    ggml_tensor * get_logits()      override { return t_logits; }
+    ggml_tensor * get_embd()        override { return t_embd; }
+    ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
+
+    void set_inputs(const llama_ubatch * ubatch) override {
+        for (auto & input : inputs) {
+            input->set_input(ubatch);
+        }
+    }
+
+    llm_graph_input_i * add_input(llm_graph_input_ptr input) {
+        inputs.emplace_back(std::move(input));
+        return inputs.back().get();
+    }
+
+    // important graph nodes
+    ggml_tensor * t_logits      = nullptr;
+    ggml_tensor * t_embd        = nullptr;
+    ggml_tensor * t_embd_pooled = nullptr;
+
+    std::vector<llm_graph_input_ptr> inputs;
+};
+
+//
+// llm_graph_context
+//
+
+// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
+using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
+
+struct llm_graph_params {
+    ggml_context * ctx;
+
+    const llm_arch arch;
+
+    const llama_hparams & hparams;
+    const llama_cparams & cparams;
+    const llama_ubatch  & ubatch;
+
+    ggml_backend_sched * sched;
+    ggml_backend * backend_cpu;
+
+    const llama_adapter_cvec  * cvec;
+    const llama_adapter_loras * loras;
+    const llama_memory_i      * memory;
+    const llama_cross         * cross;
+
+    int32_t n_outputs;
+
+    const llm_graph_cb & cb;
+};
+
+struct llm_graph_context {
+    const llm_arch arch;
+
+    const llama_hparams & hparams;
+    const llama_cparams & cparams;
+    const llama_ubatch  & ubatch;
+
+    const int64_t n_embd;
+    const int64_t n_layer;
+    const int64_t n_rot;
+    const int64_t n_ctx;       // user-specified context size (can be different from n_ctx_train)
+    const int64_t n_ctx_per_seq;
+    const int64_t n_head;
+    const int64_t n_head_kv;
+    const int64_t n_embd_head_k;
+    const int64_t n_embd_k_gqa;
+    const int64_t n_embd_head_v;
+    const int64_t n_embd_v_gqa;
+    const int64_t n_expert;
+    const int64_t n_expert_used;
+
+    const float freq_base;
+    const float freq_scale;
+    const float ext_factor;
+    const float attn_factor;
+    const float beta_fast;
+    const float beta_slow;
+    const float norm_eps;
+    const float norm_rms_eps;
+
+    const int32_t n_tokens;
+    const int32_t n_outputs;
+    const int32_t n_ctx_orig; // yarn
+
+    const enum llama_pooling_type pooling_type;
+    const enum llama_rope_type    rope_type;
+
+    ggml_context * ctx0 = nullptr;
+
+    ggml_backend_sched * sched;
+
+    ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
+
+    const llama_adapter_cvec  * cvec;
+    const llama_adapter_loras * loras;
+    const llama_memory_i      * memory;
+    const llama_cross         * cross;
+
+    const llm_graph_cb & cb_func;
+
+    std::unique_ptr<llm_graph_result> res;
+
+    llm_graph_context(const llm_graph_params & params);
+
+    int64_t n_pos_per_token() const;
+
+    void cb(ggml_tensor * cur, const char * name, int il) const;
+
+    //
+    // common
+    //
+
+    ggml_tensor * build_cvec(
+             ggml_tensor * cur,
+                     int   il) const;
+
+    // do mat_mul, while optionally apply lora
+    ggml_tensor * build_lora_mm(
+              ggml_tensor * w,
+              ggml_tensor * cur) const;
+
+    // do mat_mul_id, while optionally apply lora
+    ggml_tensor * build_lora_mm_id(
+              ggml_tensor * w,   // ggml_tensor * as
+              ggml_tensor * cur, // ggml_tensor * b
+              ggml_tensor * ids) const;
+
+    ggml_tensor * build_norm(
+             ggml_tensor * cur,
+             ggml_tensor * mw,
+             ggml_tensor * mb,
+           llm_norm_type   type,
+                     int   il) const;
+
+    ggml_tensor * build_ffn(
+             ggml_tensor * cur,
+             ggml_tensor * up,
+             ggml_tensor * up_b,
+             ggml_tensor * up_s,
+             ggml_tensor * gate,
+             ggml_tensor * gate_b,
+             ggml_tensor * gate_s,
+             ggml_tensor * down,
+             ggml_tensor * down_b,
+             ggml_tensor * down_s,
+             ggml_tensor * act_scales,
+         llm_ffn_op_type   type_op,
+       llm_ffn_gate_type   type_gate,
+                     int   il) const;
+
+    ggml_tensor * build_moe_ffn(
+             ggml_tensor * cur,
+             ggml_tensor * gate_inp,
+             ggml_tensor * up_exps,
+             ggml_tensor * gate_exps,
+             ggml_tensor * down_exps,
+             ggml_tensor * exp_probs_b,
+                 int64_t   n_expert,
+                 int64_t   n_expert_used,
+         llm_ffn_op_type   type_op,
+                    bool   norm_w,
+                    bool   scale_w,
+                   float   w_scale,
+            llama_expert_gating_func_type gating_op,
+                     int   il) const;
+
+    //
+    // inputs
+    //
+
+    ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
+    ggml_tensor * build_inp_pos() const;
+    ggml_tensor * build_inp_out_ids() const;
+    ggml_tensor * build_inp_mean() const;
+    ggml_tensor * build_inp_cls() const;
+    ggml_tensor * build_inp_s_copy() const;
+    ggml_tensor * build_inp_s_mask() const;
+
+    ggml_tensor * build_inp_cross_embd() const;
+    ggml_tensor * build_inp_pos_bucket_enc() const;
+    ggml_tensor * build_inp_pos_bucket_dec() const;
+    ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
+
+    //
+    // attention
+    //
+
+    ggml_tensor * build_attn_mha(
+             ggml_cgraph * gf,
+             ggml_tensor * q,
+             ggml_tensor * k,
+             ggml_tensor * v,
+             ggml_tensor * kq_b,
+             ggml_tensor * kq_mask,
+                    bool   v_trans,
+                   float   kq_scale) const;
+
+    llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
+
+    ggml_tensor * build_attn(
+            llm_graph_input_attn_no_cache * inp,
+            ggml_cgraph * gf,
+            ggml_tensor * wo,
+            ggml_tensor * wo_b,
+            ggml_tensor * q_cur,
+            ggml_tensor * k_cur,
+            ggml_tensor * v_cur,
+            ggml_tensor * kq_b,
+                  float   kq_scale,
+                    int   il) const;
+
+    llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified(
+            bool causal,
+            bool swa) const;
+
+    ggml_tensor * build_attn(
+            llm_graph_input_attn_kv_unified * inp,
+            ggml_cgraph * gf,
+            ggml_tensor * wo,
+            ggml_tensor * wo_b,
+            ggml_tensor * q_cur,
+            ggml_tensor * k_cur,
+            ggml_tensor * v_cur,
+            ggml_tensor * kq_b,
+                  float   kq_scale,
+                    int   il) const;
+
+    llm_graph_input_attn_cross * build_attn_inp_cross() const;
+
+    ggml_tensor * build_attn(
+            llm_graph_input_attn_cross * inp,
+            ggml_cgraph * gf,
+            ggml_tensor * wo,
+            ggml_tensor * wo_b,
+            ggml_tensor * q_cur,
+            ggml_tensor * k_cur,
+            ggml_tensor * v_cur,
+            ggml_tensor * kq_b,
+                  float   kq_scale,
+                    int   il) const;
+
+    //
+    // recurrent
+    //
+
+    ggml_tensor * build_copy_mask_state(
+             ggml_cgraph * gf,
+             ggml_tensor * s,
+             ggml_tensor * state_copy,
+             ggml_tensor * state_mask,
+                 int32_t   n_state,
+                 int32_t   n_seqs) const;
+
+    ggml_tensor * build_rwkv_token_shift_load(
+             ggml_cgraph * gf,
+             ggml_tensor * state_copy,
+             ggml_tensor * state_mask,
+      const llama_ubatch & ubatch,
+                     int   il) const;
+
+    ggml_tensor * build_rwkv_token_shift_store(
+             ggml_tensor * token_shift,
+      const llama_ubatch & ubatch,
+                     int   il) const;
+
+    //
+    // pooling
+    //
+
+    void build_pooling(
+            ggml_cgraph * gf,
+            ggml_tensor * cls,
+            ggml_tensor * cls_b,
+            ggml_tensor * cls_out,
+            ggml_tensor * cls_out_b) const;
+};
diff --git a/src/llama-io.cpp b/src/llama-io.cpp
new file mode 100644 (file)
index 0000000..7ad70d1
--- /dev/null
@@ -0,0 +1,15 @@
+#include "llama-io.h"
+
+void llama_io_write_i::write_string(const std::string & str) {
+    uint32_t str_size = str.size();
+
+    write(&str_size,  sizeof(str_size));
+    write(str.data(), str_size);
+}
+
+void llama_io_read_i::read_string(std::string & str) {
+    uint32_t str_size;
+    read_to(&str_size, sizeof(str_size));
+
+    str.assign((const char *) read(str_size), str_size);
+}
diff --git a/src/llama-io.h b/src/llama-io.h
new file mode 100644 (file)
index 0000000..ce9216b
--- /dev/null
@@ -0,0 +1,35 @@
+#pragma once
+
+#include <cstddef>
+#include <cstdint>
+#include <string>
+
+struct ggml_tensor;
+
+class llama_io_write_i {
+public:
+    llama_io_write_i() = default;
+    virtual ~llama_io_write_i() = default;
+
+    virtual void write(const void * src, size_t size) = 0;
+    virtual void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) = 0;
+
+    // bytes written so far
+    virtual size_t n_bytes() = 0;
+
+    void write_string(const std::string & str);
+};
+
+class llama_io_read_i {
+public:
+    llama_io_read_i() = default;
+    virtual ~llama_io_read_i() = default;
+
+    virtual const uint8_t * read(size_t size) = 0;
+    virtual void read_to(void * dst, size_t size) = 0;
+
+    // bytes read so far
+    virtual size_t n_bytes() = 0;
+
+    void read_string(std::string & str);
+};
index feffdf0de52cf641f414c2c0da3659d849eb7dda..14c8933b4d6c4f65d0ad0818b2443ad1c676e249 100644 (file)
@@ -6,86 +6,92 @@
 #include "llama-model.h"
 
 #include <algorithm>
+#include <cassert>
 #include <limits>
 #include <map>
+#include <stdexcept>
 
 static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
 
-uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) {
-    // the FA kernels require padding to avoid extra runtime boundary checks
-    return cparams.flash_attn ? 256u : 32u;
+llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) {
 }
 
-bool llama_kv_cache_init(
-             struct llama_kv_cache & cache,
-                 const llama_model & model,
-               const llama_cparams & cparams,
-                         ggml_type   type_k,
-                         ggml_type   type_v,
-                          uint32_t   kv_size,
-                              bool   offload) {
-    const struct llama_hparams & hparams = model.hparams;
-
+bool llama_kv_cache_unified::init(
+        const llama_model & model,
+      const llama_cparams & cparams,
+                ggml_type   type_k,
+                ggml_type   type_v,
+                 uint32_t   kv_size,
+                     bool   offload) {
     const int32_t n_layer = hparams.n_layer;
 
-    cache.has_shift = false;
+    has_shift = false;
 
-    cache.recurrent = llama_model_is_recurrent(&model);
-    cache.v_trans   = !cache.recurrent && !cparams.flash_attn;
-    cache.can_shift = !cache.recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
+    recurrent = llama_model_is_recurrent(&model);
+    v_trans   = !recurrent && !cparams.flash_attn;
+    can_shift = !recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
 
     LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n",
-            __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, cache.can_shift);
+            __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift);
 
-    cache.head = 0;
-    cache.size = kv_size;
-    cache.used = 0;
+    head = 0;
+    size = kv_size;
+    used = 0;
 
-    cache.type_k = type_k;
-    cache.type_v = type_v;
+    this->type_k = type_k;
+    this->type_v = type_v;
 
-    cache.cells.clear();
-    cache.cells.resize(kv_size);
+    cells.clear();
+    cells.resize(kv_size);
 
     // create a context for each buffer type
     std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
     auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
         auto it = ctx_map.find(buft);
         if (it == ctx_map.end()) {
-            struct ggml_init_params params = {
+            ggml_init_params params = {
                 /*.mem_size   =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
                 /*.mem_buffer =*/ NULL,
                 /*.no_alloc   =*/ true,
             };
+
             ggml_context * ctx = ggml_init(params);
             if (!ctx) {
                 return nullptr;
             }
+
             ctx_map[buft] = ctx;
-            cache.ctxs.emplace_back(ctx);
+            ctxs.emplace_back(ctx);
+
             return ctx;
         }
+
         return it->second;
     };
 
-    cache.k_l.reserve(n_layer);
-    cache.v_l.reserve(n_layer);
+    k_l.reserve(n_layer);
+    v_l.reserve(n_layer);
 
     for (int i = 0; i < n_layer; i++) {
         const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
         const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
 
-        LLAMA_LOG_DEBUG("%s: layer %d: n_embd_k_gqa = %d, n_embd_v_gqa = %d\n", __func__, i, n_embd_k_gqa, n_embd_v_gqa);
+        const char * dev_name = "CPU";
 
         ggml_backend_buffer_type_t buft;
         if (offload) {
             auto * dev = model.dev_layer(i);
             buft = ggml_backend_dev_buffer_type(dev);
+
+            dev_name = ggml_backend_dev_name(dev);
         } else {
             buft = ggml_backend_cpu_buffer_type();
         }
-        ggml_context * ctx = ctx_for_buft(buft);
 
+        LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k_gqa = %d, n_embd_v_gqa = %d, dev = %s\n", __func__,
+                i, n_embd_k_gqa, n_embd_v_gqa, dev_name);
+
+        ggml_context * ctx = ctx_for_buft(buft);
         if (!ctx) {
             LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
             return false;
@@ -95,8 +101,8 @@ bool llama_kv_cache_init(
         ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
         ggml_format_name(k, "cache_k_l%d", i);
         ggml_format_name(v, "cache_v_l%d", i);
-        cache.k_l.push_back(k);
-        cache.v_l.push_back(v);
+        k_l.push_back(k);
+        v_l.push_back(v);
     }
 
     // allocate tensors and initialize the buffers to avoid NaNs in the padding
@@ -111,20 +117,346 @@ bool llama_kv_cache_init(
         }
         ggml_backend_buffer_clear(buf, 0);
         LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
-        cache.bufs.emplace_back(buf);
+        bufs.emplace_back(buf);
     }
 
     return true;
 }
 
-struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
-           struct llama_kv_cache & cache,
-       const struct llama_ubatch & ubatch) {
+int32_t llama_kv_cache_unified::get_n_tokens() const {
+    int32_t result = 0;
+
+    for (uint32_t i = 0; i < size; i++) {
+        result += cells[i].seq_id.size();
+    }
+
+    return result;
+}
+
+uint32_t llama_kv_cache_unified::get_used_cells() const {
+    return used;
+}
+
+size_t llama_kv_cache_unified::total_size() const {
+    size_t size = 0;
+    for (const auto & buf : bufs) {
+        size += ggml_backend_buffer_get_size(buf.get());
+    }
+
+    return size;
+}
+
+llama_pos llama_kv_cache_unified::pos_max() const {
+    llama_pos pos_max = -1;
+    for (const auto & cell : cells) {
+        pos_max = std::max(pos_max, cell.pos);
+    }
+
+    return pos_max;
+}
+
+void llama_kv_cache_unified::clear() {
+    for (int32_t i = 0; i < (int32_t) size; ++i) {
+        cells[i].pos = -1;
+        cells[i].seq_id.clear();
+        cells[i].src = -1;
+        cells[i].tail = -1;
+    }
+    head = 0;
+    used = 0;
+
+    for (auto & buf : bufs) {
+        ggml_backend_buffer_clear(buf.get(), 0);
+    }
+}
+
+bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
+    uint32_t new_head = size;
+
+    if (p0 < 0) {
+        p0 = 0;
+    }
+
+    if (p1 < 0) {
+        p1 = std::numeric_limits<llama_pos>::max();
+    }
+
+    // models like Mamba or RWKV can't have a state partially erased
+    if (recurrent) {
+        if (seq_id >= (int64_t) size) {
+            // could be fatal
+            return false;
+        }
+        if (0 <= seq_id) {
+            int32_t & tail_id = cells[seq_id].tail;
+            if (tail_id >= 0) {
+                const llama_kv_cell & cell = cells[tail_id];
+                // partial intersection is invalid
+                if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
+                    return false;
+                }
+                // invalidate tails which will be cleared
+                if (p0 <= cell.pos && cell.pos < p1) {
+                    tail_id = -1;
+                }
+            }
+        } else {
+            // seq_id is negative, then the range should include everything or nothing
+            if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
+                return false;
+            }
+        }
+    }
+
+    for (uint32_t i = 0; i < size; ++i) {
+        if (cells[i].pos >= p0 && cells[i].pos < p1) {
+            if (seq_id < 0) {
+                cells[i].seq_id.clear();
+            } else if (cells[i].has_seq_id(seq_id)) {
+                cells[i].seq_id.erase(seq_id);
+            } else {
+                continue;
+            }
+            if (cells[i].is_empty()) {
+                // keep count of the number of used cells
+                if (cells[i].pos >= 0) {
+                    used--;
+                }
+
+                cells[i].pos = -1;
+                cells[i].src = -1;
+
+                if (new_head == size) {
+                    new_head = i;
+                }
+            }
+        }
+    }
+
+    // If we freed up a slot, set head to it so searching can start there.
+    if (new_head != size && new_head < head) {
+        head = new_head;
+    }
+
+    return true;
+}
+
+void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
+    if (seq_id_src == seq_id_dst) {
+        return;
+    }
+
+    if (p0 < 0) {
+        p0 = 0;
+    }
+
+    if (p1 < 0) {
+        p1 = std::numeric_limits<llama_pos>::max();
+    }
+
+    if (recurrent) {
+        if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
+            llama_kv_cell & tail_src = cells[seq_id_src];
+            llama_kv_cell & tail_dst = cells[seq_id_dst];
+            if (tail_dst.tail >= 0) {
+                // clear destination seq_id if it wasn't empty
+                llama_kv_cell & cell_dst = cells[tail_dst.tail];
+
+                cell_dst.seq_id.erase(seq_id_dst);
+                tail_dst.tail = -1;
+                if (cell_dst.seq_id.empty()) {
+                    cell_dst.pos = -1;
+                    cell_dst.delta = -1;
+                    cell_dst.src = -1;
+                    used -= 1;
+                }
+            }
+            if (tail_src.tail >= 0) {
+                llama_kv_cell & cell_src = cells[tail_src.tail];
+
+                cell_src.seq_id.insert(seq_id_dst);
+                tail_dst.tail = tail_src.tail;
+            }
+        }
+
+        return;
+    }
+
+    // otherwise, this is the KV of a Transformer-like model
+    head = 0;
+
+    for (uint32_t i = 0; i < size; ++i) {
+        if (cells[i].has_seq_id(seq_id_src) && cells[i].pos >= p0 && cells[i].pos < p1) {
+            cells[i].seq_id.insert(seq_id_dst);
+        }
+    }
+}
+
+void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
+    uint32_t new_head = size;
+
+    for (uint32_t i = 0; i < size; ++i) {
+        if (recurrent && (llama_seq_id) i != seq_id) {
+            cells[i].tail = -1;
+        }
+
+        if (!cells[i].has_seq_id(seq_id)) {
+            if (cells[i].pos >= 0) {
+                used--;
+            }
+
+            cells[i].pos = -1;
+            cells[i].src = -1;
+            cells[i].seq_id.clear();
+
+            if (new_head == size){
+                new_head = i;
+            }
+        } else {
+            cells[i].seq_id.clear();
+            cells[i].seq_id.insert(seq_id);
+        }
+    }
+
+    // If we freed up a slot, set head to it so searching can start there.
+    if (new_head != size && new_head < head) {
+        head = new_head;
+    }
+}
+
+void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
+    if (delta == 0) {
+        return;
+    }
+
+    uint32_t new_head = size;
+
+    if (p0 < 0) {
+        p0 = 0;
+    }
+
+    if (p1 < 0) {
+        p1 = std::numeric_limits<llama_pos>::max();
+    }
+
+    // If there is no range then return early to avoid looping over the
+    if (p0 == p1) {
+        return;
+    }
+
+    if (recurrent) {
+        // for Mamba-like or RWKV models, only the pos needs to be shifted
+        if (0 <= seq_id && seq_id < (int64_t) size) {
+            const int32_t tail_id = cells[seq_id].tail;
+            if (tail_id >= 0) {
+                llama_kv_cell & cell = cells[tail_id];
+                if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
+                    cell.pos += delta;
+                }
+            }
+        }
+        return;
+    }
+
+    for (uint32_t i = 0; i < size; ++i) {
+        if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
+            has_shift = true;
+            cells[i].pos   += delta;
+            cells[i].delta += delta;
+
+            if (cells[i].pos < 0) {
+                if (!cells[i].is_empty()) {
+                    used--;
+                }
+                cells[i].pos = -1;
+                cells[i].seq_id.clear();
+                if (new_head == size) {
+                    new_head = i;
+                }
+            }
+        }
+    }
+
+    // If we freed up a slot, set head to it so searching can start there.
+    // Otherwise we just start the next search from the beginning.
+    head = new_head != size ? new_head : 0;
+}
+
+void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
+    if (d == 1) {
+        return;
+    }
+
+    if (p0 < 0) {
+        p0 = 0;
+    }
+
+    if (p1 < 0) {
+        p1 = std::numeric_limits<llama_pos>::max();
+    }
+
+    // If there is no range then return early to avoid looping over the cache.
+    if (p0 == p1) {
+        return;
+    }
+
+    if (recurrent) {
+        // for Mamba-like or RWKV models, only the pos needs to be changed
+        if (0 <= seq_id && seq_id < (int64_t) size) {
+            const int32_t tail_id = cells[seq_id].tail;
+            if (tail_id >= 0) {
+                llama_kv_cell & cell = cells[tail_id];
+                if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
+                    cell.pos /= d;
+                }
+            }
+        }
+
+        return;
+    }
+
+    for (uint32_t i = 0; i < size; ++i) {
+        if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
+            has_shift = true;
+
+            {
+                llama_pos p_old = cells[i].pos;
+                cells[i].pos   /= d;
+                cells[i].delta += cells[i].pos - p_old;
+            }
+        }
+    }
+}
+
+llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) {
+    llama_pos result = 0;
+
+    for (uint32_t i = 0; i < size; ++i) {
+        if (cells[i].has_seq_id(seq_id)) {
+            result = std::max(result, cells[i].pos);
+        }
+    }
+
+    return result;
+}
+
+void llama_kv_cache_unified::defrag() {
+    if (!recurrent) {
+        do_defrag = true;
+    }
+}
+
+bool llama_kv_cache_unified::get_can_shift() const {
+    return can_shift;
+}
+
+llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
+       const llama_ubatch & ubatch) {
     const uint32_t n_tokens = ubatch.n_tokens;
     const uint32_t n_seqs   = ubatch.n_seqs;
     const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
 
-    if (cache.recurrent) {
+    if (recurrent) {
         // For recurrent state architectures (like Mamba or RWKV),
         // each cache cell can store the state for a whole sequence.
         // A slot should be always be contiguous.
@@ -132,7 +464,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
         // can only process batches with an equal number of new tokens in each sequence
         GGML_ASSERT(ubatch.equal_seqs);
 
-        int32_t min = cache.size - 1;
+        int32_t min = size - 1;
         int32_t max = 0;
 
         // everything should fit if all seq_ids are smaller than the max
@@ -141,16 +473,16 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
             for (uint32_t j = 0; j < n_seq_id; ++j) {
                 const llama_seq_id seq_id = ubatch.seq_id[s][j];
 
-                if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
+                if (seq_id < 0 || (uint32_t) seq_id >= size) {
                     // too big seq_id
                     // TODO: would it be possible to resize the cache instead?
-                    LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
+                    LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
                     return llama_kv_cache_slot_info_failed;
                 }
                 if (j > 0) {
-                    llama_kv_cell & seq = cache.cells[seq_id];
+                    llama_kv_cell & seq = cells[seq_id];
                     if (seq.tail >= 0) {
-                        llama_kv_cell & cell = cache.cells[seq.tail];
+                        llama_kv_cell & cell = cells[seq.tail];
                         // clear cells from seq_ids that become shared
                         // (should not normally happen, but let's handle it anyway)
                         cell.seq_id.erase(seq_id);
@@ -158,7 +490,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
                         if (cell.seq_id.empty()) {
                             cell.pos = -1;
                             cell.src = -1;
-                            cache.used -= 1;
+                            used -= 1;
                         }
                     }
                 }
@@ -168,9 +500,9 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
 #ifndef NDEBUG
         {
             std::vector<int32_t> tails_verif;
-            tails_verif.assign(cache.size, -1);
-            for (uint32_t i = 0; i < cache.size; ++i) {
-                llama_kv_cell & cell = cache.cells[i];
+            tails_verif.assign(size, -1);
+            for (uint32_t i = 0; i < size; ++i) {
+                llama_kv_cell & cell = cells[i];
                 for (llama_seq_id seq_id : cell.seq_id) {
                     if (tails_verif[seq_id] != -1) {
                         LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
@@ -178,20 +510,20 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
                     tails_verif[seq_id] = i;
                 }
             }
-            for (uint32_t i = 0; i < cache.size; ++i) {
-                if (tails_verif[i] != cache.cells[i].tail) {
-                    LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cache.cells[i].tail, tails_verif[i]);
+            for (uint32_t i = 0; i < size; ++i) {
+                if (tails_verif[i] != cells[i].tail) {
+                    LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]);
                 }
             }
         }
 #endif
 
         // find next empty cell
-        uint32_t next_empty_cell = cache.head;
+        uint32_t next_empty_cell = head;
 
-        for (uint32_t i = 0; i < cache.size; ++i) {
-            if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
-            llama_kv_cell & cell = cache.cells[next_empty_cell];
+        for (uint32_t i = 0; i < size; ++i) {
+            if (next_empty_cell >= size) { next_empty_cell -= size; }
+            llama_kv_cell & cell = cells[next_empty_cell];
             if (cell.is_empty()) { break; }
             next_empty_cell += 1;
         }
@@ -199,20 +531,20 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
         // find usable cell range
         for (uint32_t s = 0; s < n_seqs; ++s) {
             const llama_seq_id seq_id = ubatch.seq_id[s][0];
-            llama_kv_cell & seq_meta = cache.cells[seq_id];
+            llama_kv_cell & seq_meta = cells[seq_id];
             bool has_cell = false;
             if (seq_meta.tail >= 0) {
-                llama_kv_cell & cell = cache.cells[seq_meta.tail];
+                llama_kv_cell & cell = cells[seq_meta.tail];
                 GGML_ASSERT(cell.has_seq_id(seq_id));
                 // does this seq_id "own" the cell?
                 if (cell.seq_id.size() == 1) { has_cell = true; }
             }
             if (!has_cell) {
-                llama_kv_cell & empty_cell = cache.cells[next_empty_cell];
+                llama_kv_cell & empty_cell = cells[next_empty_cell];
                 GGML_ASSERT(empty_cell.is_empty());
                 // copy old tail into the empty cell
                 if (seq_meta.tail >= 0) {
-                    llama_kv_cell & orig_cell = cache.cells[seq_meta.tail];
+                    llama_kv_cell & orig_cell = cells[seq_meta.tail];
                     empty_cell.pos = orig_cell.pos;
                     empty_cell.src = orig_cell.src;
                     orig_cell.seq_id.erase(seq_id);
@@ -222,9 +554,9 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
                 // find next empty cell
                 if (s + 1 < n_seqs) {
                     next_empty_cell += 1;
-                    for (uint32_t i = 0; i < cache.size; ++i) {
-                        if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
-                        llama_kv_cell & cell = cache.cells[next_empty_cell];
+                    for (uint32_t i = 0; i < size; ++i) {
+                        if (next_empty_cell >= size) { next_empty_cell -= size; }
+                        llama_kv_cell & cell = cells[next_empty_cell];
                         if (cell.is_empty()) { break; }
                         next_empty_cell += 1;
                     }
@@ -237,10 +569,10 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
         // gather and re-order
         for (uint32_t s = 0; s < n_seqs; ++s) {
             int32_t dst_id = s + min;
-            int32_t src_id = cache.cells[ubatch.seq_id[s][0]].tail;
+            int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
             if (dst_id != src_id) {
-                llama_kv_cell & dst_cell = cache.cells[dst_id];
-                llama_kv_cell & src_cell = cache.cells[src_id];
+                llama_kv_cell & dst_cell = cells[dst_id];
+                llama_kv_cell & src_cell = cells[src_id];
 
                 std::swap(dst_cell.pos, src_cell.pos);
                 std::swap(dst_cell.src, src_cell.src);
@@ -248,10 +580,10 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
 
                 // swap tails (assuming they NEVER overlap)
                 for (const llama_seq_id seq_id : src_cell.seq_id) {
-                    cache.cells[seq_id].tail = src_id;
+                    cells[seq_id].tail = src_id;
                 }
                 for (const llama_seq_id seq_id : dst_cell.seq_id) {
-                    cache.cells[seq_id].tail = dst_id;
+                    cells[seq_id].tail = dst_id;
                 }
             }
         }
@@ -260,7 +592,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
         for (uint32_t s = 0; s < n_seqs; ++s) {
             const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
             int32_t cell_id = s + min;
-            llama_kv_cell & cell = cache.cells[cell_id];
+            llama_kv_cell & cell = cells[cell_id];
 
             if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
                 // What should happen when the pos backtracks or skips a value?
@@ -273,41 +605,42 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
             for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
                 const llama_seq_id seq_id = ubatch.seq_id[s][j];
                 cell.seq_id.insert(seq_id);
-                cache.cells[seq_id].tail = cell_id;
+                cells[seq_id].tail = cell_id;
             }
         }
 
         // allow getting the range of used cells, from head to head + n
-        cache.head = min;
-        cache.n    = max - min + 1;
-        cache.used = std::count_if(cache.cells.begin(), cache.cells.end(),
+        head = min;
+        n    = max - min + 1;
+        used = std::count_if(cells.begin(), cells.end(),
             [](const llama_kv_cell& cell){ return !cell.is_empty(); });
 
         // sanity check
-        return llama_kv_cache_slot_info(cache.n >= n_seqs);
+        return llama_kv_cache_slot_info(n >= n_seqs);
     }
+
     // otherwise, one cell per token.
 
-    if (n_tokens > cache.size) {
-        LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
+    if (n_tokens > size) {
+        LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size);
         return llama_kv_cache_slot_info_failed;
     }
 
     uint32_t n_tested = 0;
 
     while (true) {
-        if (cache.head + n_tokens > cache.size) {
-            n_tested += cache.size - cache.head;
-            cache.head = 0;
+        if (head + n_tokens > size) {
+            n_tested += size - head;
+            head = 0;
             continue;
         }
 
         bool found = true;
         for (uint32_t i = 0; i < n_tokens; i++) {
-            if (cache.cells[cache.head + i].pos >= 0) {
+            if (cells[head + i].pos >= 0) {
                 found = false;
-                cache.head += i + 1;
-                n_tested   += i + 1;
+                head     += i + 1;
+                n_tested += i + 1;
                 break;
             }
         }
@@ -316,7 +649,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
             break;
         }
 
-        if (n_tested >= cache.size) {
+        if (n_tested >= size) {
             //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
             return llama_kv_cache_slot_info_failed;
         }
@@ -325,22 +658,27 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
     for (uint32_t s = 0; s < n_seqs; s++) {
         for (uint32_t i = 0; i < n_seq_tokens; ++i) {
             uint32_t k = s*n_seq_tokens + i;
-            cache.cells[cache.head + k].pos = ubatch.pos[k];
+            cells[head + k].pos = ubatch.pos[k];
 
             for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) {
-                cache.cells[cache.head + k].seq_id.insert(ubatch.seq_id[s][j]);
+                cells[head + k].seq_id.insert(ubatch.seq_id[s][j]);
             }
         }
     }
 
-    cache.used += n_tokens;
+    used += n_tokens;
 
-    return llama_kv_cache_slot_info(cache.head, cache.head + n_tokens);
+    return llama_kv_cache_slot_info(head, head + n_tokens);
 }
 
-uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
-    for (uint32_t i = cache.size; i > 0; --i) {
-        const llama_kv_cell & cell = cache.cells[i - 1];
+uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const {
+    // the FA kernels require padding to avoid extra runtime boundary checks
+    return cparams.flash_attn ? 256u : 32u;
+}
+
+uint32_t llama_kv_cache_unified::cell_max() const {
+    for (uint32_t i = size; i > 0; --i) {
+        const llama_kv_cell & cell = cells[i - 1];
 
         if (cell.pos >= 0 && !cell.is_empty()) {
             return i;
@@ -350,289 +688,659 @@ uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
     return 0;
 }
 
-void llama_kv_cache_clear(struct llama_kv_cache & cache) {
-    for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
-        cache.cells[i].pos = -1;
-        cache.cells[i].seq_id.clear();
-        cache.cells[i].src = -1;
-        cache.cells[i].tail = -1;
+size_t llama_kv_cache_unified::size_k_bytes() const {
+    size_t size_k_bytes = 0;
+
+    for (const auto & k : k_l) {
+        size_k_bytes += ggml_nbytes(k);
     }
-    cache.head = 0;
-    cache.used = 0;
 
-    for (auto & buf : cache.bufs) {
-        ggml_backend_buffer_clear(buf.get(), 0);
+    return size_k_bytes;
+}
+
+size_t llama_kv_cache_unified::size_v_bytes() const {
+    size_t size_v_bytes = 0;
+
+    for (const auto & v : v_l) {
+        size_v_bytes += ggml_nbytes(v);
     }
+
+    return size_v_bytes;
 }
 
-bool llama_kv_cache_seq_rm(
-        struct llama_kv_cache & cache,
-                 llama_seq_id   seq_id,
-                    llama_pos   p0,
-                    llama_pos   p1) {
-    uint32_t new_head = cache.size;
+bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
+    const uint32_t n_layer = hparams.n_layer;
 
-    if (p0 < 0) p0 = 0;
-    if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
+    const uint32_t n_kv   = cell_max();
+    const uint32_t n_used = used;
 
-    // models like Mamba or RWKV can't have a state partially erased
-    if (cache.recurrent) {
-        if (seq_id >= (int64_t) cache.size) {
-            // could be fatal
-            return false;
+    assert(n_used <= n_kv);
+
+    //const int64_t t_start = ggml_time_us();
+
+    // number of cells moved
+    uint32_t n_moves = 0;
+
+    // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
+    //   - source view, destination view, copy operation
+    //   - x2 for keys and values
+    //const uint32_t max_moves = max_nodes()/(6*n_layer);
+    // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
+    const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
+
+    // determine which KV cells to move where
+    //
+    //  cell i moves to ids[i]
+    //
+    //  if ids[i] == i || ids[i] == n_kv, then cell i is not moved
+    //
+    auto & ids = defrag_info.ids;
+
+    ids.clear();
+    ids.resize(n_kv, n_kv);
+
+    for (uint32_t i0 = 0; i0 < n_used; ++i0) {
+        const auto & cell0 = cells[i0];
+
+        if (!cell0.is_empty()) {
+            ids[i0] = i0;
+
+            continue;
         }
-        if (0 <= seq_id) {
-            int32_t & tail_id = cache.cells[seq_id].tail;
-            if (tail_id >= 0) {
-                const llama_kv_cell & cell = cache.cells[tail_id];
-                // partial intersection is invalid
-                if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
-                    return false;
-                }
-                // invalidate tails which will be cleared
-                if (p0 <= cell.pos && cell.pos < p1) {
-                    tail_id = -1;
-                }
+
+        // found a hole - fill it with data from the end of the cache
+
+        uint32_t nh = 1;
+
+        // determine the size of the hole
+        while (i0 + nh < n_used && cells[i0 + nh].is_empty()) {
+            nh++;
+        }
+
+        uint32_t nf = 0;
+        uint32_t is = n_kv - 1;
+
+        // starting from the end, find nh non-empty cells
+        for (; is > i0; --is) {
+            const auto & cell1 = cells[is];
+
+            if (cell1.is_empty() || ids[is] != n_kv) {
+                continue;
             }
-        } else {
-            // seq_id is negative, then the range should include everything or nothing
-            if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
-                return false;
+
+            // non-empty cell which is not yet moved
+            nf++;
+
+            if (nf == nh) {
+                break;
             }
         }
-    }
 
-    for (uint32_t i = 0; i < cache.size; ++i) {
-        if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
-            if (seq_id < 0) {
-                cache.cells[i].seq_id.clear();
-            } else if (cache.cells[i].has_seq_id(seq_id)) {
-                cache.cells[i].seq_id.erase(seq_id);
-            } else {
+        // this can only happen if `n_used` is not accurate, which would be a bug
+        GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
+
+        nf = 0;
+
+        uint32_t i1 = is;
+
+        // are we moving a continuous block of memory?
+        bool cont = false;
+
+        // should we stop searching for the next move?
+        bool stop = false;
+
+        // go back and move the nf cells to the hole
+        for (; i1 < n_kv; ++i1) {
+            auto & cell1 = cells[i1];
+
+            if (cell1.is_empty() || ids[i1] != n_kv) {
+                if (n_moves == max_moves) {
+                    stop = true;
+                    break;
+                }
+
+                cont = false;
                 continue;
             }
-            if (cache.cells[i].is_empty()) {
-                // keep count of the number of used cells
-                if (cache.cells[i].pos >= 0) cache.used--;
 
-                cache.cells[i].pos = -1;
-                cache.cells[i].src = -1;
-                if (new_head == cache.size) new_head = i;
+            // this cell goes to (i0 + nf)
+            ids[i1] = i0 + nf;
+
+            // move the cell meta data
+            cells[i0 + nf] = cell1;
+
+            // clear the old cell and move the head there
+            cell1 = llama_kv_cell();
+            head = n_used;
+
+            if (!cont) {
+                n_moves++;
+                cont = true;
+            }
+
+            nf++;
+
+            if (nf == nh) {
+                break;
             }
         }
+
+        if (stop || n_moves == max_moves) {
+            break;
+        }
+
+        //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
+
+        i0 += nh - 1;
     }
 
-    // If we freed up a slot, set head to it so searching can start there.
-    if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
+    if (n_moves == 0) {
+        return false;
+    }
+
+    LLAMA_LOG_DEBUG("(tmp log) KV defrag cell moves: %u\n", n_moves);
+
+    LLAMA_LOG_DEBUG("expected gf nodes: %u\n", 6*n_moves*n_layer);
 
     return true;
 }
 
-void llama_kv_cache_seq_cp(
-        struct llama_kv_cache & cache,
-                 llama_seq_id   seq_id_src,
-                 llama_seq_id   seq_id_dst,
-                    llama_pos   p0,
-                    llama_pos   p1) {
-    if (p0 < 0) p0 = 0;
-    if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
-
-    if (cache.recurrent) {
-        if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
-            llama_kv_cell & tail_src = cache.cells[seq_id_src];
-            llama_kv_cell & tail_dst = cache.cells[seq_id_dst];
-            if (tail_dst.tail >= 0) {
-                // clear destination seq_id if it wasn't empty
-                llama_kv_cell & cell_dst = cache.cells[tail_dst.tail];
-
-                cell_dst.seq_id.erase(seq_id_dst);
-                tail_dst.tail = -1;
-                if (cell_dst.seq_id.empty()) {
-                    cell_dst.pos = -1;
-                    cell_dst.delta = -1;
-                    cell_dst.src = -1;
-                    cache.used -= 1;
-                }
+void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
+    std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
+    uint32_t cell_count = 0;
+
+    // Count the number of cells with the specified seq_id
+    // Find all the ranges of cells with this seq id (or all, when -1)
+    uint32_t cell_range_begin = size;
+    for (uint32_t i = 0; i < size; ++i) {
+        const auto & cell = cells[i];
+        if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
+            ++cell_count;
+            if (cell_range_begin == size) {
+                cell_range_begin = i;
             }
-            if (tail_src.tail >= 0) {
-                llama_kv_cell & cell_src = cache.cells[tail_src.tail];
-
-                cell_src.seq_id.insert(seq_id_dst);
-                tail_dst.tail = tail_src.tail;
+        } else {
+            if (cell_range_begin != size) {
+                cell_ranges.emplace_back(cell_range_begin, i);
+                cell_range_begin = size;
             }
         }
+    }
+    if (cell_range_begin != size) {
+        cell_ranges.emplace_back(cell_range_begin, size);
+    }
 
-        return;
+    // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
+    uint32_t cell_count_check = 0;
+    for (const auto & range : cell_ranges) {
+        cell_count_check += range.second - range.first;
     }
-    // otherwise, this is the KV cache of a Transformer-like model
+    GGML_ASSERT(cell_count == cell_count_check);
+
+    io.write(&cell_count, sizeof(cell_count));
+
+    state_write_meta(io, cell_ranges, seq_id);
+    state_write_data(io, cell_ranges);
+}
+
+void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
+    uint32_t cell_count;
+    io.read_to(&cell_count, sizeof(cell_count));
 
-    cache.head = 0;
+    bool res = true;
+    res = res && state_read_meta(io, cell_count, seq_id);
+    res = res && state_read_data(io, cell_count);
 
-    for (uint32_t i = 0; i < cache.size; ++i) {
-        if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
-            cache.cells[i].seq_id.insert(seq_id_dst);
+    if (!res) {
+        if (seq_id == -1) {
+            clear();
+        } else {
+            seq_rm(seq_id, -1, -1);
         }
+        throw std::runtime_error("failed to restore kv cache");
     }
 }
 
-void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) {
-    uint32_t new_head = cache.size;
+void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
+    for (const auto & range : cell_ranges) {
+        for (uint32_t i = range.first; i < range.second; ++i) {
+            const auto & cell = cells[i];
+            const llama_pos pos      = cell.pos;
+            const uint32_t  n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
+
+            io.write(&pos,      sizeof(pos));
+            io.write(&n_seq_id, sizeof(n_seq_id));
 
-    for (uint32_t i = 0; i < cache.size; ++i) {
-        if (cache.recurrent && (llama_seq_id) i != seq_id) {
-            cache.cells[i].tail = -1;
+            if (n_seq_id) {
+                for (auto seq_id : cell.seq_id) {
+                    io.write(&seq_id, sizeof(seq_id));
+                }
+            }
         }
-        if (!cache.cells[i].has_seq_id(seq_id)) {
-            if (cache.cells[i].pos >= 0) cache.used--;
-            cache.cells[i].pos = -1;
-            cache.cells[i].src = -1;
-            cache.cells[i].seq_id.clear();
-            if (new_head == cache.size) new_head = i;
-        } else {
-            cache.cells[i].seq_id.clear();
-            cache.cells[i].seq_id.insert(seq_id);
+    }
+}
+
+void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
+    const uint32_t v_trans = this->v_trans ? 1 : 0;
+    const uint32_t n_layer = hparams.n_layer;
+
+    io.write(&v_trans, sizeof(v_trans));
+    io.write(&n_layer, sizeof(n_layer));
+
+    std::vector<uint8_t> tmp_buf;
+
+    // Iterate and write all the keys first, each row is a cell
+    // Get whole range at a time
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+
+        // Write key type
+        const int32_t k_type_i = (int32_t)k_l[il]->type;
+        io.write(&k_type_i, sizeof(k_type_i));
+
+        // Write row size of key
+        const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
+        io.write(&k_size_row, sizeof(k_size_row));
+
+        // Read each range of cells of k_size length each into tmp_buf and write out
+        for (const auto & range : cell_ranges) {
+            const size_t range_size = range.second - range.first;
+            const size_t buf_size = range_size * k_size_row;
+            io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
         }
     }
 
-    // If we freed up a slot, set head to it so searching can start there.
-    if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
-}
+    if (!v_trans) {
+        for (uint32_t il = 0; il < n_layer; ++il) {
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
 
-void llama_kv_cache_seq_add(
-        struct llama_kv_cache & cache,
-                 llama_seq_id   seq_id,
-                    llama_pos   p0,
-                    llama_pos   p1,
-                    llama_pos   delta) {
-    uint32_t new_head = cache.size;
-
-    if (p0 < 0) p0 = 0;
-    if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
-    // If there is no range then return early to avoid looping over the cache.
-    if (p0 == p1) return;
+            // Write value type
+            const int32_t v_type_i = (int32_t)v_l[il]->type;
+            io.write(&v_type_i, sizeof(v_type_i));
 
-    if (cache.recurrent) {
-        // for Mamba-like or RWKV models, only the pos needs to be shifted
-        if (0 <= seq_id && seq_id < (int64_t) cache.size) {
-            const int32_t tail_id = cache.cells[seq_id].tail;
-            if (tail_id >= 0) {
-                llama_kv_cell & cell = cache.cells[tail_id];
-                if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
-                    cell.pos += delta;
+            // Write row size of value
+            const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
+            io.write(&v_size_row, sizeof(v_size_row));
+
+            // Read each range of cells of v_size length each into tmp_buf and write out
+            for (const auto & range : cell_ranges) {
+                const size_t range_size = range.second - range.first;
+                const size_t buf_size = range_size * v_size_row;
+                io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
+            }
+        }
+    } else {
+        // When v is transposed, we also need the element size and get the element ranges from each row
+        const uint32_t kv_size = size;
+        for (uint32_t il = 0; il < n_layer; ++il) {
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
+            // Write value type
+            const int32_t v_type_i = (int32_t)v_l[il]->type;
+            io.write(&v_type_i, sizeof(v_type_i));
+
+            // Write element size
+            const uint32_t v_size_el = ggml_type_size(v_l[il]->type);
+            io.write(&v_size_el, sizeof(v_size_el));
+
+            // Write GQA embedding size
+            io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
+
+            // For each row, we get the element values of each cell
+            for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+                // Read each range of cells of v_size_el length each into tmp_buf and write out
+                for (const auto & range : cell_ranges) {
+                    const size_t range_size = range.second - range.first;
+                    const size_t src_offset = (range.first + j * kv_size) * v_size_el;
+                    const size_t buf_size = range_size * v_size_el;
+                    io.write_tensor(v_l[il], src_offset, buf_size);
                 }
             }
         }
-        return;
     }
+}
+
+bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
+    if (dest_seq_id != -1) {
+        // single sequence
 
-    for (uint32_t i = 0; i < cache.size; ++i) {
-        if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
-            cache.has_shift = true;
-            cache.cells[i].pos   += delta;
-            cache.cells[i].delta += delta;
+        seq_rm(dest_seq_id, -1, -1);
+
+        llama_sbatch sbatch;
+        llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
+
+        batch.n_tokens = cell_count;
+        batch.n_seq_tokens = cell_count;
+        batch.n_seqs = 1;
+
+        for (uint32_t i = 0; i < cell_count; ++i) {
+            llama_pos pos;
+            uint32_t n_seq_id;
+
+            io.read_to(&pos,      sizeof(pos));
+            io.read_to(&n_seq_id, sizeof(n_seq_id));
+
+            if (n_seq_id != 0) {
+                LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
+                return false;
+            }
 
-            if (cache.cells[i].pos < 0) {
-                if (!cache.cells[i].is_empty()) {
-                    cache.used--;
+            batch.pos[i] = pos;
+        }
+        batch.n_seq_id[0] = 1;
+        batch.seq_id[0] = &dest_seq_id;
+        if (!find_slot(batch)) {
+            LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
+            return false;
+        }
+
+        // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
+        // Assume that this is one contiguous block of cells
+        GGML_ASSERT(head + cell_count <= size);
+        GGML_ASSERT(cells[head].pos == batch.pos[0]);
+        GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
+        GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
+        GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
+    } else {
+        // whole KV cache restore
+
+        if (cell_count > size) {
+            LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
+            return false;
+        }
+
+        clear();
+
+        for (uint32_t i = 0; i < cell_count; ++i) {
+            llama_kv_cell & cell = cells[i];
+
+            llama_pos pos;
+            uint32_t  n_seq_id;
+
+            io.read_to(&pos,      sizeof(pos));
+            io.read_to(&n_seq_id, sizeof(n_seq_id));
+
+            cell.pos = pos;
+
+            for (uint32_t j = 0; j < n_seq_id; ++j) {
+                llama_seq_id seq_id;
+                io.read_to(&seq_id, sizeof(seq_id));
+
+                // TODO: llama_kv_cache_unified should have a notion of max sequences
+                //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
+                if (seq_id < 0) {
+                    //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
+                    LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
+                    return false;
                 }
-                cache.cells[i].pos = -1;
-                cache.cells[i].seq_id.clear();
-                if (new_head == cache.size) {
-                    new_head = i;
+
+                cell.seq_id.insert(seq_id);
+
+                if (recurrent) {
+                    int32_t & tail = cells[seq_id].tail;
+                    if (tail != -1) {
+                        LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
+                        return false;
+                    }
+                    tail = i;
                 }
             }
         }
+
+        head = 0;
+        used = cell_count;
     }
 
-    // If we freed up a slot, set head to it so searching can start there.
-    // Otherwise we just start the next search from the beginning.
-    cache.head = new_head != cache.size ? new_head : 0;
+    if (recurrent) {
+        for (uint32_t i = 0; i < cell_count; ++i) {
+            uint32_t cell_id = head + i;
+            // make sure the recurrent states will keep their restored state
+            cells[cell_id].src = cell_id;
+        }
+    }
+
+    return true;
 }
 
-void llama_kv_cache_seq_div(
-        struct llama_kv_cache & cache,
-                 llama_seq_id   seq_id,
-                    llama_pos   p0,
-                    llama_pos   p1,
-                          int   d) {
-    if (p0 < 0) p0 = 0;
-    if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
-    // If there is no range then return early to avoid looping over the cache.
-    if (p0 == p1) return;
+bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
+    uint32_t v_trans;
+    uint32_t n_layer;
+    io.read_to(&v_trans, sizeof(v_trans));
+    io.read_to(&n_layer, sizeof(n_layer));
 
-    if (cache.recurrent) {
-        // for Mamba-like or RWKV models, only the pos needs to be changed
-        if (0 <= seq_id && seq_id < (int64_t) cache.size) {
-            const int32_t tail_id = cache.cells[seq_id].tail;
-            if (tail_id >= 0) {
-                llama_kv_cell & cell = cache.cells[tail_id];
-                if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
-                    cell.pos /= d;
+    if (n_layer != hparams.n_layer) {
+        LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
+        return false;
+    }
+    if (cell_count > size) {
+        LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
+        return false;
+    }
+    if (v_trans != (bool) v_trans) {
+        LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
+        return false;
+    }
+
+    // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+
+        // Read type of key
+        int32_t k_type_i_ref;
+        io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
+        const int32_t k_type_i = (int32_t) k_l[il]->type;
+        if (k_type_i != k_type_i_ref) {
+            LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
+            return false;
+        }
+
+        // Read row size of key
+        uint64_t k_size_row_ref;
+        io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
+        const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
+        if (k_size_row != k_size_row_ref) {
+            LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
+            return false;
+        }
+
+        if (cell_count) {
+            // Read and set the keys for the whole cell range
+            ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
+        }
+    }
+
+    if (!v_trans) {
+        for (uint32_t il = 0; il < n_layer; ++il) {
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
+            // Read type of value
+            int32_t v_type_i_ref;
+            io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
+            const int32_t v_type_i = (int32_t)v_l[il]->type;
+            if (v_type_i != v_type_i_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
+                return false;
+            }
+
+            // Read row size of value
+            uint64_t v_size_row_ref;
+            io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
+            const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
+            if (v_size_row != v_size_row_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
+                return false;
+            }
+
+            if (cell_count) {
+                // Read and set the values for the whole cell range
+                ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
+            }
+        }
+    } else {
+        // For each layer, read the values for each cell (transposed)
+        for (uint32_t il = 0; il < n_layer; ++il) {
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
+            // Read type of value
+            int32_t v_type_i_ref;
+            io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
+            const int32_t v_type_i = (int32_t)v_l[il]->type;
+            if (v_type_i != v_type_i_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
+                return false;
+            }
+
+            // Read element size of value
+            uint32_t v_size_el_ref;
+            io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
+            const size_t v_size_el = ggml_type_size(v_l[il]->type);
+            if (v_size_el != v_size_el_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
+                return false;
+            }
+
+            // Read GQA embedding size
+            uint32_t n_embd_v_gqa_ref;
+            io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
+            if (n_embd_v_gqa != n_embd_v_gqa_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
+                return false;
+            }
+
+            if (cell_count) {
+                // For each row in the transposed matrix, read the values for the whole cell range
+                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+                    const size_t dst_offset = (head + j * size) * v_size_el;
+                    ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
                 }
             }
         }
+    }
+
+    return true;
+}
+
+//
+// interface implementation
+//
+
+int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv) {
+    if (!kv) {
+        return 0;
+    }
+
+    return kv->get_n_tokens();
+}
+
+int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) {
+    if (!kv) {
+        return 0;
+    }
+
+    return kv->get_used_cells();
+}
+
+void llama_kv_cache_clear(llama_kv_cache * kv) {
+    if (!kv) {
         return;
     }
 
-    for (uint32_t i = 0; i < cache.size; ++i) {
-        if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
-            cache.has_shift = true;
+    kv->clear();
+}
 
-            {
-                llama_pos p_old = cache.cells[i].pos;
-                cache.cells[i].pos   /= d;
-                cache.cells[i].delta += cache.cells[i].pos - p_old;
-            }
-        }
+bool llama_kv_cache_seq_rm(
+        llama_kv_cache * kv,
+          llama_seq_id   seq_id,
+             llama_pos   p0,
+             llama_pos   p1) {
+    if (!kv) {
+        return true;
     }
+
+    return kv->seq_rm(seq_id, p0, p1);
 }
 
-llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama_seq_id seq_id) {
-    llama_pos result = 0;
+void llama_kv_cache_seq_cp(
+        llama_kv_cache * kv,
+          llama_seq_id   seq_id_src,
+          llama_seq_id   seq_id_dst,
+             llama_pos   p0,
+             llama_pos   p1) {
+    if (!kv) {
+        return;
+    }
 
-    for (uint32_t i = 0; i < cache.size; ++i) {
-        if (cache.cells[i].has_seq_id(seq_id)) {
-            result = std::max(result, cache.cells[i].pos);
-        }
+    kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
+}
+
+void llama_kv_cache_seq_keep(llama_kv_cache * kv, llama_seq_id seq_id) {
+    if (!kv) {
+        return;
     }
 
-    return result;
+    kv->seq_keep(seq_id);
 }
 
-void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
-    if (!cache.recurrent) {
-        cache.do_defrag = true;
+void llama_kv_cache_seq_add(
+        llama_kv_cache * kv,
+          llama_seq_id   seq_id,
+             llama_pos   p0,
+             llama_pos   p1,
+             llama_pos   delta) {
+    if (!kv) {
+        return;
     }
+
+    kv->seq_add(seq_id, p0, p1, delta);
 }
 
-int32_t llama_get_kv_cache_token_count(const struct llama_kv_cache & kv) {
-    int result = 0;
+void llama_kv_cache_seq_div(
+        llama_kv_cache * kv,
+          llama_seq_id   seq_id,
+             llama_pos   p0,
+             llama_pos   p1,
+                   int   d) {
+    if (!kv) {
+        return;
+    }
 
-    for (uint32_t i = 0; i < kv.size; i++) {
-        result += kv.cells[i].seq_id.size();
+    kv->seq_div(seq_id, p0, p1, d);
+}
+
+llama_pos llama_kv_cache_seq_pos_max(llama_kv_cache * kv, llama_seq_id seq_id) {
+    if (!kv) {
+        return 0;
     }
 
-    return result;
+    return kv->seq_pos_max(seq_id);
 }
 
-int32_t llama_get_kv_cache_used_cells(const struct llama_kv_cache & kv) {
-    return kv.used;
+void llama_kv_cache_defrag(llama_kv_cache * kv) {
+    if (!kv) {
+        return;
+    }
+
+    kv->defrag();
 }
 
-bool llama_kv_cache_can_shift(const struct llama_kv_cache & kv) {
-    return kv.can_shift;
+bool llama_kv_cache_can_shift(const llama_kv_cache * kv) {
+    if (!kv) {
+        return false;
+    }
+
+    return kv->get_can_shift();
 }
 
 //
 // kv cache view
 //
 
-struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max) {
-    struct llama_kv_cache_view result = {
+llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max) {
+    llama_kv_cache_view result = {
         /*.n_cells            = */ 0,
         /*.n_seq_max          = */ n_seq_max,
         /*.token_count        = */ 0,
-        /*.used_cells         = */ llama_get_kv_cache_used_cells(kv),
+        /*.used_cells         = */ llama_kv_cache_used_cells(&kv),
         /*.max_contiguous     = */ 0,
         /*.max_contiguous_idx = */ -1,
         /*.cells              = */ nullptr,
@@ -642,7 +1350,7 @@ struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache
     return result;
 }
 
-void llama_kv_cache_view_free(struct llama_kv_cache_view * view) {
+void llama_kv_cache_view_free(llama_kv_cache_view * view) {
     if (view->cells != nullptr) {
         free(view->cells);
         view->cells = nullptr;
@@ -653,18 +1361,25 @@ void llama_kv_cache_view_free(struct llama_kv_cache_view * view) {
     }
 }
 
-void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv) {
-    if (uint32_t(view->n_cells) < kv.size || view->cells == nullptr) {
-        view->n_cells = int32_t(kv.size);
-        void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
+void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv) {
+    // TODO: rework this in the future, for now quick hack
+    const llama_kv_cache_unified * kvu = dynamic_cast<const llama_kv_cache_unified *>(kv);
+    if (kvu == nullptr) {
+        LLAMA_LOG_ERROR("%s: the kv_cache_view currently works only with llama_kv_cache_unified\n", __func__);
+        return;
+    }
+
+    if (uint32_t(view->n_cells) < kvu->size || view->cells == nullptr) {
+        view->n_cells = int32_t(kvu->size);
+        void * p = realloc(view->cells, sizeof(llama_kv_cache_view_cell) * view->n_cells);
         GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
-        view->cells = (struct llama_kv_cache_view_cell *)p;
+        view->cells = (llama_kv_cache_view_cell *)p;
         p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_seq_max * view->n_cells);
         GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
         view->cells_sequences = (llama_seq_id *)p;
     }
 
-    const std::vector<llama_kv_cell> & kv_cells = kv.cells;
+    const std::vector<llama_kv_cell> & kv_cells = kvu->cells;
     llama_kv_cache_view_cell * c_curr = view->cells;
     llama_seq_id * cs_curr = view->cells_sequences;
     int32_t used_cells = 0;
@@ -673,7 +1388,7 @@ void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct
     uint32_t max_contig = 0;
     int32_t max_contig_idx = -1;
 
-    for (int32_t i = 0; i < int32_t(kv.size); i++, c_curr++, cs_curr += view->n_seq_max) {
+    for (int32_t i = 0; i < int32_t(kvu->size); i++, c_curr++, cs_curr += view->n_seq_max) {
         const size_t curr_size = kv_cells[i].seq_id.size();
         token_count += curr_size;
         c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
@@ -711,8 +1426,8 @@ void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct
     view->max_contiguous_idx = max_contig_idx;
     view->token_count = token_count;
     view->used_cells = used_cells;
-    if (uint32_t(used_cells) != kv.used) {
+    if (uint32_t(used_cells) != kvu->used) {
         LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
-            __func__, kv.used, used_cells);
+            __func__, kvu->used, used_cells);
     }
 }
index 1ce0850ec81bb60721f72552902b48816c917334..0a7ff8a4ea3e685589215932eec0534aa4959f19 100644 (file)
@@ -1,12 +1,29 @@
 #pragma once
 
 #include "llama.h"
+#include "llama-io.h"
+#include "llama-memory.h"
 
 #include "ggml-cpp.h"
 
+#include <functional>
 #include <set>
 #include <vector>
-#include <algorithm>
+
+struct llama_cparams;
+struct llama_hparams;
+struct llama_ubatch;
+
+struct llama_kv_cache : public llama_memory_i {
+    using llama_memory_i::llama_memory_i;
+
+    virtual int32_t  get_n_tokens()   const = 0;
+    virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
+
+    virtual bool get_can_shift() const = 0;
+
+    bool get_can_edit() const override { return get_can_shift(); }
+};
 
 struct llama_kv_cell {
     llama_pos pos   = -1;
@@ -29,55 +46,6 @@ struct llama_kv_cell {
     }
 };
 
-// ring-buffer of cached KV data
-struct llama_kv_cache {
-    bool has_shift = false;
-    bool do_defrag = false;
-    bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
-    bool v_trans   = true;  // the value tensor is transposed
-    bool can_shift = false;
-
-    // Note: The value of head isn't only used to optimize searching
-    // for a free KV slot. llama_decode_impl also uses it, so it
-    // cannot be freely changed after a slot has been allocated.
-    uint32_t head = 0;
-    uint32_t size = 0;
-    uint32_t used = 0; // used cells (i.e. at least one seq_id)
-
-    // computed before each graph build
-    uint32_t n = 0;
-
-    ggml_type type_k = GGML_TYPE_F16;
-    ggml_type type_v = GGML_TYPE_F16;
-
-    std::vector<llama_kv_cell> cells;
-
-    std::vector<struct ggml_tensor *> k_l; // per layer
-    std::vector<struct ggml_tensor *> v_l;
-
-    std::vector<ggml_context_ptr> ctxs;
-    std::vector<ggml_backend_buffer_ptr> bufs;
-
-    size_t total_size() const {
-        size_t size = 0;
-        for (const auto & buf : bufs) {
-            size += ggml_backend_buffer_get_size(buf.get());
-        }
-
-        return size;
-    }
-
-    // TODO: better data structures to reduce the cost of this operation
-    llama_pos max_pos() const {
-        llama_pos max_pos = -1;
-        for (const auto & cell : cells) {
-            max_pos = std::max(max_pos, cell.pos);
-        }
-
-        return max_pos;
-    }
-};
-
 // a structure holds information about the slot found in llama_kv_cache_find_slot
 struct llama_kv_cache_slot_info {
     std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
@@ -89,82 +57,131 @@ struct llama_kv_cache_slot_info {
     operator bool() const { return found; }
 };
 
-// TODO: maybe not needed
-uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams);
-
-bool llama_kv_cache_init(
-        struct llama_kv_cache & cache,
-            const llama_model & model,
+// ring-buffer of cached KV data
+// TODO: pimpl
+// TODO: add notion of max sequences
+class llama_kv_cache_unified : public llama_kv_cache {
+public:
+    // can be used to query data from the model if needed
+    struct callbacks {
+        std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
+    };
+
+    llama_kv_cache_unified(
+            const llama_hparams & hparams,
+            callbacks             cbs);
+
+    virtual ~llama_kv_cache_unified() = default;
+
+    // TODO: become constructor
+    bool init(
+            const llama_model & model,   // TODO: do not reference the model
           const llama_cparams & cparams,
                     ggml_type   type_k,
                     ggml_type   type_v,
                      uint32_t   kv_size,
                          bool   offload);
 
-// find an empty slot of size "n_tokens" in the cache
-// updates the cache head
-// returns a structure holding information about the slot found
-// Note: On success, it's important that cache.head points
-// to the first cell of the slot.
-struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
-           struct llama_kv_cache & cache,
-       const struct llama_ubatch & batch);
+    int32_t  get_n_tokens()   const override;
+    uint32_t get_used_cells() const override;
 
-// find how many cells are currently in use
-uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache);
+    size_t total_size() const;
 
-void llama_kv_cache_clear(struct llama_kv_cache & cache);
+    // TODO: better data structures to reduce the cost of this operation
+    llama_pos pos_max() const;
 
-bool llama_kv_cache_seq_rm(
-        struct llama_kv_cache & cache,
-                 llama_seq_id   seq_id,
-                    llama_pos   p0,
-                    llama_pos   p1);
+    void clear() override;
+    void defrag() override;
 
-void llama_kv_cache_seq_cp(
-        struct llama_kv_cache & cache,
-                 llama_seq_id   seq_id_src,
-                 llama_seq_id   seq_id_dst,
-                    llama_pos   p0,
-                    llama_pos   p1);
+    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
+    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
+    void seq_keep(llama_seq_id seq_id) override;
+    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) override;
+    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
 
-void llama_kv_cache_seq_keep(
-        struct llama_kv_cache & cache,
-                 llama_seq_id   seq_id);
+    llama_pos seq_pos_max(llama_seq_id seq_id) override;
 
-void llama_kv_cache_seq_add(
-        struct llama_kv_cache & cache,
-                 llama_seq_id   seq_id,
-                    llama_pos   p0,
-                    llama_pos   p1,
-                    llama_pos   delta);
+    bool get_can_shift() const override;
 
-void llama_kv_cache_seq_div(
-        struct llama_kv_cache & cache,
-                 llama_seq_id   seq_id,
-                    llama_pos   p0,
-                    llama_pos   p1,
-                          int   d);
+    // find an empty slot of size "n_tokens" in the cache
+    // updates the cache head
+    // returns a structure holding information about the slot found
+    // Note: On success, it's important that cache.head points
+    // to the first cell of the slot.
+    llama_kv_cache_slot_info find_slot(const llama_ubatch & batch);
 
-llama_pos llama_kv_cache_seq_pos_max(
-        struct llama_kv_cache & cache,
-                 llama_seq_id   seq_id);
+    // TODO: maybe not needed
+    uint32_t get_padding(const llama_cparams & cparams) const;
 
-void llama_kv_cache_defrag(struct llama_kv_cache & cache);
+    // find how many cells are currently in use
+    uint32_t cell_max() const;
 
-int32_t llama_get_kv_cache_token_count(const struct llama_kv_cache & kv);
+    size_t size_k_bytes() const;
+    size_t size_v_bytes() const;
 
-int32_t llama_get_kv_cache_used_cells(const struct llama_kv_cache & kv);
+    // defrag
 
-bool llama_kv_cache_can_shift(const struct llama_kv_cache & kv);
+    struct {
+        std::vector<uint32_t> ids;
+    } defrag_info;
 
-//
-// kv cache view
-//
+    // return true if cells have been moved
+    bool defrag_prepare(int32_t n_max_nodes);
+
+    // state save/load
+
+    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
+    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1);
 
-struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max);
+    // members
 
-void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv);
+    const llama_hparams & hparams;
+
+    callbacks cbs;
+
+    bool has_shift = false;
+    bool do_defrag = false;
+
+    // TODO: remove this and implement llama_kv_cache_recurrent instead
+    bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
+
+    bool v_trans   = true;  // the value tensor is transposed
+    bool can_shift = false;
+
+    // Note: The value of head isn't only used to optimize searching
+    // for a free KV slot. llama_decode_impl also uses it, so it
+    // cannot be freely changed after a slot has been allocated.
+    uint32_t head = 0;
+    uint32_t size = 0;
+    uint32_t used = 0; // used cells (i.e. at least one seq_id)
+
+    // computed before each graph build
+    uint32_t n = 0;
+
+    std::vector<llama_kv_cell> cells;
+
+    std::vector<ggml_tensor *> k_l; // per layer
+    std::vector<ggml_tensor *> v_l;
+
+private:
+    ggml_type type_k = GGML_TYPE_F16;
+    ggml_type type_v = GGML_TYPE_F16;
+
+    std::vector<ggml_context_ptr>        ctxs;
+    std::vector<ggml_backend_buffer_ptr> bufs;
+
+    void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
+    void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
+
+    bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
+    bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
+};
+
+// TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified
+//class llama_kv_cache_recurrent : public llama_kv_cache_unified {
+//public:
+//    using llama_kv_cache_unified::llama_kv_cache_unified;
+//};
 
 //
 // kv cache restore
@@ -184,13 +201,15 @@ struct llama_kv_slot_restorer {
 
     bool do_restore = false;
 
-    explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
+    llama_kv_cache_unified & cache;
+
+    explicit llama_kv_slot_restorer(llama_kv_cache_unified & cache) : cache(cache) {
         old_state.head = cache.head;
         old_state.n    = cache.n;
     }
 
     // saves a slot information for future restoration
-    void save(const struct llama_kv_cache_slot_info & slot) {
+    void save(const llama_kv_cache_slot_info & slot) {
         if (slot) {
             do_restore = true;
             if (slot.boundaries.first != slot.boundaries.second) {
@@ -201,19 +220,68 @@ struct llama_kv_slot_restorer {
 
     // must be explicitly called to restore the kv_cache state
     // and rollback changes from all llama_kv_cache_find_slot calls
-    void restore(struct llama_kv_cache & cache) {
+    void restore() {
         if (do_restore) {
             cache.head = old_state.head;
             cache.n    = old_state.n;
 
             if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
-                llama_kv_cache_seq_rm(cache, -1, -1, -1);
+                cache.seq_rm(-1, -1, -1);
             } else {
                 for (auto & slot : slot_boundaries) {
-                    llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second);
+                    cache.seq_rm(-1, slot.first, slot.second);
                 }
             }
         }
     }
 };
 
+// TODO: maybe become part of the public llama_kv_cache in the future
+int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv);
+
+int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv);
+
+void llama_kv_cache_clear(llama_kv_cache * kv);
+
+bool llama_kv_cache_seq_rm(
+        llama_kv_cache * kv,
+          llama_seq_id   seq_id,
+             llama_pos   p0,
+             llama_pos   p1);
+
+void llama_kv_cache_seq_cp(
+        llama_kv_cache * kv,
+          llama_seq_id   seq_id_src,
+          llama_seq_id   seq_id_dst,
+             llama_pos   p0,
+             llama_pos   p1);
+
+void llama_kv_cache_seq_keep(llama_kv_cache * kv, llama_seq_id seq_id);
+
+void llama_kv_cache_seq_add(
+        llama_kv_cache * kv,
+          llama_seq_id   seq_id,
+             llama_pos   p0,
+             llama_pos   p1,
+             llama_pos   delta);
+
+void llama_kv_cache_seq_div(
+        llama_kv_cache * kv,
+          llama_seq_id   seq_id,
+             llama_pos   p0,
+             llama_pos   p1,
+                   int   d);
+
+llama_pos llama_kv_cache_seq_pos_max(llama_kv_cache * kv, llama_seq_id seq_id);
+
+void llama_kv_cache_defrag(llama_kv_cache * kv);
+
+bool llama_kv_cache_can_shift(const llama_kv_cache * kv);
+
+//
+// kv cache view
+//
+
+llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max);
+
+void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv);
diff --git a/src/llama-memory.cpp b/src/llama-memory.cpp
new file mode 100644 (file)
index 0000000..1017325
--- /dev/null
@@ -0,0 +1 @@
+#include "llama-memory.h"
diff --git a/src/llama-memory.h b/src/llama-memory.h
new file mode 100644 (file)
index 0000000..69e6e34
--- /dev/null
@@ -0,0 +1,21 @@
+#pragma once
+
+#include "llama.h"
+
+// general concept of LLM memory
+// the KV cache is a type of LLM memory, but there can be other types
+class llama_memory_i {
+public:
+    virtual void clear() = 0;
+    virtual void defrag() = 0;
+
+    virtual bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) = 0;
+    virtual void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
+    virtual void seq_keep(llama_seq_id seq_id) = 0;
+    virtual void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) = 0;
+    virtual void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) = 0;
+
+    virtual llama_pos seq_pos_max(llama_seq_id seq_id) = 0;
+
+    virtual bool get_can_edit() const = 0;
+};
index 9f75589d805a9f3bb2793d9e1c2fdd4ed4680eea..522219c0122428efeb593301e2beabbc6218aa07 100644 (file)
@@ -2,12 +2,17 @@
 
 #include "llama-impl.h"
 #include "llama-mmap.h"
+#include "llama-batch.h"
+#include "llama-cparams.h"
 #include "llama-model-loader.h"
+#include "llama-kv-cache.h"
 
 #include "ggml-cpp.h"
 
 #include <algorithm>
 #include <cassert>
+#include <cmath>
+#include <cfloat>
 #include <cstring>
 #include <cmath>
 #include <functional>
@@ -244,6 +249,7 @@ static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hpara
             return cur_buft;
         }
     }
+
     return nullptr;
 }
 
@@ -302,7 +308,7 @@ static buft_list_t make_cpu_buft_list(const std::vector<ggml_backend_dev_t> & de
 }
 
 // GPU: split if LLAMA_SPLIT_MODE_ROW -> GPU
-static buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, enum llama_split_mode split_mode, const float * tensor_split) {
+static buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, llama_split_mode split_mode, const float * tensor_split) {
     buft_list_t buft_list;
 
     // add the device split buffer type if requested and available
@@ -369,7 +375,7 @@ struct llama_model::impl {
     std::vector<layer_dev> dev_layer;
 };
 
-llama_model::llama_model(const struct llama_model_params & params) : params(params), pimpl(std::make_unique<impl>()) {
+llama_model::llama_model(const llama_model_params & params) : params(params), pimpl(std::make_unique<impl>()) {
 }
 
 llama_model::~llama_model() {}
@@ -391,7 +397,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
 
     // get metadata as string
     for (int i = 0; i < gguf_get_n_kv(ctx); i++) {
-        enum gguf_type type = gguf_get_kv_type(ctx, i);
+        gguf_type type = gguf_get_kv_type(ctx, i);
         if (type == GGUF_TYPE_ARRAY) {
             continue;
         }
@@ -1444,7 +1450,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
 
             // skip unused tensors
             if (info.op == GGML_OP_NONE) {
-                LLAMA_LOG_WARN("model has unused tensor %s -- ignoring\n", tn.str().c_str());
+                const size_t nbytes = ggml_nbytes(t_meta);
+                LLAMA_LOG_WARN("model has unused tensor %s (size = %zu bytes) -- ignoring\n", tn.str().c_str(), nbytes);
+
+                ml.size_data -= nbytes;
                 ml.n_created++;
 
                 return nullptr;
@@ -3631,8 +3640,8 @@ size_t llama_model::size() const {
     return pimpl->n_bytes;
 }
 
-size_t llama_model::max_nodes() const {
-    return std::max<size_t>(8192, tensors_by_name.size()*5);
+size_t llama_model::n_tensors() const {
+    return tensors_by_name.size();
 }
 
 size_t llama_model::n_devices() const {
@@ -3745,7 +3754,7 @@ void llama_model::print_info() const {
         LLAMA_LOG_INFO("%s: n_expert_shared      = %d\n",     __func__, hparams.n_expert_shared);
         LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n",   __func__, hparams.expert_weights_scale);
         LLAMA_LOG_INFO("%s: expert_weights_norm  = %d\n",     __func__, hparams.expert_weights_norm);
-        LLAMA_LOG_INFO("%s: expert_gating_func   = %s\n",     __func__, llama_expert_gating_func_name((enum llama_expert_gating_func_type) hparams.expert_gating_func));
+        LLAMA_LOG_INFO("%s: expert_gating_func   = %s\n",     __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
         LLAMA_LOG_INFO("%s: rope_yarn_log_mul    = %.4f\n",   __func__, hparams.rope_yarn_log_mul);
     }
 
@@ -3821,9 +3830,9 @@ ggml_backend_buffer_type_t llama_model::select_buft(int il) const {
             });
 }
 
-const struct ggml_tensor * llama_model::get_tensor(const char * name) const {
+const ggml_tensor * llama_model::get_tensor(const char * name) const {
     auto it = std::find_if(tensors_by_name.begin(), tensors_by_name.end(),
-            [name](const std::pair<std::string, struct ggml_tensor *> & it) {
+            [name](const std::pair<std::string, ggml_tensor *> & it) {
                 return it.first == name;
             });
     if (it == tensors_by_name.end()) {
@@ -3833,255 +3842,7556 @@ const struct ggml_tensor * llama_model::get_tensor(const char * name) const {
     return it->second;
 }
 
-//
-// interface implementation
-//
+struct llm_build_llama : public llm_graph_context {
+    llm_build_llama(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
 
-struct llama_model_params llama_model_default_params() {
-    struct llama_model_params result = {
-        /*.devices                     =*/ nullptr,
-        /*.n_gpu_layers                =*/ 0,
-        /*.split_mode                  =*/ LLAMA_SPLIT_MODE_LAYER,
-        /*.main_gpu                    =*/ 0,
-        /*.tensor_split                =*/ nullptr,
-        /*.progress_callback           =*/ nullptr,
-        /*.progress_callback_user_data =*/ nullptr,
-        /*.kv_overrides                =*/ nullptr,
-        /*.vocab_only                  =*/ false,
-        /*.use_mmap                    =*/ true,
-        /*.use_mlock                   =*/ false,
-        /*.check_tensors               =*/ false,
-    };
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
 
-#ifdef GGML_USE_METAL
-    // note: we usually have plenty of VRAM, so by default offload all layers to the GPU
-    result.n_gpu_layers = 999;
-#endif
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
 
-    return result;
-}
+        inpL = build_inp_embd(model.tok_embd);
 
-const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model) {
-    return &model->vocab;
-}
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
 
-void llama_free_model(struct llama_model * model) {
-    llama_model_free(model);
-}
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
 
-void llama_model_free(struct llama_model * model) {
-    delete model;
-}
+        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
 
-int32_t llama_model_n_ctx_train(const struct llama_model * model) {
-    return model->hparams.n_ctx_train;
-}
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
 
-int32_t llama_model_n_embd(const struct llama_model * model) {
-    return model->hparams.n_embd;
-}
+            // self-attention
+            {
+                // rope freq factors for llama3; may return nullptr for llama2 and other models
+                ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
+
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
 
-int32_t llama_model_n_layer(const struct llama_model * model) {
-    return model->hparams.n_layer;
-}
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
 
-int32_t llama_model_n_head(const struct llama_model * model) {
-    return model->hparams.n_head();
-}
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
 
-int32_t llama_model_n_head_kv(const struct llama_model * model) {
-    return model->hparams.n_head_kv();
-}
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, kq_scale, il);
+            }
 
-// deprecated
-int32_t llama_n_ctx_train(const struct llama_model * model) {
-    return llama_model_n_ctx_train(model);
-}
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
 
-// deprecated
-int32_t llama_n_embd(const struct llama_model * model) {
-    return llama_model_n_embd(model);
-}
+            // For Granite architecture
+            if (hparams.f_residual_scale) {
+                cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
+            }
 
-// deprecated
-int32_t llama_n_layer(const struct llama_model * model) {
-    return llama_model_n_layer(model);
-}
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
 
-// deprecated
-int32_t llama_n_head(const struct llama_model * model) {
-    return llama_model_n_head(model);
-}
+            // feed-forward network
+            if (model.layers[il].ffn_gate_inp == nullptr) {
 
-enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
-    switch (model->arch) {
-        // these models do not use RoPE
-        case LLM_ARCH_GPT2:
-        case LLM_ARCH_GPTJ:
-        case LLM_ARCH_MPT:
-        case LLM_ARCH_REFACT:
-        case LLM_ARCH_BLOOM:
-        case LLM_ARCH_MAMBA:
-        case LLM_ARCH_JINA_BERT_V2:
-        case LLM_ARCH_T5:
-        case LLM_ARCH_T5ENCODER:
-        case LLM_ARCH_JAIS:
-        case LLM_ARCH_RWKV6:
-        case LLM_ARCH_RWKV6QWEN2:
-        case LLM_ARCH_WAVTOKENIZER_DEC:
-            return LLAMA_ROPE_TYPE_NONE;
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
 
-        // use what we call a normal RoPE, operating on pairs of consecutive head values
-        case LLM_ARCH_LLAMA:
-        case LLM_ARCH_DECI:
-        case LLM_ARCH_BAICHUAN:
-        case LLM_ARCH_STARCODER:
-        case LLM_ARCH_PLAMO:
-        case LLM_ARCH_ORION:
-        case LLM_ARCH_INTERNLM2:
-        case LLM_ARCH_MINICPM:
-        case LLM_ARCH_XVERSE:
-        case LLM_ARCH_COMMAND_R:
-        case LLM_ARCH_COHERE2:
-        case LLM_ARCH_OLMO:
-        case LLM_ARCH_ARCTIC:
-        case LLM_ARCH_DEEPSEEK:
-        case LLM_ARCH_DEEPSEEK2:
-        case LLM_ARCH_CHATGLM:
-        case LLM_ARCH_GRANITE:
-        case LLM_ARCH_GRANITE_MOE:
-        case LLM_ARCH_CHAMELEON:
-            return LLAMA_ROPE_TYPE_NORM;
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            } else {
+                // MoE branch
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_moe_ffn(cur,
+                        model.layers[il].ffn_gate_inp,
+                        model.layers[il].ffn_up_exps,
+                        model.layers[il].ffn_gate_exps,
+                        model.layers[il].ffn_down_exps,
+                        nullptr,
+                        n_expert, n_expert_used,
+                        LLM_FFN_SILU, true,
+                        false, 0.0,
+                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                        il);
+                cb(cur, "ffn_moe_out", il);
+            }
 
-        // the pairs of head values are offset by n_rot/2
-        case LLM_ARCH_FALCON:
-        case LLM_ARCH_GROK:
-        case LLM_ARCH_DBRX:
-        case LLM_ARCH_BERT:
-        case LLM_ARCH_NOMIC_BERT:
-        case LLM_ARCH_STABLELM:
-        case LLM_ARCH_BITNET:
-        case LLM_ARCH_QWEN:
-        case LLM_ARCH_QWEN2:
-        case LLM_ARCH_QWEN2MOE:
-        case LLM_ARCH_OLMO2:
-        case LLM_ARCH_OLMOE:
-        case LLM_ARCH_PHI2:
-        case LLM_ARCH_PHI3:
-        case LLM_ARCH_PHIMOE:
-        case LLM_ARCH_GEMMA:
-        case LLM_ARCH_GEMMA2:
-        case LLM_ARCH_GEMMA3:
-        case LLM_ARCH_STARCODER2:
-        case LLM_ARCH_OPENELM:
-        case LLM_ARCH_GPTNEOX:
-        case LLM_ARCH_CODESHELL:
-        case LLM_ARCH_NEMOTRON:
-        case LLM_ARCH_EXAONE:
-        case LLM_ARCH_MINICPM3:
-            return LLAMA_ROPE_TYPE_NEOX;
+            // For Granite architecture
+            if (hparams.f_residual_scale) {
+                cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
+            }
 
-        case LLM_ARCH_QWEN2VL:
-            return LLAMA_ROPE_TYPE_MROPE;
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
 
-        // all model arches should be listed explicitly here
-        case LLM_ARCH_UNKNOWN:
-            GGML_ABORT("unknown architecture");
-    }
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
 
-    return LLAMA_ROPE_TYPE_NONE;
-}
+            // input for next layer
+            inpL = cur;
+        }
 
-float llama_model_rope_freq_scale_train(const struct llama_model * model) {
-    return model->hparams.rope_freq_scale_train;
-}
+        cur = inpL;
 
-int32_t llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size) {
-    const auto & it = model->gguf_kv.find(key);
-    if (it == model->gguf_kv.end()) {
-        if (buf_size > 0) {
-            buf[0] = '\0';
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        // For Granite architecture
+        if (hparams.f_logit_scale) {
+            cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
         }
-        return -1;
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
     }
-    return snprintf(buf, buf_size, "%s", it->second.c_str());
-}
+};
 
-int32_t llama_model_meta_count(const struct llama_model * model) {
-    return (int)model->gguf_kv.size();
-}
+struct llm_build_deci : public llm_graph_context {
+    llm_build_deci(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
 
-int32_t llama_model_meta_key_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size) {
-    if (i < 0 || i >= (int)model->gguf_kv.size()) {
-        if (buf_size > 0) {
-            buf[0] = '\0';
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+            const int64_t n_head_kv = hparams.n_head_kv(il);
+            const int64_t n_head    = hparams.n_head(il);
+
+            if (n_head == 0) {
+                // attention-free layer of Llama-3_1-Nemotron-51B
+                cur = inpL;
+            } else {
+                // norm
+                cur = build_norm(inpL,
+                        model.layers[il].attn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "attn_norm", il);
+            }
+
+            if (n_head > 0 && n_head_kv == 0) {
+                // "linear attention" of Llama-3_1-Nemotron-51B
+                cur = build_lora_mm(model.layers[il].wo, cur);
+                cb(cur, "wo", il);
+            } else if (n_head > 0) {
+                // self-attention
+                // rope freq factors for llama3; may return nullptr for llama2 and other models
+                ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
+
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, kq_scale, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            // For Granite architecture
+            if (hparams.f_residual_scale) {
+                cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
+            }
+
+            // modified to support attention-free layer of Llama-3_1-Nemotron-51B
+            ggml_tensor * ffn_inp = cur;
+            if (n_head > 0) {
+                ffn_inp = ggml_add(ctx0, cur, inpSA);
+                cb(ffn_inp, "ffn_inp", il);
+            }
+
+            // feed-forward network
+            if (model.layers[il].ffn_gate_inp == nullptr) {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            // For Granite architecture
+            if (hparams.f_residual_scale) {
+                cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
         }
-        return -1;
-    }
-    auto it = model->gguf_kv.begin();
-    std::advance(it, i);
-    return snprintf(buf, buf_size, "%s", it->first.c_str());
-}
 
-int32_t llama_model_meta_val_str_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size) {
-    if (i < 0 || i >= (int)model->gguf_kv.size()) {
-        if (buf_size > 0) {
-            buf[0] = '\0';
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        // For Granite architecture
+        if (hparams.f_logit_scale) {
+            cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
         }
-        return -1;
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
     }
-    auto it = model->gguf_kv.begin();
-    std::advance(it, i);
-    return snprintf(buf, buf_size, "%s", it->second.c_str());
-}
+};
 
-int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) {
-    return snprintf(buf, buf_size, "%s", model->desc().c_str());
-}
+struct llm_build_baichuan : public llm_graph_context {
+    llm_build_baichuan(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
 
-uint64_t llama_model_size(const struct llama_model * model) {
-    return model->size();
-}
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
 
-const char * llama_model_chat_template(const struct llama_model * model, const char * name) {
-    const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N)
-        : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
-    const auto & it = model->gguf_kv.find(key);
-    if (it == model->gguf_kv.end()) {
-        return nullptr;
-    }
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
 
-    return it->second.c_str();
-}
+        inpL = build_inp_embd(model.tok_embd);
 
-uint64_t llama_model_n_params(const struct llama_model * model) {
-    return model->n_elements();
-}
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = model.type == LLM_TYPE_7B ? build_inp_pos() : nullptr;
 
-bool llama_model_has_encoder(const struct llama_model * model) {
-    switch (model->arch) {
-        case LLM_ARCH_T5:        return true;
-        case LLM_ARCH_T5ENCODER: return true;
-        default:                 return false;
-    }
-}
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
 
-bool llama_model_has_decoder(const struct llama_model * model) {
-    switch (model->arch) {
-        case LLM_ARCH_T5ENCODER: return false;
-        default:                 return true;
-    }
-}
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
 
-llama_token llama_model_decoder_start_token(const struct llama_model * model) {
-    return model->hparams.dec_start_token_id;
-}
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
 
-bool llama_model_is_recurrent(const struct llama_model * model) {
-    switch (model->arch) {
-        case LLM_ARCH_MAMBA:  return true;
-        case LLM_ARCH_RWKV6:  return true;
-        case LLM_ARCH_RWKV6QWEN2: return true;
-        default:              return false;
+            // self-attention
+            {
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                switch (model.type) {
+                    case LLM_TYPE_7B:
+                        Qcur = ggml_rope_ext(
+                                ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                                n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                                ext_factor, attn_factor, beta_fast, beta_slow
+                                );
+                        Kcur = ggml_rope_ext(
+                                ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                                n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                                ext_factor, attn_factor, beta_fast, beta_slow
+                                );
+                        break;
+                    case LLM_TYPE_13B:
+                        Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd/n_head, n_head, n_tokens);
+                        Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd/n_head, n_head, n_tokens);
+                        break;
+                    default:
+                        GGML_ABORT("fatal error");
+                }
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
     }
+};
+
+struct llm_build_xverse : public llm_graph_context {
+    llm_build_xverse(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_falcon : public llm_graph_context {
+    llm_build_falcon(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * attn_norm;
+
+            attn_norm = build_norm(inpL,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, il);
+            cb(attn_norm, "attn_norm", il);
+
+            // self-attention
+            {
+                if (model.layers[il].attn_norm_2) {
+                    // Falcon-40B
+                    cur = build_norm(inpL,
+                            model.layers[il].attn_norm_2,
+                            model.layers[il].attn_norm_2_b,
+                            LLM_NORM, il);
+                    cb(cur, "attn_norm_2", il);
+                } else {
+                    cur = attn_norm;
+                }
+
+                cur = build_lora_mm(model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+
+                // using mode = 2 for neox mode
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
+                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
+                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur       = ggml_get_rows(ctx0,       cur, inp_out_ids);
+                inpL      = ggml_get_rows(ctx0,      inpL, inp_out_ids);
+                attn_norm = ggml_get_rows(ctx0, attn_norm, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = cur;
+
+            // feed forward
+            {
+                cur = build_ffn(attn_norm, // !! use the attn norm, not the result
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        NULL,                      NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = ggml_add(ctx0, cur, inpL);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        // norm
+        cur = build_norm(cur,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_grok : public llm_graph_context {
+    llm_build_grok(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // multiply by embedding_multiplier_scale of 78.38367176906169
+        inpL = ggml_scale(ctx0, inpL, 78.38367176906169f);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            // Grok
+            // if attn_out_norm is present then apply it before adding the input
+            if (model.layers[il].attn_out_norm) {
+                cur = build_norm(cur,
+                        model.layers[il].attn_out_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "attn_out_norm", il);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            // MoE branch
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_moe_ffn(cur,
+                    model.layers[il].ffn_gate_inp,
+                    model.layers[il].ffn_up_exps,
+                    model.layers[il].ffn_gate_exps,
+                    model.layers[il].ffn_down_exps,
+                    nullptr,
+                    n_expert, n_expert_used,
+                    LLM_FFN_GELU, true,
+                    false, 0.0,
+                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                    il);
+            cb(cur, "ffn_moe_out", il);
+
+            // Grok
+            // if layer_out_norm is present then apply it before adding the input
+            // Idea: maybe ffn_out_norm is a better name
+            if (model.layers[il].layer_out_norm) {
+                cur = build_norm(cur,
+                        model.layers[il].layer_out_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "layer_out_norm", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        // Grok
+        // multiply logits by output_multiplier_scale of 0.5773502691896257
+
+        cur = ggml_scale(ctx0, cur, 0.5773502691896257f);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_dbrx : public llm_graph_context {
+    llm_build_dbrx(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                ggml_tensor * Qcur = nullptr;
+                ggml_tensor * Kcur = nullptr;
+                ggml_tensor * Vcur = nullptr;
+
+                cur = build_lora_mm(model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
+                cb(cur, "wqkv_clamped", il);
+
+                Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            // MoE branch
+            cur = build_norm(ffn_inp,
+                    model.layers[il].attn_out_norm, NULL,
+                    LLM_NORM, il);
+            cb(cur, "attn_out_norm", il);
+
+            cur = build_moe_ffn(cur,
+                    model.layers[il].ffn_gate_inp,
+                    model.layers[il].ffn_up_exps,
+                    model.layers[il].ffn_gate_exps,
+                    model.layers[il].ffn_down_exps,
+                    nullptr,
+                    n_expert, n_expert_used,
+                    LLM_FFN_SILU, true,
+                    false, 0.0,
+                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                    il);
+            cb(cur, "ffn_moe_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_starcoder : public llm_graph_context {
+    llm_build_starcoder(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        ggml_tensor * pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
+        cb(pos, "pos_embd", -1);
+
+        inpL = ggml_add(ctx0, inpL, pos);
+        cb(inpL, "inpL", -1);
+
+        for (int il = 0; il < n_layer; ++il) {
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                cur = build_lora_mm(model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                cb(cur, "bqkv", il);
+
+                ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            // add the input
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // FF
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm,
+                        model.layers[il].ffn_norm_b,
+                        LLM_NORM, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = build_norm(inpL,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_refact : public llm_graph_context {
+    llm_build_refact(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                cb(Kcur, "Kcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                cb(Qcur, "Qcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_bert : public llm_graph_context {
+    llm_build_bert(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+        ggml_tensor * inp_pos = nullptr;
+
+        if (model.arch != LLM_ARCH_JINA_BERT_V2) {
+            inp_pos = build_inp_pos();
+        }
+
+        // construct input embeddings (token, type, position)
+        inpL = build_inp_embd(model.tok_embd);
+
+        // token types are hardcoded to zero ("Sentence A")
+        ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
+        inpL = ggml_add(ctx0, inpL, type_row0);
+        if (model.arch == LLM_ARCH_BERT) {
+            inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL);
+        }
+        cb(inpL, "inp_embd", -1);
+
+        // embed layer norm
+        inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
+        cb(inpL, "inp_norm", -1);
+
+        auto * inp_attn = build_attn_inp_no_cache();
+
+        // iterate layers
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * cur = inpL;
+
+            ggml_tensor * Qcur;
+            ggml_tensor * Kcur;
+            ggml_tensor * Vcur;
+
+            // self-attention
+            if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) {
+                Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq);
+
+                if (model.layers[il].attn_q_norm) {
+                    Qcur = build_norm(Qcur,
+                            model.layers[il].attn_q_norm,
+                            model.layers[il].attn_q_norm_b,
+                            LLM_NORM, il);
+                }
+
+                Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk);
+
+                if (model.layers[il].attn_k_norm) {
+                    Kcur = build_norm(Kcur,
+                            model.layers[il].attn_k_norm,
+                            model.layers[il].attn_k_norm_b,
+                            LLM_NORM, il);
+                }
+
+                Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+            } else {
+                // compute Q and K and RoPE them
+                cur = build_lora_mm(model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+            }
+
+            cb(Qcur, "Qcur", il);
+            cb(Kcur, "Kcur", il);
+            cb(Vcur, "Vcur", il);
+
+            cur = build_attn(inp_attn, gf,
+                    model.layers[il].wo, model.layers[il].bo,
+                    Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            cb(cur, "kqv_out", il);
+
+            if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            // re-add the layer input
+            cur = ggml_add(ctx0, cur, inpL);
+
+            // attention layer norm
+            cur = build_norm(cur, model.layers[il].attn_out_norm, model.layers[il].attn_out_norm_b, LLM_NORM, il);
+
+            if (model.layers[il].attn_norm_2 != nullptr) {
+                cur = ggml_add(ctx0, cur, inpL); // re-add the layer input
+                cur = build_norm(cur, model.layers[il].attn_norm_2, model.layers[il].attn_norm_2_b, LLM_NORM, il);
+            }
+
+            ggml_tensor * ffn_inp = cur;
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            if (model.arch == LLM_ARCH_BERT) {
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, il);
+            } else if (model.arch == LLM_ARCH_JINA_BERT_V2) {
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL,                        NULL,
+                        model.layers[il].ffn_gate, NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_PAR, il);
+            } else {
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+            }
+            cb(cur, "ffn_out", il);
+
+            // attentions bypass the intermediate layer
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            // output layer norm
+            cur = build_norm(cur, model.layers[il].layer_out_norm, model.layers[il].layer_out_norm_b, LLM_NORM, il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cb(cur, "result_embd", -1);
+        res->t_embd = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_bloom : public llm_graph_context {
+    llm_build_bloom(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        inpL = build_norm(inpL,
+                model.tok_norm,
+                model.tok_norm_b,
+                LLM_NORM, -1);
+        cb(inpL, "inp_norm", -1);
+
+        for (int il = 0; il < n_layer; ++il) {
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                cur = build_lora_mm(model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                cb(cur, "bqkv", il);
+
+                ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            // Add the input
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // FF
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm,
+                        model.layers[il].ffn_norm_b,
+                        LLM_NORM, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = build_norm(inpL,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_mpt : public llm_graph_context {
+    llm_build_mpt(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * pos;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        if (model.pos_embd) {
+            // inp_pos - contains the positions
+            ggml_tensor * inp_pos = build_inp_pos();
+            pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
+            cb(pos, "pos_embd", -1);
+
+            inpL = ggml_add(ctx0, inpL, pos);
+            cb(inpL, "inpL", -1);
+        }
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * attn_norm;
+
+            attn_norm = build_norm(inpL,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, il);
+            cb(attn_norm, "attn_norm", il);
+
+            // self-attention
+            {
+                cur = attn_norm;
+
+                cur = build_lora_mm(model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                if (model.layers[il].bqkv){
+                    cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                    cb(cur, "bqkv", il);
+                }
+
+                if (hparams.f_clamp_kqv > 0.0f) {
+                    cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
+                    cb(cur, "wqkv_clamped", il);
+                }
+
+                ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                // Q/K Layernorm
+                if (model.layers[il].attn_q_norm) {
+                    Qcur = build_norm(Qcur,
+                            model.layers[il].attn_q_norm,
+                            model.layers[il].attn_q_norm_b,
+                            LLM_NORM, il);
+                    cb(Qcur, "Qcur", il);
+
+                    Kcur = build_norm(Kcur,
+                            model.layers[il].attn_k_norm,
+                            model.layers[il].attn_k_norm_b,
+                            LLM_NORM, il);
+                    cb(Kcur, "Kcur", il);
+
+                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+
+                    cur = build_attn(inp_attn, gf,
+                            model.layers[il].wo, model.layers[il].bo,
+                            Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                } else {
+                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+
+                    cur = build_attn(inp_attn, gf,
+                            model.layers[il].wo, model.layers[il].bo,
+                            Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+                }
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            // Add the input
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed forward
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm,
+                        model.layers[il].ffn_norm_b,
+                        LLM_NORM, il);
+                cb(cur, "ffn_norm", il);
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        model.layers[il].ffn_act,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_stablelm : public llm_graph_context {
+    llm_build_stablelm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, il);
+            cb(cur, "attn_norm", il);
+
+            ggml_tensor * inpSA = cur;
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                cb(Qcur, "Qcur", il);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                cb(Kcur, "Kcur", il);
+
+                if (model.layers[il].attn_q_norm) {
+                    Qcur = build_norm(Qcur,
+                            model.layers[il].attn_q_norm,
+                            NULL,
+                            LLM_NORM, il);
+                    cb(Qcur, "Qcur", il);
+                }
+                if (model.layers[il].attn_k_norm) {
+                    Kcur = build_norm(Kcur,
+                            model.layers[il].attn_k_norm,
+                            NULL,
+                            LLM_NORM, il);
+                    cb(Kcur, "Kcur", il);
+                }
+
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpL  = ggml_get_rows(ctx0,  inpL, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            {
+                if (model.layers[il].ffn_norm) {
+                    cur = build_norm(ffn_inp,
+                            model.layers[il].ffn_norm,
+                            model.layers[il].ffn_norm_b,
+                            LLM_NORM, il);
+                    cb(cur, "ffn_norm", il);
+                } else {
+                    // parallel residual
+                    cur = inpSA;
+                }
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_qwen : public llm_graph_context {
+    llm_build_qwen(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                cur = build_lora_mm(model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                cb(cur, "bqkv", il);
+
+                ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 2*sizeof(float)*(n_embd)));
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+
+                // using mode = 2 for neox mode
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
+                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
+                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward forward
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_qwen2 : public llm_graph_context {
+    llm_build_qwen2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_qwen2vl : public llm_graph_context {
+    llm_build_qwen2vl(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        int sections[4];
+        std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_multi(
+                        ctx0,
+                        ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_multi(
+                        ctx0,
+                        ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_qwen2moe : public llm_graph_context {
+    llm_build_qwen2moe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self_attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // MoE branch
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            ggml_tensor * moe_out =
+                build_moe_ffn(cur,
+                        model.layers[il].ffn_gate_inp,
+                        model.layers[il].ffn_up_exps,
+                        model.layers[il].ffn_gate_exps,
+                        model.layers[il].ffn_down_exps,
+                        nullptr,
+                        n_expert, n_expert_used,
+                        LLM_FFN_SILU, false,
+                        false, 0.0,
+                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                        il);
+            cb(cur, "ffn_moe_out", il);
+
+            // FFN shared expert
+            {
+                ggml_tensor * cur_gate_inp = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur);
+                cb(cur_gate_inp, "ffn_shexp_gate_inp", il);
+
+                // sigmoid
+                ggml_tensor * cur_gate = ggml_div(ctx0, ggml_silu(ctx0, cur_gate_inp), cur_gate_inp);
+                cb(cur_gate, "ffn_shexp_gate", il);
+
+                ggml_tensor * cur_ffn = build_ffn(cur,
+                        model.layers[il].ffn_up_shexp,   NULL, NULL,
+                        model.layers[il].ffn_gate_shexp, NULL, NULL,
+                        model.layers[il].ffn_down_shexp, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur_ffn, "ffn_shexp", il);
+
+                ggml_tensor * ffn_shexp_out = ggml_mul(ctx0, cur_ffn, cur_gate);
+                cb(ffn_shexp_out, "ffn_shexp_out", il);
+
+                moe_out = ggml_add(ctx0, moe_out, ffn_shexp_out);
+                cb(moe_out, "ffn_out", il);
+
+                cur = moe_out;
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_phi2 : public llm_graph_context {
+    llm_build_phi2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * attn_norm_output;
+        ggml_tensor * ffn_output;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            attn_norm_output = build_norm(inpL,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, il);
+            cb(attn_norm_output, "attn_norm", il);
+
+            // self-attention
+            {
+                ggml_tensor * Qcur = nullptr;
+                ggml_tensor * Kcur = nullptr;
+                ggml_tensor * Vcur = nullptr;
+
+                if (model.layers[il].wqkv) {
+                    cur = build_lora_mm(model.layers[il].wqkv, attn_norm_output);
+                    cb(cur, "wqkv", il);
+
+                    cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                    cb(cur, "bqkv", il);
+
+                    Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                    Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                    Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+                } else {
+                    Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq);
+                    Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk);
+                    Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv);
+                }
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
+                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                // with phi2, we scale the Q to avoid precision issues
+                // ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66
+                Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head)));
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
+                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur              = ggml_get_rows(ctx0,              cur, inp_out_ids);
+                inpL             = ggml_get_rows(ctx0,             inpL, inp_out_ids);
+                attn_norm_output = ggml_get_rows(ctx0, attn_norm_output, inp_out_ids);
+            }
+
+            // FF
+            {
+                ffn_output = build_ffn(attn_norm_output,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, il);
+                cb(ffn_output, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_output);
+            cur = ggml_add(ctx0, cur, inpL);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = build_norm(inpL,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+        cb(cur, "result_output_no_bias", -1);
+
+        cur = ggml_add(ctx0, cur, model.output_b);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_phi3 : public llm_graph_context {
+    llm_build_phi3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, true);
+
+        for (int il = 0; il < n_layer; ++il) {
+            auto * residual = inpL;
+
+            // self-attention
+            {
+                // rope freq factors for 128k context
+                ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
+
+                ggml_tensor* attn_norm_output = build_norm(inpL,
+                        model.layers[il].attn_norm,
+                        model.layers[il].attn_norm_b,
+                        LLM_NORM_RMS, il);
+                cb(attn_norm_output, "attn_norm", il);
+
+                ggml_tensor * Qcur = nullptr;
+                ggml_tensor * Kcur = nullptr;
+                ggml_tensor * Vcur = nullptr;
+
+                if (model.layers[il].wqkv) {
+                    cur = build_lora_mm(model.layers[il].wqkv, attn_norm_output);
+                    cb(cur, "wqkv", il);
+
+                    Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd)));
+                    Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd)));
+                    Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)));
+                } else {
+                    Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq);
+                    Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk);
+                    Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv);
+                }
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig,
+                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head)));
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig,
+                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor* inp_out_ids = build_inp_out_ids();
+                cur      = ggml_get_rows(ctx0, cur,      inp_out_ids);
+                residual = ggml_get_rows(ctx0, residual, inp_out_ids);
+            }
+
+            cur = ggml_add(ctx0, cur, residual);
+            residual = cur;
+
+            cur = build_norm(cur,
+                    model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            // feed-forward network
+            if (model.layers[il].ffn_gate_inp == nullptr) {
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        NULL,                      NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SWIGLU, LLM_FFN_SEQ, il);
+                cb(cur, "ffn_out", il);
+            } else {
+                // MoE branch
+                cur = build_moe_ffn(cur,
+                        model.layers[il].ffn_gate_inp,
+                        model.layers[il].ffn_up_exps,
+                        model.layers[il].ffn_gate_exps,
+                        model.layers[il].ffn_down_exps,
+                        nullptr,
+                        n_expert, n_expert_used,
+                        LLM_FFN_SILU, true,
+                        false, 0.0,
+                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                        il);
+                cb(cur, "ffn_moe_out", il);
+            }
+
+            cur = ggml_add(ctx0, residual, cur);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = build_norm(inpL,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        if (model.output_b != nullptr) {
+            cb(cur, "result_output_no_bias", -1);
+            cur = ggml_add(ctx0, cur, model.output_b);
+        }
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_plamo : public llm_graph_context {
+    llm_build_plamo(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            ggml_tensor * attention_norm = cur;
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head,    n_tokens), inp_pos, nullptr,
+                        n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+            ggml_tensor * sa_out = cur;
+
+            cur = attention_norm;
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur    = ggml_get_rows(ctx0,    cur, inp_out_ids);
+                sa_out = ggml_get_rows(ctx0, sa_out, inp_out_ids);
+                inpL   = ggml_get_rows(ctx0,   inpL, inp_out_ids);
+            }
+
+            // feed-forward network
+            {
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, sa_out);
+            cur = ggml_add(ctx0, cur, inpL);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_gpt2 : public llm_graph_context {
+    llm_build_gpt2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * pos;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
+        cb(pos, "pos_embd", -1);
+
+        inpL = ggml_add(ctx0, inpL, pos);
+        cb(inpL, "inpL", -1);
+
+        for (int il = 0; il < n_layer; ++il) {
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                cur = build_lora_mm(model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                cb(cur, "bqkv", il);
+
+                ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            // add the input
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // FF
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm,
+                        model.layers[il].ffn_norm_b,
+                        LLM_NORM, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = build_norm(inpL,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_codeshell : public llm_graph_context {
+    llm_build_codeshell(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                cur = build_lora_mm(model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                cb(cur, "bqkv", il);
+
+                ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+                cb(tmpq, "tmpq", il);
+                cb(tmpk, "tmpk", il);
+                cb(Vcur, "Vcur", il);
+
+                ggml_tensor * Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            // add the input
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // FF
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm,
+                        model.layers[il].ffn_norm_b,
+                        LLM_NORM, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = build_norm(inpL,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_orion : public llm_graph_context {
+    llm_build_orion(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+    const int64_t n_embd_head = hparams.n_embd_head_v;
+
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+
+    // inp_pos - contains the positions
+    ggml_tensor * inp_pos = build_inp_pos();
+
+    auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+    for (int il = 0; il < n_layer; ++il) {
+        ggml_tensor * inpSA = inpL;
+
+        // norm
+        cur = build_norm(inpL,
+                model.layers[il].attn_norm, model.layers[il].attn_norm_b,
+                LLM_NORM, il);
+        cb(cur, "attn_norm", il);
+
+        // self-attention
+        {
+            // compute Q and K and RoPE them
+            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+            cb(Qcur, "Qcur", il);
+            // if (model.layers[il].bq) {
+            //     Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+            //     cb(Qcur, "Qcur", il);
+            // }
+
+            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+            cb(Kcur, "Kcur", il);
+            // if (model.layers[il].bk) {
+            //     Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+            //     cb(Kcur, "Kcur", il);
+            // }
+
+            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+            cb(Vcur, "Vcur", il);
+            // if (model.layers[il].bv) {
+            //     Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+            //     cb(Vcur, "Vcur", il);
+            // }
+
+            Qcur = ggml_rope_ext(
+                ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
+                n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                ext_factor, attn_factor, beta_fast, beta_slow
+            );
+            cb(Qcur, "Qcur", il);
+
+            Kcur = ggml_rope_ext(
+                ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                ext_factor, attn_factor, beta_fast, beta_slow
+            );
+            cb(Kcur, "Kcur", il);
+
+            cur = build_attn(inp_attn, gf,
+                    model.layers[il].wo, NULL,
+                    Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+        }
+
+        if (il == n_layer - 1) {
+            // skip computing output for unused tokens
+            ggml_tensor * inp_out_ids = build_inp_out_ids();
+            cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+
+        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+        cb(ffn_inp, "ffn_inp", il);
+
+        // feed-forward network
+        cur = build_norm(ffn_inp,
+                model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
+                LLM_NORM, il);
+        cb(cur, "ffn_norm", il);
+
+        cur = build_ffn(cur,
+                model.layers[il].ffn_up,   NULL, NULL,
+                model.layers[il].ffn_gate, NULL, NULL,
+                model.layers[il].ffn_down, NULL, NULL,
+                NULL,
+                LLM_FFN_SILU, LLM_FFN_PAR, il);
+        cb(cur, "ffn_out", il);
+
+        cur = ggml_add(ctx0, cur, ffn_inp);
+
+        cur = build_cvec(cur, il);
+        cb(cur, "l_out", il);
+
+        // input for next layer
+        inpL = cur;
+    }
+
+    cur = inpL;
+
+    cur = build_norm(cur,
+            model.output_norm, model.output_norm_b,
+            LLM_NORM, -1);
+
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    // lm_head
+    cur = build_lora_mm(model.output, cur);
+
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_internlm2 : public llm_graph_context {
+    llm_build_internlm2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_minicpm3 : public llm_graph_context {
+    llm_build_minicpm3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        //TODO: if the model varies, these parameters need to be read from the model
+        const int64_t n_embd_base = 256;
+        const float scale_embd  = 12.0f;
+        const float scale_depth = 1.4f;
+        const float kq_scale = 1.0f / sqrtf(float(hparams.n_embd_head_k));
+
+        const uint32_t n_embd_head_qk_rope = hparams.n_rot;
+        const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
+        const uint32_t kv_lora_rank = hparams.n_lora_kv;
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // scale the input embeddings
+        inpL = ggml_scale(ctx0, inpL, scale_embd);
+        cb(inpL, "inp_scaled", -1);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self_attention
+            {
+                ggml_tensor * q = NULL;
+                // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens}
+                q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
+                cb(q, "q", il);
+
+                q = build_norm(q,
+                        model.layers[il].attn_q_a_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(q, "q", il);
+
+                // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens}
+                q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
+                cb(q, "q", il);
+
+                // split into {n_head * n_embd_head_qk_nope, n_tokens}
+                ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
+                        ggml_row_size(q->type, hparams.n_embd_head_k),
+                        ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
+                        0);
+                cb(q_nope, "q_nope", il);
+
+                // and {n_head * n_embd_head_qk_rope, n_tokens}
+                ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
+                        ggml_row_size(q->type, hparams.n_embd_head_k),
+                        ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
+                        ggml_row_size(q->type, n_embd_head_qk_nope));
+                cb(q_pe, "q_pe", il);
+
+                // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
+                ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
+                cb(kv_pe_compresseed, "kv_pe_compresseed", il);
+
+                // split into {kv_lora_rank, n_tokens}
+                ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
+                        kv_pe_compresseed->nb[1],
+                        0);
+                cb(kv_compressed, "kv_compressed", il);
+
+                // and {n_embd_head_qk_rope, n_tokens}
+                ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
+                        kv_pe_compresseed->nb[1],
+                        kv_pe_compresseed->nb[1],
+                        ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
+                cb(k_pe, "k_pe", il);
+
+                // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont
+                kv_compressed = ggml_cont(ctx0, kv_compressed);
+                kv_compressed = build_norm(kv_compressed,
+                        model.layers[il].attn_kv_a_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(kv_compressed, "kv_compressed", il);
+
+                // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
+                ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
+                cb(kv, "kv", il);
+
+                // split into {n_head * n_embd_head_qk_nope, n_tokens}
+                ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
+                        ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
+                        ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
+                        0);
+                cb(k_nope, "k_nope", il);
+
+                // and {n_head * n_embd_head_v, n_tokens}
+                ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
+                        ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
+                        ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
+                        ggml_row_size(kv->type, (n_embd_head_qk_nope)));
+                cb(v_states, "v_states", il);
+
+                v_states = ggml_cont(ctx0, v_states);
+                cb(v_states, "v_states", il);
+
+                v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
+                        ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
+                        0);
+                cb(v_states, "v_states", il);
+
+                q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
+                q_pe = ggml_rope_ext(
+                        ctx0, q_pe, inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(q_pe, "q_pe", il);
+
+                // shared RoPE key
+                k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
+                k_pe = ggml_rope_ext(
+                        ctx0, k_pe, inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(k_pe, "k_pe", il);
+
+                ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
+                cb(q_states, "q_states", il);
+
+                ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
+                cb(k_states, "k_states", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        q_states, k_states, v_states, nullptr, kq_scale, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            // scale_res - scale the hidden states for residual connection
+            const float scale_res = scale_depth/sqrtf(float(n_layer));
+            cur = ggml_scale(ctx0, cur, scale_res);
+            cb(cur, "hidden_scaled", il);
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            // scale the hidden states for residual connection
+            cur = ggml_scale(ctx0, cur, scale_res);
+            cb(cur, "hidden_scaled_ffn", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head scaling
+        const float scale_lmhead = float(n_embd_base)/float(n_embd);
+        cur = ggml_scale(ctx0, cur, scale_lmhead);
+        cb(cur, "lmhead_scaling", -1);
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_gemma : public llm_graph_context {
+    llm_build_gemma(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head_k = hparams.n_embd_head_k;
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
+        cb(inpL, "inp_scaled", -1);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(Qcur, "Qcur", il);
+
+                Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));
+                cb(Qcur, "Qcur_scaled", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
+            cb(sa_out, "sa_out", il);
+
+            cur = build_norm(sa_out,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            // feed-forward network
+            {
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, sa_out);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_gemma2 : public llm_graph_context {
+    llm_build_gemma2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head_k = hparams.n_embd_head_k;
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
+        cb(inpL, "inp_scaled", -1);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, true);
+
+        for (int il = 0; il < n_layer; ++il) {
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(Qcur, "Qcur", il);
+
+                // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
+                switch (model.type) {
+                    case LLM_TYPE_2B:
+                    case LLM_TYPE_9B:  Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));   break;
+                    case LLM_TYPE_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break;
+                    default: GGML_ABORT("fatal error");
+                };
+                cb(Qcur, "Qcur_scaled", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f, il);
+            }
+
+            cur = build_norm(cur,
+                    model.layers[il].attn_post_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_post_norm", il);
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
+            cb(sa_out, "sa_out", il);
+
+            cur = build_norm(sa_out,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            // feed-forward network
+            {
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = build_norm(cur,
+                    model.layers[il].ffn_post_norm, NULL,
+                    LLM_NORM_RMS, -1);
+            cb(cur, "ffn_post_norm", -1);
+
+            cur = ggml_add(ctx0, cur, sa_out);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        // final logit soft-capping
+        cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
+        cur = ggml_tanh(ctx0, cur);
+        cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_gemma3 : public llm_graph_context {
+    llm_build_gemma3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head_k = hparams.n_embd_head_k;
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
+        if (ubatch.token) {
+            inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
+            cb(inpL, "inp_scaled", -1);
+        }
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        // TODO: is causal == true correct? might need some changes
+        auto * inp_attn = build_attn_inp_kv_unified(true, true);
+
+        // "5-to-1 interleaved attention"
+        // 5 layers of local attention followed by 1 layer of global attention
+        static const int sliding_window_pattern = 6;
+
+        for (int il = 0; il < n_layer; ++il) {
+            const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
+
+            const float freq_base_l  = is_sliding ? 10000.0f : freq_base;
+            const float freq_scale_l = is_sliding ? 1.0f     : freq_scale;
+
+            // norm
+            cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens);
+                Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+                cb(Qcur, "Qcur_normed", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens);
+                Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+                cb(Kcur, "Kcur_normed", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, hparams.f_attention_scale, il);
+            }
+
+            cur = build_norm(cur,
+                    model.layers[il].attn_post_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_post_norm", il);
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
+            cb(sa_out, "sa_out", il);
+
+            cur = build_norm(sa_out,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            // feed-forward network
+            {
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = build_norm(cur,
+                    model.layers[il].ffn_post_norm, NULL,
+                    LLM_NORM_RMS, -1);
+            cb(cur, "ffn_post_norm", -1);
+
+            cur = ggml_add(ctx0, cur, sa_out);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+// TODO: move up next to build_starcoder
+struct llm_build_starcoder2 : public llm_graph_context {
+    llm_build_starcoder2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, model.layers[il].attn_norm_b,
+                    LLM_NORM, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
+                    LLM_NORM, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                    NULL,                      NULL,                        NULL,
+                    model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                    NULL,
+                    LLM_FFN_GELU, LLM_FFN_SEQ, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, model.output_norm_b,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_mamba : public llm_graph_context {
+    const llama_model & model;
+
+    llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) {
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        // {n_embd, n_tokens}
+        inpL = build_inp_embd(model.tok_embd);
+
+        ggml_tensor * state_copy = build_inp_s_copy();
+        ggml_tensor * state_mask = build_inp_s_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            //cur = build_mamba_layer(gf, cur, state_copy, state_mask, il);
+            cur = build_mamba_layer(gf, cur, state_copy, state_mask, ubatch, il);
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            // residual
+            cur = ggml_add(ctx0, cur, inpL);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        // final rmsnorm
+        cur = build_norm(inpL,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+
+    // TODO: split
+    ggml_tensor * build_mamba_layer(
+             ggml_cgraph * gf,
+             ggml_tensor * cur,
+             ggml_tensor * state_copy,
+             ggml_tensor * state_mask,
+      const llama_ubatch & ubatch,
+                     int   il) const {
+        const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+
+        const auto kv_head = kv_self->head;
+
+        const int64_t d_conv  = hparams.ssm_d_conv;
+        const int64_t d_inner = hparams.ssm_d_inner;
+        const int64_t d_state = hparams.ssm_d_state;
+        const int64_t dt_rank = hparams.ssm_dt_rank;
+        const int64_t n_seqs  = ubatch.n_seqs;
+        // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
+        const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
+        // Use the same RMS norm as the final layer norm
+        const float norm_rms_eps = hparams.f_norm_rms_eps;
+
+        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+        GGML_ASSERT(n_seqs != 0);
+        GGML_ASSERT(ubatch.equal_seqs);
+        GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+
+        ggml_tensor * conv_states_all = kv_self->k_l[il];
+        ggml_tensor * ssm_states_all  = kv_self->v_l[il];
+
+        // (ab)using the KV cache to store the states
+        ggml_tensor * conv = build_copy_mask_state(
+                gf, conv_states_all, state_copy, state_mask,
+                hparams.n_embd_k_s(), n_seqs);
+        conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
+        ggml_tensor * ssm = build_copy_mask_state(
+                gf, ssm_states_all, state_copy, state_mask,
+                hparams.n_embd_v_s(), n_seqs);
+        ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
+
+        // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
+        cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
+
+        // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs}
+        ggml_tensor * xz = build_lora_mm(model.layers[il].ssm_in, cur);
+        // split the above in two
+        // => {d_inner, n_seq_tokens, n_seqs}
+        ggml_tensor * x = ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0);
+        ggml_tensor * z = ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], d_inner*ggml_element_size(xz));
+
+        // conv
+        {
+            // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
+            ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0);
+
+            // copy last (d_conv - 1) columns back into the state cache
+            ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
+
+            ggml_build_forward_expand(gf,
+                ggml_cpy(ctx0, last_conv,
+                    ggml_view_1d(ctx0, conv_states_all,
+                        (d_conv - 1)*(d_inner)*(n_seqs),
+                        kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all))));
+
+            // 1D convolution
+            // The equivalent is to make a self-overlapping view of conv_x
+            // over d_conv columns at each stride in the 3rd dimension,
+            // then element-wise multiply that with the conv1d weight,
+            // then sum the elements of each row,
+            // (the last two steps are a dot product over rows (also doable with mul_mat))
+            // then permute away the ne[0] dimension,
+            // and then you're left with the resulting x tensor.
+            // For simultaneous sequences, all sequences need to have the same length.
+            x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
+
+            // bias
+            x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b);
+
+            x = ggml_silu(ctx0, x);
+        }
+
+        // ssm
+        {
+            // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs}
+            ggml_tensor * x_db = build_lora_mm(model.layers[il].ssm_x, x);
+            // split
+            ggml_tensor * dt = ggml_view_3d(ctx0, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0);
+            ggml_tensor * B  = ggml_view_3d(ctx0, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank);
+            ggml_tensor * C  = ggml_view_3d(ctx0, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state));
+
+            // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
+            if (ssm_dt_b_c_rms) {
+                dt = ggml_rms_norm(ctx0, dt, norm_rms_eps);
+                B = ggml_rms_norm(ctx0, B, norm_rms_eps);
+                C = ggml_rms_norm(ctx0, C, norm_rms_eps);
+            }
+
+            // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
+            dt = build_lora_mm(model.layers[il].ssm_dt, dt);
+            dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
+
+            // Custom operator to optimize the parallel associative scan
+            // as described in the Annex D of the Mamba paper.
+            // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
+            ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, ssm, x, dt, model.layers[il].ssm_a, B, C);
+
+            // store last states
+            ggml_build_forward_expand(gf,
+                ggml_cpy(ctx0,
+                    ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]),
+                    ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
+
+            ggml_tensor * y = ggml_view_3d(ctx0, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0);
+
+            // TODO: skip computing output earlier for unused tokens
+
+            // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs}
+            y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
+            y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z)));
+
+            // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
+            cur = build_lora_mm(model.layers[il].ssm_out, y);
+        }
+
+        // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
+        cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
+        //cb(cur, "mamba_out", il);
+
+        return cur;
+    }
+};
+
+struct llm_build_command_r : public llm_graph_context {
+    llm_build_command_r(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        const float f_logit_scale = hparams.f_logit_scale;
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM, il);
+            cb(cur, "attn_norm", il);
+            ggml_tensor * ffn_inp = cur;
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                if (model.layers[il].attn_q_norm) {
+                    Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens,
+                            ggml_element_size(Qcur) * n_embd_head,
+                            ggml_element_size(Qcur) * n_embd_head * n_head,
+                            0);
+                    cb(Qcur, "Qcur", il);
+                    Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens,
+                            ggml_element_size(Kcur) * n_embd_head,
+                            ggml_element_size(Kcur) * n_embd_head * n_head_kv,
+                            0);
+                    cb(Kcur, "Kcur", il);
+
+                    Qcur = build_norm(Qcur,
+                            model.layers[il].attn_q_norm,
+                            NULL,
+                            LLM_NORM, il);
+                    cb(Qcur, "Qcur", il);
+
+                    Kcur = build_norm(Kcur,
+                            model.layers[il].attn_k_norm,
+                            NULL,
+                            LLM_NORM, il);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur     = ggml_get_rows(ctx0,     cur, inp_out_ids);
+                inpL    = ggml_get_rows(ctx0,    inpL, inp_out_ids);
+                ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
+            }
+
+            ggml_tensor * attn_out = cur;
+
+            // feed-forward network
+            {
+                cur = build_ffn(ffn_inp,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            // add together residual + FFN + self-attention
+            cur = ggml_add(ctx0, cur, inpL);
+            cur = ggml_add(ctx0, cur, attn_out);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        if (f_logit_scale) {
+            cur = ggml_scale(ctx0, cur, f_logit_scale);
+        }
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_cohere2 : public llm_graph_context {
+    llm_build_cohere2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        const float f_logit_scale = hparams.f_logit_scale;
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, true);
+
+        // sliding window switch pattern
+        const int32_t sliding_window_pattern = 4;
+
+        for (int il = 0; il < n_layer; ++il) {
+            // three layers sliding window attention (window size 4096) and ROPE
+            // fourth layer uses global attention without positional embeddings
+            const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
+
+            // norm
+            cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM, il);
+            cb(cur, "attn_norm", il);
+            ggml_tensor * ffn_inp = cur;
+
+            // self-attention
+            {
+                // rope freq factors for 128k context
+                ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
+
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                if (is_sliding) {
+                    Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
+                            n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
+                            beta_fast, beta_slow);
+                    cb(Qcur, "Qcur", il);
+
+                    Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                            rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
+                            attn_factor, beta_fast, beta_slow);
+                    cb(Kcur, "Kcur", il);
+                } else {
+                    // For non-sliding layers, just reshape without applying RoPE
+                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                    cb(Qcur, "Qcur", il);
+
+                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur     = ggml_get_rows(ctx0, cur, inp_out_ids);
+                inpL    = ggml_get_rows(ctx0, inpL, inp_out_ids);
+                ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
+            }
+
+            ggml_tensor * attn_out = cur;
+
+            // feed-forward network
+            {
+                cur = build_ffn(ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate,
+                        NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR,
+                        il);
+                cb(cur, "ffn_out", il);
+            }
+
+            // add together residual + FFN + self-attention
+            cur = ggml_add(ctx0, cur, inpL);
+            cur = ggml_add(ctx0, cur, attn_out);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur, model.output_norm, NULL, LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        if (f_logit_scale) {
+            cur = ggml_scale(ctx0, cur, f_logit_scale);
+        }
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+// ref: https://allenai.org/olmo
+// based on the original build_llama() function, changes:
+//   * non-parametric layer norm
+//   * clamp qkv
+//   * removed bias
+//   * removed MoE
+struct llm_build_olmo : public llm_graph_context {
+    llm_build_olmo(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    NULL, NULL,
+                    LLM_NORM, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (hparams.f_clamp_kqv > 0.0f) {
+                    Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (hparams.f_clamp_kqv > 0.0f) {
+                    Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (hparams.f_clamp_kqv > 0.0f) {
+                    Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, nullptr,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = build_norm(ffn_inp,
+                    NULL, NULL,
+                    LLM_NORM, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                NULL, NULL,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_olmo2 : public llm_graph_context {
+    llm_build_olmo2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            cur = inpL;
+
+            // self_attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(Qcur, "Qcur_normed", il);
+
+                Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(Kcur, "Kcur_normed", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur_rope", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur_rope", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            cur = build_norm(cur,
+                    model.layers[il].attn_post_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_post_norm", il);
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = build_ffn(ffn_inp,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+
+            cur = build_norm(cur,
+                    model.layers[il].ffn_post_norm, NULL,
+                    LLM_NORM_RMS, -1);
+            cb(cur, "ffn_post_norm", -1);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+// based on the build_qwen2moe() function, changes:
+//   * removed shared experts
+//   * removed bias
+//   * added q, k norm
+struct llm_build_olmoe : public llm_graph_context {
+    llm_build_olmoe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self_attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(Qcur, "Qcur_normed", il);
+
+                Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(Kcur, "Kcur_normed", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur_rope", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur_rope", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // MoE branch
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_moe_ffn(cur,
+                    model.layers[il].ffn_gate_inp,
+                    model.layers[il].ffn_up_exps,
+                    model.layers[il].ffn_gate_exps,
+                    model.layers[il].ffn_down_exps,
+                    nullptr,
+                    n_expert, n_expert_used,
+                    LLM_FFN_SILU, false,
+                    false, 0.0,
+                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                    il);
+            cb(cur, "ffn_moe_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_openelm : public llm_graph_context {
+    llm_build_openelm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            const int64_t n_head    = hparams.n_head(il);
+            const int64_t n_head_kv = hparams.n_head_kv(il);
+            const int64_t n_head_qkv = 2*n_head_kv + n_head;
+
+            cur = inpL;
+            ggml_tensor * residual = cur;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                cur = build_lora_mm(model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_reshape_3d(ctx0, cur, n_embd_head_k, n_head_qkv, n_tokens);
+
+                ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, cur->nb[1], cur->nb[2], 0));
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head));
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv)));
+                cb(Vcur, "Vcur", il);
+
+                Qcur = build_norm(Qcur,
+                        model.layers[il].attn_q_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(Qcur, "Qcur", il);
+
+                Kcur = build_norm(Kcur,
+                        model.layers[il].attn_k_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(Kcur, "Kcur", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, NULL, n_rot, rope_type, n_ctx_orig,
+                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, NULL, n_rot, rope_type, n_ctx_orig,
+                        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                Vcur = ggml_reshape_2d(ctx0, Vcur, n_embd_head * n_head_kv, n_tokens);
+                cb(Qcur, "Vcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                residual = ggml_get_rows(ctx0, residual, inp_out_ids);
+                cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        // norm
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_gptneox : public llm_graph_context {
+    llm_build_gptneox(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                cur = build_lora_mm(model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                cb(cur, "bqkv", il);
+
+                ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            // ffn
+            if (hparams.use_par_res) {
+                // attention and ffn are computed in parallel
+                // x = x + attn(ln1(x)) + ffn(ln2(x))
+
+                ggml_tensor * attn_out = cur;
+
+                cur = build_norm(inpL,
+                        model.layers[il].ffn_norm,
+                        model.layers[il].ffn_norm_b,
+                        LLM_NORM, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, il);
+                cb(cur, "ffn_out", il);
+
+                cur = ggml_add(ctx0, cur, inpL);
+                cb(cur, "ffn_out", il);
+
+                cur = ggml_add(ctx0, cur, attn_out);
+
+                cur = build_cvec(cur, il);
+                cb(cur, "l_out", il);
+
+                // input for next layer
+                inpL = cur;
+            } else {
+                // attention and ffn are computed sequentially
+                // x = x + attn(ln1(x))
+                // x = x + ffn(ln2(x))
+
+                ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+                cb(ffn_inp, "ffn_inp", il);
+
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm,
+                        model.layers[il].ffn_norm_b,
+                        LLM_NORM, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, il);
+                cb(cur, "ffn_out", il);
+
+                cur = ggml_add(ctx0, cur, ffn_inp);
+
+                cur = build_cvec(cur, il);
+                cb(cur, "l_out", il);
+
+                // input for next layer
+                inpL = cur;
+            }
+        }
+
+        cur = build_norm(inpL,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_arctic : public llm_graph_context {
+    llm_build_arctic(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+
+            ggml_tensor * ffn_out = ggml_add(ctx0, cur, ffn_inp);
+            cb(ffn_out, "ffn_out", il);
+
+            // MoE
+            cur = build_norm(inpSA,
+                    model.layers[il].ffn_norm_exps, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm_exps", il);
+
+            cur = build_moe_ffn(cur,
+                    model.layers[il].ffn_gate_inp,
+                    model.layers[il].ffn_up_exps,
+                    model.layers[il].ffn_gate_exps,
+                    model.layers[il].ffn_down_exps,
+                    nullptr,
+                    n_expert, n_expert_used,
+                    LLM_FFN_SILU, true,
+                    false, 0.0,
+                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                    il);
+            cb(cur, "ffn_moe_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_out);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_deepseek : public llm_graph_context {
+    llm_build_deepseek(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // rope freq factors for llama3; may return nullptr for llama2 and other models
+                ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
+
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, kq_scale, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            if ((uint32_t) il < hparams.n_layer_dense_lead) {
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            } else {
+                // MoE branch
+                ggml_tensor * moe_out =
+                    build_moe_ffn(cur,
+                            model.layers[il].ffn_gate_inp,
+                            model.layers[il].ffn_up_exps,
+                            model.layers[il].ffn_gate_exps,
+                            model.layers[il].ffn_down_exps,
+                            nullptr,
+                            n_expert, n_expert_used,
+                            LLM_FFN_SILU, false,
+                            false, hparams.expert_weights_scale,
+                            LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                            il);
+                cb(moe_out, "ffn_moe_out", il);
+
+                // FFN shared expert
+                {
+                    ggml_tensor * ffn_shexp = build_ffn(cur,
+                            model.layers[il].ffn_up_shexp,   NULL, NULL,
+                            model.layers[il].ffn_gate_shexp, NULL, NULL,
+                            model.layers[il].ffn_down_shexp, NULL, NULL,
+                            NULL,
+                            LLM_FFN_SILU, LLM_FFN_PAR, il);
+                    cb(ffn_shexp, "ffn_shexp", il);
+
+                    cur = ggml_add(ctx0, moe_out, ffn_shexp);
+                    cb(cur, "ffn_out", il);
+                }
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_deepseek2 : public llm_graph_context {
+    llm_build_deepseek2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        bool is_lite = (hparams.n_layer == 27);
+
+        // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly.
+        // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
+        const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale));
+        const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k));
+        const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
+
+        const uint32_t n_embd_head_qk_rope = hparams.n_rot;
+        const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
+        const uint32_t kv_lora_rank = hparams.n_lora_kv;
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        // {n_embd, n_tokens}
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self_attention
+            {
+                ggml_tensor * q = NULL;
+                if (!is_lite) {
+                    // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens}
+                    q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
+                    cb(q, "q", il);
+
+                    q = build_norm(q,
+                            model.layers[il].attn_q_a_norm, NULL,
+                            LLM_NORM_RMS, il);
+                    cb(q, "q", il);
+
+                    // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens}
+                    q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
+                    cb(q, "q", il);
+                } else {
+                    q = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                    cb(q, "q", il);
+                }
+
+                // split into {n_head * n_embd_head_qk_nope, n_tokens}
+                ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
+                        ggml_row_size(q->type, hparams.n_embd_head_k),
+                        ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
+                        0);
+                cb(q_nope, "q_nope", il);
+
+                // and {n_head * n_embd_head_qk_rope, n_tokens}
+                ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
+                        ggml_row_size(q->type, hparams.n_embd_head_k),
+                        ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
+                        ggml_row_size(q->type, n_embd_head_qk_nope));
+                cb(q_pe, "q_pe", il);
+
+                // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
+                ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
+                cb(kv_pe_compresseed, "kv_pe_compresseed", il);
+
+                // split into {kv_lora_rank, n_tokens}
+                ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
+                        kv_pe_compresseed->nb[1],
+                        0);
+                cb(kv_compressed, "kv_compressed", il);
+
+                // and {n_embd_head_qk_rope, n_tokens}
+                ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
+                        kv_pe_compresseed->nb[1],
+                        kv_pe_compresseed->nb[1],
+                        ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
+                cb(k_pe, "k_pe", il);
+
+                // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont
+                kv_compressed = ggml_cont(ctx0, kv_compressed);
+                kv_compressed = build_norm(kv_compressed,
+                        model.layers[il].attn_kv_a_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(kv_compressed, "kv_compressed", il);
+
+                // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
+                ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
+                cb(kv, "kv", il);
+
+                // split into {n_head * n_embd_head_qk_nope, n_tokens}
+                ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
+                        ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
+                        ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
+                        0);
+                cb(k_nope, "k_nope", il);
+
+                // and {n_head * n_embd_head_v, n_tokens}
+                ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
+                        ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
+                        ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
+                        ggml_row_size(kv->type, (n_embd_head_qk_nope)));
+                cb(v_states, "v_states", il);
+
+                v_states = ggml_cont(ctx0, v_states);
+                cb(v_states, "v_states", il);
+
+                v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
+                        ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
+                        0);
+                cb(v_states, "v_states", il);
+
+                q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
+                q_pe = ggml_rope_ext(
+                        ctx0, q_pe, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor_scaled, beta_fast, beta_slow
+                        );
+                cb(q_pe, "q_pe", il);
+
+                // shared RoPE key
+                k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
+                k_pe = ggml_rope_ext(
+                        ctx0, k_pe, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor_scaled, beta_fast, beta_slow
+                        );
+                cb(k_pe, "k_pe", il);
+
+                ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
+                cb(q_states, "q_states", il);
+
+                ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
+                cb(k_states, "k_states", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        q_states, k_states, v_states, nullptr, kq_scale, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            if ((uint32_t) il < hparams.n_layer_dense_lead) {
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            } else {
+                // MoE branch
+                ggml_tensor * moe_out =
+                    build_moe_ffn(cur,
+                            model.layers[il].ffn_gate_inp,
+                            model.layers[il].ffn_up_exps,
+                            model.layers[il].ffn_gate_exps,
+                            model.layers[il].ffn_down_exps,
+                            model.layers[il].ffn_exp_probs_b,
+                            n_expert, n_expert_used,
+                            LLM_FFN_SILU, hparams.expert_weights_norm,
+                            true, hparams.expert_weights_scale,
+                            (llama_expert_gating_func_type) hparams.expert_gating_func,
+                            il);
+                cb(moe_out, "ffn_moe_out", il);
+
+                // FFN shared expert
+                {
+                    ggml_tensor * ffn_shexp = build_ffn(cur,
+                            model.layers[il].ffn_up_shexp,   NULL, NULL,
+                            model.layers[il].ffn_gate_shexp, NULL, NULL,
+                            model.layers[il].ffn_down_shexp, NULL, NULL,
+                            NULL,
+                            LLM_FFN_SILU, LLM_FFN_PAR, il);
+                    cb(ffn_shexp, "ffn_shexp", il);
+
+                    cur = ggml_add(ctx0, moe_out, ffn_shexp);
+                    cb(cur, "ffn_out", il);
+                }
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = ggml_mul_mat(ctx0, model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_bitnet : public llm_graph_context {
+    llm_build_bitnet(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                if (model.layers[il].wq_scale) {
+                    Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale);
+                }
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                // B1.K
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                if (model.layers[il].wk_scale) {
+                    Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale);
+                }
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                // B1.V
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                if (model.layers[il].wv_scale) {
+                    Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale);
+                }
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        NULL, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+
+                cur = build_norm(cur,
+                        model.layers[il].attn_sub_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "attn_sub_norm", il);
+
+                cur = build_lora_mm(model.layers[il].wo, cur);
+                if (model.layers[il].wo_scale) {
+                    cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale);
+                }
+                if (model.layers[il].bo) {
+                    cur = ggml_add(ctx0, cur, model.layers[il].bo);
+                }
+                cb(cur, "attn_o_out", il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward forward
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   NULL, model.layers[il].ffn_up_scale,
+                    model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_scale,
+                    NULL,                      NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_sub_out", il);
+
+            cur = build_norm(cur,
+                    model.layers[il].ffn_sub_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_sub_norm", il);
+
+            cur = build_lora_mm(model.layers[il].ffn_down, cur);
+            if (model.layers[il].ffn_down_scale) {
+                cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale);
+            }
+            cb(cur, "ffn_down", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        // FIXME: do not use model.tok_embd directly, duplicate as model.output
+        cur = build_lora_mm(model.tok_embd, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_t5_enc : public llm_graph_context {
+    llm_build_t5_enc(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        ggml_tensor * pos_bucket_enc = build_inp_pos_bucket_enc();
+
+        auto * inp_attn = build_attn_inp_no_cache();
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm_enc, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq_enc, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk_enc, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv_enc, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc;
+                ggml_tensor * kq_b = build_pos_bias(pos_bucket_enc, attn_rel_b);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo_enc, nullptr,
+                        Qcur, Kcur, Vcur, kq_b, 1.0f, il);
+                cb(cur, "kqv_out", il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm_enc, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                // T5 uses relu, flan-T5 uses gelu-gated
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up_enc,   NULL, NULL,
+                        model.layers[il].ffn_gate_enc, NULL, NULL,
+                        model.layers[il].ffn_down_enc, NULL, NULL,
+                        NULL,
+                        model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
+                        model.layers[il].ffn_gate_enc ? LLM_FFN_PAR  : LLM_FFN_SEQ,
+                        il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+        cb(cur, "result_embd", -1);
+
+        cur = build_norm(cur,
+                model.output_norm_enc, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_t5_dec : public llm_graph_context {
+    llm_build_t5_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        //const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        ggml_tensor * embd_enc       = build_inp_cross_embd();
+        ggml_tensor * pos_bucket_dec = build_inp_pos_bucket_dec();
+
+        const int64_t n_outputs_enc = embd_enc->ne[1];
+
+        auto * inp_attn_self  = build_attn_inp_kv_unified(true, false);
+        auto * inp_attn_cross = build_attn_inp_cross();
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b;
+                ggml_tensor * kq_b = build_pos_bias(pos_bucket_dec, attn_rel_b);
+
+                cur = build_attn(inp_attn_self, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, kq_b, 1.0f, il);
+                cb(cur, "kqv_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, inpSA);
+            cb(cur, "cross_inp", il);
+
+            ggml_tensor * inpCA = cur;
+
+            // norm
+            cur = build_norm(cur,
+                    model.layers[il].attn_norm_cross, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm_cross", il);
+
+            // cross-attention
+            {
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq_cross, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk_cross, embd_enc);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv_cross, embd_enc);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_outputs_enc);
+
+                cur = build_attn(inp_attn_cross, gf,
+                        model.layers[il].wo_cross, nullptr,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f, il);
+                cb(cur, "kqv_out", il);
+
+                //ggml_tensor * q =                 ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
+                //ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
+
+                //ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+                //cb(kq, "kq", il);
+
+                //kq = ggml_soft_max_ext(ctx0, kq, KQ_mask_cross, 1.0f, hparams.f_max_alibi_bias);
+                //cb(kq, "kq_soft_max_ext", il);
+
+                //ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_outputs_enc)));
+                //cb(v, "v", il);
+
+                //ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_outputs_enc, n_embd_head, n_head_kv), kq);
+                //cb(kqv, "kqv", il);
+
+                //ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+                //cb(kqv_merged, "kqv_merged", il);
+
+                //cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
+                //cb(cur, "kqv_merged_cont", il);
+
+                //ggml_build_forward_expand(gf, cur);
+
+                //cur = build_lora_mm(model.layers[il].wo_cross, cur);
+                //cb(cur, "kqv_out", il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+                inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpCA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                // T5 uses relu, flan-T5 uses gelu-gated
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
+                        model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ,
+                        il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+        cb(cur, "result_embd", -1);
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_jais : public llm_graph_context {
+    llm_build_jais(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                cur = build_lora_mm(model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                cb(cur, "bqkv", il);
+
+                ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*cur->nb[0]*(n_embd)));
+                ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd)));
+                ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa)));
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/float(n_embd_head), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            // add the input
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // FF
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm,
+                        model.layers[il].ffn_norm_b,
+                        LLM_NORM, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            inpL = ggml_add(ctx0, cur, ffn_inp);
+            cb(inpL, "l_out", il);
+        }
+
+        cur = build_norm(inpL,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_chatglm : public llm_graph_context {
+    llm_build_chatglm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm,
+                    NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                ggml_tensor * Qcur = nullptr;
+                ggml_tensor * Kcur = nullptr;
+                ggml_tensor * Vcur = nullptr;
+
+                if (model.layers[il].wqkv == nullptr) {
+                    Qcur = build_lora_mm(model.layers[il].wq, cur);
+                    if (model.layers[il].bq) {
+                        Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    }
+                    Kcur = build_lora_mm(model.layers[il].wk, cur);
+                    if (model.layers[il].bk) {
+                        Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    }
+                    Vcur = build_lora_mm(model.layers[il].wv, cur);
+                    if (model.layers[il].bv) {
+                        Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    }
+                } else {
+                    cur = build_lora_mm(model.layers[il].wqkv, cur);
+                    cb(cur, "wqkv", il);
+                    if (model.layers[il].bqkv) {
+                        cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                        cb(cur, "bqkv", il);
+                    }
+                    Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                    Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                    Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+                }
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor);
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur_rope", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur_rope", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            // Add the input
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // FF
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm,
+                        NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        NULL,                      NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SWIGLU, LLM_FFN_SEQ, il);
+                cb(cur, "ffn_out", il);
+
+            }
+
+            inpL = ggml_add(ctx0, cur, ffn_inp);
+            cb(inpL, "l_out", il);
+        }
+
+        cur = build_norm(inpL,
+                model.output_norm,
+                NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_nemotron : public llm_graph_context {
+    llm_build_nemotron(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        //GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm,
+                    model.layers[il].ffn_norm_b,
+                    LLM_NORM, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                    NULL,                      NULL,                        NULL,
+                    model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                    NULL,
+                    LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, model.output_norm_b,
+                LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_exaone : public llm_graph_context {
+    llm_build_exaone(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // rope freq factors for llama3; may return nullptr for llama2 and other models
+                ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
+
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_rwkv6_base : public llm_graph_context {
+    const llama_model & model;
+
+    llm_build_rwkv6_base(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) {
+    }
+
+    ggml_tensor * build_rwkv6_channel_mix(
+            const llama_layer * layer,
+            ggml_tensor * cur,
+            ggml_tensor * x_prev,
+            llm_arch arch) const {
+        ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
+        switch (arch) {
+            case LLM_ARCH_RWKV6:
+                {
+                    ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur);
+                    ggml_tensor * xr = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_r), cur);
+
+                    ggml_tensor * r = ggml_sigmoid(ctx0, build_lora_mm(layer->channel_mix_receptance, xr));
+                    ggml_tensor * k = ggml_sqr(
+                            ctx0,
+                            ggml_relu(
+                                ctx0,
+                                build_lora_mm(layer->channel_mix_key, xk)
+                                )
+                            );
+                    cur = ggml_mul(ctx0, r, build_lora_mm(layer->channel_mix_value, k));
+                } break;
+            default:
+                GGML_ABORT("fatal error");
+        }
+
+        return cur;
+    }
+
+    ggml_tensor * build_rwkv6_time_mix(
+            ggml_cgraph * gf,
+            ggml_tensor * cur,
+            ggml_tensor * x_prev,
+            ggml_tensor * state_copy,
+            ggml_tensor * state_mask,
+            const llama_ubatch & ubatch,
+            int   il) const {
+        const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+
+        const auto n_tokens = ubatch.n_tokens;
+        const auto n_seqs = ubatch.n_seqs;
+        const auto n_embd = hparams.n_embd;
+        const auto head_size = hparams.wkv_head_size;
+        const auto n_head = n_embd / head_size;
+        const auto n_head_kv = hparams.n_head_kv(il);
+
+        const auto kv_head = kv_self->head;
+
+        const auto & layer = model.layers[il];
+
+        bool is_qrwkv = layer.time_mix_first == nullptr;
+
+        ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
+        ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer.time_mix_lerp_x), cur);
+
+        xxx = ggml_reshape_4d(
+                ctx0,
+                ggml_tanh(
+                    ctx0,
+                    ggml_mul_mat(ctx0, layer.time_mix_w1, xxx)
+                    ),
+                layer.time_mix_w1->ne[1] / 5, 1, 5, n_tokens
+                );
+
+        xxx = ggml_cont(ctx0, ggml_permute(ctx0, xxx, 0, 1, 3, 2));
+
+        xxx = ggml_mul_mat(
+                ctx0,
+                ggml_reshape_4d(
+                    ctx0,
+                    layer.time_mix_w2,
+                    layer.time_mix_w2->ne[0], layer.time_mix_w2->ne[1], 1, 5
+                    ),
+                xxx
+                );
+
+        ggml_tensor *xw, *xk, *xv, *xr, *xg;
+        if (layer.time_mix_lerp_fused) {
+            // fusing these weights makes some performance improvement
+            sx  = ggml_reshape_3d(ctx0, sx,  n_embd, 1, n_tokens);
+            cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
+            xxx = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xxx, layer.time_mix_lerp_fused), sx), cur);
+            xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0);
+            xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
+            xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
+            xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
+            xg = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
+        } else {
+            // for backward compatibility
+            xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0);
+            xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
+            xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
+            xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
+            xg = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
+
+            xw = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xw, layer.time_mix_lerp_w), sx), cur);
+            xk = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xk, layer.time_mix_lerp_k), sx), cur);
+            xv = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xv, layer.time_mix_lerp_v), sx), cur);
+            xr = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xr, layer.time_mix_lerp_r), sx), cur);
+            xg = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xg, layer.time_mix_lerp_g), sx), cur);
+        }
+
+        ggml_tensor * r = build_lora_mm(layer.time_mix_receptance, xr);
+        ggml_tensor * k = build_lora_mm(layer.time_mix_key,        xk);
+        ggml_tensor * v = build_lora_mm(layer.time_mix_value,      xv);
+        if (layer.time_mix_receptance_b) {
+            r = ggml_add(ctx0, r, layer.time_mix_receptance_b);
+        }
+        if (layer.time_mix_key_b) {
+            k = ggml_add(ctx0, k, layer.time_mix_key_b);
+        }
+        if (layer.time_mix_value_b) {
+            v = ggml_add(ctx0, v, layer.time_mix_value_b);
+        }
+
+        ggml_tensor * g = build_lora_mm(layer.time_mix_gate, xg);
+        if (is_qrwkv) {
+            g = ggml_sigmoid(ctx0, g);
+        } else {
+            g = ggml_silu(ctx0, g);
+        }
+
+        if (n_head_kv != 0 && n_head_kv != n_head) {
+            GGML_ASSERT(n_head % n_head_kv == 0);
+            k = ggml_reshape_4d(ctx0, k, head_size, 1, n_head_kv, n_tokens);
+            v = ggml_reshape_4d(ctx0, v, head_size, 1, n_head_kv, n_tokens);
+            ggml_tensor * tmp = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, head_size, n_head / n_head_kv, n_head_kv, n_tokens);
+            k = ggml_repeat(ctx0, k, tmp);
+            v = ggml_repeat(ctx0, v, tmp);
+        }
+
+        k = ggml_reshape_3d(ctx0, k, head_size, n_head, n_tokens);
+        v = ggml_reshape_3d(ctx0, v, head_size, n_head, n_tokens);
+        r = ggml_reshape_3d(ctx0, r, head_size, n_head, n_tokens);
+
+        ggml_tensor * w = ggml_mul_mat(
+                ctx0,
+                layer.time_mix_decay_w2,
+                ggml_tanh(
+                    ctx0,
+                    ggml_mul_mat(ctx0, layer.time_mix_decay_w1, xw)
+                    )
+                );
+
+        w = ggml_add(ctx0, w, layer.time_mix_decay);
+        w = ggml_exp(ctx0, ggml_neg(ctx0, ggml_exp(ctx0, w)));
+        w = ggml_reshape_3d(ctx0, w, head_size, n_head, n_tokens);
+
+        if (is_qrwkv) {
+            // k = k * (1 - w)
+            k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w));
+        }
+
+        ggml_tensor * wkv_state = build_copy_mask_state(
+                gf, kv_self->v_l[il], state_copy, state_mask,
+                hparams.n_embd_v_s(), n_seqs);
+
+        ggml_tensor * wkv_output;
+        if (is_qrwkv) {
+            wkv_output = ggml_gated_linear_attn(ctx0, k, v, r, w, wkv_state, pow(head_size, -0.5f));
+        } else {
+            wkv_output = ggml_rwkv_wkv6(ctx0, k, v, r, layer.time_mix_first, w, wkv_state);
+        }
+        cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0);
+        wkv_state = ggml_view_1d(ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
+
+        ggml_build_forward_expand(
+                gf,
+                ggml_cpy(
+                    ctx0,
+                    wkv_state,
+                    ggml_view_1d(
+                        ctx0,
+                        kv_self->v_l[il],
+                        hparams.n_embd_v_s() * n_seqs,
+                        hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il])
+                        )
+                    )
+                );
+
+        if (!is_qrwkv) {
+            // group norm with head_count groups
+            cur = ggml_reshape_3d(ctx0, cur, n_embd / n_head, n_head, n_tokens);
+            cur = ggml_norm(ctx0, cur, 64e-5f);
+
+            // Convert back to regular vectors.
+            cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
+            cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.time_mix_ln), layer.time_mix_ln_b);
+        } else {
+            cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
+        }
+
+        cur = ggml_mul(ctx0, cur, g);
+        cur = build_lora_mm(layer.time_mix_output, cur);
+
+        return cur;
+    }
+};
+
+struct llm_build_rwkv6 : public llm_build_rwkv6_base {
+    llm_build_rwkv6(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv6_base(model, params) {
+        GGML_ASSERT(hparams.token_shift_count == 2);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+        inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
+
+        ggml_tensor * state_copy = build_inp_s_copy();
+        ggml_tensor * state_mask = build_inp_s_mask();
+
+        const auto n_embd = hparams.n_embd;
+        const auto n_seq_tokens = ubatch.n_seq_tokens;
+        const auto n_seqs = ubatch.n_seqs;
+
+        for (int il = 0; il < n_layer; ++il) {
+            const llama_layer * layer = &model.layers[il];
+
+            ggml_tensor * token_shift = build_rwkv_token_shift_load(
+                    gf, state_copy, state_mask, ubatch, il
+                    );
+
+            ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
+            ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
+
+            ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM, il);
+            cb(att_norm, "attn_norm", il);
+
+            ggml_tensor * x_prev = ggml_concat(
+                    ctx0,
+                    att_shift,
+                    ggml_view_3d(ctx0, att_norm, n_embd, n_seq_tokens - 1, n_seqs, att_norm->nb[1], att_norm->nb[2], 0),
+                    1
+                    );
+
+            cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, state_mask, ubatch, il);
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+            cb(ffn_inp, "ffn_inp", il);
+
+            ggml_tensor * ffn_norm = build_norm(ffn_inp, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, il);
+            cb(ffn_norm, "ffn_norm", il);
+
+            x_prev = ggml_concat(
+                    ctx0,
+                    ffn_shift,
+                    ggml_view_3d(ctx0, ffn_norm, n_embd, n_seq_tokens - 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], 0),
+                    1
+                    );
+
+            cur = build_rwkv6_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV6);
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            token_shift = ggml_concat(ctx0,
+                    ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)),
+                    ggml_view_3d(ctx0, ffn_norm, n_embd, 1, n_seqs, ffn_norm->nb[1], ffn_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(ffn_norm)),
+                    1
+                    );
+            ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
+
+            if (hparams.rescale_every_n_layers != 0 && (il + 1) % hparams.rescale_every_n_layers == 0) {
+                cur = ggml_scale(ctx0, cur, 0.5F);
+            }
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+        cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
+        cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+
+        cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+// ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py
+struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
+    llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv6_base(model, params) {
+        GGML_ASSERT(n_embd == hparams.n_embd_k_s());
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        ggml_tensor * state_copy = build_inp_s_copy();
+        ggml_tensor * state_mask = build_inp_s_mask();
+
+        const auto n_embd = hparams.n_embd;
+        const auto n_seq_tokens = ubatch.n_seq_tokens;
+        const auto n_seqs = ubatch.n_seqs;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        for (int il = 0; il < n_layer; ++il) {
+            const llama_layer * layer = &model.layers[il];
+
+            ggml_tensor * token_shift = build_rwkv_token_shift_load(
+                    gf, state_copy, state_mask, ubatch, il
+                    );
+
+            ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
+            cb(att_norm, "attn_norm", il);
+
+            ggml_tensor * x_prev = ggml_concat(
+                    ctx0,
+                    token_shift,
+                    ggml_view_3d(ctx0, att_norm, n_embd, n_seq_tokens - 1, n_seqs, att_norm->nb[1], att_norm->nb[2], 0),
+                    1
+                    );
+
+            cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, state_mask, ubatch, il);
+
+            token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
+            ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+        ggml_tensor * inp_out_ids = build_inp_out_ids();
+        cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
+        cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+
+        cur = build_norm(cur, model.output_norm, model.output_norm_b, LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+// ref: https://github.com/facebookresearch/chameleon
+// based on the original build_llama() function, changes:
+//   * qk-norm
+//   * swin-norm
+//   * removed bias
+//   * removed MoE
+struct llm_build_chameleon : public llm_graph_context {
+    llm_build_chameleon(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified(true, false);
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            if (hparams.swin_norm) {
+                cur = inpL;
+            } else {
+                cur = build_norm(inpL,
+                        model.layers[il].attn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "attn_norm", il);
+            }
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                if (model.layers[il].attn_q_norm) {
+                    Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens,
+                            ggml_element_size(Qcur) * n_embd_head,
+                            ggml_element_size(Qcur) * n_embd_head * n_head,
+                            0);
+                    cb(Qcur, "Qcur", il);
+
+                    Qcur = build_norm(Qcur,
+                            model.layers[il].attn_q_norm,
+                            model.layers[il].attn_q_norm_b,
+                            LLM_NORM, il);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                if (model.layers[il].attn_k_norm) {
+                    Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens,
+                            ggml_element_size(Kcur) * n_embd_head,
+                            ggml_element_size(Kcur) * n_embd_head * n_head_kv,
+                            0);
+                    cb(Kcur, "Kcur", il);
+
+                    Kcur = build_norm(Kcur,
+                            model.layers[il].attn_k_norm,
+                            model.layers[il].attn_k_norm_b,
+                            LLM_NORM, il);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, nullptr,
+                        Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+
+                if (hparams.swin_norm) {
+                    cur = build_norm(cur,
+                            model.layers[il].attn_norm, NULL,
+                            LLM_NORM_RMS, il);
+                }
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            if (!hparams.swin_norm) {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+            }
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+
+            if (hparams.swin_norm) {
+                cur = build_norm(cur,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+        cb(cur, "result_output_with_img_logits", -1);
+
+        // TODO: this suppresses the output of image tokens, which is required to enable text-only outputs.
+        // Needs to be removed once image outputs are supported.
+        int img_token_end_idx = 8196;
+        int img_token_start_idx = 4;
+        int num_img_tokens = img_token_end_idx - img_token_start_idx;
+        // creates 1d tensor of size num_img_tokens and values -FLT_MAX,
+        // which ensures that text token values are always at least larger than image token values
+        ggml_tensor * img_logits = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, num_img_tokens);
+        img_logits = ggml_clamp(ctx0, img_logits, -FLT_MAX, -FLT_MAX);
+        cb(img_logits, "img_logits", -1);
+
+        cur = ggml_set_1d(ctx0, cur, img_logits, ggml_element_size(cur) * img_token_start_idx);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_wavtokenizer_dec : public llm_graph_context {
+    llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        cur = ggml_cont(ctx0, ggml_transpose(ctx0, inpL));
+
+        cur = ggml_conv_1d_ph(ctx0, model.conv1d, cur, 1, 1);
+        cur = ggml_add(ctx0, cur, model.conv1d_b);
+
+        // posnet
+        for (uint32_t il = 0; il < hparams.posnet.n_layer; ++il) {
+            const auto & layer = model.layers[il].posnet;
+
+            inpL = cur;
+
+            switch (il) {
+                case 0:
+                case 1:
+                case 3:
+                case 4:
+                    {
+                        cur = build_norm(cur,
+                                layer.norm1,
+                                layer.norm1_b,
+                                LLM_NORM_GROUP, 0);
+
+                        cur = ggml_mul(ctx0, ggml_sigmoid(ctx0, cur), cur);
+
+                        cur = ggml_conv_1d_ph(ctx0, layer.conv1, cur, 1, 1);
+                        cur = ggml_add(ctx0, cur, layer.conv1_b);
+
+                        cur = build_norm(cur,
+                                layer.norm2,
+                                layer.norm2_b,
+                                LLM_NORM_GROUP, 0);
+
+                        cur = ggml_mul(ctx0, ggml_sigmoid(ctx0, cur), cur);
+
+                        cur = ggml_conv_1d_ph(ctx0, layer.conv2, cur, 1, 1);
+                        cur = ggml_add(ctx0, cur, layer.conv2_b);
+
+                        cur = ggml_add(ctx0, cur, inpL);
+                    } break;
+                case 2:
+                    {
+                        cur = build_norm(cur,
+                                layer.attn_norm,
+                                layer.attn_norm_b,
+                                LLM_NORM_GROUP, 0);
+
+                        ggml_tensor * q;
+                        ggml_tensor * k;
+                        ggml_tensor * v;
+
+                        q = ggml_conv_1d_ph(ctx0, layer.attn_q, cur, 1, 1);
+                        k = ggml_conv_1d_ph(ctx0, layer.attn_k, cur, 1, 1);
+                        v = ggml_conv_1d_ph(ctx0, layer.attn_v, cur, 1, 1);
+
+                        q = ggml_add(ctx0, q, layer.attn_q_b);
+                        k = ggml_add(ctx0, k, layer.attn_k_b);
+                        v = ggml_add(ctx0, v, layer.attn_v_b);
+
+                        q = ggml_cont(ctx0, ggml_transpose(ctx0, q));
+                        k = ggml_cont(ctx0, ggml_transpose(ctx0, k));
+
+                        ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+
+                        kq = ggml_soft_max_ext(ctx0, kq, nullptr, 1.0f/sqrtf(float(hparams.posnet.n_embd)), 0.0f);
+
+                        cur = ggml_mul_mat(ctx0, kq, v);
+
+                        cur = ggml_conv_1d_ph(ctx0, layer.attn_o, cur, 1, 1);
+                        cur = ggml_add(ctx0, cur, layer.attn_o_b);
+
+                        cur = ggml_add(ctx0, cur, inpL);
+                    } break;
+                case 5:
+                    {
+                        cur = build_norm(cur,
+                                layer.norm,
+                                layer.norm_b,
+                                LLM_NORM_GROUP, 0);
+                    } break;
+                default: GGML_ABORT("unknown posnet layer");
+            };
+        }
+
+        cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
+
+        cur = build_norm(cur,
+                model.tok_norm,
+                model.tok_norm_b,
+                LLM_NORM, -1);
+
+        cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
+
+        inpL = cur;
+
+        // convnext
+        for (uint32_t il = 0; il < hparams.convnext.n_layer; ++il) {
+            const auto & layer = model.layers[il].convnext;
+
+            cur = inpL;
+
+            cur = ggml_conv_1d_dw_ph(ctx0, layer.dw, cur, 1, 1);
+            cur = ggml_add(ctx0, cur, layer.dw_b);
+
+            cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
+
+            cur = build_norm(cur,
+                    layer.norm,
+                    layer.norm_b,
+                    LLM_NORM, -1);
+
+            cur = build_ffn(cur,
+                    layer.pw1, layer.pw1_b, NULL,
+                    NULL,      NULL,        NULL,
+                    layer.pw2, layer.pw2_b, NULL,
+                    NULL,
+                    LLM_FFN_GELU, LLM_FFN_SEQ, il);
+
+            cur = ggml_mul(ctx0, cur, layer.gamma);
+
+            cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
+
+            inpL = ggml_add(ctx0, cur, inpL);
+        }
+
+        cur = inpL;
+
+        cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
+
+        cur = build_norm(cur,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, -1);
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cur = ggml_add(ctx0, cur, model.output_b);
+
+        cb(cur, "result_embd", -1);
+        res->t_embd = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+llama_memory_i * llama_model::create_memory() const {
+    llama_memory_i * res;
+
+    switch (arch) {
+        case LLM_ARCH_RWKV6:
+        case LLM_ARCH_RWKV6QWEN2:
+        case LLM_ARCH_MAMBA:
+            {
+                res = new llama_kv_cache_unified(hparams, {
+                    /*.get_rope_factors =*/ nullptr
+                });
+            } break;
+        default:
+            {
+                res = new llama_kv_cache_unified(hparams, {
+                    /*.get_rope_factors =*/ [this](uint32_t n_ctx_per_seq, int il) {
+                        // choose long/short freq factors based on the context size
+                        if (layers[il].rope_freqs != nullptr) {
+                            return layers[il].rope_freqs;
+                        }
+
+                        if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
+                            return layers[il].rope_long;
+                        }
+
+                        return layers[il].rope_short;
+                    }
+                });
+            }
+    }
+
+    return res;
+}
+
+llm_graph_result_ptr llama_model::build_graph(
+        const llm_graph_params & params,
+                   ggml_cgraph * gf,
+                llm_graph_type   type) const {
+    std::unique_ptr<llm_graph_context> llm;
+
+    switch (arch) {
+        case LLM_ARCH_LLAMA:
+        case LLM_ARCH_MINICPM:
+        case LLM_ARCH_GRANITE:
+        case LLM_ARCH_GRANITE_MOE:
+            {
+                llm = std::make_unique<llm_build_llama>(*this, params, gf);
+            } break;
+        case LLM_ARCH_DECI:
+            {
+                llm = std::make_unique<llm_build_deci>(*this, params, gf);
+            } break;
+        case LLM_ARCH_BAICHUAN:
+            {
+                llm = std::make_unique<llm_build_baichuan>(*this, params, gf);
+            } break;
+        case LLM_ARCH_FALCON:
+            {
+                llm = std::make_unique<llm_build_falcon>(*this, params, gf);
+            } break;
+        case LLM_ARCH_GROK:
+            {
+                llm = std::make_unique<llm_build_grok>(*this, params, gf);
+            } break;
+        case LLM_ARCH_STARCODER:
+            {
+                llm = std::make_unique<llm_build_starcoder>(*this, params, gf);
+            } break;
+        case LLM_ARCH_REFACT:
+            {
+                llm = std::make_unique<llm_build_refact>(*this, params, gf);
+            } break;
+        case LLM_ARCH_BERT:
+        case LLM_ARCH_JINA_BERT_V2:
+        case LLM_ARCH_NOMIC_BERT:
+            {
+                llm = std::make_unique<llm_build_bert>(*this, params, gf);
+            } break;
+        case LLM_ARCH_BLOOM:
+            {
+                llm = std::make_unique<llm_build_bloom>(*this, params, gf);
+            } break;
+        case LLM_ARCH_MPT:
+            {
+                llm = std::make_unique<llm_build_mpt>(*this, params, gf);
+            } break;
+        case LLM_ARCH_STABLELM:
+            {
+                llm = std::make_unique<llm_build_stablelm>(*this, params, gf);
+            } break;
+        case LLM_ARCH_QWEN:
+            {
+                llm = std::make_unique<llm_build_qwen>(*this, params, gf);
+            } break;
+        case LLM_ARCH_QWEN2:
+            {
+                llm = std::make_unique<llm_build_qwen2>(*this, params, gf);
+            } break;
+        case LLM_ARCH_QWEN2VL:
+            {
+                llm = std::make_unique<llm_build_qwen2vl>(*this, params, gf);
+            } break;
+        case LLM_ARCH_QWEN2MOE:
+            {
+                llm = std::make_unique<llm_build_qwen2moe>(*this, params, gf);
+            } break;
+        case LLM_ARCH_PHI2:
+            {
+                llm = std::make_unique<llm_build_phi2>(*this, params, gf);
+            } break;
+        case LLM_ARCH_PHI3:
+        case LLM_ARCH_PHIMOE:
+            {
+                llm = std::make_unique<llm_build_phi3>(*this, params, gf);
+            } break;
+        case LLM_ARCH_PLAMO:
+            {
+                llm = std::make_unique<llm_build_plamo>(*this, params, gf);
+            } break;
+        case LLM_ARCH_GPT2:
+            {
+                llm = std::make_unique<llm_build_gpt2>(*this, params, gf);
+            } break;
+        case LLM_ARCH_CODESHELL:
+            {
+                llm = std::make_unique<llm_build_codeshell>(*this, params, gf);
+            } break;
+        case LLM_ARCH_ORION:
+            {
+                llm = std::make_unique<llm_build_orion>(*this, params, gf);
+            } break;
+        case LLM_ARCH_INTERNLM2:
+            {
+                llm = std::make_unique<llm_build_internlm2>(*this, params, gf);
+            } break;
+        case LLM_ARCH_MINICPM3:
+            {
+                llm = std::make_unique<llm_build_minicpm3>(*this, params, gf);
+            } break;
+        case LLM_ARCH_GEMMA:
+            {
+                llm = std::make_unique<llm_build_gemma>(*this, params, gf);
+            } break;
+        case LLM_ARCH_GEMMA2:
+            {
+                llm = std::make_unique<llm_build_gemma2>(*this, params, gf);
+            } break;
+        case LLM_ARCH_GEMMA3:
+            {
+                llm = std::make_unique<llm_build_gemma3>(*this, params, gf);
+            } break;
+        case LLM_ARCH_STARCODER2:
+            {
+                llm = std::make_unique<llm_build_starcoder2>(*this, params, gf);
+            } break;
+        case LLM_ARCH_MAMBA:
+            {
+                llm = std::make_unique<llm_build_mamba>(*this, params, gf);
+            } break;
+        case LLM_ARCH_XVERSE:
+            {
+                llm = std::make_unique<llm_build_xverse>(*this, params, gf);
+            } break;
+        case LLM_ARCH_COMMAND_R:
+            {
+                llm = std::make_unique<llm_build_command_r>(*this, params, gf);
+            } break;
+        case LLM_ARCH_COHERE2:
+            {
+                llm = std::make_unique<llm_build_cohere2>(*this, params, gf);
+            } break;
+        case LLM_ARCH_DBRX:
+            {
+                llm = std::make_unique<llm_build_dbrx>(*this, params, gf);
+            } break;
+        case LLM_ARCH_OLMO:
+            {
+                llm = std::make_unique<llm_build_olmo>(*this, params, gf);
+            } break;
+        case LLM_ARCH_OLMO2:
+            {
+                llm = std::make_unique<llm_build_olmo2>(*this, params, gf);
+            } break;
+        case LLM_ARCH_OLMOE:
+            {
+                llm = std::make_unique<llm_build_olmoe>(*this, params, gf);
+            } break;
+        case LLM_ARCH_OPENELM:
+            {
+                llm = std::make_unique<llm_build_openelm>(*this, params, gf);
+            } break;
+        case LLM_ARCH_GPTNEOX:
+            {
+                llm = std::make_unique<llm_build_gptneox>(*this, params, gf);
+            } break;
+        case LLM_ARCH_ARCTIC:
+            {
+                llm = std::make_unique<llm_build_arctic>(*this, params, gf);
+            } break;
+        case LLM_ARCH_DEEPSEEK:
+            {
+                llm = std::make_unique<llm_build_deepseek>(*this, params, gf);
+            } break;
+        case LLM_ARCH_DEEPSEEK2:
+            {
+                llm = std::make_unique<llm_build_deepseek2>(*this, params, gf);
+            } break;
+        case LLM_ARCH_CHATGLM:
+            {
+                llm = std::make_unique<llm_build_chatglm>(*this, params, gf);
+            } break;
+        case LLM_ARCH_BITNET:
+            {
+                llm = std::make_unique<llm_build_bitnet>(*this, params, gf);
+            } break;
+        case LLM_ARCH_T5:
+            {
+                switch (type) {
+                    case LLM_GRAPH_TYPE_ENCODER:
+                        llm = std::make_unique<llm_build_t5_enc>(*this, params, gf);
+                        break;
+                    case LLM_GRAPH_TYPE_DEFAULT:
+                    case LLM_GRAPH_TYPE_DECODER:
+                        llm = std::make_unique<llm_build_t5_dec>(*this, params, gf);
+                        break;
+                    default:
+                        GGML_ABORT("invalid graph type");
+                };
+            } break;
+            //case LLM_ARCH_T5ENCODER:
+            //    {
+            //        llm.build_t5_enc(gf);
+            //    } break;
+        case LLM_ARCH_JAIS:
+            {
+                llm = std::make_unique<llm_build_jais>(*this, params, gf);
+            } break;
+        case LLM_ARCH_NEMOTRON:
+            {
+                llm = std::make_unique<llm_build_nemotron>(*this, params, gf);
+            } break;
+        case LLM_ARCH_EXAONE:
+            {
+                llm = std::make_unique<llm_build_exaone>(*this, params, gf);
+            } break;
+        case LLM_ARCH_RWKV6:
+            {
+                llm = std::make_unique<llm_build_rwkv6>(*this, params, gf);
+            } break;
+        case LLM_ARCH_RWKV6QWEN2:
+            {
+                llm = std::make_unique<llm_build_rwkv6qwen2>(*this, params, gf);
+            } break;
+        case LLM_ARCH_CHAMELEON:
+            {
+                llm = std::make_unique<llm_build_chameleon>(*this, params, gf);
+            } break;
+        case LLM_ARCH_WAVTOKENIZER_DEC:
+            {
+                llm = std::make_unique<llm_build_wavtokenizer_dec>(*this, params, gf);
+            } break;
+        default:
+            GGML_ABORT("fatal error");
+    }
+
+    // add on pooling layer
+    llm->build_pooling(gf, cls, cls_b, cls_out, cls_out_b);
+
+    return std::move(llm->res);
+}
+
+//
+// interface implementation
+//
+
+llama_model_params llama_model_default_params() {
+    llama_model_params result = {
+        /*.devices                     =*/ nullptr,
+        /*.n_gpu_layers                =*/ 0,
+        /*.split_mode                  =*/ LLAMA_SPLIT_MODE_LAYER,
+        /*.main_gpu                    =*/ 0,
+        /*.tensor_split                =*/ nullptr,
+        /*.progress_callback           =*/ nullptr,
+        /*.progress_callback_user_data =*/ nullptr,
+        /*.kv_overrides                =*/ nullptr,
+        /*.vocab_only                  =*/ false,
+        /*.use_mmap                    =*/ true,
+        /*.use_mlock                   =*/ false,
+        /*.check_tensors               =*/ false,
+    };
+
+#ifdef GGML_USE_METAL
+    // note: we usually have plenty of VRAM, so by default offload all layers to the GPU
+    result.n_gpu_layers = 999;
+#endif
+
+    return result;
+}
+
+const llama_vocab * llama_model_get_vocab(const llama_model * model) {
+    return &model->vocab;
+}
+
+void llama_free_model(llama_model * model) {
+    llama_model_free(model);
+}
+
+void llama_model_free(llama_model * model) {
+    delete model;
+}
+
+int32_t llama_model_n_ctx_train(const llama_model * model) {
+    return model->hparams.n_ctx_train;
+}
+
+int32_t llama_model_n_embd(const llama_model * model) {
+    return model->hparams.n_embd;
+}
+
+int32_t llama_model_n_layer(const llama_model * model) {
+    return model->hparams.n_layer;
+}
+
+int32_t llama_model_n_head(const llama_model * model) {
+    return model->hparams.n_head();
+}
+
+int32_t llama_model_n_head_kv(const llama_model * model) {
+    return model->hparams.n_head_kv();
+}
+
+// deprecated
+int32_t llama_n_ctx_train(const llama_model * model) {
+    return llama_model_n_ctx_train(model);
+}
+
+// deprecated
+int32_t llama_n_embd(const llama_model * model) {
+    return llama_model_n_embd(model);
+}
+
+// deprecated
+int32_t llama_n_layer(const llama_model * model) {
+    return llama_model_n_layer(model);
+}
+
+// deprecated
+int32_t llama_n_head(const llama_model * model) {
+    return llama_model_n_head(model);
+}
+
+llama_rope_type llama_model_rope_type(const llama_model * model) {
+    switch (model->arch) {
+        // these models do not use RoPE
+        case LLM_ARCH_GPT2:
+        case LLM_ARCH_GPTJ:
+        case LLM_ARCH_MPT:
+        case LLM_ARCH_REFACT:
+        case LLM_ARCH_BLOOM:
+        case LLM_ARCH_MAMBA:
+        case LLM_ARCH_JINA_BERT_V2:
+        case LLM_ARCH_T5:
+        case LLM_ARCH_T5ENCODER:
+        case LLM_ARCH_JAIS:
+        case LLM_ARCH_RWKV6:
+        case LLM_ARCH_RWKV6QWEN2:
+        case LLM_ARCH_WAVTOKENIZER_DEC:
+            return LLAMA_ROPE_TYPE_NONE;
+
+        // use what we call a normal RoPE, operating on pairs of consecutive head values
+        case LLM_ARCH_LLAMA:
+        case LLM_ARCH_DECI:
+        case LLM_ARCH_BAICHUAN:
+        case LLM_ARCH_STARCODER:
+        case LLM_ARCH_PLAMO:
+        case LLM_ARCH_ORION:
+        case LLM_ARCH_INTERNLM2:
+        case LLM_ARCH_MINICPM:
+        case LLM_ARCH_XVERSE:
+        case LLM_ARCH_COMMAND_R:
+        case LLM_ARCH_COHERE2:
+        case LLM_ARCH_OLMO:
+        case LLM_ARCH_ARCTIC:
+        case LLM_ARCH_DEEPSEEK:
+        case LLM_ARCH_DEEPSEEK2:
+        case LLM_ARCH_CHATGLM:
+        case LLM_ARCH_GRANITE:
+        case LLM_ARCH_GRANITE_MOE:
+        case LLM_ARCH_CHAMELEON:
+            return LLAMA_ROPE_TYPE_NORM;
+
+        // the pairs of head values are offset by n_rot/2
+        case LLM_ARCH_FALCON:
+        case LLM_ARCH_GROK:
+        case LLM_ARCH_DBRX:
+        case LLM_ARCH_BERT:
+        case LLM_ARCH_NOMIC_BERT:
+        case LLM_ARCH_STABLELM:
+        case LLM_ARCH_BITNET:
+        case LLM_ARCH_QWEN:
+        case LLM_ARCH_QWEN2:
+        case LLM_ARCH_QWEN2MOE:
+        case LLM_ARCH_OLMO2:
+        case LLM_ARCH_OLMOE:
+        case LLM_ARCH_PHI2:
+        case LLM_ARCH_PHI3:
+        case LLM_ARCH_PHIMOE:
+        case LLM_ARCH_GEMMA:
+        case LLM_ARCH_GEMMA2:
+        case LLM_ARCH_GEMMA3:
+        case LLM_ARCH_STARCODER2:
+        case LLM_ARCH_OPENELM:
+        case LLM_ARCH_GPTNEOX:
+        case LLM_ARCH_CODESHELL:
+        case LLM_ARCH_NEMOTRON:
+        case LLM_ARCH_EXAONE:
+        case LLM_ARCH_MINICPM3:
+            return LLAMA_ROPE_TYPE_NEOX;
+
+        case LLM_ARCH_QWEN2VL:
+            return LLAMA_ROPE_TYPE_MROPE;
+
+        // all model arches should be listed explicitly here
+        case LLM_ARCH_UNKNOWN:
+            GGML_ABORT("unknown architecture");
+    }
+
+    return LLAMA_ROPE_TYPE_NONE;
+}
+
+float llama_model_rope_freq_scale_train(const llama_model * model) {
+    return model->hparams.rope_freq_scale_train;
+}
+
+int32_t llama_model_meta_val_str(const llama_model * model, const char * key, char * buf, size_t buf_size) {
+    const auto & it = model->gguf_kv.find(key);
+    if (it == model->gguf_kv.end()) {
+        if (buf_size > 0) {
+            buf[0] = '\0';
+        }
+        return -1;
+    }
+    return snprintf(buf, buf_size, "%s", it->second.c_str());
+}
+
+int32_t llama_model_meta_count(const llama_model * model) {
+    return (int)model->gguf_kv.size();
+}
+
+int32_t llama_model_meta_key_by_index(const llama_model * model, int i, char * buf, size_t buf_size) {
+    if (i < 0 || i >= (int)model->gguf_kv.size()) {
+        if (buf_size > 0) {
+            buf[0] = '\0';
+        }
+        return -1;
+    }
+    auto it = model->gguf_kv.begin();
+    std::advance(it, i);
+    return snprintf(buf, buf_size, "%s", it->first.c_str());
+}
+
+int32_t llama_model_meta_val_str_by_index(const llama_model * model, int32_t i, char * buf, size_t buf_size) {
+    if (i < 0 || i >= (int)model->gguf_kv.size()) {
+        if (buf_size > 0) {
+            buf[0] = '\0';
+        }
+        return -1;
+    }
+    auto it = model->gguf_kv.begin();
+    std::advance(it, i);
+    return snprintf(buf, buf_size, "%s", it->second.c_str());
+}
+
+int32_t llama_model_desc(const llama_model * model, char * buf, size_t buf_size) {
+    return snprintf(buf, buf_size, "%s", model->desc().c_str());
+}
+
+uint64_t llama_model_size(const llama_model * model) {
+    return model->size();
+}
+
+const char * llama_model_chat_template(const llama_model * model, const char * name) {
+    const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N)
+        : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
+    const auto & it = model->gguf_kv.find(key);
+    if (it == model->gguf_kv.end()) {
+        return nullptr;
+    }
+
+    return it->second.c_str();
+}
+
+uint64_t llama_model_n_params(const llama_model * model) {
+    return model->n_elements();
+}
+
+bool llama_model_has_encoder(const llama_model * model) {
+    switch (model->arch) {
+        case LLM_ARCH_T5:        return true;
+        case LLM_ARCH_T5ENCODER: return true;
+        default:                 return false;
+    }
+}
+
+bool llama_model_has_decoder(const llama_model * model) {
+    switch (model->arch) {
+        case LLM_ARCH_T5ENCODER: return false;
+        default:                 return true;
+    }
+}
+
+llama_token llama_model_decoder_start_token(const llama_model * model) {
+    return model->hparams.dec_start_token_id;
+}
+
+bool llama_model_is_recurrent(const llama_model * model) {
+    switch (model->arch) {
+        case     LLM_ARCH_MAMBA:      return true;
+        case     LLM_ARCH_RWKV6:      return true;
+        case     LLM_ARCH_RWKV6QWEN2: return true;
+        default: return false;
+    }
+}
+
+const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model) {
+    return model->tensors_by_name;
 }
index a7c30444786fdf5aeeb255c3c35d01d0fcf4cb7a..55c26a92b02d25656142badd7c0d5e8781bce606 100644 (file)
@@ -2,7 +2,9 @@
 
 #include "llama.h"
 #include "llama-arch.h"
+#include "llama-graph.h"
 #include "llama-hparams.h"
+#include "llama-memory.h"
 #include "llama-vocab.h"
 
 #include <memory>
@@ -10,6 +12,8 @@
 #include <unordered_map>
 #include <vector>
 
+struct llama_cparams;
+struct llama_ubatch;
 struct llama_model_loader;
 
 // available models
@@ -347,7 +351,7 @@ struct llama_model {
     std::string desc() const;
 
     size_t size() const;
-    size_t max_nodes() const;
+    size_t n_tensors() const;
     size_t n_devices() const;
 
     // total number of parameters in the model
@@ -362,9 +366,22 @@ struct llama_model {
 
     const struct ggml_tensor * get_tensor(const char * name) const;
 
+    // TODO: move this to new llm_arch_model_i interface
+    llama_memory_i * create_memory() const; // TODO: params
+
+    // TODO: move this to new llm_arch_model_i interface
+    llm_graph_result_ptr build_graph(
+            const llm_graph_params & params,
+                       ggml_cgraph * gf,
+                    llm_graph_type   type) const;
+
 private:
     struct impl;
     std::unique_ptr<impl> pimpl;
 };
 
 const char * llm_type_name(llm_type type);
+
+// For internal test use
+// TODO: remove
+const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model);
index 4a4e91490107c1f9384ca601d81a092443c60610..81e1dd1d0873a1ecc9617ccdc60b25082711f41e 100644 (file)
 
 #include "llama-chat.h"
 #include "llama-mmap.h"
-#include "llama-context.h"
 #include "llama-vocab.h"
-#include "llama-sampling.h"
-#include "llama-kv-cache.h"
 #include "llama-model-loader.h"
 #include "llama-model.h"
 
 #include "ggml.h"
-#include "ggml-alloc.h"
 #include "ggml-backend.h"
-#include "ggml-cpp.h"
 
 #include <algorithm>
-#include <array>
-#include <cassert>
-#include <cfloat>
-#include <cmath>
 #include <cstddef>
 #include <cstdint>
 #include <cstdio>
 #include <cstring>
 #include <ctime>
-#include <functional>
 
 #if defined(_MSC_VER)
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
 
-// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
-static int llama_model_load(const std::string & fname, std::vector<std::string> & splits, llama_model & model, llama_model_params & params) {
-    // loading time will be recalculated after the first eval, so
-    // we take page faults deferred by mmap() into consideration
-    model.t_load_us = 0;
-    time_meas tm(model.t_load_us);
-
-    model.t_start_us = tm.t_start_us;
-
-    try {
-        llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides);
-
-        ml.print_info();
-
-        model.hparams.vocab_only = params.vocab_only;
-
-        try {
-            model.load_arch(ml);
-        } catch(const std::exception & e) {
-            throw std::runtime_error("error loading model architecture: " + std::string(e.what()));
-        }
-        try {
-            model.load_hparams(ml);
-        } catch(const std::exception & e) {
-            throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
-        }
-        try {
-            model.load_vocab(ml);
-        } catch(const std::exception & e) {
-            throw std::runtime_error("error loading model vocabulary: " + std::string(e.what()));
-        }
-
-        model.load_stats(ml);
-        model.print_info();
-
-        if (params.vocab_only) {
-            LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__);
-            return 0;
-        }
-
-        if (!model.load_tensors(ml)) {
-            return -2;
-        }
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
-        return -1;
-    }
-
-    return 0;
-}
-
-//
-// llm_build
-//
-
-using llm_build_cb = std::function<void(struct ggml_tensor * cur, const char * name, int nl)>;
-
-enum llm_ffn_op_type {
-    LLM_FFN_SILU,
-    LLM_FFN_GELU,
-    LLM_FFN_RELU,
-    LLM_FFN_RELU_SQR,
-    LLM_FFN_SWIGLU,
-};
-
-enum llm_ffn_gate_type {
-    LLM_FFN_SEQ,
-    LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
-};
-
-enum llm_norm_type {
-    LLM_NORM,
-    LLM_NORM_RMS,
-    LLM_NORM_GROUP,
-};
-
-static struct ggml_tensor * llm_build_inp_embd(
-        struct ggml_context * ctx,
-       struct llama_context & lctx,
-        const llama_hparams & hparams,
-         const llama_ubatch & ubatch,
-         struct ggml_tensor * tok_embd,
-         const llm_build_cb & cb) {
-    const int64_t n_embd = hparams.n_embd;
-
-    struct ggml_tensor * inpL;
-
-    if (ubatch.token) {
-        lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ubatch.n_tokens);
-        cb(lctx.inp_tokens, "inp_tokens", -1);
-        ggml_set_input(lctx.inp_tokens);
-
-        inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
-
-        // apply lora for embedding tokens if needed
-        for (auto & it : lctx.lora) {
-            struct llama_adapter_lora_weight * lw = it.first->get_weight(tok_embd);
-            if (lw == nullptr) {
-                continue;
-            }
-            const float adapter_scale = it.second;
-            const float scale = lw->get_scale(it.first->alpha, adapter_scale);
-            struct ggml_tensor * inpL_delta = ggml_scale(ctx, ggml_mul_mat(
-                ctx, lw->b, // non-transposed lora_b
-                ggml_get_rows(ctx, lw->a, lctx.inp_tokens)
-            ), scale);
-            inpL = ggml_add(ctx, inpL, inpL_delta);
-        }
-    } else {
-        lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
-        inpL = lctx.inp_embd;
-        ggml_set_input(lctx.inp_embd);
-    }
-
-    // For Granite architecture
-    if (hparams.f_embedding_scale != 0.0f) {
-        inpL = ggml_scale(ctx, inpL, hparams.f_embedding_scale);
-    }
-
-    cb(inpL, "inp_embd", -1);
-
-    return inpL;
-}
-
-static void llm_build_kv_store(
-        struct ggml_context * ctx,
-        const llama_hparams & hparams,
-        const llama_cparams & cparams,
-       const llama_kv_cache & kv,
-         struct ggml_cgraph * graph,
-         struct ggml_tensor * k_cur,
-         struct ggml_tensor * v_cur,
-                    int32_t   n_tokens,
-                    int32_t   kv_head,
-         const llm_build_cb & cb,
-                    int64_t   il) {
-    const int64_t n_ctx = cparams.n_ctx;
-
-    const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
-    const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
-
-    GGML_ASSERT(kv.size == n_ctx);
-
-    struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa)*kv_head);
-    cb(k_cache_view, "k_cache_view", il);
-
-    // note: storing RoPE-ed version of K in the KV cache
-    ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
-
-    assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
-
-    struct ggml_tensor * v_cache_view = nullptr;
-
-    if (cparams.flash_attn) {
-        v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa)*kv_head);
-    } else {
-        // note: the V cache is transposed when not using flash attention
-        v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa,
-                (  n_ctx)*ggml_element_size(kv.v_l[il]),
-                (kv_head)*ggml_element_size(kv.v_l[il]));
-
-        v_cur = ggml_transpose(ctx, v_cur);
-    }
-    cb(v_cache_view, "v_cache_view", il);
-
-    ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
-}
-
-// do mat_mul, while optionally apply lora
-static struct ggml_tensor * llm_build_lora_mm(
-        struct llama_context & lctx,
-         struct ggml_context * ctx0,
-          struct ggml_tensor * w,
-          struct ggml_tensor * cur) {
-    struct ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
-    for (auto & it : lctx.lora) {
-        struct llama_adapter_lora_weight * lw = it.first->get_weight(w);
-        if (lw == nullptr) {
-            continue;
-        }
-        const float adapter_scale = it.second;
-        const float scale = lw->get_scale(it.first->alpha, adapter_scale);
-        struct ggml_tensor * ab_cur = ggml_mul_mat(
-            ctx0, lw->b,
-            ggml_mul_mat(ctx0, lw->a, cur)
-        );
-        ab_cur = ggml_scale(ctx0, ab_cur, scale);
-        res = ggml_add(ctx0, res, ab_cur);
-    }
-    return res;
-}
-
-// do mat_mul_id, while optionally apply lora
-static struct ggml_tensor * llm_build_lora_mm_id(
-        struct llama_context & lctx,
-         struct ggml_context * ctx0,
-          struct ggml_tensor * w,   // struct ggml_tensor * as
-          struct ggml_tensor * cur, // struct ggml_tensor * b
-          struct ggml_tensor * ids) {
-    struct ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
-    for (auto & it : lctx.lora) {
-        struct llama_adapter_lora_weight * lw = it.first->get_weight(w);
-        if (lw == nullptr) {
-            continue;
-        }
-        const float alpha = it.first->alpha;
-        const float rank  = (float) lw->b->ne[0];
-        const float scale = alpha ? it.second * alpha / rank : it.second;
-        struct ggml_tensor * ab_cur = ggml_mul_mat_id(
-            ctx0, lw->b,
-            ggml_mul_mat_id(ctx0, lw->a, cur, ids),
-            ids
-        );
-        ab_cur = ggml_scale(ctx0, ab_cur, scale);
-        res = ggml_add(ctx0, res, ab_cur);
-    }
-    return res;
-}
-
-static struct ggml_tensor * llm_build_norm(
-        struct ggml_context * ctx,
-         struct ggml_tensor * cur,
-        const llama_hparams & hparams,
-         struct ggml_tensor * mw,
-         struct ggml_tensor * mb,
-              llm_norm_type   type,
-         const llm_build_cb & cb,
-                        int   il) {
-    switch (type) {
-        case LLM_NORM:       cur = ggml_norm      (ctx, cur, hparams.f_norm_eps);     break;
-        case LLM_NORM_RMS:   cur = ggml_rms_norm  (ctx, cur, hparams.f_norm_rms_eps); break;
-        case LLM_NORM_GROUP:
-            {
-                cur = ggml_reshape_3d(ctx, cur, cur->ne[0], 1, cur->ne[1]);
-                cur = ggml_group_norm(ctx, cur, hparams.n_norm_groups, hparams.f_norm_group_eps);
-                cur = ggml_reshape_2d(ctx, cur, cur->ne[0],    cur->ne[2]);
-            } break;
-    }
-
-    if (mw || mb) {
-        cb(cur, "norm", il);
-    }
-
-    if (mw) {
-        cur = ggml_mul(ctx, cur, mw);
-        if (mb) {
-            cb(cur, "norm_w", il);
-        }
-    }
-
-    if (mb) {
-        cur = ggml_add(ctx, cur, mb);
-    }
-
-    return cur;
-}
-
-static struct ggml_tensor * llm_build_ffn(
-        struct ggml_context * ctx,
-       struct llama_context & lctx,
-         struct ggml_tensor * cur,
-         struct ggml_tensor * up,
-         struct ggml_tensor * up_b,
-         struct ggml_tensor * up_s,
-         struct ggml_tensor * gate,
-         struct ggml_tensor * gate_b,
-         struct ggml_tensor * gate_s,
-         struct ggml_tensor * down,
-         struct ggml_tensor * down_b,
-         struct ggml_tensor * down_s,
-         struct ggml_tensor * act_scales,
-            llm_ffn_op_type   type_op,
-          llm_ffn_gate_type   type_gate,
-         const llm_build_cb & cb,
-                        int   il) {
-    struct ggml_tensor * tmp = up ? llm_build_lora_mm(lctx, ctx, up, cur) : cur;
-    cb(tmp, "ffn_up", il);
-
-    if (up_b) {
-        tmp = ggml_add(ctx, tmp, up_b);
-        cb(tmp, "ffn_up_b", il);
-    }
-
-    if (up_s) {
-        tmp = ggml_mul(ctx, tmp, up_s);
-        cb(tmp, "ffn_up_s", il);
-    }
-
-    if (gate) {
-        switch (type_gate) {
-            case LLM_FFN_SEQ:
-                {
-                    cur = llm_build_lora_mm(lctx, ctx, gate, tmp);
-                    cb(cur, "ffn_gate", il);
-                } break;
-            case LLM_FFN_PAR:
-                {
-                    cur = llm_build_lora_mm(lctx, ctx, gate, cur);
-                    cb(cur, "ffn_gate", il);
-                } break;
-        }
-
-        if (gate_b) {
-            cur = ggml_add(ctx, cur, gate_b);
-            cb(cur, "ffn_gate_b", il);
-        }
-
-        if (gate_s) {
-            cur = ggml_mul(ctx, cur, gate_s);
-            cb(cur, "ffn_gate_s", il);
-        }
-
-    } else {
-        cur = tmp;
-    }
-
-    switch (type_op) {
-        case LLM_FFN_SILU:
-            {
-                cur = ggml_silu(ctx, cur);
-                cb(cur, "ffn_silu", il);
-            } break;
-        case LLM_FFN_GELU:
-            {
-                cur = ggml_gelu(ctx, cur);
-                cb(cur, "ffn_gelu", il);
-                if (act_scales != NULL) {
-                    cur = ggml_div(ctx, cur, act_scales);
-                    cb(cur, "ffn_act", il);
-                }
-            } break;
-        case LLM_FFN_RELU:
-            {
-                cur = ggml_relu(ctx, cur);
-                cb(cur, "ffn_relu", il);
-            } break;
-        case LLM_FFN_RELU_SQR:
-            {
-                cur = ggml_relu(ctx, cur);
-                cb(cur, "ffn_relu", il);
-
-                cur = ggml_sqr(ctx, cur);
-                cb(cur, "ffn_sqr(relu)", il);
-            } break;
-        case LLM_FFN_SWIGLU:
-            {
-                // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
-                int64_t split_point = cur->ne[0] / 2;
-                struct ggml_tensor * x0 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], 0));
-                struct ggml_tensor * x1 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
-
-                x0 = ggml_silu(ctx, x0);
-                cb(cur, "ffn_silu", il);
-
-                cur = ggml_mul(ctx, x0, x1);
-                cb(cur, "ffn_mul", il);
-            } break;
-    }
-
-    if (type_gate == LLM_FFN_PAR) {
-        cur = ggml_mul(ctx, cur, tmp);
-        cb(cur, "ffn_gate_par", il);
-    }
-
-    if (down) {
-        cur = llm_build_lora_mm(lctx, ctx, down, cur);
-    }
-
-    if (down_b) {
-        cb(cur, "ffn_down", il);
-    }
-
-    if (down_b) {
-        cur = ggml_add(ctx, cur, down_b);
-    }
-
-    if (down_s) {
-        cur = ggml_mul(ctx, cur, down_s);
-        cb(cur, "ffn_down_s", il);
-    }
-
-    return cur;
-}
-
-static struct ggml_tensor * llm_build_moe_ffn(
-        struct ggml_context * ctx,
-       struct llama_context & lctx,
-         struct ggml_tensor * cur,
-         struct ggml_tensor * gate_inp,
-         struct ggml_tensor * up_exps,
-         struct ggml_tensor * gate_exps,
-         struct ggml_tensor * down_exps,
-         struct ggml_tensor * exp_probs_b,
-                    int64_t   n_expert,
-                    int64_t   n_expert_used,
-            llm_ffn_op_type   type_op,
-                       bool   norm_w,
-                       bool   scale_w,
-                      float   w_scale,
-llama_expert_gating_func_type gating_op,
-         const llm_build_cb & cb,
-                        int   il) {
-    int64_t n_embd = cur->ne[0];
-    int64_t n_tokens = cur->ne[1];
-
-    ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens]
-    cb(logits, "ffn_moe_logits", il);
-
-    ggml_tensor * probs = nullptr;
-    switch (gating_op) {
-        case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
-            {
-                probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
-            } break;
-        case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID:
-            {
-                probs = ggml_sigmoid(ctx, logits); // [n_expert, n_tokens]
-            } break;
-        default:
-            GGML_ABORT("fatal error");
-    }
-    cb(probs, "ffn_moe_probs", il);
-
-    // add experts selection bias - introduced in DeepSeek V3
-    // leave probs unbiased as it's later used to get expert weights
-    ggml_tensor * selection_probs = probs;
-    if (exp_probs_b != nullptr) {
-        selection_probs = ggml_add(ctx, probs, exp_probs_b);
-        cb(selection_probs, "ffn_moe_probs_biased", il);
-    }
-
-    // select experts
-    ggml_tensor * selected_experts = ggml_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
-    cb(selected_experts->src[0], "ffn_moe_argsort", il);
-    cb(selected_experts, "ffn_moe_topk", il);
-
-    ggml_tensor * weights = ggml_get_rows(ctx,
-            ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
-    cb(weights, "ffn_moe_weights", il);
-
-    if (norm_w) {
-        weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
-
-        ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); // [1, n_tokens]
-        cb(weights_sum, "ffn_moe_weights_sum", il);
-
-        weights = ggml_div(ctx, weights, weights_sum); // [n_expert_used, n_tokens]
-        cb(weights, "ffn_moe_weights_norm", il);
-
-        weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);
-    }
-    if (scale_w) {
-        weights = ggml_scale(ctx, weights, w_scale);
-        cb(weights, "ffn_moe_weights_scaled", il);
-    }
-
-    cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
-    ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
-    cb(up, "ffn_moe_up", il);
-
-    ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
-    cb(gate, "ffn_moe_gate", il);
-
-    switch (type_op) {
-        case LLM_FFN_SILU:
-            {
-                gate = ggml_silu(ctx, gate);
-                cb(gate, "ffn_moe_silu", il);
-            } break;
-        case LLM_FFN_GELU:
-            {
-                gate = ggml_gelu(ctx, gate);
-                cb(gate, "ffn_moe_gelu", il);
-            } break;
-        default:
-            GGML_ABORT("fatal error");
-    }
-
-    ggml_tensor * par = ggml_mul(ctx, up, gate); // [n_ff, n_expert_used, n_tokens]
-    cb(par, "ffn_moe_gate_par", il);
-
-    ggml_tensor * experts = llm_build_lora_mm_id(lctx, ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
-    cb(experts, "ffn_moe_down", il);
-
-    experts = ggml_mul(ctx, experts, weights);
-
-    // aggregate experts
-    ggml_tensor * moe_out = nullptr;
-    for (int i = 0; i < n_expert_used; ++i) {
-        ggml_tensor * cur_expert = ggml_view_2d(ctx, experts, n_embd, n_tokens,
-                experts->nb[2], i*experts->nb[1]);
-
-        if (i == 0) {
-            moe_out = cur_expert;
-        } else {
-            moe_out = ggml_add(ctx, moe_out, cur_expert);
-        }
-    }
-
-    if (n_expert_used == 1) {
-        // avoid returning a non-contiguous tensor
-        moe_out = ggml_cont(ctx, moe_out);
-    }
-
-    return moe_out;
-}
-
-static struct ggml_tensor * llm_build_kqv(
-        struct ggml_context * ctx,
-       struct llama_context & lctx,
-       const llama_kv_cache & kv,
-         struct ggml_cgraph * graph,
-         struct ggml_tensor * wo,
-         struct ggml_tensor * wo_b,
-         struct ggml_tensor * q_cur,
-         struct ggml_tensor * kq_mask,
-                    int32_t   n_tokens,
-                    int32_t   n_kv,
-                    float     kq_scale,
-         const llm_build_cb & cb,
-                    int       il) {
-    const llama_model   & model   = lctx.model;
-    const llama_hparams & hparams = lctx.model.hparams;
-    const llama_cparams & cparams = lctx.cparams;
-
-    const int64_t n_ctx         = cparams.n_ctx;
-    const int64_t n_head        = hparams.n_head(il);
-    const int64_t n_head_kv     = hparams.n_head_kv(il);
-    const int64_t n_embd_head_k = hparams.n_embd_head_k;
-    const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa(il);
-    const int64_t n_embd_head_v = hparams.n_embd_head_v;
-    const int64_t n_embd_v_gqa  = hparams.n_embd_v_gqa(il);
-
-    struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
-    cb(q, "q", il);
-
-    struct ggml_tensor * k =
-        ggml_view_3d(ctx, kv.k_l[il],
-                n_embd_head_k, n_kv, n_head_kv,
-                ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa),
-                ggml_row_size(kv.k_l[il]->type, n_embd_head_k),
-                0);
-    cb(k, "k", il);
-
-    struct ggml_tensor * cur;
-
-    if (cparams.flash_attn) {
-        GGML_UNUSED(model);
-        GGML_UNUSED(n_ctx);
-
-        // split cached v into n_head heads (not transposed)
-        struct ggml_tensor * v =
-            ggml_view_3d(ctx, kv.v_l[il],
-                    n_embd_head_v, n_kv, n_head_kv,
-                    ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa),
-                    ggml_row_size(kv.v_l[il]->type, n_embd_head_v),
-                    0);
-        cb(v, "v", il);
-
-        cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
-                                  hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
-
-        ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
-
-        cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
-    } else {
-        struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
-        cb(kq, "kq", il);
-
-        // note: this op tends to require high floating point range
-        //       while for some models F16 is enough, for others it is not, so we default to F32 here
-        ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
-
-        if (model.arch == LLM_ARCH_GROK) {
-            // need to do the following:
-            // multiply by attn_output_multiplyer of 0.08838834764831845
-            // and then :
-            // kq = 30 * tanh(kq / 30)
-            // before the softmax below
-
-            kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
-            kq = ggml_scale(ctx, kq, 30);
-        }
-
-        if (hparams.attn_soft_cap) {
-            kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping);
-            kq = ggml_tanh(ctx, kq);
-            kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping);
-        }
-
-        kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
-        cb(kq, "kq_soft_max_ext", il);
-
-        GGML_ASSERT(kv.size == n_ctx);
-
-        // split cached v into n_head heads
-        struct ggml_tensor * v =
-            ggml_view_3d(ctx, kv.v_l[il],
-                    n_kv, n_embd_head_v, n_head_kv,
-                    ggml_element_size(kv.v_l[il])*n_ctx,
-                    ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v,
-                    0);
-        cb(v, "v", il);
-
-        struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
-        cb(kqv, "kqv", il);
-
-        struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
-        cb(kqv_merged, "kqv_merged", il);
-
-        cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens);
-        cb(cur, "kqv_merged_cont", il);
-    }
-
-    ggml_build_forward_expand(graph, cur);
-
-    if (wo) {
-        cur = llm_build_lora_mm(lctx, ctx, wo, cur);
-    }
-
-    if (wo_b) {
-        cb(cur, "kqv_wo", il);
-    }
-
-    if (wo_b) {
-        cur = ggml_add(ctx, cur, wo_b);
-    }
-
-    return cur;
-}
-
-static struct ggml_tensor * llm_build_kv(
-        struct ggml_context * ctx,
-       struct llama_context & lctx,
-       const llama_kv_cache & kv,
-         struct ggml_cgraph * graph,
-         struct ggml_tensor * wo,
-         struct ggml_tensor * wo_b,
-         struct ggml_tensor * k_cur,
-         struct ggml_tensor * v_cur,
-         struct ggml_tensor * q_cur,
-         struct ggml_tensor * kq_mask,
-                    int32_t   n_tokens,
-                    int32_t   kv_head,
-                    int32_t   n_kv,
-                    float     kq_scale,
-         const llm_build_cb & cb,
-                    int       il) {
-    const llama_hparams & hparams = lctx.model.hparams;
-    const llama_cparams & cparams = lctx.cparams;
-
-    // these nodes are added to the graph together so that they are not reordered
-    // by doing so, the number of splits in the graph is reduced
-    ggml_build_forward_expand(graph, q_cur);
-    ggml_build_forward_expand(graph, k_cur);
-    ggml_build_forward_expand(graph, v_cur);
-
-    llm_build_kv_store(ctx, hparams, cparams, kv, graph, k_cur, v_cur, n_tokens, kv_head, cb, il);
-
-    struct ggml_tensor * cur;
-
-    cur  = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b, q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il);
-    cb(cur, "kqv_out", il);
-
-    return cur;
-}
-
-static struct ggml_tensor * llm_build_copy_mask_state(
-        struct ggml_context * ctx,
-         struct ggml_cgraph * graph,
-         struct ggml_tensor * s,
-         struct ggml_tensor * state_copy,
-         struct ggml_tensor * state_mask,
-                    int32_t   n_state,
-                    int32_t   kv_size,
-                    int32_t   kv_head,
-                    int32_t   n_kv,
-                    int32_t   n_seqs) {
-    struct ggml_tensor * states = ggml_reshape_2d(ctx, s, n_state, kv_size);
-
-    // copy states
-    // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
-    // this shrinks the tensors's ne[1] to n_kv
-    states = ggml_get_rows(ctx, states, state_copy);
-
-    // clear states of sequences which are starting at the beginning of this batch
-    // FIXME: zero-out NANs?
-    states = ggml_mul(ctx, states, state_mask);
-
-    // copy states which won't be changed further (between n_seqs and n_kv)
-    ggml_build_forward_expand(graph,
-        ggml_cpy(ctx,
-            ggml_view_1d(ctx, states, n_state*(n_kv - n_seqs), n_seqs*n_state*ggml_element_size(states)),
-            ggml_view_1d(ctx, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
-
-    // the part of the states that will be used and modified
-    return ggml_view_2d(ctx, states, n_state, n_seqs, states->nb[1], 0);
-}
-
-// TODO: split
-static struct ggml_tensor * llm_build_mamba(
-        struct ggml_context * ctx,
-       struct llama_context & lctx,
-         const llama_ubatch & ubatch,
-         struct ggml_cgraph * graph,
-         struct ggml_tensor * cur,
-         struct ggml_tensor * state_copy,
-         struct ggml_tensor * state_mask,
-                    int32_t   kv_head,
-                    int32_t   n_kv,
-         const llm_build_cb & cb,
-                    int       il) {
-    const llama_model    & model   = lctx.model;
-    const llama_hparams  & hparams = model.hparams;
-    const llama_kv_cache & kv      = lctx.kv_self;
-    const int64_t d_conv  = hparams.ssm_d_conv;
-    const int64_t d_inner = hparams.ssm_d_inner;
-    const int64_t d_state = hparams.ssm_d_state;
-    const int64_t dt_rank = hparams.ssm_dt_rank;
-    const int64_t n_seqs  = ubatch.n_seqs;
-    // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
-    const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
-    // Use the same RMS norm as the final layer norm
-    const float norm_rms_eps = hparams.f_norm_rms_eps;
-
-    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
-
-    GGML_ASSERT(n_seqs != 0);
-    GGML_ASSERT(ubatch.equal_seqs);
-    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
-
-    struct ggml_tensor * conv_states_all = kv.k_l[il];
-    struct ggml_tensor * ssm_states_all  = kv.v_l[il];
-
-    // (ab)using the KV cache to store the states
-    struct ggml_tensor * conv = llm_build_copy_mask_state(ctx,
-            graph, conv_states_all, state_copy, state_mask,
-            hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs);
-    conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner, n_seqs);
-    struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx,
-            graph, ssm_states_all, state_copy, state_mask,
-            hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs);
-    ssm = ggml_reshape_3d(ctx, ssm, d_state, d_inner, n_seqs);
-
-    // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
-    cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs);
-
-    // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs}
-    struct ggml_tensor * xz = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_in, cur);
-    // split the above in two
-    // => {d_inner, n_seq_tokens, n_seqs}
-    struct ggml_tensor * x = ggml_view_3d(ctx, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0);
-    struct ggml_tensor * z = ggml_view_3d(ctx, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], d_inner*ggml_element_size(xz));
-
-    // conv
-    {
-        // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
-        struct ggml_tensor * conv_x = ggml_concat(ctx, conv, ggml_transpose(ctx, x), 0);
-
-        // copy last (d_conv - 1) columns back into the state cache
-        struct ggml_tensor * last_conv = ggml_view_3d(ctx, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
-
-        ggml_build_forward_expand(graph,
-            ggml_cpy(ctx, last_conv,
-                ggml_view_1d(ctx, conv_states_all,
-                    (d_conv - 1)*(d_inner)*(n_seqs),
-                    kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all))));
-
-        // 1D convolution
-        // The equivalent is to make a self-overlapping view of conv_x
-        // over d_conv columns at each stride in the 3rd dimension,
-        // then element-wise multiply that with the conv1d weight,
-        // then sum the elements of each row,
-        // (the last two steps are a dot product over rows (also doable with mul_mat))
-        // then permute away the ne[0] dimension,
-        // and then you're left with the resulting x tensor.
-        // For simultaneous sequences, all sequences need to have the same length.
-        x = ggml_ssm_conv(ctx, conv_x, model.layers[il].ssm_conv1d);
-
-        // bias
-        x = ggml_add(ctx, x, model.layers[il].ssm_conv1d_b);
-
-        x = ggml_silu(ctx, x);
-    }
-
-    // ssm
-    {
-        // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs}
-        struct ggml_tensor * x_db = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_x, x);
-        // split
-        struct ggml_tensor * dt = ggml_view_3d(ctx, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0);
-        struct ggml_tensor * B  = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank);
-        struct ggml_tensor * C  = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state));
-
-        // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
-        if (ssm_dt_b_c_rms) {
-            dt = ggml_rms_norm(ctx, dt, norm_rms_eps);
-            B = ggml_rms_norm(ctx, B, norm_rms_eps);
-            C = ggml_rms_norm(ctx, C, norm_rms_eps);
-        }
-
-        // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
-        dt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_dt, dt);
-        dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b);
-
-        // Custom operator to optimize the parallel associative scan
-        // as described in the Annex D of the Mamba paper.
-        // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
-        struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C);
-
-        // store last states
-        ggml_build_forward_expand(graph,
-            ggml_cpy(ctx,
-                ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, x->nb[3]),
-                ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
-
-        struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0);
-
-        // TODO: skip computing output earlier for unused tokens
-
-        // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs}
-        y = ggml_add(ctx, y, ggml_mul(ctx, x, model.layers[il].ssm_d));
-        y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z)));
-
-        // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
-        cur = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_out, y);
-    }
-
-    // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
-    cur = ggml_reshape_2d(ctx, cur, cur->ne[0], n_seq_tokens * n_seqs);
-    cb(cur, "mamba_out", il);
-
-    return cur;
-}
-
-static struct ggml_tensor * llm_build_rwkv6_time_mix(
-        struct llama_context & lctx,
-        struct ggml_context * ctx,
-        const struct llama_layer * layer,
-        struct ggml_tensor * cur,
-        struct ggml_tensor * x_prev,
-        struct ggml_tensor ** wkv_state,
-        size_t wkv_head_size,
-        size_t head_count_kv) {
-    size_t n_embd       = cur->ne[0];
-    size_t n_seq_tokens = cur->ne[1];
-    size_t n_seqs       = cur->ne[2];
-
-    size_t head_size  = wkv_head_size;
-    size_t head_count = n_embd / head_size;
-
-    size_t n_tokens = n_seqs * n_seq_tokens;
-
-    bool is_qrwkv = layer->time_mix_first == nullptr;
-
-    struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
-
-    sx  = ggml_reshape_2d(ctx, sx,  n_embd, n_tokens);
-    cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
-
-    struct ggml_tensor * xxx = ggml_add(ctx, ggml_mul(ctx, sx, layer->time_mix_lerp_x), cur);
-
-    xxx = ggml_reshape_4d(
-        ctx,
-        ggml_tanh(
-            ctx,
-            ggml_mul_mat(ctx, layer->time_mix_w1, xxx)
-        ),
-        layer->time_mix_w1->ne[1] / 5, 1, 5, n_tokens
-    );
-
-    xxx = ggml_cont(ctx, ggml_permute(ctx, xxx, 0, 1, 3, 2));
-
-    xxx = ggml_mul_mat(
-        ctx,
-        ggml_reshape_4d(
-            ctx,
-            layer->time_mix_w2,
-            layer->time_mix_w2->ne[0], layer->time_mix_w2->ne[1], 1, 5
-        ),
-        xxx
-    );
-
-    struct ggml_tensor *xw, *xk, *xv, *xr, *xg;
-    if (layer->time_mix_lerp_fused) {
-        // fusing these weights makes some performance improvement
-        sx  = ggml_reshape_3d(ctx, sx,  n_embd, 1, n_tokens);
-        cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
-        xxx = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xxx, layer->time_mix_lerp_fused), sx), cur);
-        xw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
-        xk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
-        xv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
-        xr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
-        xg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
-    } else {
-        // for backward compatibility
-        xw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
-        xk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
-        xv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
-        xr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
-        xg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
-
-        xw = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xw, layer->time_mix_lerp_w), sx), cur);
-        xk = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xk, layer->time_mix_lerp_k), sx), cur);
-        xv = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xv, layer->time_mix_lerp_v), sx), cur);
-        xr = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xr, layer->time_mix_lerp_r), sx), cur);
-        xg = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xg, layer->time_mix_lerp_g), sx), cur);
-    }
-
-    struct ggml_tensor * r = llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr);
-    struct ggml_tensor * k = llm_build_lora_mm(lctx, ctx, layer->time_mix_key,        xk);
-    struct ggml_tensor * v = llm_build_lora_mm(lctx, ctx, layer->time_mix_value,      xv);
-    if (layer->time_mix_receptance_b) {
-        r = ggml_add(ctx, r, layer->time_mix_receptance_b);
-    }
-    if (layer->time_mix_key_b) {
-        k = ggml_add(ctx, k, layer->time_mix_key_b);
-    }
-    if (layer->time_mix_value_b) {
-        v = ggml_add(ctx, v, layer->time_mix_value_b);
-    }
-
-    struct ggml_tensor * g = llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg);
-    if (is_qrwkv) {
-        g = ggml_sigmoid(ctx, g);
-    } else {
-        g = ggml_silu(ctx, g);
-    }
-
-    if (head_count_kv != head_count) {
-        GGML_ASSERT(head_count % head_count_kv == 0);
-        k = ggml_reshape_4d(ctx, k, head_size, 1, head_count_kv, n_tokens);
-        v = ggml_reshape_4d(ctx, v, head_size, 1, head_count_kv, n_tokens);
-        struct ggml_tensor * tmp = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_size, head_count / head_count_kv, head_count_kv, n_tokens);
-        k = ggml_repeat(ctx, k, tmp);
-        v = ggml_repeat(ctx, v, tmp);
-    }
-
-    k = ggml_reshape_3d(ctx, k, head_size, head_count, n_tokens);
-    v = ggml_reshape_3d(ctx, v, head_size, head_count, n_tokens);
-    r = ggml_reshape_3d(ctx, r, head_size, head_count, n_tokens);
-
-    struct ggml_tensor * w = ggml_mul_mat(
-        ctx,
-        layer->time_mix_decay_w2,
-        ggml_tanh(
-            ctx,
-            ggml_mul_mat(ctx, layer->time_mix_decay_w1, xw)
-        )
-    );
-
-    w = ggml_add(ctx, w, layer->time_mix_decay);
-    w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w)));
-    w = ggml_reshape_3d(ctx, w, head_size, head_count, n_tokens);
-
-    if (is_qrwkv) {
-        // k = k * (1 - w)
-        k = ggml_sub(ctx, k, ggml_mul(ctx, k, w));
-    }
-
-    struct ggml_tensor * wkv_output;
-    if (!layer->time_mix_first) {
-        wkv_output = ggml_gated_linear_attn(ctx, k, v, r, w, *wkv_state, pow(head_size, -0.5f));
-    } else {
-        wkv_output = ggml_rwkv_wkv6(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
-    }
-    cur = ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0);
-    *wkv_state = ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
-
-    if (!is_qrwkv) {
-        // group norm with head_count groups
-        cur = ggml_reshape_3d(ctx, cur, n_embd / head_count, head_count, n_tokens);
-        cur = ggml_norm(ctx, cur, 64e-5f);
-
-        // Convert back to regular vectors.
-        cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
-        cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
-    } else {
-        cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
-    }
-
-    cur = ggml_mul(ctx, cur, g);
-    cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur);
-
-    return ggml_reshape_3d(ctx, cur, n_embd, n_seq_tokens, n_seqs);
-}
-
-static struct ggml_tensor * llm_build_rwkv6_channel_mix(
-        struct llama_context & lctx,
-        struct ggml_context * ctx,
-        const struct llama_layer * layer,
-        struct ggml_tensor * cur,
-        struct ggml_tensor * x_prev) {
-    struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
-    struct ggml_tensor * xk = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_k), cur);
-    struct ggml_tensor * xr = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_r), cur);
-
-    struct ggml_tensor * r = ggml_sigmoid(ctx, llm_build_lora_mm(lctx, ctx, layer->channel_mix_receptance, xr));
-    struct ggml_tensor * k = ggml_sqr(
-        ctx,
-        ggml_relu(
-            ctx,
-            llm_build_lora_mm(lctx, ctx, layer->channel_mix_key, xk)
-        )
-    );
-
-    return ggml_mul(ctx, r, llm_build_lora_mm(lctx, ctx, layer->channel_mix_value, k));
-}
-
-struct llm_build_context {
-    const llama_model    & model;
-          llama_context  & lctx;
-    const llama_hparams  & hparams;
-    const llama_cparams  & cparams;
-    const llama_ubatch   & ubatch;
-    const llama_kv_cache & kv_self;
-
-    const int64_t n_embd;
-    const int64_t n_layer;
-    const int64_t n_rot;
-    const int64_t n_ctx;       // user-specified context size (can be different from n_ctx_train)
-    const int64_t n_head;
-    const int64_t n_head_kv;
-    const int64_t n_embd_head_k;
-    const int64_t n_embd_k_gqa;
-    const int64_t n_embd_head_v;
-    const int64_t n_embd_v_gqa;
-    const int64_t n_expert;
-    const int64_t n_expert_used;
-
-    const float freq_base;
-    const float freq_scale;
-    const float ext_factor;
-    const float attn_factor;
-    const float beta_fast;
-    const float beta_slow;
-    const float norm_eps;
-    const float norm_rms_eps;
-
-    const int32_t n_tokens;
-    const int32_t n_kv;     // size of KV cache to consider (n_kv <= kv_self.size)
-    const int32_t n_outputs;
-    const int32_t n_outputs_enc;
-    const int32_t kv_head;  // index of where we store new KV data in the cache
-    const int32_t n_ctx_orig;
-
-    const bool flash_attn;
-
-    const enum llama_pooling_type pooling_type;
-    const enum llama_rope_type    rope_type;
-
-    const llm_build_cb & cb;
-
-    std::vector<uint8_t> & buf_compute_meta;
-
-    struct ggml_context * ctx0 = nullptr;
-
-    // TODO: consider making the entire interface noexcept
-    llm_build_context(
-        llama_context  & lctx,
-    const llama_ubatch & ubatch,
-    const llm_build_cb & cb,
-                  bool   worst_case) :
-        model            (lctx.model),
-        lctx             (lctx),
-        hparams          (model.hparams),
-        cparams          (lctx.cparams),
-        ubatch           (ubatch),
-        kv_self          (lctx.kv_self),
-        n_embd           (hparams.n_embd),
-        n_layer          (hparams.n_layer),
-        n_rot            (hparams.n_rot),
-        n_ctx            (cparams.n_ctx),
-        n_head           (hparams.n_head()),
-        n_head_kv        (hparams.n_head_kv()),
-        n_embd_head_k    (hparams.n_embd_head_k),
-        n_embd_k_gqa     (hparams.n_embd_k_gqa()),
-        n_embd_head_v    (hparams.n_embd_head_v),
-        n_embd_v_gqa     (hparams.n_embd_v_gqa()),
-        n_expert         (hparams.n_expert),
-        n_expert_used    (hparams.n_expert_used),
-        freq_base        (cparams.rope_freq_base),
-        freq_scale       (cparams.rope_freq_scale),
-        ext_factor       (cparams.yarn_ext_factor),
-        attn_factor      (cparams.yarn_attn_factor),
-        beta_fast        (cparams.yarn_beta_fast),
-        beta_slow        (cparams.yarn_beta_slow),
-        norm_eps         (hparams.f_norm_eps),
-        norm_rms_eps     (hparams.f_norm_rms_eps),
-        n_tokens         (ubatch.n_tokens),
-        n_kv             (worst_case ? kv_self.size : kv_self.n),
-        n_outputs        (worst_case ? n_tokens : lctx.n_outputs),
-        n_outputs_enc    (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd),
-        kv_head          (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
-        n_ctx_orig       (cparams.n_ctx_orig_yarn),
-        flash_attn       (cparams.flash_attn),
-        pooling_type     (cparams.pooling_type),
-        rope_type        (hparams.rope_type),
-        cb               (cb),
-        buf_compute_meta (lctx.buf_compute_meta) {
-            // all initializations should be done in init()
-        }
-
-    void init() {
-        struct ggml_init_params params = {
-            /*.mem_size   =*/ buf_compute_meta.size(),
-            /*.mem_buffer =*/ buf_compute_meta.data(),
-            /*.no_alloc   =*/ true,
-        };
-
-        ctx0 = ggml_init(params);
-
-        lctx.inp_tokens      = nullptr;
-        lctx.inp_embd        = nullptr;
-        lctx.inp_pos         = nullptr;
-        lctx.inp_out_ids     = nullptr;
-        lctx.inp_KQ_mask     = nullptr;
-        lctx.inp_KQ_mask_swa = nullptr;
-        lctx.inp_K_shift     = nullptr;
-        lctx.inp_mean        = nullptr;
-        lctx.inp_cls         = nullptr;
-        lctx.inp_s_copy      = nullptr;
-        lctx.inp_s_mask      = nullptr;
-        lctx.inp_s_seq       = nullptr;
-        lctx.inp_pos_bucket    = nullptr;
-        lctx.inp_embd_enc      = nullptr;
-        lctx.inp_KQ_mask_cross = nullptr;
-    }
-
-    void free() {
-        ggml_free(ctx0);
-        ctx0 = nullptr;
-    }
-
-    struct ggml_cgraph * build_k_shift() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        GGML_ASSERT(kv_self.size == n_ctx);
-
-        lctx.inp_K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
-        cb(lctx.inp_K_shift, "K_shift", -1);
-        ggml_set_input(lctx.inp_K_shift);
-
-        for (int il = 0; il < n_layer; ++il) {
-            const int64_t n_head_kv = hparams.n_head_kv(il);
-            const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
-            struct ggml_tensor * rope_factors = build_rope_factors(il);
-            struct ggml_tensor * k =
-                ggml_view_3d(ctx0, kv_self.k_l[il],
-                    n_embd_head_k, n_head_kv, n_ctx,
-                    ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
-                    ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
-                    0);
-
-            struct ggml_tensor * tmp;
-            if (ggml_is_quantized(k->type)) {
-                // dequantize to f32 -> RoPE -> quantize back
-                tmp = ggml_cast(ctx0, k, GGML_TYPE_F32);
-                cb(tmp, "K_f32", il);
-                for (auto & backend : lctx.backends) {
-                    // Figure out which backend KV cache belongs to
-                    if (ggml_backend_supports_buft(backend.get(), ggml_backend_buffer_get_type(kv_self.k_l[il]->buffer))) {
-                        ggml_backend_sched_set_tensor_backend(lctx.sched.get(), tmp, backend.get());
-                        break;
-                    }
-                }
-                tmp = ggml_rope_ext_inplace(ctx0, tmp,
-                        lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                        ext_factor, attn_factor, beta_fast, beta_slow);
-                cb(tmp, "K_shifted_f32", il);
-                tmp = ggml_cpy(ctx0, tmp, k);
-            } else {
-                // we rotate only the first n_rot dimensions
-                tmp = ggml_rope_ext_inplace(ctx0, k,
-                        lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                        ext_factor, attn_factor, beta_fast, beta_slow);
-            }
-            cb(tmp, "K_shifted", il);
-            ggml_build_forward_expand(gf, tmp);
-        }
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        for (uint32_t i = 0; i < ids.size(); ++i) {
-            const uint32_t id = ids[i];
-
-            if (i == id || id == ids.size()) {
-                continue;
-            }
-
-            uint32_t nm = 1;
-
-            while (i + nm < ids.size() && ids[i + nm] == id + nm) {
-                nm++;
-            }
-
-            for (int il = 0; il < n_layer; ++il) {
-                const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
-                const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
-
-                ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il],
-                        n_embd_k_gqa, nm,
-                        ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
-                        ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*i));
-
-                ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self.k_l[il],
-                        n_embd_k_gqa, nm,
-                        ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
-                        ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id));
-
-                ggml_tensor * view_v_src;
-                ggml_tensor * view_v_dst;
-
-                if (flash_attn) {
-                    // NOTE: the V cache is not transposed when using flash attention
-                    view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
-                            n_embd_v_gqa, nm,
-                            ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
-                            ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i));
-
-                    view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
-                            n_embd_v_gqa, nm,
-                            ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
-                            ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id));
-                } else {
-                    view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
-                            nm, n_embd_v_gqa,
-                            ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
-                            ggml_row_size(kv_self.v_l[il]->type, i));
-
-                    view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
-                            nm, n_embd_v_gqa,
-                            ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
-                            ggml_row_size(kv_self.v_l[il]->type, id));
-                }
-
-                ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
-                ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
-            }
-
-            i += nm - 1;
-        }
-
-        //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
-
-        return gf;
-    }
-
-    struct ggml_tensor * build_inp_pos() {
-        lctx.inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
-        cb(lctx.inp_pos, "inp_pos", -1);
-        ggml_set_input(lctx.inp_pos);
-        return lctx.inp_pos;
-    }
-
-    struct ggml_tensor * build_rope_factors(int il) {
-        // choose long/short freq factors based on the context size
-        const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max;
-
-        if (model.layers[il].rope_freqs != nullptr) {
-            return model.layers[il].rope_freqs;
-        }
-
-        if (n_ctx_pre_seq > hparams.n_ctx_orig_yarn) {
-            return model.layers[il].rope_long;
-        }
-
-        return model.layers[il].rope_short;
-    }
-
-    struct ggml_tensor * build_inp_out_ids() {
-        lctx.inp_out_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs);
-        cb(lctx.inp_out_ids, "inp_out_ids", -1);
-        ggml_set_input(lctx.inp_out_ids);
-        return lctx.inp_out_ids;
-    }
-
-    struct ggml_tensor * build_inp_KQ_mask(bool causal = true) {
-        lctx.inp_KQ_mask = causal
-            ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv,     GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
-            : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
-        cb(lctx.inp_KQ_mask, "KQ_mask", -1);
-        ggml_set_input(lctx.inp_KQ_mask);
-
-        return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
-    }
-
-    struct ggml_tensor * build_inp_KQ_mask_swa(bool causal = true) {
-        GGML_ASSERT(hparams.n_swa > 0);
-
-        lctx.inp_KQ_mask_swa = causal
-            ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv,     GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
-            : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
-        cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1);
-        ggml_set_input(lctx.inp_KQ_mask_swa);
-
-        return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask_swa, GGML_TYPE_F16) : lctx.inp_KQ_mask_swa;
-    }
-
-    struct ggml_tensor * build_inp_mean() {
-        lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
-        cb(lctx.inp_mean, "inp_mean", -1);
-        ggml_set_input(lctx.inp_mean);
-        return lctx.inp_mean;
-    }
-
-    struct ggml_tensor * build_inp_cls() {
-        lctx.inp_cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
-        cb(lctx.inp_cls, "inp_cls", -1);
-        ggml_set_input(lctx.inp_cls);
-        return lctx.inp_cls;
-    }
-
-    struct ggml_tensor * build_inp_s_copy() {
-        lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
-        cb(lctx.inp_s_copy, "inp_s_copy", -1);
-        ggml_set_input(lctx.inp_s_copy);
-        return lctx.inp_s_copy;
-    }
-
-    struct ggml_tensor * build_inp_s_mask() {
-        lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
-        cb(lctx.inp_s_mask, "inp_s_mask", -1);
-        ggml_set_input(lctx.inp_s_mask);
-        return lctx.inp_s_mask;
-    }
-
-    struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) {
-        // find result_norm tensor for input
-        struct ggml_tensor * inp = nullptr;
-        for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
-            inp = ggml_graph_node(gf, i);
-            if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
-                break;
-            } else {
-                inp = nullptr;
-            }
-        }
-        GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
-
-        struct ggml_tensor * cur;
-
-        switch (pooling_type) {
-            case LLAMA_POOLING_TYPE_NONE:
-                {
-                    cur = inp;
-                } break;
-            case LLAMA_POOLING_TYPE_MEAN:
-                {
-                    struct ggml_tensor * inp_mean = build_inp_mean();
-                    cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
-                } break;
-            case LLAMA_POOLING_TYPE_CLS:
-            case LLAMA_POOLING_TYPE_LAST:
-                {
-                    struct ggml_tensor * inp_cls = build_inp_cls();
-                    cur = ggml_get_rows(ctx0, inp, inp_cls);
-                } break;
-            case LLAMA_POOLING_TYPE_RANK:
-                {
-                    struct ggml_tensor * inp_cls = build_inp_cls();
-                    inp = ggml_get_rows(ctx0, inp, inp_cls);
-
-                    // classification head
-                    // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
-                    GGML_ASSERT(model.cls       != nullptr);
-                    GGML_ASSERT(model.cls_b     != nullptr);
-
-                    cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls, inp), model.cls_b);
-                    cur = ggml_tanh(ctx0, cur);
-
-                    // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
-                    // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
-                    if (model.cls_out) {
-                        GGML_ASSERT(model.cls_out_b != nullptr);
-
-                        cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b);
-                    }
-                } break;
-            default:
-                {
-                    GGML_ABORT("unknown pooling type");
-                }
-        }
-
-        cb(cur, "result_embd_pooled", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_tensor * llm_build_pos_bucket(bool causal) {
-        if (causal) {
-            lctx.inp_pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv,     n_tokens);
-        } else {
-            lctx.inp_pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens);
-        }
-
-        ggml_set_input(lctx.inp_pos_bucket);
-        cb(lctx.inp_pos_bucket, "pos_bucket", -1);
-
-        return lctx.inp_pos_bucket;
-    }
-
-    struct ggml_tensor * llm_build_pos_bias(struct ggml_tensor * pos_bucket, struct ggml_tensor * attn_rel_b) {
-        struct ggml_tensor * pos_bucket_1d = ggml_view_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1], 0);
-        cb(pos_bucket_1d, "pos_bucket_1d", -1);
-
-        struct ggml_tensor * pos_bias = ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d);
-        cb(pos_bias, "pos_bias", -1);
-
-        pos_bias = ggml_view_3d(ctx0, pos_bias, pos_bias->ne[0], lctx.inp_pos_bucket->ne[0], lctx.inp_pos_bucket->ne[1], ggml_element_size(pos_bias) * pos_bias->ne[0], ggml_element_size(pos_bias) * pos_bias->ne[0] * lctx.inp_pos_bucket->ne[0],  0);
-        cb(pos_bias, "pos_bias", -1);
-
-        pos_bias = ggml_permute(ctx0, pos_bias, 2, 0, 1, 3);
-        cb(pos_bias, "pos_bias", -1);
-
-        pos_bias = ggml_cont(ctx0, pos_bias);
-        cb(pos_bias, "pos_bias", -1);
-
-        return pos_bias;
-    }
-
-    struct ggml_tensor * llm_build_inp_embd_enc() {
-        const int64_t n_embd = hparams.n_embd;
-        lctx.inp_embd_enc = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_outputs_enc);
-        ggml_set_input(lctx.inp_embd_enc);
-        cb(lctx.inp_embd_enc, "embd_enc", -1);
-        return lctx.inp_embd_enc;
-    }
-
-    struct ggml_tensor * llm_build_inp_KQ_mask_cross() {
-        lctx.inp_KQ_mask_cross = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_outputs_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
-        ggml_set_input(lctx.inp_KQ_mask_cross);
-        cb(lctx.inp_KQ_mask_cross, "KQ_mask_cross", -1);
-        return lctx.inp_KQ_mask_cross;
-    }
-
-    struct ggml_cgraph * build_llama() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // rope freq factors for llama3; may return nullptr for llama2 and other models
-                struct ggml_tensor * rope_factors = build_rope_factors(il);
-
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
-                }
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            // For Granite architecture
-            if (hparams.f_residual_scale) {
-                cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            if (model.layers[il].ffn_gate_inp == nullptr) {
-
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            } else {
-                // MoE branch
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_moe_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_gate_inp,
-                        model.layers[il].ffn_up_exps,
-                        model.layers[il].ffn_gate_exps,
-                        model.layers[il].ffn_down_exps,
-                        nullptr,
-                        n_expert, n_expert_used,
-                        LLM_FFN_SILU, true,
-                        false, 0.0,
-                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
-                        cb, il);
-                cb(cur, "ffn_moe_out", il);
-            }
-
-            // For Granite architecture
-            if (hparams.f_residual_scale) {
-                cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "ffn_out", il);
-
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-
-        // For Granite architecture
-        if (hparams.f_logit_scale) {
-            cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
-        }
-
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_deci() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-            const int64_t n_head_kv = hparams.n_head_kv(il);
-            const int64_t n_head    = hparams.n_head(il);
-
-            if (n_head == 0) {
-                // attention-free layer of Llama-3_1-Nemotron-51B
-                cur = inpL;
-            } else {
-                // norm
-                cur = llm_build_norm(ctx0, inpL, hparams,
-                        model.layers[il].attn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "attn_norm", il);
-            }
-
-            if (n_head > 0 && n_head_kv == 0) {
-                // "linear attention" of Llama-3_1-Nemotron-51B
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
-                cb(cur, "wo", il);
-            } else if (n_head > 0) {
-                // self-attention
-                // rope freq factors for llama3; may return nullptr for llama2 and other models
-                struct ggml_tensor * rope_factors = build_rope_factors(il);
-
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
-                }
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            // For Granite architecture
-            if (hparams.f_residual_scale) {
-                cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
-            }
-
-            // modified to support attention-free layer of Llama-3_1-Nemotron-51B
-            struct ggml_tensor * ffn_inp = cur;
-            if (n_head > 0) {
-                ffn_inp = ggml_add(ctx0, cur, inpSA);
-                cb(ffn_inp, "ffn_inp", il);
-            }
-
-            // feed-forward network
-            if (model.layers[il].ffn_gate_inp == nullptr) {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            // For Granite architecture
-            if (hparams.f_residual_scale) {
-                cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "ffn_out", il);
-
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-
-        // For Granite architecture
-        if (hparams.f_logit_scale) {
-            cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
-        }
-
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_baichuan() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = model.type == LLM_TYPE_7B ? build_inp_pos() : nullptr;
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-
-                switch (model.type) {
-                    case LLM_TYPE_7B:
-                        Qcur = ggml_rope_ext(
-                            ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                            n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                            ext_factor, attn_factor, beta_fast, beta_slow
-                        );
-                        Kcur = ggml_rope_ext(
-                            ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                            n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                            ext_factor, attn_factor, beta_fast, beta_slow
-                        );
-                        break;
-                    case LLM_TYPE_13B:
-                        Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd/n_head, n_head, n_tokens);
-                        Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd/n_head, n_head, n_tokens);
-                        break;
-                    default:
-                        GGML_ABORT("fatal error");
-                }
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_xverse() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur   = ggml_get_rows(ctx0,      cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_falcon() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * attn_norm;
-
-            attn_norm = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm,
-                    model.layers[il].attn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(attn_norm, "attn_norm", il);
-
-            // self-attention
-            {
-                if (model.layers[il].attn_norm_2) {
-                    // Falcon-40B
-                    cur = llm_build_norm(ctx0, inpL, hparams,
-                            model.layers[il].attn_norm_2,
-                            model.layers[il].attn_norm_2_b,
-                            LLM_NORM, cb, il);
-                    cb(cur, "attn_norm_2", il);
-                } else {
-                    cur = attn_norm;
-                }
-
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
-                cb(cur, "wqkv", il);
-
-                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-
-                // using mode = 2 for neox mode
-                Qcur = ggml_rope_ext(
-                    ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
-                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
-                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur       = ggml_get_rows(ctx0,       cur, inp_out_ids);
-                inpL      = ggml_get_rows(ctx0,      inpL, inp_out_ids);
-                attn_norm = ggml_get_rows(ctx0, attn_norm, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = cur;
-
-            // feed forward
-            {
-                cur = llm_build_ffn(ctx0, lctx, attn_norm, // !! use the attn norm, not the result
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        NULL,                      NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = ggml_add(ctx0, cur, inpL);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        // norm
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm,
-                model.output_norm_b,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_grok() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // multiply by embedding_multiplier_scale of 78.38367176906169
-        inpL = ggml_scale(ctx0, inpL, 78.38367176906169f);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
-                }
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            // Grok
-            // if attn_out_norm is present then apply it before adding the input
-            if (model.layers[il].attn_out_norm) {
-                cur = llm_build_norm(ctx0, cur, hparams,
-                        model.layers[il].attn_out_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "attn_out_norm", il);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            // MoE branch
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_moe_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_gate_inp,
-                    model.layers[il].ffn_up_exps,
-                    model.layers[il].ffn_gate_exps,
-                    model.layers[il].ffn_down_exps,
-                    nullptr,
-                    n_expert, n_expert_used,
-                    LLM_FFN_GELU, true,
-                    false, 0.0,
-                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
-                    cb, il);
-            cb(cur, "ffn_moe_out", il);
-
-            // Grok
-            // if layer_out_norm is present then apply it before adding the input
-            // Idea: maybe ffn_out_norm is a better name
-            if (model.layers[il].layer_out_norm) {
-                cur = llm_build_norm(ctx0, cur, hparams,
-                        model.layers[il].layer_out_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "layer_out_norm", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "ffn_out", il);
-
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-
-        // Grok
-        // multiply logits by output_multiplier_scale of 0.5773502691896257
-
-        cur = ggml_scale(ctx0, cur, 0.5773502691896257f);
-
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_dbrx() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                                 model.layers[il].attn_norm, NULL,
-                                 LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                struct ggml_tensor * Qcur = nullptr;
-                struct ggml_tensor * Kcur = nullptr;
-                struct ggml_tensor * Vcur = nullptr;
-
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
-                cb(cur, "wqkv", il);
-
-                cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
-                cb(cur, "wqkv_clamped", il);
-
-                Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            // MoE branch
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                                 model.layers[il].attn_out_norm, NULL,
-                                 LLM_NORM, cb, il);
-            cb(cur, "attn_out_norm", il);
-
-            cur = llm_build_moe_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_gate_inp,
-                    model.layers[il].ffn_up_exps,
-                    model.layers[il].ffn_gate_exps,
-                    model.layers[il].ffn_down_exps,
-                    nullptr,
-                    n_expert, n_expert_used,
-                    LLM_FFN_SILU, true,
-                    false, 0.0,
-                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
-                    cb, il);
-            cb(cur, "ffn_moe_out", il);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "ffn_out", il);
-
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                             model.output_norm, NULL,
-                             LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_starcoder() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        struct ggml_tensor * pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
-        cb(pos, "pos_embd", -1);
-
-        inpL = ggml_add(ctx0, inpL, pos);
-        cb(inpL, "inpL", -1);
-
-        for (int il = 0; il < n_layer; ++il) {
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm,
-                    model.layers[il].attn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
-                cb(cur, "wqkv", il);
-
-                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
-                cb(cur, "bqkv", il);
-
-                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
-                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
-            }
-
-            // add the input
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // FF
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm,
-                        model.layers[il].ffn_norm_b,
-                        LLM_NORM, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        NULL,                      NULL,                        NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = llm_build_norm(ctx0, inpL, hparams,
-                model.output_norm,
-                model.output_norm_b,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_refact() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-                cb(Kcur, "Kcur", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                cb(Qcur, "Qcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_bert() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
-
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-        struct ggml_tensor * inp_pos = nullptr;
-
-        if (model.arch != LLM_ARCH_JINA_BERT_V2) {
-            inp_pos = build_inp_pos();
-        }
-
-        // construct input embeddings (token, type, position)
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // token types are hardcoded to zero ("Sentence A")
-        struct ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
-        inpL = ggml_add(ctx0, inpL, type_row0);
-        if (model.arch == LLM_ARCH_BERT) {
-            inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL);
-        }
-        cb(inpL, "inp_embd", -1);
-
-        // embed layer norm
-        inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
-        cb(inpL, "inp_norm", -1);
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask(false);
-
-        // iterate layers
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * cur = inpL;
-
-            struct ggml_tensor * Qcur;
-            struct ggml_tensor * Kcur;
-            struct ggml_tensor * Vcur;
-
-            // self-attention
-            if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) {
-                Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur), model.layers[il].bq);
-                cb(Qcur, "Qcur", il);
-
-                if (model.layers[il].attn_q_norm) {
-                    Qcur = llm_build_norm(ctx0, Qcur, hparams,
-                            model.layers[il].attn_q_norm,
-                            model.layers[il].attn_q_norm_b,
-                            LLM_NORM, cb, il);
-                }
-
-                Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur), model.layers[il].bk);
-                cb(Kcur, "Kcur", il);
-
-                if (model.layers[il].attn_k_norm) {
-                    Kcur = llm_build_norm(ctx0, Kcur, hparams,
-                            model.layers[il].attn_k_norm,
-                            model.layers[il].attn_k_norm_b,
-                            LLM_NORM, cb, il);
-                }
-                Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur), model.layers[il].bv);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-            } else {
-                // compute Q and K and RoPE them
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
-                cb(cur, "wqkv", il);
-
-                Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-            }
-
-            struct ggml_tensor * q =                 ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
-            struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
-
-            struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
-            cb(kq, "kq", il);
-
-            kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
-            cb(kq, "kq_soft_max_ext", il);
-
-            struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)));
-            cb(v, "v", il);
-
-            struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq);
-            cb(kqv, "kqv", il);
-
-            struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
-            cb(kqv_merged, "kqv_merged", il);
-
-            cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
-            cb(cur, "kqv_merged_cont", il);
-
-            ggml_build_forward_expand(gf, cur);
-
-            cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
-            if (model.layers[il].bo) {
-                cb(cur, "kqv_wo", il);
-            }
-
-            if (model.layers[il].bo) {
-                cur = ggml_add(ctx0, cur, model.layers[il].bo);
-            }
-            cb(cur, "kqv_out", il);
-
-            if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
-                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
-            }
-
-            // re-add the layer input
-            cur = ggml_add(ctx0, cur, inpL);
-
-            // attention layer norm
-            cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_out_norm, model.layers[il].attn_out_norm_b, LLM_NORM, cb, il);
-
-            if (model.layers[il].attn_norm_2 != nullptr) {
-                cur = ggml_add(ctx0, cur, inpL); // re-add the layer input
-                cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_norm_2, model.layers[il].attn_norm_2_b, LLM_NORM, cb, il);
-            }
-
-            struct ggml_tensor * ffn_inp = cur;
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            if (model.arch == LLM_ARCH_BERT) {
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        NULL,                      NULL,                        NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
-            } else if (model.arch == LLM_ARCH_JINA_BERT_V2) {
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL,                        NULL,
-                        model.layers[il].ffn_gate, NULL,                        NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
-            } else {
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            }
-            cb(cur, "ffn_out", il);
-
-            // attentions bypass the intermediate layer
-            cur = ggml_add(ctx0, cur, ffn_inp);
-
-            // output layer norm
-            cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].layer_out_norm, model.layers[il].layer_out_norm_b, LLM_NORM, cb, il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cb(cur, "result_embd", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_bloom() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        inpL = llm_build_norm(ctx0, inpL, hparams,
-                model.tok_norm,
-                model.tok_norm_b,
-                LLM_NORM, cb, -1);
-        cb(inpL, "inp_norm", -1);
-
-        for (int il = 0; il < n_layer; ++il) {
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm,
-                    model.layers[il].attn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
-                cb(cur, "wqkv", il);
-
-                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
-                cb(cur, "bqkv", il);
-
-                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
-                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
-            }
-
-            // Add the input
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // FF
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm,
-                        model.layers[il].ffn_norm_b,
-                        LLM_NORM, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        NULL,                      NULL,                        NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = llm_build_norm(ctx0, inpL, hparams,
-                model.output_norm,
-                model.output_norm_b,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_mpt() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * pos;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        if (model.pos_embd) {
-            // inp_pos - contains the positions
-            struct ggml_tensor * inp_pos = build_inp_pos();
-            pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
-            cb(pos, "pos_embd", -1);
-
-            inpL = ggml_add(ctx0, inpL, pos);
-            cb(inpL, "inpL", -1);
-        }
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * attn_norm;
-
-            attn_norm = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm,
-                    model.layers[il].attn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(attn_norm, "attn_norm", il);
-
-            // self-attention
-            {
-                cur = attn_norm;
-
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
-                cb(cur, "wqkv", il);
-
-                if (model.layers[il].bqkv){
-                    cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
-                    cb(cur, "bqkv", il);
-                }
-
-                if (hparams.f_clamp_kqv > 0.0f) {
-                    cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
-                    cb(cur, "wqkv_clamped", il);
-                }
-
-                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
-                // Q/K Layernorm
-                if (model.layers[il].attn_q_norm) {
-                    Qcur = llm_build_norm(ctx0, Qcur, hparams,
-                            model.layers[il].attn_q_norm,
-                            model.layers[il].attn_q_norm_b,
-                            LLM_NORM, cb, il);
-                    cb(Qcur, "Qcur", il);
-
-                    Kcur = llm_build_norm(ctx0, Kcur, hparams,
-                            model.layers[il].attn_k_norm,
-                            model.layers[il].attn_k_norm_b,
-                            LLM_NORM, cb, il);
-                    cb(Kcur, "Kcur", il);
-
-                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-
-                    cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                            model.layers[il].wo, model.layers[il].bo,
-                            Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-                } else {
-                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-
-                    cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                            model.layers[il].wo, model.layers[il].bo,
-                            Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-                }
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
-                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
-            }
-
-            // Add the input
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed forward
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm,
-                        model.layers[il].ffn_norm_b,
-                        LLM_NORM, cb, il);
-                cb(cur, "ffn_norm", il);
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        NULL,                      NULL,                        NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        model.layers[il].ffn_act,
-                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm,
-                model.output_norm_b,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_stablelm() {
-        struct ggml_cgraph * gf = ggml_new_graph(ctx0);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm,
-                    model.layers[il].attn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
-
-            struct ggml_tensor * inpSA = cur;
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
-                }
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                cb(Qcur, "Qcur", il);
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-                cb(Kcur, "Kcur", il);
-
-                if (model.layers[il].attn_q_norm) {
-                    Qcur = llm_build_norm(ctx0, Qcur, hparams,
-                            model.layers[il].attn_q_norm,
-                            NULL,
-                            LLM_NORM, cb, il);
-                    cb(Qcur, "Qcur", il);
-                }
-                if (model.layers[il].attn_k_norm) {
-                    Kcur = llm_build_norm(ctx0, Kcur, hparams,
-                            model.layers[il].attn_k_norm,
-                            NULL,
-                            LLM_NORM, cb, il);
-                    cb(Kcur, "Kcur", il);
-                }
-
-
-                Qcur = ggml_rope_ext(
-                    ctx0, Qcur, inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, Kcur, inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpL  = ggml_get_rows(ctx0,  inpL, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            {
-                if (model.layers[il].ffn_norm) {
-                    cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                            model.layers[il].ffn_norm,
-                            model.layers[il].ffn_norm_b,
-                            LLM_NORM, cb, il);
-                    cb(cur, "ffn_norm", il);
-                } else {
-                    // parallel residual
-                    cur = inpSA;
-                }
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm,
-                model.output_norm_b,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_qwen() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
-                cb(cur, "wqkv", il);
-
-                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
-                cb(cur, "bqkv", il);
-
-                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 2*sizeof(float)*(n_embd)));
-
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-
-                // using mode = 2 for neox mode
-                Qcur = ggml_rope_ext(
-                    ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
-                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
-                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward forward
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_qwen2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_up,   NULL, NULL,
-                    model.layers[il].ffn_gate, NULL, NULL,
-                    model.layers[il].ffn_down, NULL, NULL,
-                    NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            cb(cur, "ffn_out", il);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_qwen2vl() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        lctx.inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens * 4);
-        cb(lctx.inp_pos, "inp_pos", -1);
-        ggml_set_input(lctx.inp_pos);
-        struct ggml_tensor * inp_pos = lctx.inp_pos;
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-        int sections[4];
-        std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_rope_multi(
-                    ctx0,
-                    ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_multi(
-                    ctx0,
-                    ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_up,   NULL, NULL,
-                    model.layers[il].ffn_gate, NULL, NULL,
-                    model.layers[il].ffn_down, NULL, NULL,
-                    NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            cb(cur, "ffn_out", il);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_qwen2moe() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self_attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // MoE branch
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            ggml_tensor * moe_out =
-                    llm_build_moe_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_gate_inp,
-                        model.layers[il].ffn_up_exps,
-                        model.layers[il].ffn_gate_exps,
-                        model.layers[il].ffn_down_exps,
-                        nullptr,
-                        n_expert, n_expert_used,
-                        LLM_FFN_SILU, false,
-                        false, 0.0,
-                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
-                        cb, il);
-            cb(cur, "ffn_moe_out", il);
-
-            // FFN shared expert
-            {
-                ggml_tensor * cur_gate_inp = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_gate_inp_shexp, cur);
-                cb(cur_gate_inp, "ffn_shexp_gate_inp", il);
-
-                // sigmoid
-                ggml_tensor * cur_gate = ggml_div(ctx0, ggml_silu(ctx0, cur_gate_inp), cur_gate_inp);
-                cb(cur_gate, "ffn_shexp_gate", il);
-
-                ggml_tensor * cur_ffn = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up_shexp,   NULL, NULL,
-                        model.layers[il].ffn_gate_shexp, NULL, NULL,
-                        model.layers[il].ffn_down_shexp, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur_ffn, "ffn_shexp", il);
-
-                ggml_tensor * ffn_shexp_out = ggml_mul(ctx0, cur_ffn, cur_gate);
-                cb(ffn_shexp_out, "ffn_shexp_out", il);
-
-                moe_out = ggml_add(ctx0, moe_out, ffn_shexp_out);
-                cb(moe_out, "ffn_out", il);
-
-                cur = moe_out;
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_phi2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * attn_norm_output;
-        struct ggml_tensor * ffn_output;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            attn_norm_output = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm,
-                    model.layers[il].attn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(attn_norm_output, "attn_norm", il);
-
-            // self-attention
-            {
-                struct ggml_tensor * Qcur = nullptr;
-                struct ggml_tensor * Kcur = nullptr;
-                struct ggml_tensor * Vcur = nullptr;
-
-                if (model.layers[il].wqkv) {
-                    cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, attn_norm_output);
-                    cb(cur, "wqkv", il);
-
-                    cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
-                    cb(cur, "bqkv", il);
-
-                    Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                    Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                    Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-                } else {
-                    Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
-                    Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
-                    Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
-                }
-
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-
-                Qcur = ggml_rope_ext(
-                    ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
-                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                // with phi2, we scale the Q to avoid precision issues
-                // ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66
-                Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head)));
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
-                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur              = ggml_get_rows(ctx0,              cur, inp_out_ids);
-                inpL             = ggml_get_rows(ctx0,             inpL, inp_out_ids);
-                attn_norm_output = ggml_get_rows(ctx0, attn_norm_output, inp_out_ids);
-            }
-
-            // FF
-            {
-                ffn_output = llm_build_ffn(ctx0, lctx, attn_norm_output,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        NULL,                      NULL,                        NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
-                cb(ffn_output, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_output);
-            cur = ggml_add(ctx0, cur, inpL);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = llm_build_norm(ctx0, inpL, hparams,
-                model.output_norm,
-                model.output_norm_b,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output_no_bias", -1);
-
-        cur = ggml_add(ctx0, cur, model.output_b);
-        cb(cur, "result_output", -1);
-        ggml_build_forward_expand(gf, cur);
-        return gf;
-    }
-
-    struct ggml_cgraph * build_phi3() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = nullptr;
-        if (hparams.n_swa == 0) {
-            // Phi-4 doesn't use sliding window attention
-            KQ_mask = build_inp_KQ_mask();
-        } else {
-            KQ_mask = build_inp_KQ_mask_swa();
-        }
-
-        for (int il = 0; il < n_layer; ++il) {
-            auto residual = inpL;
-
-            // self-attention
-            {
-                // rope freq factors for 128k context
-                struct ggml_tensor * rope_factors = build_rope_factors(il);
-
-                struct ggml_tensor* attn_norm_output = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm,
-                    model.layers[il].attn_norm_b,
-                    LLM_NORM_RMS, cb, il);
-                cb(attn_norm_output, "attn_norm", il);
-
-                struct ggml_tensor * Qcur = nullptr;
-                struct ggml_tensor * Kcur = nullptr;
-                struct ggml_tensor * Vcur = nullptr;
-
-                if (model.layers[il].wqkv) {
-                    cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, attn_norm_output);
-                    cb(cur, "wqkv", il);
-
-                    Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd)));
-                    Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd)));
-                    Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)));
-                } else {
-                    Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
-                    Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
-                    Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
-                }
-
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-
-                Qcur = ggml_rope_ext(
-                    ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig,
-                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head)));
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig,
-                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor* inp_out_ids = build_inp_out_ids();
-                cur = ggml_get_rows(ctx0, cur, inp_out_ids);
-                residual = ggml_get_rows(ctx0, residual, inp_out_ids);
-            }
-
-            cur = ggml_add(ctx0, cur, residual);
-            residual = cur;
-
-            cur = llm_build_norm(ctx0, cur, hparams,
-                model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
-                LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            // feed-forward network
-            if (model.layers[il].ffn_gate_inp == nullptr) {
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        NULL,                      NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il);
-                cb(cur, "ffn_out", il);
-            } else {
-                // MoE branch
-                cur = llm_build_moe_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_gate_inp,
-                        model.layers[il].ffn_up_exps,
-                        model.layers[il].ffn_gate_exps,
-                        model.layers[il].ffn_down_exps,
-                        nullptr,
-                        n_expert, n_expert_used,
-                        LLM_FFN_SILU, true,
-                        false, 0.0,
-                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
-                        cb, il);
-                cb(cur, "ffn_moe_out", il);
-            }
-
-            cur = ggml_add(ctx0, residual, cur);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = llm_build_norm(ctx0, inpL, hparams,
-            model.output_norm,
-            model.output_norm_b,
-            LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-
-        if (model.output_b != nullptr) {
-            cb(cur, "result_output_no_bias", -1);
-            cur = ggml_add(ctx0, cur, model.output_b);
-        }
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-
-    struct ggml_cgraph * build_plamo() {
-        struct ggml_cgraph * gf = ggml_new_graph(ctx0);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            struct ggml_tensor * attention_norm = cur;
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head,    n_tokens), inp_pos, nullptr,
-                        n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale,
-                        ext_factor, attn_factor, beta_fast, beta_slow);
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, nullptr,
-                        n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale,
-                        ext_factor, attn_factor, beta_fast, beta_slow);
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-            struct ggml_tensor * sa_out = cur;
-
-            cur = attention_norm;
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur    = ggml_get_rows(ctx0,    cur, inp_out_ids);
-                sa_out = ggml_get_rows(ctx0, sa_out, inp_out_ids);
-                inpL   = ggml_get_rows(ctx0,   inpL, inp_out_ids);
-            }
-
-            // feed-forward network
-            {
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, sa_out);
-            cur = ggml_add(ctx0, cur, inpL);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_gpt2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * pos;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
-        cb(pos, "pos_embd", -1);
-
-        inpL = ggml_add(ctx0, inpL, pos);
-        cb(inpL, "inpL", -1);
-
-        for (int il = 0; il < n_layer; ++il) {
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm,
-                    model.layers[il].attn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
-                cb(cur, "wqkv", il);
-
-                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
-                cb(cur, "bqkv", il);
-
-                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
-                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
-            }
-
-            // add the input
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // FF
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm,
-                        model.layers[il].ffn_norm_b,
-                        LLM_NORM, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        NULL,                      NULL,                        NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = llm_build_norm(ctx0, inpL, hparams,
-                model.output_norm,
-                model.output_norm_b,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_codeshell() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm,
-                    model.layers[il].attn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
-                cb(cur, "wqkv", il);
-
-                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
-                cb(cur, "bqkv", il);
-
-                struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                struct ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
-                cb(tmpq, "tmpq", il);
-                cb(tmpk, "tmpk", il);
-                cb(Vcur, "Vcur", il);
-
-                struct ggml_tensor * Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
-                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
-            }
-
-            // add the input
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // FF
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm,
-                        model.layers[il].ffn_norm_b,
-                        LLM_NORM, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        NULL,                      NULL,                        NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = llm_build_norm(ctx0, inpL, hparams,
-                model.output_norm,
-                model.output_norm_b,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_orion() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, model.layers[il].attn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                // if (model.layers[il].bq) {
-                //     Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                //     cb(Qcur, "Qcur", il);
-                // }
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                // if (model.layers[il].bk) {
-                //     Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                //     cb(Kcur, "Kcur", il);
-                // }
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                // if (model.layers[il].bv) {
-                //     Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                //     cb(Vcur, "Vcur", il);
-                // }
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_up,   NULL, NULL,
-                    model.layers[il].ffn_gate, NULL, NULL,
-                    model.layers[il].ffn_down, NULL, NULL,
-                    NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            cb(cur, "ffn_out", il);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, model.output_norm_b,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_internlm2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
-                }
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_up,   NULL, NULL,
-                    model.layers[il].ffn_gate, NULL, NULL,
-                    model.layers[il].ffn_down, NULL, NULL,
-                    NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            cb(cur, "ffn_out", il);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_minicpm3() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        //TODO: if the model varies, these parameters need to be read from the model
-        const int64_t n_embd_base = 256;
-        const float scale_embd  = 12.0f;
-        const float scale_depth = 1.4f;
-        const float kq_scale = 1.0f / sqrtf(float(hparams.n_embd_head_k));
-
-        const uint32_t n_embd_head_qk_rope = hparams.n_rot;
-        const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
-        const uint32_t kv_lora_rank = hparams.n_lora_kv;
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // scale the input embeddings
-        inpL = ggml_scale(ctx0, inpL, scale_embd);
-        cb(inpL, "inp_scaled", -1);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            struct ggml_tensor * rope_factors = build_rope_factors(il);
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self_attention
-            {
-                struct ggml_tensor * q = NULL;
-                // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens}
-                q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
-                cb(q, "q", il);
-
-                q = llm_build_norm(ctx0, q, hparams,
-                        model.layers[il].attn_q_a_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(q, "q", il);
-
-                // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens}
-                q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
-                cb(q, "q", il);
-
-                // split into {n_head * n_embd_head_qk_nope, n_tokens}
-                struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
-                        ggml_row_size(q->type, hparams.n_embd_head_k),
-                        ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
-                        0);
-                cb(q_nope, "q_nope", il);
-
-                // and {n_head * n_embd_head_qk_rope, n_tokens}
-                struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
-                        ggml_row_size(q->type, hparams.n_embd_head_k),
-                        ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
-                        ggml_row_size(q->type, n_embd_head_qk_nope));
-                cb(q_pe, "q_pe", il);
-
-                // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
-                struct ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
-                cb(kv_pe_compresseed, "kv_pe_compresseed", il);
-
-                // split into {kv_lora_rank, n_tokens}
-                struct ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
-                        kv_pe_compresseed->nb[1],
-                        0);
-                cb(kv_compressed, "kv_compressed", il);
-
-                // and {n_embd_head_qk_rope, n_tokens}
-                struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
-                        kv_pe_compresseed->nb[1],
-                        kv_pe_compresseed->nb[1],
-                        ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
-                cb(k_pe, "k_pe", il);
-
-                // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont
-                kv_compressed = ggml_cont(ctx0, kv_compressed);
-                kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
-                        model.layers[il].attn_kv_a_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(kv_compressed, "kv_compressed", il);
-
-                // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
-                struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
-                cb(kv, "kv", il);
-
-                // split into {n_head * n_embd_head_qk_nope, n_tokens}
-                struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
-                        ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
-                        ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
-                        0);
-                cb(k_nope, "k_nope", il);
-
-                // and {n_head * n_embd_head_v, n_tokens}
-                struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
-                        ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
-                        ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
-                        ggml_row_size(kv->type, (n_embd_head_qk_nope)));
-                cb(v_states, "v_states", il);
-
-                v_states = ggml_cont(ctx0, v_states);
-                cb(v_states, "v_states", il);
-
-                v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
-                    ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
-                    0);
-                cb(v_states, "v_states", il);
-
-                q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
-                q_pe = ggml_rope_ext(
-                    ctx0, q_pe, inp_pos, rope_factors,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(q_pe, "q_pe", il);
-
-                // shared RoPE key
-                k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
-                k_pe = ggml_rope_ext(
-                    ctx0, k_pe, inp_pos, rope_factors,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(k_pe, "k_pe", il);
-
-                struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
-                cb(q_states, "q_states", il);
-
-                struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
-                cb(k_states, "k_states", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            // scale_res - scale the hidden states for residual connection
-            const float scale_res = scale_depth/sqrtf(float(n_layer));
-            cur = ggml_scale(ctx0, cur, scale_res);
-            cb(cur, "hidden_scaled", il);
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            // scale the hidden states for residual connection
-            cur = ggml_scale(ctx0, cur, scale_res);
-            cb(cur, "hidden_scaled_ffn", il);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head scaling
-        const float scale_lmhead = float(n_embd_base)/float(n_embd);
-        cur = ggml_scale(ctx0, cur, scale_lmhead);
-        cb(cur, "lmhead_scaling", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_gemma() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head_k = hparams.n_embd_head_k;
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
-        cb(inpL, "inp_scaled", -1);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens), inp_pos, nullptr,
-                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                        ext_factor, attn_factor, beta_fast, beta_slow);
-                cb(Qcur, "Qcur", il);
-
-                Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));
-                cb(Qcur, "Qcur_scaled", il);
-
-                Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
-                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                        ext_factor, attn_factor, beta_fast, beta_slow);
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
-                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
-            }
-
-            struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
-            cb(sa_out, "sa_out", il);
-
-            cur = llm_build_norm(ctx0, sa_out, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            // feed-forward network
-            {
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, sa_out);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_gemma2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head_k = hparams.n_embd_head_k;
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
-        cb(inpL, "inp_scaled", -1);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        // gemma 2 requires different mask for layers using sliding window (SWA)
-        struct ggml_tensor * KQ_mask     = build_inp_KQ_mask(true);
-        struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(true);
-
-        for (int il = 0; il < n_layer; ++il) {
-            // (il % 2) layers use SWA
-            struct ggml_tensor * KQ_mask_l = (il % 2 == 0) ? KQ_mask_swa : KQ_mask;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens), inp_pos, nullptr,
-                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                        ext_factor, attn_factor, beta_fast, beta_slow);
-                cb(Qcur, "Qcur", il);
-
-                // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
-                switch (model.type) {
-                    case LLM_TYPE_2B:
-                    case LLM_TYPE_9B:  Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));   break;
-                    case LLM_TYPE_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break;
-                    default: GGML_ABORT("fatal error");
-                };
-                cb(Qcur, "Qcur_scaled", il);
-
-                Kcur = ggml_rope_ext(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
-                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                        ext_factor, attn_factor, beta_fast, beta_slow);
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il);
-            }
-
-            cur = llm_build_norm(ctx0, cur, hparams,
-                    model.layers[il].attn_post_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_post_norm", il);
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
-                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
-            }
-
-            struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
-            cb(sa_out, "sa_out", il);
-
-            cur = llm_build_norm(ctx0, sa_out, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            // feed-forward network
-            {
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = llm_build_norm(ctx0, cur, hparams,
-                model.layers[il].ffn_post_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-            cb(cur, "ffn_post_norm", -1);
-
-            cur = ggml_add(ctx0, cur, sa_out);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-
-        // final logit soft-capping
-        cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
-        cur = ggml_tanh(ctx0, cur);
-        cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
-
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_gemma3() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head_k = hparams.n_embd_head_k;
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
-        if (ubatch.token) {
-            inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
-            cb(inpL, "inp_scaled", -1);
-        }
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        // gemma3 requires different mask for layers using sliding window (SWA)
-        struct ggml_tensor * KQ_mask     = build_inp_KQ_mask(true);
-        struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(true);
-
-        // "5-to-1 interleaved attention"
-        // 5 layers of local attention followed by 1 layer of global attention
-        static const int sliding_window_pattern = 6;
-
-        for (int il = 0; il < n_layer; ++il) {
-            const bool is_sliding          = (il + 1) % sliding_window_pattern;
-            const float freq_base_l        = is_sliding ? 10000.0f    : freq_base;
-            const float freq_scale_l       = is_sliding ? 1.0f        : freq_scale;
-            struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens);
-                Qcur = llm_build_norm(ctx0, Qcur, hparams,
-                    model.layers[il].attn_q_norm,
-                    NULL,
-                    LLM_NORM_RMS, cb, il);
-                cb(Qcur, "Qcur_normed", il);
-
-                Qcur = ggml_rope_ext(
-                        ctx0, Qcur, inp_pos, nullptr,
-                        n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
-                        ext_factor, attn_factor, beta_fast, beta_slow);
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens);
-                Kcur = llm_build_norm(ctx0, Kcur, hparams,
-                    model.layers[il].attn_k_norm,
-                    NULL,
-                    LLM_NORM_RMS, cb, il);
-                cb(Kcur, "Kcur_normed", il);
-
-                Kcur = ggml_rope_ext(
-                        ctx0, Kcur, inp_pos, nullptr,
-                        n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
-                        ext_factor, attn_factor, beta_fast, beta_slow);
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, hparams.f_attention_scale, cb, il);
-            }
-
-            cur = llm_build_norm(ctx0, cur, hparams,
-                    model.layers[il].attn_post_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_post_norm", il);
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
-                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
-            }
-
-            struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
-            cb(sa_out, "sa_out", il);
-
-            cur = llm_build_norm(ctx0, sa_out, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            // feed-forward network
-            {
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = llm_build_norm(ctx0, cur, hparams,
-                model.layers[il].ffn_post_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-            cb(cur, "ffn_post_norm", -1);
-
-            cur = ggml_add(ctx0, cur, sa_out);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_starcoder2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, model.layers[il].attn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
-                }
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        NULL,                      NULL,                        NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
-            cb(cur, "ffn_out", il);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, model.output_norm_b,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_mamba() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        // {n_embd, n_tokens}
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        struct ggml_tensor * state_copy = build_inp_s_copy();
-        struct ggml_tensor * state_mask = build_inp_s_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            cur = llm_build_mamba(ctx0, lctx, ubatch, gf, cur,
-                    state_copy, state_mask,
-                    kv_head, n_kv, cb, il);
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
-                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
-            }
-
-            // residual
-            cur = ggml_add(ctx0, cur, inpL);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        // final rmsnorm
-        cur = llm_build_norm(ctx0, inpL, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_command_r() {
-
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        const float f_logit_scale = hparams.f_logit_scale;
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
-            struct ggml_tensor * ffn_inp = cur;
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
-                }
-
-                if (model.layers[il].attn_q_norm) {
-                    Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens,
-                                ggml_element_size(Qcur) * n_embd_head,
-                                ggml_element_size(Qcur) * n_embd_head * n_head,
-                                0);
-                    cb(Qcur, "Qcur", il);
-                    Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens,
-                                ggml_element_size(Kcur) * n_embd_head,
-                                ggml_element_size(Kcur) * n_embd_head * n_head_kv,
-                                0);
-                    cb(Kcur, "Kcur", il);
-
-                    Qcur = llm_build_norm(ctx0, Qcur, hparams,
-                                model.layers[il].attn_q_norm,
-                                NULL,
-                                LLM_NORM, cb, il);
-                    cb(Qcur, "Qcur", il);
-
-                    Kcur = llm_build_norm(ctx0, Kcur, hparams,
-                            model.layers[il].attn_k_norm,
-                            NULL,
-                            LLM_NORM, cb, il);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur     = ggml_get_rows(ctx0,     cur, inp_out_ids);
-                inpL    = ggml_get_rows(ctx0,    inpL, inp_out_ids);
-                ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
-            }
-
-            struct ggml_tensor * attn_out = cur;
-
-            // feed-forward network
-            {
-                cur = llm_build_ffn(ctx0, lctx, ffn_inp,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            // add together residual + FFN + self-attention
-            cur = ggml_add(ctx0, cur, inpL);
-            cur = ggml_add(ctx0, cur, attn_out);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-
-        if (f_logit_scale) {
-            cur = ggml_scale(ctx0, cur, f_logit_scale);
-        }
-
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-
-    }
-
-    struct ggml_cgraph * build_cohere2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        const float f_logit_scale = hparams.f_logit_scale;
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        // cohere2 requires different mask for layers using sliding window (SWA)
-        struct ggml_tensor * KQ_mask     = build_inp_KQ_mask();
-        struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa();
-
-        // sliding window switch pattern
-        const int32_t sliding_window_pattern = 4;
-
-        for (int il = 0; il < n_layer; ++il) {
-            // three layers sliding window attention (window size 4096) and ROPE
-            // fourth layer uses global attention without positional embeddings
-            const bool           is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
-            struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
-            struct ggml_tensor * ffn_inp = cur;
-
-            // self-attention
-            {
-                // rope freq factors for 128k context
-                struct ggml_tensor * rope_factors = build_rope_factors(il);
-
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
-                }
-
-                if (is_sliding) {
-                    Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
-                                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
-                                        beta_fast, beta_slow);
-                    cb(Qcur, "Qcur", il);
-
-                    Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
-                                        rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
-                                        attn_factor, beta_fast, beta_slow);
-                    cb(Kcur, "Kcur", il);
-                } else {
-                    // For non-sliding layers, just reshape without applying RoPE
-                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-                    cb(Qcur, "Qcur", il);
-
-                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur,
-                                   KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur                              = ggml_get_rows(ctx0, cur, inp_out_ids);
-                inpL                             = ggml_get_rows(ctx0, inpL, inp_out_ids);
-                ffn_inp                          = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
-            }
-
-            struct ggml_tensor * attn_out = cur;
-
-            // feed-forward network
-            {
-                cur = llm_build_ffn(ctx0, lctx, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate,
-                                    NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR,
-                                    cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            // add together residual + FFN + self-attention
-            cur = ggml_add(ctx0, cur, inpL);
-            cur = ggml_add(ctx0, cur, attn_out);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-
-        if (f_logit_scale) {
-            cur = ggml_scale(ctx0, cur, f_logit_scale);
-        }
-
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    // ref: https://allenai.org/olmo
-    // based on the original build_llama() function, changes:
-    //   * non-parametric layer norm
-    //   * clamp qkv
-    //   * removed bias
-    //   * removed MoE
-    struct ggml_cgraph * build_olmo() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    NULL, NULL,
-                    LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                if (hparams.f_clamp_kqv > 0.0f) {
-                    Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                if (hparams.f_clamp_kqv > 0.0f) {
-                    Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                if (hparams.f_clamp_kqv > 0.0f) {
-                    Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
-                    cb(Vcur, "Vcur", il);
-                }
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, nullptr,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    NULL, NULL,
-                    LLM_NORM, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_up,   NULL, NULL,
-                    model.layers[il].ffn_gate, NULL, NULL,
-                    model.layers[il].ffn_down, NULL, NULL,
-                    NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            cb(cur, "ffn_out", il);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "ffn_out", il);
-
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                NULL, NULL,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_olmo2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            cur = inpL;
-
-            // self_attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(Qcur, "Qcur_normed", il);
-
-                Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(Kcur, "Kcur_normed", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-
-                Qcur = ggml_rope_ext(
-                    ctx0, Qcur, inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur_rope", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, Kcur, inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur_rope", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            cur = llm_build_norm(ctx0, cur, hparams,
-                    model.layers[il].attn_post_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_post_norm", il);
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            cur = llm_build_ffn(ctx0, lctx, ffn_inp,
-                    model.layers[il].ffn_up,   NULL, NULL,
-                    model.layers[il].ffn_gate, NULL, NULL,
-                    model.layers[il].ffn_down, NULL, NULL,
-                    NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            cb(cur, "ffn_out", il);
-
-            cur = llm_build_norm(ctx0, cur, hparams,
-                model.layers[il].ffn_post_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-            cb(cur, "ffn_post_norm", -1);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "ffn_out", il);
-
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    // based on the build_qwen2moe() function, changes:
-    //   * removed shared experts
-    //   * removed bias
-    //   * added q, k norm
-    struct ggml_cgraph * build_olmoe() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self_attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(Qcur, "Qcur_normed", il);
-
-                Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(Kcur, "Kcur_normed", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-
-                Qcur = ggml_rope_ext(
-                    ctx0, Qcur, inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur_rope", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, Kcur, inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur_rope", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // MoE branch
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_moe_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_gate_inp,
-                    model.layers[il].ffn_up_exps,
-                    model.layers[il].ffn_gate_exps,
-                    model.layers[il].ffn_down_exps,
-                    nullptr,
-                    n_expert, n_expert_used,
-                    LLM_FFN_SILU, false,
-                    false, 0.0,
-                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
-                    cb, il);
-            cb(cur, "ffn_moe_out", il);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_openelm() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            const int64_t n_head    = hparams.n_head(il);
-            const int64_t n_head_kv = hparams.n_head_kv(il);
-            const int64_t n_head_qkv = 2*n_head_kv + n_head;
-
-            cur = inpL;
-            struct ggml_tensor * residual = cur;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
-                cb(cur, "wqkv", il);
-
-                cur = ggml_reshape_3d(ctx0, cur, n_embd_head_k, n_head_qkv, n_tokens);
-
-                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, cur->nb[1], cur->nb[2], 0));
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head));
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv)));
-                cb(Vcur, "Vcur", il);
-
-                Qcur = llm_build_norm(ctx0, Qcur, hparams,
-                        model.layers[il].attn_q_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(Qcur, "Qcur", il);
-
-                Kcur = llm_build_norm(ctx0, Kcur, hparams,
-                        model.layers[il].attn_k_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(Kcur, "Kcur", il);
-
-                Qcur = ggml_rope_ext(
-                    ctx0, Qcur, inp_pos, NULL, n_rot, rope_type, n_ctx_orig,
-                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, Kcur, inp_pos, NULL, n_rot, rope_type, n_ctx_orig,
-                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                Vcur = ggml_reshape_2d(ctx0, Vcur, n_embd_head * n_head_kv, n_tokens);
-                cb(Qcur, "Vcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                residual = ggml_get_rows(ctx0, residual, inp_out_ids);
-                cur = ggml_get_rows(ctx0, cur, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        // norm
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_gptneox() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm,
-                    model.layers[il].attn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
-                cb(cur, "wqkv", il);
-
-                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
-                cb(cur, "bqkv", il);
-
-                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
-                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
-            }
-
-            // ffn
-            if (hparams.use_par_res) {
-                // attention and ffn are computed in parallel
-                // x = x + attn(ln1(x)) + ffn(ln2(x))
-
-                struct ggml_tensor * attn_out = cur;
-
-                cur = llm_build_norm(ctx0, inpL, hparams,
-                        model.layers[il].ffn_norm,
-                        model.layers[il].ffn_norm_b,
-                        LLM_NORM, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        NULL,                      NULL,                        NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
-                cb(cur, "ffn_out", il);
-
-                cur = ggml_add(ctx0, cur, inpL);
-                cb(cur, "ffn_out", il);
-
-                cur = ggml_add(ctx0, cur, attn_out);
-                cur = lctx.cvec.apply_to(ctx0, cur, il);
-                cb(cur, "l_out", il);
-
-                // input for next layer
-                inpL = cur;
-            } else {
-                // attention and ffn are computed sequentially
-                // x = x + attn(ln1(x))
-                // x = x + ffn(ln2(x))
-
-                struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
-                cb(ffn_inp, "ffn_inp", il);
-
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm,
-                        model.layers[il].ffn_norm_b,
-                        LLM_NORM, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        NULL,                      NULL,                        NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
-                cb(cur, "ffn_out", il);
-
-                cur = ggml_add(ctx0, cur, ffn_inp);
-                cur = lctx.cvec.apply_to(ctx0, cur, il);
-                cb(cur, "l_out", il);
-
-                // input for next layer
-                inpL = cur;
-            }
-        }
-
-        cur = llm_build_norm(ctx0, inpL, hparams,
-                model.output_norm,
-                model.output_norm_b,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_arctic() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_up,   NULL, NULL,
-                    model.layers[il].ffn_gate, NULL, NULL,
-                    model.layers[il].ffn_down, NULL, NULL,
-                    NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            cb(cur, "ffn_out", il);
-
-            struct ggml_tensor * ffn_out = ggml_add(ctx0, cur, ffn_inp);
-            cb(ffn_out, "ffn_out", il);
-
-            // MoE
-            cur = llm_build_norm(ctx0, inpSA, hparams,
-                    model.layers[il].ffn_norm_exps, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm_exps", il);
-
-            cur = llm_build_moe_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_gate_inp,
-                    model.layers[il].ffn_up_exps,
-                    model.layers[il].ffn_gate_exps,
-                    model.layers[il].ffn_down_exps,
-                    nullptr,
-                    n_expert, n_expert_used,
-                    LLM_FFN_SILU, true,
-                    false, 0.0,
-                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
-                    cb, il);
-            cb(cur, "ffn_moe_out", il);
-
-            cur = ggml_add(ctx0, cur, ffn_out);
-            cb(cur, "ffn_out", il);
-
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_deepseek() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // rope freq factors for llama3; may return nullptr for llama2 and other models
-                struct ggml_tensor * rope_factors = build_rope_factors(il);
-
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
-                }
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            if ((uint32_t) il < hparams.n_layer_dense_lead) {
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            } else {
-                // MoE branch
-                ggml_tensor * moe_out =
-                        llm_build_moe_ffn(ctx0, lctx, cur,
-                            model.layers[il].ffn_gate_inp,
-                            model.layers[il].ffn_up_exps,
-                            model.layers[il].ffn_gate_exps,
-                            model.layers[il].ffn_down_exps,
-                            nullptr,
-                            n_expert, n_expert_used,
-                            LLM_FFN_SILU, false,
-                            false, hparams.expert_weights_scale,
-                            LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
-                            cb, il);
-                cb(moe_out, "ffn_moe_out", il);
-
-                // FFN shared expert
-                {
-                    ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, cur,
-                            model.layers[il].ffn_up_shexp,   NULL, NULL,
-                            model.layers[il].ffn_gate_shexp, NULL, NULL,
-                            model.layers[il].ffn_down_shexp, NULL, NULL,
-                            NULL,
-                            LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                    cb(ffn_shexp, "ffn_shexp", il);
-
-                    cur = ggml_add(ctx0, moe_out, ffn_shexp);
-                    cb(cur, "ffn_out", il);
-                }
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_deepseek2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        bool is_lite = (hparams.n_layer == 27);
-
-        // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly.
-        // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
-        const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale));
-        const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k));
-        const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
-
-        const uint32_t n_embd_head_qk_rope = hparams.n_rot;
-        const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
-        const uint32_t kv_lora_rank = hparams.n_lora_kv;
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        // {n_embd, n_tokens}
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self_attention
-            {
-                struct ggml_tensor * q = NULL;
-                if (!is_lite) {
-                    // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens}
-                    q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
-                    cb(q, "q", il);
-
-                    q = llm_build_norm(ctx0, q, hparams,
-                            model.layers[il].attn_q_a_norm, NULL,
-                            LLM_NORM_RMS, cb, il);
-                    cb(q, "q", il);
-
-                    // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens}
-                    q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
-                    cb(q, "q", il);
-                } else {
-                    q = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
-                    cb(q, "q", il);
-                }
-
-                // split into {n_head * n_embd_head_qk_nope, n_tokens}
-                struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
-                        ggml_row_size(q->type, hparams.n_embd_head_k),
-                        ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
-                        0);
-                cb(q_nope, "q_nope", il);
-
-                // and {n_head * n_embd_head_qk_rope, n_tokens}
-                struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
-                        ggml_row_size(q->type, hparams.n_embd_head_k),
-                        ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
-                        ggml_row_size(q->type, n_embd_head_qk_nope));
-                cb(q_pe, "q_pe", il);
-
-                // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
-                struct ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
-                cb(kv_pe_compresseed, "kv_pe_compresseed", il);
-
-                // split into {kv_lora_rank, n_tokens}
-                struct ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
-                        kv_pe_compresseed->nb[1],
-                        0);
-                cb(kv_compressed, "kv_compressed", il);
-
-                // and {n_embd_head_qk_rope, n_tokens}
-                struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
-                        kv_pe_compresseed->nb[1],
-                        kv_pe_compresseed->nb[1],
-                        ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
-                cb(k_pe, "k_pe", il);
-
-                // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont
-                kv_compressed = ggml_cont(ctx0, kv_compressed);
-                kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
-                        model.layers[il].attn_kv_a_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(kv_compressed, "kv_compressed", il);
-
-                // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
-                struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
-                cb(kv, "kv", il);
-
-                // split into {n_head * n_embd_head_qk_nope, n_tokens}
-                struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
-                        ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
-                        ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
-                        0);
-                cb(k_nope, "k_nope", il);
-
-                // and {n_head * n_embd_head_v, n_tokens}
-                struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
-                        ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
-                        ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
-                        ggml_row_size(kv->type, (n_embd_head_qk_nope)));
-                cb(v_states, "v_states", il);
-
-                v_states = ggml_cont(ctx0, v_states);
-                cb(v_states, "v_states", il);
-
-                v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
-                    ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
-                    0);
-                cb(v_states, "v_states", il);
-
-                q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
-                q_pe = ggml_rope_ext(
-                    ctx0, q_pe, inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor_scaled, beta_fast, beta_slow
-                );
-                cb(q_pe, "q_pe", il);
-
-                // shared RoPE key
-                k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
-                k_pe = ggml_rope_ext(
-                    ctx0, k_pe, inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor_scaled, beta_fast, beta_slow
-                );
-                cb(k_pe, "k_pe", il);
-
-                struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
-                cb(q_states, "q_states", il);
-
-                struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
-                cb(k_states, "k_states", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            if ((uint32_t) il < hparams.n_layer_dense_lead) {
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            } else {
-                // MoE branch
-                ggml_tensor * moe_out =
-                        llm_build_moe_ffn(ctx0, lctx, cur,
-                            model.layers[il].ffn_gate_inp,
-                            model.layers[il].ffn_up_exps,
-                            model.layers[il].ffn_gate_exps,
-                            model.layers[il].ffn_down_exps,
-                            model.layers[il].ffn_exp_probs_b,
-                            n_expert, n_expert_used,
-                            LLM_FFN_SILU, hparams.expert_weights_norm,
-                            true, hparams.expert_weights_scale,
-                            (enum llama_expert_gating_func_type) hparams.expert_gating_func,
-                            cb, il);
-                cb(moe_out, "ffn_moe_out", il);
-
-                // FFN shared expert
-                {
-                    ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, cur,
-                            model.layers[il].ffn_up_shexp,   NULL, NULL,
-                            model.layers[il].ffn_gate_shexp, NULL, NULL,
-                            model.layers[il].ffn_down_shexp, NULL, NULL,
-                            NULL,
-                            LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                    cb(ffn_shexp, "ffn_shexp", il);
-
-                    cur = ggml_add(ctx0, moe_out, ffn_shexp);
-                    cb(cur, "ffn_out", il);
-                }
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_bitnet() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                if (model.layers[il].wq_scale) {
-                    Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale);
-                }
-                cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                // B1.K
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                if (model.layers[il].wk_scale) {
-                    Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale);
-                }
-                cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                // B1.V
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                if (model.layers[il].wv_scale) {
-                    Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale);
-                }
-                cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
-                }
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        NULL, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-
-                cur = llm_build_norm(ctx0, cur, hparams,
-                        model.layers[il].attn_sub_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "attn_sub_norm", il);
-
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
-                if (model.layers[il].wo_scale) {
-                    cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale);
-                }
-                if (model.layers[il].bo) {
-                    cur = ggml_add(ctx0, cur, model.layers[il].bo);
-                }
-                cb(cur, "attn_o_out", il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward forward
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_up,   NULL, model.layers[il].ffn_up_scale,
-                    model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_scale,
-                    NULL,                      NULL, NULL,
-                    NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            cb(cur, "ffn_sub_out", il);
-
-            cur = llm_build_norm(ctx0, cur, hparams,
-                            model.layers[il].ffn_sub_norm, NULL,
-                            LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_sub_norm", il);
-
-            cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_down, cur);
-            if (model.layers[il].ffn_down_scale) {
-                cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale);
-            }
-            cb(cur, "ffn_down", il);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        // FIXME: do not use model.tok_embd directly, duplicate as model.output
-        cur = llm_build_lora_mm(lctx, ctx0, model.tok_embd, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-        return gf;
-    }
-
-    struct ggml_cgraph * build_t5_enc() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        GGML_ASSERT(lctx.is_encoding);
-        struct ggml_tensor * pos_bucket_enc = llm_build_pos_bucket(false);
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask_enc = build_inp_KQ_mask(false);
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm_enc, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq_enc, cur);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk_enc, cur);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv_enc, cur);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-
-                struct ggml_tensor * q =                 ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
-                struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
-
-                struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
-                cb(kq, "kq", il);
-
-                struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc;
-                struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_enc, attn_rel_b);
-                struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias);
-                cb(kq_b, "kq_b", il);
-
-                kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_enc, 1.0f, hparams.f_max_alibi_bias);
-                cb(kq, "kq_soft_max_ext", il);
-
-                struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)));
-                cb(v, "v", il);
-
-                struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq);
-                cb(kqv, "kqv", il);
-
-                struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
-                cb(kqv_merged, "kqv_merged", il);
-
-                cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
-                cb(cur, "kqv_merged_cont", il);
-
-                ggml_build_forward_expand(gf, cur);
-
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo_enc, cur);
-                cb(cur, "kqv_out", il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm_enc, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                // T5 uses relu, flan-T5 uses gelu-gated
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up_enc,   NULL, NULL,
-                        model.layers[il].ffn_gate_enc, NULL, NULL,
-                        model.layers[il].ffn_down_enc, NULL, NULL,
-                        NULL,
-                        model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
-                        model.layers[il].ffn_gate_enc ? LLM_FFN_PAR  : LLM_FFN_SEQ,
-                        cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "ffn_out", il);
-
-            ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
-            if (layer_dir != nullptr) {
-                cur = ggml_add(ctx0, cur, layer_dir);
-            }
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-        cb(cur, "result_embd", -1);
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm_enc, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_t5_dec() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        GGML_ASSERT(!lctx.is_encoding);
-        GGML_ASSERT(n_outputs_enc > 0 && "call llama_encode() first");
-
-        struct ggml_tensor * embd_enc       = llm_build_inp_embd_enc();
-        struct ggml_tensor * pos_bucket_dec = llm_build_pos_bucket(true);
-
-        struct ggml_tensor * KQ_mask_dec   = build_inp_KQ_mask();
-        struct ggml_tensor * KQ_mask_cross = llm_build_inp_KQ_mask_cross();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-
-                llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
-
-                struct ggml_tensor * k =
-                    ggml_view_3d(ctx0, kv_self.k_l[il],
-                            n_embd_head_k, n_kv, n_head_kv,
-                            ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
-                            ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
-                            0);
-                cb(k, "k", il);
-
-                struct ggml_tensor * v =
-                    ggml_view_3d(ctx0, kv_self.v_l[il],
-                            n_kv, n_embd_head_v, n_head_kv,
-                            ggml_element_size(kv_self.v_l[il])*n_ctx,
-                            ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
-                            0);
-                cb(v, "v", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-
-                struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
-
-                struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
-                cb(kq, "kq", il);
-
-                struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b;
-                struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_dec, attn_rel_b);
-                struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias);
-                cb(kq_b, "kq_b", il);
-
-                kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_dec, 1.0f, hparams.f_max_alibi_bias);
-                cb(kq, "kq_soft_max_ext", il);
-
-                struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
-                cb(kqv, "kqv", il);
-
-                struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
-                cb(kqv_merged, "kqv_merged", il);
-
-                cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
-                cb(cur, "kqv_merged_cont", il);
-
-                ggml_build_forward_expand(gf, cur);
-
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
-                cb(cur, "kqv_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, inpSA);
-            cb(cur, "cross_inp", il);
-
-            struct ggml_tensor * inpCA = cur;
-
-            // norm
-            cur = llm_build_norm(ctx0, cur, hparams,
-                    model.layers[il].attn_norm_cross, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm_cross", il);
-
-            // cross-attention
-            {
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq_cross, cur);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk_cross, embd_enc);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv_cross, embd_enc);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc);
-
-                struct ggml_tensor * q =                 ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
-                struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
-
-                struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
-                cb(kq, "kq", il);
-
-                kq = ggml_soft_max_ext(ctx0, kq, KQ_mask_cross, 1.0f, hparams.f_max_alibi_bias);
-                cb(kq, "kq_soft_max_ext", il);
-
-                struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_outputs_enc)));
-                cb(v, "v", il);
-
-                struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_outputs_enc, n_embd_head, n_head_kv), kq);
-                cb(kqv, "kqv", il);
-
-                struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
-                cb(kqv_merged, "kqv_merged", il);
-
-                cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
-                cb(cur, "kqv_merged_cont", il);
-
-                ggml_build_forward_expand(gf, cur);
-
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo_cross, cur);
-                cb(cur, "kqv_out", il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-                inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpCA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                // T5 uses relu, flan-T5 uses gelu-gated
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        model.layers[il].ffn_gate, NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
-                        model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ,
-                        cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "ffn_out", il);
-
-            ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
-            if (layer_dir != nullptr) {
-                cur = ggml_add(ctx0, cur, layer_dir);
-            }
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-        cb(cur, "result_embd", -1);
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_jais() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm,
-                    model.layers[il].attn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
-                cb(cur, "wqkv", il);
-
-                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
-                cb(cur, "bqkv", il);
-
-                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*cur->nb[0]*(n_embd)));
-                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd)));
-                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa)));
-
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-
-                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/float(n_embd_head), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
-                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
-            }
-
-            // add the input
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // FF
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm,
-                        model.layers[il].ffn_norm_b,
-                        LLM_NORM, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-                cb(cur, "ffn_out", il);
-            }
-
-            inpL = ggml_add(ctx0, cur, ffn_inp);
-            cb(inpL, "l_out", il);
-        }
-
-        cur = llm_build_norm(ctx0, inpL, hparams,
-                model.output_norm,
-                model.output_norm_b,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_chatglm() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm,
-                    NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                struct ggml_tensor * Qcur = nullptr;
-                struct ggml_tensor * Kcur = nullptr;
-                struct ggml_tensor * Vcur = nullptr;
-                if (model.layers[il].wqkv == nullptr) {
-                    Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                    if (model.layers[il].bq) {
-                        Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    }
-                    Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                    if (model.layers[il].bk) {
-                        Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    }
-                    Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                    if (model.layers[il].bv) {
-                        Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    }
-                } else {
-                    cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
-                    cb(cur, "wqkv", il);
-                    if (model.layers[il].bqkv) {
-                        cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
-                        cb(cur, "bqkv", il);
-                    }
-                    Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                    Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                    Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-                }
-                cb(Qcur, "Qcur", il);
-                cb(Kcur, "Kcur", il);
-                cb(Vcur, "Vcur", il);
-                //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor);
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur_rope", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur_rope", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            // Add the input
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // FF
-            {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm,
-                        NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = llm_build_ffn(ctx0, lctx, cur,
-                        model.layers[il].ffn_up,   NULL, NULL,
-                        NULL,                      NULL, NULL,
-                        model.layers[il].ffn_down, NULL, NULL,
-                        NULL,
-                        LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il);
-                cb(cur, "ffn_out", il);
-
-            }
-
-            inpL = ggml_add(ctx0, cur, ffn_inp);
-            cb(inpL, "l_out", il);
-        }
-
-        cur = llm_build_norm(ctx0, inpL, hparams,
-                model.output_norm,
-                NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_nemotron() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        //GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm,
-                    model.layers[il].attn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
-                }
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm,
-                    model.layers[il].ffn_norm_b,
-                    LLM_NORM, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                    NULL,                      NULL,                        NULL,
-                    model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                    NULL,
-                    LLM_FFN_RELU_SQR, LLM_FFN_SEQ, cb, il);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "ffn_out", il);
-
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, model.output_norm_b,
-                LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_exaone() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "attn_norm", il);
-
-            // self-attention
-            {
-                // rope freq factors for llama3; may return nullptr for llama2 and other models
-                struct ggml_tensor * rope_factors = build_rope_factors(il);
-
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
-                }
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_up,   NULL, NULL,
-                    model.layers[il].ffn_gate, NULL, NULL,
-                    model.layers[il].ffn_down, NULL, NULL,
-                    NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            cb(cur, "ffn_out", il);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "ffn_out", il);
-
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    ggml_cgraph * build_rwkv6() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // Token shift state dimensions should be 2 * n_emb
-        GGML_ASSERT(n_embd == hparams.n_embd_k_s() / 2);
-
-        const int64_t n_seqs = ubatch.n_seqs;
-        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
-        const int64_t n_tokens = ubatch.n_tokens;
-        GGML_ASSERT(n_seqs != 0);
-        GGML_ASSERT(ubatch.equal_seqs);
-        GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-        struct ggml_tensor * state_copy = build_inp_s_copy();
-        struct ggml_tensor * state_mask = build_inp_s_mask();
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-        inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
-
-        for (int il = 0; il < n_layer; ++il) {
-            const llama_layer * layer = &model.layers[il];
-
-            // (ab)using the KV cache to store the states
-            struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0,
-                    gf, kv_self.k_l[il], state_copy, state_mask,
-                    hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs);
-            struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0,
-                    gf, kv_self.v_l[il], state_copy, state_mask,
-                    hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs);
-
-            cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
-            token_shift = ggml_reshape_3d(ctx0, token_shift, n_embd, 2, n_seqs);
-
-            struct ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
-            struct ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
-
-            struct ggml_tensor * x_norm_att = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM, cb, il);
-            struct ggml_tensor * x_prev = ggml_concat(
-                ctx0,
-                att_shift,
-                ggml_view_3d(ctx0, x_norm_att, n_embd, n_seq_tokens - 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], 0),
-                1
-            );
-
-            cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, hparams.wkv_head_size, n_embd / hparams.wkv_head_size));
-            ggml_build_forward_expand(gf, cur);
-            ggml_build_forward_expand(
-                gf,
-                ggml_cpy(
-                    ctx0,
-                    wkv_states,
-                    ggml_view_1d(
-                        ctx0,
-                        kv_self.v_l[il],
-                        hparams.n_embd_v_s() * n_seqs,
-                        hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
-                    )
-                )
-            );
-
-            struct ggml_tensor * x_norm_ffn = llm_build_norm(ctx0, cur, hparams, layer->attn_norm_2, layer->attn_norm_2_b, LLM_NORM, cb, il);
-            x_prev = ggml_concat(
-                ctx0,
-                ffn_shift,
-                ggml_view_3d(ctx0, x_norm_ffn, n_embd, n_seq_tokens - 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], 0),
-                1
-            );
-            cur = ggml_add(ctx0, cur, llm_build_rwkv6_channel_mix(lctx, ctx0, layer, x_norm_ffn, x_prev));
-            ggml_build_forward_expand(gf, cur);
-
-            struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
-            struct ggml_tensor * last_norm_ffn = ggml_view_3d(ctx0, x_norm_ffn, n_embd, 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_ffn));
-
-            token_shift = ggml_concat(ctx0, last_norm_att, last_norm_ffn, 1);
-
-            ggml_build_forward_expand(
-                gf,
-                ggml_cpy(
-                    ctx0,
-                    ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * 2, 0),
-                    ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il]))
-                )
-            );
-
-            if (hparams.rescale_every_n_layers != 0 && (il + 1) % hparams.rescale_every_n_layers == 0) {
-                cur = ggml_scale(ctx0, cur, 0.5F);
-            }
-
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-        struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-        cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
-        cur = ggml_get_rows(ctx0, cur, inp_out_ids);
-
-        cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    // ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py
-    ggml_cgraph * build_rwkv6qwen2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        GGML_ASSERT(n_embd == hparams.n_embd_k_s());
-
-        const int64_t n_seqs = ubatch.n_seqs;
-        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
-        const int64_t n_tokens = ubatch.n_tokens;
-        GGML_ASSERT(n_seqs != 0);
-        GGML_ASSERT(ubatch.equal_seqs);
-        GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-        struct ggml_tensor * state_copy = build_inp_s_copy();
-        struct ggml_tensor * state_mask = build_inp_s_mask();
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        for (int il = 0; il < n_layer; ++il) {
-            const llama_layer * layer = &model.layers[il];
-
-            // (ab)using the KV cache to store the states
-            struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0,
-                    gf, kv_self.k_l[il], state_copy, state_mask,
-                    hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs);
-            struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0,
-                    gf, kv_self.v_l[il], state_copy, state_mask,
-                    hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs);
-
-            cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
-            token_shift = ggml_reshape_3d(ctx0, token_shift, n_embd, 1, n_seqs);
-
-            struct ggml_tensor * x_norm_att = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, cb, il);
-            struct ggml_tensor * x_prev = ggml_concat(
-                ctx0,
-                token_shift,
-                ggml_view_3d(ctx0, x_norm_att, n_embd, n_seq_tokens - 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], 0),
-                1
-            );
-
-            struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
-            ggml_build_forward_expand(
-                gf,
-                ggml_cpy(
-                    ctx0,
-                    ggml_view_1d(ctx0, last_norm_att, n_embd * n_seqs, 0),
-                    ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il]))
-                )
-            );
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, hparams.wkv_head_size, hparams.n_head_kv()));
-            ggml_build_forward_expand(gf, ffn_inp);
-            ggml_build_forward_expand(
-                gf,
-                ggml_cpy(
-                    ctx0,
-                    wkv_states,
-                    ggml_view_1d(
-                        ctx0,
-                        kv_self.v_l[il],
-                        hparams.n_embd_v_s() * n_seqs,
-                        hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
-                    )
-                )
-            );
-
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_up,   NULL, NULL,
-                    model.layers[il].ffn_gate, NULL, NULL,
-                    model.layers[il].ffn_down, NULL, NULL,
-                    NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            cb(cur, "ffn_out", il);
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-        struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-        cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
-        cur = ggml_get_rows(ctx0, cur, inp_out_ids);
-
-        cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    // ref: https://github.com/facebookresearch/chameleon
-    // based on the original build_llama() function, changes:
-    //   * qk-norm
-    //   * swin-norm
-    //   * removed bias
-    //   * removed MoE
-    struct ggml_cgraph * build_chameleon() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        // mutable variable, needed during the last layer of the computation to skip unused tokens
-        int32_t n_tokens = this->n_tokens;
-
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = build_inp_pos();
-
-        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
-
-        for (int il = 0; il < n_layer; ++il) {
-            struct ggml_tensor * inpSA = inpL;
-
-            // norm
-            if (hparams.swin_norm) {
-                cur = inpL;
-            } else {
-                cur = llm_build_norm(ctx0, inpL, hparams,
-                    model.layers[il].attn_norm, NULL,
-                    LLM_NORM_RMS, cb, il);
-                cb(cur, "attn_norm", il);
-            }
-
-            // self-attention
-            {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-
-                if (model.layers[il].attn_q_norm) {
-                    Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens,
-                                ggml_element_size(Qcur) * n_embd_head,
-                                ggml_element_size(Qcur) * n_embd_head * n_head,
-                                0);
-                    cb(Qcur, "Qcur", il);
-
-                    Qcur = llm_build_norm(ctx0, Qcur, hparams,
-                                model.layers[il].attn_q_norm,
-                                model.layers[il].attn_q_norm_b,
-                                LLM_NORM, cb, il);
-                    cb(Qcur, "Qcur", il);
-                }
-
-                if (model.layers[il].attn_k_norm) {
-                    Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens,
-                                ggml_element_size(Kcur) * n_embd_head,
-                                ggml_element_size(Kcur) * n_embd_head * n_head_kv,
-                                0);
-                    cb(Kcur, "Kcur", il);
-
-                    Kcur = llm_build_norm(ctx0, Kcur, hparams,
-                               model.layers[il].attn_k_norm,
-                               model.layers[il].attn_k_norm_b,
-                               LLM_NORM, cb, il);
-                    cb(Kcur, "Kcur", il);
-                }
-
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Qcur, "Qcur", il);
-
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
-                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                    ext_factor, attn_factor, beta_fast, beta_slow
-                );
-                cb(Kcur, "Kcur", il);
-
-                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, nullptr,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
-
-                if (hparams.swin_norm) {
-                    cur = llm_build_norm(ctx0, cur, hparams,
-                        model.layers[il].attn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                }
-            }
-
-            if (il == n_layer - 1) {
-                // skip computing output for unused tokens
-                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
-                n_tokens = n_outputs;
-                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
-                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
-            }
-
-            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-            cb(ffn_inp, "ffn_inp", il);
-
-            // feed-forward network
-            if (!hparams.swin_norm) {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-            }
-
-            cur = llm_build_ffn(ctx0, lctx, cur,
-                    model.layers[il].ffn_up,   NULL, NULL,
-                    model.layers[il].ffn_gate, NULL, NULL,
-                    model.layers[il].ffn_down, NULL, NULL,
-                    NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            cb(cur, "ffn_out", il);
-
-            if (hparams.swin_norm) {
-                cur = llm_build_norm(ctx0, cur, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-            }
-
-            cur = ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "ffn_out", il);
-
-            cur = lctx.cvec.apply_to(ctx0, cur, il);
-            cb(cur, "l_out", il);
-
-            // input for next layer
-            inpL = cur;
-        }
-
-        cur = inpL;
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm, NULL,
-                LLM_NORM_RMS, cb, -1);
-        cb(cur, "result_norm", -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-        cb(cur, "result_output_with_img_logits", -1);
-
-        // TODO: this suppresses the output of image tokens, which is required to enable text-only outputs.
-        // Needs to be removed once image outputs are supported.
-        int img_token_end_idx = 8196;
-        int img_token_start_idx = 4;
-        int num_img_tokens = img_token_end_idx - img_token_start_idx;
-        // creates 1d tensor of size num_img_tokens and values -FLT_MAX,
-        // which ensures that text token values are always at least larger than image token values
-        struct ggml_tensor * img_logits = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, num_img_tokens);
-        img_logits = ggml_clamp(ctx0, img_logits, -FLT_MAX, -FLT_MAX);
-        cb(img_logits, "img_logits", -1);
-        cur = ggml_set_1d(ctx0, cur, img_logits, ggml_element_size(cur) * img_token_start_idx);
-        cb(cur, "result_output", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-
-    struct ggml_cgraph * build_wavtokenizer_dec() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
-
-        struct ggml_tensor * cur;
-        struct ggml_tensor * inpL;
-
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
-
-        cur = ggml_cont(ctx0, ggml_transpose(ctx0, inpL));
-
-        cur = ggml_conv_1d_ph(ctx0, model.conv1d, cur, 1, 1);
-        cur = ggml_add(ctx0, cur, model.conv1d_b);
-
-        // posnet
-        for (uint32_t il = 0; il < hparams.posnet.n_layer; ++il) {
-            const auto & layer = model.layers[il].posnet;
-
-            inpL = cur;
-
-            switch (il) {
-                case 0:
-                case 1:
-                case 3:
-                case 4:
-                    {
-                        cur = llm_build_norm(ctx0, cur, hparams,
-                                layer.norm1,
-                                layer.norm1_b,
-                                LLM_NORM_GROUP, cb, 0);
-
-                        cur = ggml_mul(ctx0, ggml_sigmoid(ctx0, cur), cur);
-
-                        cur = ggml_conv_1d_ph(ctx0, layer.conv1, cur, 1, 1);
-                        cur = ggml_add(ctx0, cur, layer.conv1_b);
-
-                        cur = llm_build_norm(ctx0, cur, hparams,
-                                layer.norm2,
-                                layer.norm2_b,
-                                LLM_NORM_GROUP, cb, 0);
-
-                        cur = ggml_mul(ctx0, ggml_sigmoid(ctx0, cur), cur);
-
-                        cur = ggml_conv_1d_ph(ctx0, layer.conv2, cur, 1, 1);
-                        cur = ggml_add(ctx0, cur, layer.conv2_b);
-
-                        cur = ggml_add(ctx0, cur, inpL);
-                    } break;
-                case 2:
-                    {
-                        cur = llm_build_norm(ctx0, cur, hparams,
-                                layer.attn_norm,
-                                layer.attn_norm_b,
-                                LLM_NORM_GROUP, cb, 0);
-
-                        struct ggml_tensor * q;
-                        struct ggml_tensor * k;
-                        struct ggml_tensor * v;
-
-                        q = ggml_conv_1d_ph(ctx0, layer.attn_q, cur, 1, 1);
-                        k = ggml_conv_1d_ph(ctx0, layer.attn_k, cur, 1, 1);
-                        v = ggml_conv_1d_ph(ctx0, layer.attn_v, cur, 1, 1);
-
-                        q = ggml_add(ctx0, q, layer.attn_q_b);
-                        k = ggml_add(ctx0, k, layer.attn_k_b);
-                        v = ggml_add(ctx0, v, layer.attn_v_b);
-
-                        q = ggml_cont(ctx0, ggml_transpose(ctx0, q));
-                        k = ggml_cont(ctx0, ggml_transpose(ctx0, k));
-
-                        struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
-
-                        kq = ggml_soft_max_ext(ctx0, kq, nullptr, 1.0f/sqrtf(float(hparams.posnet.n_embd)), 0.0f);
-
-                        cur = ggml_mul_mat(ctx0, kq, v);
-
-                        cur = ggml_conv_1d_ph(ctx0, layer.attn_o, cur, 1, 1);
-                        cur = ggml_add(ctx0, cur, layer.attn_o_b);
-
-                        cur = ggml_add(ctx0, cur, inpL);
-                    } break;
-                case 5:
-                    {
-                        cur = llm_build_norm(ctx0, cur, hparams,
-                                layer.norm,
-                                layer.norm_b,
-                                LLM_NORM_GROUP, cb, 0);
-                    } break;
-                default: GGML_ABORT("unknown posnet layer");
-            };
-        }
-
-        cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.tok_norm,
-                model.tok_norm_b,
-                LLM_NORM, cb, -1);
-
-        cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
-
-        inpL = cur;
-
-        // convnext
-        for (uint32_t il = 0; il < hparams.convnext.n_layer; ++il) {
-            const auto & layer = model.layers[il].convnext;
-
-            cur = inpL;
-
-            cur = ggml_conv_1d_dw_ph(ctx0, layer.dw, cur, 1, 1);
-            cur = ggml_add(ctx0, cur, layer.dw_b);
-
-            cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
-
-            cur = llm_build_norm(ctx0, cur, hparams,
-                    layer.norm,
-                    layer.norm_b,
-                    LLM_NORM, cb, -1);
-
-            cur = llm_build_ffn(ctx0, lctx, cur,
-                    layer.pw1, layer.pw1_b, NULL,
-                    NULL,      NULL,        NULL,
-                    layer.pw2, layer.pw2_b, NULL,
-                    NULL,
-                    LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
-
-            cur = ggml_mul(ctx0, cur, layer.gamma);
-
-            cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
-
-            inpL = ggml_add(ctx0, cur, inpL);
-        }
-
-        cur = inpL;
-
-        cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
-
-        cur = llm_build_norm(ctx0, cur, hparams,
-                model.output_norm,
-                model.output_norm_b,
-                LLM_NORM, cb, -1);
-
-        // lm_head
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
-
-        cur = ggml_add(ctx0, cur, model.output_b);
-        cb(cur, "result_embd", -1);
-
-        ggml_build_forward_expand(gf, cur);
-
-        return gf;
-    }
-};
-
-static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
-    llama_ubatch dummy = {};
-    dummy.equal_seqs = true;
-
-    llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
-
-    struct llm_build_context llm(lctx, dummy, cb, false);
-
-    llm.init();
-
-    struct ggml_cgraph * result = llm.build_defrag(ids);
-
-    llm.free();
-
-    return result;
-}
-
-static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
-    llama_ubatch dummy = {};
-    dummy.equal_seqs = true;
-
-    llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
-
-    struct llm_build_context llm(lctx, dummy, cb, false);
-
-    llm.init();
-
-    struct ggml_cgraph * result = llm.build_k_shift();
-
-    llm.free();
-
-    return result;
-}
-
-static struct ggml_cgraph * llama_build_graph(
-         llama_context & lctx,
-    const llama_ubatch & ubatch,
-                  bool   worst_case) {
-    const auto & model = lctx.model;
-
-    // this callback allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
-    llm_build_cb cb = [&](struct ggml_tensor * cur, const char * name, int il) {
-        if (il >= 0) {
-            ggml_format_name(cur, "%s-%d", name, il);
-        } else {
-            ggml_set_name(cur, name);
-        }
-
-        if (!lctx.cparams.offload_kqv) {
-            if (strcmp(name, "kqv_merged_cont") == 0) {
-                // all nodes between the KV store and the attention output are run on the CPU
-                ggml_backend_sched_set_tensor_backend(lctx.sched.get(), cur, lctx.backend_cpu);
-            }
-        }
-
-        // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
-        // FIXME: fix in ggml_backend_sched
-        const bool full_offload = lctx.model.params.n_gpu_layers > (int) lctx.model.hparams.n_layer;
-        if (ubatch.n_tokens < 32 || full_offload) {
-            if (il != -1 && strcmp(name, "norm") == 0) {
-                const auto & dev_layer = lctx.model.dev_layer(il);
-                for (auto & backend : lctx.backends) {
-                    if (ggml_backend_get_device(backend.get()) == dev_layer) {
-                        if (ggml_backend_supports_op(backend.get(), cur)) {
-                            ggml_backend_sched_set_tensor_backend(lctx.sched.get(), cur, backend.get());
-                        }
-                    }
-                }
-            }
-        }
-    };
-
-    struct ggml_cgraph * result = NULL;
-
-    struct llm_build_context llm(lctx, ubatch, cb, worst_case);
-
-    llm.init();
-
-    switch (model.arch) {
-        case LLM_ARCH_LLAMA:
-        case LLM_ARCH_MINICPM:
-        case LLM_ARCH_GRANITE:
-        case LLM_ARCH_GRANITE_MOE:
-            {
-                result = llm.build_llama();
-            } break;
-        case LLM_ARCH_DECI:
-            {
-                result = llm.build_deci();
-            } break;
-        case LLM_ARCH_BAICHUAN:
-            {
-                result = llm.build_baichuan();
-            } break;
-        case LLM_ARCH_FALCON:
-            {
-                result = llm.build_falcon();
-            } break;
-        case LLM_ARCH_GROK:
-            {
-                result = llm.build_grok();
-            } break;
-        case LLM_ARCH_STARCODER:
-            {
-                result = llm.build_starcoder();
-            } break;
-        case LLM_ARCH_REFACT:
-            {
-                result = llm.build_refact();
-            } break;
-        case LLM_ARCH_BERT:
-        case LLM_ARCH_JINA_BERT_V2:
-        case LLM_ARCH_NOMIC_BERT:
-            {
-                result = llm.build_bert();
-            } break;
-        case LLM_ARCH_BLOOM:
-            {
-                result = llm.build_bloom();
-            } break;
-        case LLM_ARCH_MPT:
-            {
-                result = llm.build_mpt();
-            } break;
-         case LLM_ARCH_STABLELM:
-            {
-                result = llm.build_stablelm();
-            } break;
-        case LLM_ARCH_QWEN:
-            {
-                result = llm.build_qwen();
-            } break;
-        case LLM_ARCH_QWEN2:
-            {
-                result = llm.build_qwen2();
-            } break;
-        case LLM_ARCH_QWEN2VL:
-            {
-                lctx.n_pos_per_token = 4;
-                result = llm.build_qwen2vl();
-            } break;
-        case LLM_ARCH_QWEN2MOE:
-            {
-                result = llm.build_qwen2moe();
-            } break;
-        case LLM_ARCH_PHI2:
-            {
-                result = llm.build_phi2();
-            } break;
-        case LLM_ARCH_PHI3:
-        case LLM_ARCH_PHIMOE:
-            {
-                result = llm.build_phi3();
-            } break;
-        case LLM_ARCH_PLAMO:
-            {
-                result = llm.build_plamo();
-            } break;
-        case LLM_ARCH_GPT2:
-            {
-                result = llm.build_gpt2();
-            } break;
-        case LLM_ARCH_CODESHELL:
-            {
-                result = llm.build_codeshell();
-            } break;
-        case LLM_ARCH_ORION:
-            {
-                result = llm.build_orion();
-            } break;
-        case LLM_ARCH_INTERNLM2:
-            {
-                result = llm.build_internlm2();
-            } break;
-        case LLM_ARCH_MINICPM3:
-            {
-                result = llm.build_minicpm3();
-            } break;
-        case LLM_ARCH_GEMMA:
-            {
-                result = llm.build_gemma();
-            } break;
-        case LLM_ARCH_GEMMA2:
-            {
-                result = llm.build_gemma2();
-            } break;
-        case LLM_ARCH_GEMMA3:
-            {
-                result = llm.build_gemma3();
-            } break;
-        case LLM_ARCH_STARCODER2:
-            {
-                result = llm.build_starcoder2();
-            } break;
-        case LLM_ARCH_MAMBA:
-            {
-                result = llm.build_mamba();
-            } break;
-        case LLM_ARCH_XVERSE:
-            {
-                result = llm.build_xverse();
-            } break;
-        case LLM_ARCH_COMMAND_R:
-            {
-                result = llm.build_command_r();
-            } break;
-        case LLM_ARCH_COHERE2:
-            {
-                result = llm.build_cohere2();
-            } break;
-        case LLM_ARCH_DBRX:
-            {
-                result = llm.build_dbrx();
-            } break;
-        case LLM_ARCH_OLMO:
-            {
-                result = llm.build_olmo();
-            } break;
-        case LLM_ARCH_OLMO2:
-            {
-                result = llm.build_olmo2();
-            } break;
-        case LLM_ARCH_OLMOE:
-            {
-                result = llm.build_olmoe();
-            } break;
-        case LLM_ARCH_OPENELM:
-            {
-                result = llm.build_openelm();
-            } break;
-        case LLM_ARCH_GPTNEOX:
-            {
-                result = llm.build_gptneox();
-            } break;
-        case LLM_ARCH_ARCTIC:
-            {
-                result = llm.build_arctic();
-            } break;
-        case LLM_ARCH_DEEPSEEK:
-            {
-                result = llm.build_deepseek();
-            } break;
-        case LLM_ARCH_DEEPSEEK2:
-            {
-                result = llm.build_deepseek2();
-            } break;
-        case LLM_ARCH_CHATGLM:
-            {
-                result = llm.build_chatglm();
-            } break;
-        case LLM_ARCH_BITNET:
-            {
-                result = llm.build_bitnet();
-            } break;
-        case LLM_ARCH_T5:
-            {
-                if (lctx.is_encoding) {
-                    result = llm.build_t5_enc();
-                } else {
-                    result = llm.build_t5_dec();
-                }
-            } break;
-        case LLM_ARCH_T5ENCODER:
-            {
-                result = llm.build_t5_enc();
-            } break;
-        case LLM_ARCH_JAIS:
-            {
-                result = llm.build_jais();
-            } break;
-        case LLM_ARCH_NEMOTRON:
-            {
-                result = llm.build_nemotron();
-            } break;
-        case LLM_ARCH_EXAONE:
-            {
-                result = llm.build_exaone();
-            } break;
-        case LLM_ARCH_RWKV6:
-            {
-                result = llm.build_rwkv6();
-            } break;
-        case LLM_ARCH_RWKV6QWEN2:
-            {
-                result = llm.build_rwkv6qwen2();
-            } break;
-        case LLM_ARCH_CHAMELEON:
-            {
-                result = llm.build_chameleon();
-            } break;
-        case LLM_ARCH_WAVTOKENIZER_DEC:
-            {
-                result = llm.build_wavtokenizer_dec();
-            } break;
-        default:
-            GGML_ABORT("fatal error");
-    }
-
-    // add on pooling layer
-    if (lctx.cparams.embeddings) {
-        result = llm.append_pooling(result);
-    }
-
-    llm.free();
-
-    return result;
-}
-
-// returns the result of ggml_backend_sched_graph_compute_async execution
-static enum ggml_status llama_graph_compute(
-          llama_context & lctx,
-            ggml_cgraph * gf,
-                    int   n_threads,
-        ggml_threadpool * threadpool) {
-    if (lctx.backend_cpu != nullptr) {
-        auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(lctx.backend_cpu));
-        auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool");
-        set_threadpool_fn(lctx.backend_cpu, threadpool);
-    }
-
-    // set the number of threads for all the backends
-    for (const auto & set_n_threads_fn : lctx.set_n_threads_fns) {
-        set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
-    }
-
-    auto status = ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf);
-    if (status != GGML_STATUS_SUCCESS) {
-        LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status);
-    }
-
-    // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
-
-    return status;
-}
-
-static int llama_prepare_sbatch(
-        llama_context     & lctx,
-        const llama_batch & batch,
-        uint32_t          & n_outputs) {
-    const auto & model   = lctx.model;
-    const auto & hparams = model.hparams;
-    const auto & cparams = lctx.cparams;
-
-    const uint32_t n_tokens_all = batch.n_tokens;
-    const  int64_t n_embd       = hparams.n_embd;
-
-    // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
-    const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
-
-    GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
-    if (batch.token) {
-        for (uint32_t i = 0; i < n_tokens_all; ++i) {
-            if (batch.token[i] < 0 || uint32_t(batch.token[i]) >= model.vocab.n_tokens()) {
-                LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
-                return -1;
-            }
-        }
-    }
-    GGML_ASSERT(n_tokens_all <= cparams.n_batch);
-    GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
-
-    lctx.n_queued_tokens += n_tokens_all;
-    lctx.embd_seq.clear();
-
-    // count outputs
-    if (batch.logits && !embd_pooled) {
-        for (uint32_t i = 0; i < n_tokens_all; ++i) {
-            n_outputs += batch.logits[i] != 0;
-        }
-    } else if (lctx.logits_all || embd_pooled) {
-        n_outputs = n_tokens_all;
-    } else {
-        // keep last output only
-        n_outputs = 1;
-    }
-
-    lctx.sbatch.from_batch(batch, n_embd,
-        /* simple_split */ !lctx.kv_self.recurrent,
-        /* logits_all   */ n_outputs == n_tokens_all);
-
-    // reserve output buffer
-    if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
-        LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
-        return -2;
-    };
-
-    return 0;
-}
-
-static int llama_prepare_ubatch(
-        llama_context          & lctx,
-        llama_kv_slot_restorer & kv_slot_restorer,
-        llama_ubatch           & ubatch,
-        const uint32_t           n_outputs,
-        const uint32_t           n_tokens_all) {
-    GGML_ASSERT(lctx.sbatch.n_tokens > 0);
-
-    auto       & kv_self = lctx.kv_self;
-    const auto & cparams = lctx.cparams;
-    const auto & hparams = lctx.model.hparams;
-
-    // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
-    const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
-
-    if (lctx.kv_self.recurrent) {
-        if (embd_pooled) {
-            // Pooled embeddings cannot be split across ubatches (yet)
-            ubatch = lctx.sbatch.split_seq(cparams.n_ubatch);
-        } else {
-            // recurrent model architectures are easier to implement
-            // with equal-length sequences
-            ubatch = lctx.sbatch.split_equal(cparams.n_ubatch);
-        }
-    } else {
-        ubatch = lctx.sbatch.split_simple(cparams.n_ubatch);
-    }
-
-    // count the outputs in this u_batch
-    {
-        int32_t n_outputs_new = 0;
-
-        if (n_outputs == n_tokens_all) {
-            n_outputs_new = ubatch.n_tokens;
-        } else {
-            GGML_ASSERT(ubatch.output);
-            for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
-                n_outputs_new += int32_t(ubatch.output[i] != 0);
-            }
-        }
-
-        // needs to happen before the graph is built
-        lctx.n_outputs = n_outputs_new;
-    }
-
-    // non-causal masks do not use the KV cache
-    if (hparams.causal_attn) {
-        llama_kv_cache_update(&lctx);
-
-        // if we have enough unused cells before the current head ->
-        //   better to start searching from the beginning of the cache, hoping to fill it
-        if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
-            kv_self.head = 0;
-        }
-
-        const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
-        if (!slot) {
-            return 1;
-        }
-        kv_slot_restorer.save(slot);
-
-        if (!kv_self.recurrent) {
-            // a heuristic, to avoid attending the full cache if it is not yet utilized
-            // after enough generations, the benefit from this heuristic disappears
-            // if we start defragmenting the cache, the benefit from this will be more important
-            const uint32_t pad = llama_kv_cache_get_padding(cparams);
-            kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
-            //kv_self.n = llama_kv_cache_cell_max(kv_self);
-        }
-    }
-
-    return 0;
-}
-
-// decode a batch of tokens by evaluating the transformer
-// in case of unsuccessful decoding (error or warning),
-// the kv_cache state will be returned to its original state
-// (for non-recurrent models) or cleaned (for recurrent models)
-//
-//   - lctx:      llama context
-//   - inp_batch: batch to evaluate
-//
-// return 0 on success
-// return positive int on warning
-// return negative int on error
-//
-static int llama_decode_impl(
-         llama_context & lctx,
-           llama_batch   inp_batch) {
-
-    lctx.is_encoding = false;
-
-    if (inp_batch.n_tokens == 0) {
-        LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
-        return -1;
-    }
-
-    // temporarily allocate memory for the input batch if needed
-    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
-    const llama_batch & batch = batch_allocr.batch;
-
-    const auto & model   = lctx.model;
-    const auto & vocab   = model.vocab;
-    const auto & hparams = model.hparams;
-    const auto & cparams = lctx.cparams;
-
-    if (lctx.t_compute_start_us == 0) {
-        lctx.t_compute_start_us = ggml_time_us();
-    }
-    auto & kv_self = lctx.kv_self;
-    llama_kv_slot_restorer kv_slot_restorer(kv_self);
-
-    const int64_t n_embd  = hparams.n_embd;
-    const int64_t n_vocab = vocab.n_tokens();
-
-    uint32_t n_outputs = 0;
-    uint32_t n_outputs_prev = 0;
-
-    {
-        const int ret = llama_prepare_sbatch(lctx, batch, n_outputs);
-        if (ret != 0) {
-            return ret;
-        }
-    }
-
-    while (lctx.sbatch.n_tokens > 0) {
-        llama_ubatch ubatch;
-        {
-            const int ret = llama_prepare_ubatch(lctx, kv_slot_restorer, ubatch, n_outputs, batch.n_tokens);
-            if (ret != 0) {
-                return ret;
-            }
-        }
-
-        const int         n_threads  = ubatch.n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
-        ggml_threadpool_t threadpool = ubatch.n_tokens == 1 ? lctx.threadpool   : lctx.threadpool_batch;
-
-        GGML_ASSERT(n_threads > 0);
-
-        //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
-
-        ggml_backend_sched_reset(lctx.sched.get());
-        ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
-
-        ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false);
-
-        // the output is always the last tensor in the graph
-        struct ggml_tensor * res  = ggml_graph_node(gf, -1);
-        struct ggml_tensor * embd = ggml_graph_node(gf, -2);
-
-        if (lctx.n_outputs == 0) {
-            // no output
-            res  = nullptr;
-            embd = nullptr;
-        } else if (cparams.embeddings) {
-            res  = nullptr; // do not extract logits for embedding case
-            embd = nullptr;
-            for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
-                if (strcmp(ggml_graph_node(gf, i)->name, "result_embd_pooled") == 0) {
-                    embd = ggml_graph_node(gf, i);
-                    break;
-                }
-            }
-            GGML_ASSERT(embd != nullptr && "missing embeddings tensor");
-        } else {
-            embd = nullptr; // do not extract embeddings when not needed
-            GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
-        }
-
-        // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
-
-        ggml_backend_sched_alloc_graph(lctx.sched.get(), gf);
-
-        llama_set_inputs(lctx, ubatch);
-
-        const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
-        if (compute_status != GGML_STATUS_SUCCESS) {
-            kv_slot_restorer.restore(kv_self);
-            switch (compute_status) {
-                case GGML_STATUS_ABORTED:
-                    return 2;
-                case GGML_STATUS_ALLOC_FAILED:
-                    return -2;
-                case GGML_STATUS_FAILED:
-                default:
-                    return -3;
-            }
-        }
-
-        // update the kv ring buffer
-        {
-            kv_self.head += ubatch.n_tokens;
-
-            // Ensure kv cache head points to a valid index.
-            if (kv_self.head >= kv_self.size) {
-                kv_self.head = 0;
-            }
-        }
-
-        // plot the computation graph in dot format (for debugging purposes)
-        //if (n_past%100 == 0) {
-        //    ggml_graph_dump_dot(gf, NULL, "llama.dot");
-        //}
-
-        // extract logits
-        if (res) {
-            ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched.get(), res);
-            GGML_ASSERT(backend_res != nullptr);
-            GGML_ASSERT(lctx.logits != nullptr);
-
-            float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
-            const int32_t n_outputs_new = lctx.n_outputs;
-
-            if (n_outputs_new) {
-                GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
-                GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size);
-                ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new*n_vocab*sizeof(float));
-            }
-        }
-
-        // extract embeddings
-        if (embd) {
-            ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched.get(), embd);
-            GGML_ASSERT(backend_embd != nullptr);
-
-            switch (cparams.pooling_type) {
-                case LLAMA_POOLING_TYPE_NONE:
-                    {
-                        // extract token embeddings
-                        GGML_ASSERT(lctx.embd != nullptr);
-                        float * embd_out = lctx.embd + n_outputs_prev*n_embd;
-                        const int32_t n_outputs_new = lctx.n_outputs;
-
-                        if (n_outputs_new) {
-                            GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
-                            GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_embd <= (int64_t) lctx.embd_size);
-                            ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float));
-                        }
-                    } break;
-                case LLAMA_POOLING_TYPE_MEAN:
-                case LLAMA_POOLING_TYPE_CLS:
-                case LLAMA_POOLING_TYPE_LAST:
-                    {
-                        // extract sequence embeddings (cleared before processing each batch)
-                        auto & embd_seq_out = lctx.embd_seq;
-
-                        for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
-                            const llama_seq_id seq_id = ubatch.seq_id[s][0];
-                            if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
-                                continue;
-                            }
-                            embd_seq_out[seq_id].resize(n_embd);
-                            ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
-                        }
-                    } break;
-                case LLAMA_POOLING_TYPE_RANK:
-                    {
-                        // extract the rerank score - a single float per sequence
-                        auto & embd_seq_out = lctx.embd_seq;
-
-                        for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
-                            const llama_seq_id seq_id = ubatch.seq_id[s][0];
-                            if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
-                                continue;
-                            }
-                            embd_seq_out[seq_id].resize(1);
-                            ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
-                        }
-                    } break;
-                case LLAMA_POOLING_TYPE_UNSPECIFIED:
-                    {
-                        GGML_ABORT("unknown pooling type");
-                    }
-            }
-        }
-        n_outputs_prev += lctx.n_outputs;
-    }
-
-    // set output mappings
-    {
-        bool sorted_output = true;
-
-        GGML_ASSERT(lctx.sbatch.out_ids.size() == n_outputs);
-
-        for (size_t i = 0; i < n_outputs; ++i) {
-            size_t out_id = lctx.sbatch.out_ids[i];
-            lctx.output_ids[out_id] = i;
-            if (out_id != i) {
-                sorted_output = false;
-            }
-        }
-
-        if (sorted_output) {
-            lctx.sbatch.out_ids.clear();
-        }
-    }
-
-    // set to total number of outputs in the batch, for use in llama_get_logits_ith
-    lctx.n_outputs = n_outputs;
-
-    // wait for the computation to finish (automatically done when obtaining the model output)
-    //llama_synchronize(&lctx);
-
-    // decide if we need to defrag the kv cache
-    if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
-        // - do not defrag small contexts (i.e. < 2048 tokens)
-        // - count the padding towards the number of used tokens
-        const float fragmentation = kv_self.n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self.used + llama_kv_cache_get_padding(cparams))/float(kv_self.n)) : 0.0f;
-
-        // queue defragmentation for next llama_kv_cache_update
-        if (fragmentation > cparams.defrag_thold) {
-            LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
-
-            llama_kv_cache_defrag(kv_self);
-        }
-    }
-
-    // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
-    // overlap with device computation.
-    ggml_backend_sched_reset(lctx.sched.get());
-
-    return 0;
-}
-
-// encode a batch of tokens by evaluating the encoder part of the transformer
-//
-//   - lctx:      llama context
-//   - batch:     batch to evaluate
-//
-// return 0 on success
-// return positive int on warning
-// return negative int on error
-//
-static int llama_encode_impl(
-         llama_context & lctx,
-           llama_batch   inp_batch) {
-
-    lctx.is_encoding = true;
-
-    if (inp_batch.n_tokens == 0) {
-        LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
-        return -1;
-    }
-
-    // temporary allocate memory for the input batch if needed
-    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
-
-    const llama_batch & batch = batch_allocr.batch;
-    const uint32_t n_tokens = batch.n_tokens;
-
-    const auto & model   = lctx.model;
-    const auto & hparams = model.hparams;
-    const auto & cparams = lctx.cparams;
-
-    GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
-
-    if (batch.token) {
-        for (uint32_t i = 0; i < n_tokens; ++i) {
-            if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
-                LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
-                return -1;
-            }
-        }
-    }
-
-    // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
-    GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
-
-    if (lctx.t_compute_start_us == 0) {
-        lctx.t_compute_start_us = ggml_time_us();
-    }
-
-    lctx.n_queued_tokens += n_tokens;
-
-    const int64_t n_embd = hparams.n_embd;
-
-    lctx.sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
-
-    const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens);
-
-    // reserve output buffer
-    if (llama_output_reserve(lctx, n_tokens) < n_tokens) {
-        LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
-        return -2;
-    };
-
-    for (uint32_t i = 0; i < n_tokens; ++i) {
-        lctx.output_ids[i] = i;
-    }
-
-    lctx.inp_embd_enc = NULL;
-    lctx.n_outputs = n_tokens;
-
-    int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
-    ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
-
-    GGML_ASSERT(n_threads > 0);
-
-    ggml_backend_sched_reset(lctx.sched.get());
-    ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
-
-    ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false);
-
-    // the output embeddings after the final encoder normalization
-    struct ggml_tensor * embd = nullptr;
-
-    // there are two cases here
-    if (llama_model_has_decoder(&lctx.model)) {
-        // first case is an encoder-decoder T5 model where embeddings are passed to decoder
-        embd = ggml_graph_node(gf, -1);
-        GGML_ASSERT(strcmp(embd->name, "result_norm") == 0 && "missing result_output tensor");
-    } else {
-        // second case is an encoder-only T5 model
-        if (cparams.embeddings) {
-            // only output embeddings if required
-            embd = ggml_graph_node(gf, -1);
-            if (strcmp(embd->name, "result_embd_pooled") != 0) {
-                embd = ggml_graph_node(gf, -2);
-            }
-            GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
-        }
-    }
-
-    ggml_backend_sched_alloc_graph(lctx.sched.get(), gf);
-
-    llama_set_inputs(lctx, ubatch);
-
-    const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
-    switch (compute_status) {
-        case GGML_STATUS_SUCCESS:
-            break;
-        case GGML_STATUS_ABORTED:
-            return 2;
-        case GGML_STATUS_ALLOC_FAILED:
-            return -2;
-        case GGML_STATUS_FAILED:
-        default:
-            return -3;
-    }
-
-    // extract embeddings
-    if (embd) {
-        ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched.get(), embd);
-        GGML_ASSERT(backend_embd != nullptr);
-
-        if (llama_model_has_decoder(&lctx.model)) {
-            lctx.embd_enc.resize(n_tokens*n_embd);
-            float * embd_out = lctx.embd_enc.data();
-
-            ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
-            GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
-
-            // remember the sequence ids used during the encoding - needed for cross attention later
-            lctx.seq_ids_enc.resize(n_tokens);
-            for (uint32_t i = 0; i < n_tokens; i++) {
-                for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
-                    llama_seq_id seq_id = ubatch.seq_id[i][s];
-                    lctx.seq_ids_enc[i].insert(seq_id);
-                }
-            }
-        } else {
-            GGML_ASSERT(lctx.embd != nullptr);
-
-            switch (cparams.pooling_type) {
-                case LLAMA_POOLING_TYPE_NONE:
-                    {
-                        // extract token embeddings
-                        GGML_ASSERT(lctx.embd != nullptr);
-                        float * embd_out = lctx.embd;
-
-                        GGML_ASSERT(n_tokens*n_embd <= (int64_t) lctx.embd_size);
-                        ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
-                    } break;
-                case LLAMA_POOLING_TYPE_MEAN:
-                case LLAMA_POOLING_TYPE_CLS:
-                case LLAMA_POOLING_TYPE_LAST:
-                    {
-                        // extract sequence embeddings
-                        auto & embd_seq_out = lctx.embd_seq;
-                        embd_seq_out.clear();
-
-                        GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
-
-                        for (uint32_t i = 0; i < n_tokens; i++) {
-                            const llama_seq_id seq_id = ubatch.seq_id[i][0];
-                            if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
-                                continue;
-                            }
-                            embd_seq_out[seq_id].resize(n_embd);
-                            ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
-                        }
-                    } break;
-                case LLAMA_POOLING_TYPE_RANK:
-                    {
-                        // TODO: this likely should be the same logic as in llama_decoder_internal, but better to
-                        //       wait for an encoder model that requires this pooling type in order to test it
-                        //       https://github.com/ggerganov/llama.cpp/pull/9510
-                        GGML_ABORT("RANK pooling not implemented yet");
-                    }
-                case LLAMA_POOLING_TYPE_UNSPECIFIED:
-                    {
-                        GGML_ABORT("unknown pooling type");
-                    }
-            }
-        }
-    }
-
-    // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
-    // overlap with device computation.
-    ggml_backend_sched_reset(lctx.sched.get());
-
-    return 0;
-}
-
-// find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
-static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
-    auto & kv_self = lctx.kv_self;
-
-    const auto & hparams = lctx.model.hparams;
-
-    const uint32_t n_layer = hparams.n_layer;
-
-    const uint32_t n_kv   = llama_kv_cache_cell_max(kv_self);
-    const uint32_t n_used = kv_self.used;
-
-    assert(n_used <= n_kv);
-
-    //const int64_t t_start = ggml_time_us();
-
-    // number of cells moved
-    uint32_t n_moves = 0;
-
-    // each move requires 6*n_layer tensors (see build_defrag)
-    //   - source view, destination view, copy operation
-    //   - x2 for keys and values
-    //const uint32_t max_moves = model.max_nodes()/(6*n_layer);
-    // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
-    const uint32_t max_moves = (lctx.model.max_nodes() - 2*n_layer)/(6*n_layer);
-
-    // determine which KV cells to move where
-    //
-    //  cell i moves to ids[i]
-    //
-    //  if ids[i] == i || ids[i] == n_kv, then cell i is not moved
-    //
-    std::vector<uint32_t> ids(n_kv, n_kv);
-
-    for (uint32_t i0 = 0; i0 < n_used; ++i0) {
-        const auto & cell0 = kv_self.cells[i0];
-
-        if (!cell0.is_empty()) {
-            ids[i0] = i0;
-
-            continue;
-        }
-
-        // found a hole - fill it with data from the end of the cache
-
-        uint32_t nh = 1;
-
-        // determine the size of the hole
-        while (i0 + nh < n_used && kv_self.cells[i0 + nh].is_empty()) {
-            nh++;
-        }
-
-        uint32_t nf = 0;
-        uint32_t is = n_kv - 1;
-
-        // starting from the end, find nh non-empty cells
-        for (; is > i0; --is) {
-            const auto & cell1 = kv_self.cells[is];
-
-            if (cell1.is_empty() || ids[is] != n_kv) {
-                continue;
-            }
-
-            // non-empty cell which is not yet moved
-            nf++;
-
-            if (nf == nh) {
-                break;
-            }
-        }
-
-        // this can only happen if `n_used` is not accurate, which would be a bug
-        GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
-
-        nf = 0;
-
-        uint32_t i1 = is;
-
-        // are we moving a continuous block of memory?
-        bool cont = false;
-
-        // should we stop searching for the next move?
-        bool stop = false;
-
-        // go back and move the nf cells to the hole
-        for (; i1 < n_kv; ++i1) {
-            auto & cell1 = kv_self.cells[i1];
-
-            if (cell1.is_empty() || ids[i1] != n_kv) {
-                if (n_moves == max_moves) {
-                    stop = true;
-                    break;
-                }
-
-                cont = false;
-                continue;
-            }
-
-            // this cell goes to (i0 + nf)
-            ids[i1] = i0 + nf;
-
-            // move the cell meta data
-            kv_self.cells[i0 + nf] = cell1;
-
-            // clear the old cell and move the head there
-            cell1 = llama_kv_cell();
-            kv_self.head = n_used;
-
-            if (!cont) {
-                n_moves++;
-                cont = true;
-            }
-
-            nf++;
-
-            if (nf == nh) {
-                break;
-            }
-        }
-
-        if (stop || n_moves == max_moves) {
-            break;
-        }
-
-        //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
-
-        i0 += nh - 1;
-    }
-
-    if (n_moves == 0) {
-        return;
-    }
-
-    //LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
-
-    //LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*n_moves*n_layer);
-
-#if 0
-    // CPU defrag
-    //
-    // TODO: optimizations are possible:
-    //       - multiple threads
-    //       - avoid copying to the host memory when already there
-    //
-    // likely not worth the effort, as we have ggml_graph based defrag
-    //
-
-    const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
-    const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
-
-    const uint32_t kv_size = kv_self.size;
-
-    std::vector<uint8_t> buf_k;
-    std::vector<uint8_t> buf_v;
-
-    for (uint32_t il = 0; il < n_layer; ++il) {
-        const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
-        const size_t k_size     = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_size);
-
-        const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
-        const size_t v_size    = ggml_row_size (kv_self.v_l[il]->type, n_embd_v_gqa*kv_size);
-
-        buf_k.resize(k_size);
-        buf_v.resize(v_size);
-
-        ggml_backend_tensor_get(kv_self.k_l[il], buf_k.data(), 0, buf_k.size());
-        ggml_backend_tensor_get(kv_self.v_l[il], buf_v.data(), 0, buf_v.size());
-
-        // batch move [i, i+nm) to [id, id+nm)
-        // note: cells can move only to a lower index
-        for (uint32_t i = 0; i < n_kv; ++i) {
-            const uint32_t id = ids[i];
-
-            if (i == id || id == n_kv) {
-                continue;
-            }
-
-            uint32_t nm = 1;
-
-            while (i + nm < n_kv && ids[i + nm] == id + nm) {
-                nm++;
-            }
-
-            // move keys
-            {
-                const int64_t os =  i*k_size_row;
-                const int64_t od = id*k_size_row;
-
-                memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
-            }
-
-            // move values (note: they are transposed)
-            {
-                const int64_t os =  i;
-                const int64_t od = id;
-
-                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                    memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
-                }
-            }
-
-            i += nm - 1;
-        }
-
-        ggml_backend_tensor_set(kv_self.k_l[il], buf_k.data(), 0, buf_k.size());
-        ggml_backend_tensor_set(kv_self.v_l[il], buf_v.data(), 0, buf_v.size());
-    }
-#else
-    // ggml_graph defrag
-
-    ggml_backend_sched_reset(lctx.sched.get());
-
-    ggml_cgraph * gf = llama_build_graph_defrag(lctx, ids);
-
-    llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
-#endif
-
-    //const int64_t t_end = ggml_time_us();
-
-    //LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0);
-}
-
-static void llama_kv_cache_update_impl(struct llama_context & lctx) {
-    bool need_reserve = false;
-
-    if (lctx.kv_self.has_shift) {
-        if (!llama_kv_cache_can_shift(&lctx)) {
-            GGML_ABORT("The current context does not support K-shift");
-        }
-
-        // apply K-shift if needed
-        if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
-            ggml_backend_sched_reset(lctx.sched.get());
-
-            ggml_cgraph * gf = llama_build_graph_k_shift(lctx);
-
-            ggml_backend_sched_alloc_graph(lctx.sched.get(), gf);
-
-            llama_set_k_shift(lctx);
-
-            llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
-
-            need_reserve = true;
-        }
-
-        {
-            auto & kv_self = lctx.kv_self;
-
-            kv_self.has_shift = false;
-
-            for (uint32_t i = 0; i < kv_self.size; ++i) {
-                kv_self.cells[i].delta = 0;
-            }
-        }
-    }
-
-    // defragment the KV cache if needed
-    if (lctx.kv_self.do_defrag) {
-        llama_kv_cache_defrag_impl(lctx);
-
-        need_reserve = true;
-
-        lctx.kv_self.do_defrag = false;
-    }
-
-    // reserve a worst case graph again
-    if (need_reserve) {
-        // TODO: extract to a function
-        // build worst-case graph
-        uint32_t n_seqs = 1; // TODO: worst-case number of sequences
-        uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch);
-        llama_token token = lctx.model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
-        llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
-        ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
-
-        // initialize scheduler with the worst-case graph
-        ggml_backend_sched_reset(lctx.sched.get());
-        if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
-            LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
-        }
-    }
-}
-
-int32_t llama_set_adapter_lora(
-            struct llama_context * ctx,
-            struct llama_adapter_lora * adapter,
-            float scale) {
-    ctx->lora[adapter] = scale;
-    return 0;
-}
-
-int32_t llama_rm_adapter_lora(
-            struct llama_context * ctx,
-            struct llama_adapter_lora * adapter) {
-    auto pos = ctx->lora.find(adapter);
-    if (pos != ctx->lora.end()) {
-        ctx->lora.erase(pos);
-        return 0;
-    }
-
-    return -1;
-}
-
-void llama_clear_adapter_lora(struct llama_context * ctx) {
-    ctx->lora.clear();
-}
-
-int32_t llama_apply_adapter_cvec(
-        struct llama_context * ctx,
-                 const float * data,
-                      size_t   len,
-                     int32_t   n_embd,
-                     int32_t   il_start,
-                     int32_t   il_end) {
-    return ctx->cvec.apply(ctx->model, data, len, n_embd, il_start, il_end);
-}
-
 //
 // interface implementation
 //
 
-struct llama_context_params llama_context_default_params() {
-    struct llama_context_params result = {
-        /*.n_ctx                       =*/ 512,
-        /*.n_batch                     =*/ 2048,
-        /*.n_ubatch                    =*/ 512,
-        /*.n_seq_max                   =*/ 1,
-        /*.n_threads                   =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
-        /*.n_threads_batch             =*/ GGML_DEFAULT_N_THREADS,
-        /*.rope_scaling_type           =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
-        /*.pooling_type                =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
-        /*.attention_type              =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
-        /*.rope_freq_base              =*/ 0.0f,
-        /*.rope_freq_scale             =*/ 0.0f,
-        /*.yarn_ext_factor             =*/ -1.0f,
-        /*.yarn_attn_factor            =*/ 1.0f,
-        /*.yarn_beta_fast              =*/ 32.0f,
-        /*.yarn_beta_slow              =*/ 1.0f,
-        /*.yarn_orig_ctx               =*/ 0,
-        /*.defrag_thold                =*/ -1.0f,
-        /*.cb_eval                     =*/ nullptr,
-        /*.cb_eval_user_data           =*/ nullptr,
-        /*.type_k                      =*/ GGML_TYPE_F16,
-        /*.type_v                      =*/ GGML_TYPE_F16,
-        /*.logits_all                  =*/ false,
-        /*.embeddings                  =*/ false,
-        /*.offload_kqv                 =*/ true,
-        /*.flash_attn                  =*/ false,
-        /*.no_perf                     =*/ true,
-        /*.abort_callback              =*/ nullptr,
-        /*.abort_callback_data         =*/ nullptr,
-    };
-
-    return result;
-}
-
 struct llama_sampler_chain_params llama_sampler_chain_default_params() {
     struct llama_sampler_chain_params result = {
         /*.no_perf                     =*/ true,
@@ -9571,6 +82,57 @@ int64_t llama_time_us(void) {
     return ggml_time_us();
 }
 
+// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
+static int llama_model_load(const std::string & fname, std::vector<std::string> & splits, llama_model & model, llama_model_params & params) {
+    // loading time will be recalculated after the first eval, so
+    // we take page faults deferred by mmap() into consideration
+    model.t_load_us = 0;
+    time_meas tm(model.t_load_us);
+
+    model.t_start_us = tm.t_start_us;
+
+    try {
+        llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides);
+
+        ml.print_info();
+
+        model.hparams.vocab_only = params.vocab_only;
+
+        try {
+            model.load_arch(ml);
+        } catch(const std::exception & e) {
+            throw std::runtime_error("error loading model architecture: " + std::string(e.what()));
+        }
+        try {
+            model.load_hparams(ml);
+        } catch(const std::exception & e) {
+            throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
+        }
+        try {
+            model.load_vocab(ml);
+        } catch(const std::exception & e) {
+            throw std::runtime_error("error loading model vocabulary: " + std::string(e.what()));
+        }
+
+        model.load_stats(ml);
+        model.print_info();
+
+        if (params.vocab_only) {
+            LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__);
+            return 0;
+        }
+
+        if (!model.load_tensors(ml)) {
+            return -2;
+        }
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
+        return -1;
+    }
+
+    return 0;
+}
+
 static struct llama_model * llama_model_load_from_file_impl(
         const std::string & path_model,
         std::vector<std::string> & splits,
@@ -9691,460 +253,6 @@ struct llama_model * llama_model_load_from_splits(
     return llama_model_load_from_file_impl(splits.front(), splits, params);
 }
 
-struct llama_context * llama_init_from_model(
-                 struct llama_model * model,
-        struct llama_context_params   params) {
-
-    if (!model) {
-        LLAMA_LOG_ERROR("%s: model cannot be NULL\n", __func__);
-        return nullptr;
-    }
-
-    if (params.n_batch == 0 && params.n_ubatch == 0) {
-        LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__);
-        return nullptr;
-    }
-
-    if (params.n_ctx == 0 && model->hparams.n_ctx_train == 0) {
-        LLAMA_LOG_ERROR("%s: n_ctx and model->hparams.n_ctx_train cannot both be zero\n", __func__);
-        return nullptr;
-    }
-
-    if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
-        LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
-        params.flash_attn = false;
-    }
-
-    if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
-        LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
-        params.flash_attn = false;
-    }
-
-    if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
-        LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
-        return nullptr;
-    }
-
-    llama_context * ctx = new llama_context(*model);
-
-    const auto & hparams = model->hparams;
-    auto       & cparams = ctx->cparams;
-
-    cparams.n_seq_max        = std::max(1u, params.n_seq_max);
-    cparams.n_threads        = params.n_threads;
-    cparams.n_threads_batch  = params.n_threads_batch;
-    cparams.yarn_ext_factor  = params.yarn_ext_factor;
-    cparams.yarn_attn_factor = params.yarn_attn_factor;
-    cparams.yarn_beta_fast   = params.yarn_beta_fast;
-    cparams.yarn_beta_slow   = params.yarn_beta_slow;
-    cparams.defrag_thold     = params.defrag_thold;
-    cparams.embeddings       = params.embeddings;
-    cparams.offload_kqv      = params.offload_kqv;
-    cparams.flash_attn       = params.flash_attn;
-    cparams.no_perf          = params.no_perf;
-    cparams.pooling_type     = params.pooling_type;
-
-    cparams.n_ctx            = params.n_ctx           == 0    ? hparams.n_ctx_train           : params.n_ctx;
-    cparams.rope_freq_base   = params.rope_freq_base  == 0.0f ? hparams.rope_freq_base_train  : params.rope_freq_base;
-    cparams.rope_freq_scale  = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
-
-    // this is necessary due to kv_self.n being padded later during inference
-    cparams.n_ctx            = GGML_PAD(cparams.n_ctx, llama_kv_cache_get_padding(cparams));
-
-    // with causal attention, the batch size is limited by the context size
-    cparams.n_batch          = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
-
-    // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
-    // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
-    // ref: https://github.com/ggerganov/llama.cpp/pull/5021
-    if (cparams.n_batch < GGML_KQ_MASK_PAD) {
-        LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
-        cparams.n_batch = GGML_KQ_MASK_PAD;
-    }
-
-    cparams.n_ubatch         = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
-
-    cparams.n_ctx_orig_yarn  = params.yarn_orig_ctx    != 0 ? params.yarn_orig_ctx    :
-                               hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn :
-                                                              hparams.n_ctx_train;
-
-    cparams.cb_eval           = params.cb_eval;
-    cparams.cb_eval_user_data = params.cb_eval_user_data;
-
-    auto rope_scaling_type = params.rope_scaling_type;
-    if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
-        rope_scaling_type = hparams.rope_scaling_type_train;
-    }
-
-    if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) {
-        cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none
-    }
-
-    if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set'
-        cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
-    }
-
-    cparams.yarn_attn_factor *= hparams.rope_attn_factor;
-
-    if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
-        if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
-            cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
-        } else {
-            cparams.pooling_type = hparams.pooling_type;
-        }
-    }
-
-    if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
-        cparams.causal_attn = hparams.causal_attn;
-    } else {
-        cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
-    }
-
-    const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
-
-    LLAMA_LOG_INFO("%s: n_seq_max     = %u\n",   __func__, cparams.n_seq_max);
-    LLAMA_LOG_INFO("%s: n_ctx         = %u\n",   __func__, cparams.n_ctx);
-    LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n",   __func__, n_ctx_per_seq);
-    LLAMA_LOG_INFO("%s: n_batch       = %u\n",   __func__, cparams.n_batch);
-    LLAMA_LOG_INFO("%s: n_ubatch      = %u\n",   __func__, cparams.n_ubatch);
-    LLAMA_LOG_INFO("%s: flash_attn    = %d\n",   __func__, cparams.flash_attn);
-    LLAMA_LOG_INFO("%s: freq_base     = %.1f\n", __func__, cparams.rope_freq_base);
-    LLAMA_LOG_INFO("%s: freq_scale    = %g\n",   __func__, cparams.rope_freq_scale);
-
-    if (n_ctx_per_seq < hparams.n_ctx_train) {
-        LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
-                __func__, n_ctx_per_seq, hparams.n_ctx_train);
-    }
-
-    if (n_ctx_per_seq > hparams.n_ctx_train) {
-        LLAMA_LOG_WARN("%s: n_ctx_pre_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
-                __func__, n_ctx_per_seq, hparams.n_ctx_train);
-    }
-
-    ctx->logits_all = params.logits_all;
-
-    // build worst-case graph for encoder if a model contains encoder
-    ctx->is_encoding = llama_model_has_encoder(model);
-
-    uint32_t kv_size = cparams.n_ctx;
-    ggml_type type_k = params.type_k;
-    ggml_type type_v = params.type_v;
-
-    // Mamba only needs a constant number of KV cache cells per sequence
-    if (llama_model_is_recurrent(model)) {
-        // Mamba needs at least as many KV cells as there are sequences kept at any time
-        kv_size = std::max((uint32_t) 1, params.n_seq_max);
-        // it's probably best to keep as much precision as possible for the states
-        type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
-        type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
-    }
-
-    GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
-    GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
-
-    if (!hparams.vocab_only) {
-        // GPU backends
-        for (auto * dev : model->devices) {
-            ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
-            if (backend == nullptr) {
-                LLAMA_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
-                llama_free(ctx);
-                return nullptr;
-            }
-            ctx->backends.emplace_back(backend);
-        }
-
-        // add ACCEL backends (such as BLAS)
-        for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
-            ggml_backend_dev_t dev = ggml_backend_dev_get(i);
-            if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
-                ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
-                if (backend == nullptr) {
-                    LLAMA_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
-                    llama_free(ctx);
-                    return nullptr;
-                }
-                ctx->backends.emplace_back(backend);
-            }
-        }
-
-        // add CPU backend
-        ctx->backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
-        if (ctx->backend_cpu == nullptr) {
-            LLAMA_LOG_ERROR("%s: failed to initialize CPU backend\n", __func__);
-            llama_free(ctx);
-            return nullptr;
-        }
-        ctx->backends.emplace_back(ctx->backend_cpu);
-
-        // create a list of the set_n_threads functions in the backends
-        for (auto & backend : ctx->backends) {
-            ggml_backend_dev_t dev = ggml_backend_get_device(backend.get());
-            ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
-            if (reg) {
-                auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
-                if (ggml_backend_set_n_threads_fn) {
-                    ctx->set_n_threads_fns.emplace_back(backend.get(), ggml_backend_set_n_threads_fn);
-                }
-            }
-        }
-
-        llama_set_abort_callback(ctx, params.abort_callback, params.abort_callback_data);
-
-        if (!llama_kv_cache_init(ctx->kv_self, ctx->model, ctx->cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
-            LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
-            llama_free(ctx);
-            return nullptr;
-        }
-
-        {
-            size_t memory_size_k = 0;
-            size_t memory_size_v = 0;
-
-            for (auto & k : ctx->kv_self.k_l) {
-                memory_size_k += ggml_nbytes(k);
-            }
-
-            for (auto & v : ctx->kv_self.v_l) {
-                memory_size_v += ggml_nbytes(v);
-            }
-
-            LLAMA_LOG_INFO("%s: KV self size  = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
-                      (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
-                ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
-                ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
-        }
-
-        // graph outputs buffer
-        {
-            // resized during inference when a batch uses more outputs
-            if (llama_output_reserve(*ctx, params.n_seq_max) < params.n_seq_max) {
-                LLAMA_LOG_ERROR("%s: failed to reserve initial output buffer\n", __func__);
-                llama_free(ctx);
-                return nullptr;
-            }
-
-            LLAMA_LOG_INFO("%s: %10s  output buffer size = %8.2f MiB\n", __func__,
-                    ggml_backend_buffer_name(ctx->buf_output.get()),
-                    ggml_backend_buffer_get_size(ctx->buf_output.get()) / 1024.0 / 1024.0);
-        }
-
-        // scheduler and compute buffers
-        {
-            // buffer types used for the compute buffer of each backend
-            std::vector<ggml_backend_buffer_type_t> backend_buft;
-            std::vector<ggml_backend_t> backend_ptrs;
-            for (auto & backend : ctx->backends) {
-                auto * buft = ggml_backend_get_default_buffer_type(backend.get());
-                auto backend_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
-                if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model->devices.empty()) {
-                    // use the host buffer of the first device CPU for faster transfer of the intermediate state
-                    auto * dev = model->devices[0];
-                    auto * host_buft = ggml_backend_dev_host_buffer_type(dev);
-                    if (host_buft) {
-                        buft = host_buft;
-                    }
-                }
-                backend_buft.push_back(buft);
-                backend_ptrs.push_back(backend.get());
-            }
-
-            const size_t max_nodes = model->max_nodes();
-
-            // buffer used to store the computation graph and the tensor meta data
-            ctx->buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
-
-            // TODO: move these checks to ggml_backend_sched
-            // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
-            bool pipeline_parallel =
-                model->n_devices() > 1 &&
-                model->params.n_gpu_layers > (int)model->hparams.n_layer &&
-                model->params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
-                params.offload_kqv;
-
-            // pipeline parallelism requires support for async compute and events in all devices
-            if (pipeline_parallel) {
-                for (auto & backend : ctx->backends) {
-                    auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
-                    if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) {
-                        // ignore CPU backend
-                        continue;
-                    }
-                    auto * dev = ggml_backend_get_device(backend.get());
-                    ggml_backend_dev_props props;
-                    ggml_backend_dev_get_props(dev, &props);
-                    if (!props.caps.async || !props.caps.events) {
-                        // device does not support async compute or events
-                        pipeline_parallel = false;
-                        break;
-                    }
-                }
-            }
-
-            ctx->sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));
-
-            if (pipeline_parallel) {
-                LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(ctx->sched.get()));
-            }
-
-            // initialize scheduler with the worst-case graph
-            uint32_t n_seqs = 1; // TODO: worst-case number of sequences
-            uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
-            llama_token token = ctx->model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
-
-            llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
-            ggml_cgraph * gf_pp = llama_build_graph(*ctx, ubatch_pp, true);
-
-            // reserve pp graph first so that buffers are only allocated once
-            ggml_backend_sched_reserve(ctx->sched.get(), gf_pp);
-            int n_splits_pp = ggml_backend_sched_get_n_splits(ctx->sched.get());
-            int n_nodes_pp = ggml_graph_n_nodes(gf_pp);
-
-            // reserve with tg graph to get the number of splits and nodes
-            llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
-            ggml_cgraph * gf_tg = llama_build_graph(*ctx, ubatch_tg, true);
-            ggml_backend_sched_reserve(ctx->sched.get(), gf_tg);
-            int n_splits_tg = ggml_backend_sched_get_n_splits(ctx->sched.get());
-            int n_nodes_tg = ggml_graph_n_nodes(gf_tg);
-
-            // reserve again with pp graph to avoid ggml-alloc reallocations during inference
-            gf_pp = llama_build_graph(*ctx, ubatch_pp, true);
-            if (!ggml_backend_sched_reserve(ctx->sched.get(), gf_pp)) {
-                LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
-                llama_free(ctx);
-                return nullptr;
-            }
-
-            for (size_t i = 0; i < backend_ptrs.size(); ++i) {
-                ggml_backend_t backend = backend_ptrs[i];
-                ggml_backend_buffer_type_t buft = backend_buft[i];
-                size_t size = ggml_backend_sched_get_buffer_size(ctx->sched.get(), backend);
-                if (size > 1) {
-                    LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
-                            ggml_backend_buft_name(buft),
-                            size / 1024.0 / 1024.0);
-                }
-            }
-
-            if (n_nodes_pp == n_nodes_tg) {
-                LLAMA_LOG_INFO("%s: graph nodes  = %d\n", __func__, n_nodes_pp);
-            } else {
-                LLAMA_LOG_INFO("%s: graph nodes  = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg);
-            }
-            if (n_splits_pp == n_splits_tg) {
-                LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp);
-            } else {
-                LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
-            }
-        }
-    }
-
-    return ctx;
-}
-
-struct llama_context * llama_new_context_with_model(
-                 struct llama_model * model,
-        struct llama_context_params   params) {
-    return llama_init_from_model(model, params);
-}
-
-//
-// kv cache
-//
-
-// TODO: tmp bridges below until `struct llama_kv_cache` is exposed through the public API
-
-struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max) {
-    return llama_kv_cache_view_init(ctx->kv_self, n_seq_max);
-}
-
-void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view) {
-    llama_kv_cache_view_update(view, ctx->kv_self);
-}
-
-int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx) {
-    return llama_get_kv_cache_token_count(ctx->kv_self);
-}
-
-int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx) {
-    return llama_get_kv_cache_used_cells(ctx->kv_self);
-}
-
-void llama_kv_cache_clear(struct llama_context * ctx) {
-    llama_kv_cache_clear(ctx->kv_self);
-}
-
-bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
-    return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1);
-}
-
-void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
-    if (seq_id_src == seq_id_dst) {
-        return;
-    }
-    llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1);
-}
-
-void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) {
-    llama_kv_cache_seq_keep(ctx->kv_self, seq_id);
-}
-
-void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
-    if (delta == 0) {
-        return;
-    }
-
-    llama_kv_cache_seq_add(ctx->kv_self, seq_id, p0, p1, delta);
-}
-
-void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
-    if (d == 1) {
-        return;
-    }
-
-    llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d);
-}
-
-llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) {
-    return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id);
-}
-
-void llama_kv_cache_defrag(struct llama_context * ctx) {
-    llama_kv_cache_defrag(ctx->kv_self);
-}
-
-void llama_kv_cache_update(struct llama_context * ctx) {
-    llama_kv_cache_update_impl(*ctx);
-}
-
-bool llama_kv_cache_can_shift(struct llama_context * ctx) {
-    return llama_kv_cache_can_shift(ctx->kv_self);
-}
-
-///
-
-int32_t llama_encode(
-        struct llama_context * ctx,
-          struct llama_batch   batch) {
-    const int ret = llama_encode_impl(*ctx, batch);
-    if (ret != 0) {
-        LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
-    }
-
-    return ret;
-}
-
-int32_t llama_decode(
-        struct llama_context * ctx,
-          struct llama_batch   batch) {
-    const int ret = llama_decode_impl(*ctx, batch);
-    if (ret != 0) {
-        LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
-    }
-
-    return ret;
-}
-
 //
 // chat templates
 //
@@ -10212,7 +320,6 @@ const char * llama_print_system_info(void) {
     static std::string s;
     s.clear(); // Clear the string, since it's static, otherwise it will accumulate data from previous calls.
 
-
     for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
         auto * reg = ggml_backend_reg_get(i);
         auto * get_features_fn = (ggml_backend_get_features_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_get_features");
@@ -10231,43 +338,3 @@ const char * llama_print_system_info(void) {
 
     return s.c_str();
 }
-
-//
-// perf
-//
-
-struct llama_perf_context_data llama_perf_context(const struct llama_context * ctx) {
-    struct llama_perf_context_data data = {};
-
-    if (ctx == nullptr) {
-        return data;
-    }
-
-    data.t_start_ms  = 1e-3 * ctx->t_start_us;
-    data.t_load_ms   = 1e-3 * ctx->t_load_us;
-    data.t_p_eval_ms = 1e-3 * ctx->t_p_eval_us;
-    data.t_eval_ms   = 1e-3 * ctx->t_eval_us;
-    data.n_p_eval    = std::max(1, ctx->n_p_eval);
-    data.n_eval      = std::max(1, ctx->n_eval);
-
-    return data;
-}
-
-void llama_perf_context_print(const struct llama_context * ctx) {
-    const auto data = llama_perf_context(ctx);
-
-    const double t_end_ms = 1e-3 * ggml_time_us();
-
-    LLAMA_LOG_INFO("%s:        load time = %10.2f ms\n", __func__, data.t_load_ms);
-    LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
-            __func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
-    LLAMA_LOG_INFO("%s:        eval time = %10.2f ms / %5d runs   (%8.2f ms per token, %8.2f tokens per second)\n",
-            __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
-    LLAMA_LOG_INFO("%s:       total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
-}
-
-void llama_perf_context_reset(struct llama_context * ctx) {
-    ctx->t_start_us  = ggml_time_us();
-    ctx->t_eval_us   = ctx->n_eval = 0;
-    ctx->t_p_eval_us = ctx->n_p_eval = 0;
-}