]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add `JAIS` model(s) (#8118)
authorFaisal Zaghloul <redacted>
Tue, 2 Jul 2024 14:36:00 +0000 (10:36 -0400)
committerGitHub <redacted>
Tue, 2 Jul 2024 14:36:00 +0000 (16:36 +0200)
* Add `JAIS` model(s)

* cleanup

* address review comments

* remove hack

* un-hardcode max-alibi-bias

* minor tweaks

---------

Co-authored-by: fmz <redacted>
convert-hf-to-gguf-update.py
convert-hf-to-gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
include/llama.h
src/llama.cpp

index 2758214fa8730bfa284055f95a2aacc5768a0726..944e9d15abdcc447b85351e3ed6baa95bff8980f 100755 (executable)
@@ -86,6 +86,7 @@ models = [
     {"name": "poro-chat",      "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Poro-34B-chat", },
     {"name": "jina-v2-code",   "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-code", },
     {"name": "viking",         "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Viking-7B", }, # Also used for Viking 13B and 33B
+    {"name": "jais",           "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", },
 ]
 
 
index 05fd70171de384899614ea1591f5da93bca06114..6add27cbb11a6dcd12283421f3f5fe76c89d74bc 100755 (executable)
@@ -490,6 +490,9 @@ class Model:
         if chkhsh == "7fc505bd3104ca1083b150b17d088b59534ede9bde81f0dd2090967d7fe52cee":
             # ref: https://huggingface.co/LumiOpen/Viking-7B
             res = "viking"
+        if chkhsh == "b53802fb28e26d645c3a310b34bfe07da813026ec7c7716883404d5e0f8b1901":
+            # ref: https://huggingface.co/core42/jais-13b
+            res = "jais"
 
         if res is None:
             logger.warning("\n")
@@ -2965,6 +2968,96 @@ class T5Model(Model):
         return [(self.map_tensor_name(name), data_torch)]
 
 
+@Model.register("JAISLMHeadModel")
+class JaisModel(Model):
+    model_arch = gguf.MODEL_ARCH.JAIS
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        # SwigLU activation
+        assert self.hparams["activation_function"] == "swiglu"
+        # ALiBi position embedding
+        assert self.hparams["position_embedding_type"] == "alibi"
+
+        # Embeddings scale
+        self.embeddings_scale = 1.0
+        # note: For some JAIS flavors, output is tied to (same as) wte in original model
+        self.output_is_wte = False
+        if 'mup_embeddings_scale' in self.hparams:
+            self.output_is_wte = True   # Hack (?)
+            self.embeddings_scale = self.hparams['mup_embeddings_scale']
+        elif 'embeddings_scale' in self.hparams:
+            self.embeddings_scale = self.hparams['embeddings_scale']
+        else:
+            assert False
+
+        self.width_scale = 1.0
+        if 'mup_output_alpha' in self.hparams:
+            assert 'mup_width_scale' in self.hparams
+            self.width_scale = self.hparams['mup_output_alpha'] * self.hparams['mup_width_scale']
+        elif 'width_scale' in self.hparams:
+            self.width_scale = self.hparams['width_scale']
+        else:
+            assert False
+
+        self.max_alibi_bias = 8.0
+
+    def set_vocab(self):
+        self._set_vocab_gpt2()
+
+    def set_gguf_parameters(self):
+        self.gguf_writer.add_name(self.dir_model.name)
+        self.gguf_writer.add_block_count(self.hparams["n_layer"])
+        self.gguf_writer.add_context_length(self.hparams["n_positions"])
+        self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
+        self.gguf_writer.add_feed_forward_length(self.hparams["n_inner"])
+        self.gguf_writer.add_head_count(self.hparams["n_head"])
+        self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
+        self.gguf_writer.add_file_type(self.ftype)
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        del bid  # unused
+
+        tensors: list[tuple[str, Tensor]] = []
+
+        # we don't need these
+        if name.endswith((".attn.bias")):
+            return tensors
+
+        if name.endswith(("relative_pe.slopes")):
+            # Calculate max ALiBi bias (this is the inverse of the ALiBi calculation)
+            # Some other models has max_alibi_bias spelled out explicitly in the hyperparams,
+            # but Jais's PyTorch model simply precalculates the slope values and places them
+            # in relative_pes.slopes
+            n_head_closest_log2 = 2 ** math.floor(math.log2(self.hparams["n_head"]))
+            first_val = float(data_torch._data[0])
+            self.max_alibi_bias = -round(math.log2(first_val) * n_head_closest_log2)
+
+            return tensors
+
+        if name.endswith((".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_fc2.weight")):
+            data_torch = data_torch.transpose(1, 0)
+
+        new_name = self.map_tensor_name(name)
+
+        if new_name == self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD):
+            tensors.append((new_name, data_torch * self.embeddings_scale))
+            if self.output_is_wte:
+                tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch * self.width_scale))
+        elif new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT):
+            assert not self.output_is_wte
+            tensors.append((new_name, data_torch * self.width_scale))
+        else:
+            tensors.append((new_name, data_torch))
+
+        return tensors
+
+    def write_tensors(self):
+        super().write_tensors()
+        self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias)
+
+
 ###### CONVERSION LOGIC ######
 
 
