]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add StarCoder2 support (#5795)
authorSourab Mangrulkar <redacted>
Fri, 1 Mar 2024 19:30:46 +0000 (01:00 +0530)
committerGitHub <redacted>
Fri, 1 Mar 2024 19:30:46 +0000 (21:30 +0200)
* Add support for starcoder2

* handle rope type

* skip rope freq and rotary embeddings from being serialized

* resolve comments

* Update llama.cpp

* remove redundant changes

* handle `rope-theta`

* llama : change starcoder2 rope type

* address comment

---------

Co-authored-by: Georgi Gerganov <redacted>
convert-hf-to-gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
llama.cpp

index d3e8ec1f60c70ea9f8fa68a7597cd09bd33d7cc1..28b92ac385367ef8bcad666f5a6eb70edc6493a7 100755 (executable)
@@ -96,9 +96,11 @@ class Model:
         if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
             self.gguf_writer.add_head_count_kv(n_head_kv)
 
+        if (rope_theta := self.hparams.get("rope_theta")) is not None:
+            self.gguf_writer.add_rope_freq_base(rope_theta)
         if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
             self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps)
-        if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon"], optional=True)) is not None:
+        if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None:
             self.gguf_writer.add_layer_norm_eps(f_norm_eps)
         if (n_experts := self.hparams.get("num_local_experts")) is not None:
             self.gguf_writer.add_expert_count(n_experts)
@@ -220,6 +222,8 @@ class Model:
             return NomicBertModel
         if model_architecture == "GemmaForCausalLM":
             return GemmaModel
+        if model_architecture == "Starcoder2ForCausalLM":
+            return Model
         return Model
 
     def _is_model_safetensors(self) -> bool:
@@ -281,6 +285,8 @@ class Model:
             return gguf.MODEL_ARCH.NOMIC_BERT
         if arch == "GemmaForCausalLM":
             return gguf.MODEL_ARCH.GEMMA
+        if arch == "Starcoder2ForCausalLM":
+            return gguf.MODEL_ARCH.STARCODER2
 
         raise NotImplementedError(f'Architecture "{arch}" not supported!')
 
index 8f9139d1b7eca8a93b292c5eace3f4da0926cb45..5db760cb14900ec2297941b750f238d094ed102f 100644 (file)
@@ -112,6 +112,7 @@ class MODEL_ARCH(IntEnum):
     INTERNLM2  = auto()
     MINICPM    = auto()
     GEMMA      = auto()
+    STARCODER2 = auto()
 
 
 class MODEL_TENSOR(IntEnum):
@@ -169,6 +170,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.INTERNLM2:      "internlm2",
     MODEL_ARCH.MINICPM:        "minicpm",
     MODEL_ARCH.GEMMA:          "gemma",
+    MODEL_ARCH.STARCODER2:     "starcoder2",
 }
 
 TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -526,6 +528,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_UP,
         MODEL_TENSOR.FFN_NORM,
     ],
+    MODEL_ARCH.STARCODER2: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ROPE_FREQS,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.ATTN_ROT_EMBD,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
     # TODO
 }
 
@@ -554,6 +571,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.ROPE_FREQS,
         MODEL_TENSOR.ATTN_ROT_EMBD,
     ],
+    MODEL_ARCH.STARCODER2: [
+        MODEL_TENSOR.ROPE_FREQS,
+        MODEL_TENSOR.ATTN_ROT_EMBD,
+    ],
 }
 
 #
index 8610037767fb62b725b1397864f79b317530c3d7..db2ec9704a441ebe42c8185d12bb58147ea171ee 100644 (file)
@@ -210,6 +210,7 @@ class TensorNameMap:
             "model.layers.layers.{bid}.mlp.up_proj",                  # plamo
             "model.layers.{bid}.feed_forward.w3",                     # internlm2
             "encoder.layers.{bid}.mlp.fc11",                          # nomic-bert
