]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add LoRA support (#820)
authorslaren <redacted>
Mon, 17 Apr 2023 15:28:55 +0000 (17:28 +0200)
committerGitHub <redacted>
Mon, 17 Apr 2023 15:28:55 +0000 (17:28 +0200)
convert-lora-to-ggml.py [new file with mode: 0644]
examples/common.cpp
examples/common.h
examples/main/main.cpp
examples/perplexity/perplexity.cpp
ggml.c
ggml.h
llama.cpp
llama.h
llama_util.h

diff --git a/convert-lora-to-ggml.py b/convert-lora-to-ggml.py
new file mode 100644 (file)
index 0000000..8a2085c
--- /dev/null
@@ -0,0 +1,124 @@
+import json
+import os
+import re
+import struct
+import sys
+from typing import Any, Dict, Sequence, TextIO
+
+import torch
+
+from convert import DATA_TYPE_TO_FTYPE, NUMPY_TYPE_TO_DATA_TYPE, DataType
+
+HF_SUBLAYER_TO_GGML = {
+    "self_attn.q_proj": "attention.wq",
+    "self_attn.k_proj": "attention.wk",
+    "self_attn.v_proj": "attention.wv",
+    "self_attn.o_proj": "attention.wo",
+    "mlp.gate_proj": "feed_forward.w1",
+    "mlp.down_proj": "feed_forward.w2",
+    "mlp.up_proj": "feed_forward.w3",
+    "input_layernorm": "attention_norm",
+    "post_attention_layernorm": "ffn_norm",
+    # "norm": "norm",
+    # "embed_tokens": "tok_embeddings",
+    # "lm_head": "output",
+}
+
+
+def translate_tensor_name(t: str) -> str:
+    match = re.match(r".*layers\.(\d+)\.(\w+\.\w+)\.lora_(A|B)\.weight", t)
+    if match:
+        nn = match.group(1)
+        sub_layer = match.group(2)
+        lora_type = match.group(3)
+
+        sub_layer_renamed = HF_SUBLAYER_TO_GGML.get(sub_layer)
+        if sub_layer_renamed is None:
+            print(f"Error: unrecognized sub-layer {sub_layer} in tensor {t}")
+            sys.exit(1)
+
+        output_string = (
+            f"layers.{nn}.{HF_SUBLAYER_TO_GGML[sub_layer]}.weight.lora{lora_type}"
+        )
+        return output_string
+    else:
+        print(f"Error: unrecognized tensor {t}")
+        sys.exit(1)
+
+
+def write_file_header(fout: TextIO, params: Dict[str, Any]) -> None:
+    fout.write(b"ggla"[::-1])  # magic (ggml lora)
+    fout.write(struct.pack("i", 1))  # file version
+    fout.write(struct.pack("ii", params["r"], params["lora_alpha"]))
+
+
+def write_tensor_header(
+    self, name: str, shape: Sequence[int], data_type: DataType
+) -> None:
+    sname = name.encode("utf-8")
+    fout.write(
+        struct.pack(
+            "iii",
+            len(shape),
+            len(sname),
+            DATA_TYPE_TO_FTYPE[NUMPY_TYPE_TO_DATA_TYPE[data_type]],
+        )
+    )
+    fout.write(struct.pack("i" * len(shape), *shape[::-1]))
+    fout.write(sname)
+    fout.seek((fout.tell() + 31) & -32)
+
+
+if len(sys.argv) != 2:
+    print(f"Usage: python {sys.argv[0]} <path>")
+    print(
+        "Path must contain HuggingFace PEFT LoRA files 'adapter_config.json' and 'adapter_model.bin'"
+    )
+    sys.exit(1)
+
+input_json = os.path.join(sys.argv[1], "adapter_config.json")
+input_model = os.path.join(sys.argv[1], "adapter_model.bin")
+output_path = os.path.join(sys.argv[1], "ggml-adapter-model.bin")
+
+model = torch.load(input_model, map_location="cpu")
+
+with open(input_json, "r") as f:
+    params = json.load(f)
+
+if params["peft_type"] != "LORA":
+    print(f"Error: unsupported adapter type {params['peft_type']}, expected LORA")
+    sys.exit(1)
+
+if params["fan_in_fan_out"] == True:
+    print("Error: param fan_in_fan_out is not supported")
+    sys.exit(1)
+
+if params["bias"] is not None and params["bias"] != "none":
+    print("Error: param bias is not supported")
+    sys.exit(1)
+
+# TODO: these seem to be layers that have been trained but without lora.
+# doesn't seem widely used but eventually should be supported
+if params["modules_to_save"] is not None and len(params["modules_to_save"]) > 0:
+    print("Error: param modules_to_save is not supported")
+    sys.exit(1)
+
+with open(output_path, "wb") as fout:
+    fout.truncate()
+
+    write_file_header(fout, params)
+    for k, v in model.items():
+        if k.endswith("lora_A.weight"):
+            if v.dtype != torch.float16 and v.dtype != torch.float32:
+                v = v.float()
+            v = v.T
+        else:
+            v = v.float()
+
+        t = v.numpy()
+        tname = translate_tensor_name(k)
+        print(f"{k} => {tname} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB")
+        write_tensor_header(fout, tname, t.shape, t.dtype)
+        t.tofile(fout)
+
+print(f"Converted {input_json} and {input_model} to {output_path}")
index 0772dbfe142ffe167f2c3e5ffa3815f6488af04e..a0b6f10ad8c8bf34e968757e5c895afb67c7f581 100644 (file)
@@ -139,6 +139,19 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.model = argv[i];
+        } else if (arg == "--lora") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.lora_adapter = argv[i];
+            params.use_mmap = false;
+        } else if (arg == "--lora-base") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.lora_base = argv[i];
         } else if (arg == "-i" || arg == "--interactive") {
             params.interactive = true;
         } else if (arg == "--embedding") {
@@ -242,6 +255,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     }
     fprintf(stderr, "  --mtest               compute maximum memory usage\n");
     fprintf(stderr, "  --verbose-prompt      print prompt before generation\n");
