]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add AWQ for llama, llama2, mpt, and mistral models (#4593)
authorNam D. Tran <redacted>
Wed, 27 Dec 2023 15:39:45 +0000 (22:39 +0700)
committerGitHub <redacted>
Wed, 27 Dec 2023 15:39:45 +0000 (17:39 +0200)
* update: awq support llama-7b model

* update: change order

* update: benchmark results for llama2-7b

* update: mistral 7b v1 benchmark

* update: support 4 models

* fix: Readme

* update: ready for PR

* update: readme

* fix: readme

* update: change order import

* black

* format code

* update: work for bot mpt and awqmpt

* update: readme

* Rename to llm_build_ffn_mpt_awq

* Formatted other files

* Fixed params count

* fix: remove code

* update: more detail for mpt

* fix: readme

* fix: readme

* update: change folder architecture

* fix: common.cpp

* fix: readme

* fix: remove ggml_repeat

* update: cicd

* update: cicd

* uppdate: remove use_awq arg

* update: readme

* llama : adapt plamo to new ffn

ggml-ci

---------

Co-authored-by: Trần Đức Nam <redacted>
Co-authored-by: Le Hoang Anh <redacted>
Co-authored-by: Georgi Gerganov <redacted>
awq-py/README.md [new file with mode: 0644]
awq-py/awq/apply_awq.py [new file with mode: 0644]
awq-py/requirements.txt [new file with mode: 0644]
convert-hf-to-gguf.py
convert.py
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
llama.cpp

diff --git a/awq-py/README.md b/awq-py/README.md
new file mode 100644 (file)
index 0000000..59354f4
--- /dev/null
@@ -0,0 +1,116 @@
+# AWQ: Activation-aware Weight Quantization for LLM - version apply to llamacpp
+[[Paper](https://arxiv.org/abs/2306.00978)][[Original Repo](https://github.com/mit-han-lab/llm-awq)][[Easy-to-use Repo](https://github.com/casper-hansen/AutoAWQ)]
+
+**Supported models:**
+
+- [X] LLaMA
+- [x] LLaMA 2
+- [X] MPT
+- [X] Mistral AI v0.1
+- [ ] Bloom
+- [ ] Mixtral MoE
+
+**TODO:**
+- [x] Update version work with both MPT and MPT-AWQ model
+- [ ] Add OPT model
+- [ ] Add Bloom model
+- [ ] Add Mixtral MoE
+- [ ] Support w3, w2
+
+
+## Contents
+
+- [Install](##Install)
+- [Convert](##Convert)
+- [Quantize](##Quantize)
+- [Test](##Test)
+- [Benchmark](##Benchmark)
+- [Results](##Results)
+
+## Install
+Install requirements
+```bash
+pip install -r requirements.txt
+```
+Get the pre-computed AWQ search results for multiple model families, including LLaMA, LLaMA2, MPT, OPT
+```bash
+git clone https://huggingface.co/datasets/mit-han-lab/awq-model-zoo awq_cache
+```
+
+## Convert
+Example for llama model
+```bash
+# For llama7b and llama2 models
+python convert.py models/llama-7b/ --awq-path awq_cache/llama-7b-w4-g128.pt --outfile models/llama_7b_fp16.gguf
+# For mistral and mpt models
+python convert-hf-to-gguf.py models/mpt-7b/ --awq-path awq_cache/llama-7b-w4-g128.pt --outfile models/mpt_7b_fp16.gguf
+```
+
+## Quantize
+```bash
+# We only benchmark and confirm the results on q4_0, q4_1, and q2_k types.
+./quantize models/llama_7b_fp16.gguf models/llama_7b_q4_0.gguf q4_0
+```
+
+## Test
+```bash
+# For all models.
+./build/bin/main -m models/llama_7b_q4_0.gguf -n 128 --prompt "Once upon a time"
+```
+
+## Benchmark
+The perplexity measurements in table above are done against the `wikitext2` test dataset (https://paperswithcode.com/dataset/wikitext-2), with context length of 512.
+```bash
+# For llama and llama2, and mistral models.
+./perplexity -m models/llama_7b_q4_0.gguf -f datasets/wikitext-2-raw/wiki.test.raw
+```
+
+## Results
+Results are run on OpenBLAS (CPU) and CuBLAS (GPU) for fair comparison
+We use three types of llamacpp quantization methods to work with our version, including q4_0, q4_1, and q2_k
+
+### Llama 7B (Build with OpenBLAS)
+
+| Model      | Measure      | F16    | Q4_0   | Q4_1   | Q2_K   |
+|-----------:|--------------|-------:|-------:|-------:|-------:|
+|Llama 7B    | perplexity   | 5.9066 | 6.1214 | 6.0643 | 6.5808 |
+|Llama 7B    | file size    |  12.9G  |   3.5G |   3.9G |   2.7G |
+|Llama 7B    | bits/weight  |   16.0 |    4.5 |    5.0 |    2.6 |
+|AWQ-LLama 7B| perplexity   | 5.9175 | 6.0252 | 5.9987 | 6.3692 |
+|AWQ-LLama 7B| file size    |  12.9G  |   3.5G |   3.9G |   2.7G |
+|AWQ-LLama 7B| bits/weight  |   16.0 |    4.5 |    5.0 |    2.6 |
+
+
+### Llama2 7B (Build with CuBLAS)
+
+| Model       | Measure      | F16    | Q4_0   | Q4_1   | Q2_K   |
+|------------:|--------------|-------:|-------:|-------:|-------:|
+|Llama2 7B    | perplexity   | 5.8664 | 6.0260 | 6.0656 | 6.4496 |
+|Llama2 7B    | file size    |  12.9G  |   3.5G |   3.9G |   2.7G |
+|Llama2 7B    | bits/weight  |   16.0 |    4.5 |    5.0 |    2.6 |
+|AWQ-LLama2 7B| perplexity   | 5.8801 | 6.0054 | 5.9849 | 6.3650 |
+|AWQ-LLama2 7B| file size    |  12.9G  |   3.5G |   3.9G |   2.7G |
+|AWQ-LLama2 7B| bits/weight  |   16.0 |    4.5 |    5.0 |    2.6 |
+
+
+### Mistral 7B v0.1 (Build with CuBLAS)
+
+| Model        | Measure      | F16    | Q4_0   | Q4_1   | Q2_K   |
+|-------------:|--------------|-------:|-------:|-------:|-------:|
+|Mistral 7B    | perplexity   | 5.6931 | 5.8202 | 5.8268 | 6.1645 |
+|Mistral 7B    | file size     |  14.5G |   4.1G |   4.5G |   3.1G |
+|Mistral 7B    | bits/weight  |   16.0 |    4.5 |    5.0 |    2.6 |
+|AWQ-Mistral 7B| perplexity   | 5.6934 | 5.8020 | 5.7691 | 6.0426 |
+|AWQ-Mistral 7B| file size     |  14.5G |   4.1G |   4.5G |   3.1G |
+|AWQ-Mistral 7B| bits/weight  |   16.0 |    4.5 |    5.0 |    2.6 |
+
+### MPT 7B (Build with OpenBLAS)
+
+| Model    | Measure      | F16    | Q4_0   | Q4_1   | Q2_K    |
+|---------:|--------------|-------:|-------:|-------:|--------:|
+|MPT 7B    | perplexity   | 8.4369 | 8.7956 | 8.6265 | 11.4913 |
+|MPT 7B    | file size    |  13.7G  |   3.9G |   4.3G |   2.8G  |
+|MPT 7B    | bits/weight  |   16.0 |    4.5 |    5.0 |    2.6  |
+|AWQ-MPT 7B| perplexity   | 8.4944 | 8.7053 |  8.6750 | 10.2873|
+|AWQ-MPT 7B| file size    |  13.7G  |   3.9G |   4.3G |   2.8G  |
+|AWQ-MPT 7B| bits/weight  |   16.0 |    4.5 |    5.0 |    2.6  |
diff --git a/awq-py/awq/apply_awq.py b/awq-py/awq/apply_awq.py
new file mode 100644 (file)
index 0000000..11132c5
--- /dev/null
@@ -0,0 +1,254 @@
+"""
+Implements the AWQ for llama.cpp use cases.
+Original paper: https://arxiv.org/abs/2306.00978
+
+This code is based on versions of the AWQ implementation found in the following repositories:
+* https://github.com/mit-han-lab/llm-awq
+* https://github.com/casper-hansen/AutoAWQ
+"""
+
+import os
+import torch
+import torch.nn as nn
+
+from transformers import AutoModelForCausalLM, AutoConfig
+from transformers.models.bloom.modeling_bloom import BloomGelu
+from transformers.models.llama.modeling_llama import LlamaRMSNorm
+from transformers.activations import GELUActivation
+
+
+class ScaledActivation(nn.Module):
+    """
+    ScaledActivation module wraps an existing activation function and applies a
+    scale factor to its output.
+
+    Args:
+        module (nn.Module): The activation function to be scaled.
+        scales (torch.Tensor): A tensor of size (num_features,) containing the initial
+            scale factors for each feature.
+
+    Returns:
+        torch.Tensor: The scaled output of the activation function.
+    """
+
+    def __init__(self, module, scales):
+        super().__init__()
+        self.act = module
+        self.scales = nn.Parameter(scales.data)
+
+    def forward(self, x):
+        return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
+
+
+def set_op_by_name(layer, name, new_module):
+    """
+    Set the new module for given module's name.
+
+    Args:
+        layer (nn.Module): The layer in which to replace the submodule.
+        name (str): The path to the submodule to be replaced, using dot notation
+            to access nested modules.
+        new_module (nn.Module): The new module to replace the existing one.
+    """
+    levels = name.split(".")
+    if len(levels) > 1:
+        mod_ = layer
+        for l_idx in range(len(levels) - 1):
+            if levels[l_idx].isdigit():
+                mod_ = mod_[int(levels[l_idx])]
+            else:
+                mod_ = getattr(mod_, levels[l_idx])
+        setattr(mod_, levels[-1], new_module)
+    else:
+        setattr(layer, name, new_module)
+
+
+def get_op_by_name(module, op_name):
+    """
+    Retrieves a submodule within a given layer based on its name.
+
+    Args:
+        module (nn.Module): The layer containing the submodule to find.
+        op_name (str): The name of the submodule.
+
+    Returns:
+        nn.Module: The requested submodule found within the given layer.
+
+    Raises:
+        ValueError: If the specified submodule cannot be found within the layer.
+    """
+    for name, m in module.named_modules():
+        if name == op_name:
+            return m
+    raise ValueError(f"Cannot find op {op_name} in module {module}")
+
+
+@torch.no_grad()
+def scale_ln_fcs(ln, fcs, scales):
+    """
+    Scales the weights of a LayerNorm and a list of fully-connected layers proportionally.
+
+    Args:
+        ln (nn.LayerNorm): The LayerNorm module to be scaled.
+        fcs (List[nn.Linear]): A list of fully-connected layers to be scaled.
+        scales (torch.Tensor): A 1D tensor of size (num_features,).
+    """
+
+    if not isinstance(fcs, list):
+        fcs = [fcs]
+
+    scales = scales.to(ln.weight.device)
+
+    ln.weight.div_(scales)
+    if hasattr(ln, "bias") and ln.bias is not None:
+        ln.bias.div_(scales)
+
+    for fc in fcs:
+        fc.weight.mul_(scales.view(1, -1))
+
+    for p in ln.parameters():
+        assert torch.isnan(p).sum() == 0
+    for fc in fcs:
+        for p in fc.parameters():
+            assert torch.isnan(p).sum() == 0
+
+
+@torch.no_grad()
+def scale_fc_fc(fc1, fc2, scales):
+    """
+    Scales the weights of two fully-connected layers in a specific pattern.
+
+    Args:
+        fc1 (nn.Linear): The first fully-connected layer to be scaled.
+        fc2 (nn.Linear): The second fully-connected layer to be scaled.
+        scales (torch.Tensor): A 1D tensor of size (num_features,).
+    """
+    assert isinstance(fc1, nn.Linear)
+    assert isinstance(fc2, nn.Linear)
+
+    scales = scales.to(fc1.weight.device)
+
+    fc1.weight[-scales.size(0):].div_(scales.view(-1, 1))
+    if fc1.bias is not None:
+        fc1.bias.div_(scales.view(-1))
+
+    fc2.weight.mul_(scales.view(1, -1))
+
+    for p in fc1.parameters():
+        assert torch.isnan(p).sum() == 0
+    for p in fc2.parameters():
+        assert torch.isnan(p).sum() == 0
+
+
+@torch.no_grad()
+def scale_gelu_fc(gelu, fc, scales):
+    """
+    Scales the weight of a GELU activation and a fully-connected layer proportionally.
+
+    Args:
+        gelu (Union[nn.GELU, BloomGelu, GELUActivation]): The GELU activation module to be scaled.
+        fc (nn.Linear): The fully-connected layer to be scaled.
+        scales (torch.Tensor): A 1D tensor of size (num_features,).
+
+    Raises:
+        TypeError: If the `gelu` module is not of type `nn.GELU`, `BloomGelu`, or `GELUActivation`.
+        TypeError: If the `fc` module is not of type `nn.Linear`.
+    """
+    assert isinstance(gelu, (nn.GELU, BloomGelu, GELUActivation))
+    assert isinstance(fc, nn.Linear)
+
+    fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
+
+    for p in fc.parameters():
+        assert torch.isnan(p).sum() == 0
+
+
+def apply_scale(module, scales_list, input_feat_dict=None):
+    """
+    Applies different scaling strategies to layers based on their type and hierarchy within a given module.
+
+    Args:
+        module (nn.Module): The module containing the layers to be scaled.
+        scales_list (List[Tuple[str, List[str], torch.Tensor]]): A list of tuples containing:
+            * prev_op_name (str): The name of the preceding operation or module,
+                relative to which the layers to be scaled are located.
+            * layer_names (List[str]): A list of names of the layers to be scaled, relative to the preceding operation.
+            * scales (torch.Tensor): A 1D tensor of size (num_features,) containing the scaling factors for each feature.
+        input_feat_dict (Optional[Dict[str, torch.Tensor]]): A dictionary mapping layer names to their corresponding
+            input features (optional).
+    """
+    for prev_op_name, layer_names, scales in scales_list:
+        prev_op = get_op_by_name(module, prev_op_name)
+        layers = [get_op_by_name(module, name) for name in layer_names]
+
+        prev_op.cuda()
+        for layer in layers:
+            layer.cuda()
+        scales.cuda()
+
+        if isinstance(prev_op, nn.Linear):
+            assert len(layers) == 1
+            scale_fc_fc(prev_op, layers[0], scales)
+        elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)) or "rmsnorm" in str(prev_op.__class__).lower():
+            scale_ln_fcs(prev_op, layers, scales)
+        elif isinstance(prev_op, (nn.GELU, BloomGelu, GELUActivation)):
+            new_module = ScaledActivation(prev_op, scales)
+            set_op_by_name(module, prev_op_name, new_module)
+            scale_gelu_fc(prev_op, layers[0], scales)
+        else:
+            raise NotImplementedError(f"prev_op {type(prev_op)} not supported yet!")
+
+        # apply the scaling to input feat if given; prepare it for clipping
+        if input_feat_dict is not None:
+            for layer_name in layer_names:
+                inp = input_feat_dict[layer_name]
+                inp.div_(scales.view(1, -1).to(inp.device))
+
+        prev_op.cpu()
+        for layer in layers:
+            layer.cpu()
+        scales.cpu()
+
+
+@torch.no_grad()
+def apply_clip(module, clip_list):
+    """
+    Applies element-wise clipping to the weight of a specific layer within a given module.
+
+    Args:
+        module (nn.Module): The module containing the layer to be clipped.
+        clip_list (List[Tuple[str, torch.Tensor]]): A list of tuples containing:
+            * name (str): The name of the layer to be clipped, relative to the root of the module.
+            * max_val (torch.Tensor): A 1D or 2D tensor defining the upper bound for each element of the layer's weight.
+    """
+    for name, max_val in clip_list:
+        layer = get_op_by_name(module, name)
+        layer.cuda()
+        max_val = max_val.to(layer.weight.device)
+        org_shape = layer.weight.shape
+        layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1)
+        layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val)
+        layer.weight.data = layer.weight.data.reshape(org_shape)
+        layer.cpu()
+
+
+def add_scale_weights(model_path, scale_path, tmp_path):
+    """
+    Adds pre-computed Activation Weight Quantization (AWQ) results to a model,
+    including scaling factors and clipping bounds.
+
+    Args:
+        model_path (str): Path to the pre-trained model to be equipped with AWQ.
+        scale_path (str): Path to the AWQ scale factors (.pt file).
+        tmp_path (str): Path to the temporary directory where the equipped model will be saved.
+    """
+    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
+    model = AutoModelForCausalLM.from_pretrained(
+        model_path, config=config, trust_remote_code=True
+    )
+    model.eval()
+    awq_results = torch.load(str(scale_path), map_location="cpu")
+    apply_scale(model, awq_results["scale"])
+    apply_clip(model, awq_results["clip"])
+    model.save_pretrained(str(tmp_path))
+    os.system(f"cp {str(model_path)}/tokenizer* {str(tmp_path)}")
diff --git a/awq-py/requirements.txt b/awq-py/requirements.txt
new file mode 100644 (file)
index 0000000..5fe6043
--- /dev/null
@@ -0,0 +1,2 @@
+torch>=2.0.0
+transformers>=4.32.0
index 303d08170ecb04644f699226d42708ec6dabb487..7dbc2814796cc32b1e188e8828d6057287c0db67 100755 (executable)
@@ -46,7 +46,7 @@ class Model:
         self.part_names = self._get_part_names()
         self.hparams = Model.load_hparams(self.dir_model)
         self.model_arch = self._get_model_architecture()
-        self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess)
+        self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=False)
 
     def set_vocab(self):
         self._set_vocab_gpt2()
@@ -59,7 +59,7 @@ class Model:
                 from safetensors import safe_open
                 ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
             else:
-                ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
+                ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", weights_only=True))
 
             with ctx as model_part:
                 for name in model_part.keys():
@@ -464,7 +464,11 @@ class MPTModel(Model):
             data = data_torch.squeeze().numpy()
 
             # map tensor names
-            new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
+            if "scales" in name:
+                new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias", ".scales"))
+                new_name = new_name.replace("scales", "act.scales")
+            else:
+                new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
             if new_name is None:
                 print(f"Can not map tensor {name!r}")
                 sys.exit()
@@ -1095,6 +1099,9 @@ def parse_args() -> argparse.Namespace:
         "--vocab-only", action="store_true",
         help="extract only the vocab",
     )
+    parser.add_argument(
+        "--awq-path", type=Path, default=None,
+        help="Path to scale awq cache file")
     parser.add_argument(
         "--outfile", type=Path,
         help="path to write to; default: based on input",
@@ -1115,6 +1122,20 @@ def parse_args() -> argparse.Namespace:
 args = parse_args()
 
 dir_model = args.model
+
+if args.awq_path:
+    sys.path.insert(1, str(Path(__file__).parent / 'awq-py'))
+    from awq.apply_awq import add_scale_weights
+    tmp_model_path = args.model / "weighted_model"
+    dir_model = tmp_model_path
+    if tmp_model_path.is_dir():
+        print(f"{tmp_model_path} exists as a weighted model.")
+    else:
+        tmp_model_path.mkdir(parents=True, exist_ok=True)
+        print("Saving new weighted model ...")
+        add_scale_weights(str(args.model), str(args.awq_path), str(tmp_model_path))
+        print(f"Saved weighted model at {tmp_model_path}.")
+
 if not dir_model.is_dir():
     print(f'Error: {args.model} is not a directory', file=sys.stderr)
     sys.exit(1)
index 1f0c4f2f4197751660887be054447e87fb87a6a4..c3f3fc0a1fcd39ba1076574ea03192b17af125aa 100755 (executable)
@@ -1187,6 +1187,7 @@ def main(args_in: list[str] | None = None) -> None:
         # We currently only support Q8_0 output on little endian systems.
         output_choices.append("q8_0")
     parser = argparse.ArgumentParser(description="Convert a LLaMa model to a GGML compatible file")
+    parser.add_argument("--awq-path",    type=Path,              help="Path to scale awq cache file", default=None)
     parser.add_argument("--dump",        action="store_true",    help="don't convert, just show what's in the model")
     parser.add_argument("--dump-single", action="store_true",    help="don't convert, just show what's in a single model file")
     parser.add_argument("--vocab-only",  action="store_true",    help="extract only the vocab")
@@ -1200,6 +1201,19 @@ def main(args_in: list[str] | None = None) -> None:
     parser.add_argument("--padvocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides")
 
     args = parser.parse_args(args_in)
+    if args.awq_path:
+        sys.path.insert(1, str(Path(__file__).parent / 'awq-py'))
+        from awq.apply_awq import add_scale_weights
+        tmp_model_path = args.model / "weighted_model"
+        if tmp_model_path.is_dir():
+            print(f"{tmp_model_path} exists as a weighted model.")
+        else:
+            tmp_model_path.mkdir(parents=True, exist_ok=True)
+            print("Saving new weighted model ...")
+            add_scale_weights(str(args.model), str(args.awq_path), str(tmp_model_path))
+            print(f"Saved weighted model at {tmp_model_path}.")
+        args.model = tmp_model_path
+
     if args.dump_single:
         model_plus = lazy_load_file(args.model)
         do_dump_model(model_plus)
index 4cd87cdda8b7e2947691b76776b9eddd41a784c1..c9be21119824c33c6ccc6811576ba01b1f6abf5c 100644 (file)
@@ -120,6 +120,7 @@ class MODEL_TENSOR(IntEnum):
     FFN_GATE        = auto()
     FFN_DOWN        = auto()
     FFN_UP          = auto()
+    FFN_ACT         = auto()
     FFN_GATE_EXP    = auto()
     FFN_DOWN_EXP    = auto()
     FFN_UP_EXP      = auto()
@@ -169,6 +170,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.FFN_GATE:        "blk.{bid}.ffn_gate",
     MODEL_TENSOR.FFN_DOWN:        "blk.{bid}.ffn_down",
     MODEL_TENSOR.FFN_UP:          "blk.{bid}.ffn_up",
+    MODEL_TENSOR.FFN_ACT:         "blk.{bid}.ffn",
     MODEL_TENSOR.FFN_GATE_EXP:    "blk.{bid}.ffn_gate.{xid}",
     MODEL_TENSOR.FFN_DOWN_EXP:    "blk.{bid}.ffn_down.{xid}",
     MODEL_TENSOR.FFN_UP_EXP:      "blk.{bid}.ffn_up.{xid}",
@@ -269,6 +271,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_NORM,
         MODEL_TENSOR.FFN_DOWN,
         MODEL_TENSOR.FFN_UP,
+        MODEL_TENSOR.FFN_ACT,
     ],
     MODEL_ARCH.GPTJ: [
         MODEL_TENSOR.TOKEN_EMBD,
index 446c6b6883be9c5b6ef4ba4c21247a08008d5af2..0b8f704174e595165f2a40cf73b82b977d388211 100644 (file)
@@ -188,6 +188,11 @@ class TensorNameMap:
             "model.layers.{bid}.block_sparse_moe.experts.{xid}.w3", # mixtral
         ),
 
+        # AWQ-activation gate
+        MODEL_TENSOR.FFN_ACT: (
+            "transformer.blocks.{bid}.ffn.act",  # mpt
+        ),
+
         # Feed-forward gate
         MODEL_TENSOR.FFN_GATE: (
             "model.layers.{bid}.mlp.gate_proj",           # llama-hf refact
index 4aa59c4c0bd8831a15f8dd2d3c3455c0c9674f16..bf1b01a90dcbe1b759f2958b258b28ba167985ed 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -354,6 +354,7 @@ enum llm_tensor {
     LLM_TENSOR_FFN_GATE,
     LLM_TENSOR_FFN_DOWN,
     LLM_TENSOR_FFN_UP,
+    LLM_TENSOR_FFN_ACT,
     LLM_TENSOR_FFN_DOWN_EXP,
     LLM_TENSOR_FFN_GATE_EXP,
     LLM_TENSOR_FFN_UP_EXP,
@@ -473,6 +474,7 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
             { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
             { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+            { LLM_TENSOR_FFN_ACT,         "blk.%d.ffn.act" },
         },
     },
     {
@@ -1285,6 +1287,7 @@ struct llama_hparams {
     float f_clamp_kqv;
     float f_max_alibi_bias;
 
+
     bool operator!=(const llama_hparams & other) const {
         if (this->vocab_only    != other.vocab_only)    return true;
         if (this->n_vocab       != other.n_vocab)       return true;
@@ -1388,6 +1391,7 @@ struct llama_layer {
     // ff bias
     struct ggml_tensor * ffn_down_b; // b2
     struct ggml_tensor * ffn_up_b;   // b3
+    struct ggml_tensor * ffn_act;
 };
 
 struct llama_kv_cell {
@@ -3471,7 +3475,6 @@ static bool llm_load_tensors(
             case LLM_ARCH_MPT:
                 {
                     model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
-
                     // output
                     {
                         ggml_backend_type backend_norm;
@@ -3509,6 +3512,9 @@ static bool llm_load_tensors(
 
                         layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, backend_split);
                         layer.ffn_up   = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, backend_split);
+
+                        // AWQ ScaleActivation layer
+                        layer.ffn_act = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, backend, false);
                     }
                 } break;
             case LLM_ARCH_STABLELM:
@@ -4039,6 +4045,7 @@ static struct ggml_tensor * llm_build_ffn(
          struct ggml_tensor * gate_b,
          struct ggml_tensor * down,
          struct ggml_tensor * down_b,
+         struct ggml_tensor * act_scales,
             llm_ffn_op_type   type_op,
           llm_ffn_gate_type   type_gate,
          const llm_build_cb & cb,
@@ -4083,6 +4090,10 @@ static struct ggml_tensor * llm_build_ffn(
             {
                 cur = ggml_gelu(ctx, cur);
                 cb(cur, "ffn_gelu", il);
+                if (act_scales != NULL) {
+                    cur = ggml_div(ctx, cur, act_scales);
+                    cb(cur, "ffn_act", il);
+                }
             } break;
         case LLM_FFN_RELU:
             {
@@ -4401,6 +4412,7 @@ struct llm_build_context {
                         model.layers[il].ffn_up,   NULL,
                         model.layers[il].ffn_gate, NULL,
                         model.layers[il].ffn_down, NULL,
+                        NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                 cb(cur, "ffn_out", il);
             } else {
@@ -4580,6 +4592,7 @@ struct llm_build_context {
                         model.layers[il].ffn_up,   NULL,
                         model.layers[il].ffn_gate, NULL,
                         model.layers[il].ffn_down, NULL,
+                        NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                 cb(cur, "ffn_out", il);
             }
@@ -4694,6 +4707,7 @@ struct llm_build_context {
                         model.layers[il].ffn_up,   NULL,
                         NULL,                      NULL,
                         model.layers[il].ffn_down, NULL,
+                        NULL,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
             }
@@ -4798,6 +4812,7 @@ struct llm_build_context {
                         model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
                         NULL,                      NULL,
                         model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                        NULL,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
             }
@@ -5002,6 +5017,7 @@ struct llm_build_context {
                         model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
                         NULL,                      NULL,
                         model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                        NULL,
                         LLM_FFN_RELU_SQR, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
             }
@@ -5088,6 +5104,7 @@ struct llm_build_context {
                         model.layers[il].ffn_up,   NULL,
                         model.layers[il].ffn_gate, NULL,
                         model.layers[il].ffn_down, NULL,
+                        NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                 cb(cur, "ffn_out", il);
             }
@@ -5183,6 +5200,7 @@ struct llm_build_context {
                         model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
                         NULL,                      NULL,
                         model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                        NULL,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
             }
@@ -5268,11 +5286,11 @@ struct llm_build_context {
                         NULL,
                         LLM_NORM, cb, il);
                 cb(cur, "ffn_norm", il);
-
                 cur = llm_build_ffn(ctx0, cur,
                         model.layers[il].ffn_up,   NULL,
                         NULL,                      NULL,
                         model.layers[il].ffn_down, NULL,
+                        model.layers[il].ffn_act,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
             }
@@ -5381,6 +5399,7 @@ struct llm_build_context {
                         model.layers[il].ffn_up,   NULL,
                         model.layers[il].ffn_gate, NULL,
                         model.layers[il].ffn_down, NULL,
+                        NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                 cb(cur, "ffn_out", il);
             }
@@ -5493,6 +5512,7 @@ struct llm_build_context {
                         model.layers[il].ffn_up,   NULL,
                         model.layers[il].ffn_gate, NULL,
                         model.layers[il].ffn_down, NULL,
+                        NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                 cb(cur, "ffn_out", il);
             }
@@ -5600,6 +5620,7 @@ struct llm_build_context {
                         model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
                         NULL,                      NULL,
                         model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                        NULL,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
                 cb(ffn_output, "ffn_out", il);
             }
@@ -5703,6 +5724,7 @@ struct llm_build_context {
                         model.layers[il].ffn_up, NULL,
                         model.layers[il].ffn_gate, NULL,
                         model.layers[il].ffn_down, NULL,
+                        NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                 cb(cur, "ffn_out", il);
             }
@@ -5887,6 +5909,7 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
     { "ffn_gate",                   OFFLOAD_FUNC     },
     { "ffn_gate_b",                 OFFLOAD_FUNC     },
     { "ffn_gate_par",               OFFLOAD_FUNC     },
+    { "ffn_act",                    OFFLOAD_FUNC     },
     { "ffn_down",                   OFFLOAD_FUNC     },
     { "ffn_down_b",                 OFFLOAD_FUNC     },
     { "ffn_out",                    OFFLOAD_FUNC     },