]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llm : add MPT support (#3417)
authorJan Ploski <redacted>
Tue, 10 Oct 2023 07:50:23 +0000 (09:50 +0200)
committerGitHub <redacted>
Tue, 10 Oct 2023 07:50:23 +0000 (10:50 +0300)
* CUDA: added support for ggml_clamp (see also: https://github.com/ggerganov/ggml/issues/545)

* mpt : added an implementation based (mostly) on falcon integration, modified with deltas from ggml/examples/mpt

* mpt : protect against "clip_qkv": null in mpt-7b

* mpt : quick fix to avoid "Strange model" warning when quantizing MPT models

* mpt : addendum to changeset:84e30e8 - leave parameter clamp_kqv out from metadata rather than use 0.0 to indicate "no clamping" (more compliant with the current GGUF spec?)

* mpt : standardized all tensor names to follow GGUF spec

* mpt : addendum to changeset:1be89c40 - use "req" parameter of GGUF_GET_KEY macro instead of duplicate code

* mpt : fixed comment s/gptneox/mpt/

* mpt : remove tabs, trailing whitespace

* mpt : removed ne01 + n_past == ne00 assertion from alibi (cuda/f32) and rope_shift from build_mpt

* mpt : updated convert-mpt-hf-to-gguf.py to reflect changes made to convert-gptneox-hf-to-gguf.py in pr:3252

* comment out n_past instead of marking it unused

* mpt : removed hardcoded +178 from convert script in favor of utilizing hparams["vocab_size"]

* mpt : remove unused tokenizer_json in convert script

* ggml : remove obsolete n_past assert in ggml_alibi

* llama : print clam_kqv and max_alibi_bias hparams

---------

Co-authored-by: Cebtenzzre <redacted>
Co-authored-by: Georgi Gerganov <redacted>
convert-mpt-hf-to-gguf.py [new file with mode: 0755]
ggml-cuda.cu
ggml-metal.m
ggml.c
llama.cpp

diff --git a/convert-mpt-hf-to-gguf.py b/convert-mpt-hf-to-gguf.py
new file mode 100755 (executable)
index 0000000..73a4932
--- /dev/null
@@ -0,0 +1,216 @@
+#!/usr/bin/env python3
+# HF mpt--> gguf conversion
+
+from __future__ import annotations
+
+import argparse
+import json
+import os
+import struct
+import sys
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer  # type: ignore[import]
+
+if 'NO_LOCAL_GGUF' not in os.environ:
+    sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
+import gguf
+
+
+def count_model_parts(dir_model: Path) -> int:
+    num_parts = 0
+    for filename in os.listdir(dir_model):
+        if filename.startswith("pytorch_model-"):
+            num_parts += 1
+
+    if num_parts > 0:
+        print("gguf: found " + str(num_parts) + " model parts")
+    return num_parts
+
+
+def parse_args() -> argparse.Namespace:
+    parser = argparse.ArgumentParser(description="Convert an MPT model to a GGML compatible file")
+    parser.add_argument(
+        "--vocab-only", action="store_true",
+        help="extract only the vocab",
+    )
+    parser.add_argument(
+        "--outfile", type=Path,
+        help="path to write to; default: based on input",
+    )
+    parser.add_argument(
+        "model", type=Path,
+        help="directory containing model file, or model file itself (*.bin)",
+    )
+    parser.add_argument(
+        "ftype", type=int, choices=[0, 1], default=1, nargs='?',
+        help="output format - use 0 for float32, 1 for float16",
+    )
+    return parser.parse_args()
+
+args = parse_args()
+
+dir_model = args.model
+ftype = args.ftype
+if not dir_model.is_dir():
+    print(f'Error: {args.model} is not a directory', file = sys.stderr)
+    sys.exit(1)
+
+# possible tensor data types
+#   ftype == 0 -> float32
+#   ftype == 1 -> float16
+
+# map from ftype to string
+ftype_str = ["f32", "f16"]
+
+if args.outfile is not None:
+    fname_out = args.outfile
+else:
+    # output in the same directory as the model by default
+    fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'
+
+print("gguf: loading model "+dir_model.name)
+
+with open(dir_model / "config.json", "r", encoding="utf-8") as f:
+    hparams = json.load(f)
+
+if hparams["architectures"][0] != "MPTForCausalLM":
+    print("Model architecture not supported: " + hparams["architectures"][0])
+
+    sys.exit()
+
+# get number of model parts
+num_parts = count_model_parts(dir_model)
+
+ARCH=gguf.MODEL_ARCH.MPT
+gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
+
+print("gguf: get model metadata")
+
+block_count = hparams["n_layers"]
+
+gguf_writer.add_name(dir_model.name)
+gguf_writer.add_context_length(hparams["max_seq_len"])
+gguf_writer.add_embedding_length(hparams["d_model"])
+gguf_writer.add_block_count(block_count)
+gguf_writer.add_feed_forward_length(4 * hparams["d_model"])
+gguf_writer.add_head_count(hparams["n_heads"])
+gguf_writer.add_layer_norm_eps(1e-05)
+if hparams["attn_config"]["clip_qkv"] is not None:
+    gguf_writer.add_clamp_kqv(hparams["attn_config"]["clip_qkv"])
+gguf_writer.add_max_alibi_bias(hparams["attn_config"]["alibi_bias_max"])
+
+# TOKENIZATION
+
+print("gguf: get tokenizer metadata")
+
+tokens: list[bytearray] = []
+scores: list[float] = []
+toktypes: list[int] = []
+
+# gpt2 tokenizer
+gguf_writer.add_tokenizer_model("gpt2")
+
+print("gguf: get gpt2 tokenizer vocab")
+
+# MPT token embedding tensors have dimension 50432 (hparams["vocab_size"]), but
+# there are only 50254 (len(tokenizer.vocab)) tokens in the vocab, presumably to
+# accomodate some "reserved" tokens; this is causing problems down the line in
+# llama.cpp, so we pad the vocab with dummy tokens:
+
+vocab_size = hparams["vocab_size"]
+
+# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
+tokenizer = AutoTokenizer.from_pretrained(dir_model)
+
+reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
+
+for i in range(vocab_size):
+    tokens.append(reverse_vocab[i] if i in reverse_vocab else f"[PAD{i}]")
+    scores.append(0.0) # dummy
+    toktypes.append(gguf.TokenType.NORMAL)
+
+gguf_writer.add_token_list(tokens)
+gguf_writer.add_token_scores(scores)
+gguf_writer.add_token_types(toktypes)
+
+special_vocab = gguf.SpecialVocab(dir_model, load_merges = True)
+special_vocab.add_to_gguf(gguf_writer)
+
+# TENSORS
+
+tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
+
+# tensor info
+print("gguf: get tensor metadata")
+
+if num_parts == 0:
+    part_names = iter(("pytorch_model.bin",))
+else:
+    part_names = (
+        f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
+    )
+
+for part_name in part_names:
+    if args.vocab_only:
+        break
+    print("gguf: loading model part '" + part_name + "'")
+    model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")
+
+    for name in model_part.keys():
+        data = model_part[name]
+
+        old_dtype = data.dtype
+
+        # convert any unsupported data types to float32
+        if data.dtype != torch.float16 and data.dtype != torch.float32:
+            data = data.to(torch.float32)
+
+        data = data.squeeze().numpy()
+
+        # map tensor names
+        new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
+        if new_name is None:
+            print("Cannot map tensor '" + name + "'")
+            continue # for the sake of compatibility with some old published models, don't quit
+            sys.exit()
+
+        n_dims = len(data.shape)
+        data_dtype = data.dtype
+
+        # if f32 desired, convert any float16 to float32
+        if ftype == 0 and data_dtype == np.float16:
+            data = data.astype(np.float32)
+
+        # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
+        if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
+            data = data.astype(np.float32)
+
+        # if f16 desired, convert any float32 2-dim weight tensors to float16
+        if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
+            data = data.astype(np.float16)
+
+        print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
+
+        gguf_writer.add_tensor(new_name, data)
+
+        # note: MPT output is tied to (same as) wte in original model;
+        # for easier implementation in llama.cpp it's duplicated in GGUF, though :/
+        if new_name == "token_embd.weight":
+            gguf_writer.add_tensor("output.weight", data)
+
+print("gguf: write header")
+gguf_writer.write_header_to_file()
+print("gguf: write metadata")
+gguf_writer.write_kv_data_to_file()
+if not args.vocab_only:
+    print("gguf: write tensors")
+    gguf_writer.write_tensors_to_file()
+
+gguf_writer.close()
+
+print(f"gguf: model successfully exported to '{fname_out}'")
+print("")
index 7e92c519741b97fd85951aa0c839ed593bbe6ba8..654d3632fc179a3608b47040e78b9886f403172b 100644 (file)
@@ -415,6 +415,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 #define CUDA_SILU_BLOCK_SIZE 256
 #define CUDA_CPY_BLOCK_SIZE 32
 #define CUDA_SCALE_BLOCK_SIZE 256
+#define CUDA_CLAMP_BLOCK_SIZE 256
 #define CUDA_ROPE_BLOCK_SIZE 256
 #define CUDA_ALIBI_BLOCK_SIZE 32
 #define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
@@ -4585,6 +4586,15 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale
     dst[i] = scale * x[i];
 }
 
+static __global__ void clamp_f32(const float * x, float * dst, const float min, const float max, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
+}
 
 template<int qk, int qr, dequantize_kernel_t dq>
 static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
@@ -5475,6 +5485,11 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
     scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
 }
 
