]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
gguf-split : improve --split and --merge logic (#9619)
authorZhenwei Jin <redacted>
Wed, 2 Oct 2024 07:21:57 +0000 (15:21 +0800)
committerGitHub <redacted>
Wed, 2 Oct 2024 07:21:57 +0000 (10:21 +0300)
* make sure params --split and --merge are not specified at same time

* update gguf-split params parse logic

* Update examples/gguf-split/gguf-split.cpp

Co-authored-by: slaren <redacted>
---------

Co-authored-by: Xuan Son Nguyen <redacted>
Co-authored-by: slaren <redacted>
examples/gguf-split/gguf-split.cpp

index 82c239b8336be8a79606416fcd8ac8f9efe36e46..7e62657e118a46a0b49c73b433243db08639f58b 100644 (file)
 #endif
 
 enum split_operation : uint8_t {
-    SPLIT_OP_SPLIT,
-    SPLIT_OP_MERGE,
+    OP_NONE,
+    OP_SPLIT,
+    OP_MERGE,
+};
+
+enum split_mode : uint8_t {
+    MODE_NONE,
+    MODE_TENSOR,
+    MODE_SIZE,
 };
 
 struct split_params {
-    split_operation operation = SPLIT_OP_SPLIT;
+    split_operation operation = OP_NONE;
+    split_mode mode = MODE_NONE;
     size_t n_bytes_split = 0;
     int n_split_tensors = 128;
     std::string input;
@@ -87,59 +95,52 @@ static void split_params_parse_ex(int argc, const char ** argv, split_params & p
         }
 
         bool arg_found = false;
-        bool is_op_set = false;
-        bool is_mode_set = false;
         if (arg == "-h" || arg == "--help") {
             split_print_usage(argv[0]);
             exit(0);
-        }
-        if (arg == "--version") {
+        } else if (arg == "--version") {
             fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
             fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET);
             exit(0);
-        }
-        if (arg == "--dry-run") {
+        } else if (arg == "--dry-run") {
             arg_found = true;
             params.dry_run = true;
-        }
-        if (arg == "--no-tensor-first-split") {
+        } else if (arg == "--no-tensor-first-split") {
             arg_found = true;
             params.no_tensor_first_split = true;
-        }
-
-        if (is_op_set) {
-            throw std::invalid_argument("error: either --split or --merge can be specified, but not both");
-        }
-        if (arg == "--merge") {
+        } else if (arg == "--merge") {
             arg_found = true;
-            is_op_set = true;
-            params.operation = SPLIT_OP_MERGE;
-        }
-        if (arg == "--split") {
+            if (params.operation != OP_NONE && params.operation != OP_MERGE) {
+                throw std::invalid_argument("error: either --split or --merge can be specified, but not both");
+            }
+            params.operation = OP_MERGE;
+        } else if (arg == "--split") {
             arg_found = true;
-            is_op_set = true;
-            params.operation = SPLIT_OP_SPLIT;
-        }
-
-        if (is_mode_set) {
-            throw std::invalid_argument("error: either --split-max-tensors or --split-max-size can be specified, but not both");
-        }
-        if (arg == "--split-max-tensors") {
+            if (params.operation != OP_NONE && params.operation != OP_SPLIT) {
+                throw std::invalid_argument("error: either --split or --merge can be specified, but not both");
+            }
+            params.operation = OP_SPLIT;
+        } else if (arg == "--split-max-tensors") {
             if (++arg_idx >= argc) {
                 invalid_param = true;
                 break;
             }
             arg_found = true;
-            is_mode_set = true;
+            if (params.mode != MODE_NONE && params.mode != MODE_TENSOR) {
+                throw std::invalid_argument("error: either --split-max-tensors or --split-max-size can be specified, but not both");
+            }
+            params.mode = MODE_TENSOR;
             params.n_split_tensors = atoi(argv[arg_idx]);
-        }
-        if (arg == "--split-max-size") {
+        } else if (arg == "--split-max-size") {
             if (++arg_idx >= argc) {
                 invalid_param = true;
                 break;
             }
             arg_found = true;
-            is_mode_set = true;
+            if (params.mode != MODE_NONE && params.mode != MODE_SIZE) {
+                throw std::invalid_argument("error: either --split-max-tensors or --split-max-size can be specified, but not both");
+            }
+            params.mode = MODE_SIZE;
             params.n_bytes_split = split_str_to_n_bytes(argv[arg_idx]);
         }
 
@@ -148,6 +149,15 @@ static void split_params_parse_ex(int argc, const char ** argv, split_params & p
         }
     }
 
+    // the operation is split if not specified
+    if (params.operation == OP_NONE) {
+        params.operation = OP_SPLIT;
+    }
+    // the split mode is by tensor if not specified
+    if (params.mode == MODE_NONE) {
+        params.mode = MODE_TENSOR;
+    }
+
     if (invalid_param) {
         throw std::invalid_argument("error: invalid parameter for argument: " + arg);
     }
@@ -265,13 +275,15 @@ struct split_strategy {
     }
 
     bool should_split(int i_tensor, size_t next_size) {
-        if (params.n_bytes_split > 0) {
+        if (params.mode == MODE_SIZE) {
             // split by max size per file
             return next_size > params.n_bytes_split;
-        } else {
+        } else if (params.mode == MODE_TENSOR) {
             // split by number of tensors per file
             return i_tensor > 0 && i_tensor < n_tensors && i_tensor % params.n_split_tensors == 0;
         }
+        // should never happen
+        GGML_ABORT("invalid mode");
     }
 
     void print_info() {
@@ -559,9 +571,9 @@ int main(int argc, const char ** argv) {
     split_params_parse(argc, argv, params);
 
     switch (params.operation) {
-        case SPLIT_OP_SPLIT: gguf_split(params);
+        case OP_SPLIT: gguf_split(params);
             break;
-        case SPLIT_OP_MERGE: gguf_merge(params);
+        case OP_MERGE: gguf_merge(params);
             break;
         default: split_print_usage(argv[0]);
             exit(EXIT_FAILURE);