]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
split: allow --split-max-size option (#6343)
authorXuan Son Nguyen <redacted>
Fri, 29 Mar 2024 21:34:44 +0000 (22:34 +0100)
committerGitHub <redacted>
Fri, 29 Mar 2024 21:34:44 +0000 (22:34 +0100)
* split by max size

* clean up arg parse

* split: ok

* add dry run option

* error on 0 tensors

* be positive

* remove next_metadata_size

examples/gguf-split/gguf-split.cpp

index b1af599923809835eced18a292457e5ab4aa6bf8..24acbf02a4eeddd361bbe03ca4aa6bfd4303889d 100644 (file)
@@ -28,9 +28,11 @@ enum split_operation : uint8_t {
 
 struct split_params {
     split_operation operation = SPLIT_OP_SPLIT;
+    size_t n_bytes_split = 0;
     int n_split_tensors = 128;
     std::string input;
     std::string output;
+    bool dry_run = false;
 };
 
 static void split_print_usage(const char * executable) {
@@ -41,15 +43,36 @@ static void split_print_usage(const char * executable) {
     printf("Apply a GGUF operation on IN to OUT.");
     printf("\n");
     printf("options:\n");
-    printf("  -h, --help            show this help message and exit\n");
-    printf("  --version             show version and build info\n");
-    printf("  --split               split GGUF to multiple GGUF (default)\n");
-    printf("  --split-max-tensors   max tensors in each split: default(%d)\n", default_params.n_split_tensors);
-    printf("  --merge               merge multiple GGUF to a single GGUF\n");
+    printf("  -h, --help              show this help message and exit\n");
+    printf("  --version               show version and build info\n");
+    printf("  --split                 split GGUF to multiple GGUF (enabled by default)\n");
+    printf("  --merge                 merge multiple GGUF to a single GGUF\n");
+    printf("  --split-max-tensors     max tensors in each split (default: %d)\n", default_params.n_split_tensors);
+    printf("  --split-max-size N(M|G) max size per split\n");
+    printf("  --dry-run               only print out a split plan and exit, without writing any new files\n");
     printf("\n");
 }
 
-static bool split_params_parse_ex(int argc, const char ** argv, split_params & params) {
+// return convert string, for example "128M" or "4G" to number of bytes
+static size_t split_str_to_n_bytes(std::string str) {
+    size_t n_bytes = 0;
+    int n;
+    if (str.back() == 'M') {
+        sscanf(str.c_str(), "%d", &n);
+        n_bytes = n * 1024 * 1024; // megabytes
+    } else if (str.back() == 'G') {
+        sscanf(str.c_str(), "%d", &n);
+        n_bytes = n * 1024 * 1024 * 1024; // gigabytes
+    } else {
+        throw std::invalid_argument("error: supported units are M (megabytes) or G (gigabytes), but got: " + std::string(1, str.back()));
+    }
+    if (n <= 0) {
+        throw std::invalid_argument("error: size must be a positive value");
+    }
+    return n_bytes;
+}
+
+static void split_params_parse_ex(int argc, const char ** argv, split_params & params) {
     std::string arg;
     const std::string arg_prefix = "--";
     bool invalid_param = false;
@@ -62,6 +85,8 @@ static bool split_params_parse_ex(int argc, const char ** argv, split_params & p
         }
 
         bool arg_found = false;
+        bool is_op_set = false;
+        bool is_mode_set = false;
         if (arg == "-h" || arg == "--help") {
             split_print_usage(argv[0]);
             exit(0);
@@ -71,23 +96,46 @@ static bool split_params_parse_ex(int argc, const char ** argv, split_params & p
             fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET);
             exit(0);
         }
+        if (arg == "--dry-run") {
+            arg_found = true;
+            params.dry_run = true;
+        }
 
+        if (is_op_set) {
+            throw std::invalid_argument("error: either --split or --merge can be specified, but not both");
+        }
         if (arg == "--merge") {
             arg_found = true;
+            is_op_set = true;
             params.operation = SPLIT_OP_MERGE;
         }
         if (arg == "--split") {
             arg_found = true;
+            is_op_set = true;
             params.operation = SPLIT_OP_SPLIT;
         }
+
+        if (is_mode_set) {
+            throw std::invalid_argument("error: either --split-max-tensors or --split-max-size can be specified, but not both");
+        }
         if (arg == "--split-max-tensors") {
             if (++arg_idx >= argc) {
                 invalid_param = true;
                 break;
             }
             arg_found = true;
+            is_mode_set = true;
             params.n_split_tensors = atoi(argv[arg_idx]);
         }
+        if (arg == "--split-max-size") {
+            if (++arg_idx >= argc) {
+                invalid_param = true;
+                break;
+            }
+            arg_found = true;
+            is_mode_set = true;
+            params.n_bytes_split = split_str_to_n_bytes(argv[arg_idx]);
+        }
 
         if (!arg_found) {
             throw std::invalid_argument("error: unknown argument: " + arg);
@@ -99,24 +147,17 @@ static bool split_params_parse_ex(int argc, const char ** argv, split_params & p
     }
 
     if (argc - arg_idx < 2) {
-        printf("%s: bad arguments\n", argv[0]);
-        split_print_usage(argv[0]);
-        return false;
+        throw std::invalid_argument("error: bad arguments");
     }
 
     params.input = argv[arg_idx++];
     params.output = argv[arg_idx++];
-
-    return true;
 }
 
 static bool split_params_parse(int argc, const char ** argv, split_params & params) {
     bool result = true;
     try {
-        if (!split_params_parse_ex(argc, argv, params)) {
-            split_print_usage(argv[0]);
-            exit(EXIT_FAILURE);
-        }
+        split_params_parse_ex(argc, argv, params);
     }
     catch (const std::invalid_argument & ex) {
         fprintf(stderr, "%s\n", ex.what());
@@ -140,15 +181,11 @@ struct split_strategy {
     struct ggml_context * ctx_meta = NULL;
     const int n_tensors;
 
-    const int n_split;
-    int i_split = 0;
-
-    int i_tensor = 0;
-
-    std::vector<uint8_t> read_data;
+    // one ctx_out per one output file
+    std::vector<struct gguf_context *> ctx_outs;
 
-    struct gguf_context * ctx_out;
-    std::ofstream fout;
+    // temporary buffer for reading in tensor data
+    std::vector<uint8_t> read_buf;
 
     split_strategy(const split_params & params,
             std::ifstream & f_input,
@@ -158,79 +195,141 @@ struct split_strategy {
         f_input(f_input),
         ctx_gguf(ctx_gguf),
         ctx_meta(ctx_meta),
-        n_tensors(gguf_get_n_tensors(ctx_gguf)),
-        n_split(std::ceil(1. * n_tensors / params.n_split_tensors)) {
-        }
-
-    bool should_split() const {
-        return i_tensor < n_tensors && i_tensor % params.n_split_tensors == 0;
-    }
-
-    void split_start() {
-        ctx_out = gguf_init_empty();
+        n_tensors(gguf_get_n_tensors(ctx_gguf)) {
+
+        // because we need to know list of tensors for each file in advance, we will build all the ctx_out for all output splits
+        int i_split = -1;
+        struct gguf_context * ctx_out = NULL;
+        auto new_ctx_out = [&]() {
+            i_split++;
+            if (ctx_out != NULL) {
+                if (gguf_get_n_tensors(ctx_out) == 0) {
+                    fprintf(stderr, "error: one of splits have 0 tensors. Maybe size or tensors limit is too small\n");
+                    exit(EXIT_FAILURE);
+                }
+                ctx_outs.push_back(ctx_out);
+            }
+            ctx_out = gguf_init_empty();
+            // Save all metadata in first split only
+            if (i_split == 0) {
+                gguf_set_kv(ctx_out, ctx_gguf);
+            }
+            gguf_set_val_u16(ctx_out, LLM_KV_SPLIT_NO, i_split);
+            gguf_set_val_u16(ctx_out, LLM_KV_SPLIT_COUNT, 0); // placeholder
+            gguf_set_val_i32(ctx_out, LLM_KV_SPLIT_TENSORS_COUNT, n_tensors);
+        };
 
-        // Save all metadata in first split only
-        if (i_split == 0) {
-            gguf_set_kv(ctx_out, ctx_gguf);
-        }
-        gguf_set_val_u16(ctx_out, LLM_KV_SPLIT_NO, i_split);
-        gguf_set_val_u16(ctx_out, LLM_KV_SPLIT_COUNT, n_split);
-        gguf_set_val_i32(ctx_out, LLM_KV_SPLIT_TENSORS_COUNT, n_tensors);
-
-        // populate the original tensors, so we get an initial metadata
-        for (int i = i_split * params.n_split_tensors; i < n_tensors && i < (i_split + 1) * params.n_split_tensors; ++i) {
-            struct ggml_tensor * meta = ggml_get_tensor(ctx_meta, gguf_get_tensor_name(ctx_gguf, i));
-            gguf_add_tensor(ctx_out, meta);
+        // initialize ctx_out for the first split
+        new_ctx_out();
+
+        // process tensors one by one
+        size_t curr_tensors_size = 0; // current size by counting only tensors size (without metadata)
+        for (int i = 0; i < n_tensors; ++i) {
+            struct ggml_tensor * t = ggml_get_tensor(ctx_meta, gguf_get_tensor_name(ctx_gguf, i));
+            // calculate the "imaginary" size = the current size + next tensor size
+            size_t n_bytes = GGML_PAD(ggml_nbytes(t), GGUF_DEFAULT_ALIGNMENT);
+            size_t next_tensors_size = curr_tensors_size + n_bytes;
+            if (should_split(i, next_tensors_size)) {
+                new_ctx_out();
+                curr_tensors_size = n_bytes;
+            } else {
+                curr_tensors_size = next_tensors_size;
+            }
+            gguf_add_tensor(ctx_out, t);
         }
 
-        char split_path[PATH_MAX] = {0};
-        llama_split_path(split_path, sizeof(split_path), params.output.c_str(), i_split, n_split);
+        // push the last ctx_out
+        ctx_outs.push_back(ctx_out);
 
-        fprintf(stderr, "%s: %s ...", __func__, split_path);
-        fout = std::ofstream(split_path, std::ios::binary);
-        fout.exceptions(std::ofstream::failbit); // fail fast on write errors
-
-        auto meta_size = gguf_get_meta_size(ctx_out);
-
-        // placeholder for the meta data
-        ::zeros(fout, meta_size);
-
-        i_split++;
+        // set the correct n_split for all ctx_out
+        for (auto & ctx : ctx_outs) {
+            gguf_set_val_u16(ctx, LLM_KV_SPLIT_COUNT, ctx_outs.size());
+        }
     }
 
-    void next_tensor() {
-        const char * t_name = gguf_get_tensor_name(ctx_gguf, i_tensor);
-        struct ggml_tensor * t = ggml_get_tensor(ctx_meta, t_name);
-        auto n_bytes = ggml_nbytes(t);
-
-        if (read_data.size() < n_bytes) {
-            read_data.resize(n_bytes);
+    ~split_strategy() {
+        for (auto & ctx_out : ctx_outs) {
+            gguf_free(ctx_out);
         }
+    }
 
-        auto offset = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, i_tensor);
-        f_input.seekg(offset);
-        f_input.read((char *)read_data.data(), n_bytes);
-
-        t->data = read_data.data();
-
-        // write tensor data + padding
-        fout.write((const char *)t->data, n_bytes);
-        zeros(fout, GGML_PAD(n_bytes, GGUF_DEFAULT_ALIGNMENT) - n_bytes);
+    bool should_split(int i_tensor, size_t next_size) {
+        if (params.n_bytes_split > 0) {
+            // split by max size per file
+            return next_size > params.n_bytes_split;
+        } else {
+            // split by number of tensors per file
+            return i_tensor > 0 && i_tensor < n_tensors && i_tensor % params.n_split_tensors == 0;
+        }
+    }
 
-        i_tensor++;
+    void print_info() {
+        printf("n_split: %ld\n", ctx_outs.size());
+        int i_split = 0;
+        for (auto & ctx_out : ctx_outs) {
+            // re-calculate the real gguf size for each split (= metadata size + total size of all tensors)
+            size_t total_size = gguf_get_meta_size(ctx_out);
+            for (int i = 0; i < gguf_get_n_tensors(ctx_out); ++i) {
+                struct ggml_tensor * t = ggml_get_tensor(ctx_meta, gguf_get_tensor_name(ctx_out, i));
+                total_size += ggml_nbytes(t);
+            }
+            total_size = total_size / 1024 / 1024; // convert to megabytes
+            printf("split %05d: n_tensors = %d, total_size = %ldM\n", i_split + 1, gguf_get_n_tensors(ctx_out), total_size);
+            i_split++;
+        }
     }
 
-    void split_end() {
-        // go back to beginning of file and write the updated metadata
-        fout.seekp(0);
-        std::vector<uint8_t> data(gguf_get_meta_size(ctx_out));
-        gguf_get_meta_data(ctx_out, data.data());
-        fout.write((const char *)data.data(), data.size());
+    void write() {
+        int i_split = 0;
+        int n_split = ctx_outs.size();
+        for (auto & ctx_out : ctx_outs) {
+            // construct file path
+            char split_path[PATH_MAX] = {0};
+            llama_split_path(split_path, sizeof(split_path), params.output.c_str(), i_split, n_split);
+
+            // open the output file
+            printf("Writing file %s ... ", split_path);
+            fflush(stdout);
+            std::ofstream fout = std::ofstream(split_path, std::ios::binary);
+            fout.exceptions(std::ofstream::failbit); // fail fast on write errors
+
+            // write metadata
+            std::vector<uint8_t> data(gguf_get_meta_size(ctx_out));
+            gguf_get_meta_data(ctx_out, data.data());
+            fout.write((const char *)data.data(), data.size());
+
+            // write tensors
+            for (int i = 0; i < gguf_get_n_tensors(ctx_out); ++i) {
+                // read tensor meta and prepare buffer
+                const char * t_name = gguf_get_tensor_name(ctx_out, i);
+                struct ggml_tensor * t = ggml_get_tensor(ctx_meta, t_name);
+                auto n_bytes = ggml_nbytes(t);
+                read_buf.resize(n_bytes);
+
+                // calculate offset
+                auto i_tensor_in = gguf_find_tensor(ctx_gguf, t_name); // idx of tensor in the input file
+                auto offset = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, i_tensor_in);
+
+                // copy tensor from input to output file
+                copy_file_to_file(f_input, fout, offset, n_bytes);
+                zeros(fout, GGML_PAD(n_bytes, GGUF_DEFAULT_ALIGNMENT) - n_bytes);
+            }
 
-        fout.close();
-        gguf_free(ctx_out);
+            printf("done\n");
+            // close the file
+            fout.close();
+            i_split++;
+        }
+    }
 
-        fprintf(stderr, "\033[3Ddone\n");
+    void copy_file_to_file(std::ifstream & f_in, std::ofstream & f_out, const size_t in_offset, const size_t len) {
+        // TODO: detect OS and use copy_file_range() here for better performance
+        if (read_buf.size() < len) {
+            read_buf.resize(len);
+        }
+        f_in.seekg(in_offset);
+        f_in.read((char *)read_buf.data(), len);
+        f_out.write((const char *)read_buf.data(), len);
     }
 };
 
@@ -254,32 +353,22 @@ static void gguf_split(const split_params & split_params) {
         exit(EXIT_FAILURE);
     }
 
+    // prepare the strategy
     split_strategy strategy(split_params, f_input, ctx_gguf, ctx_meta);
+    int n_split = strategy.ctx_outs.size();
+    strategy.print_info();
 
-    char first_split_path[PATH_MAX] = {0};
-    llama_split_path(first_split_path, sizeof(first_split_path),
-                     split_params.output.c_str(), strategy.i_split, strategy.n_split);
-    fprintf(stderr, "%s: %s -> %s (%d tensors per file)\n",
-            __func__, split_params.input.c_str(),
-            first_split_path,
-            split_params.n_split_tensors);
-
-    strategy.split_start();
-
-    while (strategy.i_tensor < strategy.n_tensors) {
-        strategy.next_tensor();
-        if (strategy.should_split()) {
-            strategy.split_end();
-            strategy.split_start();
-        }
+    if (!split_params.dry_run) {
+        // write all output splits
+        strategy.write();
     }
-    strategy.split_end();
 
+    // done, clean up
     gguf_free(ctx_gguf);
     f_input.close();
 
     fprintf(stderr, "%s: %d gguf split written with a total of %d tensors.\n",
-            __func__, strategy.n_split, strategy.n_tensors);
+            __func__, n_split, strategy.n_tensors);
 }
 
 static void gguf_merge(const split_params & split_params) {
@@ -448,10 +537,6 @@ static void gguf_merge(const split_params & split_params) {
 }
 
 int main(int argc, const char ** argv) {
-    if (argc < 3) {
-        split_print_usage(argv[0]);
-    }
-
     split_params params;
     split_params_parse(argc, argv, params);