]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add OpenELM support (#7359)
authorIcecream95 <redacted>
Thu, 4 Jul 2024 17:14:21 +0000 (05:14 +1200)
committerGitHub <redacted>
Thu, 4 Jul 2024 17:14:21 +0000 (20:14 +0300)
* Initial OpenELM support (270M only so far)

* Fill out missing entries in llama_model_type_name

* fixup! Initial OpenELM support (270M only so far)

Fix formatting

* llama : support all OpenELM models

* llama : add variable GQA and variable FFN sizes

Some metadata keys can now also be arrays to support setting
their value per-layer for models like OpenELM.

* llama : minor spacing changes

Co-authored-by: Georgi Gerganov <redacted>
* llama : use std::array for per-layer hparams

* llama : fix save/load state

* llama : do not print hparams for vocab-only models

* llama : handle n_head == 0

* llama : use const ref for print_f and fix division by zero

* llama : fix t5 uses of n_head and n_ff

* llama : minor comment

---------

Co-authored-by: Francis Couture-Harpin <redacted>
Co-authored-by: Georgi Gerganov <redacted>
convert_hf_to_gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
gguf-py/gguf/tensor_mapping.py
src/llama.cpp

index a1d165023ce73c4457f96339ee51a10cc4f0e608..bdc336644e626c0fa957869f3d0e67df0e0052ee 100755 (executable)
@@ -13,7 +13,7 @@ import sys
 from enum import IntEnum
 from pathlib import Path
 from hashlib import sha256
-from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast
+from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast
 
 import math
 import numpy as np
@@ -677,6 +677,51 @@ class Model:
         special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
         special_vocab.add_to_gguf(self.gguf_writer)
 
+    def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab_size: int):
+        tokenizer_path = Path(sys.path[0]) / "models" / f"ggml-vocab-{model_name}.gguf"
+        logger.warning(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
+        vocab_reader = gguf.GGUFReader(tokenizer_path, "r")
+
+        default_pre = "mpt" if model_name == "gpt-neox" else "default"
+
+        field = vocab_reader.get_field(gguf.Keys.Tokenizer.MODEL)
+        assert field  # tokenizer model
+        self.gguf_writer.add_tokenizer_model(bytes(field.parts[-1]).decode("utf-8"))
+
+        field = vocab_reader.get_field(gguf.Keys.Tokenizer.PRE)
+        self.gguf_writer.add_tokenizer_pre(bytes(field.parts[-1]).decode("utf-8") if field else default_pre)
+
+        field = vocab_reader.get_field(gguf.Keys.Tokenizer.LIST)
+        assert field  # token list
+        self.gguf_writer.add_token_list([bytes(field.parts[i]) for i in field.data][:vocab_size])
+
+        if model_name == "llama-spm":
+            field = vocab_reader.get_field(gguf.Keys.Tokenizer.SCORES)
+            assert field  # token scores
+            self.gguf_writer.add_token_scores([field.parts[i].tolist()[0] for i in field.data][:vocab_size])
+
+        field = vocab_reader.get_field(gguf.Keys.Tokenizer.TOKEN_TYPE)
+        assert field  # token types
+        self.gguf_writer.add_token_types([field.parts[i].tolist()[0] for i in field.data][:vocab_size])
+
+        if model_name != "llama-spm":
+            field = vocab_reader.get_field(gguf.Keys.Tokenizer.MERGES)
+            assert field  # token merges
+            self.gguf_writer.add_token_merges([bytes(field.parts[i]) for i in field.data])
+
+        if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.BOS_ID)) is not None:
+            self.gguf_writer.add_bos_token_id(field.parts[-1].tolist()[0])
+        if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.EOS_ID)) is not None:
+            self.gguf_writer.add_eos_token_id(field.parts[-1].tolist()[0])
+        if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.UNK_ID)) is not None:
+            self.gguf_writer.add_unk_token_id(field.parts[-1].tolist()[0])
+        if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.PAD_ID)) is not None:
+            self.gguf_writer.add_pad_token_id(field.parts[-1].tolist()[0])
+        if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.ADD_BOS)) is not None:
+            self.gguf_writer.add_add_bos_token(field.parts[-1].tolist()[0])
+        if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.ADD_EOS)) is not None:
+            self.gguf_writer.add_add_eos_token(field.parts[-1].tolist()[0])
+
 
 @Model.register("GPTNeoXForCausalLM")
 class GPTNeoXModel(Model):
@@ -2439,39 +2484,7 @@ class MambaModel(Model):
             self._set_vocab_sentencepiece()
         else:
             # Use the GPT-NeoX tokenizer when no tokenizer files are present
-            tokenizer_path = Path(sys.path[0]) / "models" / "ggml-vocab-gpt-neox.gguf"
-            logger.warning(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
-            neox_reader = gguf.GGUFReader(tokenizer_path, "r")
-
-            field = neox_reader.get_field(gguf.Keys.Tokenizer.MODEL)
-            self.gguf_writer.add_tokenizer_model(bytes(field.parts[-1]).decode("utf-8") if field else "gpt2")
-
-            field = neox_reader.get_field(gguf.Keys.Tokenizer.PRE)
-            self.gguf_writer.add_tokenizer_pre(bytes(field.parts[-1]).decode("utf-8") if field else "mpt")
-
-            field = neox_reader.get_field(gguf.Keys.Tokenizer.LIST)
-            assert field
-            self.gguf_writer.add_token_list([bytes(field.parts[i]) for i in field.data][:vocab_size])
-
-            field = neox_reader.get_field(gguf.Keys.Tokenizer.TOKEN_TYPE)
-            assert field
-            self.gguf_writer.add_token_types([field.parts[i].tolist()[0] for i in field.data][:vocab_size])
-
-            field = neox_reader.get_field(gguf.Keys.Tokenizer.MERGES)
-            assert field
-            self.gguf_writer.add_token_merges([bytes(field.parts[i]) for i in field.data])
-
-            field = neox_reader.get_field(gguf.Keys.Tokenizer.BOS_ID)
-            self.gguf_writer.add_bos_token_id(field.parts[-1].tolist()[0] if field else 1)
-
-            field = neox_reader.get_field(gguf.Keys.Tokenizer.EOS_ID)
-            self.gguf_writer.add_eos_token_id(field.parts[-1].tolist()[0] if field else 0)
-
-            field = neox_reader.get_field(gguf.Keys.Tokenizer.UNK_ID)
-            self.gguf_writer.add_unk_token_id(field.parts[-1].tolist()[0] if field else 0)
-
-            field = neox_reader.get_field(gguf.Keys.Tokenizer.PAD_ID)
-            self.gguf_writer.add_pad_token_id(field.parts[-1].tolist()[0] if field else 0)
+            self._set_vocab_builtin("gpt-neox", vocab_size)
 
     def set_gguf_parameters(self):
         d_model = self.find_hparam(["hidden_size",       "d_model"])
@@ -2623,6 +2636,82 @@ class JinaBertV2Model(BertModel):
         self.gguf_writer.add_add_eos_token(True)
 
 
+@Model.register("OpenELMForCausalLM")
+class OpenELMModel(Model):
+    model_arch = gguf.MODEL_ARCH.OPENELM
+
+    @staticmethod
+    def _make_divisible(v: float | int, divisor: int) -> int:
+        # ref: https://huggingface.co/apple/OpenELM-270M-Instruct/blob/eb111ff2e6724348e5b905984063d4064d4bc579/configuration_openelm.py#L34-L38
+        new_v = max(divisor, int(v + divisor / 2) // divisor * divisor)
+        # Make sure that round down does not go down by more than 10%.
+        if new_v < 0.9 * v:
+            new_v += divisor
+        return new_v
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        ffn_multipliers: list[float] = self.hparams["ffn_multipliers"]
+        ffn_dim_divisor: int = self.hparams["ffn_dim_divisor"]
+        self._n_embd: int = self.hparams["model_dim"]
+        self._num_kv_heads: list[int] = self.hparams["num_kv_heads"]
+        self._num_query_heads: list[int] = self.hparams["num_query_heads"]
+        self._ffn_dims: list[int] = [
+            OpenELMModel._make_divisible(multiplier * self._n_embd, ffn_dim_divisor)
+            for multiplier in ffn_multipliers
+        ]
+        assert isinstance(self._num_kv_heads, list) and isinstance(self._num_kv_heads[0], int)
+        assert isinstance(self._num_query_heads, list) and isinstance(self._num_query_heads[0], int)
+
+    # Uses the tokenizer from meta-llama/Llama-2-7b-hf
+    def set_vocab(self):
+        try:
+            self._set_vocab_sentencepiece()
+        except FileNotFoundError:
+            self._set_vocab_builtin("llama-spm", self.hparams["vocab_size"])
+
+    def set_gguf_parameters(self):
+        n_embd = self._n_embd
+        head_dim = self.hparams["head_dim"]
+        rot_pct = 1.0
+        assert self.block_count == len(self._num_kv_heads)
+        assert self.block_count == len(self._num_query_heads)
+        assert self.block_count == len(self._ffn_dims)
+
+        self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
+        self.gguf_writer.add_block_count(self.block_count)
+        self.gguf_writer.add_context_length(self.hparams["max_context_length"])
+        self.gguf_writer.add_embedding_length(n_embd)
+        self.gguf_writer.add_feed_forward_length(self._ffn_dims)
+        self.gguf_writer.add_head_count(self._num_query_heads)
+        self.gguf_writer.add_head_count_kv(self._num_kv_heads)
+        self.gguf_writer.add_rope_freq_base(self.hparams["rope_freq_constant"])
+        # https://huggingface.co/apple/OpenELM-270M-Instruct/blob/c401df2/modeling_openelm.py#L30
+        self.gguf_writer.add_layer_norm_rms_eps(1e-6)
+        self.gguf_writer.add_rope_dimension_count(int(rot_pct * head_dim))
+        self.gguf_writer.add_key_length(head_dim)
+        self.gguf_writer.add_value_length(head_dim)
+        self.gguf_writer.add_file_type(self.ftype)
+
+    def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any:
+        if "n_layers" in keys:
+            return self.hparams["num_transformer_layers"]
+
+        return super().find_hparam(keys, optional)
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+
+        # split ff
+        if bid is not None and name == f"transformer.layers.{bid}.ffn.proj_1.weight":
+            ff_dim = self._ffn_dims[bid]
+            yield (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), data_torch[:ff_dim])
+            yield (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), data_torch[ff_dim:])
+            return
+
+        yield (self.map_tensor_name(name), data_torch)
+
+
 @Model.register("ArcticForCausalLM")
 class ArcticModel(Model):
     model_arch = gguf.MODEL_ARCH.ARCTIC
