]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add PLaMo model (#3557)
authorShintarou Okada <redacted>
Sun, 24 Dec 2023 13:35:49 +0000 (22:35 +0900)
committerGitHub <redacted>
Sun, 24 Dec 2023 13:35:49 +0000 (15:35 +0200)
* add plamo mock

* add tensor loading

* plamo convert

* update norm

* able to compile

* fix norm_rms_eps hparam

* runnable

* use inp_pos

* seems ok

* update kqv code

* remove develop code

* update README

* shuffle attn_q.weight and attn_output.weight for broadcasting

* remove plamo_llm_build_kqv and use llm_build_kqv

* fix style

* update

* llama : remove obsolete KQ_scale

* plamo : fix tensor names for correct GPU offload

---------

Co-authored-by: Georgi Gerganov <redacted>
README.md
convert-hf-to-gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
llama.cpp

index 649c3b3334387e4cba29ad954beee376afeb33de..09338d2264ca7287a1d183eee1a55d6cb989707b 100644 (file)
--- a/README.md
+++ b/README.md
@@ -102,6 +102,7 @@ as the main playground for developing new features for the [ggml](https://github
 - [x] [Deepseek models](https://huggingface.co/models?search=deepseek-ai/deepseek)
 - [x] [Qwen models](https://huggingface.co/models?search=Qwen/Qwen)
 - [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral)
+- [x] [PLaMo-13B](https://github.com/ggerganov/llama.cpp/pull/3557)
 
 **Multimodal models:**
 
index e71a96c483313e205d9c789f20331ef3876b8858..303d08170ecb04644f699226d42708ec6dabb487 100755 (executable)
@@ -184,6 +184,8 @@ class Model:
             return MixtralModel
         if model_architecture == "PhiForCausalLM":
             return Phi2Model
+        if model_architecture == "PlamoForCausalLM":
+            return PlamoModel
         return Model
 
     def _is_model_safetensors(self) -> bool:
@@ -225,6 +227,8 @@ class Model:
             return gguf.MODEL_ARCH.LLAMA
         if arch == "PhiForCausalLM":
             return gguf.MODEL_ARCH.PHI2
+        if arch == "PlamoForCausalLM":
+            return gguf.MODEL_ARCH.PLAMO
 
         raise NotImplementedError(f'Architecture "{arch}" not supported!')
 
@@ -1002,11 +1006,91 @@ class Phi2Model(Model):
         self.gguf_writer.add_add_bos_token(False)
 
 
+class PlamoModel(Model):
+    def set_vocab(self):
+        self._set_vocab_sentencepiece()
+
+    def set_gguf_parameters(self):
+        hparams = self.hparams
+        block_count = hparams["num_hidden_layers"]
+
+        self.gguf_writer.add_name("PLaMo")
+        self.gguf_writer.add_context_length(4096)  # not in config.json
+        self.gguf_writer.add_embedding_length(hparams["hidden_size"])
+        self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
+        self.gguf_writer.add_block_count(block_count)
+        self.gguf_writer.add_head_count(hparams["num_attention_heads"])
+        self.gguf_writer.add_head_count_kv(5)  # hparams["num_key_value_heads"]) is wrong
+        self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
+
+    def shuffle_attn_q_weight(self, data_torch):
+        assert data_torch.size() == (5120, 5120)
+        data_torch = data_torch.reshape(8, 5, 128, 5120)
+        data_torch = torch.permute(data_torch, (1, 0, 2, 3))
+        data_torch = torch.reshape(data_torch, (5120, 5120))
+        return data_torch
+
+    def shuffle_attn_output_weight(self, data_torch):
+        assert data_torch.size() == (5120, 5120)
+        data_torch = data_torch.reshape(5120, 8, 5, 128)
+        data_torch = torch.permute(data_torch, (0, 2, 1, 3))
+        data_torch = torch.reshape(data_torch, (5120, 5120))
+        return data_torch
+
+    def write_tensors(self):
+        block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
+        tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
+
+        for name, data_torch in self.get_tensors():
+            if "self_attn.rotary_emb.inv_freq" in name:
+                continue
+
+            # map tensor names
+            new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
+            if new_name is None:
+                print(f"Can not map tensor {name!r}")
+                sys.exit()
+
+            # shuffle for broadcasting of gqa in ggml_mul_mat
+            if new_name.endswith("attn_q.weight"):
+                data_torch = self.shuffle_attn_q_weight(data_torch)
+            elif new_name.endswith("attn_output.weight"):
+                data_torch = self.shuffle_attn_output_weight(data_torch)
+
+            old_dtype = data_torch.dtype
+
+            # convert any unsupported data types to float32
+            if data_torch.dtype not in (torch.float16, torch.float32):
+                data_torch = data_torch.to(torch.float32)
+
+            data = data_torch.squeeze().numpy()
+
+            n_dims = len(data.shape)
+            data_dtype = data.dtype
+
+            # if f32 desired, convert any float16 to float32
+            if self.ftype == 0 and data_dtype == np.float16:
+                data = data.astype(np.float32)
+
+            # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
+            if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
+                data = data.astype(np.float32)
+
+            # if f16 desired, convert any float32 2-dim weight tensors to float16
+            if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
+                data = data.astype(np.float16)
+
+            print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
+
+            self.gguf_writer.add_tensor(new_name, data)
+
+
 ###### CONVERSION LOGIC ######
 
 
 def parse_args() -> argparse.Namespace:
-    parser = argparse.ArgumentParser(description="Convert a huggingface model to a GGML compatible file")
+    parser = argparse.ArgumentParser(
+        description="Convert a huggingface model to a GGML compatible file")
     parser.add_argument(
         "--vocab-only", action="store_true",
         help="extract only the vocab",
index 390dca049ebee5bbb240cf91be75e2c2d2b7118a..4cd87cdda8b7e2947691b76776b9eddd41a784c1 100644 (file)
@@ -96,6 +96,7 @@ class MODEL_ARCH(IntEnum):
     STABLELM  = auto()
     QWEN      = auto()
     PHI2      = auto()
+    PLAMO     = auto()
 
 
 class MODEL_TENSOR(IntEnum):
@@ -142,6 +143,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.STABLELM:       "stablelm",
     MODEL_ARCH.QWEN:           "qwen",
     MODEL_ARCH.PHI2:           "phi2",
+    MODEL_ARCH.PLAMO:          "plamo",
 }
 
 TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -349,6 +351,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN,
         MODEL_TENSOR.FFN_UP,
     ],
+    MODEL_ARCH.PLAMO: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ROPE_FREQS,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.ATTN_ROT_EMBD,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
     MODEL_ARCH.GPT2: [
         # TODO
     ],
index 6fcbdbc1c0d4cd43bb0c1abce8ac2f0196107c56..446c6b6883be9c5b6ef4ba4c21247a08008d5af2 100644 (file)
@@ -79,6 +79,7 @@ class TensorNameMap:
             "language_model.encoder.layers.{bid}.input_layernorm",  # persimmon
             "model.layers.{bid}.ln1",                               # yi
             "transformer.h.{bid}.ln",                               # phi2
+            "model.layers.layers.{bid}.norm",                       # plamo
         ),
 
         # Attention norm 2
@@ -99,26 +100,29 @@ class TensorNameMap:
 
         # Attention query
         MODEL_TENSOR.ATTN_Q: (
-            "model.layers.{bid}.self_attn.q_proj",       # llama-hf
-            "layers.{bid}.attention.wq",                 # llama-pth
-            "encoder.layer.{bid}.attention.self.query",  # bert
-            "transformer.h.{bid}.attn.q_proj",           # gpt-j
+            "model.layers.{bid}.self_attn.q_proj",         # llama-hf
+            "layers.{bid}.attention.wq",                   # llama-pth
+            "encoder.layer.{bid}.attention.self.query",    # bert
+            "transformer.h.{bid}.attn.q_proj",             # gpt-j
+            "model.layers.layers.{bid}.self_attn.q_proj",  # plamo
         ),
 
         # Attention key
         MODEL_TENSOR.ATTN_K: (
-            "model.layers.{bid}.self_attn.k_proj",     # llama-hf
-            "layers.{bid}.attention.wk",               # llama-pth
-            "encoder.layer.{bid}.attention.self.key",  # bert
-            "transformer.h.{bid}.attn.k_proj",         # gpt-j
+            "model.layers.{bid}.self_attn.k_proj",         # llama-hf
+            "layers.{bid}.attention.wk",                   # llama-pth
+            "encoder.layer.{bid}.attention.self.key",      # bert
+            "transformer.h.{bid}.attn.k_proj",             # gpt-j
+            "model.layers.layers.{bid}.self_attn.k_proj",  # plamo
         ),
 
         # Attention value
         MODEL_TENSOR.ATTN_V: (
-            "model.layers.{bid}.self_attn.v_proj",       # llama-hf
-            "layers.{bid}.attention.wv",                 # llama-pth
-            "encoder.layer.{bid}.attention.self.value",  # bert
-            "transformer.h.{bid}.attn.v_proj",           # gpt-j
+            "model.layers.{bid}.self_attn.v_proj",         # llama-hf
+            "layers.{bid}.attention.wv",                   # llama-pth
+            "encoder.layer.{bid}.attention.self.value",    # bert
+            "transformer.h.{bid}.attn.v_proj",             # gpt-j
+            "model.layers.layers.{bid}.self_attn.v_proj",  # plamo
         ),
 
         # Attention output
@@ -134,12 +138,14 @@ class TensorNameMap:
             "transformer.h.{bid}.attn.out_proj",                         # gpt-j
             "language_model.encoder.layers.{bid}.self_attention.dense",  # persimmon
             "transformer.h.{bid}.mixer.out_proj",                        # phi2
+            "model.layers.layers.{bid}.self_attn.o_proj",                # plamo
         ),
 
         # Rotary embeddings
         MODEL_TENSOR.ATTN_ROT_EMBD: (
-            "model.layers.{bid}.self_attn.rotary_emb.inv_freq",   # llama-hf
-            "layers.{bid}.attention.inner_attention.rope.freqs",  # llama-pth
+            "model.layers.{bid}.self_attn.rotary_emb.inv_freq",        # llama-hf
+            "layers.{bid}.attention.inner_attention.rope.freqs",       # llama-pth
+            "model.layers.layers.{bid}.self_attn.rotary_emb.inv_freq", # plamo
         ),
 
         # Feed-forward norm
@@ -174,6 +180,7 @@ class TensorNameMap:
             "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h",  # persimmon
             "transformer.h.{bid}.mlp.w1",                             # qwen
             "transformer.h.{bid}.mlp.fc1",                            # phi2
+            "model.layers.layers.{bid}.mlp.up_proj",                  # plamo
         ),
 
         MODEL_TENSOR.FFN_UP_EXP: (
@@ -186,6 +193,7 @@ class TensorNameMap:
             "model.layers.{bid}.mlp.gate_proj",           # llama-hf refact
             "layers.{bid}.feed_forward.w1",               # llama-pth
             "transformer.h.{bid}.mlp.w2",                 # qwen
+            "model.layers.layers.{bid}.mlp.gate_proj",    # plamo
         ),
 
         MODEL_TENSOR.FFN_GATE_EXP: (
@@ -206,6 +214,7 @@ class TensorNameMap:
             "transformer.h.{bid}.mlp.fc_out",                         # gpt-j
             "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h",  # persimmon
             "transformer.h.{bid}.mlp.fc2",                            # phi2
+            "model.layers.layers.{bid}.mlp.down_proj",                # plamo
         ),
 
         MODEL_TENSOR.FFN_DOWN_EXP: (
index a24621539f6bd9d048a7d104afe3ca9054dec6c2..0b99f1e03f5273b320b5265cfcd4ac1ff69380bd 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -198,6 +198,7 @@ enum llm_arch {
     LLM_ARCH_STABLELM,
     LLM_ARCH_QWEN,
     LLM_ARCH_PHI2,
+    LLM_ARCH_PLAMO,
     LLM_ARCH_UNKNOWN,
 };
 
@@ -216,6 +217,7 @@ static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
     { LLM_ARCH_STABLELM,        "stablelm"  },
     { LLM_ARCH_QWEN,            "qwen"      },
     { LLM_ARCH_PHI2,            "phi2"      },
+    { LLM_ARCH_PLAMO,           "plamo"     },
 };
 
 enum llm_kv {
@@ -567,6 +569,24 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_PLAMO,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
 
     {
         LLM_ARCH_UNKNOWN,
@@ -2749,6 +2769,15 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_PLAMO:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+                switch (hparams.n_layer) {
+                    case 40: model.type = e_model::MODEL_13B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+               }
+            } break;
 
         default: (void)0;
     }
@@ -3630,6 +3659,51 @@ static bool llm_load_tensors(
                         layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i),   {n_ff},         backend);
                     }
                 } break;
