]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add support for the cohere2 model architecture (#10900)
authorDAN™ <redacted>
Sat, 4 Jan 2025 14:33:31 +0000 (09:33 -0500)
committerGitHub <redacted>
Sat, 4 Jan 2025 14:33:31 +0000 (16:33 +0200)
convert_hf_to_gguf.py
gguf-py/gguf/constants.py
src/llama-arch.cpp
src/llama-arch.h
src/llama-model.cpp
src/llama.cpp

index 4e6c0f60c0621a230c5d65ba12636e6a3ff2c1c0..d4441bbe9d3435e9d4856c9513059e4706c0f4a6 100755 (executable)
@@ -3373,6 +3373,24 @@ class CommandR2Model(Model):
         self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
 
 
+@Model.register("Cohere2ForCausalLM")
+class Cohere2Model(Model):
+    model_arch = gguf.MODEL_ARCH.COHERE2
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+
+        self.gguf_writer.add_logit_scale(self.hparams["logit_scale"])
+        self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
+        self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
+
+        rotary_pct = self.hparams["rotary_pct"]
+        hidden_size = self.hparams["hidden_size"]
+        num_attention_heads = self.hparams["num_attention_heads"]
+        self.gguf_writer.add_rope_dimension_count(int(rotary_pct * (hidden_size // num_attention_heads)))
+        self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
+
+
 @Model.register("OlmoForCausalLM")
 @Model.register("OLMoForCausalLM")
 class OlmoModel(Model):
index 273370370e6ca15891de82bdd6aa188f2ed093a2..cdf79673b5d9a218ddb5a7e19e2f0c7e16d002da 100644 (file)
@@ -255,6 +255,7 @@ class MODEL_ARCH(IntEnum):
     MAMBA            = auto()
     XVERSE           = auto()
     COMMAND_R        = auto()
+    COHERE2          = auto()
     DBRX             = auto()
     OLMO             = auto()
     OLMO2            = auto()
@@ -437,6 +438,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.MAMBA:            "mamba",
     MODEL_ARCH.XVERSE:           "xverse",
     MODEL_ARCH.COMMAND_R:        "command-r",
+    MODEL_ARCH.COHERE2:          "cohere2",
     MODEL_ARCH.DBRX:             "dbrx",
     MODEL_ARCH.OLMO:             "olmo",
     MODEL_ARCH.OLMO2:            "olmo2",
@@ -1136,6 +1138,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.ATTN_K_NORM,
         MODEL_TENSOR.ATTN_Q_NORM,
     ],
+    MODEL_ARCH.COHERE2: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
     MODEL_ARCH.DBRX: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.OUTPUT_NORM,
index a6003838559a7fba9c51a0c22f75796d2536e746..fea4b21d3be939ad35d98184557d4514559d8858 100644 (file)
@@ -39,6 +39,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_MAMBA,            "mamba"            },
     { LLM_ARCH_XVERSE,           "xverse"           },
     { LLM_ARCH_COMMAND_R,        "command-r"        },
+    { LLM_ARCH_COHERE2,          "cohere2"          },
     { LLM_ARCH_DBRX,             "dbrx"             },
     { LLM_ARCH_OLMO,             "olmo"             },
     { LLM_ARCH_OLMO2,            "olmo2"            },
@@ -807,6 +808,21 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm" },
         },
     },
