]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add phi3 support (#6852)
authorliuwei-git <redacted>
Wed, 24 Apr 2024 07:00:37 +0000 (15:00 +0800)
committerGitHub <redacted>
Wed, 24 Apr 2024 07:00:37 +0000 (10:00 +0300)
* add explicit phi3 support

* add explicit phi3 support

* remove unused code

* convert : add BOS token

* llama : match EOT token <|end|>

* llama : minor / style

* llama : tabs -> spaces

* convert : fix lint checks

---------

Co-authored-by: Georgi Gerganov <redacted>
convert-hf-to-gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
llama.cpp

index 4fd916cba3ed73cb0fbfbedd1bf8bc8e4b952d6b..4ace13eb631492747f077d802a82c08568ace061 100755 (executable)
@@ -1979,6 +1979,91 @@ class Phi2Model(Model):
         self.gguf_writer.add_add_bos_token(False)
 
 
+@Model.register("Phi3ForCausalLM")
+class Phi3MiniModel(Model):
+    model_arch = gguf.MODEL_ARCH.PHI3
+
+    def set_vocab(self):
+        from sentencepiece import SentencePieceProcessor
+
+        tokenizer_path = self.dir_model / 'tokenizer.model'
+
+        if not tokenizer_path.is_file():
+            print(f'Error: Missing {tokenizer_path}', file=sys.stderr)
+            sys.exit(1)
+
+        tokenizer = SentencePieceProcessor(str(tokenizer_path))
+
+        vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
+
+        tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
+        scores: list[float] = [-10000.0] * vocab_size
+        toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size
+
+        for token_id in range(tokenizer.vocab_size()):
+
+            piece = tokenizer.id_to_piece(token_id)
+            text = piece.encode("utf-8")
+            score = tokenizer.get_score(token_id)
+
+            toktype = SentencePieceTokenTypes.NORMAL
+            if tokenizer.is_unknown(token_id):
+                toktype = SentencePieceTokenTypes.UNKNOWN
+            elif tokenizer.is_control(token_id):
+                toktype = SentencePieceTokenTypes.CONTROL
+            elif tokenizer.is_unused(token_id):
+                toktype = SentencePieceTokenTypes.UNUSED
+            elif tokenizer.is_byte(token_id):
+                toktype = SentencePieceTokenTypes.BYTE
+
+            tokens[token_id] = text
+            scores[token_id] = score
+            toktypes[token_id] = toktype
+
+        added_tokens_file = self.dir_model / 'added_tokens.json'
+        if added_tokens_file.is_file():
+            with open(added_tokens_file, "r", encoding="utf-8") as f:
+                added_tokens_json = json.load(f)
+
+                for key in added_tokens_json:
+                    token_id = added_tokens_json[key]
+                    if (token_id >= vocab_size):
+                        print(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
+                        continue
+
+                    tokens[token_id] = key.encode("utf-8")
+                    scores[token_id] = -1000.0
+                    toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED
+
+        self.gguf_writer.add_tokenizer_model("llama")
+        self.gguf_writer.add_token_list(tokens)
+        self.gguf_writer.add_token_scores(scores)
+        self.gguf_writer.add_token_types(toktypes)
+
+        special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
+        special_vocab.add_to_gguf(self.gguf_writer)
+
+    def set_gguf_parameters(self):
+        block_count = self.find_hparam(["num_hidden_layers", "n_layer"])
+
+        rot_pct = 1.0
+        n_embd = self.find_hparam(["hidden_size", "n_embd"])
+        n_head = self.find_hparam(["num_attention_heads", "n_head"])
+        rms_eps = self.find_hparam(["rms_norm_eps"])
+
+        self.gguf_writer.add_name("Phi3")
+        self.gguf_writer.add_context_length(self.find_hparam(["n_positions", "max_position_embeddings"]))
+
+        self.gguf_writer.add_embedding_length(n_embd)
+        self.gguf_writer.add_feed_forward_length(8192)
+        self.gguf_writer.add_block_count(block_count)
+        self.gguf_writer.add_head_count(n_head)
+        self.gguf_writer.add_head_count_kv(n_head)
+        self.gguf_writer.add_layer_norm_rms_eps(rms_eps)
+        self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
+        self.gguf_writer.add_file_type(self.ftype)
+
+
 @Model.register("PlamoForCausalLM")
 class PlamoModel(Model):
     model_arch = gguf.MODEL_ARCH.PLAMO
index 06cb26a7d495b71b31d1a48b28b156f3fba45652..d2f1de19831e341107c123a161aee1c10561dcb5 100644 (file)
@@ -124,6 +124,7 @@ class MODEL_ARCH(IntEnum):
     QWEN2      = auto()
     QWEN2MOE   = auto()
     PHI2       = auto()
+    PHI3       = auto()
     PLAMO      = auto()
     CODESHELL  = auto()
     ORION      = auto()
@@ -200,6 +201,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.QWEN2:          "qwen2",
     MODEL_ARCH.QWEN2MOE:       "qwen2moe",
     MODEL_ARCH.PHI2:           "phi2",
+    MODEL_ARCH.PHI3:           "phi3",
     MODEL_ARCH.PLAMO:          "plamo",
     MODEL_ARCH.CODESHELL:      "codeshell",
     MODEL_ARCH.ORION:          "orion",
@@ -550,6 +552,20 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN,
         MODEL_TENSOR.FFN_UP,
     ],
+    MODEL_ARCH.PHI3: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_QKV,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
     MODEL_ARCH.CODESHELL: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.POS_EMBD,
index 10de36fa883ee54b7ecb4d12a968695756cc91e6..e5750d4191f6b1aa686bd573b82e929994254280 100644 (file)
@@ -117,6 +117,7 @@ class TensorNameMap:
             "h.{bid}.attn.c_attn",                                                 # gpt2
             "transformer.h.{bid}.mixer.Wqkv",                                      # phi2
             "encoder.layers.{bid}.attn.Wqkv",                                      # nomic-bert
+            "model.layers.{bid}.self_attn.qkv_proj"                                # phi3
         ),
 
         # Attention query
