]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add Command-R support (#6033)
authorAndrew Canis <redacted>
Fri, 15 Mar 2024 20:41:22 +0000 (16:41 -0400)
committerGitHub <redacted>
Fri, 15 Mar 2024 20:41:22 +0000 (22:41 +0200)
Information about the Command-R 35B model (128k context) can be found at:
https://huggingface.co/CohereForAI/c4ai-command-r-v01

Based on the llama2 model with a few changes:

1) New hyper parameter to scale output logits (logit_scale)
2) Uses LayerNorm instead of RMSNorm
3) Transfomer layers have a single shared LayerNorm that feeds into both the
   self-attention and FFN layers in parallel. There is no post-attention LayerNorm.
4) No support for Rotary Position Embeddings (RoPE) scaling
5) No biases used

Find GGUF files here:
https://huggingface.co/andrewcanis/c4ai-command-r-v01-GGUF

To convert model to GGUF format yourself:

1) Download Command-R Hugging Face safetensors:
git lfs install
git clone https://huggingface.co/CohereForAI/c4ai-command-r-v01

2) Run:
python3 convert-hf-to-gguf.py --outtype f16 ./c4ai-command-r-v01

README.md
convert-hf-to-gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
llama.cpp

index 61bedc3f86309c8a187360030e480d17b3061044..5cbdf7e476b0984e78b4a26d36c9a7addd7b2989 100644 (file)
--- a/README.md
+++ b/README.md
@@ -112,6 +112,7 @@ Typically finetunes of the base models below are supported as well.
 - [x] [CodeShell](https://github.com/WisdomShell/codeshell)
 - [x] [Gemma](https://ai.google.dev/gemma)
 - [x] [Mamba](https://github.com/state-spaces/mamba)
+- [x] [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01)
 
 **Multimodal models:**
 
index 5eee320163d297555cbe60aaa8b9e1cc042d7ae2..cf1f98d66063385838ea3d8beaac22d5061f8a1b 100755 (executable)
@@ -1965,6 +1965,23 @@ class MambaModel(Model):
             self.gguf_writer.add_tensor(new_name, data)
 
 
+@Model.register("CohereForCausalLM")
+class CommandR2Model(Model):
+    model_arch = gguf.MODEL_ARCH.COMMAND_R
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        # max_position_embeddings = 8192 in config.json but model was actually
+        # trained on 128k context length
+        self.hparams["max_position_embeddings"] = self.hparams["model_max_length"]
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        self.gguf_writer.add_logit_scale(self.hparams["logit_scale"])
+        self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
+
+
 ###### CONVERSION LOGIC ######
 
 
index 458a641dcd2297e58f4213601f12f8c4cded9637..4a4facb06ea14a0c4ba92f0f173ed6b3bc1604d2 100644 (file)
@@ -42,6 +42,7 @@ class Keys:
         EXPERT_COUNT          = "{arch}.expert_count"
         EXPERT_USED_COUNT     = "{arch}.expert_used_count"
         POOLING_TYPE          = "{arch}.pooling_type"
+        LOGIT_SCALE           = "{arch}.logit_scale"
 
     class Attention:
         HEAD_COUNT        = "{arch}.attention.head_count"
@@ -121,6 +122,7 @@ class MODEL_ARCH(IntEnum):
     GEMMA      = auto()
     STARCODER2 = auto()
     MAMBA      = auto()
+    COMMAND_R  = auto()
 
 
 class MODEL_TENSOR(IntEnum):
@@ -187,6 +189,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.GEMMA:          "gemma",
     MODEL_ARCH.STARCODER2:     "starcoder2",
     MODEL_ARCH.MAMBA:          "mamba",
+    MODEL_ARCH.COMMAND_R:      "command-r",
 }
 
 TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -579,6 +582,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.SSM_D,
         MODEL_TENSOR.SSM_OUT,
     ],
+    MODEL_ARCH.COMMAND_R: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
     # TODO
 }
 
index 1967b633ce2619d4814161686ab76feb0b8dc9d7..2ae6c814b52de155a3ac222fce9b2ff1e9263b08 100644 (file)
@@ -361,6 +361,9 @@ class GGUFWriter:
     def add_clamp_kqv(self, value: float) -> None:
         self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)
 
+    def add_logit_scale(self, value: float) -> None:
+        self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)
+
     def add_expert_count(self, count: int) -> None:
         self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count)
 
index 8e185d4bfe7731f04a4da7fea6bf8c09b9a684ef..fc5dd5cb43a7ec4aee2c4cad148cea049f2c3256 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -214,6 +214,7 @@ enum llm_arch {
     LLM_ARCH_GEMMA,
     LLM_ARCH_STARCODER2,
     LLM_ARCH_MAMBA,
+    LLM_ARCH_COMMAND_R,
     LLM_ARCH_UNKNOWN,
 };
 
@@ -243,6 +244,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_GEMMA,           "gemma"      },
     { LLM_ARCH_STARCODER2,      "starcoder2" },
     { LLM_ARCH_MAMBA,           "mamba"      },
+    { LLM_ARCH_COMMAND_R,       "command-r"  },
     { LLM_ARCH_UNKNOWN,         "(unknown)"  },
 };
 