+    {
+        LLM_ARCH_COHERE2,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
     {
         LLM_ARCH_DBRX,
         {
index 446e72eebf6d6432b534df2a7a2d327a47ca9955..10bd619a4b9877f06e4e022bef9d66f042d689f7 100644 (file)
@@ -43,6 +43,7 @@ enum llm_arch {
     LLM_ARCH_MAMBA,
     LLM_ARCH_XVERSE,
     LLM_ARCH_COMMAND_R,
+    LLM_ARCH_COHERE2,
     LLM_ARCH_DBRX,
     LLM_ARCH_OLMO,
     LLM_ARCH_OLMO2,
index ace0ba2629f680ff148534ab062af992de6d9c4b..c356abded6c19dd96073519dc07481b40747db37 100644 (file)
@@ -786,6 +786,16 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) {
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_COHERE2:
+            {
+                ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
+                ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
+                switch (hparams.n_layer) {
+                    case 32: model.type = e_model::MODEL_8B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_DBRX:
         {
             ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -2031,6 +2041,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_MINICPM:
         case LLM_ARCH_XVERSE:
         case LLM_ARCH_COMMAND_R:
+        case LLM_ARCH_COHERE2:
         case LLM_ARCH_OLMO:
         case LLM_ARCH_ARCTIC:
         case LLM_ARCH_DEEPSEEK:
index d7110b90bcce0d44e879dbc53d6368f7871e1f53..50e9191fa574ef3600e7696c74dec182682c36f6 100644 (file)
@@ -1552,6 +1552,32 @@ static bool llm_load_tensors(
                         layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                     }
                 } break;
+            case LLM_ARCH_COHERE2:
+                {
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
+
+                    // output
+                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
+                    // init output from the input tok embed
+                    model.output      = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab },
+                                                      llama_model_loader::TENSOR_DUPLICATED);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd }, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
+                    }
+                }
+                break;
             case LLM_ARCH_OLMO:  // adapted from LLM_ARCH_LLAMA with norm params removed
                 {
                     model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -7633,6 +7659,137 @@ struct llm_build_context {
 
     }
 
+    struct ggml_cgraph * build_cohere2() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        const float f_logit_scale = hparams.f_logit_scale;
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        // cohere2 requires different mask for layers using sliding window (SWA)
+        struct ggml_tensor * KQ_mask     = build_inp_KQ_mask();
+        struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa();
+
+        // sliding window switch pattern
+        const int32_t sliding_window_pattern = 4;
+
+        for (int il = 0; il < n_layer; ++il) {
+            // three layers sliding window attention (window size 4096) and ROPE
+            // fourth layer uses global attention without positional embeddings
+            const bool           is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
+            struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM, cb, il);
+            cb(cur, "attn_norm", il);
+            struct ggml_tensor * ffn_inp = cur;
+
+            // self-attention
+            {
+                // rope freq factors for 128k context
+                struct ggml_tensor * rope_factors = build_rope_factors(il);
+
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                if (is_sliding) {
+                    Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
+                                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
+                                        beta_fast, beta_slow);
+                    cb(Qcur, "Qcur", il);
+
+                    Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                                        rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
+                                        attn_factor, beta_fast, beta_slow);
+                    cb(Kcur, "Kcur", il);
+                } else {
+                    // For non-sliding layers, just reshape without applying RoPE
+                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                    cb(Qcur, "Qcur", il);
+
+                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur,
+                                   KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur                              = ggml_get_rows(ctx0, cur, inp_out_ids);
+                inpL                             = ggml_get_rows(ctx0, inpL, inp_out_ids);
+                ffn_inp                          = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
+            }
+
+            struct ggml_tensor * attn_out = cur;
+
+            // feed-forward network
+            {
+                cur = llm_build_ffn(ctx0, lctx, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate,
+                                    NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR,
+                                    cb, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            // add together residual + FFN + self-attention
+            cur = ggml_add(ctx0, cur, inpL);
+            cur = ggml_add(ctx0, cur, attn_out);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+        if (f_logit_scale) {
+            cur = ggml_scale(ctx0, cur, f_logit_scale);
+        }
+
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
     // ref: https://allenai.org/olmo
     // based on the original build_llama() function, changes:
     //   * non-parametric layer norm
@@ -10384,6 +10541,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_command_r();
             } break;
+        case LLM_ARCH_COHERE2:
+            {
+                result = llm.build_cohere2();
+            } break;
         case LLM_ARCH_DBRX:
             {
                 result = llm.build_dbrx();