GPTNEOX : int = auto()
MPT : int = auto()
STARCODER : int = auto()
+ PERSIMMON : int = auto()
REFACT : int = auto()
BERT : int = auto()
FFN_DOWN : int = auto()
FFN_UP : int = auto()
FFN_NORM : int = auto()
+ ATTN_Q_NORM : int = auto()
+ ATTN_K_NORM : int = auto()
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.GPTNEOX: "gptneox",
MODEL_ARCH.MPT: "mpt",
MODEL_ARCH.STARCODER: "starcoder",
+ MODEL_ARCH.PERSIMMON: "persimmon",
MODEL_ARCH.REFACT: "refact",
MODEL_ARCH.BERT: "bert",
}
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
MODEL_TENSOR.OUTPUT: "output",
MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
-
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2",
MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",
MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v",
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
+ MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
+ MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
+ MODEL_ARCH.PERSIMMON: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ ],
MODEL_ARCH.REFACT: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
+ MODEL_ARCH.PERSIMMON: [
+ MODEL_TENSOR.ROPE_FREQS,
+ ]
}
mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
# Token embeddings
MODEL_TENSOR.TOKEN_EMBD: (
- "gpt_neox.embed_in", # gptneox
- "transformer.wte", # gpt2 gpt-j mpt refact
- "transformer.word_embeddings", # falcon
- "model.embed_tokens", # llama-hf
- "tok_embeddings", # llama-pth
- "embeddings.word_embeddings", # bert
+ "gpt_neox.embed_in", # gptneox
+ "transformer.wte", # gpt2 gpt-j mpt refact
+ "transformer.word_embeddings", # falcon
+ "model.embed_tokens", # llama-hf
+ "tok_embeddings", # llama-pth
+ "embeddings.word_embeddings", # bert
+ "language_model.embedding.word_embeddings", # persimmon
),
# Token type embeddings
# Output
MODEL_TENSOR.OUTPUT: (
- "embed_out", # gptneox
- "lm_head", # gpt2 gpt-j mpt falcon llama-hf baichuan
- "output", # llama-pth
+ "embed_out", # gptneox
+ "lm_head", # gpt2 mpt falcon llama-hf baichuan
+ "output", # llama-pth
+ "word_embeddings_for_head", # persimmon
),
# Output norm
MODEL_TENSOR.OUTPUT_NORM: (
- "gpt_neox.final_layer_norm", # gptneox
- "transformer.ln_f", # gpt2 gpt-j falcon
- "model.norm", # llama-hf baichuan
- "norm", # llama-pth
- "embeddings.LayerNorm", # bert
- "transformer.norm_f", # mpt
- "ln_f", # refact
+ "gpt_neox.final_layer_norm", # gptneox
+ "transformer.ln_f", # gpt2 gpt-j falcon
+ "model.norm", # llama-hf baichuan
+ "norm", # llama-pth
+ "embeddings.LayerNorm", # bert
+ "transformer.norm_f", # mpt
+ "ln_f", # refact
+ "language_model.encoder.final_layernorm", # persimmon
),
# Rope frequencies
block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
# Attention norm
MODEL_TENSOR.ATTN_NORM: (
- "gpt_neox.layers.{bid}.input_layernorm", # gptneox
- "transformer.h.{bid}.ln_1", # gpt2 gpt-j refact
- "transformer.blocks.{bid}.norm_1", # mpt
- "transformer.h.{bid}.input_layernorm", # falcon7b
- "transformer.h.{bid}.ln_mlp", # falcon40b
- "model.layers.{bid}.input_layernorm", # llama-hf
- "layers.{bid}.attention_norm", # llama-pth
- "encoder.layer.{bid}.attention.output.LayerNorm", # bert
+ "gpt_neox.layers.{bid}.input_layernorm", # gptneox
+ "transformer.h.{bid}.ln_1", # gpt2 gpt-j refact
+ "transformer.blocks.{bid}.norm_1", # mpt
+ "transformer.h.{bid}.input_layernorm", # falcon7b
+ "transformer.h.{bid}.ln_mlp", # falcon40b
+ "model.layers.{bid}.input_layernorm", # llama-hf
+ "layers.{bid}.attention_norm", # llama-pth
+ "encoder.layer.{bid}.attention.output.LayerNorm", # bert
+ "language_model.encoder.layers.{bid}.input_layernorm", # persimmon
),
# Attention norm 2
# Attention query-key-value
MODEL_TENSOR.ATTN_QKV: (
- "gpt_neox.layers.{bid}.attention.query_key_value", # gptneox
- "transformer.h.{bid}.attn.c_attn", # gpt2
- "transformer.blocks.{bid}.attn.Wqkv", # mpt
- "transformer.h.{bid}.self_attention.query_key_value", # falcon
+ "gpt_neox.layers.{bid}.attention.query_key_value", # gptneox
+ "transformer.h.{bid}.attn.c_attn", # gpt2
+ "transformer.blocks.{bid}.attn.Wqkv", # mpt
+ "transformer.h.{bid}.self_attention.query_key_value", # falcon
+ "language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
),
# Attention query
# Attention output
MODEL_TENSOR.ATTN_OUT: (
- "gpt_neox.layers.{bid}.attention.dense", # gptneox
- "transformer.h.{bid}.attn.c_proj", # gpt2 refact
- "transformer.blocks.{bid}.attn.out_proj", # mpt
- "transformer.h.{bid}.self_attention.dense", # falcon
- "model.layers.{bid}.self_attn.o_proj", # llama-hf
- "layers.{bid}.attention.wo", # llama-pth
- "encoder.layer.{bid}.attention.output.dense", # bert
- "transformer.h.{bid}.attn.out_proj", # gpt-j
+ "gpt_neox.layers.{bid}.attention.dense", # gptneox
+ "transformer.h.{bid}.attn.c_proj", # gpt2 refact
+ "transformer.blocks.{bid}.attn.out_proj", # mpt
+ "transformer.h.{bid}.self_attention.dense", # falcon
+ "model.layers.{bid}.self_attn.o_proj", # llama-hf
+ "layers.{bid}.attention.wo", # llama-pth
+ "encoder.layer.{bid}.attention.output.dense", # bert
+ "transformer.h.{bid}.attn.out_proj", # gpt-j
+ "language_model.encoder.layers.{bid}.self_attention.dense" # persimmon
),
# Rotary embeddings
# Feed-forward norm
MODEL_TENSOR.FFN_NORM: (
- "gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
- "transformer.h.{bid}.ln_2", # gpt2 refact
- "transformer.blocks.{bid}.norm_2", # mpt
- "model.layers.{bid}.post_attention_layernorm", # llama-hf
- "layers.{bid}.ffn_norm", # llama-pth
- "encoder.layer.{bid}.output.LayerNorm", # bert
+ "gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
+ "transformer.h.{bid}.ln_2", # gpt2 refact
+ "transformer.blocks.{bid}.norm_2", # mpt
+ "model.layers.{bid}.post_attention_layernorm", # llama-hf
+ "layers.{bid}.ffn_norm", # llama-pth
+ "encoder.layer.{bid}.output.LayerNorm", # bert
+ "language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
),
# Feed-forward up
MODEL_TENSOR.FFN_UP: (
- "gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox
- "transformer.h.{bid}.mlp.c_fc", # gpt2
- "transformer.blocks.{bid}.ffn.up_proj", # mpt
- "transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
- "model.layers.{bid}.mlp.up_proj", # llama-hf refact
- "layers.{bid}.feed_forward.w3", # llama-pth
- "encoder.layer.{bid}.intermediate.dense", # bert
- "transformer.h.{bid}.mlp.fc_in", # gpt-j
+ "gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox
+ "transformer.h.{bid}.mlp.c_fc", # gpt2
+ "transformer.blocks.{bid}.ffn.up_proj", # mpt
+ "transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
+ "model.layers.{bid}.mlp.up_proj", # llama-hf refact
+ "layers.{bid}.feed_forward.w3", # llama-pth
+ "encoder.layer.{bid}.intermediate.dense", # bert
+ "transformer.h.{bid}.mlp.fc_in", # gpt-j
+ "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
),
# Feed-forward gate
# Feed-forward down
MODEL_TENSOR.FFN_DOWN: (
- "gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
- "transformer.h.{bid}.mlp.c_proj", # gpt2 refact
- "transformer.blocks.{bid}.ffn.down_proj", # mpt
- "transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
- "model.layers.{bid}.mlp.down_proj", # llama-hf
- "layers.{bid}.feed_forward.w2", # llama-pth
- "encoder.layer.{bid}.output.dense", # bert
- "transformer.h.{bid}.mlp.fc_out", # gpt-j
+ "gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
+ "transformer.h.{bid}.mlp.c_proj", # gpt2 refact
+ "transformer.blocks.{bid}.ffn.down_proj", # mpt
+ "transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
+ "model.layers.{bid}.mlp.down_proj", # llama-hf
+ "layers.{bid}.feed_forward.w2", # llama-pth
+ "encoder.layer.{bid}.output.dense", # bert
+ "transformer.h.{bid}.mlp.fc_out", # gpt-j
+ "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
+ ),
+
+ MODEL_TENSOR.ATTN_Q_NORM: (
+ "language_model.encoder.layers.{bid}.self_attention.q_layernorm",
),
+
+ MODEL_TENSOR.ATTN_K_NORM: (
+ "language_model.encoder.layers.{bid}.self_attention.k_layernorm",
+ ),
+
+ MODEL_TENSOR.ROPE_FREQS: (
+ "language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon
+ )
}
mapping: dict[str, tuple[MODEL_TENSOR, str]]
LLM_ARCH_GPTNEOX,
LLM_ARCH_MPT,
LLM_ARCH_STARCODER,
+ LLM_ARCH_PERSIMMON,
LLM_ARCH_REFACT,
LLM_ARCH_UNKNOWN,
};
{ LLM_ARCH_MPT, "mpt" },
{ LLM_ARCH_BAICHUAN, "baichuan" },
{ LLM_ARCH_STARCODER, "starcoder" },
+ { LLM_ARCH_PERSIMMON, "persimmon" },
{ LLM_ARCH_REFACT, "refact" },
};
LLM_TENSOR_FFN_DOWN,
LLM_TENSOR_FFN_UP,
LLM_TENSOR_FFN_NORM,
+ LLM_TENSOR_ATTN_Q_NORM,
+ LLM_TENSOR_ATTN_K_NORM,
};
static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
+ {
+ LLM_ARCH_PERSIMMON,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd"},
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm"},
+ { LLM_TENSOR_OUTPUT, "output"},
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm"},
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv"},
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output"},
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm"},
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm"},
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm"},
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down"},
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up"},
+ { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd"},
+ },
+ },
{
LLM_ARCH_MPT,
{
MODEL_1B,
MODEL_3B,
MODEL_7B,
+ MODEL_8B,
MODEL_13B,
MODEL_15B,
MODEL_30B,
struct ggml_tensor * attn_norm_b;
struct ggml_tensor * attn_norm_2;
struct ggml_tensor * attn_norm_2_b;
+ struct ggml_tensor * attn_q_norm;
+ struct ggml_tensor * attn_q_norm_b;
+ struct ggml_tensor * attn_k_norm;
+ struct ggml_tensor * attn_k_norm_b;
// attention
struct ggml_tensor * wq;
case MODEL_1B: return "1B";
case MODEL_3B: return "3B";
case MODEL_7B: return "7B";
+ case MODEL_8B: return "8B";
case MODEL_13B: return "13B";
case MODEL_15B: return "15B";
case MODEL_30B: return "30B";
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
+ case LLM_ARCH_PERSIMMON:
+ {
+ GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
+ switch (hparams.n_layer) {
+ case 36: model.type = e_model::MODEL_8B; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ }
+ }
case LLM_ARCH_REFACT:
{
GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
}
}
} break;
+ case LLM_ARCH_PERSIMMON:
+ {
+ model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
+
+ {
+ ggml_backend backend_norm;
+ ggml_backend backend_output;
+
+ if (n_gpu_layers > int(n_layer)) {
+ // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
+ // on Windows however this is detrimental unless everything is on the GPU
+#ifndef _WIN32
+ backend_norm = LLAMA_BACKEND_OFFLOAD;
+#else
+ backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+#endif // _WIN32
+
+ backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
+ } else {
+ backend_norm = GGML_BACKEND_CPU;
+ backend_output = GGML_BACKEND_CPU;
+ }
+
+ model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm);
+ model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm);
+ model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output);
+
+ if (backend_norm == GGML_BACKEND_GPU) {
+ vram_weights += ggml_nbytes(model.output_norm);
+ vram_weights += ggml_nbytes(model.output_norm_b);
+ }
+ if (backend_output == GGML_BACKEND_GPU_SPLIT) {
+ vram_weights += ggml_nbytes(model.output);
+ }
+ }
+
+ const uint32_t n_ff = hparams.n_ff;
+ const int i_gpu_start = n_layer - n_gpu_layers;
+ model.layers.resize(n_layer);
+ for (uint32_t i = 0; i < n_layer; ++i) {
+ const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+ const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT;
+ auto & layer = model.layers[i];
+ layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
+ layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend);
+ layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
+ layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend_split);
+ layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
+ layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend_split);
+ layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split);
+ layer.b2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend_split);
+ layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
+ layer.b3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend_split);
+ layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
+ layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend);
+ layer.attn_q_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {64}, backend);
+ layer.attn_q_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {64}, backend);
+ layer.attn_k_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {64}, backend);
+ layer.attn_k_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {64}, backend);
+ }
+ } break;
default:
throw std::runtime_error("unknown architecture");
}
}
static struct ggml_cgraph * llm_build_llama(
- llama_context & lctx,
- const llama_batch & batch) {
+ llama_context & lctx,
+ const llama_batch & batch) {
const auto & model = lctx.model;
const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;
struct ggml_init_params params = {
/*.mem_size =*/ buf_compute.size,
/*.mem_buffer =*/ buf_compute.data,
- /*.no_alloc =*/ false,
+ /*.no_alloc =*/ true,
};
- params.no_alloc = true;
-
struct ggml_context * ctx0 = ggml_init(params);
ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_init_params params = {
/*.mem_size =*/ buf_compute.size,
/*.mem_buffer =*/ buf_compute.data,
- /*.no_alloc =*/ false,
+ /*.no_alloc =*/ true,
};
- params.no_alloc = true;
-
struct ggml_context * ctx0 = ggml_init(params);
ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_init_params params = {
/*.mem_size =*/ buf_compute.size,
/*.mem_buffer =*/ buf_compute.data,
- /*.no_alloc =*/ false,
+ /*.no_alloc =*/ true,
};
- params.no_alloc = true;
-
struct ggml_context * ctx0 = ggml_init(params);
ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_init_params params = {
/*.mem_size =*/ buf_compute.size,
/*.mem_buffer =*/ buf_compute.data,
- /*.no_alloc =*/ false,
+ /*.no_alloc =*/ true,
};
- params.no_alloc = true;
-
struct ggml_context * ctx0 = ggml_init(params);
ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_init_params params = {
/*.mem_size =*/ buf_compute.size,
/*.mem_buffer =*/ buf_compute.data,
- /*.no_alloc =*/ false,
+ /*.no_alloc =*/ true,
};
- params.no_alloc = true;
-
struct ggml_context * ctx0 = ggml_init(params);
ggml_cgraph * gf = ggml_new_graph(ctx0);
return gf;
}
+
+static struct ggml_cgraph * llm_build_persimmon(
+ llama_context & lctx,
+ const llama_batch & batch) {
+ const auto & model = lctx.model;
+ const auto & hparams = model.hparams;
+
+ const auto & kv_self = lctx.kv_self;
+
+ GGML_ASSERT(!!kv_self.ctx);
+
+ const auto & cparams = lctx.cparams;
+ const int64_t n_embd = hparams.n_embd;
+ const int64_t n_layer = hparams.n_layer;
+ const int64_t n_ctx = cparams.n_ctx;
+ const int64_t n_head_kv = hparams.n_head_kv;
+ const int64_t n_head = hparams.n_head;
+ const int64_t n_embd_head = hparams.n_embd_head();
+ const int64_t n_embd_gqa = hparams.n_embd_gqa();
+ const size_t n_rot = n_embd_head / 2;
+
+ const float freq_base = cparams.rope_freq_base;
+ const float freq_scale = cparams.rope_freq_scale;
+ const float norm_eps = hparams.f_norm_eps;
+
+ const int n_gpu_layers = model.n_gpu_layers;
+
+
+ const int32_t n_tokens = batch.n_tokens;
+ const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n;
+ const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
+
+ const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift;
+
+ auto & buf_compute = lctx.buf_compute;
+ struct ggml_init_params params = {
+ /*.mem_size =*/ buf_compute.size,
+ /*.mem_buffer =*/ buf_compute.data,
+ /*.no_alloc =*/ true,
+ };
+
+ struct ggml_context * ctx0 = ggml_init(params);
+
+ ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
+
+ if (batch.token) {
+ struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+
+ ggml_allocr_alloc(lctx.alloc, inp_tokens);
+ if (!ggml_allocr_is_measure(lctx.alloc)) {
+ memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens));
+ }
+ ggml_set_name(inp_tokens, "inp_tokens");
+ inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
+ } else {
+ inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
+ ggml_allocr_alloc(lctx.alloc, inpL);
+ if (!ggml_allocr_is_measure(lctx.alloc)) {
+ memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL));
+ }
+ }
+ const int i_gpu_start = n_layer - n_gpu_layers;
+ (void) i_gpu_start;
+ offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
+ offload_func_t offload_func_kq = llama_nop;
+ offload_func_t offload_func_v = llama_nop;
+ // KQ_scale
+ struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+ ggml_allocr_alloc(lctx.alloc, KQ_scale);
+ if (!ggml_allocr_is_measure(lctx.alloc)) {
+ ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd_head)));
+ }
+ ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
+ struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+ offload_func_kq(KQ_mask);
+ ggml_set_name(KQ_mask, "KQ_mask");
+ ggml_allocr_alloc(lctx.alloc, KQ_mask);
+
+ if (!ggml_allocr_is_measure(lctx.alloc)) {
+ float * data = (float *) KQ_mask->data;
+ memset(data, 0, ggml_nbytes(KQ_mask));
+ for (int h = 0; h < 1; ++h) {
+ for (int j = 0; j < n_tokens; ++j) {
+ const llama_pos pos = batch.pos[j];
+ const llama_seq_id seq_id = batch.seq_id[j];
+ for (int i = 0; i < n_kv; ++i) {
+ if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
+ data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
+ }
+ }
+ }
+ }
+ }
+
+ struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+ offload_func_kq(KQ_pos);
+ ggml_set_name(KQ_pos, "KQ_pos");
+ ggml_allocr_alloc(lctx.alloc, KQ_pos);
+ if (!ggml_allocr_is_measure(lctx.alloc)) {
+ int * data = (int *) KQ_pos->data;
+ for (int i = 0; i < n_tokens; ++i) {
+ data[i] = batch.pos[i];
+ }
+ }
+ if (do_rope_shift) {
+ struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
+ offload_func_kq(K_shift);
+ ggml_set_name(K_shift, "K_shift");
+ ggml_allocr_alloc(lctx.alloc, K_shift);
+ if (!ggml_allocr_is_measure(lctx.alloc)) {
+ int * data = (int *) K_shift->data;
+ for (int i = 0; i < n_ctx; ++i) {
+ data[i] = kv_self.cells[i].delta;
+ }
+ }
+ for (int il = 0; il < n_layer; ++il) {
+ struct ggml_tensor * tmp =
+ // we rotate only the first n_rot dimensions.
+ ggml_rope_custom_inplace(ctx0,
+ ggml_view_3d(ctx0, kv_self.k,
+ n_rot, n_head, n_ctx,
+ ggml_element_size(kv_self.k)*n_embd_gqa,
+ ggml_element_size(kv_self.k)*n_embd_head,
+ ggml_element_size(kv_self.k)*(n_embd_head*n_ctx*il)
+ ),
+ K_shift, n_rot, 2, 0, freq_base, freq_scale);
+ offload_func_kq(tmp);
+ ggml_build_forward_expand(gf, tmp);
+ }
+ }
+ for (int il=0; il < n_layer; ++il) {
+ struct ggml_tensor * residual = inpL;
+ offload_func_t offload_func = llama_nop;
+ {
+ cur = ggml_norm(ctx0, inpL, norm_eps);
+ offload_func(cur);
+ cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm);
+ offload_func(cur);
+ cur = ggml_add(ctx0, cur, model.layers[il].attn_norm_b);
+ offload_func(cur);
+ ggml_format_name(cur, "input_layernorm_%d", il);
+ }
+ // self attention
+ {
+ cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+ offload_func_kq(cur);
+ cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+ offload_func_kq(cur);
+
+ // split qkv
+ GGML_ASSERT(n_head_kv == n_head);
+ ggml_set_name(cur, format("qkv_%d", il).c_str());
+ struct ggml_tensor * tmpqkv = ggml_reshape_4d(ctx0, cur, n_embd_head, 3, n_head, n_tokens);
+ offload_func_kq(tmpqkv);
+ struct ggml_tensor * tmpqkv_perm = ggml_cont(ctx0, ggml_permute(ctx0, tmpqkv, 0, 3, 1, 2));
+ offload_func_kq(tmpqkv_perm);
+ ggml_format_name(tmpqkv_perm, "tmpqkv_perm_%d", il);
+ struct ggml_tensor * tmpq = ggml_view_3d(
+ ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens,
+ ggml_element_size(tmpqkv_perm) * n_embd_head,
+ ggml_element_size(tmpqkv_perm) * n_embd_head * n_head,
+ 0
+ );
+ offload_func_kq(tmpq);
+ struct ggml_tensor * tmpk = ggml_view_3d(
+ ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens,
+ ggml_element_size(tmpqkv_perm) * n_embd_head,
+ ggml_element_size(tmpqkv_perm) * n_embd_head * n_head,
+ ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens
+ );
+ offload_func_kq(tmpk);
+ // Q/K Layernorm
+ tmpq = ggml_norm(ctx0, tmpq, norm_eps);
+ offload_func_kq(tmpq);
+ tmpq = ggml_mul(ctx0, tmpq, model.layers[il].attn_q_norm);
+ offload_func_kq(tmpq);
+ tmpq = ggml_add(ctx0, tmpq, model.layers[il].attn_q_norm_b);
+ offload_func_kq(tmpq);
+
+ tmpk = ggml_norm(ctx0, tmpk, norm_eps);
+ offload_func_v(tmpk);
+ tmpk = ggml_mul(ctx0, tmpk, model.layers[il].attn_k_norm);
+ offload_func_v(tmpk);
+ tmpk = ggml_add(ctx0, tmpk, model.layers[il].attn_k_norm_b);
+ offload_func_v(tmpk);
+
+ // RoPE the first n_rot of q/k, pass the other half, and concat.
+ struct ggml_tensor * qrot = ggml_view_3d(
+ ctx0, tmpq, n_rot, n_head, n_tokens,
+ ggml_element_size(tmpq) * n_embd_head,
+ ggml_element_size(tmpq) * n_embd_head * n_head,
+ 0
+ );
+ offload_func_kq(qrot);
+ ggml_format_name(qrot, "qrot_%d", il);
+ struct ggml_tensor * krot = ggml_view_3d(
+ ctx0, tmpk, n_rot, n_head, n_tokens,
+ ggml_element_size(tmpk) * n_embd_head,
+ ggml_element_size(tmpk) * n_embd_head * n_head,
+ 0
+ );
+ offload_func_kq(krot);
+ ggml_format_name(krot, "krot_%d", il);
+
+ // get the second half of tmpq, e.g tmpq[n_rot:, :, :]
+ struct ggml_tensor * qpass = ggml_view_3d(
+ ctx0, tmpq, n_rot, n_head, n_tokens,
+ ggml_element_size(tmpq) * n_embd_head,
+ ggml_element_size(tmpq) * n_embd_head * n_head,
+ ggml_element_size(tmpq) * n_rot
+ );
+ offload_func_kq(qpass);
+ ggml_format_name(qpass, "qpass_%d", il);
+ struct ggml_tensor * kpass = ggml_view_3d(
+ ctx0, tmpk, n_rot, n_head, n_tokens,
+ ggml_element_size(tmpk) * n_embd_head,
+ ggml_element_size(tmpk) * n_embd_head * n_head,
+ ggml_element_size(tmpk) * n_rot
+ );
+ offload_func_kq(kpass);
+ ggml_format_name(kpass, "kpass_%d", il);
+
+ struct ggml_tensor * qrotated = ggml_rope_custom(
+ ctx0, qrot, KQ_pos, n_rot, 2, 0, freq_base, freq_scale
+ );
+ offload_func_kq(qrotated);
+ struct ggml_tensor * krotated = ggml_rope_custom(
+ ctx0, krot, KQ_pos, n_rot, 2, 0, freq_base, freq_scale
+ );
+ offload_func_kq(krotated);
+ // ggml currently only supports concatenation on dim=2
+ // so we need to permute qrot, qpass, concat, then permute back.
+ qrotated = ggml_cont(ctx0, ggml_permute(ctx0, qrotated, 2, 1, 0, 3));
+ offload_func_kq(qrotated);
+ krotated = ggml_cont(ctx0, ggml_permute(ctx0, krotated, 2, 1, 0, 3));
+ offload_func_kq(krotated);
+
+ qpass = ggml_cont(ctx0, ggml_permute(ctx0, qpass, 2, 1, 0, 3));
+ offload_func_kq(qpass);
+ kpass = ggml_cont(ctx0, ggml_permute(ctx0, kpass, 2, 1, 0, 3));
+ offload_func_kq(kpass);
+
+ struct ggml_tensor * Qcur = ggml_concat(ctx0, qrotated, qpass);
+ offload_func_kq(Qcur);
+ struct ggml_tensor * Kcur = ggml_concat(ctx0, krotated, kpass);
+ offload_func_kq(Kcur);
+
+ struct ggml_tensor * Q = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 1, 2, 0, 3));
+ offload_func_kq(Q);
+
+ Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 2, 1, 0, 3));
+ offload_func_kq(Kcur);
+ {
+ struct ggml_tensor * tmpv = ggml_view_3d(
+ ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens,
+ ggml_element_size(tmpqkv_perm) * n_embd_head,
+ ggml_element_size(tmpqkv_perm) * n_embd_head * n_head,
+ ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens * 2
+ );
+ offload_func_v(tmpv);
+ // store K, V in cache
+ struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens));
+ offload_func_v(Vcur);
+ ggml_set_name(Vcur, "Vcur");
+
+ struct ggml_tensor * k = ggml_view_1d(
+ ctx0, kv_self.k, n_tokens*n_embd_gqa,
+ (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)
+ );
+ offload_func_kq(k);
+ ggml_set_name(k, "k");
+
+ struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa,
+ ( n_ctx)*ggml_element_size(kv_self.v),
+ (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
+ offload_func_v(v);
+ ggml_set_name(v, "v");
+
+ // important: storing RoPE-ed version of K in the KV cache!
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
+ }
+ struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k,
+ n_embd_head, n_kv, n_head_kv,
+ ggml_element_size(kv_self.k)*n_embd_gqa,
+ ggml_element_size(kv_self.k)*n_embd_head,
+ ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
+
+ offload_func_kq(K);
+ ggml_format_name(K, "K_%d", il);
+
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+ offload_func_kq(KQ);
+ ggml_set_name(KQ, "KQ");
+
+ struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
+ offload_func_kq(KQ_scaled);
+ ggml_set_name(KQ_scaled, "KQ_scaled");
+
+ struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask);
+ offload_func_kq(KQ_masked);
+ ggml_set_name(KQ_masked, "KQ_masked");
+
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
+ offload_func_kq(KQ_soft_max);
+ ggml_set_name(KQ_soft_max, "KQ_soft_max");
+
+ struct ggml_tensor * V =
+ ggml_view_3d(ctx0, kv_self.v,
+ n_kv, n_embd_head, n_head_kv,
+ ggml_element_size(kv_self.v)*n_ctx,
+ ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
+ ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
+ offload_func_v(V);
+ ggml_set_name(V, "V");
+
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+ offload_func_v(KQV);
+ ggml_set_name(KQV, "KQV");
+
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+ offload_func_v(KQV_merged);
+ ggml_set_name(KQV_merged, "KQV_merged");
+
+ cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens);
+ offload_func_v(cur);
+ ggml_set_name(cur, "KQV_merged_contiguous");
+
+ cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
+ offload_func(cur);
+ cur = ggml_add(ctx0, cur, model.layers[il].bo);
+ offload_func(cur);
+ ggml_set_name(cur, "result_wo");
+ }
+
+ struct ggml_tensor * inpFF = ggml_add(ctx0, residual, cur);
+ offload_func(inpFF);
+ ggml_set_name(inpFF, "inpFF");
+ {
+ // MLP
+ {
+ // Norm
+ cur = ggml_norm(ctx0, inpFF, norm_eps);
+ offload_func(cur);
+ cur = ggml_add(ctx0,
+ ggml_mul(ctx0, cur, model.layers[il].ffn_norm),
+ model.layers[il].ffn_norm_b
+ );
+ ggml_set_name(cur, "ffn_norm");
+ offload_func(cur);
+ }
+ cur = ggml_mul_mat(ctx0, model.layers[il].w3, cur);
+ offload_func(cur);
+
+ cur = ggml_add(ctx0, cur, model.layers[il].b3);
+ offload_func(cur);
+ ggml_set_name(cur, "result_ffn_up");
+
+ cur = ggml_sqr(ctx0, ggml_relu(ctx0, cur));
+ ggml_set_name(cur, "result_ffn_act");
+ offload_func(cur);
+ offload_func(cur->src[0]);
+
+ cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur);
+ offload_func(cur);
+ cur = ggml_add(ctx0,
+ cur,
+ model.layers[il].b2);
+ offload_func(cur);
+ ggml_set_name(cur, "outFF");
+ }
+ cur = ggml_add(ctx0, cur, inpFF);
+ offload_func(cur);
+ ggml_set_name(cur, "inpFF_+_outFF");
+ inpL = cur;
+ }
+ cur = inpL;
+ {
+ cur = ggml_norm(ctx0, cur, norm_eps);
+ offload_func_nr(cur);
+ cur = ggml_mul(ctx0, cur, model.output_norm);
+ offload_func_nr(cur);
+
+ cur = ggml_add(ctx0, cur, model.output_norm_b);
+ // offload_func_nr(cur);
+
+ ggml_set_name(cur, "result_norm");
+ }
+ cur = ggml_mul_mat(ctx0, model.output, cur);
+ ggml_set_name(cur, "result_output");
+ ggml_build_forward_expand(gf, cur);
+ ggml_free(ctx0);
+ return gf;
+}
+
static struct ggml_cgraph * llama_build_graph(
llama_context & lctx,
const llama_batch & batch) {
{
result = llm_build_starcoder(lctx, batch);
} break;
+ case LLM_ARCH_PERSIMMON:
+ {
+ result = llm_build_persimmon(lctx, batch);
+ }
case LLM_ARCH_REFACT:
{
result = llm_build_refact(lctx, batch);