import gguf
+# check for any of the given keys in the dictionary and return the value of the first key found
+def get_key_opts(d, keys):
+ for k in keys:
+ if k in d:
+ return d[k]
+ print(f"Could not find any of {keys}")
+ sys.exit()
+
+
###### MODEL DEFINITIONS ######
class SentencePieceTokenTypes(IntEnum):
toktypes.append(gguf.TokenType.USER_DEFINED)
elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i])
- if tokenizer.added_tokens_decoder[i].special:
- toktypes.append(gguf.TokenType.CONTROL)
- else:
- toktypes.append(gguf.TokenType.USER_DEFINED)
+ if hasattr(tokenizer, "added_tokens_decoder"):
+ if tokenizer.added_tokens_decoder[i].special:
+ toktypes.append(gguf.TokenType.CONTROL)
+ else:
+ toktypes.append(gguf.TokenType.USER_DEFINED)
else:
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL)
class Phi2Model(Model):
def set_gguf_parameters(self):
- block_count = self.hparams["n_layer"]
+ block_count = get_key_opts(self.hparams, ["num_hidden_layers", "n_layer"])
+
+ rot_pct = get_key_opts(self.hparams, ["partial_rotary_factor"])
+ n_embd = get_key_opts(self.hparams, ["hidden_size", "n_embd"])
+ n_head = get_key_opts(self.hparams, ["num_attention_heads", "n_head"])
self.gguf_writer.add_name("Phi2")
- self.gguf_writer.add_context_length(self.hparams["n_positions"])
- self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
- self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
+ self.gguf_writer.add_context_length(get_key_opts(self.hparams, ["n_positions", "max_position_embeddings"]))
+
+ self.gguf_writer.add_embedding_length(n_embd)
+ self.gguf_writer.add_feed_forward_length(4 * n_embd)
self.gguf_writer.add_block_count(block_count)
- self.gguf_writer.add_head_count(self.hparams["n_head"])
- self.gguf_writer.add_head_count_kv(self.hparams["n_head"])
- self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
- self.gguf_writer.add_rope_dimension_count(self.hparams["rotary_dim"])
+ self.gguf_writer.add_head_count(n_head)
+ self.gguf_writer.add_head_count_kv(n_head)
+ self.gguf_writer.add_layer_norm_eps(get_key_opts(self.hparams, ["layer_norm_epsilon", "layer_norm_eps"]))
+ self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_add_bos_token(False)
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd});
- layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
- layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa});
+ layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, false);
+ layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, false);
+
+ if (layer.wqkv == nullptr) {
+ layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
+ layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd});
+
+ layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
+ layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa});
+
+ layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
+ layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa});
+ }
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd});
// self-attention
{
- cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output);
- cb(cur, "wqkv", il);
+ struct ggml_tensor * Qcur = nullptr;
+ struct ggml_tensor * Kcur = nullptr;
+ struct ggml_tensor * Vcur = nullptr;
- cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
- cb(cur, "bqkv", il);
+ if (model.layers[il].wqkv) {
+ cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output);
+ cb(cur, "wqkv", il);
- struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
- struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
- struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+ cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+ cb(cur, "bqkv", il);
+
+ Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+ Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+ Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+ } else {
+ Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
+ Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
+ Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
+ }
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);