]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
feature : support Baichuan serial models (#3009)
authorjameswu2014 <redacted>
Thu, 14 Sep 2023 16:32:10 +0000 (00:32 +0800)
committerGitHub <redacted>
Thu, 14 Sep 2023 16:32:10 +0000 (12:32 -0400)
convert-baichuan-hf-to-gguf.py [new file with mode: 0755]
gguf-py/gguf/gguf.py
llama.cpp
prompts/chat-with-baichuan.txt [new file with mode: 0644]

diff --git a/convert-baichuan-hf-to-gguf.py b/convert-baichuan-hf-to-gguf.py
new file mode 100755 (executable)
index 0000000..5b301de
--- /dev/null
@@ -0,0 +1,292 @@
+#!/usr/bin/env python3
+# HF baichuan --> gguf conversion
+
+from __future__ import annotations
+
+import argparse
+import json
+import os
+import struct
+import sys
+from pathlib import Path
+from typing import TYPE_CHECKING, Any
+import itertools
+import gguf
+import numpy as np
+import torch
+from sentencepiece import SentencePieceProcessor  # type: ignore[import]
+
+
+if TYPE_CHECKING:
+    from typing import TypeAlias
+
+NDArray: TypeAlias = 'np.ndarray[Any, Any]'
+
+# reverse HF permute back to original pth layout
+
+
+def reverse_hf_permute(weights: NDArray, n_head: int, n_kv_head: int | None = None) -> NDArray:
+    if n_kv_head is not None and n_head != n_kv_head:
+        n_head //= n_kv_head
+
+    return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
+            .swapaxes(1, 2)
+            .reshape(weights.shape))
+
+def reverse_hf_permute_part(weights: NDArray, n_part: int, n_head: int, n_head_kv: int| None = None) -> NDArray:
+        r = weights.shape[0] // 3
+        return (reverse_hf_permute(weights[r * n_part : r * n_part + r, ...], n_head, n_head_kv))
+
+def reverse_hf_part(weights: NDArray, n_part: int) -> NDArray:
+        r = weights.shape[0] // 3
+        return weights[r * n_part : r * n_part + r, ...]
+
+def count_model_parts(dir_model: str) -> int:
+    num_parts = 0
+
+    for filename in os.listdir(dir_model):
+        if filename.startswith("pytorch_model-"):
+            num_parts += 1
+
+    if num_parts > 0:
+        print("gguf: found " + str(num_parts) + " model parts")
+
+    return num_parts
+
+
+
+def parse_args() -> argparse.Namespace:
+    parser = argparse.ArgumentParser(description="Convert a HuggingFace LLaMA model to a GGML compatible file")
+    parser.add_argument("--vocab-only",  action="store_true",    help="extract only the vocab")
+    parser.add_argument("--outfile",     type=Path,              help="path to write to; default: based on input")
+    parser.add_argument("model",         type=Path,              help="directory containing model file, or model file itself (*.bin)")
+    parser.add_argument("ftype",     type=int, choices=[0, 1],   help="output format - use 0 for float32, 1 for float16", default = 1)
+    return parser.parse_args()
+
+args = parse_args()
+
+dir_model = args.model
+ftype = args.ftype
+if not dir_model.is_dir():
+    print(f'Error: {args.model} is not a directory', file = sys.stderr)
+    sys.exit(1)
+
+# possible tensor data types
+#   ftype == 0 -> float32
+#   ftype == 1 -> float16
+
+# map from ftype to string
+ftype_str = ["f32", "f16"]
+
+if args.outfile is not None:
+    fname_out = args.outfile
+else:
+    # output in the same directory as the model by default
+    fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'
+
+print("gguf: loading model "+dir_model.name)
+
+with open(dir_model / "config.json", "r", encoding="utf-8") as f:
+    hparams = json.load(f)
+print("hello print: ",hparams["architectures"][0])
+if hparams["architectures"][0] != "BaichuanForCausalLM":
+    print("Model architecture not supported: " + hparams["architectures"][0])
+
+    sys.exit()
+
+# get number of model parts
+num_parts = count_model_parts(dir_model)
+print(f"num_parts:{num_parts}\n")
+ARCH=gguf.MODEL_ARCH.BAICHUAN
+gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
+
+print("gguf: get model metadata")
+
+block_count = hparams["num_hidden_layers"]
+head_count = hparams["num_attention_heads"]
+
+if "num_key_value_heads" in hparams:
+    head_count_kv = hparams["num_key_value_heads"]
+else:
+    head_count_kv = head_count
+
+if "_name_or_path" in hparams:
+    hf_repo = hparams["_name_or_path"]
+else:
+    hf_repo = ""
+
+if "max_sequence_length" in hparams:
+    ctx_length = hparams["max_sequence_length"]
+elif "max_position_embeddings" in hparams:
+    ctx_length = hparams["max_position_embeddings"]
+elif "model_max_length" in hparams:
+    ctx_length = hparams["model_max_length"]
+else:
+    print("gguf: can not find ctx length parameter.")
+
+    sys.exit()
+
+
+gguf_writer.add_name(dir_model.name)
+gguf_writer.add_source_hf_repo(hf_repo)
+gguf_writer.add_tensor_data_layout("Meta AI original pth")
+gguf_writer.add_context_length(ctx_length)
+gguf_writer.add_embedding_length(hparams["hidden_size"])
+gguf_writer.add_block_count(block_count)
+gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
+gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"])
+gguf_writer.add_head_count(head_count)
+gguf_writer.add_head_count_kv(head_count_kv)
+gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
+
+if "rope_scaling" in hparams and hparams["rope_scaling"] != None and "factor" in hparams["rope_scaling"]:
+    if "type" in hparams["rope_scaling"]:
+        if hparams["rope_scaling"]["type"] == "linear":
+            gguf_writer.add_rope_scale_linear(hparams["rope_scaling"]["factor"])
+
+
+# TOKENIZATION
+
+print("gguf: get tokenizer metadata")
+
+tokens: list[bytes] = []
+scores: list[float] = []
+toktypes: list[int] = []
+
+tokenizer_model_file = dir_model / 'tokenizer.model'
+if not tokenizer_model_file.is_file():
+    print(f'Error: Missing {tokenizer_model_file}', file = sys.stderr)
+    sys.exit(1)
+
+# vocab type sentencepiece
+print("gguf: get sentencepiece tokenizer vocab, scores and token types")
+
+tokenizer = SentencePieceProcessor(str(tokenizer_model_file))
+
+for i in range(tokenizer.vocab_size()):
+    text: bytes
+    score: float
+
+    piece = tokenizer.id_to_piece(i)
+    text = piece.encode("utf-8")
+    score = tokenizer.get_score(i)
+
+    toktype = 1  # defualt to normal token type
+    if tokenizer.is_unknown(i):
+        toktype = 2
+    if tokenizer.is_control(i):
+        toktype = 3
+
+    # toktype = 4 is user-defined = tokens from added_tokens.json
+
+    if tokenizer.is_unused(i):
+        toktype = 5
+    if tokenizer.is_byte(i):
+        toktype = 6
+
+    tokens.append(text)
+    scores.append(score)
+    toktypes.append(toktype)
+
+added_tokens_file = dir_model / 'added_tokens.json'
+if added_tokens_file.is_file():
+    with open(added_tokens_file, "r", encoding="utf-8") as f:
+        addtokens_json = json.load(f)
+
+        print("gguf: get added tokens")
+
+        for key in addtokens_json:
+            tokens.append( key.encode("utf-8") )
+            scores.append(-1000.0)
+            toktypes.append(4) # user-defined token type
+
+
+gguf_writer.add_tokenizer_model("llama")
+gguf_writer.add_token_list(tokens)
+gguf_writer.add_token_scores(scores)
+gguf_writer.add_token_types(toktypes)
+
+special_vocab = gguf.SpecialVocab(dir_model)
+special_vocab.add_to_gguf(gguf_writer)
+
+# TENSORS
+
+tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
+
+# tensor info
+print("gguf: get tensor metadata")
+
+if num_parts == 0:
+    part_names = iter(("pytorch_model.bin",))
+else:
+    part_names = (
+        f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
+    )
+
+
+for part_name in part_names:
+    if args.vocab_only:
+        break
+    print("gguf: loading model part '" + part_name + "'")
+    model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")
+
+    tmp=model_part
+    for i in range(block_count):
+        if f"model.layers.{i}.self_attn.W_pack.weight" in model_part:
+            print(f"Unpacking and permuting layer {i}")
+            tmp[f"model.layers.{i}.self_attn.q_proj.weight"]=reverse_hf_permute_part(model_part[f"model.layers.{i}.self_attn.W_pack.weight"],0,head_count,head_count)
+            tmp[f"model.layers.{i}.self_attn.k_proj.weight"]=reverse_hf_permute_part(model_part[f"model.layers.{i}.self_attn.W_pack.weight"],1,head_count,head_count_kv)
+            tmp[f"model.layers.{i}.self_attn.v_proj.weight"]=reverse_hf_part(model_part[f"model.layers.{i}.self_attn.W_pack.weight"],2)
+            del tmp[f"model.layers.{i}.self_attn.W_pack.weight"]
+
+    for name in model_part.keys():
+        data = model_part[name]
+        # we don't need these
+        if name.endswith(".rotary_emb.inv_freq"):
+            continue
+
+        old_dtype = data.dtype
+
+        # convert any unsupported data types to float32
+        if data.dtype != torch.float16 and data.dtype != torch.float32:
+            data = data.to(torch.float32)
+
+        data = data.squeeze().numpy()
+
+        # map tensor names
+        new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
+        if new_name is None:
+            print("Can not map tensor '" + name + "'")
+            sys.exit()
+
+        n_dims = len(data.shape)
+        data_dtype = data.dtype
+
+        # if f32 desired, convert any float16 to float32
+        if ftype == 0 and data_dtype == np.float16:
+            data = data.astype(np.float32)
+
+        # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
+        if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
+            data = data.astype(np.float32)
+
+        # if f16 desired, convert any float32 2-dim weight tensors to float16
+        if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
+            data = data.astype(np.float16)
+
+        print(name + " -> " +  new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
+        gguf_writer.add_tensor(new_name, data)
+
+
+print("gguf: write header")
+gguf_writer.write_header_to_file()
+print("gguf: write metadata")
+gguf_writer.write_kv_data_to_file()
+if not args.vocab_only:
+    print("gguf: write tensors")
+    gguf_writer.write_tensors_to_file()
+
+gguf_writer.close()
+
+print(f"gguf: model successfully exported to '{fname_out}'")
+print("")
index d377cd56d88e792ddc61a7b5ded1e1c14af065cc..bda13ac0067c2aae4eda32ff6fc28e8a3c1395d2 100644 (file)
@@ -79,6 +79,7 @@ KEY_TOKENIZER_RWKV       = "tokenizer.rwkv.world"
 class MODEL_ARCH(IntEnum):
     LLAMA  : int = auto()
     FALCON : int = auto()
+    BAICHUAN:int = auto()
     GPT2   : int = auto()
     GPTJ   : int = auto()
     GPTNEOX: int = auto()
@@ -108,6 +109,7 @@ class MODEL_TENSOR(IntEnum):
 MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.LLAMA:   "llama",
     MODEL_ARCH.FALCON:  "falcon",
+    MODEL_ARCH.BAICHUAN:"baichuan",
     MODEL_ARCH.GPT2:    "gpt2",
     MODEL_ARCH.GPTJ:    "gptj",
     MODEL_ARCH.GPTNEOX: "gptneox",
@@ -153,6 +155,22 @@ MODEL_TENSOR_NAMES: dict[MODEL_ARCH, dict[MODEL_TENSOR, str]] = {
         MODEL_TENSOR.FFN_DOWN:    "blk.{bid}.ffn_down",
         MODEL_TENSOR.FFN_UP:      "blk.{bid}.ffn_up",
     },
+    MODEL_ARCH.BAICHUAN: {
+        MODEL_TENSOR.TOKEN_EMBD:    "token_embd",
+        MODEL_TENSOR.OUTPUT_NORM:   "output_norm",
+        MODEL_TENSOR.OUTPUT:        "output",
+        MODEL_TENSOR.ROPE_FREQS:    "rope_freqs",
+        MODEL_TENSOR.ATTN_NORM:     "blk.{bid}.attn_norm",
+        MODEL_TENSOR.ATTN_Q:        "blk.{bid}.attn_q",
+        MODEL_TENSOR.ATTN_K:        "blk.{bid}.attn_k",
+        MODEL_TENSOR.ATTN_V:        "blk.{bid}.attn_v",
+        MODEL_TENSOR.ATTN_OUT:      "blk.{bid}.attn_output",
+        MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
+        MODEL_TENSOR.FFN_NORM:      "blk.{bid}.ffn_norm",
+        MODEL_TENSOR.FFN_GATE:      "blk.{bid}.ffn_gate",
+        MODEL_TENSOR.FFN_DOWN:      "blk.{bid}.ffn_down",
+        MODEL_TENSOR.FFN_UP:        "blk.{bid}.ffn_up",
+    },
     MODEL_ARCH.GPT2: {
         # TODO
     },
@@ -165,6 +183,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.ROPE_FREQS,
         MODEL_TENSOR.ATTN_ROT_EMBD,
     ],
