]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Implement the OLMo architecture (#6741)
authornopperl <redacted>
Fri, 19 Apr 2024 09:35:54 +0000 (09:35 +0000)
committerGitHub <redacted>
Fri, 19 Apr 2024 09:35:54 +0000 (11:35 +0200)
* implement olmo architecture

* remove unused variable

* remove unused moe branch

* remove check for weight

* remove superfluous moe, bias and rope tensors

* clarified comment

* fix clamp_kqv setting

* remove obsolete parameter name filter

README.md
convert-hf-to-gguf.py
gguf-py/gguf/constants.py
llama.cpp

index e3eae60bb015b7b9f6db4a3a677d27e65ace8a16..08661831526fa057668a54fc587b34dd658e3e07 100644 (file)
--- a/README.md
+++ b/README.md
@@ -122,6 +122,7 @@ Typically finetunes of the base models below are supported as well.
 - [x] [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01)
 - [x] [SEA-LION](https://huggingface.co/models?search=sea-lion)
 - [x] [GritLM-7B](https://huggingface.co/GritLM/GritLM-7B) + [GritLM-8x7B](https://huggingface.co/GritLM/GritLM-8x7B)
+- [x] [OLMo](https://allenai.org/olmo)
 
 (instructions for supporting more models: [HOWTO-add-model.md](./docs/HOWTO-add-model.md))
 
index c14186abbc2a6cb6df585e6cf1bd2c195a960397..358dba8ed9d9010340263f273dac7e6c77d62218 100755 (executable)
@@ -2636,6 +2636,66 @@ class CommandR2Model(Model):
         self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
 
 
+@Model.register("OlmoForCausalLM")
+@Model.register("OLMoForCausalLM")
+class OlmoModel(Model):
+    model_arch = gguf.MODEL_ARCH.OLMO
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        self.gguf_writer.add_layer_norm_eps(1e-5)
+        if "clip_qkv" in self.hparams is not None:
+            self.gguf_writer.add_clamp_kqv(self.hparams["clip_qkv"])
+
+    # Same as super class, but permuting q_proj, k_proj
+    # Copied from: LlamaModel
+    def write_tensors(self):
+        block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
+        tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
+        n_head = self.hparams.get("num_attention_heads")
+        n_kv_head = self.hparams.get("num_key_value_heads")
+        for name, data_torch in self.get_tensors():
+            old_dtype = data_torch.dtype
+
+            # convert any unsupported data types to float32
+            if data_torch.dtype not in (torch.float16, torch.float32):
+                data_torch = data_torch.to(torch.float32)
+
+            data = data_torch.numpy()
+
+            if name.endswith("q_proj.weight"):
+                data = permute(data, n_head, n_head)
+            if name.endswith("k_proj.weight"):
+                data = permute(data, n_head, n_kv_head)
+
+            data = data.squeeze()
+
+            # map tensor names
+            new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
+            if new_name is None:
+                print(f"Can not map tensor {name!r}")
+                sys.exit()
+
+            n_dims = len(data.shape)
+            data_dtype = data.dtype
+
+            # if f32 desired, convert any float16 to float32
+            if self.ftype == 0 and data_dtype == np.float16:
+                data = data.astype(np.float32)
+
+            # 1d tensors need to be converted to float32
+            if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
+                data = data.astype(np.float32)
+
+            # if f16 desired, convert any float32 2-dim weight tensors to float16
+            if self.ftype == 1 and data_dtype == np.float32 and n_dims == 2:
+                data = data.astype(np.float16)
+
+            print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
+
+            self.gguf_writer.add_tensor(new_name, data)
+
+
 ###### CONVERSION LOGIC ######
 
 
index feae03e1091732f9f67b51c00a1c85983c541e6d..ba24065a891c71830c9e1750e6f3342ef0450a4a 100644 (file)
@@ -135,6 +135,7 @@ class MODEL_ARCH(IntEnum):
     XVERSE     = auto()
     COMMAND_R  = auto()
     DBRX       = auto()
+    OLMO       = auto()
 
 
 class MODEL_TENSOR(IntEnum):
@@ -210,6 +211,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.XVERSE:         "xverse",
     MODEL_ARCH.COMMAND_R:      "command-r",
     MODEL_ARCH.DBRX:           "dbrx",
+    MODEL_ARCH.OLMO:           "olmo",
 }
 
 TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -695,6 +697,17 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN_EXP,
         MODEL_TENSOR.FFN_UP_EXP,
     ],
+    MODEL_ARCH.OLMO: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
     # TODO
 }
 
index 18e473c095cd5729931568e8d5279ccd0fb04593..fa7c022f291304a2c1c817700c361c37d90fbcf6 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -222,6 +222,7 @@ enum llm_arch {
     LLM_ARCH_XVERSE,
     LLM_ARCH_COMMAND_R,
     LLM_ARCH_DBRX,
+    LLM_ARCH_OLMO,
     LLM_ARCH_UNKNOWN,
 };
 
@@ -256,6 +257,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_XVERSE,          "xverse"     },
     { LLM_ARCH_COMMAND_R,       "command-r"  },
     { LLM_ARCH_DBRX,            "dbrx"       },