+static void clamp_f32_cuda(const float * x, float * dst, const float min, const float max, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE;
+    clamp_f32<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k);
+}
+
 template<typename T>
 static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
                           const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
@@ -6419,12 +6434,12 @@ inline void ggml_cuda_op_alibi(
     const int64_t ne02 = src0->ne[2];
     const int64_t nrows = ggml_nrows(src0);
 
-    const int n_past = ((int32_t *) dst->op_params)[0];
+    //const int n_past = ((int32_t *) dst->op_params)[0];
     const int n_head = ((int32_t *) dst->op_params)[1];
     float max_bias;
     memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
 
-    GGML_ASSERT(ne01 + n_past == ne00);
+    //GGML_ASSERT(ne01 + n_past == ne00);
     GGML_ASSERT(n_head == ne02);
 
     const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
@@ -6500,6 +6515,24 @@ inline void ggml_cuda_op_scale(
     (void) src1_dd;
 }
 
+inline void ggml_cuda_op_clamp(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    const float min = ((float *) dst->op_params)[0];
+    const float max = ((float *) dst->op_params)[1];
+
+    clamp_f32_cuda(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream);
+    CUDA_CHECK(cudaGetLastError());
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
 static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_cuda_op_flatten_t op) {
     const int64_t nrows0 = ggml_nrows(src0);
 
@@ -7061,6 +7094,10 @@ static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1,
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
 }
 
+static void ggml_cuda_clamp(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_clamp);
+}
+
 static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     const int64_t ne = ggml_nelements(src0);
     GGML_ASSERT(ne == ggml_nelements(src1));