+    fprintf(stderr, "  --lora FNAME          apply LoRA adapter (implies --no-mmap)\n");
+    fprintf(stderr, "  --lora-base FNAME     optional model to use as a base for the layers modified by the LoRA adapter\n");
     fprintf(stderr, "  -m FNAME, --model FNAME\n");
     fprintf(stderr, "                        model path (default: %s)\n", params.model.c_str());
     fprintf(stderr, "\n");
index 1ea6f74451811130a378892f5f5a0ff6927633be..cbbc2dfab16de11969f67e0fcfd8cd15b7f5e62b 100644 (file)
@@ -31,11 +31,12 @@ struct gpt_params {
 
     std::string model  = "models/lamma-7B/ggml-model.bin"; // model path
     std::string prompt = "";
-    std::string input_prefix = ""; // string to prefix user inputs with
-
-
+    std::string input_prefix = "";       // string to prefix user inputs with
     std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
 
+    std::string lora_adapter = "";  // lora adapter path
+    std::string lora_base = "";     // base model path for the lora adapter
+
     bool memory_f16        = true;  // use f16 instead of f32 for memory kv
     bool random_prompt     = false; // do not randomize prompt if none provided
     bool use_color         = false; // use color to distinguish generations and inputs
index 3e4b0034ee977bbe13e8fb02edb23f355be90ea8..b7b3c419655f663bba703d27b9c0f5295dd2d6aa 100644 (file)
@@ -114,6 +114,17 @@ int main(int argc, char ** argv) {
         }
     }
 
+    if (!params.lora_adapter.empty()) {
+        int err = llama_apply_lora_from_file(ctx,
+                                             params.lora_adapter.c_str(),
+                                             params.lora_base.empty() ? NULL : params.lora_base.c_str(),
+                                             params.n_threads);
+        if (err != 0) {
+            fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
+            return 1;
+        }
+    }
+
     // print system information
     {
         fprintf(stderr, "\n");
index 19449e16e4d54785bd9d498f493077d8b4578028..80792ea0d95d0a7932e64492d8d457597b171cb0 100644 (file)
@@ -134,6 +134,17 @@ int main(int argc, char ** argv) {
         }
     }
 