+    MODEL_ARCH.BAICHUAN: [
+        MODEL_TENSOR.ROPE_FREQS,
+        MODEL_TENSOR.ATTN_ROT_EMBD,
+    ],
 }
 
 
@@ -187,7 +209,7 @@ class TensorNameMap:
         # Output
         MODEL_TENSOR.OUTPUT: (
             "embed_out", # gptneox
-            "lm_head",   # gpt2 mpt falcon llama-hf
+            "lm_head",   # gpt2 mpt falcon llama-hf baichuan
             "output",    # llama-pth
         ),
 
@@ -195,7 +217,7 @@ class TensorNameMap:
         MODEL_TENSOR.OUTPUT_NORM: (
             "gpt_neox.final_layer_norm", # gptneox
             "transformer.ln_f",          # gpt2 falcon
-            "model.norm",                # llama-hf
+            "model.norm",                # llama-hf baichuan
             "norm",                      # llama-pth
         ),
 
index cbaf8edac0ba88134e583a3a50c7af28a776d0da..146605d44fa1a4d546ddabf0686118799706cba2 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -155,6 +155,7 @@ static std::string format(const char * fmt, ...) {
 enum llm_arch {
     LLM_ARCH_LLAMA,
     LLM_ARCH_FALCON,
+    LLM_ARCH_BAICHUAN,
     LLM_ARCH_GPT2,
     LLM_ARCH_GPTJ,
     LLM_ARCH_GPTNEOX,
@@ -169,6 +170,7 @@ static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
     { LLM_ARCH_GPTJ,    "gptj"    },
     { LLM_ARCH_GPTNEOX, "gptneox" },
     { LLM_ARCH_MPT,     "mpt"     },
+    { LLM_ARCH_BAICHUAN,"baichuan" },
 };
 
 enum llm_kv {
@@ -309,6 +311,25 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_BAICHUAN,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
     {
         LLM_ARCH_FALCON,
         {
@@ -1683,6 +1704,15 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_BAICHUAN:
+            {
+                GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
+                switch (hparams.n_layer) {
+                    case 32: model.type = e_model::MODEL_7B; break;
+                    case 40: model.type = e_model::MODEL_13B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
         default: (void)0;
     };
 
@@ -1923,7 +1953,6 @@ static void llm_load_tensors(
         const int64_t n_vocab    = hparams.n_vocab;
 
         const auto tn = LLM_TN(model.arch);
-
         switch (model.arch) {
             case LLM_ARCH_LLAMA:
                 {
@@ -1966,6 +1995,72 @@ static void llm_load_tensors(
 
                     model.layers.resize(n_layer);
 
+                    for (uint32_t i = 0; i < n_layer; ++i) {
+                        const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
+                        const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
+
+                        layer.wq = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd},     backend_split);
+                        layer.wk = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, backend_split);
+                        layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, backend_split);
+                        layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd},     backend_split);
+
+                        layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
+
+                        layer.w1 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, backend_split);
+                        layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, backend_split);
+                        layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, backend_split);
+
+                        if (backend == GGML_BACKEND_GPU) {
+                            vram_weights +=
+                                ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk)       +
+                                ggml_nbytes(layer.wv)        + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) +
+                                ggml_nbytes(layer.w1)        + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3);
+                        }
+                    }
+                } break;
+            case LLM_ARCH_BAICHUAN:
+                {
+                    model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
+                    {
+                        ggml_backend backend_norm;
+                        ggml_backend backend_output;
+
+                        if (n_gpu_layers > int(n_layer)) {
+                            // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
+                            // on Windows however this is detrimental unless everything is on the GPU
+#ifndef _WIN32
+                            backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+#else
+                            backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+#endif // _WIN32
+
+                            backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
+                        } else {
+                            backend_norm   = GGML_BACKEND_CPU;
+                            backend_output = GGML_BACKEND_CPU;
+                        }
+
+                        model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd},          backend_norm);
+                        model.output      = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, backend_output);
+
+                        if (backend_norm == GGML_BACKEND_GPU) {
+                            vram_weights += ggml_nbytes(model.output_norm);
+                        }
+                        if (backend_output == GGML_BACKEND_GPU_SPLIT) {
+                            vram_weights += ggml_nbytes(model.output);
+                        }
+                    }
+
+                    const uint32_t n_ff = hparams.n_ff;
+
+                    const int i_gpu_start = n_layer - n_gpu_layers;
+
+                    model.layers.resize(n_layer);
+
                     for (uint32_t i = 0; i < n_layer; ++i) {
                         const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
                         const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
@@ -2542,6 +2637,367 @@ static struct ggml_cgraph * llm_build_llama(
     return gf;
 }
 
+
+static struct ggml_cgraph * llm_build_baichaun(
+         llama_context & lctx,
+     const llama_token * tokens,
+           const float * embd,
+                   int   n_tokens,
+                   int   n_past) {
+
+    GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT
+
+    const int N = n_tokens;
+
+    const auto & model   = lctx.model;
+    const auto & hparams = model.hparams;
+
+    const auto & kv_self = lctx.kv_self;
+
+    GGML_ASSERT(!!kv_self.ctx);
+
+    const int64_t n_embd      = hparams.n_embd;
+    const int64_t n_layer     = hparams.n_layer;
+    const int64_t n_ctx       = hparams.n_ctx;
+    const int64_t n_head      = hparams.n_head;
+    const int64_t n_head_kv   = hparams.n_head_kv;
+    const int64_t n_embd_head = hparams.n_embd_head();
+    const int64_t n_embd_gqa  = hparams.n_embd_gqa();
+
+    GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+    const float freq_base    = hparams.rope_freq_base;
+    const float freq_scale   = hparams.rope_freq_scale;
+    const float norm_rms_eps = hparams.f_norm_rms_eps;
+
+    const int n_gpu_layers = model.n_gpu_layers;
+
+    auto & buf_compute = lctx.buf_compute;
+
+    struct ggml_init_params params = {
+        /*.mem_size   =*/ buf_compute.size,
+        /*.mem_buffer =*/ buf_compute.data,
+        /*.no_alloc   =*/ false,
+    };
+
+    params.no_alloc = true;
+
+    struct ggml_context * ctx0 = ggml_init(params);
+
+    ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+    struct ggml_tensor * cur;
+    struct ggml_tensor * inpL;
+
+    if (tokens) {
+        struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+
+        ggml_allocr_alloc(lctx.alloc, inp_tokens);
+        if (!ggml_allocr_is_measure(lctx.alloc)) {
+            memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
+        }
+        ggml_set_name(inp_tokens, "inp_tokens");
+
+        inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
+    } else {
+#ifdef GGML_USE_MPI
+        GGML_ASSERT(false && "not implemented");
+#endif
+
+        inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N);
+
+        ggml_allocr_alloc(lctx.alloc, inpL);
+        if (!ggml_allocr_is_measure(lctx.alloc)) {
+            memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL));
+        }
+    }
+
+    const int i_gpu_start = n_layer - n_gpu_layers;
+    (void) i_gpu_start;
+
+    // offload functions set the tensor output backend to GPU
+    // tensors are GPU-accelerated if any input or the output has been offloaded
+    //
+    // with the low VRAM option VRAM scratch is disabled in llama_load_model_internal
+    // in that case ggml_cuda_assign_buffers has no effect
+    offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
+    offload_func_t offload_func_kq = llama_nop;
+    offload_func_t offload_func_v  = llama_nop;
+
+#ifdef GGML_USE_CUBLAS
+    if (n_gpu_layers > n_layer) {
+        offload_func_nr = ggml_cuda_assign_buffers_no_alloc;
+    }
+    if (n_gpu_layers > n_layer + 1) {
+        offload_func_v  = ggml_cuda_assign_buffers_no_alloc;
+    }
+    if (n_gpu_layers > n_layer + 2) {
+        offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
+    }
+#endif // GGML_USE_CUBLAS
+
+    struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+    ggml_allocr_alloc(lctx.alloc, KQ_scale);
+    if (!ggml_allocr_is_measure(lctx.alloc)) {
+        ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
+    }
+    ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
+
+    for (int il = 0; il < n_layer; ++il) {
+        ggml_format_name(inpL, "layer_inp_%d", il);
+
+        offload_func_t offload_func = llama_nop;
+
+#ifdef GGML_USE_CUBLAS
+        if (il >= i_gpu_start) {
+            offload_func = ggml_cuda_assign_buffers_no_alloc;
+        }
+#endif // GGML_USE_CUBLAS
+
+        struct ggml_tensor * inpSA = inpL;
+
+        // norm
+        {
+            cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
+            offload_func(cur);
+            ggml_set_name(cur, "rms_norm_0");
+
+            // cur = cur*attn_norm(broadcasted)
+            cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm);
+            offload_func(cur);
+            ggml_set_name(cur, "attention_norm_0");
+        }
+
+        // self-attention
+        {
+            // compute Q and K and RoPE them
+            struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+            offload_func_kq(tmpk);
+            ggml_set_name(tmpk, "tmpk");
+
+            struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+            offload_func_kq(tmpq);
+            ggml_set_name(tmpq, "tmpq");
+
+            struct ggml_tensor * Kcur;
+            struct ggml_tensor * Qcur;
+            switch (model.type) {
+                case MODEL_7B:
+                    Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
+                    Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N),    n_past, n_embd_head, 0, 0, freq_base, freq_scale);
+                    break;
+                case MODEL_13B:
+                    Kcur  = ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N);
+                    Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N);
+                    break;
+                default:
+                    GGML_ASSERT(false);
+            }
+
+            offload_func_kq(Kcur);
+            ggml_set_name(Kcur, "Kcur");
+
+            offload_func_kq(Qcur);
+            ggml_set_name(Qcur, "Qcur");
+
+            // store key and value to memory
+            {
+                // compute the transposed [N, n_embd] V matrix
+
+                struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                offload_func_v(tmpv);
+                ggml_set_name(tmpv, "tmpv");
+
+                struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N));
+                offload_func_v(Vcur);
+                ggml_set_name(Vcur, "Vcur");
+
+                struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past));
+                offload_func_kq(k);
+                ggml_set_name(k, "k");
+
+                struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa,
+                        (   n_ctx)*ggml_element_size(kv_self.v),
+                        (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v));
+                offload_func_v(v);
+                ggml_set_name(v, "v");
+
+                // important: storing RoPE-ed version of K in the KV cache!
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
+            }
+
+            struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
+            offload_func_kq(Q);
+            ggml_set_name(Q, "Q");
+
+            struct ggml_tensor * K =
+                ggml_view_3d(ctx0, kv_self.k,
+                        n_embd_head, n_past + N, n_head_kv,
+                        ggml_element_size(kv_self.k)*n_embd_gqa,
+                        ggml_element_size(kv_self.k)*n_embd_head,
+                        ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
+            offload_func_kq(K);
+            ggml_set_name(K, "K");
+
+            // K * Q
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+            offload_func_kq(KQ);
+            ggml_set_name(KQ, "KQ");
+
+            // KQ_scaled = KQ / sqrt(n_embd_head)
+            // KQ_scaled shape [n_past + N, N, n_head, 1]
+            struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
+            offload_func_kq(KQ_scaled);
+            ggml_set_name(KQ_scaled, "KQ_scaled");
+
+            struct ggml_tensor * KQ_masked;
+            struct ggml_tensor * KQ_scaled_alibi;
+
+            switch (model.type) {
+                case MODEL_7B:
+                    KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
+                    break;
+                case MODEL_13B:
+                    KQ_scaled_alibi =ggml_alibi(ctx0, KQ_scaled, n_past, n_head, 8);
+                    ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi");
+                    KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_alibi, n_past);
+                    break;
+                default:
+                    GGML_ASSERT(false);
+            }
+            // KQ_masked = mask_past(KQ_scaled)
+            // struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
+            // struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_alibi, n_past);
+            // offload_func_kq(KQ_masked);
+            // ggml_set_name(KQ_masked, "KQ_masked");
+
+            // KQ = soft_max(KQ_masked)
+            struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
+            offload_func_v(KQ_soft_max);
+            ggml_set_name(KQ_soft_max, "KQ_soft_max");
+
+            // split cached V into n_head heads
+            struct ggml_tensor * V =
+                ggml_view_3d(ctx0, kv_self.v,
+                        n_past + N, n_embd_head, n_head_kv,
+                        ggml_element_size(kv_self.v)*n_ctx,
+                        ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
+                        ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
+            offload_func_v(V);
+            ggml_set_name(V, "V");
+
+#if 1
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+            offload_func_v(KQV);
+            ggml_set_name(KQV, "KQV");
+#else
+            // make V contiguous in memory to speed up the matmul, however we waste time on the copy
+            // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
+            // is there a better way?
+            struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd_head, n_head));
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
+#endif
+
+            // KQV_merged = KQV.permute(0, 2, 1, 3)
+            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+            offload_func_v(KQV_merged);
+            ggml_set_name(KQV_merged, "KQV_merged");
+
+            // cur = KQV_merged.contiguous().view(n_embd, N)
+            cur = ggml_cpy(ctx0,
+                    KQV_merged,
+                    ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
+            offload_func_v(cur);
+            ggml_set_name(cur, "KQV_merged_contiguous");
+
+            // projection (no bias)
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].wo,
+                    cur);
+            offload_func(cur);
+            ggml_set_name(cur, "result_wo");
+        }
+
+        struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
+        offload_func(inpFF);
+        ggml_set_name(inpFF, "inpFF");
+
+        // feed-forward network
+        {
+            // norm
+            {
+                cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps);
+                offload_func(cur);
+                ggml_set_name(cur, "rms_norm_1");
+
+                // cur = cur*ffn_norm(broadcasted)
+                cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
+                offload_func(cur);
+                ggml_set_name(cur, "ffn_norm");
+            }
+
+            struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
+                    model.layers[il].w3,
+                    cur);
+            offload_func(tmp);
+            ggml_set_name(tmp, "result_w3");
+
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].w1,
+                    cur);
+            offload_func(cur);
+            ggml_set_name(cur, "result_w1");
+
+            // SILU activation
+            cur = ggml_silu(ctx0, cur);
+            offload_func(cur);
+            ggml_set_name(cur, "silu");
+
+            cur = ggml_mul(ctx0, cur, tmp);
+            offload_func(cur);
+            ggml_set_name(cur, "silu_x_result_w3");
+
+            cur = ggml_mul_mat(ctx0,
+                    model.layers[il].w2,
+                    cur);
+            offload_func(cur);
+            ggml_set_name(cur, "result_w2");
+        }
+
+        cur = ggml_add(ctx0, cur, inpFF);
+        offload_func(cur);
+        ggml_set_name(cur, "inpFF_+_result_w2");
+
+        // input for next layer
+        inpL = cur;
+    }
+
+    cur = inpL;
+
+    // norm
+    {
+        cur = ggml_rms_norm(ctx0, cur, norm_rms_eps);
+        offload_func_nr(cur);
+        ggml_set_name(cur, "rms_norm_2");
+
+        // cur = cur*norm(broadcasted)
+        cur = ggml_mul(ctx0, cur, model.output_norm);
+        // offload_func_nr(cur); // TODO CPU + GPU mirrored backend
+        ggml_set_name(cur, "result_norm");
+    }
+
+    // lm_head
+    cur = ggml_mul_mat(ctx0, model.output, cur);
+    ggml_set_name(cur, "result_output");
+
+    ggml_build_forward_expand(gf, cur);
+
+    ggml_free(ctx0);
+
+    return gf;
+}
+
 static struct ggml_cgraph * llm_build_falcon(
          llama_context & lctx,
      const llama_token * tokens,
@@ -2864,6 +3320,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm_build_llama(lctx, tokens, embd, n_tokens, n_past);
             } break;
+        case LLM_ARCH_BAICHUAN:
+            {
+                result = llm_build_baichaun(lctx, tokens, embd, n_tokens, n_past);
+            } break;
         case LLM_ARCH_FALCON:
             {
                 result = llm_build_falcon(lctx, tokens, embd, n_tokens, n_past);
diff --git a/prompts/chat-with-baichuan.txt b/prompts/chat-with-baichuan.txt
new file mode 100644 (file)
index 0000000..11626b6
--- /dev/null
@@ -0,0 +1,4 @@
+以下内容为人类用户与与一位智能助手的对话。
+
+用户:你好!
+助手: