]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Support all LLaMA models + change Q4_0 quantization storage
authorGeorgi Gerganov <redacted>
Sat, 11 Mar 2023 08:47:09 +0000 (10:47 +0200)
committerGeorgi Gerganov <redacted>
Sat, 11 Mar 2023 09:28:30 +0000 (11:28 +0200)
README.md
convert-pth-to-ggml.py
ggml.c
main.cpp
utils.cpp

index afa0e7e9089b424726facfecc840d1bdd95e0750..02b0c09298cff077d9149c4642e75459ab6338cf 100644 (file)
--- a/README.md
+++ b/README.md
@@ -17,12 +17,11 @@ The main goal is to run the model using 4-bit quantization on a MacBook.
 
 This was hacked in an evening - I have no idea if it works correctly.
 
-So far, I've tested just the 7B model.
-Here is a typical run:
+Here is a typical run using LLaMA-7B:
 
 ```java
-make -j && ./main -m ../LLaMA-4bit/7B/ggml-model-q4_0.bin -p "Building a website can be done in 10 simple steps:" -t 8 -n 512
-I llama.cpp build info: 
+make -j && ./main -m ./models/7B/ggml-model-q4_0.bin -p "Building a website can be done in 10 simple steps:" -t 8 -n 512
+I llama.cpp build info:
 I UNAME_S:  Darwin
 I UNAME_P:  arm
 I UNAME_M:  arm64
@@ -34,7 +33,7 @@ I CXX:      Apple clang version 14.0.0 (clang-1400.0.29.202)
 
 make: Nothing to be done for `default'.
 main: seed = 1678486056
-llama_model_load: loading model from '../LLaMA-4bit/7B/ggml-model-q4_0.bin' - please wait ...
+llama_model_load: loading model from './models/7B/ggml-model-q4_0.bin' - please wait ...
 llama_model_load: n_vocab = 32000
 llama_model_load: n_ctx   = 512
 llama_model_load: n_embd  = 4096
@@ -110,6 +109,8 @@ https://user-images.githubusercontent.com/1991296/224442907-7693d4be-acaa-4e01-8
 
 ## Usage
 
+Here are the step for the LLaMA-7B model:
+
 ```bash
 # build this repo
 git clone https://github.com/ggerganov/llama.cpp
@@ -133,9 +134,40 @@ python3 convert-pth-to-ggml.py models/7B/ 1
 ./main -m ./models/7B/ggml-model-q4_0.bin -t 8 -n 128
 ```
 
+For the bigger models, there are a few extra quantization steps. For example, for LLaMA-13B, converting to FP16 format
+will create 2 ggml files, instead of one:
+
+```bash
+ggml-model-f16.bin
+ggml-model-f16.bin.1
+```
+
+You need to quantize each of them separately like this:
+
+```bash
+./quantize ./models/13B/ggml-model-f16.bin   ./models/13B/ggml-model-q4_0.bin 2
+./quantize ./models/13B/ggml-model-f16.bin.1 ./models/13B/ggml-model-q4_0.bin.1 2
+```
+
+Everything else is the same. Simply run:
+
+```bash
+./main -m ./models/13B/ggml-model-q4_0.bin -t 8 -n 128
+```
+
+The number of files generated for each model is as follows:
+
+```
+7B  -> 1 file
+13B -> 2 files
+33B -> 4 files
+65B -> 8 files
+```
+
+When running the larger models, make sure you have enough disk space to store all the intermediate files.
+
 ## Limitations
 
-- Currently, only LLaMA-7B is supported since I haven't figured out how to merge the tensors of the bigger models. However, in theory, you should be able to run 65B on a 64GB MacBook
 - Not sure if my tokenizer is correct. There are a few places where we might have a mistake:
   - https://github.com/ggerganov/llama.cpp/blob/26c084662903ddaca19bef982831bfb0856e8257/convert-pth-to-ggml.py#L79-L87
   - https://github.com/ggerganov/llama.cpp/blob/26c084662903ddaca19bef982831bfb0856e8257/utils.h#L65-L69