+    if (!params.lora_adapter.empty()) {
+        int err = llama_apply_lora_from_file(ctx,
+                                             params.lora_adapter.c_str(),
+                                             params.lora_base.empty() ? NULL : params.lora_base.c_str(),
+                                             params.n_threads);
+        if (err != 0) {
+            fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
+            return 1;
+        }
+    }
+
     // print system information
     {
         fprintf(stderr, "\n");
diff --git a/ggml.c b/ggml.c
index 995a2faabac8481e5592fa00f2ea64e7ca776c65..acdba0333c8ea536a2f6467b7943f31a044a8759 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -1420,6 +1420,34 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
 #endif
 }
 
+static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+
+static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
+    [GGML_TYPE_Q4_0] = {
+        .dequantize_row_q         = dequantize_row_q4_0,
+        .quantize_row_q           = quantize_row_q4_0,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
+        .quantize_row_q_dot       = quantize_row_q8_0,
+        .vec_dot_q                = ggml_vec_dot_q4_0_q8_0,
+    },
+    [GGML_TYPE_Q4_1] = {
+        .dequantize_row_q         = dequantize_row_q4_1,
+        .quantize_row_q           = quantize_row_q4_1,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
+        .quantize_row_q_dot       = quantize_row_q4_1,
+        .vec_dot_q                = ggml_vec_dot_q4_1,
+    },
+    // TODO: GGML_TYPE_Q8_0
+};
+
+// For internal test use
+quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
+    GGML_ASSERT(i < GGML_TYPE_COUNT);
+    return quantize_fns[i];
+}
+
+
 //
 // simd mappings
 //
@@ -5588,6 +5616,26 @@ static void ggml_compute_forward_dup_f16(
                         }
                     }
                 }
+            } else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
+                quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
+                size_t id = 0;
+                uint8_t * dst_ptr = (uint8_t *) dst->data;
+                size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
+                float * src0_f32 = (float *) params->wdata;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        for (int i01 = 0; i01 < ne01; i01++) {
+                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+                            // convert to f32 and quantize
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
+                            }
+                            quantize_row_q(src0_f32, dst_ptr + id, ne00);
+                            id += dst_row_size;
+                        }
+                    }
+                }
             } else {
                 GGML_ASSERT(false); // TODO: implement
             }
@@ -5780,6 +5828,21 @@ static void ggml_compute_forward_dup_f32(
                         }
                     }
                 }
+            } else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
+                quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
+                size_t id = 0;
+                uint8_t * dst_ptr = (uint8_t *) dst->data;
+                size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        for (int i01 = 0; i01 < ne01; i01++) {
+                            const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+                            quantize_row_q(src0_ptr, dst_ptr + id, ne00);
+                            id += dst_row_size;
+                        }
+                    }
+                }
             } else {
                 GGML_ASSERT(false); // TODO: implement
             }
@@ -5968,6 +6031,212 @@ static void ggml_compute_forward_add_f32(
     }
 }
 