+    { LLM_ARCH_OLMO,            "olmo"       },
     { LLM_ARCH_UNKNOWN,         "(unknown)"  },
 };
 
@@ -990,6 +992,20 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
             { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
         },
     },
+    {
+        LLM_ARCH_OLMO,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
     {
         LLM_ARCH_UNKNOWN,
         {
@@ -4070,6 +4086,18 @@ static void llm_load_hparams(
                 default: model.type = e_model::MODEL_UNKNOWN;
             }
         } break;
+        case LLM_ARCH_OLMO:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
+                ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV,     hparams.f_clamp_kqv, false);
+
+                switch (hparams.n_layer) {
+                    case 22: model.type = e_model::MODEL_1B; break;
+                    case 32: model.type = e_model::MODEL_7B; break;
+                    case 80: model.type = e_model::MODEL_70B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
         default: (void)0;
     }
 
@@ -5666,6 +5694,37 @@ static bool llm_load_tensors(
                         layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
                         layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
 
+                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
+                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
+                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                    }
+                } break;
+            case LLM_ARCH_OLMO:  // adapted from LLM_ARCH_LLAMA with norm params removed
+                {
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+                    // output
+                    {
+                        model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false);
+                        // if output is NULL, init from the input tok embed
+                        if (model.output == NULL) {
+                            model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                            ml.n_created--; // artificial tensor
+                            ml.size_data += ggml_nbytes(model.output);
+                        }
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
+                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
+                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+
+
                         layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
                         layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
                         layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
@@ -10096,6 +10155,139 @@ struct llm_build_context {
         return gf;
 
     }
+
+    // ref: https://allenai.org/olmo
+    // based on the original build_llama() function, changes:
+    //   * non-parametric layer norm
+    //   * clamp qkv
+    //   * removed bias
+    //   * removed MoE
+    struct ggml_cgraph * build_olmo() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+        // mutable variable, needed during the last layer of the computation to skip unused tokens
+        int32_t n_tokens = this->n_tokens;
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    NULL, NULL,
+                    LLM_NORM, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (hparams.f_clamp_kqv > 0.0f) {
+                    Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (hparams.f_clamp_kqv > 0.0f) {
+                    Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (hparams.f_clamp_kqv > 0.0f) {
+                    Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_custom(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_custom(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                        model.layers[il].wo, nullptr,
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                n_tokens = n_outputs;
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                    NULL, NULL,
+                    LLM_NORM, cb, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = llm_build_ffn(ctx0, cur,
+                    model.layers[il].ffn_up,   NULL,
+                    model.layers[il].ffn_gate, NULL,
+                    model.layers[il].ffn_down, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
+            if (layer_dir != nullptr) {
+                cur = ggml_add(ctx0, cur, layer_dir);
+            }
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                NULL, NULL,
+                LLM_NORM, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
 };
 
 static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@@ -10301,6 +10493,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_dbrx();
             } break;
+        case LLM_ARCH_OLMO:
+            {
+                result = llm.build_olmo();
+            } break;
         default:
             GGML_ASSERT(false);
     }
@@ -15154,6 +15350,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_MINICPM:
         case LLM_ARCH_XVERSE:
         case LLM_ARCH_COMMAND_R:
+        case LLM_ARCH_OLMO:
             return LLAMA_ROPE_TYPE_NORM;
 
         // the pairs of head values are offset by n_rot/2