index bd0a9d0898566d3125b3a2e3fc844fd22c99f4c9..fc217c7ec3b328a613be93efa57fc793f282e641 100644 (file)
@@ -33,12 +33,23 @@ if len(sys.argv) < 3:
 
 # output in the same directory as the model
 dir_model = sys.argv[1]
-fname_out = sys.argv[1] + "/ggml-model.bin"
 
 fname_hparams   = sys.argv[1] + "/params.json"
-fname_model     = sys.argv[1] + "/consolidated.00.pth"
 fname_tokenizer = sys.argv[1] + "/../tokenizer.model"
 
+def get_n_parts(dim):
+    if dim == 4096:
+        return 1
+    elif dim == 5120:
+        return 2
+    elif dim == 6656:
+        return 4
+    elif dim == 8192:
+        return 8
+    else:
+        print("Invalid dim: " + str(dim))
+        sys.exit(1)
+
 # possible data types
 #   ftype == 0 -> float32
 #   ftype == 1 -> float16
@@ -61,76 +72,91 @@ tokenizer = SentencePieceProcessor(fname_tokenizer)
 
 hparams.update({"vocab_size": tokenizer.vocab_size()})
 
+n_parts = get_n_parts(hparams["dim"])
+
 print(hparams)
+print('n_parts = ', n_parts)
 
-model = torch.load(fname_model, map_location="cpu")
-
-fout = open(fname_out, "wb")
-
-fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
-fout.write(struct.pack("i", hparams["vocab_size"]))
-fout.write(struct.pack("i", hparams["dim"]))
-fout.write(struct.pack("i", hparams["multiple_of"]))
-fout.write(struct.pack("i", hparams["n_heads"]))
-fout.write(struct.pack("i", hparams["n_layers"]))
-fout.write(struct.pack("i", hparams["dim"] // hparams["n_heads"])) # rot (obsolete)
-fout.write(struct.pack("i", ftype))
-
-# Is this correct??
-for i in range(32000):
-    # TODO: this is probably wrong - not sure how this tokenizer works
-    text = tokenizer.decode([29889, i]).encode('utf-8')
-    # remove the first byte (it's always '.')
-    text = text[1:]
-    fout.write(struct.pack("i", len(text)))
-    fout.write(text)
-
-for k, v in model.items():
-    name = k
-    shape = v.shape
-
-    # skip layers.X.attention.inner_attention.rope.freqs
-    if name[-5:] == "freqs":
-        continue
-
-    print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype)
-
-    #data = tf.train.load_variable(dir_model, name).squeeze()
-    data = v.numpy().squeeze()
-    n_dims = len(data.shape);
-
-    # for efficiency - transpose some matrices
-    # "model/h.*/attn/c_attn/w"
-    # "model/h.*/attn/c_proj/w"
-    # "model/h.*/mlp/c_fc/w"
-    # "model/h.*/mlp/c_proj/w"
-    #if name[-14:] == "/attn/c_attn/w" or \
-    #   name[-14:] == "/attn/c_proj/w" or \
-    #   name[-11:] == "/mlp/c_fc/w" or \
-    #   name[-13:] == "/mlp/c_proj/w":
-    #    print("  Transposing")
-    #    data = data.transpose()
-
-    dshape = data.shape
-
-    # default type is fp16
-    ftype_cur = 1
-    if ftype == 0 or n_dims == 1:
-        print("  Converting to float32")
-        data = data.astype(np.float32)
-        ftype_cur = 0
-
-    # header
-    str = name.encode('utf-8')
-    fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
-    for i in range(n_dims):
-        fout.write(struct.pack("i", dshape[n_dims - 1 - i]))
-    fout.write(str);
-
-    # data
-    data.tofile(fout)
-
-fout.close()
-
-print("Done. Output file: " + fname_out)
-print("")
+for p in range(n_parts):
+    print('Processing part ', p)
+
+    #fname_model = sys.argv[1] + "/consolidated.00.pth"
+    fname_model = sys.argv[1] + "/consolidated.0" + str(p) + ".pth"
+    fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"
+    if (p > 0):
+        fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" + "." + str(p)
+
+    model = torch.load(fname_model, map_location="cpu")
+
+    fout = open(fname_out, "wb")
+
+    fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
+    fout.write(struct.pack("i", hparams["vocab_size"]))
+    fout.write(struct.pack("i", hparams["dim"]))
+    fout.write(struct.pack("i", hparams["multiple_of"]))
+    fout.write(struct.pack("i", hparams["n_heads"]))
+    fout.write(struct.pack("i", hparams["n_layers"]))
+    fout.write(struct.pack("i", hparams["dim"] // hparams["n_heads"])) # rot (obsolete)
+    fout.write(struct.pack("i", ftype))
+
+    # Is this correct??
+    for i in range(32000):
+        # TODO: this is probably wrong - not sure how this tokenizer works
+        text = tokenizer.decode([29889, i]).encode('utf-8')
+        # remove the first byte (it's always '.')
+        text = text[1:]
+        fout.write(struct.pack("i", len(text)))
+        fout.write(text)
+
+    for k, v in model.items():
+        name = k
+        shape = v.shape
+
+        # skip layers.X.attention.inner_attention.rope.freqs
+        if name[-5:] == "freqs":
+            continue
+
+        print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype)
+
+        #data = tf.train.load_variable(dir_model, name).squeeze()
+        data = v.numpy().squeeze()
+        n_dims = len(data.shape);
+
+        # for efficiency - transpose some matrices
+        # "model/h.*/attn/c_attn/w"
+        # "model/h.*/attn/c_proj/w"
+        # "model/h.*/mlp/c_fc/w"
+        # "model/h.*/mlp/c_proj/w"
+        #if name[-14:] == "/attn/c_attn/w" or \
+        #   name[-14:] == "/attn/c_proj/w" or \
+        #   name[-11:] == "/mlp/c_fc/w" or \
+        #   name[-13:] == "/mlp/c_proj/w":
+        #    print("  Transposing")
+        #    data = data.transpose()
+
+        dshape = data.shape
+
+        # default type is fp16
+        ftype_cur = 1
+        if ftype == 0 or n_dims == 1:
+            print("  Converting to float32")
+            data = data.astype(np.float32)
+            ftype_cur = 0
+
+        # header
+        sname = name.encode('utf-8')
+        fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur))
+        for i in range(n_dims):
+            fout.write(struct.pack("i", dshape[n_dims - 1 - i]))
+        fout.write(sname);
+
+        # data
+        data.tofile(fout)
+
+    # I hope this deallocates the memory ..
+    model = None
+
+    fout.close()
+
+    print("Done. Output file: " + fname_out + ", (part ", p, ")")
+    print("")
diff --git a/ggml.c b/ggml.c
index ee3b0af0291a682884e6419e1c32d2b3e8fe5ea7..bb714e2bc270d7f35a7457dfc5b03ea75baecb54 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -366,9 +366,10 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
     assert(k % QK == 0);
 
     const int nb = k / QK;