+            "model.layers.{bid}.mlp.c_fc",                            # starcoder2
         ),
 
         MODEL_TENSOR.FFN_UP_EXP: (
@@ -256,6 +257,7 @@ class TensorNameMap:
             "model.layers.layers.{bid}.mlp.down_proj",                # plamo
             "model.layers.{bid}.feed_forward.w2",                     # internlm2
             "encoder.layers.{bid}.mlp.fc2",                           # nomic-bert
+            "model.layers.{bid}.mlp.c_proj",                          # starcoder2
         ),
 
         MODEL_TENSOR.FFN_DOWN_EXP: (
index 073fd3b703037b39b3228849845ef3c3b0d08bea..b1db5b1797dc501d4f68024b4115a1d77e9181b1 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -211,6 +211,7 @@ enum llm_arch {
     LLM_ARCH_INTERNLM2,
     LLM_ARCH_MINICPM,
     LLM_ARCH_GEMMA,
+    LLM_ARCH_STARCODER2,
     LLM_ARCH_UNKNOWN,
 };
 
@@ -238,6 +239,7 @@ static std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_INTERNLM2,       "internlm2"  },
     { LLM_ARCH_MINICPM,         "minicpm"    },
     { LLM_ARCH_GEMMA,           "gemma"      },
+    { LLM_ARCH_STARCODER2,      "starcoder2" },
 };
 
 enum llm_kv {
@@ -779,6 +781,24 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_STARCODER2,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
     {
         LLM_ARCH_UNKNOWN,
         {
@@ -3320,6 +3340,16 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                }
             } break;
+        case LLM_ARCH_STARCODER2:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
+                switch (hparams.n_layer) {
+                    case 30: model.type = e_model::MODEL_3B; break;
+                    case 32: model.type = e_model::MODEL_7B; break;
+                    case 40: model.type = e_model::MODEL_15B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
         default: (void)0;
     }
 
@@ -4490,6 +4520,56 @@ static bool llm_load_tensors(
                         layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
                     }
                 } break;
+            case LLM_ARCH_STARCODER2:
+                {
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+                    // output
+                    {
+                        model.output_norm   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
+
+                        model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false);
+                        // if output is NULL, init from the input tok embed
+                        if (model.output == NULL) {
+                            model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                            ml.n_created--; // artificial tensor
+                            ml.size_data += ggml_nbytes(model.output);
+                        }
+
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd});
+
+                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
+                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
+                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+
+                        // optional bias tensors
+                        layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd});
+                        layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa});
+                        layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa});
+                        layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd});
+
+                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd});
+
+                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
+                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
+
+                        // optional bias tensors
+                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd});
+                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP ,  "bias", i), {  n_ff});
+                    }
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -7559,6 +7639,120 @@ struct llm_build_context {
 
         return gf;
     }
+
+    struct ggml_cgraph * build_starcoder2() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
+        cb(inpL, "inp_embd", -1);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
+        cb(inp_pos, "inp_pos", -1);
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
+        cb(KQ_mask, "KQ_mask", -1);
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, model.layers[il].attn_norm_b,
+                    LLM_NORM, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_custom(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_custom(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                cb(cur, "kqv_out", il);
+            }
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+
+            cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                    model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
+                    LLM_NORM, cb, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = llm_build_ffn(ctx0, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
+                        NULL,                      NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
+            cb(cur, "ffn_out", il);
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, model.output_norm_b,
+                LLM_NORM, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
 };
 
 static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@@ -7705,6 +7899,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_gemma();
             } break;
+        case LLM_ARCH_STARCODER2:
+            {
+                result = llm.build_starcoder2();
+            } break;
         default:
             GGML_ASSERT(false);
     }
@@ -12084,6 +12282,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_QWEN2:
         case LLM_ARCH_PHI2:
         case LLM_ARCH_GEMMA:
+        case LLM_ARCH_STARCODER2:
             return LLAMA_ROPE_TYPE_NEOX;
 
         // all model arches should be listed explicitly here