index 419f10cee74a7ad4a546d0b081713eda4c0e8b6e..d78fb1e44b29ee7a72b6d6139370ad2306fcaad6 100644 (file)
@@ -160,6 +160,7 @@ class MODEL_ARCH(IntEnum):
     COMMAND_R    = auto()
     DBRX         = auto()
     OLMO         = auto()
+    OPENELM      = auto()
     ARCTIC       = auto()
     DEEPSEEK2    = auto()
     BITNET       = auto()
@@ -285,6 +286,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.COMMAND_R:      "command-r",
     MODEL_ARCH.DBRX:           "dbrx",
     MODEL_ARCH.OLMO:           "olmo",
+    MODEL_ARCH.OPENELM:        "openelm",
     MODEL_ARCH.ARCTIC:         "arctic",
     MODEL_ARCH.DEEPSEEK2:      "deepseek2",
     MODEL_ARCH.BITNET:         "bitnet",
@@ -861,6 +863,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN,
         MODEL_TENSOR.FFN_UP,
     ],
+    MODEL_ARCH.OPENELM: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_QKV,
+        MODEL_TENSOR.ATTN_Q_NORM,
+        MODEL_TENSOR.ATTN_K_NORM,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
     MODEL_ARCH.ARCTIC: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.OUTPUT_NORM,
index 75a8b2636a6a216b79b5348e2ca030daabe34ddc..cf955416290323258ec2dfd2e1927f115fe14ea8 100644 (file)
@@ -480,8 +480,11 @@ class GGUFWriter:
     def add_leading_dense_block_count(self, length: int) -> None:
         self.add_uint32(Keys.LLM.LEADING_DENSE_BLOCK_COUNT.format(arch=self.arch), length)
 
-    def add_feed_forward_length(self, length: int) -> None:
-        self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
+    def add_feed_forward_length(self, length: int | Sequence[int]) -> None:
+        if isinstance(length, int):
+            self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
+        else:
+            self.add_array(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
 
     def add_expert_feed_forward_length(self, length: int) -> None:
         self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
@@ -495,11 +498,17 @@ class GGUFWriter:
     def add_decoder_start_token_id(self, id: int) -> None:
         self.add_uint32(Keys.LLM.DECODER_START_TOKEN_ID.format(arch=self.arch), id)
 
-    def add_head_count(self, count: int) -> None:
-        self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
+    def add_head_count(self, count: int | Sequence[int]) -> None:
+        if isinstance(count, int):
+            self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
+        else:
+            self.add_array(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
 
-    def add_head_count_kv(self, count: int) -> None:
-        self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
+    def add_head_count_kv(self, count: int | Sequence[int]) -> None:
+        if isinstance(count, int):
+            self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
+        else:
+            self.add_array(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count)
 
     def add_key_length(self, length: int) -> None:
         self.add_uint32(Keys.Attention.KEY_LENGTH.format(arch=self.arch), length)
index 20e28423b9a99ce521c38c86e240485d8dcf0bb3..6054f61595e135e1e0b7a27aabe0ec6806fd28cc 100644 (file)
@@ -24,6 +24,7 @@ class TensorNameMap:
             "backbone.embedding",                        # mamba
             "backbone.embeddings",                       # mamba-hf
             "transformer.in_out_embed",                  # Grok
+            "transformer.token_embeddings",              # openelm
             "shared",                                    # t5
         ),
 
@@ -37,6 +38,7 @@ class TensorNameMap:
             "word_embeddings_layernorm",  # bloom
             "embeddings.LayerNorm",       # bert
             "emb_ln",                     # nomic-bert
+            "transformer.norm",           # openelm
         ),
 
         # Position embeddings
@@ -69,6 +71,7 @@ class TensorNameMap:
             "model.norm_f",                            # mamba-qbert
             "backbone.norm_f",                         # mamba
             "transformer.rms_norm",                    # Grok
+            "transformer.norm",                        # openelm
         ),
 
         # Rope frequencies
@@ -98,6 +101,7 @@ class TensorNameMap:
             "backbone.layers.{bid}.norm",                           # mamba
             "transformer.decoder_layer.{bid}.rms_norm",             # Grok
             "transformer.blocks.{bid}.norm_attn_norm.norm_1",       # dbrx
+            "transformer.layers.{bid}.attn_norm",                   # openelm
         ),
 
         # Attention norm 2
@@ -119,7 +123,8 @@ class TensorNameMap:
             "h.{bid}.attn.c_attn",                                                 # gpt2
             "transformer.h.{bid}.mixer.Wqkv",                                      # phi2
             "encoder.layers.{bid}.attn.Wqkv",                                      # nomic-bert
-            "model.layers.{bid}.self_attn.qkv_proj"                                # phi3
+            "model.layers.{bid}.self_attn.qkv_proj",                               # phi3
+            "transformer.layers.{bid}.attn.qkv_proj",                              # openelm
         ),
 
         # Attention query
@@ -177,6 +182,7 @@ class TensorNameMap:
             "encoder.layers.{bid}.attn.out_proj",                           # nomic-bert
             "transformer.decoder_layer.{bid}.multi_head_attention.linear",  # Grok
             "transformer.blocks.{bid}.norm_attn_norm.attn.out_proj",        # dbrx
+            "transformer.layers.{bid}.attn.out_proj",                       # openelm
         ),
 
         # Attention output norm