+    const size_t bs = sizeof(float) + QK/2;
 
-    float   * restrict pd = (float *)   (y);
-    uint8_t * restrict pb = (uint8_t *) (pd + nb);
+    uint8_t * restrict pd = (uint8_t *) (y + 0*bs);
+    uint8_t * restrict pb = (uint8_t *) (y + 0*bs + sizeof(float));
 
     uint8_t pp[QK/2];
 
@@ -395,7 +396,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
         const float d = amax / ((1 << 3) - 1);
         const float id = d ? 1.0/d : 0.0;
 
-        pd[i] = d;
+        *(float *)pd = d;
+        pd += bs;
 
         for (int l = 0; l < 8; l++) {
             const float32x4_t v  = vmulq_n_f32(srcv[l], id);
@@ -406,7 +408,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
             pp[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
         }
 
-        memcpy(pb + i*16, pp, sizeof(pp));
+        memcpy(pb, pp, sizeof(pp));
+        pb += bs;
     }
 #else
 #error "not implemented for QK"
@@ -434,7 +437,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
         const float d = amax / ((1 << 3) - 1);
         const float id = d ? 1.0/d : 0.0;
 
-        pd[i] = d;
+        *(float *)pd = d;
+        pd += bs;
 
         for (int l = 0; l < 8; l++) {
             const v128_t v  = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id));
@@ -445,7 +449,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
             pp[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
         }
 