+static void ggml_compute_forward_add_f16_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+
+    const size_t nb10 = src1->nb[0];
+    const size_t nb11 = src1->nb[1];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F16);
+
+    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+    if (nb10 == sizeof(float)) {
+        for (int j = ith; j < n; j += nth) {
+            ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + j*nb1);
+            ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
+            for (int i = 0; i < nc; i++) {
+                float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
+                dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
+            }
+        }
+    }
+    else {
+        // src1 is not contiguous
+        GGML_ASSERT(false);
+    }
+}
+
+static void ggml_compute_forward_add_f16_f16(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+
+    const size_t nb10 = src1->nb[0];
+    const size_t nb11 = src1->nb[1];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F16);
+    GGML_ASSERT(dst->type == GGML_TYPE_F16);
+
+    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+    if (nb10 == sizeof(ggml_fp16_t)) {
+        for (int j = ith; j < n; j += nth) {
+            ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + j*nb1);
+            ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
+            for (int i = 0; i < nc; i++) {
+                ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10);
+                dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr));
+            }
+        }
+    }
+    else {
+        // src1 is not contiguous
+        GGML_ASSERT(false);
+    }
+}
+
+static void ggml_compute_forward_add_q_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
+
+    //const int64_t ne10 = src1->ne[0];
+    //const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
+
+    //const int64_t ne0  = dst->ne[0];
+    //const int64_t ne1  = dst->ne[1];
+    const int64_t ne2  = dst->ne[2];
+    const int64_t ne3  = dst->ne[3];
+
+    const int nb00 = src0->nb[0];
+    const int nb01 = src0->nb[1];
+    const int nb02 = src0->nb[2];
+    const int nb03 = src0->nb[3];
+
+    const int nb10 = src1->nb[0];
+    const int nb11 = src1->nb[1];
+    const int nb12 = src1->nb[2];
+    const int nb13 = src1->nb[3];
+
+    const int nb0  = dst->nb[0];
+    const int nb1  = dst->nb[1];
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_ASSERT(ne02 == ne12);
+    GGML_ASSERT(ne03 == ne13);
+    GGML_ASSERT(ne2  == ne12);
+    GGML_ASSERT(ne3  == ne13);
+
+    const enum ggml_type type = src0->type;
+    dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
+    quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
+
+    // we don't support permuted src0 or src1
+    GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    GGML_ASSERT(src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1);
+    GGML_ASSERT(dst->type == src0->type);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+    // total rows in src0
+    const int nr = ne01*ne02*ne03;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    float * wdata = (float*) params->wdata + ne00 * ith;
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 indices
+        const int i03 = ir/(ne02*ne01);
+        const int i02 = (ir - i03*ne02*ne01)/ne01;
+        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+        // src1 and dst are same shape as src0 => same indices
+        const int i13 = i03;
+        const int i12 = i02;
+        const int i11 = i01;
+
+        const int i3 = i03;
+        const int i2 = i02;
+        const int i1 = i01;
+
+        void  * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
+        float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
+        void  * dst_row  = (void *) ((char *)  dst->data + ( i1*nb1  +  i2*nb2  +  i3*nb0));
+
+        assert(ne00 % 32 == 0);
+
+        // unquantize row from src0 to temp buffer
+        dequantize_row_q(src0_row, wdata, ne00);
+        // add src1
+        ggml_vec_acc_f32(ne00, wdata, src1_row);
+        // quantize row to dst
+        quantize_row_q(wdata, dst_row, ne00);
+    }
+}
+
 static void ggml_compute_forward_add(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
@@ -5978,6 +6247,23 @@ static void ggml_compute_forward_add(
             {
                 ggml_compute_forward_add_f32(params, src0, src1, dst);
             } break;
+        case GGML_TYPE_F16:
+            {
+                if (src1->type == GGML_TYPE_F16) {
+                    ggml_compute_forward_add_f16_f16(params, src0, src1, dst);
+                }
+                else if (src1->type == GGML_TYPE_F32) {
+                    ggml_compute_forward_add_f16_f32(params, src0, src1, dst);
+                }
+                else {
+                    GGML_ASSERT(false);
+                }
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+            {
+                ggml_compute_forward_add_q_f32(params, src0, src1, dst);
+            } break;
         default:
             {
                 GGML_ASSERT(false);
@@ -7257,30 +7543,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
     //}
 }
 
-static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
-    [GGML_TYPE_Q4_0] = {
-        .dequantize_row_q         = dequantize_row_q4_0,
-        .quantize_row_q           = quantize_row_q4_0,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
-        .quantize_row_q_dot       = quantize_row_q8_0,
-        .vec_dot_q                = ggml_vec_dot_q4_0_q8_0,
-    },
-    [GGML_TYPE_Q4_1] = {
-        .dequantize_row_q         = dequantize_row_q4_1,
-        .quantize_row_q           = quantize_row_q4_1,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
-        .quantize_row_q_dot       = quantize_row_q4_1,
-        .vec_dot_q                = ggml_vec_dot_q4_1,
-    },
-    // TODO: GGML_TYPE_Q8_0
-};
-
-// For internal test use
-quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
-    GGML_ASSERT(i < GGML_TYPE_COUNT);
-    return quantize_fns[i];
-}
-
 static void ggml_compute_forward_mul_mat_q_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
@@ -10137,13 +10399,29 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
             struct ggml_tensor * node = cgraph->nodes[i];
 
             switch (node->op) {
+                case GGML_OP_CPY:
                 case GGML_OP_DUP:
                     {
                         node->n_tasks = 1;
+
+                        size_t cur = 0;
+                        if (node->type == GGML_TYPE_Q4_0 || node->type == GGML_TYPE_Q4_1) {
+                            cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0];
+                        }
+
+                        work_size = MAX(work_size, cur);
                     } break;
                 case GGML_OP_ADD:
                     {
                         node->n_tasks = n_threads;
+
+                        size_t cur = 0;
+
+                        if (node->src0->type == GGML_TYPE_Q4_0 || node->src0->type == GGML_TYPE_Q4_1) {
+                            cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
+                        }
+
+                        work_size = MAX(work_size, cur);
                     } break;
                 case GGML_OP_SUB:
                 case GGML_OP_MUL:
@@ -10224,7 +10502,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                     {
                         node->n_tasks = n_threads;
                     } break;
-                case GGML_OP_CPY:
                 case GGML_OP_CONT:
                 case GGML_OP_RESHAPE:
                 case GGML_OP_VIEW:
diff --git a/ggml.h b/ggml.h
index e693754c917fc3c202388ec987a71363302d63cc..59de0cb12572ab3819db575657cb41b53503f7b2 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -430,6 +430,12 @@ struct ggml_tensor * ggml_add(
         struct ggml_tensor  * a,
         struct ggml_tensor  * b);
 
+
+struct ggml_tensor * ggml_add_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b);
+
 struct ggml_tensor * ggml_sub(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
index cdd8bd16bc039996d18efeea46fc73a16e59f5a9..db71c03bdd4a1fddbaaec6ebb5eedd8cb204f08e 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1,6 +1,8 @@
 // Defines fileno on msys:
 #ifndef _GNU_SOURCE
 #define _GNU_SOURCE
+#include <cstdint>
+#include <cstdio>
 #endif
 
 #include "llama_util.h"
@@ -633,6 +635,7 @@ struct llama_model_loader {
             throw format("llama.cpp: tensor '%s' has wrong shape; expected %s, got %s",
                          name.c_str(), llama_format_tensor_shape(ne).c_str(), llama_format_tensor_shape(lt.ne).c_str());
         }
+
         return get_tensor_for(lt);
     }
 