@@ -7470,6 +7507,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
         case GGML_OP_SCALE:
             func = ggml_cuda_scale;
             break;
+        case GGML_OP_CLAMP:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_clamp;
+            break;
         case GGML_OP_CPY:
             func = ggml_cuda_cpy;
             break;
index 5a23144d0c89133d70ed01ae8d44287ca4350c9f..87fa172161405a769893a479f53b58233501cb20 100644 (file)
@@ -1299,7 +1299,7 @@ void ggml_metal_graph_compute(
 
                             const int nth = MIN(1024, ne00);
 
-                            const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
+                            //const int n_past = ((int32_t *) dst->op_params)[0];
                             const int n_head = ((int32_t *) dst->op_params)[1];
                             float max_bias;
                             memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
diff --git a/ggml.c b/ggml.c
index 5bb1da31ba624d4e54ce083925ab8456d2c8e40c..1f5598fa6af8f93316169ecabb17237fb312368e 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -13059,13 +13059,11 @@ static void ggml_compute_forward_alibi_f32(
         return;
     }
 
-    const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
+    //const int n_past = ((int32_t *) dst->op_params)[0];
     const int n_head = ((int32_t *) dst->op_params)[1];
     float max_bias;
     memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
 
-    assert(n_past >= 0);
-
     const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
     const int64_t ne1 = src0->ne[1]; // seq_len_without_past
     const int64_t ne2 = src0->ne[2]; // n_head -> this is k
index 24f07daca6181a28c12ed528d3cf33735c55e221..3b63b64010b0f190941fe77a7677e590fb9c6467 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -424,6 +424,14 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
         LLM_ARCH_MPT,
         {
             { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
     },
     {
@@ -1011,6 +1019,9 @@ struct llama_hparams {
     float rope_freq_base_train;
     float rope_freq_scale_train;
 
+    float f_clamp_kqv;
+    float f_max_alibi_bias;
+
     bool operator!=(const llama_hparams & other) const {
         if (this->vocab_only != other.vocab_only) return true;
         if (this->n_vocab != other.n_vocab) return true;
@@ -2060,6 +2071,20 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_MPT:
+            {
+                hparams.f_clamp_kqv = 0.0f;
+
+                GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
+                GGUF_GET_KEY(ctx, hparams.f_clamp_kqv, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_CLAMP_KQV));
+                GGUF_GET_KEY(ctx, hparams.f_max_alibi_bias, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS));
+
+                switch (hparams.n_layer) {
+                    case 32: model.type = e_model::MODEL_7B; break;
+                    case 48: model.type = e_model::MODEL_30B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
         default: (void)0;
     }
 
@@ -2204,6 +2229,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     LLAMA_LOG_INFO("%s: n_gqa            = %u\n",     __func__, hparams.n_gqa());
     LLAMA_LOG_INFO("%s: f_norm_eps       = %.1e\n",   __func__, hparams.f_norm_eps);
     LLAMA_LOG_INFO("%s: f_norm_rms_eps   = %.1e\n",   __func__, hparams.f_norm_rms_eps);
+    LLAMA_LOG_INFO("%s: f_clamp_kqv      = %.1e\n",   __func__, hparams.f_clamp_kqv);
+    LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n",   __func__, hparams.f_max_alibi_bias);
     LLAMA_LOG_INFO("%s: n_ff             = %u\n",     __func__, hparams.n_ff);
     LLAMA_LOG_INFO("%s: freq_base_train  = %.1f\n",   __func__, hparams.rope_freq_base_train);
     LLAMA_LOG_INFO("%s: freq_scale_train = %g\n",     __func__, hparams.rope_freq_scale_train);
@@ -2649,6 +2676,73 @@ static void llm_load_tensors(
                         layer.attn_k_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i),   {64}, backend);
                     }
                 } break;
+            case LLM_ARCH_MPT:
+                {
+                    model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
+
+                    // output
+                    {
+                        ggml_backend_type backend_norm;
+                        ggml_backend_type backend_output;
+
+                        if (n_gpu_layers > int(n_layer)) {
+                            // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
+                            // on Windows however this is detrimental unless everything is on the GPU
+#ifndef _WIN32
+                            backend_norm = LLAMA_BACKEND_OFFLOAD;
+#else
+                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+#endif // _WIN32
+
+                            backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
+                        } else {
+                            backend_norm   = GGML_BACKEND_CPU;
+                            backend_output = GGML_BACKEND_CPU;
+                        }
+
+                        model.output_norm   = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd},          backend_norm);
+                        model.output        = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, backend_output);
+
+                        if (backend_norm == GGML_BACKEND_GPU) {
+                            vram_weights += ggml_nbytes(model.output_norm);
+                        }
+                        if (backend_output == GGML_BACKEND_GPU_SPLIT) {
+                            vram_weights += ggml_nbytes(model.output);
+                        }
+                    }
+
+                    const uint32_t n_ff = hparams.n_ff;
+
+                    const int i_gpu_start = n_layer - n_gpu_layers;
+
+                    model.layers.resize(n_layer);
+
+                    for (uint32_t i = 0; i < n_layer; ++i) {
+                        const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
+                        const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
+                        layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3*n_embd}, backend_split);
+                        layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd},     backend_split);
+
+                        layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
+
+                        layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, backend_split);
+                        layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, backend_split);
+
+                        if (backend == GGML_BACKEND_GPU) {
+                            vram_weights +=
+                                ggml_nbytes(layer.attn_norm) +
+                                ggml_nbytes(layer.wqkv)      +
+                                ggml_nbytes(layer.wo)        +
+                                ggml_nbytes(layer.ffn_norm)  +
+                                ggml_nbytes(layer.w2)        +
+                                ggml_nbytes(layer.w3);
+                        }
+                    }
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -4505,7 +4599,6 @@ static struct ggml_cgraph * llm_build_starcoder(
     return gf;
 }
 