index e87c58266158a1a7c70eb4c051b934b90868ffa1..419f10cee74a7ad4a546d0b081713eda4c0e8b6e 100644 (file)
@@ -164,6 +164,7 @@ class MODEL_ARCH(IntEnum):
     DEEPSEEK2    = auto()
     BITNET       = auto()
     T5           = auto()
+    JAIS         = auto()
 
 
 class MODEL_TENSOR(IntEnum):
@@ -288,6 +289,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.DEEPSEEK2:      "deepseek2",
     MODEL_ARCH.BITNET:         "bitnet",
     MODEL_ARCH.T5:             "t5",
+    MODEL_ARCH.JAIS:           "jais",
 }
 
 TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -954,6 +956,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.ENC_FFN_UP,
         MODEL_TENSOR.ENC_OUTPUT_NORM,
     ],
+    MODEL_ARCH.JAIS: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_QKV,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_UP,
+    ],
     # TODO
 }
 
index 0bed439397bcdbba3f28558a212de2aca3145157..20e28423b9a99ce521c38c86e240485d8dcf0bb3 100644 (file)
@@ -10,7 +10,7 @@ class TensorNameMap:
         # Token embeddings
         MODEL_TENSOR.TOKEN_EMBD: (
             "gpt_neox.embed_in",                         # gptneox
-            "transformer.wte",                           # gpt2 gpt-j mpt refact qwen dbrx
+            "transformer.wte",                           # gpt2 gpt-j mpt refact qwen dbrx jais
             "transformer.word_embeddings",               # falcon
             "word_embeddings",                           # bloom
             "model.embed_tokens",                        # llama-hf
@@ -49,7 +49,7 @@ class TensorNameMap:
         # Output
         MODEL_TENSOR.OUTPUT: (
             "embed_out",                 # gptneox
-            "lm_head",                   # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx
+            "lm_head",                   # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais
             "output",                    # llama-pth bloom internlm2
             "word_embeddings_for_head",  # persimmon
             "lm_head.linear",            # phi2
@@ -58,7 +58,7 @@ class TensorNameMap:
         # Output norm
         MODEL_TENSOR.OUTPUT_NORM: (
             "gpt_neox.final_layer_norm",               # gptneox
-            "transformer.ln_f",                        # gpt2 gpt-j falcon
+            "transformer.ln_f",                        # gpt2 gpt-j falcon jais
             "model.norm",                              # llama-hf baichuan internlm2
             "norm",                                    # llama-pth
             "transformer.norm_f",                      # mpt dbrx
@@ -81,7 +81,7 @@ class TensorNameMap:
         # Attention norm
         MODEL_TENSOR.ATTN_NORM: (
             "gpt_neox.layers.{bid}.input_layernorm",                # gptneox
-            "transformer.h.{bid}.ln_1",                             # gpt2 gpt-j refact qwen
+            "transformer.h.{bid}.ln_1",                             # gpt2 gpt-j refact qwen jais
             "transformer.blocks.{bid}.norm_1",                      # mpt
             "transformer.h.{bid}.input_layernorm",                  # falcon7b
             "h.{bid}.input_layernorm",                              # bloom
@@ -109,7 +109,7 @@ class TensorNameMap:
         # Attention query-key-value
         MODEL_TENSOR.ATTN_QKV: (
             "gpt_neox.layers.{bid}.attention.query_key_value",                     # gptneox
-            "transformer.h.{bid}.attn.c_attn",                                     # gpt2 qwen
+            "transformer.h.{bid}.attn.c_attn",                                     # gpt2 qwen jais
             "transformer.blocks.{bid}.attn.Wqkv",                                  # mpt
             "transformer.blocks.{bid}.norm_attn_norm.attn.Wqkv",                   # dbrx
             "transformer.h.{bid}.self_attention.query_key_value",                  # falcon
@@ -160,7 +160,7 @@ class TensorNameMap:
         # Attention output
         MODEL_TENSOR.ATTN_OUT: (
             "gpt_neox.layers.{bid}.attention.dense",                        # gptneox
-            "transformer.h.{bid}.attn.c_proj",                              # gpt2 refact qwen
+            "transformer.h.{bid}.attn.c_proj",                              # gpt2 refact qwen jais
             "transformer.blocks.{bid}.attn.out_proj",                       # mpt
             "transformer.h.{bid}.self_attention.dense",                     # falcon
             "h.{bid}.self_attention.dense",                                 # bloom
@@ -202,7 +202,7 @@ class TensorNameMap:
         # Feed-forward norm
         MODEL_TENSOR.FFN_NORM: (
             "gpt_neox.layers.{bid}.post_attention_layernorm",                # gptneox
-            "transformer.h.{bid}.ln_2",                                      # gpt2 refact qwen
+            "transformer.h.{bid}.ln_2",                                      # gpt2 refact qwen jais
             "h.{bid}.post_attention_layernorm",                              # bloom
             "transformer.blocks.{bid}.norm_2",                               # mpt
             "model.layers.{bid}.post_attention_layernorm",                   # llama-hf
@@ -239,7 +239,7 @@ class TensorNameMap:
         # Feed-forward up
         MODEL_TENSOR.FFN_UP: (
             "gpt_neox.layers.{bid}.mlp.dense_h_to_4h",                # gptneox
-            "transformer.h.{bid}.mlp.c_fc",                           # gpt2
+            "transformer.h.{bid}.mlp.c_fc",                           # gpt2 jais
             "transformer.blocks.{bid}.ffn.up_proj",                   # mpt
             "transformer.h.{bid}.mlp.dense_h_to_4h",                  # falcon
             "h.{bid}.mlp.dense_h_to_4h",                              # bloom
@@ -285,6 +285,7 @@ class TensorNameMap:
             "model.layers.{bid}.mlp.gate_proj",           # llama-hf refact
             "layers.{bid}.feed_forward.w1",               # llama-pth
             "transformer.h.{bid}.mlp.w2",                 # qwen
+            "transformer.h.{bid}.mlp.c_fc2",              # jais
             "model.layers.layers.{bid}.mlp.gate_proj",    # plamo
             "model.layers.{bid}.feed_forward.w1",         # internlm2
             "encoder.layers.{bid}.mlp.fc12",              # nomic-bert
@@ -308,7 +309,7 @@ class TensorNameMap:
         # Feed-forward down
         MODEL_TENSOR.FFN_DOWN: (
             "gpt_neox.layers.{bid}.mlp.dense_4h_to_h",                # gptneox
-            "transformer.h.{bid}.mlp.c_proj",                         # gpt2 refact qwen
+            "transformer.h.{bid}.mlp.c_proj",                         # gpt2 refact qwen jais
             "transformer.blocks.{bid}.ffn.down_proj",                 # mpt
             "transformer.h.{bid}.mlp.dense_4h_to_h",                  # falcon
             "h.{bid}.mlp.dense_4h_to_h",                              # bloom
index cafeafb85dbc7f5dfddef9b8eea437f188805b89..c5b61829204285991c73762f1cf7f227f766593d 100644 (file)
@@ -89,6 +89,7 @@ extern "C" {
         LLAMA_VOCAB_PRE_TYPE_SMAUG          = 14,
         LLAMA_VOCAB_PRE_TYPE_PORO           = 15,
         LLAMA_VOCAB_PRE_TYPE_VIKING         = 16,
+        LLAMA_VOCAB_PRE_TYPE_JAIS           = 17,
     };
 
     // note: these values should be synchronized with ggml_rope
index eea532f6ac2ff32115b6277cb2155d16d9e0f724..73f52435a503efe62d35c6323962850f9580ed68 100644 (file)
@@ -228,6 +228,7 @@ enum llm_arch {
     LLM_ARCH_DEEPSEEK2,
     LLM_ARCH_BITNET,
     LLM_ARCH_T5,
+    LLM_ARCH_JAIS,
     LLM_ARCH_UNKNOWN,
 };
 
@@ -269,6 +270,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_DEEPSEEK2,       "deepseek2"    },
     { LLM_ARCH_BITNET,          "bitnet"       },
     { LLM_ARCH_T5,              "t5"           },
+    { LLM_ARCH_JAIS,            "jais"         },
     { LLM_ARCH_UNKNOWN,         "(unknown)"    },
 };
 
@@ -1236,6 +1238,21 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
             { LLM_TENSOR_ENC_FFN_UP,           "enc.blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_JAIS,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+        },
+    },
     {
         LLM_ARCH_UNKNOWN,
         {
@@ -2035,6 +2052,7 @@ enum e_model {
     MODEL_410M,
     MODEL_0_5B,
     MODEL_1B,
+    MODEL_1_3B,
     MODEL_1_4B,
     MODEL_2B,
     MODEL_2_8B,
@@ -4276,6 +4294,7 @@ static const char * llama_model_type_name(e_model type) {
         case MODEL_410M:          return "410M";
         case MODEL_0_5B:          return "0.5B";
         case MODEL_1B:            return "1B";
+        case MODEL_1_3B:          return "1.3B";
         case MODEL_1_4B:          return "1.4B";
         case MODEL_2B:            return "2B";
         case MODEL_2_8B:          return "2.8B";
@@ -4898,6 +4917,18 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_JAIS:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
+                ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias);
+
+                switch (hparams.n_layer) {
+                    case 24: model.type = e_model::MODEL_1_3B; break;
+                    case 40: model.type = e_model::MODEL_13B; break;
+                    /* TODO: add variants */
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
         default: (void)0;
     }
 
@@ -5129,6 +5160,9 @@ static void llm_load_vocab(
             } else if (
                 tokenizer_pre == "viking") {
                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_VIKING;
+            } else if (
+                tokenizer_pre == "jais") {
+                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS;
             } else {
                 throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
             }
@@ -6962,6 +6996,44 @@ static bool llm_load_tensors(
                         layer.ffn_up_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "scale", i), {1});
                     }
                 } break;
+            case LLM_ARCH_JAIS:
+                {
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+                    // Output
+                    {
+                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output_norm_b = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
+                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
+                    }
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM,   "weight", i), {n_embd});
+                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM,   "bias", i),   {n_embd});
+
+                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
+                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa});
+
+                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.bo   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd});
+
+                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd});
+
+                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
+                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd});
+
+                        layer.ffn_gate     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE,   "weight", i), {n_embd, n_ff});
+                        layer.ffn_gate_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE,   "bias", i),   {n_ff});
+
+                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
+                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff});
+                    }
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -12354,6 +12426,97 @@ struct llm_build_context {
         return gf;
     }
 
