]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llm : add bloom models (#3553)
authorXingchen Song(宋星辰) <redacted>
Tue, 10 Oct 2023 14:48:21 +0000 (22:48 +0800)
committerGitHub <redacted>
Tue, 10 Oct 2023 14:48:21 +0000 (17:48 +0300)
* feat: Support bloom models

* fix(bloom): fix model size

---------

Co-authored-by: Georgi Gerganov <redacted>
convert-bloom-hf-to-gguf.py [new file with mode: 0755]
gguf-py/gguf/gguf.py
llama.cpp

diff --git a/convert-bloom-hf-to-gguf.py b/convert-bloom-hf-to-gguf.py
new file mode 100755 (executable)
index 0000000..7bfc95e
--- /dev/null
@@ -0,0 +1,238 @@
+#!/usr/bin/env python3
+# HF bloom --> gguf conversion
+
+from __future__ import annotations
+
+import argparse
+import json
+import os
+import re
+import struct
+import sys
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer  # type: ignore[import]
+
+if 'NO_LOCAL_GGUF' not in os.environ:
+    sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
+import gguf
+
+
+def count_model_parts(dir_model: Path) -> int:
+    num_parts = 0
+    for filename in os.listdir(dir_model):
+        if filename.startswith("pytorch_model-"):
+            num_parts += 1
+
+    if num_parts > 0:
+        print("gguf: found " + str(num_parts) + " model parts")
+    return num_parts
+
+
+# Supported Models:
+#   https://huggingface.co/bigscience/bloom-1b7
+#   https://huggingface.co/bigscience/bloom-3b
+#   https://huggingface.co/bigscience/bloom-7b1
+#   https://huggingface.co/Langboat/bloom-1b4-zh
+def parse_args() -> argparse.Namespace:
+    parser = argparse.ArgumentParser(description="Convert a Bloom model to a GGML compatible file")
+    parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
+    parser.add_argument("--outfile",    type=Path,           help="path to write to; default: based on input")
+    parser.add_argument("model",        type=Path,           help="directory containing model file, or model file itself (*.bin)")
+    parser.add_argument("ftype",        type=int,            help="output format - use 0 for float32, 1 for float16", choices=[0, 1], default = 1)
+    return parser.parse_args()
+
+args = parse_args()
+
+dir_model = args.model
+ftype = args.ftype
+if not dir_model.is_dir():
+    print(f'Error: {args.model} is not a directory', file = sys.stderr)
+    sys.exit(1)
+
+# possible tensor data types
+#   ftype == 0 -> float32
+#   ftype == 1 -> float16
+
+# map from ftype to string
+ftype_str = ["f32", "f16"]
+
+if args.outfile is not None:
+    fname_out = args.outfile
+else:
+    # output in the same directory as the model by default
+    fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'
+
+print("gguf: loading model "+dir_model.name)
+
+with open(dir_model / "config.json", "r", encoding="utf-8") as f:
+    hparams = json.load(f)
+
+if hparams["architectures"][0] != "BloomForCausalLM":
+    print("Model architecture not supported: " + hparams["architectures"][0])
+    sys.exit(1)
+
+# get number of model parts
+num_parts = count_model_parts(dir_model)
+
+ARCH=gguf.MODEL_ARCH.BLOOM
+gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
+
+print("gguf: get model metadata")
+
+block_count = hparams["n_layer"]
+
+gguf_writer.add_name("Bloom")
+n_embed = hparams.get("hidden_size", hparams.get("n_embed"))
+n_head = hparams.get("n_head", hparams.get("num_attention_heads"))
+gguf_writer.add_context_length(hparams.get("seq_length", n_embed))
+gguf_writer.add_embedding_length(n_embed)
+gguf_writer.add_feed_forward_length(4 * n_embed)
+gguf_writer.add_block_count(block_count)
+gguf_writer.add_head_count(n_head)
+gguf_writer.add_head_count_kv(n_head)
+gguf_writer.add_layer_norm_eps(hparams["layer_norm_epsilon"])
+gguf_writer.add_file_type(ftype)
+
+# TOKENIZATION
+
+print("gguf: get tokenizer metadata")
+
+tokens: list[bytearray] = []
+scores: list[float] = []
+toktypes: list[int] = []
+
+# gpt2 tokenizer
+gguf_writer.add_tokenizer_model("gpt2")
+
+print("gguf: get gpt2 tokenizer vocab")
+
+# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
+tokenizer = AutoTokenizer.from_pretrained(dir_model)
+
+# The number of tokens in tokenizer.json can differ from the expected vocab size.
+# This causes downstream issues with mismatched tensor sizes when running the inference
+vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
+assert max(tokenizer.vocab.values()) < vocab_size
+
+reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
+
+for i in range(vocab_size):
+    tokens.append(reverse_vocab[i] if i in reverse_vocab else f"[PAD{i}]")
+    scores.append(0.0)  # dummy
+    toktypes.append(gguf.TokenType.NORMAL)
+
+gguf_writer.add_token_list(tokens)
+gguf_writer.add_token_scores(scores)
+gguf_writer.add_token_types(toktypes)
+
+special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
+special_vocab.add_to_gguf(gguf_writer)
+
+# TENSORS
+
+tensor_map = gguf.get_tensor_name_map(ARCH, block_count)
+
+# params for qkv transform
+n_head_kv = hparams.get("n_head_kv", n_head)
+head_dim = n_embed // n_head
+
+# tensor info
+print("gguf: get tensor metadata")
+
+if num_parts == 0:
+    part_names = iter(("pytorch_model.bin",))
+else:
+    part_names = (
+        f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
+    )
+
+for part_name in part_names:
+    if args.vocab_only:
+        break
+    print("gguf: loading model part '" + part_name + "'")
+    model_part = torch.load(dir_model / part_name, map_location="cpu")
+
+    has_lm_head = True
+    if "lm_head.weight" not in model_part.keys() and "output.weight" not in model_part.keys():
+        has_lm_head = False
+
+    for original_name in model_part.keys():
+        data = model_part[original_name]
+        name = re.sub(r'transformer\.', '', original_name)
+
+        old_dtype = data.dtype
+
+        # convert any unsupported data types to float32
+        if data.dtype != torch.float16 and data.dtype != torch.float32:
+            data = data.to(torch.float32)
+
+        data = data.squeeze().numpy()
+
+        if re.match(r"h\.\d+\.self_attention\.query_key_value\.weight", name):
+            # Map bloom-style qkv_linear to gpt-style qkv_linear
+            # bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252  # noqa
+            # gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312  # noqa
+            qkv_weights = data.reshape((n_head, 3, n_embed // n_head, n_embed))
+            data = np.concatenate(
+                (qkv_weights[:, 0, :, :].reshape((-1, n_embed)),
+                 qkv_weights[:, 1, :, :].reshape((-1, n_embed)),
+                 qkv_weights[:, 2, :, :].reshape((-1, n_embed))),
+                axis=0
+            )
+            print("re-format attention.linear_qkv.weight")
+        elif re.match(r"h\.\d+\.self_attention\.query_key_value\.bias", name):
+            qkv_bias = data.reshape((n_head, 3, n_embed // n_head))
+            data = np.concatenate(
+                (qkv_bias[:, 0, :].reshape((n_embed,)),
+                 qkv_bias[:, 1, :].reshape((n_embed,)),
+                 qkv_bias[:, 2, :].reshape((n_embed,))),
+                axis=0
+            )
+            print("re-format attention.linear_qkv.bias")
+
+        # map tensor names
+        new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
+        if new_name is None:
+            print("Can not map tensor '" + name + "'")
+            sys.exit()
+
+        n_dims = len(data.shape)
+        data_dtype = data.dtype
+
+        # if f32 desired, convert any float16 to float32
+        if ftype == 0 and data_dtype == np.float16:
+            data = data.astype(np.float32)
+
+        # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
+        if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
+            data = data.astype(np.float32)
+
+        # if f16 desired, convert any float32 2-dim weight tensors to float16
+        if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
+            data = data.astype(np.float16)
+
+        print(name, "=>", new_name + ", shape = " + str(data.shape) + ", " + str(old_dtype) + " --> " + str(data.dtype))
+
+        gguf_writer.add_tensor(new_name, data)
+
+        if not has_lm_head and name == "word_embeddings.weight":
+            gguf_writer.add_tensor("output.weight", data)
+            print(name, "=>", "output.weight" + ", shape = " + str(data.shape) + ", " + str(old_dtype) + " --> " + str(data.dtype))  # noqa
+
+
+print("gguf: write header")
+gguf_writer.write_header_to_file()
+print("gguf: write metadata")
+gguf_writer.write_kv_data_to_file()
+if not args.vocab_only:
+    print("gguf: write tensors")
+    gguf_writer.write_tensors_to_file()
+
+gguf_writer.close()
+
+print(f"gguf: model successfully exported to '{fname_out}'")
+print("")
index fb677a6ed728393ac0b8508128b3647d9062b4cd..557ce7ac0173c036f7b6e83ef0e4ae63f8379b4f 100644 (file)
@@ -88,29 +88,31 @@ class MODEL_ARCH(IntEnum):
     PERSIMMON     : int = auto()
     REFACT        : int = auto()
     BERT          : int = auto()
+    BLOOM         : int = auto()
 
 
 class MODEL_TENSOR(IntEnum):
-    TOKEN_EMBD   : int = auto()
-    TOKEN_TYPES  : int = auto()
-    POS_EMBD     : int = auto()
-    OUTPUT       : int = auto()
-    OUTPUT_NORM  : int = auto()
-    ROPE_FREQS   : int = auto()
-    ATTN_Q       : int = auto()
-    ATTN_K       : int = auto()
-    ATTN_V       : int = auto()
-    ATTN_QKV     : int = auto()
-    ATTN_OUT     : int = auto()
-    ATTN_NORM    : int = auto()
-    ATTN_NORM_2  : int = auto()
-    ATTN_ROT_EMBD: int = auto()
-    FFN_GATE     : int = auto()
-    FFN_DOWN     : int = auto()
-    FFN_UP       : int = auto()
-    FFN_NORM     : int = auto()
-    ATTN_Q_NORM  : int = auto()
-    ATTN_K_NORM  : int = auto()
+    TOKEN_EMBD      : int = auto()
+    TOKEN_EMBD_NORM : int = auto()
+    TOKEN_TYPES     : int = auto()
+    POS_EMBD        : int = auto()
+    OUTPUT          : int = auto()
+    OUTPUT_NORM     : int = auto()
+    ROPE_FREQS      : int = auto()
+    ATTN_Q          : int = auto()
+    ATTN_K          : int = auto()
+    ATTN_V          : int = auto()
+    ATTN_QKV        : int = auto()
+    ATTN_OUT        : int = auto()
+    ATTN_NORM       : int = auto()
+    ATTN_NORM_2     : int = auto()
+    ATTN_ROT_EMBD   : int = auto()
+    FFN_GATE        : int = auto()
+    FFN_DOWN        : int = auto()
+    FFN_UP          : int = auto()
+    FFN_NORM        : int = auto()
+    ATTN_Q_NORM     : int = auto()
+    ATTN_K_NORM     : int = auto()
 
 
 MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@@ -125,29 +127,31 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.PERSIMMON:      "persimmon",
     MODEL_ARCH.REFACT:         "refact",
     MODEL_ARCH.BERT:           "bert",
+    MODEL_ARCH.BLOOM:          "bloom",
 }
 
 TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
-    MODEL_TENSOR.TOKEN_EMBD:    "token_embd",
-    MODEL_TENSOR.TOKEN_TYPES:   "token_types",
-    MODEL_TENSOR.POS_EMBD:      "position_embd",
-    MODEL_TENSOR.OUTPUT_NORM:   "output_norm",
-    MODEL_TENSOR.OUTPUT:        "output",
-    MODEL_TENSOR.ROPE_FREQS:    "rope_freqs",
-    MODEL_TENSOR.ATTN_NORM:     "blk.{bid}.attn_norm",
-    MODEL_TENSOR.ATTN_NORM_2:   "blk.{bid}.attn_norm_2",
-    MODEL_TENSOR.ATTN_QKV:      "blk.{bid}.attn_qkv",
-    MODEL_TENSOR.ATTN_Q:        "blk.{bid}.attn_q",
-    MODEL_TENSOR.ATTN_K:        "blk.{bid}.attn_k",
-    MODEL_TENSOR.ATTN_V:        "blk.{bid}.attn_v",
-    MODEL_TENSOR.ATTN_OUT:      "blk.{bid}.attn_output",
-    MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
-    MODEL_TENSOR.ATTN_Q_NORM:   "blk.{bid}.attn_q_norm",
-    MODEL_TENSOR.ATTN_K_NORM:   "blk.{bid}.attn_k_norm",
-    MODEL_TENSOR.FFN_NORM:      "blk.{bid}.ffn_norm",
-    MODEL_TENSOR.FFN_GATE:      "blk.{bid}.ffn_gate",
-    MODEL_TENSOR.FFN_DOWN:      "blk.{bid}.ffn_down",
-    MODEL_TENSOR.FFN_UP:        "blk.{bid}.ffn_up",
+    MODEL_TENSOR.TOKEN_EMBD:      "token_embd",
+    MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm",
+    MODEL_TENSOR.TOKEN_TYPES:     "token_types",
+    MODEL_TENSOR.POS_EMBD:        "position_embd",
+    MODEL_TENSOR.OUTPUT_NORM:     "output_norm",
+    MODEL_TENSOR.OUTPUT:          "output",
+    MODEL_TENSOR.ROPE_FREQS:      "rope_freqs",
+    MODEL_TENSOR.ATTN_NORM:       "blk.{bid}.attn_norm",
+    MODEL_TENSOR.ATTN_NORM_2:     "blk.{bid}.attn_norm_2",
+    MODEL_TENSOR.ATTN_QKV:        "blk.{bid}.attn_qkv",
+    MODEL_TENSOR.ATTN_Q:          "blk.{bid}.attn_q",
+    MODEL_TENSOR.ATTN_K:          "blk.{bid}.attn_k",
+    MODEL_TENSOR.ATTN_V:          "blk.{bid}.attn_v",
+    MODEL_TENSOR.ATTN_OUT:        "blk.{bid}.attn_output",
+    MODEL_TENSOR.ATTN_ROT_EMBD:   "blk.{bid}.attn_rot_embd",
+    MODEL_TENSOR.ATTN_Q_NORM:     "blk.{bid}.attn_q_norm",
+    MODEL_TENSOR.ATTN_K_NORM:     "blk.{bid}.attn_k_norm",
+    MODEL_TENSOR.FFN_NORM:        "blk.{bid}.ffn_norm",
+    MODEL_TENSOR.FFN_GATE:        "blk.{bid}.ffn_gate",
+    MODEL_TENSOR.FFN_DOWN:        "blk.{bid}.ffn_down",
+    MODEL_TENSOR.FFN_UP:          "blk.{bid}.ffn_up",
 }
 
 MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@@ -282,6 +286,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN,
         MODEL_TENSOR.FFN_UP,
     ],
+    MODEL_ARCH.BLOOM: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.TOKEN_EMBD_NORM,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_QKV,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
     MODEL_ARCH.GPT2: [
         # TODO
     ],
@@ -311,6 +327,7 @@ class TensorNameMap:
             "gpt_neox.embed_in",                        # gptneox
             "transformer.wte",                          # gpt2 gpt-j mpt refact
             "transformer.word_embeddings",              # falcon
+            "word_embeddings",                          # bloom
             "model.embed_tokens",                       # llama-hf
             "tok_embeddings",                           # llama-pth
             "embeddings.word_embeddings",               # bert
@@ -322,6 +339,11 @@ class TensorNameMap:
             "embeddings.token_type_embeddings",  # bert
         ),
 
+        # Normalization of token embeddings
+        MODEL_TENSOR.TOKEN_EMBD_NORM: (
+            "word_embeddings_layernorm",  # bloom
+        ),
+
         # Position embeddings
         MODEL_TENSOR.POS_EMBD: (
             "transformer.wpe",                 # gpt2
@@ -332,7 +354,7 @@ class TensorNameMap:
         MODEL_TENSOR.OUTPUT: (
             "embed_out",                # gptneox
             "lm_head",                  # gpt2 mpt falcon llama-hf baichuan
-            "output",                   # llama-pth
+            "output",                   # llama-pth bloom
             "word_embeddings_for_head", # persimmon
         ),
 
@@ -344,7 +366,7 @@ class TensorNameMap:
             "norm",                                   # llama-pth
             "embeddings.LayerNorm",                   # bert
             "transformer.norm_f",                     # mpt
-            "ln_f",                                   # refact
+            "ln_f",                                   # refact bloom
             "language_model.encoder.final_layernorm", # persimmon
         ),
 
@@ -361,6 +383,7 @@ class TensorNameMap:
             "transformer.h.{bid}.ln_1",                            # gpt2 gpt-j refact
             "transformer.blocks.{bid}.norm_1",                     # mpt
             "transformer.h.{bid}.input_layernorm",                 # falcon7b
+            "h.{bid}.input_layernorm",                             # bloom
             "transformer.h.{bid}.ln_mlp",                          # falcon40b
             "model.layers.{bid}.input_layernorm",                  # llama-hf
             "layers.{bid}.attention_norm",                         # llama-pth
@@ -379,6 +402,7 @@ class TensorNameMap:
             "transformer.h.{bid}.attn.c_attn",                                    # gpt2
             "transformer.blocks.{bid}.attn.Wqkv",                                 # mpt
             "transformer.h.{bid}.self_attention.query_key_value",                 # falcon
+            "h.{bid}.self_attention.query_key_value",                             # bloom
             "language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
         ),
 
