]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add enum llama_ftype, sync ggml_type to model files (#709)
authorStephan Walter <redacted>
Tue, 11 Apr 2023 15:03:51 +0000 (15:03 +0000)
committerGitHub <redacted>
Tue, 11 Apr 2023 15:03:51 +0000 (15:03 +0000)
examples/quantize/quantize.cpp
ggml.c
ggml.h
llama.cpp
llama.h

index 680757c6bf35616429ebeedc657f52b36a2a33b5..5c9e2ad9420b3a776c2008842415f064391d035c 100644 (file)
@@ -5,15 +5,15 @@
 #include <string>
 
 // usage:
-//  ./llama-quantize models/llama/ggml-model.bin models/llama/ggml-model-quant.bin type
+//  ./quantize models/llama/ggml-model.bin models/llama/ggml-model-quant.bin type
 //
 int main(int argc, char ** argv) {
     ggml_time_init();
 
     if (argc != 4) {
         fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]);
-        fprintf(stderr, "  type = 2 - q4_0\n");
-        fprintf(stderr, "  type = 3 - q4_1\n");
+        fprintf(stderr, "  type = %d - q4_0\n", LLAMA_FTYPE_MOSTLY_Q4_0);
+        fprintf(stderr, "  type = %d - q4_1\n", LLAMA_FTYPE_MOSTLY_Q4_1);
         return 1;
     }
 
@@ -27,7 +27,7 @@ int main(int argc, char ** argv) {
     const std::string fname_inp = argv[1];
     const std::string fname_out = argv[2];
 
-    const int itype = atoi(argv[3]);
+    const enum llama_ftype ftype = (enum llama_ftype)atoi(argv[3]);
 
     const int64_t t_main_start_us = ggml_time_us();
 
@@ -37,7 +37,7 @@ int main(int argc, char ** argv) {
     {
         const int64_t t_start_us = ggml_time_us();
 
-        if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), itype)) {
+        if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), ftype)) {
             fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
             return 1;
         }
diff --git a/ggml.c b/ggml.c
index 897b67d930614986e7593e7c73cb505ae9d82816..31947c4c10a91f46611a569fa4b443bca4cdc1a1 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -2560,29 +2560,26 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
 //
 
 static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
-    QK,
-    QK,
-    1,
-    1,
-    1,
-    1,
-    1,
+    [GGML_TYPE_F32]  = 1,
+    [GGML_TYPE_F16]  = 1,
+    [GGML_TYPE_Q4_0] = QK,
+    [GGML_TYPE_Q4_1] = QK,
+    [GGML_TYPE_I8]   = 1,
+    [GGML_TYPE_I16]  = 1,
+    [GGML_TYPE_I32]  = 1,
 };
-
-static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
+static_assert(GGML_TYPE_COUNT == 7, "GGML_BLCK_SIZE is outdated");
 
 static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
-    sizeof(block_q4_0),
-    sizeof(block_q4_1),
-    sizeof(int8_t ),
-    sizeof(int16_t),
-    sizeof(int32_t),
-    sizeof(ggml_fp16_t),
-    sizeof(float  ),
+    [GGML_TYPE_F32]  = sizeof(float),
+    [GGML_TYPE_F16]  = sizeof(ggml_fp16_t),
+    [GGML_TYPE_Q4_0] = sizeof(block_q4_0),
+    [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
+    [GGML_TYPE_I8]   = sizeof(int8_t),
+    [GGML_TYPE_I16]  = sizeof(int16_t),
+    [GGML_TYPE_I32]  = sizeof(int32_t),
 };
-
-// don't forget to update the array above when adding new types
-static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
+static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_SIZE is outdated");
 
 static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "NONE",
diff --git a/ggml.h b/ggml.h
index a5245a8ae6256c0bee449cff0d9112e24df152dc..7d8b7a1829dd0110aa9182866dbe19830658d6f5 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -198,13 +198,14 @@ struct ggml_object;
 struct ggml_context;
 
 enum ggml_type {
-    GGML_TYPE_Q4_0,
-    GGML_TYPE_Q4_1,
+    // explicitly numbered values are used in llama.cpp files
+    GGML_TYPE_F32  = 0,
+    GGML_TYPE_F16  = 1,
+    GGML_TYPE_Q4_0 = 2,
+    GGML_TYPE_Q4_1 = 3,
     GGML_TYPE_I8,
     GGML_TYPE_I16,
     GGML_TYPE_I32,
-    GGML_TYPE_F16,
-    GGML_TYPE_F32,
     GGML_TYPE_COUNT,
 };
 