@@ -268,6 +270,7 @@ enum llm_kv {
     LLM_KV_EXPERT_COUNT,
     LLM_KV_EXPERT_USED_COUNT,
     LLM_KV_POOLING_TYPE,
+    LLM_KV_LOGIT_SCALE,
 
     LLM_KV_ATTENTION_HEAD_COUNT,
     LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -332,6 +335,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_EXPERT_COUNT,                  "%s.expert_count"          },
     { LLM_KV_EXPERT_USED_COUNT,             "%s.expert_used_count"     },
     { LLM_KV_POOLING_TYPE ,                 "%s.pooling_type"          },
+    { LLM_KV_LOGIT_SCALE,                   "%s.logit_scale"           },
 
     { LLM_KV_ATTENTION_HEAD_COUNT,          "%s.attention.head_count"             },
     { LLM_KV_ATTENTION_HEAD_COUNT_KV,       "%s.attention.head_count_kv"          },
@@ -838,6 +842,21 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
             { LLM_TENSOR_SSM_OUT,         "blk.%d.ssm_out" },
         },
     },
+    {
+        LLM_ARCH_COMMAND_R,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
     {
         LLM_ARCH_UNKNOWN,
         {
@@ -1597,6 +1616,7 @@ enum e_model {
     MODEL_20B,
     MODEL_30B,
     MODEL_34B,
+    MODEL_35B,
     MODEL_40B,
     MODEL_65B,
     MODEL_70B,
@@ -1643,6 +1663,7 @@ struct llama_hparams {
 
     float f_clamp_kqv      = 0.0f;
     float f_max_alibi_bias = 0.0f;
+    float f_logit_scale    = 0.0f;
 
     bool causal_attn = true;
     bool need_kq_pos = false;
@@ -3231,6 +3252,7 @@ static const char * llama_model_type_name(e_model type) {
         case MODEL_20B:    return "20B";
         case MODEL_30B:    return "30B";
         case MODEL_34B:    return "34B";
+        case MODEL_35B:    return "35B";
         case MODEL_40B:    return "40B";
         case MODEL_65B:    return "65B";
         case MODEL_70B:    return "70B";
@@ -3623,6 +3645,15 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_COMMAND_R:
+            {
+                ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
+                switch (hparams.n_layer) {
+                    case 40: model.type = e_model::MODEL_35B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
         default: (void)0;
     }
 
@@ -3944,6 +3975,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     LLAMA_LOG_INFO("%s: f_norm_rms_eps   = %.1e\n",   __func__, hparams.f_norm_rms_eps);
     LLAMA_LOG_INFO("%s: f_clamp_kqv      = %.1e\n",   __func__, hparams.f_clamp_kqv);
     LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n",   __func__, hparams.f_max_alibi_bias);
+    LLAMA_LOG_INFO("%s: f_logit_scale    = %.1e\n",   __func__, hparams.f_logit_scale);
     LLAMA_LOG_INFO("%s: n_ff             = %u\n",     __func__, hparams.n_ff);
     LLAMA_LOG_INFO("%s: n_expert         = %u\n",     __func__, hparams.n_expert);
     LLAMA_LOG_INFO("%s: n_expert_used    = %u\n",     __func__, hparams.n_expert_used);
@@ -4918,6 +4950,37 @@ static bool llm_load_tensors(
                         layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd});
                     }
                 } break;
+            case LLM_ARCH_COMMAND_R:
+                {
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+                    // output
+                    {
+                        model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        // init output from the input tok embed
+                        model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                        ml.n_created--; // artificial tensor
+                        ml.size_data += ggml_nbytes(model.output);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+
+                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
+                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
+                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+
+                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
+                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
+                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                    }
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -8315,6 +8378,121 @@ struct llm_build_context {
 
         return gf;
     }
+
+    struct ggml_cgraph * build_command_r() {
+
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        const float f_logit_scale = hparams.f_logit_scale;
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM, cb, il);
+            cb(cur, "attn_norm", il);
+            struct ggml_tensor * ffn_inp = cur;
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_custom(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_custom(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+            }
+
+            struct ggml_tensor * attn_out = cur;
+
+            // feed-forward network
+            {
+                cur = llm_build_ffn(ctx0, ffn_inp,
+                        model.layers[il].ffn_up,   NULL,
+                        model.layers[il].ffn_gate, NULL,
+                        model.layers[il].ffn_down, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            // add together residual + FFN + self-attention
+            cur = ggml_add(ctx0, cur, inpL);
+            cur = ggml_add(ctx0, cur, attn_out);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = ggml_mul_mat(ctx0, model.output, cur);
+
+        if (f_logit_scale) {
+            cur = ggml_scale(ctx0, cur, f_logit_scale);
+        }
+
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+
+    }
 };
 
 static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@@ -8497,6 +8675,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_mamba();
             } break;
+        case LLM_ARCH_COMMAND_R:
+            {
+                result = llm.build_command_r();
+            } break;
         default:
             GGML_ASSERT(false);
     }
@@ -13147,6 +13329,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_ORION:
         case LLM_ARCH_INTERNLM2:
         case LLM_ARCH_MINICPM:
+        case LLM_ARCH_COMMAND_R:
             return LLAMA_ROPE_TYPE_NORM;
 
         // the pairs of head values are offset by n_rot/2