-
 static struct ggml_cgraph * llm_build_persimmon(
          llama_context & lctx,
      const llama_batch & batch) {
@@ -4903,6 +4996,326 @@ static struct ggml_cgraph * llm_build_persimmon(
     return gf;
 }
 
+static struct ggml_cgraph * llm_build_mpt(
+         llama_context & lctx,
+     const llama_batch & batch) {
+    const auto & model   = lctx.model;
+    const auto & hparams = model.hparams;
+    const auto & cparams = lctx.cparams;
+
+    const auto & kv_self = lctx.kv_self;
+
+    GGML_ASSERT(!!kv_self.ctx);
+
+    const int64_t n_embd      = hparams.n_embd;
+    const int64_t n_layer     = hparams.n_layer;
+    const int64_t n_ctx       = cparams.n_ctx;
+    const int64_t n_head      = hparams.n_head;
+    const int64_t n_head_kv   = hparams.n_head_kv; // == n_head for MPT, as there's no MQA/GQA
+    const int64_t n_embd_head = hparams.n_embd_head();
+    const int64_t n_embd_gqa  = hparams.n_embd_gqa();
+
+    const float norm_eps       = hparams.f_norm_eps;
+    const float clamp_kqv      = hparams.f_clamp_kqv;
+    const float max_alibi_bias = hparams.f_max_alibi_bias;
+
+    const int n_gpu_layers = model.n_gpu_layers;
+
+    const int32_t n_tokens = batch.n_tokens;
+    const int32_t n_kv     = ggml_allocr_is_measure(lctx.alloc) ? n_ctx            : kv_self.n;
+    const int32_t kv_head  = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
+
+    //printf("kv_head = %d, n_kv = %d, n_tokens = %d, n_ctx = %d, is_measure = %d, has_shift = %d\n",
+    //        kv_head, n_kv, n_tokens, n_ctx, ggml_allocr_is_measure(lctx.alloc), kv_self.has_shift);
+
+    auto & buf_compute = lctx.buf_compute;
+
+    struct ggml_init_params params = {
+        /*.mem_size   =*/ buf_compute.size,
+        /*.mem_buffer =*/ buf_compute.data,
+        /*.no_alloc   =*/ false,
+    };
+
+    params.no_alloc = true;
+
+    struct ggml_context * ctx0 = ggml_init(params);
+
+    ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+    struct ggml_tensor * cur;
+    struct ggml_tensor * inpL;
+
+    //int warmup = 0;
+    if (batch.token) {
+        struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+
+        ggml_allocr_alloc(lctx.alloc, inp_tokens);
+        if (!ggml_allocr_is_measure(lctx.alloc)) {
+            memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens));
+            //warmup = ((uint32_t*) inp_tokens->data)[0] == 0;
+        }
+
+        ggml_set_name(inp_tokens, "inp_tokens");
+
+        inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
+    } else {
+#ifdef GGML_USE_MPI
+        GGML_ASSERT(false && "not implemented");
+#endif
+
+        inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
+
+        ggml_allocr_alloc(lctx.alloc, inpL);
+        if (!ggml_allocr_is_measure(lctx.alloc)) {
+            memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL));
+        }
+    }
+
+    const int i_gpu_start = n_layer - n_gpu_layers;
+    (void) i_gpu_start;
+
+    // offload functions set the tensor output backend to GPU
+    // tensors are GPU-accelerated if any input or the output has been offloaded
+    offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
+    offload_func_t offload_func_kq = llama_nop;
+    offload_func_t offload_func_v  = llama_nop;
+
+#ifdef GGML_USE_CUBLAS
+    if (n_gpu_layers > n_layer) {
+        offload_func_nr = ggml_cuda_assign_buffers_no_alloc;
+    }
+    if (n_gpu_layers > n_layer + 1) {
+        offload_func_v  = ggml_cuda_assign_buffers_no_alloc;
+    }
+    if (n_gpu_layers > n_layer + 2) {
+        offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
+    }
+#endif // GGML_USE_CUBLAS
+
+    // KQ_scale
+    struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+    ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
+    ggml_allocr_alloc(lctx.alloc, KQ_scale);
+    if (!ggml_allocr_is_measure(lctx.alloc)) {
+        ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
+    }
+
+    // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+    struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+    offload_func_kq(KQ_mask);
+    ggml_set_name(KQ_mask, "KQ_mask");
+    ggml_allocr_alloc(lctx.alloc, KQ_mask);
+    if (!ggml_allocr_is_measure(lctx.alloc)) {
+        float * data = (float *) KQ_mask->data;
+        memset(data, 0, ggml_nbytes(KQ_mask));
+
+        for (int h = 0; h < 1; ++h) {
+            for (int j = 0; j < n_tokens; ++j) {
+                const llama_pos    pos    = batch.pos[j];
+                const llama_seq_id seq_id = batch.seq_id[j];
+
+                for (int i = 0; i < n_kv; ++i) {
+                    if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
+                        data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
+                    }
+                }
+            }
+        }
+    }
+
+    for (int il = 0; il < n_layer; ++il) {
+        struct ggml_tensor * attn_norm;
+
+        offload_func_t offload_func = llama_nop;
+
+#ifdef GGML_USE_CUBLAS
+        if (il >= i_gpu_start) {
+            offload_func = ggml_cuda_assign_buffers_no_alloc;
+        }
+#endif // GGML_USE_CUBLAS
+
+        // self-attention
+        // TODO: refactor into common function (shared with LLaMA)
+        {
+            attn_norm = ggml_norm(ctx0, inpL, norm_eps);
+            offload_func(attn_norm);
+
+            attn_norm = ggml_mul(ctx0, attn_norm, model.layers[il].attn_norm);
+            offload_func(attn_norm);
+
+            if (1) {
+                cur = attn_norm;
+            }
+
+            // compute QKV
+
+            cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+            offload_func_kq(cur);
+
+            if (clamp_kqv > 0.0f) {
+                cur = ggml_clamp(ctx0, cur, -clamp_kqv, clamp_kqv);
+                offload_func_kq(cur);
+            }
+
+            const size_t wsize = ggml_type_size(cur->type);
+
+            struct ggml_tensor * Qcur = ggml_view_3d(
+                ctx0, cur, n_embd_head, n_head, n_tokens,
+                wsize * n_embd_head,
+                wsize * n_embd_head * (n_head + 2 * n_head_kv),
+                0);
+            offload_func_kq(Qcur);
+
+            struct ggml_tensor * Kcur = ggml_view_3d(
+                ctx0, cur, n_embd_head, n_head_kv, n_tokens,
+                wsize * n_embd_head,
+                wsize * n_embd_head * (n_head + 2 * n_head_kv),
+                wsize * n_embd_head *  n_head);
+            offload_func_kq(Kcur);
+
+            struct ggml_tensor * tmpv = ggml_view_3d(
+                ctx0, cur, n_embd_head, n_head_kv, n_tokens,
+                wsize * n_embd_head,
+                wsize * n_embd_head * (n_head + 2 * n_head_kv),
+                wsize * n_embd_head * (n_head +     n_head_kv));
+            offload_func_kq(Kcur);
+
+            ggml_set_name(Qcur, "Qcur");
+            ggml_set_name(Kcur, "Kcur");
+
+            {
+                struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens));
+                offload_func_v(Vcur);
+                offload_func_v(Vcur->src[0]->src[0]);
+                ggml_set_name(Vcur, "Vcur");
+
+                struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head));
+                offload_func_kq(k);
+                ggml_set_name(k, "k");
+
+                struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa,
+                        (   n_ctx)*ggml_element_size(kv_self.v),
+                        (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
+                offload_func_v(v);
+
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
+            }
+
+            struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
+            offload_func_kq(Q);
+            ggml_set_name(Q, "Q");
+
+            struct ggml_tensor * K =
+                ggml_view_3d(ctx0, kv_self.k,
+                        n_embd_head, n_kv, n_head_kv,
+                        ggml_element_size(kv_self.k)*n_embd_gqa,
+                        ggml_element_size(kv_self.k)*n_embd_head,
+                        ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
+            offload_func_kq(K);
+            ggml_set_name(K, "K");
+
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+            offload_func_kq(KQ);
+            ggml_set_name(KQ, "KQ");
+
+            struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
+            offload_func_kq(KQ_scaled);
+            ggml_set_name(KQ_scaled, "KQ_scaled");
+
+            // TODO: replace with ggml_add()
+            struct ggml_tensor * KQ_scaled_alibi =
+                ggml_alibi(ctx0, KQ_scaled, 0, n_head, max_alibi_bias);
+            offload_func_kq(KQ_scaled_alibi);
+            ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi");
+
+            struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask);
+            offload_func_kq(KQ_masked);
+            ggml_set_name(KQ_masked, "KQ_masked");
+
+            struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
+            offload_func_v(KQ_soft_max);
+            ggml_set_name(KQ_soft_max, "KQ_soft_max");
+
+            struct ggml_tensor * V =
+                ggml_view_3d(ctx0, kv_self.v,
+                        n_kv, n_embd_head, n_head_kv,
+                        ggml_element_size(kv_self.v)*n_ctx,
+                        ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
+                        ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
+            offload_func_v(V);
+            ggml_set_name(V, "V");
+
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+            offload_func_v(KQV);
+            ggml_set_name(KQV, "KQV");
+
+            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+            offload_func_v(KQV_merged);
+            ggml_set_name(KQV_merged, "KQV_merged");
+
+            cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens);
+            offload_func_v(cur);
+            ggml_set_name(cur, "KQV_merged_contiguous");
+
+            cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
+            offload_func(cur);
+            ggml_set_name(cur, "result_wo");
+        }
+
+        // Add the input
+        cur = ggml_add(ctx0, cur, inpL);
+        offload_func(cur);
+
+        struct ggml_tensor * attn_out = cur;
+
+        // feed forward
+        {
+            // Norm
+            {
+                cur = ggml_norm(ctx0, attn_out, norm_eps);
+                offload_func(cur);
+
+                cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
+                offload_func(cur);
+            }
+
+            cur = ggml_mul_mat(ctx0, model.layers[il].w3, cur);
+            offload_func(cur);
+
+            cur = ggml_gelu(ctx0, cur);
+            offload_func(cur);
+            cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur);
+            offload_func(cur);
+        }
+
+        cur = ggml_add(ctx0, cur, attn_out);
+        offload_func(cur);
+        // input for next layer
+        inpL = cur;
+    }
+
+    cur = inpL;
+
+    // norm
+    {
+        cur = ggml_norm(ctx0, cur, norm_eps);
+        offload_func_nr(cur);
+
+        cur = ggml_mul(ctx0, cur, model.output_norm);
+        ggml_set_name(cur, "result_norm");
+    }
+
+    cur = ggml_mul_mat(ctx0, model.output, cur);
+    ggml_set_name(cur, "result_output");
+
+    ggml_build_forward_expand(gf, cur);
+
+    ggml_free(ctx0);
+
+    return gf;
+}
+
 static struct ggml_cgraph * llama_build_graph(
          llama_context & lctx,
      const llama_batch & batch) {
@@ -4935,6 +5348,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm_build_refact(lctx, batch);
             } break;
+        case LLM_ARCH_MPT:
+            {
+                result = llm_build_mpt(lctx, batch);
+            } break;
         default:
             GGML_ASSERT(false);
     }
@@ -5065,7 +5482,8 @@ static int llama_decode_internal(
     const bool full_offload_supported = model.arch == LLM_ARCH_LLAMA ||
         model.arch == LLM_ARCH_BAICHUAN ||
         model.arch == LLM_ARCH_FALCON ||
-        model.arch == LLM_ARCH_REFACT;
+        model.arch == LLM_ARCH_REFACT ||
+        model.arch == LLM_ARCH_MPT;
     const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 3;
     if (ggml_cpu_has_cublas() && full_offload_supported && fully_offloaded) {
         n_threads = 1;
@@ -7161,7 +7579,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         const std::string name = ggml_get_name(meta);
 
         // TODO: avoid hardcoded tensor names - use the TN_* constants
-        if (name.find("attn_v.weight") != std::string::npos) {
+        if (name.find("attn_v.weight") != std::string::npos ||
+            name.find("attn_qkv.weight") != std::string::npos) {
             ++n_attention_wv;
         }
         else if (name.find("ffn_down.weight") != std::string::npos) {