]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llm : add llm_build_context (#3881)
authorGeorgi Gerganov <redacted>
Wed, 1 Nov 2023 18:11:02 +0000 (20:11 +0200)
committerGitHub <redacted>
Wed, 1 Nov 2023 18:11:02 +0000 (20:11 +0200)
* llm : add llm_build_context

* llm : deduce norm eps based on type + explict max_alibi_bias, clamp_kqv

* llm : restore the non-graph llm_build_ functional API

ggml-ci

* llm : cleanup + comments

llama.cpp

index 42cedc7a1cd59275a7ad67a91b83c25715b86210..d0c4ef1015182881d98edbd818ed879c8ef8d637 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -3090,6 +3090,10 @@ static bool llama_model_load(
     return true;
 }
 
+//
+// llm_build
+//
+
 using llm_build_cb = std::function<void(struct ggml_tensor * cur, const char * name, int nl)>;
 
 enum llm_rope_type {
@@ -3098,17 +3102,35 @@ enum llm_rope_type {
     LLM_ROPE_GLM,
 };
 
+enum llm_ffn_op_type {
+    LLM_FFN_SILU,
+    LLM_FFN_GELU,
+    LLM_FFN_RELU,
+    LLM_FFN_RELU_SQR,
+};
+
+enum llm_ffn_gate_type {
+    LLM_FFN_SEQ,
+    LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
+};
+
+enum llm_norm_type {
+    LLM_NORM,
+    LLM_NORM_RMS,
+};
+
 static struct ggml_tensor * llm_build_inp_embd(
         struct ggml_context * ctx,
+        const llama_hparams & hparams,
           const llama_batch & batch,
          struct ggml_tensor * tok_embd,
-                    int64_t   n_embd,
-                    int32_t   n_tokens,
          const llm_build_cb & cb) {
+    const int64_t n_embd = hparams.n_embd;
+
     struct ggml_tensor * inpL;
 
     if (batch.token) {
-        struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_tokens);
+        struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, batch.n_tokens);
         cb(inp_tokens, "inp_tokens", -1);
 
         inpL = ggml_get_rows(ctx, tok_embd, inp_tokens);
@@ -3117,7 +3139,7 @@ static struct ggml_tensor * llm_build_inp_embd(
         GGML_ASSERT(false && "not implemented");
 #endif
 
-        inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens);
+        inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
     }
 
     return inpL;
@@ -3126,28 +3148,21 @@ static struct ggml_tensor * llm_build_inp_embd(
 // Persimmon: n_rot = n_embd_head/2
 // Other:     n_rot = n_embd_head
 static void llm_build_k_shift(
-        const llama_context & lctx,
-        struct ggml_context * ctx,
-         struct ggml_cgraph * graph,
-                    int64_t   n_rot,
-              llm_rope_type   type,
-         const llm_build_cb & cb) {
-    const auto & model   = lctx.model;
-    const auto & kv_self = lctx.kv_self;
-    const auto & cparams = lctx.cparams;
-
-    const auto & hparams = model.hparams;
-
+      struct ggml_context * ctx,
+      const llama_hparams & hparams,
+     const llama_kv_cache & kv,
+       struct ggml_cgraph * graph,
+            llm_rope_type   type,
+                  int64_t   n_ctx,
+                  int64_t   n_rot,
+                  float     freq_base,
+                  float     freq_scale,
+       const llm_build_cb & cb) {
     const int64_t n_layer     = hparams.n_layer;
     const int64_t n_head_kv   = hparams.n_head_kv;
     const int64_t n_embd_gqa  = hparams.n_embd_gqa();
     const int64_t n_embd_head = hparams.n_embd_head();
 
-    const int64_t n_ctx = lctx.cparams.n_ctx;
-
-    const float freq_base  = cparams.rope_freq_base;
-    const float freq_scale = cparams.rope_freq_scale;
-
     GGML_ASSERT(n_embd_head % n_rot == 0);
 
     struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_ctx);
@@ -3165,11 +3180,11 @@ static void llm_build_k_shift(
         struct ggml_tensor * tmp =
             // we rotate only the first n_rot dimensions
             ggml_rope_custom_inplace(ctx,
-                    ggml_view_3d(ctx, kv_self.k,
+                    ggml_view_3d(ctx, kv.k,
                         n_rot, n_head_kv, n_ctx,
-                        ggml_element_size(kv_self.k)*n_embd_head,
-                        ggml_element_size(kv_self.k)*n_embd_gqa,
-                        ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
+                        ggml_element_size(kv.k)*n_embd_head,
+                        ggml_element_size(kv.k)*n_embd_gqa,
+                        ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il),
                     K_shift, n_rot, rope_type, 0, freq_base, freq_scale);
         cb(tmp, "K_shifted", il);
         ggml_build_forward_expand(graph, tmp);
@@ -3177,22 +3192,17 @@ static void llm_build_k_shift(
 }
 
 static void llm_build_kv_store(
-        const llama_context & lctx,
         struct ggml_context * ctx,
+        const llama_hparams & hparams,
+       const llama_kv_cache & kv,
          struct ggml_cgraph * graph,
          struct ggml_tensor * k_cur,
          struct ggml_tensor * v_cur,
+                    int64_t   n_ctx,
                     int32_t   n_tokens,
                     int32_t   kv_head,
          const llm_build_cb & cb,
                     int64_t   il) {
-    const auto & model   = lctx.model;
-    const auto & kv_self = lctx.kv_self;
-    const auto & cparams = lctx.cparams;
-
-    const auto & hparams = model.hparams;
-
-    const int64_t n_ctx      = cparams.n_ctx;
     const int64_t n_embd_gqa = hparams.n_embd_gqa();
 
     // compute the transposed [n_tokens, n_embd] V matrix
@@ -3200,13 +3210,13 @@ static void llm_build_kv_store(
     //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
     cb(v_cur_t, "v_cur_t", il);
 
-    struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv_self.k, n_tokens*n_embd_gqa,
-                (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head));
+    struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k, n_tokens*n_embd_gqa,
+            (ggml_element_size(kv.k)*n_embd_gqa)*(il*n_ctx + kv_head));
     cb(k_cache_view, "k_cache_view", il);
 
-    struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv_self.v, n_tokens, n_embd_gqa,
-            (   n_ctx)*ggml_element_size(kv_self.v),
-            (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
+    struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v, n_tokens, n_embd_gqa,
+            (   n_ctx)*ggml_element_size(kv.v),
+            (il*n_ctx)*ggml_element_size(kv.v)*n_embd_gqa + kv_head*ggml_element_size(kv.v));
     cb(v_cache_view, "v_cache_view", il);
 
     // important: storing RoPE-ed version of K in the KV cache!
@@ -3214,23 +3224,18 @@ static void llm_build_kv_store(
     ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view));
 }
 
-enum llm_norm_type {
-    LLM_NORM,
-    LLM_NORM_RMS,
-};
-
 static struct ggml_tensor * llm_build_norm(
         struct ggml_context * ctx,
          struct ggml_tensor * cur,
+        const llama_hparams & hparams,
          struct ggml_tensor * mw,
          struct ggml_tensor * mb,
               llm_norm_type   type,
-                      float   eps,
          const llm_build_cb & cb,
                         int   il) {
     switch (type) {
-        case LLM_NORM:     cur = ggml_norm    (ctx, cur, eps); break;
-        case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, eps); break;
+        case LLM_NORM:     cur = ggml_norm    (ctx, cur, hparams.f_norm_eps);     break;
+        case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, hparams.f_norm_rms_eps); break;
     }
 
     if (mw || mb) {
@@ -3251,18 +3256,6 @@ static struct ggml_tensor * llm_build_norm(
     return cur;
 }
 
-enum llm_ffn_op_type {
-    LLM_FFN_SILU,
-    LLM_FFN_GELU,
-    LLM_FFN_RELU,
-    LLM_FFN_RELU_SQR,
-};
-
-enum llm_ffn_gate_type {
-    LLM_FFN_SEQ,
-    LLM_FFN_PAR, // ffn_gate is parallel to ffn_up
-};
-
 static struct ggml_tensor * llm_build_ffn(
         struct ggml_context * ctx,
          struct ggml_tensor * cur,
@@ -3351,26 +3344,21 @@ static struct ggml_tensor * llm_build_ffn(
 
 // if max_alibi_bias > 0 then apply ALiBi
 static struct ggml_tensor * llm_build_kqv(
-        const llama_context & lctx,
         struct ggml_context * ctx,
          struct ggml_tensor * cur,
+        const llama_hparams & hparams,
+       const llama_kv_cache & kv,
          struct ggml_tensor * wo,
          struct ggml_tensor * wo_b,
          struct ggml_tensor * q_cur,
          struct ggml_tensor * kq_scale,
          struct ggml_tensor * kq_mask,
+                    int64_t   n_ctx,
                     int32_t   n_tokens,
                     int32_t   n_kv,
-                      float   alibi_bias_max,
+                    float     max_alibi_bias,
          const llm_build_cb & cb,
-         int   il) {
-    const auto & model   = lctx.model;
-    const auto & kv_self = lctx.kv_self;
-    const auto & cparams = lctx.cparams;
-
-    const auto & hparams = model.hparams;
-
-    const int64_t n_ctx       = cparams.n_ctx;
+                    int       il) {
     const int64_t n_embd      = hparams.n_embd;
     const int64_t n_head      = hparams.n_head;
     const int64_t n_head_kv   = hparams.n_head_kv;
@@ -3381,11 +3369,11 @@ static struct ggml_tensor * llm_build_kqv(
     cb(q, "q", il);
 
     struct ggml_tensor * k =
-        ggml_view_3d(ctx, kv_self.k,
+        ggml_view_3d(ctx, kv.k,
                 n_embd_head, n_kv, n_head_kv,
-                ggml_element_size(kv_self.k)*n_embd_gqa,
-                ggml_element_size(kv_self.k)*n_embd_head,
-                ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
+                ggml_element_size(kv.k)*n_embd_gqa,
+                ggml_element_size(kv.k)*n_embd_head,
+                ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il);
     cb(k, "k", il);
 
     struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
@@ -3394,11 +3382,11 @@ static struct ggml_tensor * llm_build_kqv(
     kq = ggml_scale(ctx, kq, kq_scale);
     cb(kq, "kq_scaled", il);
 
-    if (alibi_bias_max > 0.0f) {
+    if (max_alibi_bias > 0.0f) {
         // TODO: n_head or n_head_kv
         // TODO: K-shift is likely not working
         // TODO: change to ggml_add
-        kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, alibi_bias_max);
+        kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias);
         cb(kq, "kq_scaled_alibi", il);
     }
 
@@ -3410,11 +3398,11 @@ static struct ggml_tensor * llm_build_kqv(
 
     // split cached v into n_head heads
     struct ggml_tensor * v =
-        ggml_view_3d(ctx, kv_self.v,
+        ggml_view_3d(ctx, kv.v,
                 n_kv, n_embd_head, n_head_kv,
-                ggml_element_size(kv_self.v)*n_ctx,
-                ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
-                ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
+                ggml_element_size(kv.v)*n_ctx,
+                ggml_element_size(kv.v)*n_ctx*n_embd_head,
+                ggml_element_size(kv.v)*n_ctx*n_embd_gqa*il);
     cb(v, "v", il);
 
     struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
@@ -3438,1259 +3426,1011 @@ static struct ggml_tensor * llm_build_kqv(
     return cur;
 }
 
-static struct ggml_cgraph * llm_build_llama(
-        llama_context  & lctx,
-    const llama_batch  & batch,
-    const llm_build_cb & cb,
-                  bool   worst_case) {
-    const auto & model   = lctx.model;
-    const auto & hparams = model.hparams;
-    const auto & cparams = lctx.cparams;
-
-    const auto & kv_self = lctx.kv_self;
+struct llm_build_context {
+    const llama_model    & model;
+    const llama_hparams  & hparams;
+    const llama_cparams  & cparams;
+    const llama_batch    & batch;
+    const llama_kv_cache & kv_self;
 
-    GGML_ASSERT(!!kv_self.ctx);
+    const int64_t n_embd;
+    const int64_t n_layer;
+    const int64_t n_ctx;       // user-specified context size (can be different from n_ctx_train)
+    const int64_t n_head;
+    const int64_t n_head_kv;
+    const int64_t n_embd_head;
+    const int64_t n_embd_gqa;
 
-    const int64_t n_embd      = hparams.n_embd;
-    const int64_t n_layer     = hparams.n_layer;
-    const int64_t n_ctx       = cparams.n_ctx;
-    const int64_t n_head      = hparams.n_head;
-    const int64_t n_head_kv   = hparams.n_head_kv;
-    const int64_t n_embd_head = hparams.n_embd_head();
+    const float freq_base;
+    const float freq_scale;
+    const float norm_eps;
+    const float norm_rms_eps;
 
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    const int32_t n_tokens;
+    const int32_t n_kv;     // size of KV cache to consider (n_kv <= n_ctx)
+    const int32_t kv_head;  // index of where we store new KV data in the cache
 
-    const float freq_base    = cparams.rope_freq_base;
-    const float freq_scale   = cparams.rope_freq_scale;
-    const float norm_rms_eps = hparams.f_norm_rms_eps;
+    const bool do_rope_shift;
 
-    const int32_t n_tokens = batch.n_tokens;
-    const int32_t n_kv     = worst_case ? n_ctx            : kv_self.n;
-    const int32_t kv_head  = worst_case ? n_ctx - n_tokens : kv_self.head;
+    const llm_build_cb & cb;
 
-    const bool do_rope_shift = worst_case || kv_self.has_shift;
+    llama_buffer & buf_compute;
 
-    //printf("n_kv = %d\n", n_kv);
+    struct ggml_context * ctx0 = nullptr;
 
-    auto & buf_compute = lctx.buf_compute;
-
-    struct ggml_init_params params = {
-        /*.mem_size   =*/ buf_compute.size,
-        /*.mem_buffer =*/ buf_compute.data,
-        /*.no_alloc   =*/ true,
-    };
-
-    struct ggml_context * ctx0 = ggml_init(params);
-
-    ggml_cgraph * gf = ggml_new_graph(ctx0);
-
-    struct ggml_tensor * cur;
-    struct ggml_tensor * inpL;
-
-    inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
-    cb(inpL, "inp_embd", -1);
-
-    // inp_pos - contains the positions
-    struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
-    cb(inp_pos, "inp_pos", -1);
-
-    // KQ_scale
-    struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-    cb(KQ_scale, "KQ_scale", -1);
-
-    // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-    struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
-    cb(KQ_mask, "KQ_mask", -1);
+    // TODO: consider making the entire interface noexcept
+    llm_build_context(
+        llama_context  & lctx,
+    const llama_batch  & batch,
+    const llm_build_cb & cb,
+                  bool   worst_case) :
+        model         (lctx.model),
+        hparams       (model.hparams),
+        cparams       (lctx.cparams),
+        batch         (batch),
+        kv_self       (lctx.kv_self),
+        n_embd        (hparams.n_embd),
+        n_layer       (hparams.n_layer),
+        n_ctx         (cparams.n_ctx),
+        n_head        (hparams.n_head),
+        n_head_kv     (hparams.n_head_kv),
+        n_embd_head   (hparams.n_embd_head()),
+        n_embd_gqa    (hparams.n_embd_gqa()),
+        freq_base     (cparams.rope_freq_base),
+        freq_scale    (cparams.rope_freq_scale),
+        norm_eps      (hparams.f_norm_eps),
+        norm_rms_eps  (hparams.f_norm_rms_eps),
+        n_tokens      (batch.n_tokens),
+        n_kv          (worst_case ? n_ctx            : kv_self.n),
+        kv_head       (worst_case ? n_ctx - n_tokens : kv_self.head),
+        do_rope_shift (worst_case || kv_self.has_shift),
+        cb            (cb),
+        buf_compute   (lctx.buf_compute) {
+            GGML_ASSERT(!!kv_self.ctx);
+
+            // all initializations should be done in init()
+        }
+
+    void init() {
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ buf_compute.size,
+            /*.mem_buffer =*/ buf_compute.data,
+            /*.no_alloc   =*/ true,
+        };
 
-    // shift the entire K-cache if needed
-    if (do_rope_shift) {
-        llm_build_k_shift(lctx, ctx0, gf, n_embd_head, LLM_ROPE, cb);
+        ctx0 = ggml_init(params);
     }
 
-    for (int il = 0; il < n_layer; ++il) {
-        struct ggml_tensor * inpSA = inpL;
-
-        // norm
-        cur = llm_build_norm(ctx0, inpL,
-                model.layers[il].attn_norm, NULL,
-                LLM_NORM_RMS, norm_rms_eps, cb, il);
-        cb(cur, "attn_norm", il);
-
-        // self-attention
-        {
-            // compute Q and K and RoPE them
-            struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
-            cb(Qcur, "Qcur", il);
-
-            struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
-            cb(Kcur, "Kcur", il);
-
-            struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
-            cb(Vcur, "Vcur", il);
-
-            Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
-            cb(Qcur, "Qcur", il);
-
-            Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
-            cb(Kcur, "Kcur", il);
-
-            llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
-
-            cur = llm_build_kqv(lctx, ctx0, cur,
-                    model.layers[il].wo, NULL,
-                    Qcur, KQ_scale, KQ_mask, n_tokens, n_kv, -1.0f, cb, il);
-            cb(cur, "kqv_out", il);
-        }
-
-        struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-        cb(ffn_inp, "ffn_inp", il);
-
-        // feed-forward network
-        {
-            cur = llm_build_norm(ctx0, ffn_inp,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, norm_rms_eps, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_ffn(ctx0, cur,
-                    model.layers[il].ffn_up,   NULL,
-                    model.layers[il].ffn_gate, NULL,
-                    model.layers[il].ffn_down, NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            cb(cur, "ffn_out", il);
+    void free() {
+        if (ctx0) {
+            ggml_free(ctx0);
+            ctx0 = nullptr;
         }
-
-        cur = ggml_add(ctx0, cur, ffn_inp);
-        cb(cur, "l_out", il);
-
-        // input for next layer
-        inpL = cur;
     }
 
-    cur = inpL;
-
-    cur = llm_build_norm(ctx0, cur,
-            model.output_norm, NULL,
-            LLM_NORM_RMS, norm_rms_eps, cb, -1);
-    cb(cur, "result_norm", -1);
-
-    // lm_head
-    cur = ggml_mul_mat(ctx0, model.output, cur);
-    cb(cur, "result_output", -1);
-
-    ggml_build_forward_expand(gf, cur);
-
-    ggml_free(ctx0);
-
-    return gf;
-}
-
-static struct ggml_cgraph * llm_build_baichaun(
-         llama_context & lctx,
-     const llama_batch & batch,
-    const llm_build_cb & cb,
-                  bool   worst_case) {
-    const auto & model   = lctx.model;
-    const auto & hparams = model.hparams;
-    const auto & cparams = lctx.cparams;
-
-    const auto & kv_self = lctx.kv_self;
-
-    GGML_ASSERT(!!kv_self.ctx);
-
-    const int64_t n_embd      = hparams.n_embd;
-    const int64_t n_layer     = hparams.n_layer;
-    const int64_t n_ctx       = cparams.n_ctx;
-    const int64_t n_head      = hparams.n_head;
-    const int64_t n_head_kv   = hparams.n_head_kv;
-    const int64_t n_embd_head = hparams.n_embd_head();
-
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-    const float freq_base    = cparams.rope_freq_base;
-    const float freq_scale   = cparams.rope_freq_scale;
-    const float norm_rms_eps = hparams.f_norm_rms_eps;
+    struct ggml_cgraph * build_llama() {
+        struct ggml_cgraph * gf = ggml_new_graph(ctx0);
 
-    const int32_t n_tokens = batch.n_tokens;
-    const int32_t n_kv     = worst_case ? n_ctx            : kv_self.n;
-    const int32_t kv_head  = worst_case ? n_ctx - n_tokens : kv_self.head;
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
 
-    const bool do_rope_shift = worst_case || kv_self.has_shift;
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
 
-    auto & buf_compute = lctx.buf_compute;
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+        cb(inpL, "inp_embd", -1);
 
-    struct ggml_init_params params = {
-        /*.mem_size   =*/ buf_compute.size,
-        /*.mem_buffer =*/ buf_compute.data,
-        /*.no_alloc   =*/ true,
-    };
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+        cb(inp_pos, "inp_pos", -1);
 
-    struct ggml_context * ctx0 = ggml_init(params);
+        // KQ_scale
+        struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+        cb(KQ_scale, "KQ_scale", -1);
 
-    ggml_cgraph * gf = ggml_new_graph(ctx0);
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+        cb(KQ_mask, "KQ_mask", -1);
 
-    struct ggml_tensor * cur;
-    struct ggml_tensor * inpL;
+        // shift the entire K-cache if needed
+        if (do_rope_shift) {
+            llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+        }
 
-    inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
-    cb(inpL, "inp_embd", -1);
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
 
-    // inp_pos - contains the positions
-    struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
-    cb(inp_pos, "inp_pos", -1);
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
 
-    // KQ_scale
-    struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-    cb(KQ_scale, "KQ_scale", -1);
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
 
-    // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-    struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
-    cb(KQ_mask, "KQ_mask", -1);
+                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
 
-    // shift the entire K-cache if needed
-    if (do_rope_shift) {
-        llm_build_k_shift(lctx, ctx0, gf, n_embd_head, LLM_ROPE, cb);
-    }
+                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
 
-    for (int il = 0; il < n_layer; ++il) {
-        struct ggml_tensor * inpSA = inpL;
+                Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
+                cb(Qcur, "Qcur", il);
 
-        cur = llm_build_norm(ctx0, inpL,
-                model.layers[il].attn_norm, NULL,
-                LLM_NORM_RMS, norm_rms_eps, cb, il);
-        cb(cur, "attn_norm", il);
+                Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
+                cb(Kcur, "Kcur", il);
 
-        // self-attention
-        {
-            struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
-            cb(Qcur, "Qcur", il);
+                llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
 
-            struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
-            cb(Kcur, "Kcur", il);
+                cur = llm_build_kqv(ctx0, cur, hparams, kv_self,
+                        model.layers[il].wo, NULL,
+                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
+                cb(cur, "kqv_out", il);
+            }
 
-            struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
-            cb(Vcur, "Vcur", il);
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
 
-            switch (model.type) {
-                case MODEL_7B:
-                    Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens),    inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
-                    Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
-                    break;
-                case MODEL_13B:
-                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd/n_head, n_head, n_tokens);
-                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd/n_head, n_head, n_tokens);
-                    break;
-                default:
-                    GGML_ASSERT(false);
+            // feed-forward network
+            {
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = llm_build_ffn(ctx0, cur,
+                        model.layers[il].ffn_up,   NULL,
+                        model.layers[il].ffn_gate, NULL,
+                        model.layers[il].ffn_down, NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+                cb(cur, "ffn_out", il);
             }
-            cb(Qcur, "Qcur", il);
-            cb(Kcur, "Kcur", il);
 
-            llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "l_out", il);
 
-            // apply ALiBi for 13B model
-            const float alibi_bias_max = model.type == MODEL_13B ? 8.0f : -1.0f;
-
-            cur = llm_build_kqv(lctx, ctx0, cur,
-                    model.layers[il].wo, NULL,
-                    Qcur, KQ_scale, KQ_mask, n_tokens, n_kv, alibi_bias_max, cb, il);
-            cb(cur, "kqv_out", il);
+            // input for next layer
+            inpL = cur;
         }
 
-        struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-        cb(ffn_inp, "ffn_inp", il);
+        cur = inpL;
 
-        // feed-forward network
-        {
-            cur = llm_build_norm(ctx0, ffn_inp,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, norm_rms_eps, cb, il);
-            cb(cur, "ffn_norm", il);
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
 
-            cur = llm_build_ffn(ctx0, cur,
-                    model.layers[il].ffn_up,   NULL,
-                    model.layers[il].ffn_gate, NULL,
-                    model.layers[il].ffn_down, NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            cb(cur, "ffn_out", il);
-        }
+        // lm_head
+        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
 
-        cur = ggml_add(ctx0, cur, ffn_inp);
-        cb(cur, "l_out", il);
+        ggml_build_forward_expand(gf, cur);
 
-        // input for next layer
-        inpL = cur;
+        return gf;
     }
 
-    cur = inpL;
+    struct ggml_cgraph * build_baichuan() {
+        struct ggml_cgraph * gf = ggml_new_graph(ctx0);
 
-    cur = llm_build_norm(ctx0, cur,
-            model.output_norm, NULL,
-            LLM_NORM_RMS, norm_rms_eps, cb, -1);
-    cb(cur, "result_norm", -1);
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
 
-    // lm_head
-    cur = ggml_mul_mat(ctx0, model.output, cur);
-    cb(cur, "result_output", -1);
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+        cb(inpL, "inp_embd", -1);
 
-    ggml_build_forward_expand(gf, cur);
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+        cb(inp_pos, "inp_pos", -1);
 
-    ggml_free(ctx0);
+        // KQ_scale
+        struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+        cb(KQ_scale, "KQ_scale", -1);
 
-    return gf;
-}
-
-static struct ggml_cgraph * llm_build_falcon(
-         llama_context & lctx,
-     const llama_batch & batch,
-    const llm_build_cb & cb,
-                  bool   worst_case) {
-    const auto & model   = lctx.model;
-    const auto & hparams = model.hparams;
-    const auto & cparams = lctx.cparams;
-
-    const auto & kv_self = lctx.kv_self;
-
-    GGML_ASSERT(!!kv_self.ctx);
-
-    const int64_t n_embd      = hparams.n_embd;
-    const int64_t n_layer     = hparams.n_layer;
-    const int64_t n_ctx       = cparams.n_ctx;
-    const int64_t n_head      = hparams.n_head;
-    const int64_t n_head_kv   = hparams.n_head_kv;
-    const int64_t n_embd_head = hparams.n_embd_head();
-    const int64_t n_embd_gqa  = hparams.n_embd_gqa();
-
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+        cb(KQ_mask, "KQ_mask", -1);
 
-    const float freq_base  = cparams.rope_freq_base;
-    const float freq_scale = cparams.rope_freq_scale;
-    const float norm_eps   = hparams.f_norm_eps;
-
-    const int32_t n_tokens = batch.n_tokens;
-    const int32_t n_kv     = worst_case ? n_ctx            : kv_self.n;
-    const int32_t kv_head  = worst_case ? n_ctx - n_tokens : kv_self.head;
-
-    const bool do_rope_shift = worst_case || kv_self.has_shift;
-
-    //printf("kv_head = %d, n_kv = %d, n_tokens = %d, n_ctx = %d, is_measure = %d, has_shift = %d\n",
-    //        kv_head, n_kv, n_tokens, n_ctx, ggml_allocr_is_measure(lctx.alloc), kv_self.has_shift);
-
-    auto & buf_compute = lctx.buf_compute;
-
-    struct ggml_init_params params = {
-        /*.mem_size   =*/ buf_compute.size,
-        /*.mem_buffer =*/ buf_compute.data,
-        /*.no_alloc   =*/ true,
-    };
-
-    struct ggml_context * ctx0 = ggml_init(params);
-
-    ggml_cgraph * gf = ggml_new_graph(ctx0);
+        // shift the entire K-cache if needed
+        if (do_rope_shift) {
+            llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+        }
 
-    struct ggml_tensor * cur;
-    struct ggml_tensor * inpL;
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
 
-    inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
-    cb(inpL, "inp_embd", -1);
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
 
-    // inp_pos - contains the positions
-    struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
-    cb(inp_pos, "inp_pos", -1);
+            // self-attention
+            {
+                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
 
-    // KQ_scale
-    struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-    cb(KQ_scale, "KQ_scale", -1);
+                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
 
-    // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-    struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
-    cb(KQ_mask, "KQ_mask", -1);
+                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
 
-    // shift the entire K-cache if needed
-    if (do_rope_shift) {
-        llm_build_k_shift(lctx, ctx0, gf, n_embd_head, LLM_ROPE_NEOX, cb);
-    }
+                switch (model.type) {
+                    case MODEL_7B:
+                        Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens),    inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
+                        Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
+                        break;
+                    case MODEL_13B:
+                        Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd/n_head, n_head, n_tokens);
+                        Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd/n_head, n_head, n_tokens);
+                        break;
+                    default:
+                        GGML_ASSERT(false);
+                }
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
 
-    for (int il = 0; il < n_layer; ++il) {
-        struct ggml_tensor * attn_norm;
+                llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
 
-        attn_norm = llm_build_norm(ctx0, inpL,
-                model.layers[il].attn_norm,
-                model.layers[il].attn_norm_b,
-                LLM_NORM, norm_eps, cb, il);
-        cb(attn_norm, "attn_norm", il);
+                // apply ALiBi for 13B model
+                const float max_alibi_bias = model.type == MODEL_13B ? 8.0f : -1.0f;
 
-        // self-attention
-        {
-            if (model.layers[il].attn_norm_2) {
-                // Falcon-40B
-                cur = llm_build_norm(ctx0, attn_norm,
-                        model.layers[il].attn_norm_2,
-                        model.layers[il].attn_norm_2_b,
-                        LLM_NORM, norm_eps, cb, il);
-                cb(cur, "attn_norm_2", il);
-            } else {
-                cur = attn_norm;
+                cur = llm_build_kqv(ctx0, cur, hparams, kv_self,
+                        model.layers[il].wo, NULL,
+                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, cb, il);
+                cb(cur, "kqv_out", il);
             }
 
-            cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
-            cb(cur, "wqkv", il);
-
-            struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-            struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-            struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
-            cb(Qcur, "Qcur", il);
-            cb(Kcur, "Kcur", il);
-            cb(Vcur, "Vcur", il);
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
 
-            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-
-            // using mode = 2 for neox mode
-            Qcur = ggml_rope_custom(ctx0, Qcur, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale);
-            cb(Qcur, "Qcur", il);
-
-            Kcur = ggml_rope_custom(ctx0, Kcur, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale);
-            cb(Kcur, "Kcur", il);
+            // feed-forward network
+            {
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = llm_build_ffn(ctx0, cur,
+                        model.layers[il].ffn_up,   NULL,
+                        model.layers[il].ffn_gate, NULL,
+                        model.layers[il].ffn_down, NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+                cb(cur, "ffn_out", il);
+            }
 
-            llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "l_out", il);
 
-            cur = llm_build_kqv(lctx, ctx0, attn_norm,
-                    model.layers[il].wo, NULL,
-                    Qcur, KQ_scale, KQ_mask, n_tokens, n_kv, -1.0f, cb, il);
-            cb(cur, "kqv_out", il);
+            // input for next layer
+            inpL = cur;
         }
 
-        struct ggml_tensor * ffn_inp = cur;
+        cur = inpL;
 
-        // feed forward
-        {
-            cur = llm_build_ffn(ctx0, attn_norm, // !! use the attn norm, not the result
-                    model.layers[il].ffn_up,   NULL,
-                    NULL,                      NULL,
-                    model.layers[il].ffn_down, NULL,
-                    LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
-            cb(cur, "ffn_out", il);
-        }
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
 
-        cur = ggml_add(ctx0, cur, ffn_inp);
-        cb(cur, "l_out", il);
+        // lm_head
+        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
 
-        cur = ggml_add(ctx0, cur, inpL);
-        cb(cur, "l_out", il);
+        ggml_build_forward_expand(gf, cur);
 
-        // input for next layer
-        inpL = cur;
+        return gf;
     }
 
-    cur = inpL;
-
-    // norm
-    cur = llm_build_norm(ctx0, cur,
-            model.output_norm,
-            model.output_norm_b,
-            LLM_NORM, norm_eps, cb, -1);
-    cb(cur, "result_norm", -1);
-
-    cur = ggml_mul_mat(ctx0, model.output, cur);
-    cb(cur, "result_output", -1);
+    struct ggml_cgraph * build_falcon() {
+        struct ggml_cgraph * gf = ggml_new_graph(ctx0);
 
-    ggml_build_forward_expand(gf, cur);
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
 
-    ggml_free(ctx0);
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+        cb(inpL, "inp_embd", -1);
 
-    return gf;
-}
-
-static struct ggml_cgraph * llm_build_starcoder(
-         llama_context & lctx,
-     const llama_batch & batch,
-    const llm_build_cb & cb,
-                  bool   worst_case) {
-    const auto & model   = lctx.model;
-    const auto & hparams = model.hparams;
-    const auto & cparams = lctx.cparams;
-
-    const auto & kv_self = lctx.kv_self;
-
-    GGML_ASSERT(!!kv_self.ctx);
-
-    const int64_t n_embd      = hparams.n_embd;
-    const int64_t n_layer     = hparams.n_layer;
-    const int64_t n_ctx       = cparams.n_ctx;
-    const int64_t n_head      = hparams.n_head;
-    const int64_t n_embd_head = hparams.n_embd_head();
-    const int64_t n_embd_gqa  = hparams.n_embd_gqa();
-
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-    const float norm_eps = hparams.f_norm_eps;
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+        cb(inp_pos, "inp_pos", -1);
 
-    const int32_t n_tokens = batch.n_tokens;
-    const int32_t n_kv     = worst_case ? n_ctx            : kv_self.n;
-    const int32_t kv_head  = worst_case ? n_ctx - n_tokens : kv_self.head;
+        // KQ_scale
+        struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+        cb(KQ_scale, "KQ_scale", -1);
 
-    auto & buf_compute = lctx.buf_compute;
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+        cb(KQ_mask, "KQ_mask", -1);
 
-    struct ggml_init_params params = {
-        /*.mem_size   =*/ buf_compute.size,
-        /*.mem_buffer =*/ buf_compute.data,
-        /*.no_alloc   =*/ true,
-    };
-
-    struct ggml_context * ctx0 = ggml_init(params);
+        // shift the entire K-cache if needed
+        if (do_rope_shift) {
+            llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+        }
 
-    ggml_cgraph * gf = ggml_new_graph(ctx0);
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * attn_norm;
 
-    struct ggml_tensor * cur;
-    struct ggml_tensor * pos;
-    struct ggml_tensor * inpL;
+            attn_norm = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, cb, il);
+            cb(attn_norm, "attn_norm", il);
 
-    inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
-    cb(inpL, "inp_embd", -1);
+            // self-attention
+            {
+                if (model.layers[il].attn_norm_2) {
+                    // Falcon-40B
+                    cur = llm_build_norm(ctx0, attn_norm, hparams,
+                            model.layers[il].attn_norm_2,
+                            model.layers[il].attn_norm_2_b,
+                            LLM_NORM, cb, il);
+                    cb(cur, "attn_norm_2", il);
+                } else {
+                    cur = attn_norm;
+                }
 
-    // inp_pos - contains the positions
-    struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
-    cb(inp_pos, "inp_pos", -1);
+                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
 
-    // KQ_scale
-    struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-    cb(KQ_scale, "KQ_scale", -1);
+                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
 
-    // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-    struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
-    cb(KQ_mask, "KQ_mask", -1);
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
-    pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
-    cb(pos, "pos_embd", -1);
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
 
-    inpL = ggml_add(ctx0, inpL, pos);
-    cb(inpL, "inpL", -1);
+                // using mode = 2 for neox mode
+                Qcur = ggml_rope_custom(ctx0, Qcur, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale);
+                cb(Qcur, "Qcur", il);
 
-    for (int il = 0; il < n_layer; ++il) {
-        cur = llm_build_norm(ctx0, inpL,
-                model.layers[il].attn_norm,
-                model.layers[il].attn_norm_b,
-                LLM_NORM, norm_eps, cb, il);
-        cb(cur, "attn_norm", il);
+                Kcur = ggml_rope_custom(ctx0, Kcur, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale);
+                cb(Kcur, "Kcur", il);
 
-        // self-attention
-        {
-            cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
-            cb(cur, "wqkv", il);
+                llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
 
-            cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
-            cb(cur, "bqkv", il);
+                cur = llm_build_kqv(ctx0, attn_norm, hparams, kv_self,
+                        model.layers[il].wo, NULL,
+                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
+                cb(cur, "kqv_out", il);
+            }
 
-            struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-            struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-            struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+            struct ggml_tensor * ffn_inp = cur;
 
-            cb(Qcur, "Qcur", il);
-            cb(Kcur, "Kcur", il);
-            cb(Vcur, "Vcur", il);
+            // feed forward
+            {
+                cur = llm_build_ffn(ctx0, attn_norm, // !! use the attn norm, not the result
+                        model.layers[il].ffn_up,   NULL,
+                        NULL,                      NULL,
+                        model.layers[il].ffn_down, NULL,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
+                cb(cur, "ffn_out", il);
+            }
 
-            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "l_out", il);
 
-            llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
+            cur = ggml_add(ctx0, cur, inpL);
+            cb(cur, "l_out", il);
 
-            cur = llm_build_kqv(lctx, ctx0, cur,
-                    model.layers[il].wo, model.layers[il].bo,
-                    Qcur, KQ_scale, KQ_mask, n_tokens, n_kv, -1.0f, cb, il);
-            cb(cur, "kqv_out", il);
+            // input for next layer
+            inpL = cur;
         }
 
-        // add the input
-        struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
-        cb(ffn_inp, "ffn_inp", il);
+        cur = inpL;
 
-        // FF
-        {
-            cur = llm_build_norm(ctx0, ffn_inp,
-                    model.layers[il].ffn_norm,
-                    model.layers[il].ffn_norm_b,
-                    LLM_NORM, norm_eps, cb, il);
-            cb(cur, "ffn_norm", il);
-
-            cur = llm_build_ffn(ctx0, cur,
-                    model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
-                    NULL,                      NULL,
-                    model.layers[il].ffn_down, model.layers[il].ffn_down_b,
-                    LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
-            cb(cur, "ffn_out", il);
-        }
+        // norm
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, cb, -1);
+        cb(cur, "result_norm", -1);
 
-        inpL = ggml_add(ctx0, cur, ffn_inp);
-        cb(inpL, "l_out", il);
-    }
+        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
 
-    cur = llm_build_norm(ctx0, inpL,
-            model.output_norm,
-            model.output_norm_b,
-            LLM_NORM, norm_eps, cb, -1);
-    cb(cur, "result_norm", -1);
+        ggml_build_forward_expand(gf, cur);
 
-    cur = ggml_mul_mat(ctx0, model.output, cur);
-    cb(cur, "result_output", -1);
+        return gf;
+    }
 
-    ggml_build_forward_expand(gf, cur);
-    ggml_free(ctx0);
+    struct ggml_cgraph * build_starcoder() {
+        struct ggml_cgraph * gf = ggml_new_graph(ctx0);
 
-    return gf;
-}
+        struct ggml_tensor * cur;
+        struct ggml_tensor * pos;
+        struct ggml_tensor * inpL;
 
-static struct ggml_cgraph * llm_build_persimmon(
-         llama_context & lctx,
-     const llama_batch & batch,
-    const llm_build_cb & cb,
-                  bool   worst_case) {
-    const auto & model = lctx.model;
-    const auto & hparams = model.hparams;
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+        cb(inpL, "inp_embd", -1);
 
-    const auto & kv_self = lctx.kv_self;
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+        cb(inp_pos, "inp_pos", -1);
 
-    GGML_ASSERT(!!kv_self.ctx);
+        // KQ_scale
+        struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+        cb(KQ_scale, "KQ_scale", -1);
 
-    const auto & cparams = lctx.cparams;
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+        cb(KQ_mask, "KQ_mask", -1);
 
-    const int64_t n_embd      = hparams.n_embd;
-    const int64_t n_layer     = hparams.n_layer;
-    const int64_t n_ctx       = cparams.n_ctx;
-    const int64_t n_head_kv   = hparams.n_head_kv;
-    const int64_t n_head      = hparams.n_head;
-    const int64_t n_embd_head = hparams.n_embd_head();
-    const int64_t n_rot       = n_embd_head / 2;
+        pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
+        cb(pos, "pos_embd", -1);
 
-    const float freq_base  = cparams.rope_freq_base;
-    const float freq_scale = cparams.rope_freq_scale;
-    const float norm_eps   = hparams.f_norm_eps;
+        inpL = ggml_add(ctx0, inpL, pos);
+        cb(inpL, "inpL", -1);
 
-    const int32_t n_tokens    = batch.n_tokens;
-    const int32_t n_kv        = worst_case ? n_ctx            : kv_self.n;
-    const int32_t kv_head     = worst_case ? n_ctx - n_tokens : kv_self.head;
+        for (int il = 0; il < n_layer; ++il) {
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, cb, il);
+            cb(cur, "attn_norm", il);
 
-    const bool do_rope_shift  = worst_case || kv_self.has_shift;
+            // self-attention
+            {
+                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
 
-    auto & buf_compute = lctx.buf_compute;
+                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                cb(cur, "bqkv", il);
 
-    struct ggml_init_params params = {
-        /*.mem_size   =*/ buf_compute.size,
-        /*.mem_buffer =*/ buf_compute.data,
-        /*.no_alloc   =*/ true,
-    };
+                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
 
-    struct ggml_context * ctx0 = ggml_init(params);
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
-    ggml_cgraph * gf = ggml_new_graph(ctx0);
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
 
-    struct ggml_tensor * cur;
-    struct ggml_tensor * inpL;
+                llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
 
-    inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
-    cb(inpL, "imp_embd", -1);
+                cur = llm_build_kqv(ctx0, cur, hparams, kv_self,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
+                cb(cur, "kqv_out", il);
+            }
 
-    struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
-    cb(inp_pos, "inp_pos", -1);
+            // add the input
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+            cb(ffn_inp, "ffn_inp", il);
 
-    // KQ_scale
-    struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-    cb(KQ_scale, "KQ_scale", -1);
+            // FF
+            {
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm,
+                        model.layers[il].ffn_norm_b,
+                        LLM_NORM, cb, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = llm_build_ffn(ctx0, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
+                        NULL,                      NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
+                cb(cur, "ffn_out", il);
+            }
 
-    struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
-    cb(KQ_mask, "KQ_mask", -1);
+            inpL = ggml_add(ctx0, cur, ffn_inp);
+            cb(inpL, "l_out", il);
+        }
 
-    if (do_rope_shift) {
-        llm_build_k_shift(lctx, ctx0, gf, n_rot, LLM_ROPE_NEOX, cb);
-    }
+        cur = llm_build_norm(ctx0, inpL, hparams,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, cb, -1);
+        cb(cur, "result_norm", -1);
 
-    for (int il = 0; il < n_layer; ++il) {
-        struct ggml_tensor * residual = inpL;
+        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
 
-        cur = llm_build_norm(ctx0, inpL,
-                model.layers[il].attn_norm,
-                model.layers[il].attn_norm_b,
-                LLM_NORM, norm_eps, cb, il);
-        cb(cur, "attn_norm", il);
+        ggml_build_forward_expand(gf, cur);
 
-        // self attention
-        {
-            cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
-            cb(cur, "wqkv", il);
-
-            cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
-            cb(cur, "bqkv", il);
-
-            // split qkv
-            GGML_ASSERT(n_head_kv == n_head);
-
-            struct ggml_tensor * tmpqkv = ggml_reshape_4d(ctx0, cur, n_embd_head, 3, n_head, n_tokens);
-            cb(tmpqkv, "tmpqkv", il);
-
-            struct ggml_tensor * tmpqkv_perm = ggml_cont(ctx0, ggml_permute(ctx0, tmpqkv, 0, 3, 1, 2));
-            cb(tmpqkv_perm, "tmpqkv", il);
-
-            struct ggml_tensor * tmpq = ggml_view_3d(
-                    ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens,
-                    ggml_element_size(tmpqkv_perm) * n_embd_head,
-                    ggml_element_size(tmpqkv_perm) * n_embd_head * n_head,
-                    0
-                );
-            cb(tmpq, "tmpq", il);
-
-            struct ggml_tensor * tmpk = ggml_view_3d(
-                    ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens,
-                    ggml_element_size(tmpqkv_perm) * n_embd_head,
-                    ggml_element_size(tmpqkv_perm) * n_embd_head * n_head,
-                    ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens
-                );
-            cb(tmpk, "tmpk", il);
-
-            // Q/K Layernorm
-            tmpq = llm_build_norm(ctx0, tmpq,
-                    model.layers[il].attn_q_norm,
-                    model.layers[il].attn_q_norm_b,
-                    LLM_NORM, norm_eps, cb, il);
-            cb(tmpq, "tmpq", il);
-
-            tmpk = llm_build_norm(ctx0, tmpk,
-                    model.layers[il].attn_k_norm,
-                    model.layers[il].attn_k_norm_b,
-                    LLM_NORM, norm_eps, cb, il);
-            cb(tmpk, "tmpk", il);
-
-            // RoPE the first n_rot of q/k, pass the other half, and concat.
-            struct ggml_tensor * qrot = ggml_view_3d(
-                ctx0, tmpq, n_rot, n_head, n_tokens,
-                ggml_element_size(tmpq) * n_embd_head,
-                ggml_element_size(tmpq) * n_embd_head * n_head,
-                0
-            );
-            cb(qrot, "qrot", il);
+        return gf;
+    }
 
-            struct ggml_tensor * krot = ggml_view_3d(
-                ctx0, tmpk, n_rot, n_head, n_tokens,
-                ggml_element_size(tmpk) * n_embd_head,
-                ggml_element_size(tmpk) * n_embd_head * n_head,
-                0
-            );
-            cb(krot, "krot", il);
-
-            // get the second half of tmpq, e.g tmpq[n_rot:, :, :]
-            struct ggml_tensor * qpass = ggml_view_3d(
-                ctx0, tmpq, n_rot, n_head, n_tokens,
-                ggml_element_size(tmpq) * n_embd_head,
-                ggml_element_size(tmpq) * n_embd_head * n_head,
-                ggml_element_size(tmpq) * n_rot
-            );
-            cb(qpass, "qpass", il);
+    struct ggml_cgraph * build_persimmon() {
+        struct ggml_cgraph * gf = ggml_new_graph(ctx0);
 
-            struct ggml_tensor * kpass = ggml_view_3d(
-                ctx0, tmpk, n_rot, n_head, n_tokens,
-                ggml_element_size(tmpk) * n_embd_head,
-                ggml_element_size(tmpk) * n_embd_head * n_head,
-                ggml_element_size(tmpk) * n_rot
-            );
-            cb(kpass, "kpass", il);
+        const int64_t n_rot = n_embd_head / 2;
 
-            struct ggml_tensor * qrotated = ggml_rope_custom(
-                    ctx0, qrot, inp_pos, n_rot, 2, 0, freq_base, freq_scale
-            );
-            cb(qrotated, "qrotated", il);
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
 
-            struct ggml_tensor * krotated = ggml_rope_custom(
-                    ctx0, krot, inp_pos, n_rot, 2, 0, freq_base, freq_scale
-            );
-            cb(krotated, "krotated", il);
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+        cb(inpL, "imp_embd", -1);
 
-            // ggml currently only supports concatenation on dim=2
-            // so we need to permute qrot, qpass, concat, then permute back.
-            qrotated = ggml_cont(ctx0, ggml_permute(ctx0, qrotated, 2, 1, 0, 3));
-            cb(qrotated, "qrotated", il);
+        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+        cb(inp_pos, "inp_pos", -1);
 
-            krotated = ggml_cont(ctx0, ggml_permute(ctx0, krotated, 2, 1, 0, 3));
-            cb(krotated, "krotated", il);
+        // KQ_scale
+        struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+        cb(KQ_scale, "KQ_scale", -1);
 
-            qpass = ggml_cont(ctx0, ggml_permute(ctx0, qpass, 2, 1, 0, 3));
-            cb(qpass, "qpass", il);
+        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+        cb(KQ_mask, "KQ_mask", -1);
 
-            kpass = ggml_cont(ctx0, ggml_permute(ctx0, kpass, 2, 1, 0, 3));
-            cb(kpass, "kpass", il);
+        if (do_rope_shift) {
+            llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+        }
 
-            struct ggml_tensor * Qcur = ggml_concat(ctx0, qrotated, qpass);
-            cb(Qcur, "Qcur", il);
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * residual = inpL;
 
-            struct ggml_tensor * Kcur = ggml_concat(ctx0, krotated, kpass);
-            cb(Kcur, "Kcur", il);
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, cb, il);
+            cb(cur, "attn_norm", il);
 
-            struct ggml_tensor * Q = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 1, 2, 0, 3));
-            cb(Q, "Q", il);
+            // self attention
+            {
+                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                cb(cur, "bqkv", il);
+
+                // split qkv
+                GGML_ASSERT(n_head_kv == n_head);
+
+                struct ggml_tensor * tmpqkv = ggml_reshape_4d(ctx0, cur, n_embd_head, 3, n_head, n_tokens);
+                cb(tmpqkv, "tmpqkv", il);
+
+                struct ggml_tensor * tmpqkv_perm = ggml_cont(ctx0, ggml_permute(ctx0, tmpqkv, 0, 3, 1, 2));
+                cb(tmpqkv_perm, "tmpqkv", il);
+
+                struct ggml_tensor * tmpq = ggml_view_3d(
+                        ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens,
+                        ggml_element_size(tmpqkv_perm) * n_embd_head,
+                        ggml_element_size(tmpqkv_perm) * n_embd_head * n_head,
+                        0
+                        );
+                cb(tmpq, "tmpq", il);
+
+                struct ggml_tensor * tmpk = ggml_view_3d(
+                        ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens,
+                        ggml_element_size(tmpqkv_perm) * n_embd_head,
+                        ggml_element_size(tmpqkv_perm) * n_embd_head * n_head,
+                        ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens
+                        );
+                cb(tmpk, "tmpk", il);
+
+                // Q/K Layernorm
+                tmpq = llm_build_norm(ctx0, tmpq, hparams,
+                        model.layers[il].attn_q_norm,
+                        model.layers[il].attn_q_norm_b,
+                        LLM_NORM, cb, il);
+                cb(tmpq, "tmpq", il);
+
+                tmpk = llm_build_norm(ctx0, tmpk, hparams,
+                        model.layers[il].attn_k_norm,
+                        model.layers[il].attn_k_norm_b,
+                        LLM_NORM, cb, il);
+                cb(tmpk, "tmpk", il);
+
+                // RoPE the first n_rot of q/k, pass the other half, and concat.
+                struct ggml_tensor * qrot = ggml_view_3d(
+                        ctx0, tmpq, n_rot, n_head, n_tokens,
+                        ggml_element_size(tmpq) * n_embd_head,
+                        ggml_element_size(tmpq) * n_embd_head * n_head,
+                        0
+                        );
+                cb(qrot, "qrot", il);
+
+                struct ggml_tensor * krot = ggml_view_3d(
+                        ctx0, tmpk, n_rot, n_head, n_tokens,
+                        ggml_element_size(tmpk) * n_embd_head,
+                        ggml_element_size(tmpk) * n_embd_head * n_head,
+                        0
+                        );
+                cb(krot, "krot", il);
+
+                // get the second half of tmpq, e.g tmpq[n_rot:, :, :]
+                struct ggml_tensor * qpass = ggml_view_3d(
+                        ctx0, tmpq, n_rot, n_head, n_tokens,
+                        ggml_element_size(tmpq) * n_embd_head,
+                        ggml_element_size(tmpq) * n_embd_head * n_head,
+                        ggml_element_size(tmpq) * n_rot
+                        );
+                cb(qpass, "qpass", il);
+
+                struct ggml_tensor * kpass = ggml_view_3d(
+                        ctx0, tmpk, n_rot, n_head, n_tokens,
+                        ggml_element_size(tmpk) * n_embd_head,
+                        ggml_element_size(tmpk) * n_embd_head * n_head,
+                        ggml_element_size(tmpk) * n_rot
+                        );
+                cb(kpass, "kpass", il);
+
+                struct ggml_tensor * qrotated = ggml_rope_custom(
+                        ctx0, qrot, inp_pos, n_rot, 2, 0, freq_base, freq_scale
+                        );
+                cb(qrotated, "qrotated", il);
+
+                struct ggml_tensor * krotated = ggml_rope_custom(
+                        ctx0, krot, inp_pos, n_rot, 2, 0, freq_base, freq_scale
+                        );
+                cb(krotated, "krotated", il);
+
+                // ggml currently only supports concatenation on dim=2
+                // so we need to permute qrot, qpass, concat, then permute back.
+                qrotated = ggml_cont(ctx0, ggml_permute(ctx0, qrotated, 2, 1, 0, 3));
+                cb(qrotated, "qrotated", il);
+
+                krotated = ggml_cont(ctx0, ggml_permute(ctx0, krotated, 2, 1, 0, 3));
+                cb(krotated, "krotated", il);
+
+                qpass = ggml_cont(ctx0, ggml_permute(ctx0, qpass, 2, 1, 0, 3));
+                cb(qpass, "qpass", il);
+
+                kpass = ggml_cont(ctx0, ggml_permute(ctx0, kpass, 2, 1, 0, 3));
+                cb(kpass, "kpass", il);
+
+                struct ggml_tensor * Qcur = ggml_concat(ctx0, qrotated, qpass);
+                cb(Qcur, "Qcur", il);
+
+                struct ggml_tensor * Kcur = ggml_concat(ctx0, krotated, kpass);
+                cb(Kcur, "Kcur", il);
+
+                struct ggml_tensor * Q = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 1, 2, 0, 3));
+                cb(Q, "Q", il);
+
+                Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 2, 1, 0, 3));
+                cb(Kcur, "Kcur", il);
+
+                struct ggml_tensor * Vcur = ggml_view_3d(
+                        ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens,
+                        ggml_element_size(tmpqkv_perm) * n_embd_head,
+                        ggml_element_size(tmpqkv_perm) * n_embd_head * n_head,
+                        ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens * 2
+                        );
+                cb(Vcur, "Vcur", il);
+
+                llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
+
+                // TODO: not tested, could be broken
+                cur = llm_build_kqv(ctx0, Q, hparams, kv_self,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Q, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
+                cb(cur, "kqv_out", il);
+            }
 
-            Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 2, 1, 0, 3));
-            cb(Kcur, "Kcur", il);
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur);
+            cb(ffn_inp, "ffn_inp", il);
 
-            struct ggml_tensor * Vcur = ggml_view_3d(
-                    ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens,
-                    ggml_element_size(tmpqkv_perm) * n_embd_head,
-                    ggml_element_size(tmpqkv_perm) * n_embd_head * n_head,
-                    ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens * 2
-                    );
-            cb(Vcur, "Vcur", il);
+            // feed-forward network
+            {
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm,
+                        model.layers[il].ffn_norm_b,
+                        LLM_NORM, cb, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = llm_build_ffn(ctx0, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
+                        NULL,                      NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                        LLM_FFN_RELU_SQR, LLM_FFN_SEQ, cb, il);
+                cb(cur, "ffn_out", il);
+            }
 
-            llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "l_out", il);
 
-            // TODO: not tested, could be broken
-            cur = llm_build_kqv(lctx, ctx0, Q,
-                    model.layers[il].wo, model.layers[il].bo,
-                    Q, KQ_scale, KQ_mask, n_tokens, n_kv, -1.0f, cb, il);
-            cb(cur, "kqv_out", il);
+            inpL = cur;
         }
 
-        struct ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur);
-        cb(ffn_inp, "ffn_inp", il);
+        cur = inpL;
 
-        // feed-forward network
-        {
-            cur = llm_build_norm(ctx0, ffn_inp,
-                    model.layers[il].ffn_norm,
-                    model.layers[il].ffn_norm_b,
-                    LLM_NORM, norm_eps, cb, il);
-            cb(cur, "ffn_norm", il);
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, cb, -1);
+        cb(cur, "result_norm", -1);
 
-            cur = llm_build_ffn(ctx0, cur,
-                    model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
-                    NULL,                      NULL,
-                    model.layers[il].ffn_down, model.layers[il].ffn_down_b,
-                    LLM_FFN_RELU_SQR, LLM_FFN_SEQ, cb, il);
-            cb(cur, "ffn_out", il);
-        }
+        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
 
-        cur = ggml_add(ctx0, cur, ffn_inp);
-        cb(cur, "l_out", il);
+        ggml_build_forward_expand(gf, cur);
 
-        inpL = cur;
+        return gf;
     }
 
-    cur = inpL;
-
-    cur = llm_build_norm(ctx0, cur,
-            model.output_norm,
-            model.output_norm_b,
-            LLM_NORM, norm_eps, cb, -1);
-    cb(cur, "result_norm", -1);
-
-    cur = ggml_mul_mat(ctx0, model.output, cur);
-    cb(cur, "result_output", -1);
-
-    ggml_build_forward_expand(gf, cur);
-
-    ggml_free(ctx0);
-
-    return gf;
-}
-
-static struct ggml_cgraph * llm_build_refact(
-         llama_context & lctx,
-     const llama_batch & batch,
-    const llm_build_cb & cb,
-                  bool   worst_case) {
-    const auto & model   = lctx.model;
-    const auto & hparams = model.hparams;
-    const auto & cparams = lctx.cparams;
-
-    const auto & kv_self = lctx.kv_self;
-
-    GGML_ASSERT(!!kv_self.ctx);
-
-    const int64_t n_embd      = hparams.n_embd;
-    const int64_t n_layer     = hparams.n_layer;
-    const int64_t n_ctx       = cparams.n_ctx;
-    const int64_t n_head      = hparams.n_head;
-    const int64_t n_head_kv   = hparams.n_head_kv;
-    const int64_t n_embd_head = hparams.n_embd_head();
-
-    const float norm_rms_eps = hparams.f_norm_rms_eps;
-
-    const int32_t n_tokens = batch.n_tokens;
-    const int32_t n_kv     = worst_case ? n_ctx            : kv_self.n;
-    const int32_t kv_head  = worst_case ? n_ctx - n_tokens : kv_self.head;
+    struct ggml_cgraph * build_refact() {
+        struct ggml_cgraph * gf = ggml_new_graph(ctx0);
 
-    auto & buf_compute = lctx.buf_compute;
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
 
-    struct ggml_init_params params = {
-        /*.mem_size   =*/ buf_compute.size,
-        /*.mem_buffer =*/ buf_compute.data,
-        /*.no_alloc   =*/ true,
-    };
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+        cb(inpL, "inp_embd", -1);
 
-    struct ggml_context * ctx0 = ggml_init(params);
+        // KQ_scale
+        struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+        cb(KQ_scale, "KQ_scale", -1);
 
-    ggml_cgraph * gf = ggml_new_graph(ctx0);
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+        cb(KQ_mask, "KQ_mask", -1);
 
-    struct ggml_tensor * cur;
-    struct ggml_tensor * inpL;
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
 
-    inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
-    cb(inpL, "inp_embd", -1);
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
 
-    // KQ_scale
-    struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-    cb(KQ_scale, "KQ_scale", -1);
+            // self-attention
+            {
+                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
 
-    // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-    struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
-    cb(KQ_mask, "KQ_mask", -1);
+                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
 
-    for (int il = 0; il < n_layer; ++il) {
-        struct ggml_tensor * inpSA = inpL;
+                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
 
-        cur = llm_build_norm(ctx0, inpL,
-                model.layers[il].attn_norm, NULL,
-                LLM_NORM_RMS, norm_rms_eps, cb, il);
-        cb(cur, "attn_norm", il);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                cb(Kcur, "Kcur", il);
 
-        // self-attention
-        {
-            struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
-            cb(Qcur, "Qcur", il);
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                cb(Qcur, "Qcur", il);
 
-            struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
-            cb(Kcur, "Kcur", il);
+                llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
 
-            struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
-            cb(Vcur, "Vcur", il);
+                cur = llm_build_kqv(ctx0, Qcur, hparams, kv_self,
+                        model.layers[il].wo, NULL,
+                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, cb, il);
+                cb(cur, "kqv_out", il);
+            }
 
-            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-            cb(Kcur, "Kcur", il);
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
 
-            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
-            cb(Qcur, "Qcur", il);
+            // feed-forward network
+            {
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = llm_build_ffn(ctx0, cur,
+                        model.layers[il].ffn_up,   NULL,
+                        model.layers[il].ffn_gate, NULL,
+                        model.layers[il].ffn_down, NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+                cb(cur, "ffn_out", il);
+            }
 
-            llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "l_out", il);
 
-            cur = llm_build_kqv(lctx, ctx0, Qcur,
-                    model.layers[il].wo, NULL,
-                    Qcur, KQ_scale, KQ_mask, n_tokens, n_kv, 8.0f, cb, il);
-            cb(cur, "kqv_out", il);
+            // input for next layer
+            inpL = cur;
         }
 
-        struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
-        cb(ffn_inp, "ffn_inp", il);
+        cur = inpL;
 
-        // feed-forward network
-        {
-            cur = llm_build_norm(ctx0, ffn_inp,
-                    model.layers[il].ffn_norm, NULL,
-                    LLM_NORM_RMS, norm_rms_eps, cb, il);
-            cb(cur, "ffn_norm", il);
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
 
-            cur = llm_build_ffn(ctx0, cur,
-                    model.layers[il].ffn_up,   NULL,
-                    model.layers[il].ffn_gate, NULL,
-                    model.layers[il].ffn_down, NULL,
-                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
-            cb(cur, "ffn_out", il);
-        }
+        // lm_head
+        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
 
-        cur = ggml_add(ctx0, cur, ffn_inp);
-        cb(cur, "l_out", il);
+        ggml_build_forward_expand(gf, cur);
 
-        // input for next layer
-        inpL = cur;
+        return gf;
     }
 
-    cur = inpL;
-
-    cur = llm_build_norm(ctx0, cur,
-            model.output_norm, NULL,
-            LLM_NORM_RMS, norm_rms_eps, cb, -1);
-    cb(cur, "result_norm", -1);
-
-    // lm_head
-    cur = ggml_mul_mat(ctx0, model.output, cur);
-    cb(cur, "result_output", -1);
-
-    ggml_build_forward_expand(gf, cur);
-
-    ggml_free(ctx0);
-
-    return gf;
-}
-
-static struct ggml_cgraph * llm_build_bloom(
-         llama_context & lctx,
-     const llama_batch & batch,
-    const llm_build_cb & cb,
-                  bool   worst_case) {
-    const auto & model   = lctx.model;
-    const auto & hparams = model.hparams;
-    const auto & cparams = lctx.cparams;
-
-    const auto & kv_self = lctx.kv_self;
-
-    GGML_ASSERT(!!kv_self.ctx);
-
-    const int64_t n_embd      = hparams.n_embd;
-    const int64_t n_layer     = hparams.n_layer;
-    const int64_t n_ctx       = cparams.n_ctx;
-    const int64_t n_head      = hparams.n_head;
-    const int64_t n_embd_head = hparams.n_embd_head();
-    const int64_t n_embd_gqa  = hparams.n_embd_gqa();
-
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-    const float norm_eps = hparams.f_norm_eps;
-
-    const int32_t n_tokens = batch.n_tokens;
-    const int32_t n_kv     = worst_case ? n_ctx            : kv_self.n;
-    const int32_t kv_head  = worst_case ? n_ctx - n_tokens : kv_self.head;
-
-    auto & buf_compute = lctx.buf_compute;
-
-    struct ggml_init_params params = {
-        /*.mem_size   =*/ buf_compute.size,
-        /*.mem_buffer =*/ buf_compute.data,
-        /*.no_alloc   =*/ false,
-    };
+    struct ggml_cgraph * build_bloom() {
+        struct ggml_cgraph * gf = ggml_new_graph(ctx0);
 
-    params.no_alloc = true;
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
 
-    struct ggml_context * ctx0 = ggml_init(params);
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+        cb(inpL, "inp_embd", -1);
 
-    ggml_cgraph * gf = ggml_new_graph(ctx0);
+        // KQ_scale
+        struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+        cb(KQ_scale, "KQ_scale", -1);
 
-    struct ggml_tensor * cur;
-    struct ggml_tensor * inpL;
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+        cb(KQ_mask, "KQ_mask", -1);
 
-    inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
-    cb(inpL, "inp_embd", -1);
+        inpL = llm_build_norm(ctx0, inpL, hparams,
+                model.tok_norm,
+                model.tok_norm_b,
+                LLM_NORM, cb, -1);
+        cb(inpL, "inp_norm", -1);
 
-    // KQ_scale
-    struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-    cb(KQ_scale, "KQ_scale", -1);
+        for (int il = 0; il < n_layer; ++il) {
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, cb, il);
+            cb(cur, "attn_norm", il);
 
-    // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-    struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
-    cb(KQ_mask, "KQ_mask", -1);
+            // self-attention
+            {
+                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
 
-    inpL = llm_build_norm(ctx0, inpL,
-            model.tok_norm,
-            model.tok_norm_b,
-            LLM_NORM, norm_eps, cb, -1);
-    cb(inpL, "inp_norm", -1);
+                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                cb(cur, "bqkv", il);
 
-    for (int il = 0; il < n_layer; ++il) {
-        cur = llm_build_norm(ctx0, inpL,
-                model.layers[il].attn_norm,
-                model.layers[il].attn_norm_b,
-                LLM_NORM, norm_eps, cb, il);
-        cb(cur, "attn_norm", il);
+                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
 
-        // self-attention
-        {
-            cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
-            cb(cur, "wqkv", il);
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
-            cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
-            cb(cur, "bqkv", il);
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
 
-            struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-            struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-            struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+                llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
 
-            cb(Qcur, "Qcur", il);
-            cb(Kcur, "Kcur", il);
-            cb(Vcur, "Vcur", il);
+                cur = llm_build_kqv(ctx0, Qcur, hparams, kv_self,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, cb, il);
+                cb(cur, "kqv_out", il);
+            }
 
-            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+            // Add the input
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+            cb(ffn_inp, "ffn_inp", il);
 
-            llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
+            // FF
+            {
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm,
+                        model.layers[il].ffn_norm_b,
+                        LLM_NORM, cb, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = llm_build_ffn(ctx0, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
+                        NULL,                      NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
+                cb(cur, "ffn_out", il);
+            }
 
-            cur = llm_build_kqv(lctx, ctx0, Qcur,
-                    model.layers[il].wo, model.layers[il].bo,
-                    Qcur, KQ_scale, KQ_mask, n_tokens, n_kv, 8.0f, cb, il);
-            cb(cur, "kqv_out", il);
+            inpL = ggml_add(ctx0, cur, ffn_inp);
+            cb(inpL, "l_out", il);
         }
 
-        // Add the input
-        struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
-        cb(ffn_inp, "ffn_inp", il);
+        cur = llm_build_norm(ctx0, inpL, hparams,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, cb, -1);
+        cb(cur, "result_norm", -1);
 
-        // FF
-        {
-            cur = llm_build_norm(ctx0, ffn_inp,
-                    model.layers[il].ffn_norm,
-                    model.layers[il].ffn_norm_b,
-                    LLM_NORM, norm_eps, cb, il);
-            cb(cur, "ffn_norm", il);
+        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
 
-            cur = llm_build_ffn(ctx0, cur,
-                    model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
-                    NULL,                      NULL,
-                    model.layers[il].ffn_down, model.layers[il].ffn_down_b,
-                    LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
-            cb(cur, "ffn_out", il);
-        }
+        ggml_build_forward_expand(gf, cur);
 
-        inpL = ggml_add(ctx0, cur, ffn_inp);
-        cb(inpL, "l_out", il);
+        return gf;
     }
 
-    cur = llm_build_norm(ctx0, inpL,
-            model.output_norm,
-            model.output_norm_b,
-            LLM_NORM, norm_eps, cb, -1);
-    cb(cur, "result_norm", -1);
+    struct ggml_cgraph * build_mpt() {
+        struct ggml_cgraph * gf = ggml_new_graph(ctx0);
 
-    cur = ggml_mul_mat(ctx0, model.output, cur);
-    cb(cur, "result_output", -1);
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
 
-    ggml_build_forward_expand(gf, cur);
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+        cb(inpL, "inp_embd", -1);
 
-    ggml_free(ctx0);
+        // KQ_scale
+        struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+        cb(KQ_scale, "KQ_scale", -1);
 
-    return gf;
-}
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+        cb(KQ_mask, "KQ_mask", -1);
 
-static struct ggml_cgraph * llm_build_mpt(
-         llama_context & lctx,
-     const llama_batch & batch,
-    const llm_build_cb & cb,
-                  bool   worst_case) {
-    const auto & model   = lctx.model;
-    const auto & hparams = model.hparams;
-    const auto & cparams = lctx.cparams;
-
-    const auto & kv_self = lctx.kv_self;
-
-    GGML_ASSERT(!!kv_self.ctx);
-
-    const int64_t n_embd      = hparams.n_embd;
-    const int64_t n_layer     = hparams.n_layer;
-    const int64_t n_ctx       = cparams.n_ctx;
-    const int64_t n_head      = hparams.n_head;
-    const int64_t n_embd_head = hparams.n_embd_head();
-    const int64_t n_embd_gqa  = hparams.n_embd_gqa();
-
-    const float norm_eps       = hparams.f_norm_eps;
-    const float clamp_kqv      = hparams.f_clamp_kqv;
-    const float max_alibi_bias = hparams.f_max_alibi_bias;
-
-    const int32_t n_tokens = batch.n_tokens;
-    const int32_t n_kv     = worst_case ? n_ctx            : kv_self.n;
-    const int32_t kv_head  = worst_case ? n_ctx - n_tokens : kv_self.head;
-
-    auto & buf_compute = lctx.buf_compute;
-
-    struct ggml_init_params params = {
-        /*.mem_size   =*/ buf_compute.size,
-        /*.mem_buffer =*/ buf_compute.data,
-        /*.no_alloc   =*/ false,
-    };
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * attn_norm;
 
-    params.no_alloc = true;
-
-    struct ggml_context * ctx0 = ggml_init(params);
-
-    ggml_cgraph * gf = ggml_new_graph(ctx0);
-
-    struct ggml_tensor * cur;
-    struct ggml_tensor * inpL;
+            attn_norm = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm,
+                    NULL,
+                    LLM_NORM, cb, il);
+            cb(attn_norm, "attn_norm", il);
 
-    inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
-    cb(inpL, "inp_embd", -1);
+            // self-attention
+            {
+                cur = attn_norm;
 
-    // KQ_scale
-    struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-    cb(KQ_scale, "KQ_scale", -1);
+                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
 
-    // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-    struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
-    cb(KQ_mask, "KQ_mask", -1);
+                if (hparams.f_clamp_kqv > 0.0f) {
+                    cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
+                    cb(cur, "wqkv_clamped", il);
+                }
 
-    for (int il = 0; il < n_layer; ++il) {
-        struct ggml_tensor * attn_norm;
+                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
 
-        attn_norm = llm_build_norm(ctx0, inpL,
-                model.layers[il].attn_norm,
-                NULL,
-                LLM_NORM, norm_eps, cb, il);
-        cb(attn_norm, "attn_norm", il);
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
-        // self-attention
-        {
-            cur = attn_norm;
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
 
-            cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
-            cb(cur, "wqkv", il);
+                llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
 
-            if (clamp_kqv > 0.0f) {
-                cur = ggml_clamp(ctx0, cur, -clamp_kqv, clamp_kqv);
-                cb(cur, "wqkv_clamped", il);
+                cur = llm_build_kqv(ctx0, Qcur, hparams, kv_self,
+                        model.layers[il].wo, NULL,
+                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, hparams.f_max_alibi_bias, cb, il);
+                cb(cur, "kqv_out", il);
             }
 
-            struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-            struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-            struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
-            cb(Qcur, "Qcur", il);
-            cb(Kcur, "Kcur", il);
-            cb(Vcur, "Vcur", il);
+            // Add the input
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+            cb(ffn_inp, "ffn_inp", il);
 
-            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+            // feed forward
+            {
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm,
+                        NULL,
+                        LLM_NORM, cb, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = llm_build_ffn(ctx0, cur,
+                        model.layers[il].ffn_up,   NULL,
+                        NULL,                      NULL,
+                        model.layers[il].ffn_down, NULL,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
+                cb(cur, "ffn_out", il);
+            }
 
-            llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "l_out", il);
 
-            cur = llm_build_kqv(lctx, ctx0, Qcur,
-                    model.layers[il].wo, NULL,
-                    Qcur, KQ_scale, KQ_mask, n_tokens, n_kv, max_alibi_bias, cb, il);
-            cb(cur, "kqv_out", il);
+            // input for next layer
+            inpL = cur;
         }
 
-        // Add the input
-        struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
-        cb(ffn_inp, "ffn_inp", il);
+        cur = inpL;
 
-        // feed forward
-        {
-            cur = llm_build_norm(ctx0, ffn_inp,
-                    model.layers[il].ffn_norm,
-                    NULL,
-                    LLM_NORM, norm_eps, cb, il);
-            cb(cur, "ffn_norm", il);
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm,
+                NULL,
+                LLM_NORM, cb, -1);
+        cb(cur, "result_norm", -1);
 
-            cur = llm_build_ffn(ctx0, cur,
-                    model.layers[il].ffn_up,   NULL,
-                    NULL,                      NULL,
-                    model.layers[il].ffn_down, NULL,
-                    LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
-            cb(cur, "ffn_out", il);
-        }
+        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
 
-        cur = ggml_add(ctx0, cur, ffn_inp);
-        cb(cur, "l_out", il);
+        ggml_build_forward_expand(gf, cur);
 
-        // input for next layer
-        inpL = cur;
+        return gf;
     }
-
-    cur = inpL;
-
-    cur = llm_build_norm(ctx0, cur,
-            model.output_norm,
-            NULL,
-            LLM_NORM, norm_eps, cb, -1);
-    cb(cur, "result_norm", -1);
-
-    cur = ggml_mul_mat(ctx0, model.output, cur);
-    cb(cur, "result_output", -1);
-
-    ggml_build_forward_expand(gf, cur);
-
-    ggml_free(ctx0);
-
-    return gf;
-}
+};
 
 //
 // tensor offloading helpers
@@ -5122,43 +4862,49 @@ static struct ggml_cgraph * llama_build_graph(
 
     struct ggml_cgraph * result = NULL;
 
+    struct llm_build_context llm(lctx, batch, cb, worst_case);
+
+    llm.init();
+
     switch (model.arch) {
         case LLM_ARCH_LLAMA:
             {
-                result = llm_build_llama(lctx, batch, cb, worst_case);
+                result = llm.build_llama();
             } break;
         case LLM_ARCH_BAICHUAN:
             {
-                result = llm_build_baichaun(lctx, batch, cb, worst_case);
+                result = llm.build_baichuan();
             } break;
         case LLM_ARCH_FALCON:
             {
-                result = llm_build_falcon(lctx, batch, cb, worst_case);
+                result = llm.build_falcon();
             } break;
         case LLM_ARCH_STARCODER:
             {
-                result = llm_build_starcoder(lctx, batch, cb, worst_case);
+                result = llm.build_starcoder();
             } break;
         case LLM_ARCH_PERSIMMON:
             {
-                result = llm_build_persimmon(lctx, batch, cb, worst_case);
+                result = llm.build_persimmon();
             } break;
         case LLM_ARCH_REFACT:
             {
-                result = llm_build_refact(lctx, batch, cb, worst_case);
+                result = llm.build_refact();
             } break;
         case LLM_ARCH_BLOOM:
             {
-                result = llm_build_bloom(lctx, batch, cb, worst_case);
+                result = llm.build_bloom();
             } break;
         case LLM_ARCH_MPT:
             {
-                result = llm_build_mpt(lctx, batch, cb, worst_case);
+                result = llm.build_mpt();
             } break;
         default:
             GGML_ASSERT(false);
     }
 
+    llm.free();
+
     if (worst_case) {
         int n_non_view_total = 0;