@@ -212,6 +218,7 @@ class TensorNameMap:
             "h.{bid}.ln_2",                                                  # gpt2
             "model.layers.{bid}.ffn_norm",                                   # internlm2
             "transformer.decoder_layer.{bid}.rms_norm_2",                    # Grok
+            "transformer.layers.{bid}.ffn_norm",                             # openelm
         ),
 
         # Post feed-forward norm
@@ -327,6 +334,7 @@ class TensorNameMap:
             "encoder.layers.{bid}.mlp.fc2",                           # nomic-bert
             "model.layers.{bid}.mlp.c_proj",                          # starcoder2
             "encoder.layer.{bid}.mlp.wo",                             # jina-bert-v2
+            "transformer.layers.{bid}.ffn.proj_2",                    # openelm
             "model.layers.{bid}.residual_mlp.w2",                     # arctic
             "encoder.layer.{bid}.mlp.down_layer",                     # jina-bert-v2
         ),
@@ -348,7 +356,8 @@ class TensorNameMap:
             "model.layers.{bid}.self_attn.q_layernorm",                       # persimmon
             "model.layers.{bid}.self_attn.q_norm",                            # cohere
             "transformer.blocks.{bid}.attn.q_ln",                             # sea-lion
-            "encoder.layer.{bid}.attention.self.layer_norm_q"                 # jina-bert-v2
+            "encoder.layer.{bid}.attention.self.layer_norm_q",                # jina-bert-v2
+            "transformer.layers.{bid}.attn.q_norm",                           # openelm
         ),
 
         MODEL_TENSOR.ATTN_K_NORM: (
@@ -356,7 +365,8 @@ class TensorNameMap:
             "model.layers.{bid}.self_attn.k_layernorm",                       # persimmon
             "model.layers.{bid}.self_attn.k_norm",                            # cohere
             "transformer.blocks.{bid}.attn.k_ln",                             # sea-lion
-            "encoder.layer.{bid}.attention.self.layer_norm_k"                 # jina-bert-v2
+            "encoder.layer.{bid}.attention.self.layer_norm_k",                # jina-bert-v2
+            "transformer.layers.{bid}.attn.k_norm",                           # openelm
         ),
 
         MODEL_TENSOR.ROPE_FREQS: (
index 1fe2b9f79d320d17620019bc59970a9b47f310c2..721b8f4e5931b7128d5f0677335949de83567774 100644 (file)
 #define LLAMA_ATTRIBUTE_FORMAT(...)
 #endif
 
+// bump if necessary
 #define LLAMA_MAX_NODES   8192
-#define LLAMA_MAX_EXPERTS 160
+#define LLAMA_MAX_LAYERS  256
+#define LLAMA_MAX_EXPERTS 160  // DeepSeekV2
 
 //
 // logging
@@ -224,6 +226,7 @@ enum llm_arch {
     LLM_ARCH_COMMAND_R,
     LLM_ARCH_DBRX,
     LLM_ARCH_OLMO,
+    LLM_ARCH_OPENELM,
     LLM_ARCH_ARCTIC,
     LLM_ARCH_DEEPSEEK2,
     LLM_ARCH_BITNET,
@@ -266,6 +269,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_COMMAND_R,       "command-r"    },
     { LLM_ARCH_DBRX,            "dbrx"         },
     { LLM_ARCH_OLMO,            "olmo"         },
+    { LLM_ARCH_OPENELM,         "openelm"      },
     { LLM_ARCH_ARCTIC,          "arctic"       },
     { LLM_ARCH_DEEPSEEK2,       "deepseek2"    },
     { LLM_ARCH_BITNET,          "bitnet"       },
@@ -1134,6 +1138,22 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_OPENELM,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
+            { LLM_TENSOR_ATTN_Q_NORM,     "blk.%d.attn_q_norm" },
+            { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
     {
         LLM_ARCH_ARCTIC,
         {
@@ -2052,8 +2072,10 @@ enum e_model {
     MODEL_160M,
     MODEL_220M,
     MODEL_250M,
+    MODEL_270M,
     MODEL_335M,
     MODEL_410M,
+    MODEL_450M,
     MODEL_770M,
     MODEL_780M,
     MODEL_0_5B,
@@ -2108,19 +2130,20 @@ struct llama_hparams {
     uint32_t n_vocab;
     uint32_t n_ctx_train; // context size the model was trained on
     uint32_t n_embd;
-    uint32_t n_head;
-    uint32_t n_head_kv;
     uint32_t n_layer;
     uint32_t n_rot;
     uint32_t n_swa = 0; // sliding window attention (SWA)
     uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
     uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
-    uint32_t n_ff;
     uint32_t n_expert = 0;
     uint32_t n_expert_used = 0;
     uint32_t n_vocab_type = 0; // for BERT-style token types
     uint32_t n_rel_attn_bkts = 0;
 
+    std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_arr;
+    std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_kv_arr;
+    std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
+
     uint32_t n_layer_dense_lead = 0;
     uint32_t n_lora_q = 0;
     uint32_t n_lora_kv = 0;
@@ -2168,17 +2191,18 @@ struct llama_hparams {
         if (this->n_vocab       != other.n_vocab)       return true;
         if (this->n_ctx_train   != other.n_ctx_train)   return true;
         if (this->n_embd        != other.n_embd)        return true;
-        if (this->n_head        != other.n_head)        return true;
-        if (this->n_head_kv     != other.n_head_kv)     return true;
         if (this->n_layer       != other.n_layer)       return true;
         if (this->n_rot         != other.n_rot)         return true;
         if (this->n_swa         != other.n_swa)         return true;
         if (this->n_embd_head_k != other.n_embd_head_k) return true;
         if (this->n_embd_head_v != other.n_embd_head_v) return true;
-        if (this->n_ff          != other.n_ff)          return true;
         if (this->n_expert      != other.n_expert)      return true;
         if (this->n_expert_used != other.n_expert_used) return true;
 
+        if (this->n_head_arr    != other.n_head_arr)    return true;
+        if (this->n_head_kv_arr != other.n_head_kv_arr) return true;
+        if (this->n_ff_arr      != other.n_ff_arr)      return true;
+
         if (this->n_rel_attn_bkts    != other.n_rel_attn_bkts)    return true;
         if (this->n_layer_dense_lead != other.n_layer_dense_lead) return true;
         if (this->n_lora_q           != other.n_lora_q)           return true;
@@ -2210,18 +2234,53 @@ struct llama_hparams {
         return false;
     }
 
-    uint32_t n_gqa() const {
+    uint32_t n_head(uint32_t il = 0) const {
+        if (il < n_layer) {
+            return n_head_arr[il];
+        }
+
+        GGML_ASSERT(false);
+        return 0;
+    }
+
+    uint32_t n_head_kv(uint32_t il = 0) const {
+        if (il < n_layer) {
+            return n_head_kv_arr[il];
+        }
+
+        GGML_ASSERT(false);
+        return 0;
+    }
+
+    uint32_t n_ff(uint32_t il = 0) const {
+        if (il < n_layer) {
+            return n_ff_arr[il];
+        }
+
+        GGML_ASSERT(false);
+        return 0;
+    }
+
+    uint32_t n_gqa(uint32_t il = 0) const {
+        const uint32_t n_head    = this->n_head(il);
+        const uint32_t n_head_kv = this->n_head_kv(il);
+
         if (n_head_kv == 0) {
             return 0;
         }
+
         return n_head/n_head_kv;
     }
 
-    uint32_t n_embd_k_gqa() const { // dimension of key embeddings across all k-v heads
+    uint32_t n_embd_k_gqa(uint32_t il = 0) const { // dimension of key embeddings across all k-v heads
+        const uint32_t n_head_kv = this->n_head_kv(il);
+
         return n_embd_head_k * n_head_kv;
     }
 
-    uint32_t n_embd_v_gqa() const { // dimension of value embeddings across all k-v heads
+    uint32_t n_embd_v_gqa(uint32_t il = 0) const { // dimension of value embeddings across all k-v heads
+        const uint32_t n_head_kv = this->n_head_kv(il);
+
         return n_embd_head_v * n_head_kv;
     }
 
@@ -2238,6 +2297,8 @@ struct llama_hparams {
     }
 };
 
+static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
+
 struct llama_cparams {
     uint32_t n_ctx;           // context size used during inference
     uint32_t n_batch;
@@ -2857,9 +2918,7 @@ static bool llama_kv_cache_init(
 
     const struct llama_hparams & hparams = model.hparams;
 
-    const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
-    const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
-    const int64_t  n_layer      = hparams.n_layer;
+    const int64_t  n_layer = hparams.n_layer;
 
     cache.has_shift = false;
 
@@ -2867,13 +2926,6 @@ static bool llama_kv_cache_init(
     cache.recurrent = model.arch == LLM_ARCH_MAMBA;
     cache.v_trans   = !cparams.flash_attn;
 
-    // TODO: support mixed recurrent Transformer architectures
-    // NOTE: (!a || b) is a logical implication (a -> b)
-    GGML_ASSERT(!cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_s());
-    GGML_ASSERT(!cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_s());
-    GGML_ASSERT( cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_gqa());
-    GGML_ASSERT( cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_gqa());
-
     cache.head = 0;
     cache.size = kv_size;
     cache.used = 0;
@@ -2923,6 +2975,9 @@ static bool llama_kv_cache_init(
     cache.v_l.reserve(n_layer);
 
     for (int i = 0; i < (int) n_layer; i++) {
+        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
+        const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
+
         struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
         ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
         ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
@@ -3798,9 +3853,9 @@ struct llama_model_loader {
     bool get_arr(const std::string & key, std::vector<T> & result, const bool required = true) {
         const int kid = gguf_find_key(meta, key.c_str());
 
-        if (kid < 0) {
+        if (kid < 0 || gguf_get_kv_type(meta, kid) != GGUF_TYPE_ARRAY) {
             if (required) {
-                throw std::runtime_error(format("key not found in model: %s", key.c_str()));
+                throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
             }
             return false;
         }
@@ -3808,22 +3863,53 @@ struct llama_model_loader {
         struct GGUFMeta::ArrayInfo arr_info =
             GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta, kid);
 
-        if (arr_info.gt != GGUF_TYPE_FLOAT32 && arr_info.gt != GGUF_TYPE_INT32) {
-            throw std::runtime_error(format("%s is not a float32 or int32 array", key.c_str()));
+        switch (arr_info.gt) {
+            case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
+            case GGUF_TYPE_INT32:   GGML_ASSERT(
+                                            (std::is_same<T,  int32_t>::value) ||
+                                            (std::is_same<T, uint32_t>::value));  break;
+            default:
+                throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
         }
 
-        // GGML_ASSERT(gguf_type_size(arr_info.gt) == sizeof(T));
-        GGML_ASSERT((arr_info.gt != GGUF_TYPE_FLOAT32 || std::is_same<T, float>::value));
-        GGML_ASSERT((arr_info.gt != GGUF_TYPE_INT32   || std::is_same<T, int>::value));
-
         result.resize(arr_info.length);
         result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
 
         return true;
     }
 
+    template<typename T, size_t N_MAX>
+    bool get_arr(const std::string & key, std::array<T, N_MAX> & result, const bool required = true) {
+        const int kid = gguf_find_key(meta, key.c_str());
+
+        if (kid < 0 || gguf_get_kv_type(meta, kid) != GGUF_TYPE_ARRAY) {
+            if (required) {
+                throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
+            }
+            return false;
+        }
+
+        struct GGUFMeta::ArrayInfo arr_info =
+            GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta, kid);
+
+        switch (arr_info.gt) {
+            case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
+            case GGUF_TYPE_INT32:   GGML_ASSERT(
+                                            (std::is_same<T,  int32_t>::value) ||
+                                            (std::is_same<T, uint32_t>::value));  break;
+            default:
+                throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
+        }
+
+        GGML_ASSERT(arr_info.length <= N_MAX);
+
+        std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
+
+        return true;
+    }
+
     template<typename T>
-    bool get_arr(const enum llm_kv kid, T& result, const bool required = true) {
+    bool get_arr(const enum llm_kv kid, T & result, const bool required = true) {
         return get_arr(llm_kv(kid), result, required);
     }
 
@@ -3848,6 +3934,50 @@ struct llama_model_loader {
         return get_key(llm_kv(kid), result, required);
     }
 
+    // get array of n <= N_MAX elements, or a single element repeated n times
+    template<typename T, size_t N_MAX>
+    bool get_key_or_arr(const std::string & key, std::array<T, N_MAX> & result, uint32_t n, const bool required = true) {
+        GGML_ASSERT(n <= N_MAX);
+
+        const int kid = gguf_find_key(meta, key.c_str());
+
+        if (kid < 0) {
+            if (required) {
+                throw std::runtime_error(format("key not found in model: %s", key.c_str()));
+            }
+            return false;
+        }
+
+        if (gguf_get_kv_type(meta, kid) == GGUF_TYPE_ARRAY) {
+            struct GGUFMeta::ArrayInfo arr_info =
+                GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta, kid);
+
+            if (n != arr_info.length) {
+                throw std::runtime_error(format("key %s has wrong array length; expected %u, got %u", key.c_str(), n, (uint32_t) arr_info.length));
+            }
+
+            return get_arr(key, result, required);
+        } else {
+            T value;
+
+            bool ok = get_key(key, value, required);
+            if (!ok) {
+                return false;
+            }
+
+            for (uint32_t i = 0; i < n; i++) {
+                result[i] = value;
+            }
+
+            return true;
+        }
+    }
+
+    template<typename T>
+    bool get_key_or_arr(const enum llm_kv kid, T & result, uint32_t n, const bool required = true) {
+        return get_key_or_arr(llm_kv(kid), result, n, required);
+    }
+
     std::string get_arch_name() const {
         return arch_name;
     }
@@ -4340,8 +4470,10 @@ static const char * llama_model_type_name(e_model type) {
         case MODEL_160M:          return "160M";
         case MODEL_220M:          return "220M";
         case MODEL_250M:          return "250M";
+        case MODEL_270M:          return "270M";
         case MODEL_335M:          return "335M";
         case MODEL_410M:          return "410M";
+        case MODEL_450M:          return "450M";
         case MODEL_770M:          return "770M";
         case MODEL_780M:          return "780M";
         case MODEL_0_5B:          return "0.5B";
@@ -4425,20 +4557,18 @@ static void llm_load_hparams(
     ml.get_key(LLM_KV_GENERAL_NAME, model.name, false);
 
     // get hparams kv
-    ml.get_key(LLM_KV_VOCAB_SIZE,           hparams.n_vocab,       false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab);
+    ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab);
 
     // everything past this point is not vocab-related
     if (hparams.vocab_only) {
         return;
     }
 
-    ml.get_key(LLM_KV_CONTEXT_LENGTH,       hparams.n_ctx_train);
-    ml.get_key(LLM_KV_EMBEDDING_LENGTH,     hparams.n_embd);
-    ml.get_key(LLM_KV_FEED_FORWARD_LENGTH,  hparams.n_ff);
-    ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head);
-    ml.get_key(LLM_KV_BLOCK_COUNT,          hparams.n_layer);
-    ml.get_key(LLM_KV_EXPERT_COUNT,         hparams.n_expert,      false);
-    ml.get_key(LLM_KV_EXPERT_USED_COUNT,    hparams.n_expert_used, false);
+    ml.get_key(LLM_KV_CONTEXT_LENGTH,    hparams.n_ctx_train);
+    ml.get_key(LLM_KV_EMBEDDING_LENGTH,  hparams.n_embd);
+    ml.get_key(LLM_KV_BLOCK_COUNT,       hparams.n_layer);
+    ml.get_key(LLM_KV_EXPERT_COUNT,      hparams.n_expert,      false);
+    ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
 
     GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS);
     GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert);
@@ -4448,9 +4578,18 @@ static void llm_load_hparams(
         GGML_ASSERT(hparams.n_expert_used == 0);
     }
 
+    // zero-out the per-layer hparams
+    std::fill(hparams.n_head_arr.begin(),    hparams.n_head_arr.end(),    0);
+    std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
+    std::fill(hparams.n_ff_arr.begin(),      hparams.n_ff_arr.end(),      0);
+
+    ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH,  hparams.n_ff_arr,   hparams.n_layer);
+    ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer);
+
     // n_head_kv is optional, default to n_head
-    hparams.n_head_kv = hparams.n_head;
-    ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv, false);
+    hparams.n_head_kv_arr = hparams.n_head_arr;
+
+    ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, hparams.n_layer, false);
 
     bool rope_finetuned = false;
     ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false);
@@ -4478,26 +4617,31 @@ static void llm_load_hparams(
 
     ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false);
 
-    // sanity check for n_rot (optional)
-    {
-        hparams.n_rot = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head;
+    // non-transformer models do not have attention heads
+    if (hparams.n_head() > 0) {
+        // sanity check for n_rot (optional)
+        hparams.n_rot = hparams.n_embd / hparams.n_head();
 
         ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
 
         if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
-            if (hparams.n_rot != hparams.n_embd / hparams.n_head) {
-                throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head));
+            if (hparams.n_rot != hparams.n_embd / hparams.n_head()) {
+                throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head()));
             }
         }
         // gpt-neox n_rot = rotary_pct * (n_embd / n_head)
         // gpt-j n_rot = rotary_dim
-    }
 
-    hparams.n_embd_head_k = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head;
-    ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
+        hparams.n_embd_head_k = hparams.n_embd / hparams.n_head();
+        ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
 
-    hparams.n_embd_head_v = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head;
-    ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
+        hparams.n_embd_head_v = hparams.n_embd / hparams.n_head();
+        ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
+    } else {
+        hparams.n_rot = 0;
+        hparams.n_embd_head_k = 0;
+        hparams.n_embd_head_v = 0;
+    }
 
     // arch-specific KVs
     switch (model.arch) {
@@ -4521,7 +4665,7 @@ static void llm_load_hparams(
                         case 40: model.type = e_model::MODEL_13B; break;
                         case 48: model.type = e_model::MODEL_34B; break;
                         case 60: model.type = e_model::MODEL_30B; break;
-                        case 80: model.type = hparams.n_head == hparams.n_head_kv ? e_model::MODEL_65B : e_model::MODEL_70B; break;
+                        case 80: model.type = hparams.n_head() == hparams.n_head_kv() ? e_model::MODEL_65B : e_model::MODEL_70B; break;
                         default: model.type = e_model::MODEL_UNKNOWN;
                     }
                 }
@@ -4690,7 +4834,7 @@ static void llm_load_hparams(
                 switch (hparams.n_layer) {
                     case 24: model.type = hparams.n_embd == 1024 ? e_model::MODEL_0_5B : e_model::MODEL_1B; break;
                     case 32: model.type = e_model::MODEL_7B; break;
-                    case 40: model.type = hparams.n_head == 20 ? e_model::MODEL_4B : e_model::MODEL_13B; break;
+                    case 40: model.type = hparams.n_head() == 20 ? e_model::MODEL_4B : e_model::MODEL_13B; break;
                     case 80: model.type = e_model::MODEL_70B; break;
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
@@ -4882,46 +5026,58 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_OPENELM:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+                switch (hparams.n_layer) {
+                case 16: model.type = e_model::MODEL_270M; break;
+                case 20: model.type = e_model::MODEL_450M; break;
+                case 28: model.type = e_model::MODEL_1B; break;
+                case 36: model.type = e_model::MODEL_3B; break;
+                default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_GPTNEOX:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
                 ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res);
                 switch (hparams.n_layer) {
                     case 6:
-                        switch (hparams.n_ff) {
+                        switch (hparams.n_ff()) {
                             case 512: model.type = e_model::MODEL_14M; break;
                             case 2048: model.type = e_model::MODEL_70M; break;
                             default: model.type = e_model::MODEL_UNKNOWN;
                         } break;
                     case 12:
-                        switch (hparams.n_ff) {
+                        switch (hparams.n_ff()) {
                             case 3072: model.type = e_model::MODEL_160M; break;
                             default: model.type = e_model::MODEL_UNKNOWN;
                         } break;
                     case 16:
-                        switch (hparams.n_ff) {
+                        switch (hparams.n_ff()) {
                             case 8192: model.type = e_model::MODEL_1B; break;
                             default: model.type = e_model::MODEL_UNKNOWN;
                         } break;
                     case 24:
-                        switch (hparams.n_ff) {
+                        switch (hparams.n_ff()) {
                             case 4096: model.type = e_model::MODEL_410M; break;
                             case 8192: model.type = e_model::MODEL_1_4B; break;
                             default: model.type = e_model::MODEL_UNKNOWN;
                         } break;
                     case 32:
-                        switch (hparams.n_ff) {
+                        switch (hparams.n_ff()) {
                             case 10240: model.type = e_model::MODEL_2_8B; break;
                             case 16384: model.type = e_model::MODEL_6_9B; break;
                             default: model.type = e_model::MODEL_UNKNOWN;
                         } break;
                     case 36:
-                        switch (hparams.n_ff) {
+                        switch (hparams.n_ff()) {
                             case 20480: model.type = e_model::MODEL_12B; break;
                             default: model.type = e_model::MODEL_UNKNOWN;
                         } break;
                     case 44:
-                        switch (hparams.n_ff) {
+                        switch (hparams.n_ff()) {
                             case 24576: model.type = e_model::MODEL_20B; break;
                             default: model.type = e_model::MODEL_UNKNOWN;
                         } break;
@@ -4984,13 +5140,13 @@ static void llm_load_hparams(
                     case 6:  model.type = e_model::MODEL_60M;  break; // t5-small
                     case 8:  model.type = e_model::MODEL_80M;  break; // flan-t5-small
                     case 12:
-                        switch (hparams.n_ff) {
+                        switch (hparams.n_ff()) {
                             case 3072: model.type = e_model::MODEL_220M; break; // t5-base
                             case 2048: model.type = e_model::MODEL_250M; break; // flan-t5-base
                             default: model.type = e_model::MODEL_UNKNOWN;
                         } break;
                     case 24:
-                        switch (hparams.n_ff) {
+                        switch (hparams.n_ff()) {
                             case 4096:  model.type = e_model::MODEL_770M; break; // t5-large
                             case 2816:  model.type = e_model::MODEL_780M; break; // flan-t5-large
                             case 16384: model.type = e_model::MODEL_3B;   break; // t5-3b
@@ -5533,44 +5689,78 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
 
     const char * rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train);
 
+    auto print_f = [](const std::function<uint32_t(uint32_t)> & f, uint32_t n) {
+        bool is_var = false;
+
+        std::vector<uint32_t> v;
+        for (uint32_t i = 0; i < n; ++i) {
+            v.push_back(f(i));
+            if (v[i] != v[0]) {
+                is_var = true;
+            }
+        }
+
+        std::stringstream ss;
+
+        if (is_var) {
+            ss << "[";
+            for (uint32_t i = 0; i < n; ++i) {
+                ss << v[i];
+                if (i < n - 1) {
+                    ss << ", ";
+                }
+            }
+            ss << "]";
+        } else {
+            ss << v[0];
+        }
+
+        return ss.str();
+    };
+
     // hparams
     LLAMA_LOG_INFO("%s: format           = %s\n",     __func__, llama_file_version_name(ml.fver));
     LLAMA_LOG_INFO("%s: arch             = %s\n",     __func__, LLM_ARCH_NAMES.at(model.arch));
     LLAMA_LOG_INFO("%s: vocab type       = %s\n",     __func__, llama_model_vocab_type_name(vocab.type));
     LLAMA_LOG_INFO("%s: n_vocab          = %u\n",     __func__, hparams.n_vocab);
     LLAMA_LOG_INFO("%s: n_merges         = %u\n",     __func__, (int) vocab.bpe_ranks.size());
-    LLAMA_LOG_INFO("%s: n_ctx_train      = %u\n",     __func__, hparams.n_ctx_train);
-    LLAMA_LOG_INFO("%s: n_embd           = %u\n",     __func__, hparams.n_embd);
-    LLAMA_LOG_INFO("%s: n_head           = %u\n",     __func__, hparams.n_head);
-    LLAMA_LOG_INFO("%s: n_head_kv        = %u\n",     __func__, hparams.n_head_kv);
-    LLAMA_LOG_INFO("%s: n_layer          = %u\n",     __func__, hparams.n_layer);
-    LLAMA_LOG_INFO("%s: n_rot            = %u\n",     __func__, hparams.n_rot);
-    LLAMA_LOG_INFO("%s: n_swa            = %u\n",     __func__, hparams.n_swa);
-    LLAMA_LOG_INFO("%s: n_embd_head_k    = %u\n",     __func__, hparams.n_embd_head_k);
-    LLAMA_LOG_INFO("%s: n_embd_head_v    = %u\n",     __func__, hparams.n_embd_head_v);
-    LLAMA_LOG_INFO("%s: n_gqa            = %u\n",     __func__, hparams.n_gqa());
-    LLAMA_LOG_INFO("%s: n_embd_k_gqa     = %u\n",     __func__, hparams.n_embd_k_gqa());
-    LLAMA_LOG_INFO("%s: n_embd_v_gqa     = %u\n",     __func__, hparams.n_embd_v_gqa());
-    LLAMA_LOG_INFO("%s: f_norm_eps       = %.1e\n",   __func__, hparams.f_norm_eps);
-    LLAMA_LOG_INFO("%s: f_norm_rms_eps   = %.1e\n",   __func__, hparams.f_norm_rms_eps);
-    LLAMA_LOG_INFO("%s: f_clamp_kqv      = %.1e\n",   __func__, hparams.f_clamp_kqv);
-    LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n",   __func__, hparams.f_max_alibi_bias);
-    LLAMA_LOG_INFO("%s: f_logit_scale    = %.1e\n",   __func__, hparams.f_logit_scale);
-    LLAMA_LOG_INFO("%s: n_ff             = %u\n",     __func__, hparams.n_ff);
-    LLAMA_LOG_INFO("%s: n_expert         = %u\n",     __func__, hparams.n_expert);
-    LLAMA_LOG_INFO("%s: n_expert_used    = %u\n",     __func__, hparams.n_expert_used);
-    LLAMA_LOG_INFO("%s: causal attn      = %d\n",     __func__, hparams.causal_attn);
-    LLAMA_LOG_INFO("%s: pooling type     = %d\n",     __func__, hparams.pooling_type);
-    LLAMA_LOG_INFO("%s: rope type        = %d\n",     __func__, hparams.rope_type);
-    LLAMA_LOG_INFO("%s: rope scaling     = %s\n",     __func__, rope_scaling_type);
-    LLAMA_LOG_INFO("%s: freq_base_train  = %.1f\n",   __func__, hparams.rope_freq_base_train);
-    LLAMA_LOG_INFO("%s: freq_scale_train = %g\n",     __func__, hparams.rope_freq_scale_train);
-    LLAMA_LOG_INFO("%s: n_ctx_orig_yarn  = %u\n",     __func__, hparams.n_ctx_orig_yarn);
-    LLAMA_LOG_INFO("%s: rope_finetuned   = %s\n",     __func__, hparams.rope_finetuned ? "yes" : "unknown");
-    LLAMA_LOG_INFO("%s: ssm_d_conv       = %u\n",     __func__, hparams.ssm_d_conv);
-    LLAMA_LOG_INFO("%s: ssm_d_inner      = %u\n",     __func__, hparams.ssm_d_inner);
-    LLAMA_LOG_INFO("%s: ssm_d_state      = %u\n",     __func__, hparams.ssm_d_state);
-    LLAMA_LOG_INFO("%s: ssm_dt_rank      = %u\n",     __func__, hparams.ssm_dt_rank);
+    LLAMA_LOG_INFO("%s: vocab_only       = %d\n",     __func__, hparams.vocab_only);
+
+    if (!hparams.vocab_only) {
+        LLAMA_LOG_INFO("%s: n_ctx_train      = %u\n",     __func__, hparams.n_ctx_train);
+        LLAMA_LOG_INFO("%s: n_embd           = %u\n",     __func__, hparams.n_embd);
+        LLAMA_LOG_INFO("%s: n_layer          = %u\n",     __func__, hparams.n_layer);
+        LLAMA_LOG_INFO("%s: n_head           = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_head(il);    }, hparams.n_layer).c_str());
+        LLAMA_LOG_INFO("%s: n_head_kv        = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
+        LLAMA_LOG_INFO("%s: n_rot            = %u\n",     __func__, hparams.n_rot);
+        LLAMA_LOG_INFO("%s: n_swa            = %u\n",     __func__, hparams.n_swa);
+        LLAMA_LOG_INFO("%s: n_embd_head_k    = %u\n",     __func__, hparams.n_embd_head_k);
+        LLAMA_LOG_INFO("%s: n_embd_head_v    = %u\n",     __func__, hparams.n_embd_head_v);
+        LLAMA_LOG_INFO("%s: n_gqa            = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il);        }, hparams.n_layer).c_str());
+        LLAMA_LOG_INFO("%s: n_embd_k_gqa     = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str());
+        LLAMA_LOG_INFO("%s: n_embd_v_gqa     = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str());
+        LLAMA_LOG_INFO("%s: f_norm_eps       = %.1e\n",   __func__, hparams.f_norm_eps);
+        LLAMA_LOG_INFO("%s: f_norm_rms_eps   = %.1e\n",   __func__, hparams.f_norm_rms_eps);
+        LLAMA_LOG_INFO("%s: f_clamp_kqv      = %.1e\n",   __func__, hparams.f_clamp_kqv);
+        LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n",   __func__, hparams.f_max_alibi_bias);
+        LLAMA_LOG_INFO("%s: f_logit_scale    = %.1e\n",   __func__, hparams.f_logit_scale);
+        LLAMA_LOG_INFO("%s: n_ff             = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str());
+        LLAMA_LOG_INFO("%s: n_expert         = %u\n",     __func__, hparams.n_expert);
+        LLAMA_LOG_INFO("%s: n_expert_used    = %u\n",     __func__, hparams.n_expert_used);
+        LLAMA_LOG_INFO("%s: causal attn      = %d\n",     __func__, hparams.causal_attn);
+        LLAMA_LOG_INFO("%s: pooling type     = %d\n",     __func__, hparams.pooling_type);
+        LLAMA_LOG_INFO("%s: rope type        = %d\n",     __func__, hparams.rope_type);
+        LLAMA_LOG_INFO("%s: rope scaling     = %s\n",     __func__, rope_scaling_type);
+        LLAMA_LOG_INFO("%s: freq_base_train  = %.1f\n",   __func__, hparams.rope_freq_base_train);
+        LLAMA_LOG_INFO("%s: freq_scale_train = %g\n",     __func__, hparams.rope_freq_scale_train);
+        LLAMA_LOG_INFO("%s: n_ctx_orig_yarn  = %u\n",     __func__, hparams.n_ctx_orig_yarn);
+        LLAMA_LOG_INFO("%s: rope_finetuned   = %s\n",     __func__, hparams.rope_finetuned ? "yes" : "unknown");
+        LLAMA_LOG_INFO("%s: ssm_d_conv       = %u\n",     __func__, hparams.ssm_d_conv);
+        LLAMA_LOG_INFO("%s: ssm_d_inner      = %u\n",     __func__, hparams.ssm_d_inner);
+        LLAMA_LOG_INFO("%s: ssm_d_state      = %u\n",     __func__, hparams.ssm_d_state);
+        LLAMA_LOG_INFO("%s: ssm_dt_rank      = %u\n",     __func__, hparams.ssm_dt_rank);
+    }
+
     LLAMA_LOG_INFO("%s: model type       = %s\n",     __func__, llama_model_type_name(model.type));
     LLAMA_LOG_INFO("%s: model ftype      = %s\n",     __func__, llama_model_ftype_name(model.ftype).c_str());
     if (ml.n_elements >= 1e12) {
@@ -5765,13 +5955,13 @@ static bool llm_load_tensors(
     // create tensors for the weights
     {
         const int64_t n_embd       = hparams.n_embd;
-        const int64_t n_embd_head  = (hparams.n_head == 0) ? 0 : n_embd / hparams.n_head;
+        const int64_t n_embd_head  = hparams.n_head() > 0 ? n_embd / hparams.n_head() : 0;
         const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
         const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
         const int64_t n_embd_gqa   = n_embd_v_gqa;
         const int64_t n_vocab      = hparams.n_vocab;
         const int64_t n_vocab_type = hparams.n_vocab_type;
-        const int64_t n_ff         = hparams.n_ff;
+        const int64_t n_ff         = hparams.n_ff();
         const int64_t n_expert     = hparams.n_expert;
 
         if (n_expert > 0 && hparams.n_expert_used == 0) {
@@ -6292,8 +6482,8 @@ static bool llm_load_tensors(
                         layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
                         // optional q and k layernorms, present in StableLM 2 12B
-                        layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head()}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head_kv()}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
                         // optional FFN norm, not present in StableLM 2 12B which uses parallel residual
                         layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
@@ -6664,7 +6854,7 @@ static bool llm_load_tensors(
                     model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
                     model.output      = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
 
-                    const int64_t n_ff          = hparams.n_ff;
+                    const int64_t n_ff          = hparams.n_ff();
                     const int64_t n_embd_head_k = hparams.n_embd_head_k;
                     const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa();
                     const int64_t n_embd_v_gqa  = hparams.n_embd_v_gqa();
@@ -6677,10 +6867,10 @@ static bool llm_load_tensors(
 
                         layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
 
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * hparams.n_head});
+                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * hparams.n_head()});
                         layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
                         layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * hparams.n_head, n_embd});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * hparams.n_head(), n_embd});
 
                         layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
                         layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
@@ -6696,7 +6886,7 @@ static bool llm_load_tensors(
                     model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
                     model.output      = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
 
-                    const int64_t n_ff          = hparams.n_ff;
+                    const int64_t n_ff          = hparams.n_ff();
                     const int64_t n_embd_head_k = hparams.n_embd_head_k;
                     const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa();
                     const int64_t n_embd_v_gqa  = hparams.n_embd_v_gqa();
@@ -6709,10 +6899,10 @@ static bool llm_load_tensors(
 
                         layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
 
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * hparams.n_head});
+                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * hparams.n_head()});
                         layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
                         layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * hparams.n_head, n_embd});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * hparams.n_head(), n_embd});
                         layer.attn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd});
 
                         layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
@@ -6861,8 +7051,8 @@ static bool llm_load_tensors(
                         layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
 
                         if (n_layer >= 64){
-                            layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head});
-                            layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head_kv});
+                            layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head()});
+                            layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head_kv()});
                         }
 
                         layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
@@ -6904,6 +7094,42 @@ static bool llm_load_tensors(
                         layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
                     }
                 } break;
+            case LLM_ARCH_OPENELM:
+                {
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+                    // output
+                    {
+                        model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        // init output from the input tok embed
+                        model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        const int64_t n_head = hparams.n_head(i);
+                        const int64_t n_head_qkv = 2*hparams.n_head_kv(i) + n_head;
+                        const int64_t n_embd_head = hparams.n_embd_head_k;
+
+                        const int64_t n_ff = hparams.n_ff(i);
+
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+
+                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_head_qkv*n_embd_head});
+                        layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head});
+                        layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head*n_embd_head, n_embd});
+
+                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
+                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
+                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
+                    }
+                } break;
             case LLM_ARCH_GPTNEOX:
                 {
                     model.tok_embd   = ml.create_tensor(ctx_input,  tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab});