-        memcpy(pb + i*16, pp, sizeof(pp));
+        memcpy(pb, pp, sizeof(pp));
+        pb += bs;
     }
 #else
 #error "not implemented for QK"
@@ -463,7 +468,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
         const float d = amax / ((1 << 3) - 1);
         const float id = d ? 1.0f/d : 0.0f;
 
-        pd[i] = d;
+        *(float *)pd = d;
+        pd += bs;
 
         for (int l = 0; l < QK; l += 2) {
             const float v0 = x[i*QK + l + 0]*id;
@@ -478,7 +484,8 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
             pp[l/2] = vi0 | (vi1 << 4);
         }
 
-        memcpy(pb + i*QK/2, pp, sizeof(pp));
+        memcpy(pb, pp, sizeof(pp));
+        pb += bs;
     }
 #endif
 }
@@ -535,15 +542,16 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
     assert(k % QK == 0);
 
     const int nb = k / QK;
+    const size_t bs = sizeof(float) + QK/2;
 
-    const float   * restrict pd = (const float *)   (x);
-    const uint8_t * restrict pb = (const uint8_t *) (pd + nb);
+    const uint8_t * restrict pd = (const uint8_t *) (x + 0*bs);
+    const uint8_t * restrict pb = (const uint8_t *) (x + 0*bs + sizeof(float));
 
     // scalar
     for (int i = 0; i < nb; i++) {
-        const float d = pd[i];
+        const float d = *(const float *) (pd + i*bs);
 
-        const uint8_t * restrict pp = pb + i*QK/2;
+        const uint8_t * restrict pp = pb + i*bs;
 
         for (int l = 0; l < QK; l += 2) {
             const uint8_t vi = pp[l/2];
@@ -554,6 +562,8 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
             const float v0 = (vi0 - 8)*d;
             const float v1 = (vi1 - 8)*d;
 
+            //printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1);
+
             y[i*QK + l + 0] = v0;
             y[i*QK + l + 1] = v1;
 
@@ -1179,11 +1189,13 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
     assert(n % QK == 0);
     assert(nb % 2 == 0);
 
-    const float * restrict pd0 = (const float *) x;
-    const float * restrict pd1 = (const float *) y;
+    const size_t bs = sizeof(float) + QK/2;
 
-    const uint8_t * restrict pb0 = (const uint8_t *) (pd0 + nb);
-    const uint8_t * restrict pb1 = (const uint8_t *) (pd1 + nb);
+    const uint8_t * restrict pd0 = (const uint8_t *) (x + 0*bs);
+    const uint8_t * restrict pd1 = (const uint8_t *) (y + 0*bs);
+
+    const uint8_t * restrict pb0 = (const uint8_t *) (x + 0*bs + sizeof(float));
+    const uint8_t * restrict pb1 = (const uint8_t *) (y + 0*bs + sizeof(float));
 
     float sumf = 0.0;
 
@@ -1193,23 +1205,23 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
     float sum1 = 0.0f;
 
     for (int i = 0; i < nb; i += 2) {
-        const float d0_0 = pd0[i + 0];
-        const float d1_0 = pd1[i + 0];
-        const float d0_1 = pd0[i + 1];
-        const float d1_1 = pd1[i + 1];
+        const float d0_0 = *(const float *) (pd0 + i*bs);
+        const float d1_0 = *(const float *) (pd1 + i*bs);
+        const float d0_1 = *(const float *) (pd0 + (i + 1)*bs);
+        const float d1_1 = *(const float *) (pd1 + (i + 1)*bs);
 
         //printf("d0_0: %f, d1_0: %f, d0_1: %f, d1_1: %f\n", d0_0, d1_0, d0_1, d1_1);
 
-        const uint8_t * restrict p0 = pb0 + i*16;
-        const uint8_t * restrict p1 = pb1 + i*16;
+        const uint8_t * restrict p0 = pb0 + i*bs;
+        const uint8_t * restrict p1 = pb1 + i*bs;
 
         const uint8x16_t m4b = vdupq_n_u8(0xf);
         const int8x16_t  s8b = vdupq_n_s8(0x8);
 
         const uint8x16_t v0_0 = vld1q_u8(p0);
         const uint8x16_t v1_0 = vld1q_u8(p1);
-        const uint8x16_t v0_1 = vld1q_u8(p0 + 16);
-        const uint8x16_t v1_1 = vld1q_u8(p1 + 16);
+        const uint8x16_t v0_1 = vld1q_u8(p0 + bs);
+        const uint8x16_t v1_1 = vld1q_u8(p1 + bs);
 
         // 4-bit -> 8-bit
         const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
@@ -1280,21 +1292,21 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
     float sum1 = 0.0f;
 
     for (int i = 0; i < nb; i += 2) {
-        const float d0_0 = pd0[i + 0];
-        const float d0_1 = pd0[i + 1];
-        const float d1_0 = pd1[i + 0];
-        const float d1_1 = pd1[i + 1];
+        const float d0_0 = *(const float *) (pd0 + i*bs);
+        const float d1_0 = *(const float *) (pd1 + i*bs);
+        const float d0_1 = *(const float *) (pd0 + (i + 1)*bs);
+        const float d1_1 = *(const float *) (pd1 + (i + 1)*bs);
 
-        const uint8_t * restrict p0 = pb0 + i*16;
-        const uint8_t * restrict p1 = pb1 + i*16;
+        const uint8_t * restrict p0 = pb0 + i*bs;
+        const uint8_t * restrict p1 = pb1 + i*bs;
 
         const v128_t m4b = wasm_u8x16_splat(0xf);
         const v128_t s8b = wasm_i8x16_splat(0x8);
 
         const v128_t v0_0 = wasm_v128_load(p0);
-        const v128_t v0_1 = wasm_v128_load(p0 + 16);
+        const v128_t v0_1 = wasm_v128_load(p0 + bs);
         const v128_t v1_0 = wasm_v128_load(p1);
-        const v128_t v1_1 = wasm_v128_load(p1 + 16);
+        const v128_t v1_1 = wasm_v128_load(p1 + bs);
 
         // 4-bit -> 8-bit
         const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
@@ -1363,11 +1375,11 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
 #else
     // scalar
     for (int i = 0; i < nb; i++) {
-        const float d0 = pd0[i];
-        const float d1 = pd1[i];
+        const float d0 = *(const float *) (pd0 + i*bs);
+        const float d1 = *(const float *) (pd1 + i*bs);
 
-        const uint8_t * restrict p0 = pb0 + i*QK/2;
-        const uint8_t * restrict p1 = pb1 + i*QK/2;
+        const uint8_t * restrict p0 = pb0 + i*bs;
+        const uint8_t * restrict p1 = pb1 + i*bs;
 
         for (int j = 0; j < QK/2; j++) {
             const uint8_t v0 = p0[j];
@@ -1552,16 +1564,17 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res
     assert(n % QK == 0);
 
     const int nb = n / QK;
+    const size_t bs = sizeof(float) + QK/2;
 
-    const float   * restrict pd = (const float *)   (x);
-    const uint8_t * restrict pb = (const uint8_t *) (pd + nb);
+    const uint8_t * restrict pd = (const uint8_t *) (x + 0*bs);
+    const uint8_t * restrict pb = (const uint8_t *) (x + 0*bs + sizeof(float));
 
 #if __ARM_NEON
 #if QK == 32
     for (int i = 0; i < nb; ++i) {
-        const float d0 = pd[i]*v;
+        const float d0 = v*(*(const float *) (pd + i*bs));
 
-        const uint8_t * restrict pp = pb + i*16;
+        const uint8_t * restrict pp = pb + i*bs;
 
         const uint8x8_t m4b = vdup_n_u8(0xf);
         const int8x8_t  s8b = vdup_n_s8(0x8);
@@ -1615,9 +1628,9 @@ inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * res
 #else
     // scalar
     for (int i = 0; i < nb; i++) {
-        const float d = pd[i];
+        const float d = *(const float *) (pd + i*bs);
 
-        const uint8_t * restrict pp = pb + i*QK/2;
+        const uint8_t * restrict pp = pb + i*bs;
 
         for (int l = 0; l < QK; l += 2) {
             const uint8_t vi = pp[l/2];
index eca71408305edf02b5bc5900079567a607a9f444..d28fc916bac156743e66b1f5dd77e280722e6028 100644 (file)
--- a/main.cpp
+++ b/main.cpp
 #include <string>
 #include <vector>
 
+// determine number of model parts based on the dimension
+static const std::map<int, int> LLAMA_N_PARTS = {
+    { 4096, 1 },
+    { 5120, 2 },
+    { 6656, 4 },
+    { 8192, 8 },
+};
+
 // default hparams (LLaMA 7B)
 struct llama_hparams {
     int32_t n_vocab = 32000;
@@ -82,6 +90,7 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
     }
 
     int n_ff = 0;
+    int n_parts = 0;
 
     // load hparams
     {
@@ -99,6 +108,7 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
         hparams.n_ctx = n_ctx;
 
         n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
+        n_parts = LLAMA_N_PARTS.at(hparams.n_embd);
 
         printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
         printf("%s: n_ctx   = %d\n", __func__, hparams.n_ctx);
@@ -109,6 +119,7 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
         printf("%s: n_rot   = %d\n", __func__, hparams.n_rot);
         printf("%s: f16     = %d\n", __func__, hparams.f16);
         printf("%s: n_ff    = %d\n", __func__, n_ff);
+        printf("%s: n_parts = %d\n", __func__, n_parts);
     }
 
     // load vocab
@@ -220,7 +231,7 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
 
         model.layers.resize(n_layer);
 
-        model.tok_embeddings = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);
+        model.tok_embeddings = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
 
         model.norm   = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
         model.output = ggml_new_tensor_2d(ctx, wtype,         n_embd, n_vocab);
@@ -234,14 +245,14 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
         for (int i = 0; i < n_layer; ++i) {
             auto & layer = model.layers[i];
 
-            layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+            layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
 
-            layer.wq = ggml_new_tensor_2d(ctx, wtype, n_embd,   n_embd);
-            layer.wk = ggml_new_tensor_2d(ctx, wtype, n_embd,   n_embd);
-            layer.wv = ggml_new_tensor_2d(ctx, wtype, n_embd,   n_embd);
-            layer.wo = ggml_new_tensor_2d(ctx, wtype, n_embd,   n_embd);
+            layer.wq = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
+            layer.wk = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
+            layer.wv = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
+            layer.wo = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
 
-            layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,   n_embd);
+            layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
 
             layer.w1 = ggml_new_tensor_2d(ctx, wtype, n_embd,   n_ff);
             layer.w2 = ggml_new_tensor_2d(ctx, wtype,   n_ff, n_embd);
@@ -282,94 +293,208 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
         printf("%s: memory_size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);
     }
 
-    // load weights
-    {
-        int n_tensors = 0;
-        size_t total_size = 0;
+    const size_t file_offset = fin.tellg();
 
-        printf("%s: ", __func__);
+    fin.close();
 
-        while (true) {
-            int32_t n_dims;
-            int32_t length;
-            int32_t ftype;
+    std::vector<uint8_t> tmp;
 
-            fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
-            fin.read(reinterpret_cast<char *>(&length), sizeof(length));
-            fin.read(reinterpret_cast<char *>(&ftype),  sizeof(ftype));
+    for (int i = 0; i < n_parts; ++i) {
+        const int part_id = i;
+        //const int part_id = n_parts - i - 1;
 
-            if (fin.eof()) {
-                break;
-            }
+        std::string fname_part = fname;
+        if (i > 0) {
+            fname_part += "." + std::to_string(i);
+        }
 
-            int32_t nelements = 1;
-            int32_t ne[2] = { 1, 1 };
-            for (int i = 0; i < n_dims; ++i) {
-                fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
-                nelements *= ne[i];
-            }
+        printf("%s: loading model part %d/%d from '%s'\n", __func__, i+1, n_parts, fname_part.c_str());
 
-            std::string name(length, 0);
-            fin.read(&name[0], length);
+        fin = std::ifstream(fname_part, std::ios::binary);
+        fin.seekg(file_offset);
 
-            if (model.tensors.find(name.data()) == model.tensors.end()) {
-                fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
-                return false;
-            }
+        // load weights
+        {
+            int n_tensors = 0;
+            size_t total_size = 0;
 
-            auto tensor = model.tensors[name.data()];
-            if (ggml_nelements(tensor) != nelements) {
-                fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
-                return false;
-            }
+            printf("%s: ", __func__);
 
-            if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
-                fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
-                        __func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
-                return false;
-            }
+            while (true) {
+                int32_t n_dims;
+                int32_t length;
+                int32_t ftype;
 
-            if (0) {
-                static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
-                printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ftype_str[ftype], ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
-            }
+                fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
+                fin.read(reinterpret_cast<char *>(&length), sizeof(length));
+                fin.read(reinterpret_cast<char *>(&ftype),  sizeof(ftype));
+
+                if (fin.eof()) {
+                    break;
+                }
+
+                int32_t nelements = 1;
+                int32_t ne[2] = { 1, 1 };
+                for (int i = 0; i < n_dims; ++i) {
+                    fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
+                    nelements *= ne[i];
+                }
 
-            size_t bpe = 0;
+                std::string name(length, 0);
+                fin.read(&name[0], length);
 
-            switch (ftype) {
-                case 0: bpe = ggml_type_size(GGML_TYPE_F32);  break;
-                case 1: bpe = ggml_type_size(GGML_TYPE_F16);  break;
-                case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break;
-                case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break;
-                default:
-                        {
-                            fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype);
+                if (model.tensors.find(name.data()) == model.tensors.end()) {
+                    fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
+                    return false;
+                }
+
+                // split_type = 0: split by columns
+                // split_type = 1: split by rows
+                int split_type = 0;
+
+                // split_type = 0:
+                // regex:
+                //   - tok_embeddings.*
+                //   - layers.*.attention.wo.weight
+                //   - layers.*.feed_forward.w2.weight
+
+                // split_type = 1:
+                // regex:
+                //   - output.*
+                //   - layers.*.attention.wq.weight
+                //   - layers.*.attention.wk.weight
+                //   - layers.*.attention.wv.weight
+                //   - layers.*.feed_forward.w1.weight
+                //   - layers.*.feed_forward.w3.weight
+                if (name.find("tok_embeddings") != std::string::npos) {
+                    split_type = 0;
+                } else if (name.find("layers") != std::string::npos) {
+                    if (name.find("attention.wo.weight") != std::string::npos) {
+                        split_type = 0;
+                    } else if (name.find("feed_forward.w2.weight") != std::string::npos) {
+                        split_type = 0;
+                    } else {
+                        split_type = 1;
+                    }
+                } else if (name.find("output") != std::string::npos) {
+                    split_type = 1;
+                }
+
+                auto tensor = model.tensors[name.data()];
+
+                if (n_dims == 1) {
+                    if (ggml_nelements(tensor) != nelements) {
+                        fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
+                        return false;
+                    }
+                } else {
+                    if (ggml_nelements(tensor)/n_parts != nelements) {
+                        fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
+                        return false;
+                    }
+                }
+
+                if (n_dims == 1) {
+                    if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
+                        fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
+                                __func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
+                        return false;
+                    }
+                } else {
+                    if (split_type == 0) {
+                        if (tensor->ne[0]/n_parts != ne[0] || tensor->ne[1] != ne[1]) {
+                            fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
+                                    __func__, name.data(), tensor->ne[0]/n_parts, tensor->ne[1], ne[0], ne[1]);
+                            return false;
+                        }
+                    } else {
+                        if (tensor->ne[0] != ne[0] || tensor->ne[1]/n_parts != ne[1]) {
+                            fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
+                                    __func__, name.data(), tensor->ne[0], tensor->ne[1]/n_parts, ne[0], ne[1]);
                             return false;
                         }
-            };
+                    }
+                }
 
-            if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
-                fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
-                        __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
-                return false;
-            }
+                if (0) {
+                    static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
+                    printf("%24s - [%5d, %5d], type = %6s, split = %d\n", name.data(), ne[0], ne[1], ftype_str[ftype], split_type);
+                }
+
+                size_t bpe = 0;
+
+                switch (ftype) {
+                    case 0: bpe = ggml_type_size(GGML_TYPE_F32);  break;
+                    case 1: bpe = ggml_type_size(GGML_TYPE_F16);  break;
+                    case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break;
+                    case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break;
+                    default:
+                            {
+                                fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype);
+                                return false;
+                            }
+                };
+
+                if (n_dims == 1 || n_parts == 1) {
+                    if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
+                        fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
+                                __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
+                        return false;
+                    }
+
+                    if (part_id == 0) {
+                        fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
+                    } else {
+                        fin.seekg(ggml_nbytes(tensor), std::ios::cur);
+                    }
+
+                    total_size += ggml_nbytes(tensor);
+                } else {
+                    if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)/n_parts) {
+                        fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
+                                __func__, name.data(), ggml_nbytes(tensor)/n_parts, nelements*bpe);
+                        return false;
+                    }
+
+                    if (split_type == 0) {
+                        const int np0 = ne[0];
+
+                        const size_t row_size = (tensor->ne[0]/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
+                        assert(row_size == tensor->nb[1]);
+
+                        for (int i1 = 0; i1 < ne[1]; ++i1) {
+                            const size_t offset_row = i1*row_size;
+                            const size_t offset = offset_row + ((part_id*np0)/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
+                            fin.read(reinterpret_cast<char *>(tensor->data) + offset, row_size/n_parts);
+                        }
+                    } else {
+                        const int np1 = ne[1];
 
-            fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
+                        const size_t row_size = (tensor->ne[0]/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
 
-            //printf("%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
-            total_size += ggml_nbytes(tensor);
-            if (++n_tensors % 8 == 0) {
-                printf(".");
-                fflush(stdout);
+                        for (int i1 = 0; i1 < ne[1]; ++i1) {
+                            const size_t offset_row = (i1 + part_id*np1)*row_size;
+                            fin.read(reinterpret_cast<char *>(tensor->data) + offset_row, row_size);
+                        }
+                    }
+
+                    total_size += ggml_nbytes(tensor)/n_parts;
+                }
+
+                //printf("%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
+                if (++n_tensors % 8 == 0) {
+                    printf(".");
+                    fflush(stdout);
+                }
             }
-        }
 
-        printf(" done\n");
+            printf(" done\n");
 
-        printf("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors);
-    }
+            printf("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors);
+        }
 
-    fin.close();
+        fin.close();
+    }
 
     return true;
 }
index 6bd1fc02dac7cf7935f8594afa0dbb6fc50c32ba..abb34756ac026a60890bc2fe4299cd6df9ef3f83 100644 (file)
--- a/utils.cpp
+++ b/utils.cpp
@@ -448,7 +448,8 @@ gpt_vocab::id llama_sample_top_p(
 
 size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t * hist) {
     const int nb = k / qk;
-    const size_t row_size = nb*(sizeof(float) + sizeof(uint8_t)*qk/2);
+    const size_t bs = (sizeof(float) + sizeof(uint8_t)*qk/2);
+    const size_t row_size = nb*bs;
 
     assert(k % qk == 0);
 
@@ -457,8 +458,8 @@ size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t
     char * pdst = (char *) dst;
 
     for (int j = 0; j < n; j += k) {
-        float   * pd = (float *)   (pdst + (j/k)*row_size);
-        uint8_t * pb = (uint8_t *) (pd + nb);
+        uint8_t * pd = (uint8_t *) (pdst + (j/k)*row_size + 0*bs);
+        uint8_t * pb = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + sizeof(float));
 
         for (int i = 0; i < nb; i++) {
             float amax = 0.0f; // absolute max
@@ -472,7 +473,8 @@ size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t
                 const float d = amax / ((1 << 3) - 1);
                 const float id = d ? 1.0f/d : 0.0f;
 
-                pd[i] = d;
+                *(float *) pd = d;
+                pd += bs;
 
                 for (int l = 0; l < qk; l += 2) {
                     const float v0 = (src[j + i*qk + l + 0])*id;
@@ -490,7 +492,8 @@ size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t
                     pp[l/2] = vi0 | (vi1 << 4);
                 }
 
-                memcpy(pb + i*qk/2, pp, sizeof(pp));
+                memcpy(pb, pp, sizeof(pp));
+                pb += bs;
             }
         }
     }