]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add CodeShell support (#5016)
authorchiranko <redacted>
Fri, 19 Jan 2024 09:07:27 +0000 (17:07 +0800)
committerGitHub <redacted>
Fri, 19 Jan 2024 09:07:27 +0000 (11:07 +0200)
* llama: add codeshell support

* llama.cpp: fix codeshell with NeoX rope

Co-authored-by: Georgi Gerganov <redacted>
---------

Co-authored-by: Georgi Gerganov <redacted>
convert-hf-to-gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
llama.cpp

index 1178d63a231faef29426c7939a10d3aedb30b414..aae3a5e876f82aa844eb8bfad282b1accdffb20d 100755 (executable)
@@ -197,6 +197,8 @@ class Model:
             return Phi2Model
         if model_architecture == "PlamoForCausalLM":
             return PlamoModel
+        if model_architecture == "CodeShellForCausalLM":
+            return CodeShellModel
         return Model
 
     def _is_model_safetensors(self) -> bool:
@@ -242,6 +244,8 @@ class Model:
             return gguf.MODEL_ARCH.PHI2
         if arch == "PlamoForCausalLM":
             return gguf.MODEL_ARCH.PLAMO
+        if arch == "CodeShellForCausalLM":
+            return gguf.MODEL_ARCH.CODESHELL
 
         raise NotImplementedError(f'Architecture "{arch}" not supported!')
 
@@ -1175,6 +1179,69 @@ class PlamoModel(Model):
 
             self.gguf_writer.add_tensor(new_name, data)
 
+class CodeShellModel(Model):
+    def set_gguf_parameters(self):
+        block_count = self.hparams["n_layer"]
+
+        self.gguf_writer.add_name("CodeShell")
+        self.gguf_writer.add_context_length(self.hparams["n_positions"])
+        self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
+        self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
+        self.gguf_writer.add_block_count(block_count)
+        self.gguf_writer.add_head_count(self.hparams["n_head"])
+        self.gguf_writer.add_head_count_kv(self.hparams["num_query_groups"])
+        self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
+        self.gguf_writer.add_file_type(self.ftype)
+        self.gguf_writer.add_rope_freq_base(10000.0)
+        self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
+        self.gguf_writer.add_rope_scaling_factor(1.0)
+
+    def write_tensors(self):
+        block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
+        tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
+        tensors = dict(self.get_tensors())
+        has_lm_head = "lm_head.weight" in tensors.keys() or "output.weight" in tensors.keys()
+        for name, data_torch in tensors.items():
+            # we don't need these
+            if name.endswith((".attn.rotary_emb.inv_freq")):
+                continue
+
+            old_dtype = data_torch.dtype
+
+            # convert any unsupported data types to float32
+            if data_torch.dtype not in (torch.float16, torch.float32):
+                data_torch = data_torch.to(torch.float32)
+
+            data = data_torch.squeeze().numpy()
+
+            # map tensor names
+            new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
+            if new_name is None:
+                print(f"Can not map tensor {name!r}")
+                sys.exit()
+
+            n_dims = len(data.shape)
+            data_dtype = data.dtype
+
+            # if f32 desired, convert any float16 to float32
+            if self.ftype == 0 and data_dtype == np.float16:
+                data = data.astype(np.float32)
+
+            # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
+            if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
+                data = data.astype(np.float32)
+
+            # if f16 desired, convert any float32 2-dim weight tensors to float16
+            if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
+                data = data.astype(np.float16)
+
+            print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
+
+            self.gguf_writer.add_tensor(new_name, data)
+
+            if not has_lm_head and name == "transformer.wte.weight":
+                self.gguf_writer.add_tensor("output.weight", data)
+                print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
 
 ###### CONVERSION LOGIC ######
 
index 972b4e9a737667e1e9258125fd57f09ea2b19506..95c58b4192a8d2efa4858833af9fdbd0e862cb0a 100644 (file)
@@ -99,6 +99,7 @@ class MODEL_ARCH(IntEnum):
     QWEN      = auto()
     PHI2      = auto()
     PLAMO     = auto()
+    CODESHELL = auto()
 
 
 class MODEL_TENSOR(IntEnum):
@@ -147,6 +148,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.QWEN:           "qwen",
     MODEL_ARCH.PHI2:           "phi2",
     MODEL_ARCH.PLAMO:          "plamo",
+    MODEL_ARCH.CODESHELL:      "codeshell",
 }
 
 TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -396,6 +398,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_NORM,
         MODEL_TENSOR.FFN_DOWN,
         MODEL_TENSOR.FFN_UP,