@@ -7011,13 +7237,13 @@ static bool llm_load_tensors(
 
                         if (!is_lite) {
                             layer.wq_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_A,   "weight", i), {n_embd, q_lora_rank});
-                            layer.wq_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_B,   "weight", i), {q_lora_rank, hparams.n_head * hparams.n_embd_head_k});
+                            layer.wq_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_B,   "weight", i), {q_lora_rank, hparams.n_head() * hparams.n_embd_head_k});
                         } else {
                             layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa});
                         }
                         layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA,   "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope});
-                        layer.wkv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B,   "weight", i), {kv_lora_rank, hparams.n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {hparams.n_head * hparams.n_embd_head_v, n_embd});
+                        layer.wkv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, hparams.n_head() * (n_embd_head_qk_nope + hparams.n_embd_head_v)});
+                        layer.wo    = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT,  "weight", i), {              hparams.n_head() * (                      hparams.n_embd_head_v), n_embd});
 
                         layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
 
@@ -7104,7 +7330,7 @@ static bool llm_load_tensors(
                         auto & layer = model.layers[i];
 
                         layer.attn_norm_enc  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_ATTN_NORM,  "weight", i), {n_embd});
-                        layer.attn_rel_b_enc = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {hparams.n_head, hparams.n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_rel_b_enc = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {hparams.n_head(), hparams.n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
                         layer.wq_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa});
                         layer.wk_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