+    struct ggml_cgraph * build_jais() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                cb(cur, "bqkv", il);
+
+                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*cur->nb[0]*(n_embd)));
+                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd)));
+                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa)));
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/float(n_embd_head), cb, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            // add the input
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // FF
+            {
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm,
+                        model.layers[il].ffn_norm_b,
+                        LLM_NORM, cb, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = llm_build_ffn(ctx0, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            inpL = ggml_add(ctx0, cur, ffn_inp);
+            cb(inpL, "l_out", il);
+        }
+
+        cur = llm_build_norm(ctx0, inpL, hparams,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        cur = ggml_mul_mat(ctx0, model.output, cur);
+
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
 };
 
 static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@@ -12585,6 +12748,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_bitnet();
             } break;
+        case LLM_ARCH_JAIS:
+            {
+                result = llm.build_jais();
+            } break;
         default:
             GGML_ASSERT(false);
     }
@@ -13947,6 +14114,7 @@ struct llm_tokenizer_bpe {
                 break;
             case LLAMA_VOCAB_PRE_TYPE_GPT2:
             case LLAMA_VOCAB_PRE_TYPE_OLMO:
+            case LLAMA_VOCAB_PRE_TYPE_JAIS:
                 regex_exprs = {
                     "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
                 };
@@ -17826,6 +17994,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_MAMBA:
         case LLM_ARCH_JINA_BERT_V2:
         case LLM_ARCH_T5:
+        case LLM_ARCH_JAIS:
             return LLAMA_ROPE_TYPE_NONE;
 
         // use what we call a normal RoPE, operating on pairs of consecutive head values