index 54ba01eefbade0633c8e3be493e9837737cd3a4d..653558be94885eb1d58c81aee7b84845afdddbd6 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -82,7 +82,7 @@ struct llama_hparams {
     uint32_t n_head  = 32;
     uint32_t n_layer = 32;
     uint32_t n_rot   = 64;
-    uint32_t f16     = 1;
+    enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16;
 
     bool operator!=(const llama_hparams & other) const {
         return memcmp(this, &other, sizeof(llama_hparams));
@@ -432,7 +432,7 @@ struct llama_file_loader {
         hparams.n_head = file.read_u32();
         hparams.n_layer = file.read_u32();
         hparams.n_rot = file.read_u32();
-        hparams.f16 = file.read_u32();
+        hparams.ftype = (enum llama_ftype) file.read_u32();
     }
     void read_vocab() {
         vocab.id_to_token.resize(hparams.n_vocab);
@@ -458,20 +458,21 @@ struct llama_file_loader {
             llama_load_tensor_shard shard;
             uint32_t n_dims = file.read_u32();
             uint32_t name_len = file.read_u32();
-            uint32_t ftype = file.read_u32();
+            shard.type = (enum ggml_type) file.read_u32();
             shard.ne.resize(n_dims);
             file.read_raw(shard.ne.data(), sizeof(shard.ne[0]) * n_dims);
             std::string name = file.read_string(name_len);
             if (n_dims < 1 || n_dims > 2) {
                 throw format("llama.cpp: tensor '%s' should not be %u-dimensional", name.c_str(), n_dims);
             }
-            switch (ftype) {
-                case 0: shard.type = GGML_TYPE_F32; break;
-                case 1: shard.type = GGML_TYPE_F16; break;
-                case 2: shard.type = GGML_TYPE_Q4_0; break;
-                case 3: shard.type = GGML_TYPE_Q4_1; break;
+            switch (shard.type) {
+                case GGML_TYPE_F32:
+                case GGML_TYPE_F16:
+                case GGML_TYPE_Q4_0:
+                case GGML_TYPE_Q4_1:
+                    break;
                 default: {
-                    throw format("unrecognized ftype %u\n", ftype);
+                    throw format("unrecognized tensor type %u\n", shard.type);
                 }
             }
 
@@ -502,18 +503,18 @@ struct llama_file_loader {
 struct llama_file_saver {
     llama_file file;
     llama_file_loader * any_file_loader;
-    llama_file_saver(const char * fname, llama_file_loader * any_file_loader, uint32_t new_f16)
+    llama_file_saver(const char * fname, llama_file_loader * any_file_loader, enum llama_ftype new_ftype)
         : file(fname, "wb"), any_file_loader(any_file_loader) {
         fprintf(stderr, "llama.cpp: saving model to %s\n", fname);
         write_magic();
-        write_hparams(new_f16);
+        write_hparams(new_ftype);
         write_vocab();
     }
     void write_magic() {
         file.write_u32('ggjt'); // magic
         file.write_u32(1); // version
     }
-    void write_hparams(uint32_t new_f16) {
+    void write_hparams(enum llama_ftype new_ftype) {
         const llama_hparams & hparams = any_file_loader->hparams;
         file.write_u32(hparams.n_vocab);
         file.write_u32(hparams.n_embd);
@@ -521,7 +522,7 @@ struct llama_file_saver {
         file.write_u32(hparams.n_head);
         file.write_u32(hparams.n_layer);
         file.write_u32(hparams.n_rot);
-        file.write_u32(new_f16);
+        file.write_u32(new_ftype);
     }
     void write_vocab() {
         if (any_file_loader->file_version == LLAMA_FILE_VERSION_GGML) {
@@ -536,17 +537,17 @@ struct llama_file_saver {
         }
     }
     void write_tensor(llama_load_tensor & tensor, enum ggml_type new_type, const void * new_data, size_t new_size) {
-        uint32_t ftype;
         switch (new_type) {
-            case GGML_TYPE_F32:  ftype = 0; break;
-            case GGML_TYPE_F16:  ftype = 1; break;
-            case GGML_TYPE_Q4_0: ftype = 2; break;
-            case GGML_TYPE_Q4_1: ftype = 3; break;
+            case GGML_TYPE_F32:
+            case GGML_TYPE_F16:
+            case GGML_TYPE_Q4_0:
+            case GGML_TYPE_Q4_1:
+                break;
             default: LLAMA_ASSERT(false);
         }
         file.write_u32((uint32_t) tensor.ne.size());
         file.write_u32((uint32_t) tensor.name.size());
-        file.write_u32(ftype);
+        file.write_u32(new_type);
         file.write_raw(tensor.ne.data(), sizeof(tensor.ne[0]) * tensor.ne.size());
         file.write_raw(tensor.name.data(), tensor.name.size());
         file.seek(-file.tell() & 31, SEEK_CUR);
@@ -820,6 +821,16 @@ static const char *llama_file_version_name(llama_file_version version) {
     }
 }
 
+static const char *llama_ftype_name(enum llama_ftype ftype) {
+    switch (ftype) {
+        case LLAMA_FTYPE_ALL_F32:     return "all F32";
+        case LLAMA_FTYPE_MOSTLY_F16:  return "mostly F16";
+        case LLAMA_FTYPE_MOSTLY_Q4_0: return "mostly Q4_0";
+        case LLAMA_FTYPE_MOSTLY_Q4_1: return "mostly Q4_1";
+        default: LLAMA_ASSERT(false);
+    }
+}
+
 static const char *llama_model_type_name(e_model type) {
     switch (type) {
         case MODEL_7B: return "7B";
@@ -872,7 +883,7 @@ static void llama_model_load_internal(
         fprintf(stderr, "%s: n_head     = %u\n",  __func__, hparams.n_head);
         fprintf(stderr, "%s: n_layer    = %u\n",  __func__, hparams.n_layer);
         fprintf(stderr, "%s: n_rot      = %u\n",  __func__, hparams.n_rot);
-        fprintf(stderr, "%s: f16        = %u\n",  __func__, hparams.f16);
+        fprintf(stderr, "%s: ftype      = %u (%s)\n", __func__, hparams.ftype, llama_ftype_name(hparams.ftype));
         fprintf(stderr, "%s: n_ff       = %u\n",  __func__, n_ff);
         fprintf(stderr, "%s: n_parts    = %zu\n", __func__, ml->file_loaders.size());
         fprintf(stderr, "%s: model size = %s\n",  __func__, llama_model_type_name(model.type));
@@ -1544,17 +1555,17 @@ static llama_vocab::id llama_sample_top_p_top_k(
 // quantization
 //
 
-static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, int itype) {
+static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, enum llama_ftype ftype) {
     ggml_type quantized_type;
-    switch (itype) {
-        case 2: quantized_type = GGML_TYPE_Q4_0; break;
-        case 3: quantized_type = GGML_TYPE_Q4_1; break;
-        default: throw format("invalid quantization type %d\n", itype);
+    switch (ftype) {
+        case LLAMA_FTYPE_MOSTLY_Q4_0: quantized_type = GGML_TYPE_Q4_0; break;
+        case LLAMA_FTYPE_MOSTLY_Q4_1: quantized_type = GGML_TYPE_Q4_1; break;
+        default: throw format("invalid output file type %d\n", ftype);
     };
 
     std::unique_ptr<llama_model_loader> model_loader(new llama_model_loader(fname_inp.c_str(), /*use_mmap*/ false,
                                                                             /*vocab_only*/ false));
-    llama_file_saver file_saver(fname_out.c_str(), model_loader->file_loaders.at(0).get(), (uint32_t) itype);
+    llama_file_saver file_saver(fname_out.c_str(), model_loader->file_loaders.at(0).get(), ftype);
 
     size_t total_size_org = 0;
     size_t total_size_new = 0;
@@ -1745,9 +1756,9 @@ void llama_free(struct llama_context * ctx) {
 int llama_model_quantize(
         const char * fname_inp,
         const char * fname_out,
-               int   itype) {
+  enum llama_ftype   ftype) {
     try {
-        llama_model_quantize_internal(fname_inp, fname_out, itype);
+        llama_model_quantize_internal(fname_inp, fname_out, ftype);
         return 0;
     } catch (const std::string & err) {
         fprintf(stderr, "%s: failed to quantize: %s\n", __func__, err.c_str());
diff --git a/llama.h b/llama.h
index 42c364c6b342e671dcdac642b13458029b2643a7..8a0d50fb80cec630e15bd5f4296f3e87b0eca66e 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -65,6 +65,14 @@ extern "C" {
         void * progress_callback_user_data;
     };
 
+    // model file types
+    enum llama_ftype {
+        LLAMA_FTYPE_ALL_F32     = 0,
+        LLAMA_FTYPE_MOSTLY_F16  = 1,  // except 1d tensors
+        LLAMA_FTYPE_MOSTLY_Q4_0 = 2,  // except 1d tensors
+        LLAMA_FTYPE_MOSTLY_Q4_1 = 3,  // except 1d tensors
+    };
+
     LLAMA_API struct llama_context_params llama_context_default_params();
 
     LLAMA_API bool llama_mmap_supported();
@@ -85,7 +93,7 @@ extern "C" {
     LLAMA_API int llama_model_quantize(
             const char * fname_inp,
             const char * fname_out,
-                   int   itype);
+      enum llama_ftype   ftype);
 
     // Returns the KV cache that will contain the context for the
     // ongoing prediction with the model.