@@ -7117,7 +7343,7 @@ static bool llm_load_tensors(
                         layer.ffn_up_enc   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_UP,   "weight", i), {n_embd,   n_ff});
 
                         layer.attn_norm  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_ATTN_NORM,  "weight", i), {n_embd});
-                        layer.attn_rel_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {hparams.n_head, hparams.n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_rel_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {hparams.n_head(), hparams.n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
                         layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa});
                         layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
@@ -7126,7 +7352,7 @@ static bool llm_load_tensors(
 
                         layer.attn_norm_cross  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM,  "weight", i), {n_embd});
                         // this tensor seems to be unused in HF transformers implementation
-                        layer.attn_rel_b_cross = ml.create_tensor(ctx_input, tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {hparams.n_head, hparams.n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_rel_b_cross = ml.create_tensor(ctx_input, tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {hparams.n_head(), hparams.n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
                         layer.wq_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa});
                         layer.wk_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
@@ -7455,8 +7681,8 @@ static void llm_build_kv_store(
                     int64_t   il) {
     const int64_t n_ctx = cparams.n_ctx;
 
-    const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
-    const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
+    const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
+    const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
     GGML_ASSERT(kv.size == n_ctx);
 
@@ -7749,12 +7975,12 @@ static struct ggml_tensor * llm_build_kqv(
          const llm_build_cb & cb,
                     int       il) {
     const int64_t n_ctx         = cparams.n_ctx;
-    const int64_t n_head        = hparams.n_head;
-    const int64_t n_head_kv     = hparams.n_head_kv;
+    const int64_t n_head        = hparams.n_head(il);
+    const int64_t n_head_kv     = hparams.n_head_kv(il);
     const int64_t n_embd_head_k = hparams.n_embd_head_k;
-    const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa();
+    const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa(il);
     const int64_t n_embd_head_v = hparams.n_embd_head_v;
-    const int64_t n_embd_v_gqa  = hparams.n_embd_v_gqa();
+    const int64_t n_embd_v_gqa  = hparams.n_embd_v_gqa(il);
 
     struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
     cb(q, "q", il);
@@ -7961,8 +8187,8 @@ struct llm_build_context {
         n_layer          (hparams.n_layer),
         n_rot            (hparams.n_rot),
         n_ctx            (cparams.n_ctx),
-        n_head           (hparams.n_head),
-        n_head_kv        (hparams.n_head_kv),
+        n_head           (hparams.n_head()),
+        n_head_kv        (hparams.n_head_kv()),
         n_embd_head_k    (hparams.n_embd_head_k),
         n_embd_k_gqa     (hparams.n_embd_k_gqa()),
         n_embd_head_v    (hparams.n_embd_head_v),
@@ -8034,6 +8260,8 @@ struct llm_build_context {
         ggml_set_input(lctx.inp_K_shift);
 
         for (int il = 0; il < n_layer; ++il) {
+            const int64_t n_head_kv = hparams.n_head_kv(il);
+            const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
             struct ggml_tensor * rope_factors = build_rope_factors(il);
             struct ggml_tensor * tmp =
                 // we rotate only the first n_rot dimensions
@@ -8093,6 +8321,9 @@ struct llm_build_context {
             }
 
             for (int il = 0; il < n_layer; ++il) {
+                const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
+                const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
+
                 ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il],
                         n_embd_k_gqa, nm,
                         ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
@@ -11974,6 +12205,131 @@ struct llm_build_context {
         return gf;
     }
 
+    struct ggml_cgraph * build_openelm() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            const int64_t n_head    = hparams.n_head(il);
+            const int64_t n_head_kv = hparams.n_head_kv(il);
+            const int64_t n_head_qkv = 2*n_head_kv + n_head;
+
+            cur = inpL;
+            struct ggml_tensor * residual = cur;
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_reshape_3d(ctx0, cur, n_embd_head_k, n_head_qkv, n_tokens);
+
+                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, cur->nb[1], cur->nb[2], 0));
+                cb(Qcur, "Qcur", il);
+
+                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head));
+                cb(Kcur, "Kcur", il);
+
+                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv)));
+                cb(Vcur, "Vcur", il);
+
+                Qcur = llm_build_norm(ctx0, Qcur, hparams,
+                        model.layers[il].attn_q_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(Qcur, "Qcur", il);
+
+                Kcur = llm_build_norm(ctx0, Kcur, hparams,
+                        model.layers[il].attn_k_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(Kcur, "Kcur", il);
+
+                Qcur = ggml_rope_ext(
+                    ctx0, Qcur, inp_pos, NULL, n_rot, rope_type, n_ctx_orig,
+                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                    ctx0, Kcur, inp_pos, NULL, n_rot, rope_type, n_ctx_orig,
+                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                Vcur = ggml_reshape_2d(ctx0, Vcur, n_embd_head * n_head_kv, n_tokens);
+                cb(Qcur, "Vcur", il);
+
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                        model.layers[il].wo, NULL,
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                residual = ggml_get_rows(ctx0, residual, inp_out_ids);
+                cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+            }
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            {
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = llm_build_ffn(ctx0, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        // norm
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
     struct ggml_cgraph * build_gptneox() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
 
@@ -13242,6 +13598,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_olmo();
             } break;
+        case LLM_ARCH_OPENELM:
+            {
+                result = llm.build_openelm();
+            } break;
         case LLM_ARCH_GPTNEOX:
             {
                 result = llm.build_gptneox();
@@ -18772,6 +19132,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_GEMMA:
         case LLM_ARCH_GEMMA2:
         case LLM_ARCH_STARCODER2:
+        case LLM_ARCH_OPENELM:
         case LLM_ARCH_GPTNEOX:
             return LLAMA_ROPE_TYPE_NEOX;
 
@@ -19365,8 +19726,6 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data
         const auto & hparams = ctx->model.hparams;
 
         const uint32_t n_layer      = hparams.n_layer;
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
-        const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
 
         // NOTE: kv_size and kv_buf_size are mostly used for sanity checks
         const uint32_t kv_head     = llama_kv_cache_cell_max(kv_self);
@@ -19386,6 +19745,9 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data
 
             std::vector<uint8_t> tmp_buf;
             for (int il = 0; il < (int) n_layer; ++il) {
+                const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
                 const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
 
                 tmp_buf.resize(k_size);
@@ -19518,8 +19880,6 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
         const auto & hparams = ctx->model.hparams;
 
         const uint32_t n_layer      = hparams.n_layer;
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
-        const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
 
         size_t   kv_buf_size;
         uint32_t kv_head;
@@ -19551,6 +19911,9 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
             GGML_ASSERT(kv_self.total_size() >= kv_buf_size);
 
             for (int il = 0; il < (int) n_layer; ++il) {
+                const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
                 const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
 
                 ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size);
@@ -19713,8 +20076,6 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id)
     const auto & hparams = ctx->model.hparams;
 
     const uint32_t n_layer = hparams.n_layer;
-    const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
-    const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
 
     for (uint32_t i = 0; i < kv_self.size; ++i) {
         const auto & cell = kv_self.cells[i];
@@ -19725,6 +20086,9 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id)
     }
 
     for (int il = 0; il < (int)n_layer; ++il) {
+        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+        const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
         // types of keys and values
         s_cell_data_size += sizeof(int32_t) * 2;
         // k_size_row and v_size_el values of layer
@@ -19799,14 +20163,15 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
 
     const auto & hparams = ctx->model.hparams;
     const uint32_t n_layer = hparams.n_layer;
-    const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
-    const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
 
     // Write the layer count
     data_ctx.write(&n_layer, sizeof(n_layer));
 
-    // Write n_embd_v_gqa
-    data_ctx.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
+    // Write n_embd_v_gqa (reference value)
+    {
+        const uint32_t n_embd_v_gqa_ref = hparams.n_embd_v_gqa() + hparams.n_embd_k_s();
+        data_ctx.write(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
+    }
 
     // Iterate the ranges and write all the pos (this is the token position in the prompt)
     for (const auto & range : cell_ranges) {
@@ -19820,6 +20185,8 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
     // Get whole range at a time
     std::vector<uint8_t> tmp_buf;
     for (int il = 0; il < (int)n_layer; ++il) {
+        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+
         // Write key type
         const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
         data_ctx.write(&k_type_i, sizeof(k_type_i));
@@ -19840,6 +20207,8 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
     // TODO: simplify, reduce copy-paste
     if (!kv_self.v_trans) {
         for (int il = 0; il < (int)n_layer; ++il) {
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
             // Write value type
             const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
             data_ctx.write(&v_type_i, sizeof(v_type_i));
@@ -19860,6 +20229,8 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
         // For the values, they are transposed, so we also need the element size and get the element ranges from each row
         const uint32_t kv_size = kv_self.size;
         for (int il = 0; il < (int)n_layer; ++il) {
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
             // Write value type
             const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
             data_ctx.write(&v_type_i, sizeof(v_type_i));
@@ -19928,14 +20299,14 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
     // Sanity check model compatibility
     const auto & hparams = ctx->model.hparams;
     const uint32_t n_layer = hparams.n_layer;
-    const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
-    const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
+
     if (n_layer != n_layer_ref) {
         LLAMA_LOG_ERROR("%s: mismatched n_layer (%d != %d)\n", __func__, n_layer, n_layer_ref);
         return 0;
     }
-    if (n_embd_v_gqa != n_embd_v_gqa_ref) {
-        LLAMA_LOG_ERROR("%s: mismatched n_embd_v_gqa (%d != %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref);
+
+    if (hparams.n_embd_v_gqa() != n_embd_v_gqa_ref) {
+        LLAMA_LOG_ERROR("%s: mismatched n_embd_v_gqa (%d != %d)\n", __func__, hparams.n_embd_v_gqa(), n_embd_v_gqa_ref);
         return 0;
     }
 
@@ -19975,6 +20346,8 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
 
     // For each layer, read the keys for each cell, one row is one cell, read as one contiguous blo
     for (int il = 0; il < (int)n_layer; ++il) {
+        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+
         // Read type of key
         int32_t k_type_i_ref;
         memcpy(&k_type_i_ref, inp, sizeof(k_type_i_ref));
@@ -20007,6 +20380,8 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
     // TODO: simplify, reduce copy-paste
     if (!kv_self.v_trans) {
         for (int il = 0; il < (int)n_layer; ++il) {
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
             // Read type of value
             int32_t v_type_i_ref;
             memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
@@ -20038,6 +20413,8 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
     } else {
         // For each layer, read the values for each cell (transposed)
         for (int il = 0; il < (int)n_layer; ++il) {
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
             // Read type of value
             int32_t v_type_i_ref;
             memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));