@@ -1774,6 +1777,254 @@ int llama_model_quantize(
     }
 }
 
+int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char * path_lora, const char * path_base_model, int n_threads) {
+    fprintf(stderr, "%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora);
+
+    auto & model = ctx->model;
+
+    const int64_t t_start_lora_us = ggml_time_us();
+
+    auto fin = std::ifstream(path_lora, std::ios::binary);
+    if (!fin) {
+        fprintf(stderr, "%s: failed to open '%s'\n", __func__, path_lora);
+        return 1;
+    }
+
+    // verify magic and version
+    {
+        uint32_t magic;
+        fin.read((char *) &magic, sizeof(magic));
+        if (magic != 'ggla') {
+            fprintf(stderr, "%s: bad file magic\n", __func__);
+            return 1;
+        }
+        uint32_t format_version;
+        fin.read((char *) &format_version, sizeof(format_version));
+
+        if (format_version != 1) {
+            fprintf(stderr, "%s: unsupported file version\n", __func__ );
+            return 1;
+        }
+    }
+
+    int32_t lora_r;
+    int32_t lora_alpha;
+    fin.read((char *) &lora_r, sizeof(lora_r));
+    fin.read((char *) &lora_alpha, sizeof(lora_alpha));
+    float scaling = (float)lora_alpha / (float)lora_r;
+
+    fprintf(stderr, "%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling);
+
+
+    // create a temporary ggml context to store the lora tensors
+    // todo: calculate size from biggest possible tensor
+    std::vector<uint8_t> lora_buf(1024ull * 1024ull * 1024ull);
+    struct ggml_init_params params;
+    params.mem_size   = lora_buf.size();
+    params.mem_buffer = lora_buf.data();
+    params.no_alloc   = false;
+
+    ggml_context * lora_ctx = ggml_init(params);
+    std::unordered_map<std::string, struct ggml_tensor *> lora_tensors;
+
+    // create a name -> tensor map of the model to accelerate lookups
+    std::unordered_map<std::string, struct ggml_tensor*> model_tensors;
+    for (auto & kv: model.tensors_by_name) {
+        model_tensors.insert(kv);
+    }
+
+
+    // load base model
+    std::unique_ptr<llama_model_loader> model_loader;
+    ggml_context * base_ctx = NULL;
+    llama_buffer base_buf;
+    if (path_base_model) {
+        fprintf(stderr, "%s: loading base model from '%s'\n", __func__, path_base_model);
+        model_loader.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*vocab_only*/ false));
+
+        size_t ctx_size, mmapped_size;
+        model_loader->calc_sizes(&ctx_size, &mmapped_size);
+        base_buf.resize(ctx_size);
+
+        ggml_init_params base_params;
+        base_params.mem_size   = base_buf.size;
+        base_params.mem_buffer = base_buf.addr;
+        base_params.no_alloc   = model_loader->use_mmap;
+
+        base_ctx = ggml_init(base_params);
+
+        model_loader->ggml_ctx = base_ctx;
+
+        // maybe this should in llama_model_loader
+        if (model_loader->use_mmap) {
+            model_loader->mapping.reset(new llama_mmap(&model_loader->file_loaders.at(0)->file, /* prefetch */ false));
+        }
+    }
+
+    // read tensors and apply
+    bool warned = false;
+    int n_tensors = 0;
+    while (true) {
+        int32_t n_dims;
+        int32_t length;
+        int32_t ftype;
+
+        fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
+        fin.read(reinterpret_cast<char *>(&length), sizeof(length));
+        fin.read(reinterpret_cast<char *>(&ftype),  sizeof(ftype));
+        if (fin.eof()) {
+            break;
+        }
+
+        int32_t ne[2] = { 1, 1 };
+        for (int i = 0; i < n_dims; ++i) {
+            fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+        }
+
+        std::string name(length, 0);
+        fin.read(&name[0], length);
+
+        // check for lora suffix and get the type of tensor
+        const std::string lora_suffix = ".lora";
+        size_t pos = name.rfind(lora_suffix);
+        if (pos == std::string::npos) {
+            fprintf(stderr, "%s: error: '%s' is not a lora tensor\n", __func__, name.c_str());
+            return 1;
+        }
+
+        std::string lora_type = name.substr(pos + lora_suffix.length());
+        std::string base_name = name;
+        base_name.erase(pos);
+        // fprintf(stderr, "%s: %s => %s (lora type %s) ", __func__, name.c_str(),base_name.c_str(), lora_type.c_str());
+
+        if (model_tensors.find(base_name.data()) == model_tensors.end()) {
+            fprintf(stderr, "%s: unknown tensor '%s' in lora adapter\n", __func__, name.data());
+            return 1;
+        }
+
+        // create ggml tensor
+        ggml_type wtype;
+        switch (ftype) {
+            case 0: wtype = GGML_TYPE_F32;  break;
+            case 1: wtype = GGML_TYPE_F16;  break;
+            default:
+                    {
+                        fprintf(stderr, "%s: invalid tensor data type '%d'\n",
+                                __func__, ftype);
+                        return false;
+                    }
+        }
+        ggml_tensor* lora_tensor;
+        if (n_dims == 2) {
+            lora_tensor = ggml_new_tensor_2d(lora_ctx, wtype, ne[0], ne[1]);
+        }
+        else {
+            fprintf(stderr, "%s: unsupported tensor dimension %d\n", __func__, n_dims);
+            return 1;
+        }
+
+        // load tensor data
+        size_t offset = fin.tellg();
+        size_t tensor_data_size = ggml_nbytes(lora_tensor);
+        offset = (offset + 31) & -32;
+        fin.seekg(offset);
+        fin.read((char*)lora_tensor->data, tensor_data_size);
+
+        lora_tensors[name] = lora_tensor;
+
+        // check if we have both A and B tensors and apply
+        if (lora_tensors.find(base_name + ".loraA") != lora_tensors.end() &&
+            lora_tensors.find(base_name + ".loraB") != lora_tensors.end()) {
+
+            ggml_tensor * dest_t = model_tensors[base_name];
+            ggml_tensor * base_t;
+            if (model_loader) {
+                // load from base model
+                if (model_loader->tensors_map.name_to_idx.find(base_name) == model_loader->tensors_map.name_to_idx.end()) {
+                    fprintf(stderr, "%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str());
+                    return 1;
+                }
+                size_t idx = model_loader->tensors_map.name_to_idx[base_name];
+                llama_load_tensor & lt = model_loader->tensors_map.tensors[idx];
+                base_t = model_loader->get_tensor(base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] });
+                lt.data = (uint8_t *) lt.ggml_tensor->data;
+                model_loader->load_data_for(lt);
+                lt.ggml_tensor->data = lt.data;
+            }
+            else {
+                base_t = dest_t;
+            }
+
+            if (base_t->type == GGML_TYPE_Q4_0 || base_t->type == GGML_TYPE_Q4_1) {
+                if (!warned) {
+                    fprintf(stderr, "%s: warning: using a lora adapter with a quantized model may result in poor quality, "
+                                    "use a f16 or f32 base model with --lora-base\n", __func__);
+                    warned = true;
+                }
+            }
+
+            ggml_tensor * loraA = lora_tensors[base_name + ".loraA"];
+            ggml_tensor * loraB = lora_tensors[base_name + ".loraB"];
+
+            if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) {
+                fprintf(stderr, "%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");"
+                               " are you sure that this adapter is for this model?\n", __func__, base_t->ne[0], loraA->ne[1]);
+                return 1;
+            }
+
+            // w = w + BA*s
+            ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraA, loraB);
+
+            if (scaling != 1.0f) {
+                ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, scaling);
+                BA = ggml_scale(lora_ctx, BA, scale_tensor);
+            }
+
+            ggml_tensor * r;
+            if (base_t == dest_t) {
+                r = ggml_add_inplace(lora_ctx, dest_t, BA);
+            }
+            else {
+                r = ggml_add(lora_ctx, base_t, BA);
+                r = ggml_cpy(lora_ctx, r, dest_t);
+            }
+
+            struct ggml_cgraph gf = ggml_build_forward(r);
+            gf.n_threads = n_threads;
+            ggml_graph_compute(lora_ctx, &gf);
+
+            // we won't need these tensors again, reset the context to save memory
+            ggml_free(lora_ctx);
+            lora_ctx = ggml_init(params);
+            lora_tensors.clear();
+
+            n_tensors++;
+            if (n_tensors % 4 == 0)
+                fprintf(stderr, ".");
+        }
+    }
+
+    // TODO: this should be in a destructor, it will leak on failure
+    ggml_free(lora_ctx);
+    if (base_ctx) {
+        ggml_free(base_ctx);
+    }
+
+    const int64_t t_lora_us = ggml_time_us() - t_start_lora_us;
+    fprintf(stderr, " done (%.2f ms)\n", t_lora_us / 1000.0);
+
+    return 0;
+}
+
+int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, const char * path_base_model, int n_threads) {
+    try {
+        return llama_apply_lora_from_file_internal(ctx, path_lora, path_base_model, n_threads);
+    } catch (const std::string & err) {
+        fprintf(stderr, "%s: failed to apply lora adapter: %s\n", __func__, err.c_str());
+        return 1;
+    }
+}
+
 // Returns the KV cache that will contain the context for the
 // ongoing prediction with the model.
 const uint8_t * llama_get_kv_cache(struct llama_context * ctx) {
diff --git a/llama.h b/llama.h
index 19221759376858f65b3de856d22f795dc2f2620c..c35193a8a80de466b763a89a64e62675937e5b33 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -96,6 +96,18 @@ extern "C" {
             const char * fname_out,
       enum llama_ftype   ftype);
 
+    // Apply a LoRA adapter to a loaded model
+    // path_base_model is the path to a higher quality model to use as a base for
+    // the layers modified by the adapter. Can be NULL to use the current loaded model.
+    // The model needs to be reloaded before applying a new adapter, otherwise the adapter
+    // will be applied on top of the previous one
+    // Returns 0 on success
+    LLAMA_API int llama_apply_lora_from_file(
+            struct llama_context * ctx,
+                      const char * path_lora,
+                      const char * path_base_model,
+                             int   n_threads);
+
     // Returns the KV cache that will contain the context for the
     // ongoing prediction with the model.
     LLAMA_API const uint8_t * llama_get_kv_cache(struct llama_context * ctx);
index c92c0cc7148d86ffb7817bc06fad43ad29225ed2..ee9b2a6f366b18590e43237c62a7cc93fc4682ba 100755 (executable)
@@ -168,7 +168,7 @@ struct llama_mmap {
 #ifdef _POSIX_MAPPED_FILES
     static constexpr bool SUPPORTED = true;
 
-    llama_mmap(struct llama_file * file) {
+    llama_mmap(struct llama_file * file, bool prefetch = true) {
         size = file->size;
         int fd = fileno(file->fp);
         int flags = MAP_SHARED;
@@ -180,10 +180,12 @@ struct llama_mmap {
             throw format("mmap failed: %s", strerror(errno));
         }
 
-        // Advise the kernel to preload the mapped memory
-        if (madvise(addr, file->size, MADV_WILLNEED)) {
-            fprintf(stderr, "warning: madvise(.., MADV_WILLNEED) failed: %s\n",
-                    strerror(errno));
+        if (prefetch) {
+            // Advise the kernel to preload the mapped memory
+            if (madvise(addr, file->size, MADV_WILLNEED)) {
+                fprintf(stderr, "warning: madvise(.., MADV_WILLNEED) failed: %s\n",
+                        strerror(errno));
+            }
         }
     }
 
@@ -193,7 +195,7 @@ struct llama_mmap {
 #elif defined(_WIN32)
     static constexpr bool SUPPORTED = true;
 
-    llama_mmap(struct llama_file * file) {
+    llama_mmap(struct llama_file * file, bool prefetch = true) {
         size = file->size;
 
         HANDLE hFile = (HANDLE) _get_osfhandle(_fileno(file->fp));
@@ -215,13 +217,15 @@ struct llama_mmap {
         }
 
         #if _WIN32_WINNT >= _WIN32_WINNT_WIN8
-        // Advise the kernel to preload the mapped memory
-        WIN32_MEMORY_RANGE_ENTRY range;
-        range.VirtualAddress = addr;
-        range.NumberOfBytes = (SIZE_T)size;
-        if (!PrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) {
-            fprintf(stderr, "warning: PrefetchVirtualMemory failed: %s\n",
-                    llama_format_win_err(GetLastError()).c_str());
+        if (prefetch) {
+            // Advise the kernel to preload the mapped memory
+            WIN32_MEMORY_RANGE_ENTRY range;
+            range.VirtualAddress = addr;
+            range.NumberOfBytes = (SIZE_T)size;
+            if (!PrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) {
+                fprintf(stderr, "warning: PrefetchVirtualMemory failed: %s\n",
+                        llama_format_win_err(GetLastError()).c_str());
+            }
         }
         #else
         #pragma message("warning: You are building for pre-Windows 8; prefetch not supported")