--- /dev/null
+#!/usr/bin/env python3
+# HF baichuan --> gguf conversion
+
+from __future__ import annotations
+
+import argparse
+import json
+import os
+import struct
+import sys
+from pathlib import Path
+from typing import TYPE_CHECKING, Any
+import itertools
+import gguf
+import numpy as np
+import torch
+from sentencepiece import SentencePieceProcessor # type: ignore[import]
+
+
+if TYPE_CHECKING:
+ from typing import TypeAlias
+
+NDArray: TypeAlias = 'np.ndarray[Any, Any]'
+
+# reverse HF permute back to original pth layout
+
+
+def reverse_hf_permute(weights: NDArray, n_head: int, n_kv_head: int | None = None) -> NDArray:
+ if n_kv_head is not None and n_head != n_kv_head:
+ n_head //= n_kv_head
+
+ return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
+ .swapaxes(1, 2)
+ .reshape(weights.shape))
+
+def reverse_hf_permute_part(weights: NDArray, n_part: int, n_head: int, n_head_kv: int| None = None) -> NDArray:
+ r = weights.shape[0] // 3
+ return (reverse_hf_permute(weights[r * n_part : r * n_part + r, ...], n_head, n_head_kv))
+
+def reverse_hf_part(weights: NDArray, n_part: int) -> NDArray:
+ r = weights.shape[0] // 3
+ return weights[r * n_part : r * n_part + r, ...]
+
+def count_model_parts(dir_model: str) -> int:
+ num_parts = 0
+
+ for filename in os.listdir(dir_model):
+ if filename.startswith("pytorch_model-"):
+ num_parts += 1
+
+ if num_parts > 0:
+ print("gguf: found " + str(num_parts) + " model parts")
+
+ return num_parts
+
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(description="Convert a HuggingFace LLaMA model to a GGML compatible file")
+ parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
+ parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
+ parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.bin)")
+ parser.add_argument("ftype", type=int, choices=[0, 1], help="output format - use 0 for float32, 1 for float16", default = 1)
+ return parser.parse_args()
+
+args = parse_args()
+
+dir_model = args.model
+ftype = args.ftype
+if not dir_model.is_dir():
+ print(f'Error: {args.model} is not a directory', file = sys.stderr)
+ sys.exit(1)
+
+# possible tensor data types
+# ftype == 0 -> float32
+# ftype == 1 -> float16
+
+# map from ftype to string
+ftype_str = ["f32", "f16"]
+
+if args.outfile is not None:
+ fname_out = args.outfile
+else:
+ # output in the same directory as the model by default
+ fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'
+
+print("gguf: loading model "+dir_model.name)
+
+with open(dir_model / "config.json", "r", encoding="utf-8") as f:
+ hparams = json.load(f)
+print("hello print: ",hparams["architectures"][0])
+if hparams["architectures"][0] != "BaichuanForCausalLM":
+ print("Model architecture not supported: " + hparams["architectures"][0])
+
+ sys.exit()
+
+# get number of model parts
+num_parts = count_model_parts(dir_model)
+print(f"num_parts:{num_parts}\n")
+ARCH=gguf.MODEL_ARCH.BAICHUAN
+gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
+
+print("gguf: get model metadata")
+
+block_count = hparams["num_hidden_layers"]
+head_count = hparams["num_attention_heads"]
+
+if "num_key_value_heads" in hparams:
+ head_count_kv = hparams["num_key_value_heads"]
+else:
+ head_count_kv = head_count
+
+if "_name_or_path" in hparams:
+ hf_repo = hparams["_name_or_path"]
+else:
+ hf_repo = ""
+
+if "max_sequence_length" in hparams:
+ ctx_length = hparams["max_sequence_length"]
+elif "max_position_embeddings" in hparams:
+ ctx_length = hparams["max_position_embeddings"]
+elif "model_max_length" in hparams:
+ ctx_length = hparams["model_max_length"]
+else:
+ print("gguf: can not find ctx length parameter.")
+
+ sys.exit()
+
+
+gguf_writer.add_name(dir_model.name)
+gguf_writer.add_source_hf_repo(hf_repo)
+gguf_writer.add_tensor_data_layout("Meta AI original pth")
+gguf_writer.add_context_length(ctx_length)
+gguf_writer.add_embedding_length(hparams["hidden_size"])
+gguf_writer.add_block_count(block_count)
+gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
+gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"])
+gguf_writer.add_head_count(head_count)
+gguf_writer.add_head_count_kv(head_count_kv)
+gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
+
+if "rope_scaling" in hparams and hparams["rope_scaling"] != None and "factor" in hparams["rope_scaling"]:
+ if "type" in hparams["rope_scaling"]:
+ if hparams["rope_scaling"]["type"] == "linear":
+ gguf_writer.add_rope_scale_linear(hparams["rope_scaling"]["factor"])
+
+
+# TOKENIZATION
+
+print("gguf: get tokenizer metadata")
+
+tokens: list[bytes] = []
+scores: list[float] = []
+toktypes: list[int] = []
+
+tokenizer_model_file = dir_model / 'tokenizer.model'
+if not tokenizer_model_file.is_file():
+ print(f'Error: Missing {tokenizer_model_file}', file = sys.stderr)
+ sys.exit(1)
+
+# vocab type sentencepiece
+print("gguf: get sentencepiece tokenizer vocab, scores and token types")
+
+tokenizer = SentencePieceProcessor(str(tokenizer_model_file))
+
+for i in range(tokenizer.vocab_size()):
+ text: bytes
+ score: float
+
+ piece = tokenizer.id_to_piece(i)
+ text = piece.encode("utf-8")
+ score = tokenizer.get_score(i)
+
+ toktype = 1 # defualt to normal token type
+ if tokenizer.is_unknown(i):
+ toktype = 2
+ if tokenizer.is_control(i):
+ toktype = 3
+
+ # toktype = 4 is user-defined = tokens from added_tokens.json
+
+ if tokenizer.is_unused(i):
+ toktype = 5
+ if tokenizer.is_byte(i):
+ toktype = 6
+
+ tokens.append(text)
+ scores.append(score)
+ toktypes.append(toktype)
+
+added_tokens_file = dir_model / 'added_tokens.json'
+if added_tokens_file.is_file():
+ with open(added_tokens_file, "r", encoding="utf-8") as f:
+ addtokens_json = json.load(f)
+
+ print("gguf: get added tokens")
+
+ for key in addtokens_json:
+ tokens.append( key.encode("utf-8") )
+ scores.append(-1000.0)
+ toktypes.append(4) # user-defined token type
+
+
+gguf_writer.add_tokenizer_model("llama")
+gguf_writer.add_token_list(tokens)
+gguf_writer.add_token_scores(scores)
+gguf_writer.add_token_types(toktypes)
+
+special_vocab = gguf.SpecialVocab(dir_model)
+special_vocab.add_to_gguf(gguf_writer)
+
+# TENSORS
+
+tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
+
+# tensor info
+print("gguf: get tensor metadata")
+
+if num_parts == 0:
+ part_names = iter(("pytorch_model.bin",))
+else:
+ part_names = (
+ f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
+ )
+
+
+for part_name in part_names:
+ if args.vocab_only:
+ break
+ print("gguf: loading model part '" + part_name + "'")
+ model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")
+
+ tmp=model_part
+ for i in range(block_count):
+ if f"model.layers.{i}.self_attn.W_pack.weight" in model_part:
+ print(f"Unpacking and permuting layer {i}")
+ tmp[f"model.layers.{i}.self_attn.q_proj.weight"]=reverse_hf_permute_part(model_part[f"model.layers.{i}.self_attn.W_pack.weight"],0,head_count,head_count)
+ tmp[f"model.layers.{i}.self_attn.k_proj.weight"]=reverse_hf_permute_part(model_part[f"model.layers.{i}.self_attn.W_pack.weight"],1,head_count,head_count_kv)
+ tmp[f"model.layers.{i}.self_attn.v_proj.weight"]=reverse_hf_part(model_part[f"model.layers.{i}.self_attn.W_pack.weight"],2)
+ del tmp[f"model.layers.{i}.self_attn.W_pack.weight"]
+
+ for name in model_part.keys():
+ data = model_part[name]
+ # we don't need these
+ if name.endswith(".rotary_emb.inv_freq"):
+ continue
+
+ old_dtype = data.dtype
+
+ # convert any unsupported data types to float32
+ if data.dtype != torch.float16 and data.dtype != torch.float32:
+ data = data.to(torch.float32)
+
+ data = data.squeeze().numpy()
+
+ # map tensor names
+ new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
+ if new_name is None:
+ print("Can not map tensor '" + name + "'")
+ sys.exit()
+
+ n_dims = len(data.shape)
+ data_dtype = data.dtype
+
+ # if f32 desired, convert any float16 to float32
+ if ftype == 0 and data_dtype == np.float16:
+ data = data.astype(np.float32)
+
+ # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
+ if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
+ data = data.astype(np.float32)
+
+ # if f16 desired, convert any float32 2-dim weight tensors to float16
+ if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
+ data = data.astype(np.float16)
+
+ print(name + " -> " + new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
+ gguf_writer.add_tensor(new_name, data)
+
+
+print("gguf: write header")
+gguf_writer.write_header_to_file()
+print("gguf: write metadata")
+gguf_writer.write_kv_data_to_file()
+if not args.vocab_only:
+ print("gguf: write tensors")
+ gguf_writer.write_tensors_to_file()
+
+gguf_writer.close()
+
+print(f"gguf: model successfully exported to '{fname_out}'")
+print("")
enum llm_arch {
LLM_ARCH_LLAMA,
LLM_ARCH_FALCON,
+ LLM_ARCH_BAICHUAN,
LLM_ARCH_GPT2,
LLM_ARCH_GPTJ,
LLM_ARCH_GPTNEOX,
{ LLM_ARCH_GPTJ, "gptj" },
{ LLM_ARCH_GPTNEOX, "gptneox" },
{ LLM_ARCH_MPT, "mpt" },
+ { LLM_ARCH_BAICHUAN,"baichuan" },
};
enum llm_kv {
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
+ {
+ LLM_ARCH_BAICHUAN,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ },
+ },
{
LLM_ARCH_FALCON,
{
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
+ case LLM_ARCH_BAICHUAN:
+ {
+ GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
+ switch (hparams.n_layer) {
+ case 32: model.type = e_model::MODEL_7B; break;
+ case 40: model.type = e_model::MODEL_13B; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ }
+ } break;
default: (void)0;
};
const int64_t n_vocab = hparams.n_vocab;
const auto tn = LLM_TN(model.arch);
-
switch (model.arch) {
case LLM_ARCH_LLAMA:
{
model.layers.resize(n_layer);
+ for (uint32_t i = 0; i < n_layer; ++i) {
+ const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
+ const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
+
+ auto & layer = model.layers[i];
+
+ layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
+
+ layer.wq = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, backend_split);
+ layer.wk = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, backend_split);
+ layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split);
+ layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
+
+ layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
+
+ layer.w1 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split);
+ layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
+ layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
+
+ if (backend == GGML_BACKEND_GPU) {
+ vram_weights +=
+ ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) +
+ ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) +
+ ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3);
+ }
+ }
+ } break;
+ case LLM_ARCH_BAICHUAN:
+ {
+ model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
+ {
+ ggml_backend backend_norm;
+ ggml_backend backend_output;
+
+ if (n_gpu_layers > int(n_layer)) {
+ // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
+ // on Windows however this is detrimental unless everything is on the GPU
+#ifndef _WIN32
+ backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+#else
+ backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+#endif // _WIN32
+
+ backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
+ } else {
+ backend_norm = GGML_BACKEND_CPU;
+ backend_output = GGML_BACKEND_CPU;
+ }
+
+ model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm);
+ model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output);
+
+ if (backend_norm == GGML_BACKEND_GPU) {
+ vram_weights += ggml_nbytes(model.output_norm);
+ }
+ if (backend_output == GGML_BACKEND_GPU_SPLIT) {
+ vram_weights += ggml_nbytes(model.output);
+ }
+ }
+
+ const uint32_t n_ff = hparams.n_ff;
+
+ const int i_gpu_start = n_layer - n_gpu_layers;
+
+ model.layers.resize(n_layer);
+
for (uint32_t i = 0; i < n_layer; ++i) {
const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
return gf;
}
+
+static struct ggml_cgraph * llm_build_baichaun(
+ llama_context & lctx,
+ const llama_token * tokens,
+ const float * embd,
+ int n_tokens,
+ int n_past) {
+
+ GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT
+
+ const int N = n_tokens;
+
+ const auto & model = lctx.model;
+ const auto & hparams = model.hparams;
+
+ const auto & kv_self = lctx.kv_self;
+
+ GGML_ASSERT(!!kv_self.ctx);
+
+ const int64_t n_embd = hparams.n_embd;
+ const int64_t n_layer = hparams.n_layer;
+ const int64_t n_ctx = hparams.n_ctx;
+ const int64_t n_head = hparams.n_head;
+ const int64_t n_head_kv = hparams.n_head_kv;
+ const int64_t n_embd_head = hparams.n_embd_head();
+ const int64_t n_embd_gqa = hparams.n_embd_gqa();
+
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+ const float freq_base = hparams.rope_freq_base;
+ const float freq_scale = hparams.rope_freq_scale;
+ const float norm_rms_eps = hparams.f_norm_rms_eps;
+
+ const int n_gpu_layers = model.n_gpu_layers;
+
+ auto & buf_compute = lctx.buf_compute;
+
+ struct ggml_init_params params = {
+ /*.mem_size =*/ buf_compute.size,
+ /*.mem_buffer =*/ buf_compute.data,
+ /*.no_alloc =*/ false,
+ };
+
+ params.no_alloc = true;
+
+ struct ggml_context * ctx0 = ggml_init(params);
+
+ ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
+
+ if (tokens) {
+ struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+
+ ggml_allocr_alloc(lctx.alloc, inp_tokens);
+ if (!ggml_allocr_is_measure(lctx.alloc)) {
+ memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
+ }
+ ggml_set_name(inp_tokens, "inp_tokens");
+
+ inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
+ } else {
+#ifdef GGML_USE_MPI
+ GGML_ASSERT(false && "not implemented");
+#endif
+
+ inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N);
+
+ ggml_allocr_alloc(lctx.alloc, inpL);
+ if (!ggml_allocr_is_measure(lctx.alloc)) {
+ memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL));
+ }
+ }
+
+ const int i_gpu_start = n_layer - n_gpu_layers;
+ (void) i_gpu_start;
+
+ // offload functions set the tensor output backend to GPU
+ // tensors are GPU-accelerated if any input or the output has been offloaded
+ //
+ // with the low VRAM option VRAM scratch is disabled in llama_load_model_internal
+ // in that case ggml_cuda_assign_buffers has no effect
+ offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
+ offload_func_t offload_func_kq = llama_nop;
+ offload_func_t offload_func_v = llama_nop;
+
+#ifdef GGML_USE_CUBLAS
+ if (n_gpu_layers > n_layer) {
+ offload_func_nr = ggml_cuda_assign_buffers_no_alloc;
+ }
+ if (n_gpu_layers > n_layer + 1) {
+ offload_func_v = ggml_cuda_assign_buffers_no_alloc;
+ }
+ if (n_gpu_layers > n_layer + 2) {
+ offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
+ }
+#endif // GGML_USE_CUBLAS
+
+ struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+ ggml_allocr_alloc(lctx.alloc, KQ_scale);
+ if (!ggml_allocr_is_measure(lctx.alloc)) {
+ ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
+ }
+ ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
+
+ for (int il = 0; il < n_layer; ++il) {
+ ggml_format_name(inpL, "layer_inp_%d", il);
+
+ offload_func_t offload_func = llama_nop;
+
+#ifdef GGML_USE_CUBLAS
+ if (il >= i_gpu_start) {
+ offload_func = ggml_cuda_assign_buffers_no_alloc;
+ }
+#endif // GGML_USE_CUBLAS
+
+ struct ggml_tensor * inpSA = inpL;
+
+ // norm
+ {
+ cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
+ offload_func(cur);
+ ggml_set_name(cur, "rms_norm_0");
+
+ // cur = cur*attn_norm(broadcasted)
+ cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm);
+ offload_func(cur);
+ ggml_set_name(cur, "attention_norm_0");
+ }
+
+ // self-attention
+ {
+ // compute Q and K and RoPE them
+ struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+ offload_func_kq(tmpk);
+ ggml_set_name(tmpk, "tmpk");
+
+ struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+ offload_func_kq(tmpq);
+ ggml_set_name(tmpq, "tmpq");
+
+ struct ggml_tensor * Kcur;
+ struct ggml_tensor * Qcur;
+ switch (model.type) {
+ case MODEL_7B:
+ Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
+ Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
+ break;
+ case MODEL_13B:
+ Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N);
+ Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N);
+ break;
+ default:
+ GGML_ASSERT(false);
+ }
+
+ offload_func_kq(Kcur);
+ ggml_set_name(Kcur, "Kcur");
+
+ offload_func_kq(Qcur);
+ ggml_set_name(Qcur, "Qcur");
+
+ // store key and value to memory
+ {
+ // compute the transposed [N, n_embd] V matrix
+
+ struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+ offload_func_v(tmpv);
+ ggml_set_name(tmpv, "tmpv");
+
+ struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N));
+ offload_func_v(Vcur);
+ ggml_set_name(Vcur, "Vcur");
+
+ struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past));
+ offload_func_kq(k);
+ ggml_set_name(k, "k");
+
+ struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa,
+ ( n_ctx)*ggml_element_size(kv_self.v),
+ (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v));
+ offload_func_v(v);
+ ggml_set_name(v, "v");
+
+ // important: storing RoPE-ed version of K in the KV cache!
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
+ }
+
+ struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
+ offload_func_kq(Q);
+ ggml_set_name(Q, "Q");
+
+ struct ggml_tensor * K =
+ ggml_view_3d(ctx0, kv_self.k,
+ n_embd_head, n_past + N, n_head_kv,
+ ggml_element_size(kv_self.k)*n_embd_gqa,
+ ggml_element_size(kv_self.k)*n_embd_head,
+ ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
+ offload_func_kq(K);
+ ggml_set_name(K, "K");
+
+ // K * Q
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+ offload_func_kq(KQ);
+ ggml_set_name(KQ, "KQ");
+
+ // KQ_scaled = KQ / sqrt(n_embd_head)
+ // KQ_scaled shape [n_past + N, N, n_head, 1]
+ struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
+ offload_func_kq(KQ_scaled);
+ ggml_set_name(KQ_scaled, "KQ_scaled");
+
+ struct ggml_tensor * KQ_masked;
+ struct ggml_tensor * KQ_scaled_alibi;
+
+ switch (model.type) {
+ case MODEL_7B:
+ KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
+ break;
+ case MODEL_13B:
+ KQ_scaled_alibi =ggml_alibi(ctx0, KQ_scaled, n_past, n_head, 8);
+ ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi");
+ KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_alibi, n_past);
+ break;
+ default:
+ GGML_ASSERT(false);
+ }
+ // KQ_masked = mask_past(KQ_scaled)
+ // struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
+ // struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_alibi, n_past);
+ // offload_func_kq(KQ_masked);
+ // ggml_set_name(KQ_masked, "KQ_masked");
+
+ // KQ = soft_max(KQ_masked)
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
+ offload_func_v(KQ_soft_max);
+ ggml_set_name(KQ_soft_max, "KQ_soft_max");
+
+ // split cached V into n_head heads
+ struct ggml_tensor * V =
+ ggml_view_3d(ctx0, kv_self.v,
+ n_past + N, n_embd_head, n_head_kv,
+ ggml_element_size(kv_self.v)*n_ctx,
+ ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
+ ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
+ offload_func_v(V);
+ ggml_set_name(V, "V");
+
+#if 1
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+ offload_func_v(KQV);
+ ggml_set_name(KQV, "KQV");
+#else
+ // make V contiguous in memory to speed up the matmul, however we waste time on the copy
+ // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
+ // is there a better way?
+ struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd_head, n_head));
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
+#endif
+
+ // KQV_merged = KQV.permute(0, 2, 1, 3)
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+ offload_func_v(KQV_merged);
+ ggml_set_name(KQV_merged, "KQV_merged");
+
+ // cur = KQV_merged.contiguous().view(n_embd, N)
+ cur = ggml_cpy(ctx0,
+ KQV_merged,
+ ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
+ offload_func_v(cur);
+ ggml_set_name(cur, "KQV_merged_contiguous");
+
+ // projection (no bias)
+ cur = ggml_mul_mat(ctx0,
+ model.layers[il].wo,
+ cur);
+ offload_func(cur);
+ ggml_set_name(cur, "result_wo");
+ }
+
+ struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
+ offload_func(inpFF);
+ ggml_set_name(inpFF, "inpFF");
+
+ // feed-forward network
+ {
+ // norm
+ {
+ cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps);
+ offload_func(cur);
+ ggml_set_name(cur, "rms_norm_1");
+
+ // cur = cur*ffn_norm(broadcasted)
+ cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
+ offload_func(cur);
+ ggml_set_name(cur, "ffn_norm");
+ }
+
+ struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
+ model.layers[il].w3,
+ cur);
+ offload_func(tmp);
+ ggml_set_name(tmp, "result_w3");
+
+ cur = ggml_mul_mat(ctx0,
+ model.layers[il].w1,
+ cur);
+ offload_func(cur);
+ ggml_set_name(cur, "result_w1");
+
+ // SILU activation
+ cur = ggml_silu(ctx0, cur);
+ offload_func(cur);
+ ggml_set_name(cur, "silu");
+
+ cur = ggml_mul(ctx0, cur, tmp);
+ offload_func(cur);
+ ggml_set_name(cur, "silu_x_result_w3");
+
+ cur = ggml_mul_mat(ctx0,
+ model.layers[il].w2,
+ cur);
+ offload_func(cur);
+ ggml_set_name(cur, "result_w2");
+ }
+
+ cur = ggml_add(ctx0, cur, inpFF);
+ offload_func(cur);
+ ggml_set_name(cur, "inpFF_+_result_w2");
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ // norm
+ {
+ cur = ggml_rms_norm(ctx0, cur, norm_rms_eps);
+ offload_func_nr(cur);
+ ggml_set_name(cur, "rms_norm_2");
+
+ // cur = cur*norm(broadcasted)
+ cur = ggml_mul(ctx0, cur, model.output_norm);
+ // offload_func_nr(cur); // TODO CPU + GPU mirrored backend
+ ggml_set_name(cur, "result_norm");
+ }
+
+ // lm_head
+ cur = ggml_mul_mat(ctx0, model.output, cur);
+ ggml_set_name(cur, "result_output");
+
+ ggml_build_forward_expand(gf, cur);
+
+ ggml_free(ctx0);
+
+ return gf;
+}
+
static struct ggml_cgraph * llm_build_falcon(
llama_context & lctx,
const llama_token * tokens,
{
result = llm_build_llama(lctx, tokens, embd, n_tokens, n_past);
} break;
+ case LLM_ARCH_BAICHUAN:
+ {
+ result = llm_build_baichaun(lctx, tokens, embd, n_tokens, n_past);
+ } break;
case LLM_ARCH_FALCON:
{
result = llm_build_falcon(lctx, tokens, embd, n_tokens, n_past);