]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
quantize : improve type name parsing (#9570)
authorslaren <redacted>
Fri, 20 Sep 2024 18:55:36 +0000 (20:55 +0200)
committerGitHub <redacted>
Fri, 20 Sep 2024 18:55:36 +0000 (20:55 +0200)
quantize : do not ignore invalid types in arg parsing

quantize : ignore case of type and ftype arguments

examples/quantize/quantize.cpp

index a23bfb86b350fa18bf7f02e7eea9500065068b9c..b989932107dba56623bfcb78c3154c1a4fda41ad 100644 (file)
@@ -63,6 +63,16 @@ static const char * const LLM_KV_QUANTIZE_IMATRIX_DATASET    = "quantize.imatrix
 static const char * const LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES  = "quantize.imatrix.entries_count";
 static const char * const LLM_KV_QUANTIZE_IMATRIX_N_CHUNKS   = "quantize.imatrix.chunks_count";
 
+static bool striequals(const char * a, const char * b) {
+    while (*a && *b) {
+        if (std::tolower(*a) != std::tolower(*b)) {
+            return false;
+        }
+        a++; b++;
+    }
+    return *a == *b;
+}
+
 static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftype, std::string & ftype_str_out) {
     std::string ftype_str;
 
@@ -70,7 +80,7 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp
         ftype_str.push_back(std::toupper(ch));
     }
     for (auto & it : QUANT_OPTIONS) {
-        if (it.name == ftype_str) {
+        if (striequals(it.name.c_str(), ftype_str.c_str())) {
             ftype = it.ftype;
             ftype_str_out = it.name;
             return true;
@@ -225,15 +235,15 @@ static int prepare_imatrix(const std::string & imatrix_file,
 }
 
 static ggml_type parse_ggml_type(const char * arg) {
-    ggml_type result = GGML_TYPE_COUNT;
-    for (int j = 0; j < GGML_TYPE_COUNT; ++j) {
-        auto type = ggml_type(j);
+    for (int i = 0; i < GGML_TYPE_COUNT; ++i) {
+        auto type = (ggml_type)i;
         const auto * name = ggml_type_name(type);
-        if (name && strcmp(arg, name) == 0) {
-            result = type; break;
+        if (name && striequals(name, arg)) {
+            return type;
         }
     }
-    return result;
+    fprintf(stderr, "%s: invalid ggml_type '%s'\n", __func__, arg);
+    return GGML_TYPE_COUNT;
 }
 
 int main(int argc, char ** argv) {
@@ -254,12 +264,18 @@ int main(int argc, char ** argv) {
         } else if (strcmp(argv[arg_idx], "--output-tensor-type") == 0) {
             if (arg_idx < argc-1) {
                 params.output_tensor_type = parse_ggml_type(argv[++arg_idx]);
+                if (params.output_tensor_type == GGML_TYPE_COUNT) {
+                    usage(argv[0]);
+                }
             } else {
                 usage(argv[0]);
             }
         } else if (strcmp(argv[arg_idx], "--token-embedding-type") == 0) {
             if (arg_idx < argc-1) {
                 params.token_embedding_type = parse_ggml_type(argv[++arg_idx]);
+                if (params.token_embedding_type == GGML_TYPE_COUNT) {
+                    usage(argv[0]);
+                }
             } else {
                 usage(argv[0]);
             }