@@ -234,6 +235,7 @@ class TensorNameMap:
             "h.{bid}.mlp.c_fc",                                       # gpt2
             "transformer.h.{bid}.mlp.fc1",                            # phi2
             "model.layers.{bid}.mlp.fc1",                             # phi2
+            "model.layers.{bid}.mlp.gate_up_proj",                    # phi3
             "model.layers.layers.{bid}.mlp.up_proj",                  # plamo
             "model.layers.{bid}.feed_forward.w3",                     # internlm2
             "encoder.layers.{bid}.mlp.fc11",                          # nomic-bert
index a25d115c1d82af93c3f69df8055b626ab59dcdb9..30fe190373b430498237cf7f784474111d1f6290 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -211,6 +211,7 @@ enum llm_arch {
     LLM_ARCH_QWEN2,
     LLM_ARCH_QWEN2MOE,
     LLM_ARCH_PHI2,
+    LLM_ARCH_PHI3,
     LLM_ARCH_PLAMO,
     LLM_ARCH_CODESHELL,
     LLM_ARCH_ORION,
@@ -246,6 +247,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_QWEN2,           "qwen2"      },
     { LLM_ARCH_QWEN2MOE,        "qwen2moe"   },
     { LLM_ARCH_PHI2,            "phi2"       },
+    { LLM_ARCH_PHI3,            "phi3"       },
     { LLM_ARCH_PLAMO,           "plamo"      },
     { LLM_ARCH_CODESHELL,       "codeshell"  },
     { LLM_ARCH_ORION,           "orion"      },
