]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
[Model] Add support for xverse (#6301)
authorhxer7963 <redacted>
Fri, 29 Mar 2024 13:37:03 +0000 (21:37 +0800)
committerGitHub <redacted>
Fri, 29 Mar 2024 13:37:03 +0000 (14:37 +0100)
* Support xverse model convert to gguf format.

* 1. Convert xverse models to gguf;
2. Add LLM_ARCH_XVERSE inference in llama.cpp;
3. Add xverse item in Supported models in README.md;

* * gguf-py: remove redundant logs
* llama: remove the init_mapping_prefetch custom parameter

* llama.cpp: Include the changes from #6122 to exclude the unused outputs of the last layers.

* - Fix format issues
- Remove duplicate set kqv_out to llm_build_kv

* Update llama.cpp

---------

Co-authored-by: willhe <redacted>
Co-authored-by: willhe <redacted>
README.md
convert-hf-to-gguf.py
gguf-py/gguf/constants.py
llama.cpp

index 74d9d3faed9024711514958520836c574a78cc0e..42925c2144de00bdeb756bddbbeb48b9c0a9b72d 100644 (file)
--- a/README.md
+++ b/README.md
@@ -115,6 +115,7 @@ Typically finetunes of the base models below are supported as well.
 - [x] [CodeShell](https://github.com/WisdomShell/codeshell)
 - [x] [Gemma](https://ai.google.dev/gemma)
 - [x] [Mamba](https://github.com/state-spaces/mamba)
+- [x] [Xverse](https://huggingface.co/models?search=xverse)
 - [x] [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01)
 
 **Multimodal models:**
index 6a2ce187c1994aec807ea7dfeb16b3b19348dea9..18337839ab72df4600db6158476dc1fc2c8d6e4f 100755 (executable)
@@ -773,6 +773,148 @@ class BaichuanModel(Model):
         return weights[r * n_part:r * n_part + r, ...]
 
 
+@Model.register("XverseForCausalLM")
+class XverseModel(Model):
+    model_arch = gguf.MODEL_ARCH.XVERSE
+
+    def set_vocab(self):
+        assert (self.dir_model / "tokenizer.json").is_file()
+        dir_model = self.dir_model
+        hparams = self.hparams
+
+        tokens: list[bytearray] = []
+        toktypes: list[int] = []
+
+        from transformers import AutoTokenizer
+        tokenizer = AutoTokenizer.from_pretrained(dir_model)
+        vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
+        assert max(tokenizer.vocab.values()) < vocab_size
+
+        reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
+        added_vocab = tokenizer.get_added_vocab()
+
+        for token_id in range(vocab_size):
+            token_text = reverse_vocab[token_id].encode('utf-8')
+            # replace "\x00" to string with length > 0
+            if token_text == b"\x00":
+                toktype = gguf.TokenType.BYTE  # special
+                token_text = f"<{token_text}>".encode('utf-8')
+            elif re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text):
+                toktype = gguf.TokenType.BYTE  # special
+            elif reverse_vocab[token_id] in added_vocab:
+                if tokenizer.added_tokens_decoder[token_id].special:
+                    toktype = gguf.TokenType.CONTROL
+                else:
+                    toktype = gguf.TokenType.USER_DEFINED
+            else:
+                toktype = gguf.TokenType.NORMAL
+
+            tokens.append(token_text)
+            toktypes.append(toktype)
+
+        self.gguf_writer.add_tokenizer_model("llama")
+        self.gguf_writer.add_token_list(tokens)
+        self.gguf_writer.add_token_types(toktypes)
+
+        special_vocab = gguf.SpecialVocab(dir_model, n_vocab=len(tokens))
+        special_vocab.add_to_gguf(self.gguf_writer)
+
+    def set_gguf_parameters(self):
+        block_count = self.hparams["num_hidden_layers"]
+        head_count = self.hparams["num_attention_heads"]
+        head_count_kv = self.hparams.get("num_key_value_heads", head_count)
+        hf_repo = self.hparams.get("_name_or_path", "")
+
+        ctx_length = 0
+        if "max_sequence_length" in self.hparams:
+            ctx_length = self.hparams["max_sequence_length"]
+        elif "max_position_embeddings" in self.hparams:
+            ctx_length = self.hparams["max_position_embeddings"]
+        elif "model_max_length" in self.hparams:
+            ctx_length = self.hparams["model_max_length"]
+        else:
+            print("gguf: can not find ctx length parameter.")
+            sys.exit()
+
+        self.gguf_writer.add_name(self.dir_model.name)
+        self.gguf_writer.add_source_hf_repo(hf_repo)
+        self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
+        self.gguf_writer.add_context_length(ctx_length)
+        self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
+        self.gguf_writer.add_block_count(block_count)
+        self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
+        self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
+        self.gguf_writer.add_head_count(head_count)
+        self.gguf_writer.add_head_count_kv(head_count_kv)
+        self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
+
+        if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
+            if self.hparams["rope_scaling"].get("type") == "linear":
+                self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
+                self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
+
+    def write_tensors(self):
+        # Collect tensors from generator object
+        model_kv = dict(self.get_tensors())
+        block_count = self.hparams["num_hidden_layers"]
+        head_count = self.hparams["num_attention_heads"]
+        tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
+        head_count_kv = self.hparams.get("num_key_value_heads", head_count)
+
+        for name, data_torch in model_kv.items():
+            # we don't need these
+            if name.endswith(".rotary_emb.inv_freq"):
+                continue
+
+            old_dtype = data_torch.dtype
+
+            # convert any unsupported data types to float32
+            if data_torch.dtype not in (torch.float16, torch.float32):
+                data_torch = data_torch.to(torch.float32)
+
+            # HF models permute some of the tensors, so we need to undo that
+            if name.endswith(("q_proj.weight")):
+                data_torch = self._reverse_hf_permute(data_torch, head_count, head_count)
+            if name.endswith(("k_proj.weight")):
+                data_torch = self._reverse_hf_permute(data_torch, head_count, head_count_kv)
+
+            data = data_torch.squeeze().numpy()
+
+            # map tensor names
+            new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
+            if new_name is None:
+                print(f"Can not map tensor {name!r}")
+                sys.exit()
+
+            n_dims = len(data.shape)
+            data_dtype = data.dtype
+
+            # if f32 desired, convert any float16 to float32
+            if self.ftype == 0 and data_dtype == np.float16:
+                data = data.astype(np.float32)
+
+            # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
+            if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
+                data = data.astype(np.float32)
+
+            # if f16 desired, convert any float32 2-dim weight tensors to float16
+            if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
+                data = data.astype(np.float16)
+
+            print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
+            self.gguf_writer.add_tensor(new_name, data)
+
+    def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
+        if n_kv_head is not None and n_head != n_kv_head:
+            n_head //= n_kv_head
+
+        return (
+            weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
+            .swapaxes(1, 2)
+            .reshape(weights.shape)
+        )
+
+
 @Model.register("FalconForCausalLM", "RWForCausalLM")
 class FalconModel(Model):
     model_arch = gguf.MODEL_ARCH.FALCON
index 4ab026482a19efa858c51f753ec5cb38bac1dca1..27eaf723cd85fdb9fb282b6b8a12f8379fffe15f 100644 (file)
@@ -123,6 +123,7 @@ class MODEL_ARCH(IntEnum):
     GEMMA      = auto()
     STARCODER2 = auto()
     MAMBA      = auto()
+    XVERSE     = auto()
     COMMAND_R  = auto()
 
 
@@ -191,6 +192,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.GEMMA:          "gemma",
     MODEL_ARCH.STARCODER2:     "starcoder2",
     MODEL_ARCH.MAMBA:          "mamba",
+    MODEL_ARCH.XVERSE:         "xverse",
     MODEL_ARCH.COMMAND_R:      "command-r",
 }
 
@@ -606,6 +608,22 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.SSM_D,
         MODEL_TENSOR.SSM_OUT,
     ],
+    MODEL_ARCH.XVERSE: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ROPE_FREQS,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.ATTN_ROT_EMBD,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
     MODEL_ARCH.COMMAND_R: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.OUTPUT_NORM,
@@ -650,6 +668,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.ROPE_FREQS,
         MODEL_TENSOR.ATTN_ROT_EMBD,
     ],
+    MODEL_ARCH.XVERSE: [
+        MODEL_TENSOR.ROPE_FREQS,
+        MODEL_TENSOR.ATTN_ROT_EMBD,
+    ],
 }
 
 #
