]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
MPT : support GQA for replit-code-v1.5 (#3627)
authorcebtenzzre <redacted>
Sun, 15 Oct 2023 06:32:06 +0000 (02:32 -0400)
committerGitHub <redacted>
Sun, 15 Oct 2023 06:32:06 +0000 (09:32 +0300)
convert-mpt-hf-to-gguf.py
llama.cpp

index 73a4932f7c831b8b2d353572108e221ba5e2858b..19a66820dceab53a5f7aef3733a97a803f49dab3 100755 (executable)
@@ -98,6 +98,8 @@ gguf_writer.add_embedding_length(hparams["d_model"])
 gguf_writer.add_block_count(block_count)
 gguf_writer.add_feed_forward_length(4 * hparams["d_model"])
 gguf_writer.add_head_count(hparams["n_heads"])
+if kv_n_heads := hparams["attn_config"].get("kv_n_heads"):
+    gguf_writer.add_head_count_kv(kv_n_heads)
 gguf_writer.add_layer_norm_eps(1e-05)
 if hparams["attn_config"]["clip_qkv"] is not None:
     gguf_writer.add_clamp_kqv(hparams["attn_config"]["clip_qkv"])
index 2cd2dad7f3bb0f1a883af6d7af18a86814749255..5329bd828a12575fed2c990ca31327af5beeb711 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -2839,8 +2839,8 @@ static void llm_load_tensors(
                         auto & layer = model.layers[i];
 
                         layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
-                        layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3*n_embd}, backend_split);
-                        layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd},     backend_split);
+                        layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
+                        layer.wo   = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd},                backend_split);
 
                         layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
 
@@ -5368,7 +5368,7 @@ static struct ggml_cgraph * llm_build_mpt(
     const int64_t n_layer     = hparams.n_layer;
     const int64_t n_ctx       = cparams.n_ctx;
     const int64_t n_head      = hparams.n_head;
-    const int64_t n_head_kv   = hparams.n_head_kv; // == n_head for MPT, as there's no MQA/GQA
+    const int64_t n_head_kv   = hparams.n_head_kv;
     const int64_t n_embd_head = hparams.n_embd_head();
     const int64_t n_embd_gqa  = hparams.n_embd_gqa();