+    ],
+    MODEL_ARCH.CODESHELL: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.POS_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_QKV,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.ATTN_ROT_EMBD,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
     ]
     # TODO
 }
@@ -417,6 +432,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.ROPE_FREQS,
         MODEL_TENSOR.ATTN_ROT_EMBD,
     ],
+    MODEL_ARCH.CODESHELL: [
+        MODEL_TENSOR.ROPE_FREQS,
+        MODEL_TENSOR.ATTN_ROT_EMBD,
+    ],
 }
 
 #
index e5b146106b4adb0bf87836bb93e15002b17fad54..de177af1377144c03d8b3c61b4df6f1f9f7e2df5 100644 (file)
@@ -154,6 +154,7 @@ class TensorNameMap:
             "model.layers.{bid}.self_attn.rotary_emb.inv_freq",        # llama-hf
             "layers.{bid}.attention.inner_attention.rope.freqs",       # llama-pth
             "model.layers.layers.{bid}.self_attn.rotary_emb.inv_freq", # plamo
+            "transformer.h.{bid}.attn.rotary_emb.inv_freq",            # codeshell
         ),
 
         # Feed-forward norm
index 47b4384a8b88b98ba2b3a75f489d0723b351c97c..1cee5a7911eac00debc1e9f2b22eafc82cd5dccb 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -194,6 +194,7 @@ enum llm_arch {
     LLM_ARCH_QWEN,
     LLM_ARCH_PHI2,
     LLM_ARCH_PLAMO,
+    LLM_ARCH_CODESHELL,
     LLM_ARCH_UNKNOWN,
 };
 
@@ -213,6 +214,7 @@ static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
     { LLM_ARCH_QWEN,            "qwen"      },
     { LLM_ARCH_PHI2,            "phi2"      },
     { LLM_ARCH_PLAMO,           "plamo"     },
+    { LLM_ARCH_CODESHELL,       "codeshell" },
 };
 
 enum llm_kv {
@@ -600,6 +602,26 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_CODESHELL,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
 
     {
         LLM_ARCH_UNKNOWN,
@@ -2877,6 +2899,14 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_CODESHELL:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
+                switch (hparams.n_layer) {
+                    case 42: model.type = e_model::MODEL_SMALL; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
 
         default: (void)0;
     }
@@ -3784,6 +3814,42 @@ static bool llm_load_tensors(
                         layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff});
                     }
                 } break;
+            case LLM_ARCH_CODESHELL:
+                {
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+                    // output
+                    {
+                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output_norm_b = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
+                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd});
+
+                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
+                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa});
+
+                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.bo   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd});
+
+                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd});
+
+                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
+                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd});
+
+                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i),   {n_embd, n_ff});
+                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i),     {n_ff});
+                    }
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -5965,6 +6031,117 @@ struct llm_build_context {
 
         return gf;
     }
+
+    struct ggml_cgraph * build_codeshell() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+        cb(inpL, "inp_embd", -1);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+        cb(inp_pos, "inp_pos", -1);
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+        cb(KQ_mask, "KQ_mask", -1);
+
+        // shift the entire K-cache if needed
+        if (do_rope_shift) {
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
+        }
+
+        for (int il = 0; il < n_layer; ++il) {
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                cb(cur, "bqkv", il);
+
+                struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                struct ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+                cb(tmpq, "tmpq", il);
+                cb(tmpk, "tmpk", il);
+                cb(Vcur, "Vcur", il);
+
+                struct ggml_tensor * Qcur = ggml_rope_custom(
+                    ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head,    n_tokens), inp_pos,
+                    hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                struct ggml_tensor * Kcur = ggml_rope_custom(
+                    ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                    hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
+
+                cur = llm_build_kqv(ctx0, model, hparams, kv_self,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                cb(cur, "kqv_out", il);
+            }
+
+            // add the input
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // FF
+            {
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm,
+                        model.layers[il].ffn_norm_b,
+                        LLM_NORM, cb, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = llm_build_ffn(ctx0, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
+                        NULL,                      NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            inpL = ggml_add(ctx0, cur, ffn_inp);
+            cb(inpL, "l_out", il);
+        }
+
+        cur = llm_build_norm(ctx0, inpL, hparams,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
 };
 
 static struct ggml_cgraph * llama_build_graph(
@@ -6159,6 +6336,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_gpt2();
             } break;
+        case LLM_ARCH_CODESHELL:
+            {
+                result = llm.build_codeshell();
+            } break;
         default:
             GGML_ASSERT(false);
     }