@@ -412,6 +436,7 @@ class TensorNameMap:
             "transformer.h.{bid}.attn.c_proj",                         # gpt2 refact
             "transformer.blocks.{bid}.attn.out_proj",                  # mpt
             "transformer.h.{bid}.self_attention.dense",                # falcon
+            "h.{bid}.self_attention.dense",                            # bloom
             "model.layers.{bid}.self_attn.o_proj",                     # llama-hf
             "layers.{bid}.attention.wo",                               # llama-pth
             "encoder.layer.{bid}.attention.output.dense",              # bert
@@ -429,6 +454,7 @@ class TensorNameMap:
         MODEL_TENSOR.FFN_NORM: (
             "gpt_neox.layers.{bid}.post_attention_layernorm",               # gptneox
             "transformer.h.{bid}.ln_2",                                     # gpt2 refact
+            "h.{bid}.post_attention_layernorm",                             # bloom
             "transformer.blocks.{bid}.norm_2",                              # mpt
             "model.layers.{bid}.post_attention_layernorm",                  # llama-hf
             "layers.{bid}.ffn_norm",                                        # llama-pth
@@ -442,6 +468,7 @@ class TensorNameMap:
             "transformer.h.{bid}.mlp.c_fc",                          # gpt2
             "transformer.blocks.{bid}.ffn.up_proj",                  # mpt
             "transformer.h.{bid}.mlp.dense_h_to_4h",                 # falcon
+            "h.{bid}.mlp.dense_h_to_4h",                             # bloom
             "model.layers.{bid}.mlp.up_proj",                        # llama-hf refact
             "layers.{bid}.feed_forward.w3",                          # llama-pth
             "encoder.layer.{bid}.intermediate.dense",                # bert
@@ -461,6 +488,7 @@ class TensorNameMap:
             "transformer.h.{bid}.mlp.c_proj",                        # gpt2 refact
             "transformer.blocks.{bid}.ffn.down_proj",                # mpt
             "transformer.h.{bid}.mlp.dense_4h_to_h",                 # falcon
+            "h.{bid}.mlp.dense_4h_to_h",                             # bloom
             "model.layers.{bid}.mlp.down_proj",                      # llama-hf
             "layers.{bid}.feed_forward.w2",                          # llama-pth
             "encoder.layer.{bid}.output.dense",                      # bert
index 3b63b64010b0f190941fe77a7677e590fb9c6467..4653c80232c5cd1bedcb4ae3816c5aacfcefe251 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -188,6 +188,7 @@ enum llm_arch {
     LLM_ARCH_STARCODER,
     LLM_ARCH_PERSIMMON,
     LLM_ARCH_REFACT,
+    LLM_ARCH_BLOOM,
     LLM_ARCH_UNKNOWN,
 };
 
@@ -201,7 +202,8 @@ static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
     { LLM_ARCH_BAICHUAN,        "baichuan"  },
     { LLM_ARCH_STARCODER,       "starcoder" },
     { LLM_ARCH_PERSIMMON,       "persimmon" },
-    { LLM_ARCH_REFACT,          "refact" },
+    { LLM_ARCH_REFACT,          "refact"    },
+    { LLM_ARCH_BLOOM,           "bloom"     },
 };
 
 enum llm_kv {
@@ -304,6 +306,7 @@ struct LLM_KV {
 
 enum llm_tensor {
     LLM_TENSOR_TOKEN_EMBD,
+    LLM_TENSOR_TOKEN_EMBD_NORM,
     LLM_TENSOR_POS_EMBD,
     LLM_TENSOR_OUTPUT,
     LLM_TENSOR_OUTPUT_NORM,
@@ -466,6 +469,21 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_BLOOM,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+        },
+    },
     {
         LLM_ARCH_UNKNOWN,
         {
@@ -1207,6 +1225,8 @@ struct llama_model {
 
     struct ggml_tensor * tok_embeddings;
     struct ggml_tensor * pos_embeddings;
+    struct ggml_tensor * tok_norm;
+    struct ggml_tensor * tok_norm_b;
 
     struct ggml_tensor * output_norm;
     struct ggml_tensor * output_norm_b;
@@ -2056,13 +2076,13 @@ static void llm_load_hparams(
                 }
             } break;
         case LLM_ARCH_PERSIMMON:
-        {
-            GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
-            switch (hparams.n_layer) {
-                case 36: model.type = e_model::MODEL_8B; break;
-                default: model.type = e_model::MODEL_UNKNOWN;
-            }
-        } break;
+            {
+                GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
+                switch (hparams.n_layer) {
+                    case 36: model.type = e_model::MODEL_8B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_REFACT:
             {
                 GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
@@ -2071,6 +2091,19 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_BLOOM:
+            {
+                GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
+
+                switch (hparams.n_layer) {
+                    case 24: model.type = e_model::MODEL_1B; break;
+                    case 30:
+                        switch (hparams.n_embd) {
+                            case 2560: model.type = e_model::MODEL_3B; break;
+                            case 4096: model.type = e_model::MODEL_7B; break;
+                        } break;
+                }
+            } break;
         case LLM_ARCH_MPT:
             {
                 hparams.f_clamp_kqv = 0.0f;
@@ -2676,6 +2709,88 @@ static void llm_load_tensors(
                         layer.attn_k_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i),   {64}, backend);
                     }
                 } break;
+            case LLM_ARCH_BLOOM:
+                {
+                    // TODO: CPU-only for now
+
+                    model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
+                    model.tok_norm       = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd},          GGML_BACKEND_CPU);
+                    model.tok_norm_b     = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd},          GGML_BACKEND_CPU);
+
+                    // output
+                    {
+                        ggml_backend_type backend_norm;
+                        ggml_backend_type backend_output;
+
+                        if (n_gpu_layers > int(n_layer)) {
+                            // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
+                            // on Windows however this is detrimental unless everything is on the GPU
+#ifndef _WIN32
+                            backend_norm = LLAMA_BACKEND_OFFLOAD;
+#else
+                            backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+#endif // _WIN32
+
+                            backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
+                        } else {
+                            backend_norm   = GGML_BACKEND_CPU;
+                            backend_output = GGML_BACKEND_CPU;
+                        }
+
+                        model.output_norm   = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd},          backend_norm);
+                        model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd},          backend_norm);
+                        model.output        = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, backend_output);
+
+                        if (backend_norm == GGML_BACKEND_GPU) {
+                            vram_weights += ggml_nbytes(model.output_norm);
+                            vram_weights += ggml_nbytes(model.output_norm_b);
+                        }
+                        if (backend_output == GGML_BACKEND_GPU_SPLIT) {
+                            vram_weights += ggml_nbytes(model.output);
+                        }
+                    }
+
+                    const uint32_t n_ff = hparams.n_ff;
+
+                    const int i_gpu_start = n_layer - n_gpu_layers;
+
+                    model.layers.resize(n_layer);
+
+                    for (uint32_t i = 0; i < n_layer; ++i) {
+                        const ggml_backend_type backend       = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
+                        const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm   = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM,   "weight", i), {n_embd}, backend);
+                        layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM,   "bias", i),   {n_embd}, backend);
+
+                        layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
+                        layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa},         backend_split);
+
+                        layer.wo   = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd},                backend_split);
+                        layer.bo   = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd},                        backend_split);
+
+                        layer.ffn_norm   = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
+                        layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, backend);
+
+                        layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split);
+                        layer.b2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd},       backend_split);
+
+                        layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, backend_split);
+                        layer.b3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff},           backend_split);
+
+                        if (backend == GGML_BACKEND_GPU) {
+                            vram_weights +=
+                                ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) +
+                                ggml_nbytes(layer.wqkv)      + ggml_nbytes(layer.bqkv)        +
+                                ggml_nbytes(layer.wo)        + ggml_nbytes(layer.bo)          +
+                                ggml_nbytes(layer.ffn_norm)  + ggml_nbytes(layer.ffn_norm_b)  +
+                                ggml_nbytes(layer.w3)        + ggml_nbytes(layer.b3)          +
+                                ggml_nbytes(layer.w2)        + ggml_nbytes(layer.b2);
+                        }
+                    }
+                } break;
             case LLM_ARCH_MPT:
                 {
                     model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
@@ -4996,6 +5111,248 @@ static struct ggml_cgraph * llm_build_persimmon(
     return gf;
 }
 
+static struct ggml_cgraph * llm_build_bloom(
+         llama_context & lctx,
+     const llama_batch & batch) {
+    const auto & model   = lctx.model;
+    const auto & hparams = model.hparams;
+    const auto & cparams = lctx.cparams;
+
+    const auto & kv_self = lctx.kv_self;
+
+    GGML_ASSERT(!!kv_self.ctx);
+
+    const int64_t n_embd      = hparams.n_embd;
+    const int64_t n_layer     = hparams.n_layer;
+    const int64_t n_ctx       = cparams.n_ctx;
+    const int64_t n_head      = hparams.n_head;
+    const int64_t n_head_kv   = hparams.n_head_kv;
+    const int64_t n_embd_head = hparams.n_embd_head();
+    const int64_t n_embd_gqa  = hparams.n_embd_gqa();
+
+    GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+    const float norm_eps = hparams.f_norm_eps;
+
+    const int32_t n_tokens = batch.n_tokens;
+    const int32_t n_kv     = ggml_allocr_is_measure(lctx.alloc) ? n_ctx            : kv_self.n;
+    const int32_t kv_head  = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
+
+    auto & buf_compute = lctx.buf_compute;
+
+    struct ggml_init_params params = {
+        /*.mem_size   =*/ buf_compute.size,
+        /*.mem_buffer =*/ buf_compute.data,
+        /*.no_alloc   =*/ false,
+    };
+
+    params.no_alloc = true;
+
+    struct ggml_context * ctx0 = ggml_init(params);
+
+    ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+    struct ggml_tensor * cur;
+    struct ggml_tensor * token;
+    struct ggml_tensor * inpL;
+
+    if (batch.token) {
+        struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+
+        ggml_allocr_alloc(lctx.alloc, inp_tokens);
+        if (!ggml_allocr_is_measure(lctx.alloc)) {
+            memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens));
+        }
+        ggml_set_name(inp_tokens, "inp_tokens");
+
+        token = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
+    } else {
+#ifdef GGML_USE_MPI
+        GGML_ASSERT(false && "not implemented");
+#endif
+
+        token = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
+
+        ggml_allocr_alloc(lctx.alloc, token);
+        if (!ggml_allocr_is_measure(lctx.alloc)) {
+            memcpy(token->data, batch.embd, n_tokens * n_embd * ggml_element_size(token));
+        }
+    }
+
+    // KQ_scale
+    struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+    ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
+    ggml_allocr_alloc(lctx.alloc, KQ_scale);
+    if (!ggml_allocr_is_measure(lctx.alloc)) {
+        ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
+    }
+
+    // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+    struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+    ggml_set_name(KQ_mask, "KQ_mask");
+    ggml_allocr_alloc(lctx.alloc, KQ_mask);
+    if (!ggml_allocr_is_measure(lctx.alloc)) {
+        float * data = (float *) KQ_mask->data;
+        memset(data, 0, ggml_nbytes(KQ_mask));
+
+        for (int h = 0; h < 1; ++h) {
+            for (int j = 0; j < n_tokens; ++j) {
+                const llama_pos    pos    = batch.pos[j];
+                const llama_seq_id seq_id = batch.seq_id[j];
+
+                for (int i = 0; i < n_kv; ++i) {
+                    if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
+                        data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
+                    }
+                }
+            }
+        }
+    }
+
+    // norm
+    {
+        inpL = ggml_norm(ctx0, token, norm_eps);
+        inpL = ggml_add(ctx0, ggml_mul(ctx0, inpL, model.tok_norm), model.tok_norm_b);
+    }
+
+    ggml_set_name(inpL, "inpL");
+
+    for (int il = 0; il < n_layer; ++il) {
+        {
+            // Norm
+            cur = ggml_norm(ctx0, inpL, norm_eps);
+            cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].attn_norm), model.layers[il].attn_norm_b);
+        }
+
+        {
+            // Self Attention
+            cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wqkv, cur), model.layers[il].bqkv);
+
+            struct ggml_tensor * tmpq = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*n_embd);
+            struct ggml_tensor * tmpk = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*n_embd);
+            struct ggml_tensor * tmpv = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*(n_embd + n_embd_gqa));
+
+            struct ggml_tensor * Qcur = tmpq;
+            struct ggml_tensor * Kcur = tmpk;
+
+            // store key and value to memory
+            {
+                struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens));
+                ggml_set_name(Vcur, "Vcur");
+
+                struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head));
+                ggml_set_name(k, "k");
+
+                struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa,
+                        (   n_ctx)*ggml_element_size(kv_self.v),
+                        (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
+
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
+            }
+
+            struct ggml_tensor * Q =
+                ggml_permute(ctx0,
+                        ggml_cpy(ctx0,
+                            Qcur,
+                            ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd_head, n_head, n_tokens)),
+                        0, 2, 1, 3);
+            ggml_set_name(Q, "Q");
+
+            struct ggml_tensor * K =
+                ggml_view_3d(ctx0, kv_self.k,
+                        n_embd_head, n_kv, n_head_kv,
+                        ggml_element_size(kv_self.k)*n_embd_gqa,
+                        ggml_element_size(kv_self.k)*n_embd_head,
+                        ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
+            ggml_set_name(K, "K");
+
+            // K * Q
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+            ggml_set_name(KQ, "KQ");
+
+            // KQ_scaled = KQ / sqrt(n_embd_head)
+            // KQ_scaled shape [n_past + n_tokens, n_tokens, n_head, 1]
+            struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
+            ggml_set_name(KQ_scaled, "KQ_scaled");
+
+            struct ggml_tensor * KQ_scaled_alibi = ggml_alibi(ctx0, KQ_scaled, /*n_past*/ kv_head, n_head, 8);
+            ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi");
+
+            // KQ_masked = mask_past(KQ_scaled)
+            struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask);
+            ggml_set_name(KQ_masked, "KQ_masked");
+
+            // KQ = soft_max(KQ_masked)
+            struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
+            ggml_set_name(KQ_soft_max, "KQ_soft_max");
+
+            // split cached V into n_head heads
+            struct ggml_tensor * V =
+                ggml_view_3d(ctx0, kv_self.v,
+                        n_kv, n_embd_head, n_head_kv,
+                        ggml_element_size(kv_self.v)*n_ctx,
+                        ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
+                        ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
+            ggml_set_name(V, "V");
+
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+            ggml_set_name(KQV, "KQV");
+
+            // KQV_merged = KQV.permute(0, 2, 1, 3)
+            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+            ggml_set_name(KQV_merged, "KQV_merged");
+
+            // cur = KQV_merged.contiguous().view(n_embd, n_tokens)
+            cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens);
+            ggml_set_name(cur, "KQV_merged_contiguous");
+        }
+
+        // Projection
+        cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wo, cur), model.layers[il].bo);
+
+        // Add the input
+        cur = ggml_add(ctx0, cur, inpL);
+
+        struct ggml_tensor * inpFF = cur;
+
+        // FF
+        {
+            // Norm
+            {
+                cur = ggml_norm(ctx0, inpFF, norm_eps);
+                cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ffn_norm), model.layers[il].ffn_norm_b);
+            }
+
+            cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w3, cur), model.layers[il].b3);
+
+            // GELU activation
+            cur = ggml_gelu(ctx0, cur);
+
+            // Projection
+            cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w2, cur), model.layers[il].b2);
+        }
+
+        inpL = ggml_add(ctx0, cur, inpFF);
+    }
+
+    // Output Norm
+    {
+        cur = ggml_norm(ctx0, inpL, norm_eps);
+        cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.output_norm), model.output_norm_b);
+    }
+    ggml_set_name(cur, "result_norm");
+
+    cur = ggml_mul_mat(ctx0, model.output, cur);
+    ggml_set_name(cur, "result_output");
+
+    ggml_build_forward_expand(gf, cur);
+
+    ggml_free(ctx0);
+
+    return gf;
+}
+
 static struct ggml_cgraph * llm_build_mpt(
          llama_context & lctx,
      const llama_batch & batch) {
@@ -5025,9 +5382,6 @@ static struct ggml_cgraph * llm_build_mpt(
     const int32_t n_kv     = ggml_allocr_is_measure(lctx.alloc) ? n_ctx            : kv_self.n;
     const int32_t kv_head  = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
 
-    //printf("kv_head = %d, n_kv = %d, n_tokens = %d, n_ctx = %d, is_measure = %d, has_shift = %d\n",
-    //        kv_head, n_kv, n_tokens, n_ctx, ggml_allocr_is_measure(lctx.alloc), kv_self.has_shift);
-
     auto & buf_compute = lctx.buf_compute;
 
     struct ggml_init_params params = {
@@ -5348,6 +5702,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm_build_refact(lctx, batch);
             } break;
+        case LLM_ARCH_BLOOM:
+            {
+                result = llm_build_bloom(lctx, batch);
+            } break;
         case LLM_ARCH_MPT:
             {
                 result = llm_build_mpt(lctx, batch);
@@ -7579,8 +7937,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         const std::string name = ggml_get_name(meta);
 
         // TODO: avoid hardcoded tensor names - use the TN_* constants
-        if (name.find("attn_v.weight") != std::string::npos ||
-            name.find("attn_qkv.weight") != std::string::npos) {
+        if (name.find("attn_v.weight") != std::string::npos || name.find("attn_qkv.weight") != std::string::npos) {
             ++n_attention_wv;
         }
         else if (name.find("ffn_down.weight") != std::string::npos) {