+            case LLM_ARCH_PLAMO:
+                {
+                    model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
+
+                    // output
+                    {
+                        ggml_backend_type backend_norm;
+                        ggml_backend_type backend_output;
+
+                        if (n_gpu_layers > int(n_layer)) {
+                            backend_norm   = llama_backend_offload;
+                            backend_output = llama_backend_offload_split;
+                        } else {
+                            backend_norm   = GGML_BACKEND_CPU;
+                            backend_output = GGML_BACKEND_CPU;
+                        }
+
+                        model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd},          backend_norm);
+                        model.output      = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, backend_output);
+                    }
+
+                    const uint32_t n_ff = hparams.n_ff;
+
+                    const int i_gpu_start = n_layer - n_gpu_layers;
+
+                    model.layers.resize(n_layer);
+
+                    for (uint32_t i = 0; i < n_layer; ++i) {
+                        const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT
+                        const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
+
+                        layer.wq = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd},     backend_split);
+                        layer.wk = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, backend_split);
+                        layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, backend_split);
+                        layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd},     backend_split);
+
+                        layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, backend_split);
+                        layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, backend_split);
+                        layer.ffn_up   = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, backend_split);
+                    }
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -5555,6 +5629,109 @@ struct llm_build_context {
 
         return gf;
     }