index 1875e247168411249d407d63af2ec4b78746fb77..97408ba1621e36275b92882f96b1b51e6f2630ba 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -218,6 +218,7 @@ enum llm_arch {
     LLM_ARCH_GEMMA,
     LLM_ARCH_STARCODER2,
     LLM_ARCH_MAMBA,
+    LLM_ARCH_XVERSE,
     LLM_ARCH_COMMAND_R,
     LLM_ARCH_UNKNOWN,
 };
@@ -249,6 +250,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_GEMMA,           "gemma"      },
     { LLM_ARCH_STARCODER2,      "starcoder2" },
     { LLM_ARCH_MAMBA,           "mamba"      },
+    { LLM_ARCH_XVERSE,          "xverse"     },
     { LLM_ARCH_COMMAND_R,       "command-r"  },
     { LLM_ARCH_UNKNOWN,         "(unknown)"  },
 };
@@ -878,6 +880,25 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
             { LLM_TENSOR_SSM_OUT,         "blk.%d.ssm_out" },
         },
     },
+    {
+        LLM_ARCH_XVERSE,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
     {
         LLM_ARCH_COMMAND_R,
         {
@@ -3847,6 +3868,16 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_XVERSE:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                switch (hparams.n_layer) {
+                    case 32: model.type = e_model::MODEL_7B; break;
+                    case 40: model.type = e_model::MODEL_13B; break;
+                    case 80: model.type = e_model::MODEL_65B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_COMMAND_R:
             {
                 ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
@@ -5200,6 +5231,28 @@ static bool llm_load_tensors(
                         layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd});
                     }
                 } break;
+            case LLM_ARCH_XVERSE:
+                {
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+                    {
+                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
+                    }
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+                        auto & layer = model.layers[i];
+                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
+                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
+                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
+                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
+                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                    }
+                } break;
             case LLM_ARCH_COMMAND_R:
                 {
                     model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@@ -5238,7 +5291,7 @@ static bool llm_load_tensors(
 
     ml.done_getting_tensors();
 
-    ml.init_mappings(true, &model.mlock_mmaps);
+    ml.init_mappings(true, use_mlock ? &model.mlock_mmaps : nullptr);
     model.mappings.reserve(ml.mappings.size());
 
     // create the backend buffers
@@ -6411,6 +6464,111 @@ struct llm_build_context {
         return gf;
     }
 
+    struct ggml_cgraph * build_xverse() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        // positions of the tokens in the KV cache
+        struct ggml_tensor * KQ_pos = build_inp_KQ_pos();
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
+
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_custom(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_custom(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                        model.layers[il].wo, NULL,
+                        Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,      cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            {
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = llm_build_ffn(ctx0, cur,
+                        model.layers[il].ffn_up,   NULL,
+                        model.layers[il].ffn_gate, NULL,
+                        model.layers[il].ffn_down, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
     struct ggml_cgraph * build_falcon() {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
 
@@ -9389,6 +9547,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_mamba();
             } break;
+        case LLM_ARCH_XVERSE:
+            {
+                result = llm.build_xverse();
+            } break;
         case LLM_ARCH_COMMAND_R:
             {
                 result = llm.build_command_r();
@@ -14188,6 +14350,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_ORION:
         case LLM_ARCH_INTERNLM2:
         case LLM_ARCH_MINICPM:
+        case LLM_ARCH_XVERSE:
         case LLM_ARCH_COMMAND_R:
             return LLAMA_ROPE_TYPE_NORM;