@@ -793,6 +795,23 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_PHI3,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
     {
         LLM_ARCH_PLAMO,
         {
@@ -3955,6 +3974,16 @@ static void llm_load_hparams(
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
 
+                switch (hparams.n_layer) {
+                    case 24: model.type = e_model::MODEL_1B; break;
+                    case 32: model.type = e_model::MODEL_3B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
+        case LLM_ARCH_PHI3:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
                 switch (hparams.n_layer) {
                     case 24: model.type = e_model::MODEL_1B; break;
                     case 32: model.type = e_model::MODEL_3B; break;
@@ -4352,6 +4381,7 @@ static void llm_load_vocab(
                         //vocab.id_to_token[t.second].type == LLAMA_TOKEN_TYPE_CONTROL &&
                         (t.first == "<|eot_id|>" ||
                          t.first == "<|im_end|>" ||
+                         t.first == "<|end|>" ||
                          t.first == "<end_of_turn>"
                         )
                    ) {
@@ -5375,6 +5405,33 @@ static bool llm_load_tensors(
                         layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff});
                     }
                 } break;
+            case LLM_ARCH_PHI3:
+                {
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab });
+
+                    // output
+                    {
+                        model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd });
+                        model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab });
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context* ctx_layer = ctx_for_layer(i);
+                        ggml_context* ctx_split = ctx_for_layer_split(i);
+
+                        auto& layer = model.layers[i];
+
+                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd });
+
+                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, false);
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd });
+
+                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd });
+
+                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd });
+                        layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff });
+                    }
+                } break;
             case LLM_ARCH_PLAMO:
                 {
                     model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@@ -6326,7 +6383,7 @@ static struct ggml_tensor * llm_build_kqv(
     struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
     cb(kq, "kq", il);
 
-    if (model.arch == LLM_ARCH_PHI2) {
+    if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) {
         // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
         // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
         ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
@@ -8967,12 +9024,140 @@ struct llm_build_context {
 
         cur = ggml_add(ctx0, cur, model.output_b);
         cb(cur, "result_output", -1);
+        ggml_build_forward_expand(gf, cur);
+        return gf;
+    }
+
+    struct ggml_cgraph * build_phi3() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            auto residual = inpL;
+
+            // self-attention
+            {
+                struct ggml_tensor* attn_norm_output = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm,
+                    NULL,
+                    LLM_NORM_RMS, cb, il);
+                cb(attn_norm_output, "attn_norm", il);
+
+                struct ggml_tensor * Qcur = nullptr;
+                struct ggml_tensor * Kcur = nullptr;
+                struct ggml_tensor * Vcur = nullptr;
+
+                if (model.layers[il].wqkv) {
+                    cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output);
+                    cb(cur, "wqkv", il);
+
+                    Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd)));
+                    Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd)));
+                    Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)));
+                }
+                else {
+                    Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
+                    Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
+                    Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
+                }
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = ggml_rope_custom(
+                    ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
+                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head)));
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_custom(
+                    ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
+                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
+                    model.layers[il].wo, NULL,
+                    Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor* inp_out_ids = build_inp_out_ids();
+                cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+                residual = ggml_get_rows(ctx0, residual, inp_out_ids);
+            }
+
+            cur = ggml_add(ctx0, cur, residual);
+            residual = cur;
+
+            cur = llm_build_norm(ctx0, cur, hparams,
+                model.layers[il].ffn_norm, NULL,
+                LLM_NORM_RMS, cb, il);
+            cb(cur, "ffn_norm", il);
+
+            // FF
+            // special-case: the up and gate tensors are merged into a single tensor
+            // TOOD: support into llm_build_ffn
+            {
+                struct ggml_tensor* up = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur);
+                cb(up, "ffn_up", il);
+
+                auto g = ggml_cont(ctx0, ggml_view_2d(ctx0, up, up->ne[0] / 2, up->ne[1], ggml_row_size(up->type, up->ne[0]), 0));
+                auto y = ggml_cont(ctx0, ggml_view_2d(ctx0, up, up->ne[0] / 2, up->ne[1], ggml_row_size(up->type, up->ne[0]), up->nb[1] / 2));
+
+                y = ggml_mul(ctx0, y, ggml_silu(ctx0, g));
+                cb(y, "ffn_gate", il);
+
+                auto down = ggml_mul_mat(ctx0, model.layers[il].ffn_down, y);
+                cb(down, "ffn_down", il);
+
+                cur = down;
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, residual, cur);
+            cb(cur, "l_out", il);
+
+            inpL = cur;
+        }
+
+        cur = llm_build_norm(ctx0, inpL, hparams,
+            model.output_norm,
+            NULL,
+            LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
 
         return gf;
     }
 
+
     struct ggml_cgraph * build_plamo() {
         struct ggml_cgraph * gf = ggml_new_graph(ctx0);
 
@@ -10474,6 +10659,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_phi2();
             } break;
+        case LLM_ARCH_PHI3:
+            {
+                result = llm.build_phi3();
+            } break;
         case LLM_ARCH_PLAMO:
             {
                 result = llm.build_plamo();
@@ -15393,6 +15582,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_QWEN2:
         case LLM_ARCH_QWEN2MOE:
         case LLM_ARCH_PHI2:
+        case LLM_ARCH_PHI3:
         case LLM_ARCH_GEMMA:
         case LLM_ARCH_STARCODER2:
             return LLAMA_ROPE_TYPE_NEOX;