+
+    struct ggml_cgraph * build_plamo() {
+        struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+        cb(inpL, "inp_embd", -1);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+        cb(inp_pos, "inp_pos", -1);
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+        cb(KQ_mask, "KQ_mask", -1);
+
+        // shift the entire K-cache if needed
+        if (do_rope_shift) {
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+        }
+
+        for (int il = 0; il < n_layer; ++il) {
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
+
+            struct ggml_tensor * attention_norm = cur;
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_custom(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos,
+                        n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_custom(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                        n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(Kcur, "Kcur", il);
+
+                llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
+
+                cur = llm_build_kqv(ctx0, model, hparams, kv_self,
+                        model.layers[il].wo, NULL,
+                        Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                cb(cur, "kqv_out", il);
+            }
+            struct ggml_tensor * sa_out = cur;
+
+            cur = attention_norm;
+
+            // feed-forward network
+            {
+                cur = llm_build_ffn(ctx0, cur,
+                        model.layers[il].ffn_up, NULL,
+                        model.layers[il].ffn_gate, NULL,
+                        model.layers[il].ffn_down, NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, sa_out);
+            cb(cur, "l_out", il);
+
+            cur = ggml_add(ctx0, cur, inpL);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
 };
 
 //
@@ -6065,6 +6242,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_phi2();
             } break;
+        case LLM_ARCH_PLAMO:
+            {
+                result = llm.build_plamo();
+            } break;
         default:
             GGML_ASSERT(false);
     }