]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
GGUF: C++ refactor, backend support, misc fixes (#11030)
authorJohannes Gäßler <redacted>
Tue, 7 Jan 2025 17:01:58 +0000 (18:01 +0100)
committerGitHub <redacted>
Tue, 7 Jan 2025 17:01:58 +0000 (18:01 +0100)
* GGUF: C++ refactor, backend support, misc fixes

remove ggml_tensor.backend

update CODEOWNERS [no ci]

remove gguf_get_data from API

revise GGUF API data types

21 files changed:
CODEOWNERS
common/common.cpp
examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp
examples/cvector-generator/cvector-generator.cpp
examples/export-lora/export-lora.cpp
examples/gguf-hash/gguf-hash.cpp
examples/gguf-split/gguf-split.cpp
examples/gguf/gguf.cpp
examples/llava/clip.cpp
ggml/CMakeLists.txt
ggml/include/ggml-cpp.h
ggml/include/ggml.h
ggml/include/gguf.h [new file with mode: 0644]
ggml/src/CMakeLists.txt
ggml/src/ggml-impl.h
ggml/src/ggml.c
ggml/src/gguf.cpp [new file with mode: 0644]
src/llama-impl.cpp
src/llama-model-loader.cpp
src/llama-quant.cpp
tests/test-gguf.cpp

index c9fa34761f5754a8d301e8241441dfe1106a1f94..72d594b46e9111f27ee64515d971bbf016a70a82 100644 (file)
@@ -3,3 +3,9 @@
 /ci/ @ggerganov
 /.devops/*.Dockerfile @ngxson
 /examples/server/ @ngxson
+/ggml/src/ggml-cuda/fattn* @JohannesGaessler
+/ggml/src/ggml-cuda/mmq.* @JohannesGaessler
+/ggml/src/ggml-cuda/mmv.* @JohannesGaessler
+/ggml/src/ggml-cuda/mmvq.* @JohannesGaessler
+/ggml/src/ggml-opt.cpp @JohannesGaessler
+/ggml/src/gguf.cpp @JohannesGaessler
index 4fd36105e42e428a486111d7823b8da22839b47a..86e4e1e24edf969e0c39eaa242268b5a6a19324d 100644 (file)
@@ -2,6 +2,9 @@
 #define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
 #endif
 
+#include "ggml.h"
+#include "gguf.h"
+
 #include "common.h"
 #include "log.h"
 // Change JSON_ASSERT from assert() to GGML_ASSERT:
index 9c3a0c367ab8e6dbf305c8229bc1092c5bb14de6..1256abb172298510ddb1e9badec728b57b12318b 100644 (file)
@@ -1,4 +1,6 @@
 #include "ggml.h"
+#include "gguf.h"
+
 #include "llama.h"
 #include "common.h"
 #include "log.h"
index 7c9f5022850571701f9c924961649d43f668e6ee..e899c107819f311139d6fa89d3fb78e636197ad8 100644 (file)
@@ -1,7 +1,9 @@
+#include "ggml.h"
+#include "gguf.h"
+
 #include "arg.h"
 #include "common.h"
 #include "llama.h"
-#include "ggml.h"
 #include "pca.hpp"
 #include "mean.hpp"
 
index 058b5cc860213ba579ce16142c0e07e8e7d2bd9d..d5dcd20a0ae4a32699807f7c227cae6ebe7e623d 100644 (file)
@@ -1,7 +1,9 @@
-#include "arg.h"
-#include "common.h"
 #include "ggml.h"
 #include "ggml-alloc.h"
+#include "gguf.h"
+
+#include "arg.h"
+#include "common.h"
 
 #include <map>
 #include <vector>
index e96c75117f53350a0f1e545a39aebb2c5c004fa6..9523ec122f57304f7a439534c49b24c3fb64893e 100644 (file)
@@ -1,4 +1,5 @@
 #include "ggml.h"
+#include "gguf.h"
 
 #include <cstdlib>   /* abort() */
 #include <cstddef>
index 9e3d44984a06fe5bc97c6e6dbe573898baed214a..ef3ceb686f697060e4cd5337e9bc9cb869a2b9b5 100644 (file)
@@ -1,16 +1,18 @@
+#include "ggml.h"
+#include "gguf.h"
 #include "llama.h"
 #include "common.h"
 
 #include <algorithm>
+#include <cinttypes>
+#include <climits>
+#include <cstdio>
 #include <cstdlib>
+#include <stdexcept>
+#include <cstring>
 #include <fstream>
 #include <string>
 #include <vector>
-#include <climits>
-
-#include <cstdio>
-#include <cstring>
-#include <stdexcept>
 
 #if defined(_WIN32)
     #include <windows.h>
@@ -296,7 +298,7 @@ struct split_strategy {
                 total_size += ggml_nbytes(t);
             }
             total_size = total_size / 1000 / 1000; // convert to megabytes
-            printf("split %05d: n_tensors = %d, total_size = %zuM\n", i_split + 1, gguf_get_n_tensors(ctx_out), total_size);
+            printf("split %05d: n_tensors = %" PRIi64 ", total_size = %zuM\n", i_split + 1, gguf_get_n_tensors(ctx_out), total_size);
             i_split++;
         }
     }
index 7498f85efc4f90cc3bc819659ee555a81af117e3..f31989c8c55c6a5b6b2a7369e47bc48d4677d3ea 100644 (file)
@@ -1,10 +1,9 @@
 #include "ggml.h"
+#include "gguf.h"
 
 #include <cstdio>
-#include <cinttypes>
 #include <string>
 #include <sstream>
-#include <fstream>
 #include <vector>
 
 #undef MIN
@@ -135,9 +134,10 @@ static bool gguf_ex_read_0(const std::string & fname) {
 
         for (int i = 0; i < n_tensors; ++i) {
             const char * name   = gguf_get_tensor_name  (ctx, i);
+            const size_t size   = gguf_get_tensor_size  (ctx, i);
             const size_t offset = gguf_get_tensor_offset(ctx, i);
 
-            printf("%s: tensor[%d]: name = %s, offset = %zu\n", __func__, i, name, offset);
+            printf("%s: tensor[%d]: name = %s, size = %zu, offset = %zu\n", __func__, i, name, size, offset);
         }
     }
 
@@ -182,9 +182,10 @@ static bool gguf_ex_read_1(const std::string & fname, bool check_data) {
 
         for (int i = 0; i < n_tensors; ++i) {
             const char * name   = gguf_get_tensor_name  (ctx, i);
+            const size_t size   = gguf_get_tensor_size  (ctx, i);
             const size_t offset = gguf_get_tensor_offset(ctx, i);
 
-            printf("%s: tensor[%d]: name = %s, offset = %zu\n", __func__, i, name, offset);
+            printf("%s: tensor[%d]: name = %s, size = %zu, offset = %zu\n", __func__, i, name, size, offset);
         }
     }
 
@@ -199,7 +200,8 @@ static bool gguf_ex_read_1(const std::string & fname, bool check_data) {
 
             struct ggml_tensor * cur = ggml_get_tensor(ctx_data, name);
 
-            printf("%s: tensor[%d]: n_dims = %d, name = %s, data = %p\n", __func__, i, ggml_n_dims(cur), cur->name, cur->data);
+            printf("%s: tensor[%d]: n_dims = %d, ne = (%d, %d, %d, %d), name = %s, data = %p\n",
+                __func__, i, ggml_n_dims(cur), int(cur->ne[0]), int(cur->ne[1]), int(cur->ne[2]), int(cur->ne[3]), cur->name, cur->data);
 
             // print first 10 elements
             const float * data = (const float *) cur->data;
@@ -215,7 +217,7 @@ static bool gguf_ex_read_1(const std::string & fname, bool check_data) {
                 const float * data = (const float *) cur->data;
                 for (int j = 0; j < ggml_nelements(cur); ++j) {
                     if (data[j] != 100 + i) {
-                        fprintf(stderr, "%s: tensor[%d]: data[%d] = %f\n", __func__, i, j, data[j]);
+                        fprintf(stderr, "%s: tensor[%d], data[%d]: found %f, expected %f\n", __func__, i, j, data[j], float(100 + i));
                         gguf_free(ctx);
                         return false;
                     }
@@ -245,6 +247,8 @@ int main(int argc, char ** argv) {
         check_data = false;
     }
 
+    srand(123456);
+
     const std::string fname(argv[1]);
     const std::string mode (argv[2]);
 
index 3cd0d2fa8590cdfc15e086bd6f5fa2ac7c6d5e7a..7a8a3156bfdeffda5172b4859a47445cec2e4e1c 100644 (file)
@@ -7,6 +7,7 @@
 #include "ggml-cpu.h"
 #include "ggml-alloc.h"
 #include "ggml-backend.h"
+#include "gguf.h"
 
 //#ifdef GGML_USE_CUDA
 //#include "ggml-cuda.h"
@@ -262,7 +263,7 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
             {
                 const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i);
                 int arr_n = gguf_get_arr_n(ctx_gguf, i);
-                const void * data = gguf_get_arr_data(ctx_gguf, i);
+                const void * data = arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx_gguf, i);
                 std::stringstream ss;
                 ss << "[";
                 for (int j = 0; j < arr_n; j++) {
@@ -2734,7 +2735,8 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i
         total_size_org += orig_size;
         total_size_new += new_size;
         gguf_set_tensor_type(ctx_out, name.c_str(), new_type);
-        gguf_set_tensor_data(ctx_out, name.c_str(), new_data, new_size);
+        GGML_ASSERT(gguf_get_tensor_size(ctx_out, gguf_find_tensor(ctx_out, name.c_str())) == new_size);
+        gguf_set_tensor_data(ctx_out, name.c_str(), new_data);
         fout.write((const char *)new_data, new_size);
         size_t pad = GGML_PAD(new_size, gguf_get_alignment(ctx_out)) - new_size;
         for (size_t j = 0; j < pad; ++j) {
index 3935065332ea1e45a91b29cc4da055c5f48e4516..fe8acc8038b5b098768f25d1e4565e8ced1ec9e8 100644 (file)
@@ -243,7 +243,8 @@ set(GGML_PUBLIC_HEADERS
     include/ggml-metal.h
     include/ggml-rpc.h
     include/ggml-sycl.h
-    include/ggml-vulkan.h)
+    include/ggml-vulkan.h
+    include/gguf.h)
 
 set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
 #if (GGML_METAL)
index 219361af43e06cf0303554c9e6b1238bb5373e28..a12342c25debea5bf028c8c84fbe3a8ad7928780 100644 (file)
@@ -7,6 +7,7 @@
 #include "ggml.h"
 #include "ggml-alloc.h"
 #include "ggml-backend.h"
+#include "gguf.h"
 #include <memory>
 
 // Smart pointers for ggml types
index c714fc8c837bba1c615f3b4f8f914cd6a1a57672..8630d92c5c6a496150fbf87225dc5a5b0f16a641 100644 (file)
 #define GGML_ROPE_TYPE_MROPE  8
 #define GGML_ROPE_TYPE_VISION 24
 
-#define GGUF_MAGIC "GGUF"
-
-#define GGUF_VERSION 3
-
-#define GGUF_DEFAULT_ALIGNMENT 32
-
 #define GGML_UNUSED(x) (void)(x)
 
 #define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
@@ -403,12 +397,6 @@ extern "C" {
         GGML_PREC_F32,
     };
 
-    enum ggml_backend_type {
-        GGML_BACKEND_TYPE_CPU = 0,
-        GGML_BACKEND_TYPE_GPU = 10,
-        GGML_BACKEND_TYPE_GPU_SPLIT = 20,
-    };
-
     // model file types
     enum ggml_ftype {
         GGML_FTYPE_UNKNOWN        = -1,
@@ -587,8 +575,6 @@ extern "C" {
     struct ggml_tensor {
         enum ggml_type type;
 
-        GGML_DEPRECATED(enum ggml_backend_type backend, "use the buffer type to find the storage location of the tensor");
-
         struct ggml_backend_buffer * buffer;
 
         int64_t ne[GGML_MAX_DIMS]; // number of elements
@@ -2111,132 +2097,6 @@ extern "C" {
                    int64_t   n_per_row,
                const float * imatrix);
 
-    //
-    // gguf
-    //
-
-    enum gguf_type {
-        GGUF_TYPE_UINT8   = 0,
-        GGUF_TYPE_INT8    = 1,
-        GGUF_TYPE_UINT16  = 2,
-        GGUF_TYPE_INT16   = 3,
-        GGUF_TYPE_UINT32  = 4,
-        GGUF_TYPE_INT32   = 5,
-        GGUF_TYPE_FLOAT32 = 6,
-        GGUF_TYPE_BOOL    = 7,
-        GGUF_TYPE_STRING  = 8,
-        GGUF_TYPE_ARRAY   = 9,
-        GGUF_TYPE_UINT64  = 10,
-        GGUF_TYPE_INT64   = 11,
-        GGUF_TYPE_FLOAT64 = 12,
-        GGUF_TYPE_COUNT,       // marks the end of the enum
-    };
-
-    struct gguf_context;
-
-    struct gguf_init_params {
-        bool no_alloc;
-
-        // if not NULL, create a ggml_context and allocate the tensor data in it
-        struct ggml_context ** ctx;
-    };
-
-    GGML_API struct gguf_context * gguf_init_empty(void);
-    GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params);
-    //GGML_API struct gguf_context * gguf_init_from_buffer(..);
-
-    GGML_API void gguf_free(struct gguf_context * ctx);
-
-    GGML_API const char * gguf_type_name(enum gguf_type type);
-
-    GGML_API int    gguf_get_version    (const struct gguf_context * ctx);
-    GGML_API size_t gguf_get_alignment  (const struct gguf_context * ctx);
-    GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx);
-    GGML_API void * gguf_get_data       (const struct gguf_context * ctx);
-
-    GGML_API int          gguf_get_n_kv(const struct gguf_context * ctx);
-    GGML_API int          gguf_find_key(const struct gguf_context * ctx, const char * key);
-    GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int key_id);
-
-    GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int key_id);
-    GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id);
-
-    // will abort if the wrong type is used for the key
-    GGML_API uint8_t      gguf_get_val_u8  (const struct gguf_context * ctx, int key_id);
-    GGML_API int8_t       gguf_get_val_i8  (const struct gguf_context * ctx, int key_id);
-    GGML_API uint16_t     gguf_get_val_u16 (const struct gguf_context * ctx, int key_id);
-    GGML_API int16_t      gguf_get_val_i16 (const struct gguf_context * ctx, int key_id);
-    GGML_API uint32_t     gguf_get_val_u32 (const struct gguf_context * ctx, int key_id);
-    GGML_API int32_t      gguf_get_val_i32 (const struct gguf_context * ctx, int key_id);
-    GGML_API float        gguf_get_val_f32 (const struct gguf_context * ctx, int key_id);
-    GGML_API uint64_t     gguf_get_val_u64 (const struct gguf_context * ctx, int key_id);
-    GGML_API int64_t      gguf_get_val_i64 (const struct gguf_context * ctx, int key_id);
-    GGML_API double       gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
-    GGML_API bool         gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
-    GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
-    GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id);
-    GGML_API int          gguf_get_arr_n   (const struct gguf_context * ctx, int key_id);
-    GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
-    GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
-
-    GGML_API int            gguf_get_n_tensors    (const struct gguf_context * ctx);
-    GGML_API int            gguf_find_tensor      (const struct gguf_context * ctx, const char * name);
-    GGML_API size_t         gguf_get_tensor_offset(const struct gguf_context * ctx, int i);
-    GGML_API char *         gguf_get_tensor_name  (const struct gguf_context * ctx, int i);
-    GGML_API enum ggml_type gguf_get_tensor_type  (const struct gguf_context * ctx, int i);
-
-    // removes key if it exists
-    GGML_API void gguf_remove_key(struct gguf_context * ctx, const char * key);
-
-    // overrides existing values or adds a new one
-    GGML_API void gguf_set_val_u8  (struct gguf_context * ctx, const char * key, uint8_t  val);
-    GGML_API void gguf_set_val_i8  (struct gguf_context * ctx, const char * key, int8_t   val);
-    GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val);
-    GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t  val);
-    GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val);
-    GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t  val);
-    GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float    val);
-    GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val);
-    GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t  val);
-    GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double   val);
-    GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool     val);
-    GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val);
-    GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n);
-    GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, int n);
-
-    // set or add KV pairs from another context
-    GGML_API void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src);
-
-    // manage tensor info
-    GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor);
-    GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type);
-    GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size);
-
-    // writing gguf files can be done in 2 ways:
-    //
-    // - write the entire gguf_context to a binary file in a single pass:
-    //
-    //   gguf_write_to_file(ctx, fname);
-    //
-    // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data:
-    //
-    //   FILE * f = fopen(fname, "wb");
-    //   fseek(f, gguf_get_meta_size(ctx), SEEK_SET);
-    //   fwrite(f, ...);
-    //   void * data = gguf_meta_get_meta_data(ctx);
-    //   fseek(f, 0, SEEK_SET);
-    //   fwrite(f, data, gguf_get_meta_size(ctx));
-    //   free(data);
-    //   fclose(f);
-    //
-
-    // write the entire context to a binary file
-    GGML_API void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta);
-
-    // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
-    GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
-    GGML_API void   gguf_get_meta_data(const struct gguf_context * ctx, void * data);
-
 #ifdef __cplusplus
     // restrict not standard in C++
 #    if defined(__GNUC__)
diff --git a/ggml/include/gguf.h b/ggml/include/gguf.h
new file mode 100644 (file)
index 0000000..79ee202
--- /dev/null
@@ -0,0 +1,202 @@
+// This file contains functionality related to "GGUF" files, the binary file format used by ggml.
+// GGUF files have the following structure:
+//
+// 1. File magic "GGUF" (4 bytes).
+// 2. File version (uint32_t).
+// 3. Number of ggml tensors in file (int64_t).
+// 4. Number of key-value-pairs in file (int64_t).
+// 5. For each KV pair:
+//   1. The key (string).
+//   2. The value type (gguf_type).
+//   3a. If the value type is GGUF_TYPE_ARRAY:
+//     1. The type of the array (gguf_type).
+//     2. The number of elements in the array (uint64_t).
+//     3. The binary representation of each element in the array.
+//   3b. Otherwise:
+//     1. The binary representation of the value.
+// 6. For each ggml tensor:
+//   1. The tensor name (string).
+//   2. The number of dimensions of the tensor (uint32_t).
+//   3. For each dimension:
+//     1. The size of the tensor in the dimension (int64_t).
+//   4. The tensor data type (ggml_type).
+//   5. The tensor data offset in the tensor data binary blob (uint64_t).
+// 7. The tensor data binary blob (optional, aligned).
+//
+// Strings are serialized as the string length (uint64_t) followed by the C string without the null terminator.
+// All enums are stored as int32_t.
+// All bool values are stored as int8_t.
+// If the special key "general.alignment" (uint32_t) is defined it is used for alignment,
+//   otherwise GGUF_DEFAULT_ALIGNMENT is used.
+//
+// Module maintainer: Johannes Gäßler (@JohannesGaessler, johannesg@5d6.de)
+
+#pragma once
+
+#include "ggml.h"
+
+#include <stdbool.h>
+#include <stdint.h>
+
+#define GGUF_MAGIC   "GGUF"
+#define GGUF_VERSION 3
+
+#define GGUF_KEY_GENERAL_ALIGNMENT "general.alignment"
+
+#define GGUF_DEFAULT_ALIGNMENT 32
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+    // types that can be stored as GGUF KV data
+    enum gguf_type {
+        GGUF_TYPE_UINT8   = 0,
+        GGUF_TYPE_INT8    = 1,
+        GGUF_TYPE_UINT16  = 2,
+        GGUF_TYPE_INT16   = 3,
+        GGUF_TYPE_UINT32  = 4,
+        GGUF_TYPE_INT32   = 5,
+        GGUF_TYPE_FLOAT32 = 6,
+        GGUF_TYPE_BOOL    = 7,
+        GGUF_TYPE_STRING  = 8,
+        GGUF_TYPE_ARRAY   = 9,
+        GGUF_TYPE_UINT64  = 10,
+        GGUF_TYPE_INT64   = 11,
+        GGUF_TYPE_FLOAT64 = 12,
+        GGUF_TYPE_COUNT,       // marks the end of the enum
+    };
+
+    struct gguf_context;
+
+    struct gguf_init_params {
+        bool no_alloc;
+
+        // if not NULL, create a ggml_context and allocate the tensor data in it
+        struct ggml_context ** ctx;
+    };
+
+    GGML_API struct gguf_context * gguf_init_empty(void);
+    GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params);
+    //GGML_API struct gguf_context * gguf_init_from_buffer(..);
+
+    GGML_API void gguf_free(struct gguf_context * ctx);
+
+    GGML_API const char * gguf_type_name(enum gguf_type type);
+
+    GGML_API uint32_t gguf_get_version    (const struct gguf_context * ctx);
+    GGML_API size_t   gguf_get_alignment  (const struct gguf_context * ctx);
+    GGML_API size_t   gguf_get_data_offset(const struct gguf_context * ctx);
+
+    GGML_API int64_t      gguf_get_n_kv(const struct gguf_context * ctx);
+    GGML_API int64_t      gguf_find_key(const struct gguf_context * ctx, const char * key); // returns -1 if key is not found
+    GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int64_t key_id);
+
+    GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id);
+
+    // will abort if the wrong type is used for the key
+    GGML_API uint8_t      gguf_get_val_u8  (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API int8_t       gguf_get_val_i8  (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API uint16_t     gguf_get_val_u16 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API int16_t      gguf_get_val_i16 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API uint32_t     gguf_get_val_u32 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API int32_t      gguf_get_val_i32 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API float        gguf_get_val_f32 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API uint64_t     gguf_get_val_u64 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API int64_t      gguf_get_val_i64 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API double       gguf_get_val_f64 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API bool         gguf_get_val_bool(const struct gguf_context * ctx, int64_t key_id);
+    GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id);
+    GGML_API size_t       gguf_get_arr_n   (const struct gguf_context * ctx, int64_t key_id);
+
+    // get raw pointer to the first element of the array with the given key_id
+    // for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference)
+    GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id);
+
+    // get ith C string from array with given key_id
+    GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i);
+
+    GGML_API int64_t        gguf_get_n_tensors    (const struct gguf_context * ctx);
+    GGML_API int64_t        gguf_find_tensor      (const struct gguf_context * ctx, const char * name); // returns -1 if the tensor is not found
+    GGML_API size_t         gguf_get_tensor_offset(const struct gguf_context * ctx, int64_t tensor_id);
+    GGML_API const char *   gguf_get_tensor_name  (const struct gguf_context * ctx, int64_t tensor_id);
+    GGML_API enum ggml_type gguf_get_tensor_type  (const struct gguf_context * ctx, int64_t tensor_id);
+    GGML_API size_t         gguf_get_tensor_size  (const struct gguf_context * ctx, int64_t tensor_id);
+
+    // removes key if it exists, returns id that the key had prior to removal (-1 if it didn't exist)
+    GGML_API int64_t gguf_remove_key(struct gguf_context * ctx, const char * key);
+
+    // overrides an existing KV pair or adds a new one, the new KV pair is always at the back
+    GGML_API void gguf_set_val_u8  (struct gguf_context * ctx, const char * key, uint8_t      val);
+    GGML_API void gguf_set_val_i8  (struct gguf_context * ctx, const char * key, int8_t       val);
+    GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t     val);
+    GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t      val);
+    GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t     val);
+    GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t      val);
+    GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float        val);
+    GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t     val);
+    GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t      val);
+    GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double       val);
+    GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool         val);
+    GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val);
+
+    // creates a new array with n elements of the given type and copies the corresponding number of bytes from data
+    GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, size_t n);
+
+    // creates a new array with n strings and copies the corresponding strings from data
+    GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, size_t n);
+
+    // set or add KV pairs from another context
+    GGML_API void gguf_set_kv(struct gguf_context * ctx, const struct gguf_context * src);
+
+    // add tensor to GGUF context, tensor name must be unique
+    GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor);
+
+    // after changing a tensor's type, the offsets of all tensors with higher indices are immediately recalculated
+    //   in such a way that the tensor data remains as one contiguous block (except for padding)
+    GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type);
+
+    // assumes that at least gguf_get_tensor_size bytes can be read from data
+    GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data);
+
+    // writing gguf files can be done in 3 ways:
+    //
+    // - write the entire gguf_context to a binary file in a single pass:
+    //
+    //   gguf_write_to_file(ctx, fname, /*only_meta =*/ false);
+    //
+    // - write only the meta data to a file, then re-open the file and append the tensor data:
+    //
+    //   gguf_write_to_file(ctx, fname, /*only_meta =*/ true);
+    //   FILE * f = fopen(fname, "ab");
+    //   fwrite(f, ...); // write tensor data
+    //   fclose(f);
+    //
+    // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data:
+    //
+    //   FILE * f = fopen(fname, "wb");
+    //   const size_t size_meta = gguf_get_meta_size(ctx);
+    //   fseek(f, size_meta, SEEK_SET);
+    //   fwrite(f, ...); // write tensor data
+    //   void * data = malloc(size_meta);
+    //   gguf_get_meta_data(ctx, data);
+    //   rewind(f);
+    //   fwrite(data, 1, data, f);
+    //   free(data);
+    //   fclose(f);
+    //
+
+    // write the entire context to a binary file
+    GGML_API bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta);
+
+    // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
+    GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
+
+    // writes the meta data to pointer "data"
+    GGML_API void   gguf_get_meta_data(const struct gguf_context * ctx, void * data);
+
+#ifdef  __cplusplus
+}
+#endif
index 84101c32c2b50171a166ae652a6b8fe78557c2e9..ae1cd23376ccf477a8f270998f03521b244c4152 100644 (file)
@@ -208,6 +208,7 @@ add_library(ggml-base
             ../include/ggml-backend.h
             ../include/ggml-cpp.h
             ../include/ggml-opt.h
+            ../include/gguf.h
             ggml.c
             ggml-alloc.c
             ggml-backend.cpp
@@ -215,7 +216,8 @@ add_library(ggml-base
             ggml-threading.cpp
             ggml-threading.h
             ggml-quants.c
-            ggml-quants.h)
+            ggml-quants.h
+            gguf.cpp)
 
 target_include_directories(ggml-base PRIVATE .)
 
index 549772c57c90a62085395b1e27cfac99b9f5a779..eab017889c9198adb2f3ed5b7c661431149b561c 100644 (file)
@@ -3,6 +3,8 @@
 // GGML internal header
 
 #include "ggml.h"
+#include "gguf.h"
+
 #include <assert.h>
 #include <math.h>
 #include <stdlib.h> // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/
@@ -551,22 +553,15 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
 #define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
 #define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
 
-// expose GGUF internals for test code
-
-GGML_API size_t gguf_type_size(enum gguf_type type);
-
-GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
-
-struct gguf_buf {
-    void * data;
-    size_t size;
-    size_t offset;
-};
-GGML_API struct gguf_buf gguf_buf_init(size_t size);
-GGML_API void gguf_buf_free(struct gguf_buf buf);
-
-GGML_API void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta);
-
 #ifdef __cplusplus
 }
 #endif
+
+#ifdef __cplusplus
+#include <vector>
+
+// expose GGUF internals for test code
+GGML_API size_t gguf_type_size(enum gguf_type type);
+GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
+GGML_API void gguf_write_to_buf(const struct gguf_context * ctx, std::vector<int8_t> & buf, bool only_meta);
+#endif // __cplusplus
index 2bbe5f48257b2f5c3fe37c4a14cb263eac028c09..90abc6ad45233aab20ada03b8e7a9c6736e506c1 100644 (file)
@@ -1588,15 +1588,8 @@ static struct ggml_tensor * ggml_new_tensor_impl(
 
     struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs);
 
-#ifdef __clang__
-    // temporary until ggml_tensor::backend is removed
-    #pragma clang diagnostic push
-    #pragma clang diagnostic ignored "-Wdeprecated-declarations"
-#endif
-
     *result = (struct ggml_tensor) {
         /*.type         =*/ type,
-        /*.backend      =*/ GGML_BACKEND_TYPE_CPU,
         /*.buffer       =*/ NULL,
         /*.ne           =*/ { 1, 1, 1, 1 },
         /*.nb           =*/ { 0, 0, 0, 0 },
@@ -1612,10 +1605,6 @@ static struct ggml_tensor * ggml_new_tensor_impl(
         /*.padding      =*/ { 0 },
     };
 
-#ifdef __clang__
-    #pragma clang diagnostic pop
-#endif
-
     // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
     //GGML_ASSERT_ALIGNED(result->data);
 
@@ -6417,1271 +6406,6 @@ size_t ggml_quantize_chunk(
 
 ////////////////////////////////////////////////////////////////////////////////
 
-struct gguf_str {
-    uint64_t n;  // GGUFv2
-    char * data;
-};
-
-static const size_t GGUF_TYPE_SIZE[GGUF_TYPE_COUNT] = {
-    [GGUF_TYPE_UINT8]   = sizeof(uint8_t),
-    [GGUF_TYPE_INT8]    = sizeof(int8_t),
-    [GGUF_TYPE_UINT16]  = sizeof(uint16_t),
-    [GGUF_TYPE_INT16]   = sizeof(int16_t),
-    [GGUF_TYPE_UINT32]  = sizeof(uint32_t),
-    [GGUF_TYPE_INT32]   = sizeof(int32_t),
-    [GGUF_TYPE_FLOAT32] = sizeof(float),
-    [GGUF_TYPE_BOOL]    = sizeof(bool),
-    [GGUF_TYPE_STRING]  = sizeof(struct gguf_str),
-    [GGUF_TYPE_UINT64]  = sizeof(uint64_t),
-    [GGUF_TYPE_INT64]   = sizeof(int64_t),
-    [GGUF_TYPE_FLOAT64] = sizeof(double),
-    [GGUF_TYPE_ARRAY]   = 0, // undefined
-};
-static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13");
-
-static const char * GGUF_TYPE_NAME[GGUF_TYPE_COUNT] = {
-    [GGUF_TYPE_UINT8]   = "u8",
-    [GGUF_TYPE_INT8]    = "i8",
-    [GGUF_TYPE_UINT16]  = "u16",
-    [GGUF_TYPE_INT16]   = "i16",
-    [GGUF_TYPE_UINT32]  = "u32",
-    [GGUF_TYPE_INT32]   = "i32",
-    [GGUF_TYPE_FLOAT32] = "f32",
-    [GGUF_TYPE_BOOL]    = "bool",
-    [GGUF_TYPE_STRING]  = "str",
-    [GGUF_TYPE_ARRAY]   = "arr",
-    [GGUF_TYPE_UINT64]  = "u64",
-    [GGUF_TYPE_INT64]   = "i64",
-    [GGUF_TYPE_FLOAT64] = "f64",
-};
-static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13");
-
-union gguf_value {
-    uint8_t  uint8;
-    int8_t   int8;
-    uint16_t uint16;
-    int16_t  int16;
-    uint32_t uint32;
-    int32_t  int32;
-    float    float32;
-    uint64_t uint64;
-    int64_t  int64;
-    double   float64;
-    bool     bool_;
-
-    struct gguf_str str;
-
-    struct {
-        enum gguf_type type;
-
-        uint64_t n;  // GGUFv2
-        void * data;
-    } arr;
-};
-
-struct gguf_kv {
-    struct gguf_str key;
-
-    enum  gguf_type  type;
-    union gguf_value value;
-};
-
-struct gguf_header {
-    char magic[4];
-
-    uint32_t version;
-    uint64_t n_tensors; // GGUFv2
-    uint64_t n_kv;      // GGUFv2
-};
-
-struct gguf_tensor_info {
-    struct gguf_str name;
-
-    uint32_t n_dims;
-    uint64_t ne[GGML_MAX_DIMS];
-
-    enum ggml_type type;
-
-    uint64_t offset; // offset from start of `data`, must be a multiple of `ALIGNMENT`
-
-    // for writing API
-    const void * data;
-    size_t size;
-};
-
-struct gguf_context {
-    struct gguf_header header;
-
-    struct gguf_kv          * kv;
-    struct gguf_tensor_info * infos;
-
-    size_t alignment;
-    size_t offset;    // offset of `data` from beginning of file
-    size_t size;      // size of `data` in bytes
-
-    //uint8_t * padding;
-    void * data;
-};
-
-size_t gguf_type_size(enum gguf_type type) {
-    GGML_ASSERT(0 <= type && type < GGUF_TYPE_COUNT);
-    return GGUF_TYPE_SIZE[type];
-}
-
-static bool gguf_tensor_info_sanitize(struct gguf_tensor_info * info) {
-    if (info->n_dims > GGML_MAX_DIMS) {
-        fprintf(stderr, "%s: invalid number of dimensions (%" PRIu32 ")\n", __func__, info->n_dims);
-        return false;
-    }
-
-    if (info->type < 0 || info->type >= GGML_TYPE_COUNT) {
-        fprintf(stderr, "%s: invalid type (%d)\n", __func__, info->type);
-        return false;
-    }
-
-    if (strlen(info->name.data) >= GGML_MAX_NAME) {
-        fprintf(stderr, "%s: tensor '%s' name is too long\n", __func__, info->name.data);
-        return false;
-    }
-
-    for (uint32_t i = 0; i < info->n_dims; ++i) {
-        if (info->ne[i] <= 0) {
-            fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[i]);
-            return false;
-        }
-    }
-
-    // prevent overflow for total number of elements
-    if (INT64_MAX/info->ne[1] <= info->ne[0]) {
-        fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[1]);
-        return false;
-    }
-
-    if (INT64_MAX/info->ne[2] <= info->ne[0]*info->ne[1]) {
-        fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[2]);
-        return false;
-    }
-
-    if (INT64_MAX/info->ne[3] <= info->ne[0]*info->ne[1]*info->ne[2]) {
-        fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[3]);
-        return false;
-    }
-
-    return true;
-}
-
-static bool gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset) {
-    const size_t n = fread(dst, 1, size, file);
-    *offset += n;
-    return n == size;
-}
-
-static bool gguf_fread_str(FILE * file, struct gguf_str * p, size_t * offset) {
-    p->n    = 0;
-    p->data = NULL;
-
-    bool ok = true;
-
-    ok = ok && gguf_fread_el(file, &p->n, sizeof(p->n), offset);
-
-    // early exit if string length is invalid, prevents from integer overflow
-    if (p->n == SIZE_MAX) {
-        fprintf(stderr, "%s: invalid string length (%" PRIu64 ")\n", __func__, p->n);
-        return false;
-    }
-
-    p->data = calloc(p->n + 1, 1);
-    if (!p->data) {
-        fprintf(stderr, "%s: failed to allocate memory for string of length %" PRIu64 "\n", __func__, p->n);
-        return false;
-    }
-
-    ok = ok && gguf_fread_el(file,  p->data, p->n, offset);
-
-    return ok;
-}
-
-static void gguf_free_kv(struct gguf_kv * kv) {
-    if (kv->key.data) {
-        GGML_FREE(kv->key.data);
-    }
-
-    if (kv->type == GGUF_TYPE_STRING) {
-        if (kv->value.str.data) {
-            GGML_FREE(kv->value.str.data);
-        }
-    }
-
-    if (kv->type == GGUF_TYPE_ARRAY) {
-        if (kv->value.arr.data) {
-            if (kv->value.arr.type == GGUF_TYPE_STRING) {
-                for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
-                    struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[j];
-                    if (str->data) {
-                        GGML_FREE(str->data);
-                    }
-                }
-            }
-            GGML_FREE(kv->value.arr.data);
-        }
-    }
-}
-
-struct gguf_context * gguf_init_empty(void) {
-    struct gguf_context * ctx = calloc(1, sizeof(struct gguf_context));
-    if (!ctx) {
-        fprintf(stderr, "%s: failed to allocate memory for context\n", __func__);
-        return NULL;
-    }
-
-    memcpy(ctx->header.magic, GGUF_MAGIC, sizeof(ctx->header.magic));
-    ctx->header.version   = GGUF_VERSION;
-    ctx->header.n_tensors = 0;
-    ctx->header.n_kv      = 0;
-
-    ctx->kv    = NULL;
-    ctx->infos = NULL;
-
-    ctx->alignment = GGUF_DEFAULT_ALIGNMENT;
-    ctx->offset    = 0;
-    ctx->size      = 0;
-
-    ctx->data = NULL;
-
-    return ctx;
-}
-
-struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params) {
-    // offset from start of file
-    size_t offset = 0;
-
-    char magic[4];
-
-    // check the magic before making allocations
-    {
-        gguf_fread_el(file, &magic, sizeof(magic), &offset);
-
-        for (uint32_t i = 0; i < sizeof(magic); i++) {
-            if (magic[i] != GGUF_MAGIC[i]) {
-                fprintf(stderr, "%s: invalid magic characters '%c%c%c%c'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
-                return NULL;
-            }
-        }
-    }
-
-    bool ok = true;
-
-    struct gguf_context * ctx = calloc(1, sizeof(struct gguf_context));
-    if (!ctx) {
-        fprintf(stderr, "%s: failed to allocate memory for context\n", __func__);
-        return NULL;
-    }
-
-    // read the header
-    {
-        strncpy(ctx->header.magic, magic, 4);
-
-        ctx->kv    = NULL;
-        ctx->infos = NULL;
-        ctx->data  = NULL;
-
-        ok = ok && gguf_fread_el(file, &ctx->header.version,   sizeof(ctx->header.version),   &offset);
-        ok = ok && gguf_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset);
-        ok = ok && gguf_fread_el(file, &ctx->header.n_kv,      sizeof(ctx->header.n_kv),      &offset);
-
-        if (ctx->header.version == 1) {
-            fprintf(stderr, "%s: GGUFv1 is no longer supported. please use a more up-to-date version\n", __func__);
-            gguf_free(ctx);
-            return NULL;
-        }
-
-        // sanity-checks to prevent from integer/buffer overflows
-
-        ok = ok && (ctx->header.n_tensors < (SIZE_MAX/2)/sizeof(struct gguf_tensor_info));
-        ok = ok && (ctx->header.n_tensors < (SIZE_MAX/2)/ggml_tensor_overhead());
-        ok = ok && (ctx->header.n_kv      < (SIZE_MAX/2)/sizeof(struct gguf_kv));
-
-        if (!ok) {
-            fprintf(stderr, "%s: failed to read header\n", __func__);
-            gguf_free(ctx);
-            return NULL;
-        }
-    }
-
-    // read the kv pairs
-    {
-        const uint64_t n_kv = ctx->header.n_kv;
-
-        if (n_kv > 0) {
-            ctx->kv = calloc(n_kv, sizeof(struct gguf_kv));
-            if (!ctx->kv) {
-                fprintf(stderr, "%s: failed to allocate memory for kv pairs\n", __func__);
-                gguf_free(ctx);
-                return NULL;
-            }
-        }
-
-        for (uint64_t i = 0; i < n_kv; ++i) {
-            struct gguf_kv * kv = &ctx->kv[i];
-
-            //fprintf(stderr, "%s: reading kv %d\n", __func__, i);
-
-            ok = ok && gguf_fread_str(file, &kv->key,                    &offset);
-            ok = ok && gguf_fread_el (file, &kv->type, sizeof(kv->type), &offset);
-
-            //fprintf(stderr, "%s: reading kv with key %s\n", __func__, kv->key.data);
-
-            switch (kv->type) {
-                case GGUF_TYPE_UINT8:   ok = ok && gguf_fread_el (file, &kv->value.uint8,   sizeof(kv->value.uint8),   &offset); break;
-                case GGUF_TYPE_INT8:    ok = ok && gguf_fread_el (file, &kv->value.int8,    sizeof(kv->value.int8),    &offset); break;
-                case GGUF_TYPE_UINT16:  ok = ok && gguf_fread_el (file, &kv->value.uint16,  sizeof(kv->value.uint16),  &offset); break;
-                case GGUF_TYPE_INT16:   ok = ok && gguf_fread_el (file, &kv->value.int16,   sizeof(kv->value.int16),   &offset); break;
-                case GGUF_TYPE_UINT32:  ok = ok && gguf_fread_el (file, &kv->value.uint32,  sizeof(kv->value.uint32),  &offset); break;
-                case GGUF_TYPE_INT32:   ok = ok && gguf_fread_el (file, &kv->value.int32,   sizeof(kv->value.int32),   &offset); break;
-                case GGUF_TYPE_FLOAT32: ok = ok && gguf_fread_el (file, &kv->value.float32, sizeof(kv->value.float32), &offset); break;
-                case GGUF_TYPE_UINT64:  ok = ok && gguf_fread_el (file, &kv->value.uint64,  sizeof(kv->value.uint64),  &offset); break;
-                case GGUF_TYPE_INT64:   ok = ok && gguf_fread_el (file, &kv->value.int64,   sizeof(kv->value.int64),   &offset); break;
-                case GGUF_TYPE_FLOAT64: ok = ok && gguf_fread_el (file, &kv->value.float64, sizeof(kv->value.float64), &offset); break;
-                case GGUF_TYPE_BOOL:    ok = ok && gguf_fread_el (file, &kv->value.bool_,   sizeof(kv->value.bool_),   &offset); break;
-                case GGUF_TYPE_STRING:  ok = ok && gguf_fread_str(file, &kv->value.str,                                &offset); break;
-                case GGUF_TYPE_ARRAY:
-                    {
-                        ok = ok && gguf_fread_el(file, &kv->value.arr.type, sizeof(kv->value.arr.type), &offset);
-                        ok = ok && gguf_fread_el(file, &kv->value.arr.n,    sizeof(kv->value.arr.n),    &offset);
-
-                        switch (kv->value.arr.type) {
-                            case GGUF_TYPE_UINT8:
-                            case GGUF_TYPE_INT8:
-                            case GGUF_TYPE_UINT16:
-                            case GGUF_TYPE_INT16:
-                            case GGUF_TYPE_UINT32:
-                            case GGUF_TYPE_INT32:
-                            case GGUF_TYPE_FLOAT32:
-                            case GGUF_TYPE_UINT64:
-                            case GGUF_TYPE_INT64:
-                            case GGUF_TYPE_FLOAT64:
-                            case GGUF_TYPE_BOOL:
-                                {
-                                    // prevent from integer overflow in the malloc below
-                                    if (kv->value.arr.n >= SIZE_MAX/gguf_type_size(kv->value.arr.type)) {
-                                        fprintf(stderr, "%s: array size is too large (%" PRIu64 ")\n", __func__, kv->value.arr.n);
-                                        gguf_free(ctx);
-                                        return NULL;
-                                    }
-
-                                    kv->value.arr.data = calloc(kv->value.arr.n, gguf_type_size(kv->value.arr.type));
-                                    if (!kv->value.arr.data) {
-                                        fprintf(stderr, "%s: failed to allocate memory for array\n", __func__);
-                                        gguf_free(ctx);
-                                        return NULL;
-                                    }
-
-                                    ok = ok && gguf_fread_el(file, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type), &offset);
-                                } break;
-                            case GGUF_TYPE_STRING:
-                                {
-                                    // prevent from integer overflow in the malloc below
-                                    if (kv->value.arr.n >= SIZE_MAX/sizeof(struct gguf_str)) {
-                                        fprintf(stderr, "%s: array size is too large (%" PRIu64 ")\n", __func__, kv->value.arr.n);
-                                        gguf_free(ctx);
-                                        return NULL;
-                                    }
-
-                                    kv->value.arr.data = calloc(kv->value.arr.n, sizeof(struct gguf_str));
-                                    if (!kv->value.arr.data) {
-                                        fprintf(stderr, "%s: failed to allocate memory for array\n", __func__);
-                                        gguf_free(ctx);
-                                        return NULL;
-                                    }
-
-                                    for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
-                                        ok = ok && gguf_fread_str(file, &((struct gguf_str *) kv->value.arr.data)[j], &offset);
-                                    }
-                                } break;
-                            case GGUF_TYPE_ARRAY:
-                            default:
-                                {
-                                    fprintf(stderr, "%s: invalid array type %d\n", __func__, kv->value.arr.type);
-                                    ok = false;
-                                } break;
-                        }
-                    } break;
-                default:
-                    {
-                        fprintf(stderr, "%s: invalid type %d\n", __func__, kv->type);
-                        ok = false;
-                    } break;
-            }
-
-            if (!ok) {
-                break;
-            }
-        }
-
-        if (!ok) {
-            fprintf(stderr, "%s: failed to read key-value pairs\n", __func__);
-            gguf_free(ctx);
-            return NULL;
-        }
-    }
-
-    // read the tensor infos
-    if (ctx->header.n_tensors > 0) {
-        ctx->infos = calloc(ctx->header.n_tensors, sizeof(struct gguf_tensor_info));
-        if (!ctx->infos) {
-            fprintf(stderr, "%s: failed to allocate memory for tensor infos\n", __func__);
-            gguf_free(ctx);
-            return NULL;
-        }
-
-        for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
-            struct gguf_tensor_info * info = &ctx->infos[i];
-
-            for (int j = 0; j < GGML_MAX_DIMS; ++j) {
-                info->ne[j] = 1;
-            }
-
-            ok = ok && gguf_fread_str(file, &info->name,                          &offset);
-            ok = ok && gguf_fread_el (file, &info->n_dims, sizeof(info->n_dims),  &offset);
-
-            ok = ok && (info->n_dims <= GGML_MAX_DIMS);
-
-            for (uint32_t j = 0; j < info->n_dims; ++j) {
-                ok = ok && gguf_fread_el(file, &info->ne[j], sizeof(info->ne[j]), &offset);
-            }
-
-            ok = ok && gguf_fread_el (file, &info->type,   sizeof(info->type),    &offset);
-            ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset),  &offset);
-
-            ok = ok && gguf_tensor_info_sanitize(info);
-
-            // make sure there is no duplicated tensor names
-            for (uint64_t j = 0; j < i && ok; ++j) {
-                if (strcmp(info->name.data, ctx->infos[j].name.data) == 0) {
-                    fprintf(stderr, "%s: duplicated tensor name %s\n", __func__, info->name.data);
-                    ok = false;
-                }
-            }
-
-            if (!ok) {
-                fprintf(stderr, "%s: failed to read tensor info\n", __func__);
-                gguf_free(ctx);
-                return NULL;
-            }
-        }
-    }
-
-    ctx->alignment = GGUF_DEFAULT_ALIGNMENT;
-
-    int alignment_idx = gguf_find_key(ctx, "general.alignment");
-    if (alignment_idx != -1) {
-        ctx->alignment = gguf_get_val_u32(ctx, alignment_idx);
-    }
-
-    // we require the data section to be aligned, so take into account any padding
-    {
-        const size_t offset_pad = offset % ctx->alignment;
-
-        if (offset_pad != 0) {
-            offset += ctx->alignment - offset_pad;
-            fseek(file, offset, SEEK_SET);
-        }
-    }
-
-    // store the current file offset - this is where the data section starts
-    ctx->offset = offset;
-
-    // compute the total size of the data section, taking into account the alignment
-    {
-        ctx->size = 0;
-        for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
-            struct gguf_tensor_info * info = &ctx->infos[i];
-
-            const int64_t ne =
-                (int64_t) info->ne[0] *
-                (int64_t) info->ne[1] *
-                (int64_t) info->ne[2] *
-                (int64_t) info->ne[3];
-
-            if (ggml_blck_size(info->type) == 0 ) {
-                // this tensor type support have been removed:
-                fprintf(stderr, "%s: tensor '%s' of type %d: %s\n",
-                        __func__, info->name.data, (int) info->type, ggml_type_name(info->type));
-                gguf_free(ctx);
-                return NULL;
-            }
-
-            if (ne % ggml_blck_size(info->type) != 0) {
-                fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%" PRId64 ")\n",
-                        __func__, info->name.data, (int) info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type));
-                gguf_free(ctx);
-                return NULL;
-            }
-
-            const size_t size_cur = ggml_row_size(info->type, ne);
-
-            ctx->size += GGML_PAD(size_cur, ctx->alignment);
-        }
-    }
-
-    // load the tensor data only if requested
-    if (params.ctx != NULL) {
-        // if the provided gguf_context is no_alloc, then we create "empty" tensors and do not read the binary blob
-        // otherwise, we load the binary blob into the created ggml_context as well, and point the "data" members of
-        // the ggml_tensor structs to the appropriate locations in the binary blob
-
-        // compute the exact size needed for the new ggml_context
-        const size_t mem_size =
-            params.no_alloc ?
-            (ctx->header.n_tensors    )*ggml_tensor_overhead() :
-            (ctx->header.n_tensors + 1)*ggml_tensor_overhead() + ctx->size;
-
-        struct ggml_init_params pdata = {
-            .mem_size   = mem_size,
-            .mem_buffer = NULL,
-            .no_alloc   = params.no_alloc,
-        };
-
-        *params.ctx = ggml_init(pdata);
-        if (*params.ctx == NULL) {
-            fprintf(stderr, "%s: failed to initialize context\n", __func__);
-            gguf_free(ctx);
-            return NULL;
-        }
-
-        struct ggml_context * ctx_data = *params.ctx;
-
-        struct ggml_tensor * data = NULL;
-
-        if (!params.no_alloc) {
-            data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size);
-
-            ok = ok && data != NULL;
-
-            // read the binary blob with the tensor data
-            ok = ok && gguf_fread_el(file, data->data, ctx->size, &offset);
-
-            if (!ok) {
-                fprintf(stderr, "%s: failed to read tensor data\n", __func__);
-                ggml_free(ctx_data);
-                gguf_free(ctx);
-                return NULL;
-            }
-
-            ctx->data = data->data;
-        }
-
-        ggml_set_no_alloc(ctx_data, true);
-
-        // create the tensors
-        for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
-            const int64_t ne[GGML_MAX_DIMS] = {
-                ctx->infos[i].ne[0],
-                ctx->infos[i].ne[1],
-                ctx->infos[i].ne[2],
-                ctx->infos[i].ne[3],
-            };
-
-            struct ggml_tensor * cur = ggml_new_tensor(ctx_data, ctx->infos[i].type, ctx->infos[i].n_dims, ne);
-
-            ok = ok && cur != NULL;
-
-            if (!ok) {
-                break;
-            }
-
-            ggml_set_name(cur, ctx->infos[i].name.data);
-
-            // point the data member to the appropriate location in the binary blob using the tensor infos
-            if (!params.no_alloc) {
-              //cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file
-                cur->data = (char *) data->data + ctx->infos[i].offset;               // offset from data
-            }
-        }
-
-        if (!ok) {
-            fprintf(stderr, "%s: failed to read the tensor data\n", __func__);
-            ggml_free(ctx_data);
-            gguf_free(ctx);
-            return NULL;
-        }
-
-        ggml_set_no_alloc(ctx_data, params.no_alloc);
-    }
-
-    return ctx;
-}
-
-struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) {
-    FILE * file = ggml_fopen(fname, "rb");
-    if (!file) {
-        fprintf(stderr, "%s: failed to open '%s': '%s'\n", __func__, fname, strerror(errno));
-        return NULL;
-    }
-
-    struct gguf_context * result = gguf_init_from_file_impl(file, params);
-    fclose(file);
-    return result;
-}
-
-void gguf_free(struct gguf_context * ctx) {
-    if (ctx == NULL) {
-        return;
-    }
-
-    if (ctx->kv) {
-        // free string memory - not great..
-        for (uint64_t i = 0; i < ctx->header.n_kv; ++i) {
-            gguf_free_kv(&ctx->kv[i]);
-        }
-
-        GGML_FREE(ctx->kv);
-    }
-
-    if (ctx->infos) {
-        for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
-            struct gguf_tensor_info * info = &ctx->infos[i];
-
-            if (info->name.data) {
-                GGML_FREE(info->name.data);
-            }
-        }
-
-        GGML_FREE(ctx->infos);
-    }
-
-    GGML_FREE(ctx);
-}
-
-const char * gguf_type_name(enum gguf_type type) {
-    return GGUF_TYPE_NAME[type];
-}
-
-int gguf_get_version(const struct gguf_context * ctx) {
-    return ctx->header.version;
-}
-
-size_t gguf_get_alignment(const struct gguf_context * ctx) {
-    return ctx->alignment;
-}
-
-size_t gguf_get_data_offset(const struct gguf_context * ctx) {
-    return ctx->offset;
-}
-
-void * gguf_get_data(const struct gguf_context * ctx) {
-    return ctx->data;
-}
-
-int gguf_get_n_kv(const struct gguf_context * ctx) {
-    return ctx->header.n_kv;
-}
-
-int gguf_find_key(const struct gguf_context * ctx, const char * key) {
-    // return -1 if key not found
-    int keyfound = -1;
-
-    const int n_kv = gguf_get_n_kv(ctx);
-
-    for (int i = 0; i < n_kv; ++i) {
-        if (strcmp(key, gguf_get_key(ctx, i)) == 0) {
-            keyfound = i;
-            break;
-        }
-    }
-
-    return keyfound;
-}
-
-const char * gguf_get_key(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    return ctx->kv[key_id].key.data;
-}
-
-enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    return ctx->kv[key_id].type;
-}
-
-enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
-    return ctx->kv[key_id].value.arr.type;
-}
-
-const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
-    return ctx->kv[key_id].value.arr.data;
-}
-
-const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
-    struct gguf_kv * kv = &ctx->kv[key_id];
-    struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
-    return str->data;
-}
-
-int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
-    return ctx->kv[key_id].value.arr.n;
-}
-
-uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8);
-    return ctx->kv[key_id].value.uint8;
-}
-
-int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8);
-    return ctx->kv[key_id].value.int8;
-}
-
-uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16);
-    return ctx->kv[key_id].value.uint16;
-}
-
-int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16);
-    return ctx->kv[key_id].value.int16;
-}
-
-uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32);
-    return ctx->kv[key_id].value.uint32;
-}
-
-int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32);
-    return ctx->kv[key_id].value.int32;
-}
-
-float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32);
-    return ctx->kv[key_id].value.float32;
-}
-
-uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64);
-    return ctx->kv[key_id].value.uint64;
-}
-
-int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64);
-    return ctx->kv[key_id].value.int64;
-}
-
-double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64);
-    return ctx->kv[key_id].value.float64;
-}
-
-bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL);
-    return ctx->kv[key_id].value.bool_;
-}
-
-const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING);
-    return ctx->kv[key_id].value.str.data;
-}
-
-const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_ARRAY);
-    GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_STRING);
-    return &ctx->kv[key_id].value;
-}
-
-int gguf_get_n_tensors(const struct gguf_context * ctx) {
-    return ctx->header.n_tensors;
-}
-
-int gguf_find_tensor(const struct gguf_context * ctx, const char * name) {
-    // return -1 if tensor not found
-    int tensorfound = -1;
-
-    const int n_tensors = gguf_get_n_tensors(ctx);
-
-    for (int i = 0; i < n_tensors; ++i) {
-        if (strcmp(name, gguf_get_tensor_name(ctx, i)) == 0) {
-            tensorfound = i;
-            break;
-        }
-    }
-
-    return tensorfound;
-}
-
-size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i) {
-    return ctx->infos[i].offset;
-}
-
-char * gguf_get_tensor_name(const struct gguf_context * ctx, int i) {
-    return ctx->infos[i].name.data;
-}
-
-enum ggml_type gguf_get_tensor_type(const struct gguf_context * ctx, int i) {
-    return ctx->infos[i].type;
-}
-
-// returns the index
-static int gguf_get_or_add_key(struct gguf_context * ctx, const char * key) {
-    const int idx = gguf_find_key(ctx, key);
-    if (idx >= 0) {
-        return idx;
-    }
-
-    const int n_kv = gguf_get_n_kv(ctx);
-
-    ctx->kv = realloc(ctx->kv, (n_kv + 1) * sizeof(struct gguf_kv));
-    ctx->kv[n_kv].key.n    = strlen(key);
-    ctx->kv[n_kv].key.data = strdup(key);
-    ctx->header.n_kv++;
-
-    return n_kv;
-}
-
-void gguf_remove_key(struct gguf_context * ctx, const char * key) {
-    const int idx = gguf_find_key(ctx, key);
-    if (idx >= 0) {
-        const int n_kv = gguf_get_n_kv(ctx);
-        gguf_free_kv(&ctx->kv[idx]);
-        for (int i = idx; i < n_kv-1; ++i) {
-            ctx->kv[i] = ctx->kv[i+1];
-        }
-        ctx->kv = realloc(ctx->kv, (n_kv - 1) * sizeof(struct gguf_kv));
-        ctx->header.n_kv--;
-    }
-}
-
-void gguf_set_val_u8(struct gguf_context * ctx, const char * key, uint8_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type        = GGUF_TYPE_UINT8;
-    ctx->kv[idx].value.uint8 = val;
-}
-
-void gguf_set_val_i8(struct gguf_context * ctx, const char * key, int8_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type       = GGUF_TYPE_INT8;
-    ctx->kv[idx].value.int8 = val;
-}
-
-void gguf_set_val_u16(struct gguf_context * ctx, const char * key, uint16_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type         = GGUF_TYPE_UINT16;
-    ctx->kv[idx].value.uint16 = val;
-}
-
-void gguf_set_val_i16(struct gguf_context * ctx, const char * key, int16_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type        = GGUF_TYPE_INT16;
-    ctx->kv[idx].value.int16 = val;
-}
-
-void gguf_set_val_u32(struct gguf_context * ctx, const char * key, uint32_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type         = GGUF_TYPE_UINT32;
-    ctx->kv[idx].value.uint32 = val;
-}
-
-void gguf_set_val_i32(struct gguf_context * ctx, const char * key, int32_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type        = GGUF_TYPE_INT32;
-    ctx->kv[idx].value.int32 = val;
-}
-
-void gguf_set_val_f32(struct gguf_context * ctx, const char * key, float val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type          = GGUF_TYPE_FLOAT32;
-    ctx->kv[idx].value.float32 = val;
-}
-
-void gguf_set_val_u64(struct gguf_context * ctx, const char * key, uint64_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type         = GGUF_TYPE_UINT64;
-    ctx->kv[idx].value.uint64 = val;
-}
-
-void gguf_set_val_i64(struct gguf_context * ctx, const char * key, int64_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type        = GGUF_TYPE_INT64;
-    ctx->kv[idx].value.int64 = val;
-}
-
-void gguf_set_val_f64(struct gguf_context * ctx, const char * key, double val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type          = GGUF_TYPE_FLOAT64;
-    ctx->kv[idx].value.float64 = val;
-}
-
-void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type        = GGUF_TYPE_BOOL;
-    ctx->kv[idx].value.bool_ = val;
-}
-
-void gguf_set_val_str(struct gguf_context * ctx, const char * key, const char * val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type           = GGUF_TYPE_STRING;
-    ctx->kv[idx].value.str.n    = strlen(val);
-    ctx->kv[idx].value.str.data = strdup(val);
-}
-
-void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type           = GGUF_TYPE_ARRAY;
-    ctx->kv[idx].value.arr.type = type;
-    ctx->kv[idx].value.arr.n    = n;
-    ctx->kv[idx].value.arr.data = GGML_CALLOC(n, gguf_type_size(type));
-    memcpy(ctx->kv[idx].value.arr.data, data, n*gguf_type_size(type));
-}
-
-void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char ** data, int n) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type           = GGUF_TYPE_ARRAY;
-    ctx->kv[idx].value.arr.type = GGUF_TYPE_STRING;
-    ctx->kv[idx].value.arr.n    = n;
-    ctx->kv[idx].value.arr.data = GGML_CALLOC(n, sizeof(struct gguf_str));
-    for (int i = 0; i < n; i++) {
-        struct gguf_str * str = &((struct gguf_str *)ctx->kv[idx].value.arr.data)[i];
-        str->n    = strlen(data[i]);
-        str->data = strdup(data[i]);
-    }
-}
-
-// set or add KV pairs from another context
-void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) {
-    for (uint32_t i = 0; i < src->header.n_kv; i++) {
-        switch (src->kv[i].type) {
-            case GGUF_TYPE_UINT8:   gguf_set_val_u8  (ctx, src->kv[i].key.data, src->kv[i].value.uint8);    break;
-            case GGUF_TYPE_INT8:    gguf_set_val_i8  (ctx, src->kv[i].key.data, src->kv[i].value.int8);     break;
-            case GGUF_TYPE_UINT16:  gguf_set_val_u16 (ctx, src->kv[i].key.data, src->kv[i].value.uint16);   break;
-            case GGUF_TYPE_INT16:   gguf_set_val_i16 (ctx, src->kv[i].key.data, src->kv[i].value.int16);    break;
-            case GGUF_TYPE_UINT32:  gguf_set_val_u32 (ctx, src->kv[i].key.data, src->kv[i].value.uint32);   break;
-            case GGUF_TYPE_INT32:   gguf_set_val_i32 (ctx, src->kv[i].key.data, src->kv[i].value.int32);    break;
-            case GGUF_TYPE_FLOAT32: gguf_set_val_f32 (ctx, src->kv[i].key.data, src->kv[i].value.float32);  break;
-            case GGUF_TYPE_UINT64:  gguf_set_val_u64 (ctx, src->kv[i].key.data, src->kv[i].value.uint64);   break;
-            case GGUF_TYPE_INT64:   gguf_set_val_i64 (ctx, src->kv[i].key.data, src->kv[i].value.int64);    break;
-            case GGUF_TYPE_FLOAT64: gguf_set_val_f64 (ctx, src->kv[i].key.data, src->kv[i].value.float64);  break;
-            case GGUF_TYPE_BOOL:    gguf_set_val_bool(ctx, src->kv[i].key.data, src->kv[i].value.bool_);    break;
-            case GGUF_TYPE_STRING:  gguf_set_val_str (ctx, src->kv[i].key.data, src->kv[i].value.str.data); break;
-            case GGUF_TYPE_ARRAY:
-                {
-                    if (src->kv[i].value.arr.type == GGUF_TYPE_STRING) {
-                        const char ** data = GGML_CALLOC(src->kv[i].value.arr.n, sizeof(char *));
-                        for (uint32_t j = 0; j < src->kv[i].value.arr.n; j++) {
-                            data[j] = ((struct gguf_str *)src->kv[i].value.arr.data)[j].data;
-                        }
-                        gguf_set_arr_str(ctx, src->kv[i].key.data, data, src->kv[i].value.arr.n);
-                        GGML_FREE((void *)data);
-                    } else if (src->kv[i].value.arr.type == GGUF_TYPE_ARRAY) {
-                        GGML_ABORT("nested arrays not supported");
-                    } else {
-                        gguf_set_arr_data(ctx, src->kv[i].key.data, src->kv[i].value.arr.type, src->kv[i].value.arr.data, src->kv[i].value.arr.n);
-                    }
-                } break;
-            default: GGML_ABORT("invalid type");
-        }
-    }
-}
-
-void gguf_add_tensor(
-             struct gguf_context * ctx,
-        const struct ggml_tensor * tensor) {
-    GGML_ASSERT(tensor);
-    if (gguf_find_tensor(ctx, tensor->name) != -1) {
-        GGML_ABORT("duplicated tensor name");
-    }
-
-    const int idx = ctx->header.n_tensors;
-    ctx->infos = realloc(ctx->infos, (idx + 1)*sizeof(struct gguf_tensor_info));
-
-    ctx->infos[idx].name.n    = strlen(tensor->name);
-    ctx->infos[idx].name.data = strdup(tensor->name);
-
-    for (int i = 0; i < GGML_MAX_DIMS; ++i) {
-        ctx->infos[idx].ne[i] = 1;
-    }
-
-    ctx->infos[idx].n_dims = ggml_n_dims(tensor);
-    for (uint32_t i = 0; i < ctx->infos[idx].n_dims; i++) {
-        ctx->infos[idx].ne[i] = tensor->ne[i];
-    }
-
-    ctx->infos[idx].type   = tensor->type;
-    ctx->infos[idx].offset = 0;
-    ctx->infos[idx].data   = tensor->data;
-    ctx->infos[idx].size   = ggml_nbytes(tensor);
-
-    if (ctx->header.n_tensors > 0) {
-        ctx->infos[idx].offset = ctx->infos[idx - 1].offset + GGML_PAD(ctx->infos[idx - 1].size, ctx->alignment);
-    }
-
-    ctx->header.n_tensors++;
-}
-
-void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type) {
-    const int idx = gguf_find_tensor(ctx, name);
-    if (idx < 0) {
-        GGML_ABORT("tensor not found");
-    }
-
-    ctx->infos[idx].type = type;
-}
-
-void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size) {
-    const int idx = gguf_find_tensor(ctx, name);
-    if (idx < 0) {
-        GGML_ABORT("tensor not found");
-    }
-
-    ctx->infos[idx].data = data;
-    ctx->infos[idx].size = size;
-
-    // update offsets
-    for (uint32_t i = idx + 1; i < ctx->header.n_tensors; ++i) {
-        ctx->infos[i].offset = ctx->infos[i - 1].offset + GGML_PAD(ctx->infos[i - 1].size, ctx->alignment);
-    }
-}
-
-//static void gguf_fwrite_str(FILE * file, const struct gguf_str * val) {
-//    fwrite(&val->n,   sizeof(val->n),    1, file);
-//    fwrite(val->data, sizeof(char), val->n, file);
-//}
-//
-//static void gguf_fwrite_el(FILE * file, const void * val, size_t size) {
-//    fwrite(val, sizeof(char), size, file);
-//}
-
-struct gguf_buf gguf_buf_init(size_t size) {
-    struct gguf_buf buf = {
-        /*buf.data   =*/ size == 0 ? NULL : GGML_CALLOC(1, size),
-        /*buf.size   =*/ size,
-        /*buf.offset =*/ 0,
-    };
-
-    return buf;
-}
-
-void gguf_buf_free(struct gguf_buf buf) {
-    if (buf.data) {
-        GGML_FREE(buf.data);
-    }
-}
-
-static void gguf_buf_grow(struct gguf_buf * buf, size_t size) {
-    if (buf->offset + size > buf->size) {
-        buf->size = 1.5*(buf->offset + size);
-        if (buf->data) {
-            buf->data = realloc(buf->data, buf->size);
-        }
-    }
-}
-
-static void gguf_bwrite_str(struct gguf_buf * buf, const struct gguf_str * val) {
-    gguf_buf_grow(buf, sizeof(val->n) + val->n);
-
-    if (buf->data) {
-        memcpy((char *) buf->data + buf->offset, &val->n, sizeof(val->n));
-    }
-    buf->offset += sizeof(val->n);
-
-    if (buf->data) {
-        memcpy((char *) buf->data + buf->offset, val->data, val->n);
-    }
-    buf->offset += val->n;
-}
-
-static void gguf_bwrite_el(struct gguf_buf * buf, const void * val, size_t el_size) {
-    gguf_buf_grow(buf, el_size);
-
-    if (buf->data) {
-        memcpy((char *) buf->data + buf->offset, val, el_size);
-    }
-    buf->offset += el_size;
-}
-
-void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) {
-    // write header
-    gguf_bwrite_el(buf, &ctx->header.magic,     sizeof(ctx->header.magic));
-    gguf_bwrite_el(buf, &ctx->header.version,   sizeof(ctx->header.version));
-    gguf_bwrite_el(buf, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors));
-    gguf_bwrite_el(buf, &ctx->header.n_kv,      sizeof(ctx->header.n_kv));
-
-    // write key-value pairs
-    for (uint32_t i = 0; i < ctx->header.n_kv; ++i) {
-        struct gguf_kv * kv = &ctx->kv[i];
-
-        gguf_bwrite_str(buf, &kv->key);
-        gguf_bwrite_el (buf, &kv->type, sizeof(kv->type));
-
-        switch (kv->type) {
-            case GGUF_TYPE_UINT8:   gguf_bwrite_el( buf, &kv->value.uint8,   sizeof(kv->value.uint8)  ); break;
-            case GGUF_TYPE_INT8:    gguf_bwrite_el (buf, &kv->value.int8,    sizeof(kv->value.int8)   ); break;
-            case GGUF_TYPE_UINT16:  gguf_bwrite_el (buf, &kv->value.uint16,  sizeof(kv->value.uint16) ); break;
-            case GGUF_TYPE_INT16:   gguf_bwrite_el (buf, &kv->value.int16,   sizeof(kv->value.int16)  ); break;
-            case GGUF_TYPE_UINT32:  gguf_bwrite_el (buf, &kv->value.uint32,  sizeof(kv->value.uint32) ); break;
-            case GGUF_TYPE_INT32:   gguf_bwrite_el (buf, &kv->value.int32,   sizeof(kv->value.int32)  ); break;
-            case GGUF_TYPE_FLOAT32: gguf_bwrite_el (buf, &kv->value.float32, sizeof(kv->value.float32)); break;
-            case GGUF_TYPE_UINT64:  gguf_bwrite_el (buf, &kv->value.uint64,  sizeof(kv->value.uint64) ); break;
-            case GGUF_TYPE_INT64:   gguf_bwrite_el (buf, &kv->value.int64,   sizeof(kv->value.int64)  ); break;
-            case GGUF_TYPE_FLOAT64: gguf_bwrite_el (buf, &kv->value.float64, sizeof(kv->value.float64)); break;
-            case GGUF_TYPE_BOOL:    gguf_bwrite_el (buf, &kv->value.bool_,   sizeof(kv->value.bool_)  ); break;
-            case GGUF_TYPE_STRING:  gguf_bwrite_str(buf, &kv->value.str                               ); break;
-            case GGUF_TYPE_ARRAY:
-                {
-                    gguf_bwrite_el(buf, &kv->value.arr.type, sizeof(kv->value.arr.type));
-                    gguf_bwrite_el(buf, &kv->value.arr.n,    sizeof(kv->value.arr.n)   );
-
-                    switch (kv->value.arr.type) {
-                        case GGUF_TYPE_UINT8:
-                        case GGUF_TYPE_INT8:
-                        case GGUF_TYPE_UINT16:
-                        case GGUF_TYPE_INT16:
-                        case GGUF_TYPE_UINT32:
-                        case GGUF_TYPE_INT32:
-                        case GGUF_TYPE_FLOAT32:
-                        case GGUF_TYPE_UINT64:
-                        case GGUF_TYPE_INT64:
-                        case GGUF_TYPE_FLOAT64:
-                        case GGUF_TYPE_BOOL:
-                            {
-                                gguf_bwrite_el(buf, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type));
-                            } break;
-                        case GGUF_TYPE_STRING:
-                            {
-                                for (uint32_t j = 0; j < kv->value.arr.n; ++j) {
-                                    gguf_bwrite_str(buf, &((struct gguf_str *) kv->value.arr.data)[j]);
-                                }
-                            } break;
-                        case GGUF_TYPE_ARRAY:
-                        default: GGML_ABORT("invalid type");
-                    }
-                } break;
-            default: GGML_ABORT("invalid type");
-        }
-    }
-
-    // write tensor infos
-    for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
-        struct gguf_tensor_info * info = &ctx->infos[i];
-
-        gguf_bwrite_str(buf, &info->name);
-        gguf_bwrite_el (buf, &info->n_dims, sizeof(info->n_dims));
-        for (uint32_t j = 0; j < info->n_dims; ++j) {
-            gguf_bwrite_el(buf, &info->ne[j], sizeof(info->ne[j]));
-        }
-        gguf_bwrite_el(buf, &info->type,   sizeof(info->type));
-        gguf_bwrite_el(buf, &info->offset, sizeof(info->offset));
-    }
-
-    // we require the data section to be aligned, so take into account any padding
-    {
-        const size_t offset     = buf->offset;
-        const size_t offset_pad = GGML_PAD(offset, ctx->alignment);
-
-        if (offset_pad != offset) {
-            uint8_t pad = 0;
-            for (size_t i = 0; i < offset_pad - offset; ++i) {
-                gguf_bwrite_el(buf, &pad, sizeof(pad));
-            }
-        }
-    }
-
-    if (only_meta) {
-        return;
-    }
-
-    size_t offset = 0;
-
-    // write tensor data
-    for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
-        struct gguf_tensor_info * info = &ctx->infos[i];
-
-        const size_t size     = info->size;
-        const size_t size_pad = GGML_PAD(size, ctx->alignment);
-
-        gguf_bwrite_el(buf, info->data, size);
-
-        if (size_pad != size) {
-            uint8_t pad = 0;
-            for (size_t j = 0; j < size_pad - size; ++j) {
-                gguf_bwrite_el(buf, &pad, sizeof(pad));
-            }
-        }
-
-        GGML_ASSERT(offset == info->offset);
-
-        offset += size_pad;
-    }
-}
-
-void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) {
-    FILE * file = ggml_fopen(fname, "wb");
-    if (!file) {
-        GGML_ABORT("failed to open file for writing");
-    }
-
-    struct gguf_buf buf = gguf_buf_init(16*1024);
-
-    gguf_write_to_buf(ctx, &buf, only_meta);
-
-    fwrite(buf.data, 1, buf.offset, file);
-
-    gguf_buf_free(buf);
-
-    fclose(file);
-}
-
-size_t gguf_get_meta_size(const struct gguf_context * ctx) {
-    // no allocs - only compute size
-    struct gguf_buf buf = gguf_buf_init(0);
-
-    gguf_write_to_buf(ctx, &buf, true);
-
-    return buf.offset;
-}
-
-void gguf_get_meta_data(const struct gguf_context * ctx, void * data) {
-    struct gguf_buf buf = gguf_buf_init(16*1024);
-
-    gguf_write_to_buf(ctx, &buf, true);
-
-    memcpy(data, buf.data, buf.offset);
-
-    gguf_buf_free(buf);
-}
-
 void ggml_log_set(ggml_log_callback log_callback, void * user_data) {
     g_logger_state.log_callback = log_callback ? log_callback : ggml_log_callback_default;
     g_logger_state.log_callback_user_data = user_data;
diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp
new file mode 100644 (file)
index 0000000..655ed60
--- /dev/null
@@ -0,0 +1,1325 @@
+#include "ggml.h"
+#include "ggml-backend.h"
+#include "ggml-impl.h"
+#include "gguf.h"
+
+#include <cinttypes>
+#include <cstddef>
+#include <cstdint>
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+#include <map>
+#include <new>
+#include <stdexcept>
+#include <string>
+#include <vector>
+
+template <typename T>
+struct type_to_gguf_type;
+
+template <>
+struct type_to_gguf_type<uint8_t> {
+    static constexpr enum gguf_type value = GGUF_TYPE_UINT8;
+};
+
+template <>
+struct type_to_gguf_type<int8_t> {
+    static constexpr enum gguf_type value = GGUF_TYPE_INT8;
+};
+
+template <>
+struct type_to_gguf_type<uint16_t> {
+    static constexpr enum gguf_type value = GGUF_TYPE_UINT16;
+};
+
+template <>
+struct type_to_gguf_type<int16_t> {
+    static constexpr enum gguf_type value = GGUF_TYPE_INT16;
+};
+
+template <>
+struct type_to_gguf_type<uint32_t> {
+    static constexpr enum gguf_type value = GGUF_TYPE_UINT32;
+};
+
+template <>
+struct type_to_gguf_type<int32_t> {
+    static constexpr enum gguf_type value = GGUF_TYPE_INT32;
+};
+
+template <>
+struct type_to_gguf_type<float> {
+    static constexpr enum gguf_type value = GGUF_TYPE_FLOAT32;
+};
+
+template <>
+struct type_to_gguf_type<bool> {
+    static constexpr enum gguf_type value = GGUF_TYPE_BOOL;
+};
+
+template <>
+struct type_to_gguf_type<std::string> {
+    static constexpr enum gguf_type value = GGUF_TYPE_STRING;
+};
+
+template <>
+struct type_to_gguf_type<uint64_t> {
+    static constexpr enum gguf_type value = GGUF_TYPE_UINT64;
+};
+
+template <>
+struct type_to_gguf_type<int64_t> {
+    static constexpr enum gguf_type value = GGUF_TYPE_INT64;
+};
+
+template <>
+struct type_to_gguf_type<double> {
+    static constexpr enum gguf_type value = GGUF_TYPE_FLOAT64;
+};
+
+static const std::map<gguf_type, size_t> GGUF_TYPE_SIZE = {
+    {GGUF_TYPE_UINT8,   sizeof(uint8_t)},
+    {GGUF_TYPE_INT8,    sizeof(int8_t)},
+    {GGUF_TYPE_UINT16,  sizeof(uint16_t)},
+    {GGUF_TYPE_INT16,   sizeof(int16_t)},
+    {GGUF_TYPE_UINT32,  sizeof(uint32_t)},
+    {GGUF_TYPE_INT32,   sizeof(int32_t)},
+    {GGUF_TYPE_FLOAT32, sizeof(float)},
+    {GGUF_TYPE_BOOL,    sizeof(int8_t)},
+    {GGUF_TYPE_STRING,  0}, // undefined
+    {GGUF_TYPE_ARRAY,   0}, // undefined
+    {GGUF_TYPE_UINT64,  sizeof(uint64_t)},
+    {GGUF_TYPE_INT64,   sizeof(int64_t)},
+    {GGUF_TYPE_FLOAT64, sizeof(double)},
+};
+static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13");
+
+static const std::map<gguf_type, const char *> GGUF_TYPE_NAME = {
+    {GGUF_TYPE_UINT8,   "u8"},
+    {GGUF_TYPE_INT8,    "i8"},
+    {GGUF_TYPE_UINT16,  "u16"},
+    {GGUF_TYPE_INT16,   "i16"},
+    {GGUF_TYPE_UINT32,  "u32"},
+    {GGUF_TYPE_INT32,   "i32"},
+    {GGUF_TYPE_FLOAT32, "f32"},
+    {GGUF_TYPE_BOOL,    "bool"},
+    {GGUF_TYPE_STRING,  "str"},
+    {GGUF_TYPE_ARRAY,   "arr"},
+    {GGUF_TYPE_UINT64,  "u64"},
+    {GGUF_TYPE_INT64,   "i64"},
+    {GGUF_TYPE_FLOAT64, "f64"},
+};
+static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13");
+
+size_t gguf_type_size(enum gguf_type type) {
+    auto it = GGUF_TYPE_SIZE.find(type);
+    return it == GGUF_TYPE_SIZE.end() ? 0 : it->second;
+}
+
+struct gguf_kv {
+    std::string key;
+
+    bool is_array;
+    enum gguf_type type;
+
+    std::vector<int8_t>      data;
+    std::vector<std::string> data_string;
+
+    template <typename T>
+    gguf_kv(const std::string & key, const T value)
+            : key(key), is_array(false), type(type_to_gguf_type<T>::value) {
+        GGML_ASSERT(!key.empty());
+        data.resize(sizeof(T));
+        memcpy(data.data(), &value, sizeof(T));
+    }
+
+    template <typename T>
+    gguf_kv(const std::string & key, const std::vector<T> & value)
+            : key(key), is_array(true), type(type_to_gguf_type<T>::value) {
+        GGML_ASSERT(!key.empty());
+        data.resize(value.size()*sizeof(T));
+        for (size_t i = 0; i < value.size(); ++i) {
+            const T tmp = value[i];
+            memcpy(data.data() + i*sizeof(T), &tmp, sizeof(T));
+        }
+    }
+
+    gguf_kv(const std::string & key, const std::string & value)
+            : key(key), is_array(false), type(GGUF_TYPE_STRING) {
+        GGML_ASSERT(!key.empty());
+        data_string.push_back(value);
+    }
+
+    gguf_kv(const std::string & key, const std::vector<std::string> & value)
+            : key(key), is_array(true), type(GGUF_TYPE_STRING) {
+        GGML_ASSERT(!key.empty());
+        data_string = value;
+    }
+
+    const std::string & get_key() const {
+        return key;
+    }
+
+    const enum gguf_type & get_type() const {
+        return type;
+    }
+
+    size_t get_ne() const {
+        if (type == GGUF_TYPE_STRING) {
+            const size_t ne = data_string.size();
+            GGML_ASSERT(is_array || ne == 1);
+            return ne;
+        }
+        const size_t type_size = gguf_type_size(type);
+        GGML_ASSERT(data.size() % type_size == 0);
+        const size_t ne = data.size() / type_size;
+        GGML_ASSERT(is_array || ne == 1);
+        return ne;
+    }
+
+    template <typename T>
+    const T & get_val(const size_t i = 0) const {
+        GGML_ASSERT(type_to_gguf_type<T>::value == type);
+        if constexpr (std::is_same<T, std::string>::value) {
+            GGML_ASSERT(data_string.size() >= i+1);
+            return data_string[i];
+        }
+        const size_t type_size = gguf_type_size(type);
+        GGML_ASSERT(data.size() % type_size == 0);
+        GGML_ASSERT(data.size() >= (i+1)*type_size);
+        return reinterpret_cast<const T *>(data.data())[i];
+    }
+
+    void cast(const enum gguf_type new_type) {
+        const size_t new_type_size = gguf_type_size(new_type);
+        GGML_ASSERT(data.size() % new_type_size == 0);
+        type = new_type;
+    }
+};
+
+struct gguf_tensor_info {
+    struct ggml_tensor t; // for holding the equivalent info
+    uint64_t offset;      // offset from start of `data`, must be a multiple of `ALIGNMENT`
+};
+
+struct gguf_context {
+    uint32_t version = GGUF_VERSION;
+
+    std::vector<struct gguf_kv> kv;
+    std::vector<struct gguf_tensor_info> info;
+
+    size_t alignment = GGUF_DEFAULT_ALIGNMENT;
+    size_t offset    = 0; // offset of `data` from beginning of file
+    size_t size      = 0; // size of `data` in bytes
+
+    void * data = nullptr;
+};
+
+struct gguf_reader {
+    FILE * file;
+
+    gguf_reader(FILE * file) : file(file) {}
+
+    template <typename T>
+    bool read(T & dst) const {
+        return fread(&dst, 1, sizeof(dst), file) == sizeof(dst);
+    }
+
+    template <typename T>
+    bool read(std::vector<T> & dst, const size_t n) const {
+        dst.resize(n);
+        for (size_t i = 0; i < dst.size(); ++i) {
+            if constexpr (std::is_same<T, bool>::value) {
+                bool tmp;
+                if (!read(tmp)) {
+                    return false;
+                }
+                dst[i] = tmp;
+            } else {
+                if (!read(dst[i])) {
+                    return false;
+                }
+            }
+        }
+        return true;
+    }
+
+    bool read(bool & dst) const {
+        int8_t tmp = -1;
+        if (!read(tmp)) {
+            return false;
+        }
+        dst = tmp != 0;
+        return true;
+    }
+
+    bool read(enum ggml_type & dst) const {
+        int32_t tmp = -1;
+        if (!read(tmp)) {
+            return false;
+        }
+        dst = ggml_type(tmp);
+        return true;
+    }
+
+    bool read(enum gguf_type & dst) const {
+        int32_t tmp = -1;
+        if (!read(tmp)) {
+            return false;
+        }
+        dst = gguf_type(tmp);
+        return true;
+    }
+
+    bool read(std::string & dst) const {
+        uint64_t size = -1;
+        if (!read(size)) {
+            return false;
+        }
+        dst.resize(size);
+        return fread(dst.data(), 1, dst.length(), file) == dst.length();
+    }
+
+    bool read(void * dst, const size_t size) const {
+        return fread(dst, 1, size, file) == size;
+    }
+};
+
+struct gguf_context * gguf_init_empty(void) {
+    return new gguf_context;
+}
+
+template<typename T>
+bool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vector<struct gguf_kv> & kv, const std::string & key, const bool is_array, const size_t n) {
+    if (is_array) {
+        std::vector<T> value;
+        try {
+            if (!gr.read(value, n)) {
+                return false;
+            }
+        } catch (std::length_error &) {
+            fprintf(stderr, "%s: encountered length_error while reading value for key '%s'\n", __func__, key.c_str());
+            return false;
+        } catch (std::bad_alloc &) {
+            fprintf(stderr, "%s: encountered bad_alloc error while reading value for key '%s'\n", __func__, key.c_str());
+            return false;
+        }
+        kv.emplace_back(key, value);
+    } else {
+        T value;
+        if (!gr.read(value)) {
+            return false;
+        }
+        kv.emplace_back(key, value);
+    }
+    return true;
+}
+
+struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params) {
+    const struct gguf_reader gr(file);
+    struct gguf_context * ctx = new gguf_context;
+
+    bool ok = true;
+
+    // file magic
+    {
+        std::vector<char> magic;
+        ok = ok && gr.read(magic, 4);
+
+        if (!ok) {
+            fprintf(stderr, "%s: failed to read magic\n", __func__);
+            gguf_free(ctx);
+            return nullptr;
+        }
+
+        for (uint32_t i = 0; i < magic.size(); i++) {
+            if (magic[i] != GGUF_MAGIC[i]) {
+                fprintf(stderr, "%s: invalid magic characters: '%c%c%c%c', expected 'GGUF'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
+                gguf_free(ctx);
+                return nullptr;
+            }
+        }
+    }
+
+    // header
+    int64_t n_kv      = 0;
+    int64_t n_tensors = 0;
+
+    if (ok && gr.read(ctx->version)) {
+        if (ctx->version == 1) {
+            fprintf(stderr, "%s: GGUFv1 is no longer supported, please use a more up-to-date version\n", __func__);
+            ok = false;
+        }
+        if (ctx->version > GGUF_VERSION) {
+            fprintf(stderr, "%s: this GGUF file is version %" PRIu32 " but this software only supports up to version %d\n",
+                __func__, ctx->version, GGUF_VERSION);
+            ok = false;
+        }
+    } else {
+        ok = false;
+    }
+
+    if (ok && gr.read(n_tensors)) {
+        static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing");
+        if (n_tensors < 0 || n_tensors > int64_t(SIZE_MAX/sizeof(gguf_tensor_info))) {
+            fprintf(stderr, "%s: number of tensors is %" PRIi64 " but must be in [0, %zu]\n",
+                __func__, n_tensors, SIZE_MAX/sizeof(gguf_tensor_info));
+            ok = false;
+        }
+    } else {
+        ok = false;
+    }
+
+    if (ok && gr.read(n_kv)) {
+        static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing");
+        if (n_kv < 0 || n_kv > int64_t(SIZE_MAX/sizeof(gguf_kv))) {
+            fprintf(stderr, "%s: number of key value pairs is %" PRIi64 " but must be in [0, %zu]\n",
+                    __func__, n_kv, SIZE_MAX/sizeof(gguf_kv));
+            ok = false;
+        }
+    } else {
+        ok = false;
+    }
+
+    if (!ok) {
+        fprintf(stderr, "%s: failed to read header\n", __func__);
+        gguf_free(ctx);
+        return nullptr;
+    }
+
+    // KV pairs
+    {
+        for (int64_t i = 0; ok && i < n_kv; ++i) {
+            std::string key;
+            gguf_type   type     = gguf_type(-1);
+            bool        is_array = false;
+            uint64_t    n        = 1;
+
+            try {
+                ok = ok && gr.read(key);
+            } catch (std::length_error &) {
+                fprintf(stderr, "%s: encountered length_error while reading key %" PRIi64 "\n", __func__, i);
+                ok = false;
+            } catch (std::bad_alloc &) {
+                fprintf(stderr, "%s: encountered bad_alloc error while reading key %" PRIi64 "\n", __func__, i);
+                ok = false;
+            }
+            for (size_t j = 0; ok && j < ctx->kv.size(); ++j) {
+                if (key == ctx->kv[j].key) {
+                    fprintf(stderr, "%s: duplicate key '%s' for tensors %zu and %" PRIi64 " \n", __func__, key.c_str(), j, i);
+                    ok = false;
+                }
+            }
+            if (!ok) {
+                break;
+            }
+
+            ok = ok && gr.read(type);
+            if (type == GGUF_TYPE_ARRAY) {
+                is_array = true;
+                ok = ok && gr.read(type);
+                ok = ok && gr.read(n);
+            }
+            if (!ok) {
+                break;
+            }
+
+            switch (type) {
+                case GGUF_TYPE_UINT8:   ok = ok && gguf_read_emplace_helper<uint8_t>    (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_INT8:    ok = ok && gguf_read_emplace_helper<int8_t>     (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_UINT16:  ok = ok && gguf_read_emplace_helper<uint16_t>   (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_INT16:   ok = ok && gguf_read_emplace_helper<int16_t>    (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_UINT32:  ok = ok && gguf_read_emplace_helper<uint32_t>   (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_INT32:   ok = ok && gguf_read_emplace_helper<int32_t>    (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_FLOAT32: ok = ok && gguf_read_emplace_helper<float>      (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_BOOL:    ok = ok && gguf_read_emplace_helper<bool>       (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_STRING:  ok = ok && gguf_read_emplace_helper<std::string>(gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_UINT64:  ok = ok && gguf_read_emplace_helper<uint64_t>   (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_INT64:   ok = ok && gguf_read_emplace_helper<int64_t>    (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_FLOAT64: ok = ok && gguf_read_emplace_helper<double>     (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_ARRAY:
+                default:
+                    {
+                        fprintf(stderr, "%s: key '%s' has invalid GGUF type %d\n", __func__, key.c_str(), type);
+                        ok = false;
+                    } break;
+            }
+        }
+
+        if (!ok) {
+            fprintf(stderr, "%s: failed to read key-value pairs\n", __func__);
+            gguf_free(ctx);
+            return nullptr;
+        }
+        GGML_ASSERT(int64_t(ctx->kv.size()) == n_kv);
+
+        const int alignment_idx = gguf_find_key(ctx, GGUF_KEY_GENERAL_ALIGNMENT);
+        ctx->alignment = alignment_idx == -1 ? GGUF_DEFAULT_ALIGNMENT : gguf_get_val_u32(ctx, alignment_idx);
+
+        if (ctx->alignment == 0 || (ctx->alignment & (ctx->alignment - 1)) != 0) {
+            fprintf(stderr, "%s: alignment %zu is not a power of 2\n", __func__, ctx->alignment);
+            gguf_free(ctx);
+            return nullptr;
+        }
+    }
+
+    // read the tensor info
+    for (int64_t i = 0; ok && i < n_tensors; ++i) {
+        struct gguf_tensor_info info;
+
+        // tensor name
+        {
+            std::string name;
+            try {
+                ok = ok && gr.read(name);
+            } catch (std::length_error &) {
+                fprintf(stderr, "%s: encountered length_error while reading tensor name %" PRIi64 "\n", __func__, i);
+                ok = false;
+            } catch (std::bad_alloc &) {
+                fprintf(stderr, "%s: encountered bad_alloc error while reading tensor name %" PRIi64 "\n", __func__, i);
+                ok = false;
+            }
+            if (name.length() >= GGML_MAX_NAME) {
+                fprintf(stderr, "%s: tensor name %" PRIi64 " is too long: %zu >= %d\n", __func__, i, name.length(), GGML_MAX_NAME);
+                ok = false;
+                break;
+            }
+            ggml_set_name(&info.t, name.c_str());
+
+            // make sure there are no duplicate tensor names
+            for (int64_t j = 0; ok && j < i; ++j) {
+                if (strcmp(info.t.name, ctx->info[j].t.name) == 0) {
+                    fprintf(stderr, "%s: duplicate tensor name '%s' for tensors %" PRIi64 " and %" PRIi64 "\n", __func__, info.t.name, j, i);
+                    ok = false;
+                    break;
+                }
+            }
+        }
+        if (!ok) {
+            break;
+        }
+
+        // tensor shape
+        {
+            uint32_t n_dims = -1;
+            ok = ok && gr.read(n_dims);
+            if (n_dims > GGML_MAX_DIMS) {
+                fprintf(stderr, "%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n",
+                    __func__, info.t.name, n_dims, GGML_MAX_DIMS);
+                ok = false;
+                break;
+            }
+            for (uint32_t j = 0; ok && j < GGML_MAX_DIMS; ++j) {
+                info.t.ne[j] = 1;
+                if (j < n_dims) {
+                    ok = ok && gr.read(info.t.ne[j]);
+                }
+
+                // check that all ne are non-negative
+                if (info.t.ne[j] < 0) {
+                    fprintf(stderr, "%s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n",
+                        __func__, info.t.name, j, info.t.ne[j]);
+                    ok = false;
+                    break;
+                }
+            }
+
+            // check that the total number of elements is representable
+            if (ok && ((INT64_MAX/info.t.ne[1] <= info.t.ne[0]) ||
+                       (INT64_MAX/info.t.ne[2] <= info.t.ne[0]*info.t.ne[1]) ||
+                       (INT64_MAX/info.t.ne[3] <= info.t.ne[0]*info.t.ne[1]*info.t.ne[2]))) {
+
+                fprintf(stderr, "%s: total number of elements in tensor '%s' with shape "
+                    "(%" PRIi64 ", %" PRIi64 ", %" PRIi64 ", %" PRIi64 ") is >= %" PRIi64 "\n",
+                    __func__, info.t.name, info.t.ne[0], info.t.ne[1], info.t.ne[2], info.t.ne[3], INT64_MAX);
+                ok = false;
+                break;
+            }
+        }
+        if (!ok) {
+            break;
+        }
+
+        // tensor type
+        {
+            ok = ok && gr.read(info.t.type);
+
+            // check that tensor type is within defined range
+            if (info.t.type < 0 || info.t.type >= GGML_TYPE_COUNT) {
+                fprintf(stderr, "%s: tensor '%s' has invalid ggml type %d (%s)\n",
+                    __func__, info.t.name, info.t.type, ggml_type_name(info.t.type));
+                ok = false;
+                break;
+            }
+            const size_t  type_size = ggml_type_size(info.t.type);
+            const int64_t blck_size = ggml_blck_size(info.t.type);
+
+            // check that row size is divisible by block size
+            if (blck_size == 0 || info.t.ne[0] % blck_size != 0) {
+                fprintf(stderr, "%s: tensor '%s' of type %d (%s) has %" PRId64 " elements per row, "
+                    "not a multiple of block size (%" PRId64 ")\n",
+                    __func__, info.t.name, (int) info.t.type, ggml_type_name(info.t.type), info.t.ne[0], blck_size);
+                ok = false;
+                break;
+            }
+
+            // calculate byte offsets given the tensor shape and type
+            info.t.nb[0] = type_size;
+            info.t.nb[1] = info.t.nb[0]*(info.t.ne[0]/blck_size);
+            for (int j = 2; j < GGML_MAX_DIMS; ++j) {
+                info.t.nb[j] = info.t.nb[j - 1]*info.t.ne[j - 1];
+            }
+        }
+        if (!ok) {
+            break;
+        }
+
+        // tensor data offset within buffer
+        ok = ok && gr.read(info.offset);
+
+        ctx->info.push_back(info);
+    }
+
+    if (!ok) {
+        fprintf(stderr, "%s: failed to read tensor info\n", __func__);
+        gguf_free(ctx);
+        return nullptr;
+    }
+    GGML_ASSERT(int64_t(ctx->info.size()) == n_tensors);
+
+    // we require the data section to be aligned, so take into account any padding
+    if (fseek(file, GGML_PAD(ftell(file), ctx->alignment), SEEK_SET) != 0) {
+        fprintf(stderr, "%s: failed to seek to beginning of data section\n", __func__);
+        gguf_free(ctx);
+        return nullptr;
+    }
+
+    // store the current file offset - this is where the data section starts
+    ctx->offset = ftell(file);
+
+    // compute the total size of the data section, taking into account the alignment
+    {
+        ctx->size = 0;
+        for (size_t i = 0; i < ctx->info.size(); ++i) {
+            const gguf_tensor_info & ti = ctx->info[i];
+            if (ti.offset != ctx->size) {
+                fprintf(stderr, "%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n",
+                    __func__, ti.t.name, ti.offset, ctx->size);
+                fprintf(stderr, "%s: failed to read tensor data\n", __func__);
+                gguf_free(ctx);
+                return nullptr;
+            }
+            ctx->size += GGML_PAD(ggml_nbytes(&ti.t), ctx->alignment);
+        }
+    }
+
+    // load the tensor data only if requested
+    if (params.ctx != nullptr) {
+        // if the provided gguf_context is no_alloc, then we create "empty" tensors and do not read the binary blob
+        // otherwise, we load the binary blob into the created ggml_context as well, and point the "data" members of
+        //   the ggml_tensor structs to the appropriate locations in the binary blob
+
+        // compute the exact size needed for the new ggml_context
+        const size_t mem_size =
+            params.no_alloc ?
+            (n_tensors    )*ggml_tensor_overhead() :
+            (n_tensors + 1)*ggml_tensor_overhead() + ctx->size;
+
+        struct ggml_init_params pdata = {
+            /*mem_size   =*/ mem_size,
+            /*mem_buffer =*/ nullptr,
+            /*no_alloc   =*/ params.no_alloc,
+        };
+
+        *params.ctx = ggml_init(pdata);
+        if (*params.ctx == nullptr) {
+            fprintf(stderr, "%s: failed to initialize ggml context for storing tensors\n", __func__);
+            gguf_free(ctx);
+            return nullptr;
+        }
+
+        struct ggml_context * ctx_data = *params.ctx;
+
+        struct ggml_tensor * data = nullptr;
+
+        if (!params.no_alloc) {
+            data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size);
+
+            ok = ok && data != nullptr;
+
+            // read the binary blob with the tensor data
+            ok = ok && gr.read(data->data, ctx->size);
+
+            if (!ok) {
+                fprintf(stderr, "%s: failed to read tensor data binary blob\n", __func__);
+                ggml_free(ctx_data);
+                *params.ctx = nullptr;
+                gguf_free(ctx);
+                return nullptr;
+            }
+
+            ctx->data = data->data;
+        }
+
+        ggml_set_no_alloc(ctx_data, true);
+
+        // create the tensors
+        for (size_t i = 0; i < ctx->info.size(); ++i) {
+            const struct gguf_tensor_info & info = ctx->info[i];
+
+            struct ggml_tensor * cur = ggml_new_tensor(ctx_data, info.t.type, GGML_MAX_DIMS, info.t.ne);
+
+            ok = ok && cur != nullptr;
+
+            if (!ok) {
+                break;
+            }
+
+            ggml_set_name(cur, info.t.name);
+
+            // point the data member to the appropriate location in the binary blob using the tensor info
+            if (!params.no_alloc) {
+                cur->data = (char *) data->data + info.offset;
+            }
+        }
+
+        if (!ok) {
+            fprintf(stderr, "%s: failed to create tensors\n", __func__);
+            ggml_free(ctx_data);
+            *params.ctx = nullptr;
+            gguf_free(ctx);
+            return nullptr;
+        }
+
+        ggml_set_no_alloc(ctx_data, params.no_alloc);
+    }
+
+    return ctx;
+}
+
+struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) {
+    FILE * file = ggml_fopen(fname, "rb");
+
+    if (!file) {
+        fprintf(stderr, "%s: failed to open GGUF file '%s'\n", __func__, fname);
+        return nullptr;
+    }
+
+    struct gguf_context * result = gguf_init_from_file_impl(file, params);
+    fclose(file);
+    return result;
+}
+
+void gguf_free(struct gguf_context * ctx) {
+    if (ctx == nullptr) {
+        return;
+    }
+    delete ctx;
+}
+
+const char * gguf_type_name(enum gguf_type type) {
+    auto it = GGUF_TYPE_NAME.find(type);
+    return it == GGUF_TYPE_NAME.end() ? nullptr : it->second;
+}
+
+uint32_t gguf_get_version(const struct gguf_context * ctx) {
+    return ctx->version;
+}
+
+size_t gguf_get_alignment(const struct gguf_context * ctx) {
+    return ctx->alignment;
+}
+
+size_t gguf_get_data_offset(const struct gguf_context * ctx) {
+    return ctx->offset;
+}
+
+int64_t gguf_get_n_kv(const struct gguf_context * ctx) {
+    return ctx->kv.size();
+}
+
+int64_t gguf_find_key(const struct gguf_context * ctx, const char * key) {
+    // return -1 if key not found
+    int64_t keyfound = -1;
+
+    const int64_t n_kv = gguf_get_n_kv(ctx);
+
+    for (int64_t i = 0; i < n_kv; ++i) {
+        if (strcmp(key, gguf_get_key(ctx, i)) == 0) {
+            keyfound = i;
+            break;
+        }
+    }
+
+    return keyfound;
+}
+
+const char * gguf_get_key(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    return ctx->kv[key_id].get_key().c_str();
+}
+
+enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    return ctx->kv[key_id].is_array ? GGUF_TYPE_ARRAY : ctx->kv[key_id].get_type();
+}
+
+enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].is_array);
+    return ctx->kv[key_id].get_type();
+}
+
+const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
+    return ctx->kv[key_id].data.data();
+}
+
+const char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING);
+    return ctx->kv[key_id].data_string[i].c_str();
+}
+
+size_t gguf_get_arr_n(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+
+    if (ctx->kv[key_id].type == GGUF_TYPE_STRING) {
+        return ctx->kv[key_id].data_string.size();
+    }
+
+    const size_t type_size = gguf_type_size(ctx->kv[key_id].type);
+    GGML_ASSERT(ctx->kv[key_id].data.size() % type_size == 0);
+    return ctx->kv[key_id].data.size() / type_size;
+}
+
+uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val<uint8_t>();
+}
+
+int8_t gguf_get_val_i8(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val<int8_t>();
+}
+
+uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val<uint16_t>();
+}
+
+int16_t gguf_get_val_i16(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val<int16_t>();
+}
+
+uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val<uint32_t>();
+}
+
+int32_t gguf_get_val_i32(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val<int32_t>();
+}
+
+float gguf_get_val_f32(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val<float>();
+}
+
+uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val<uint64_t>();
+}
+
+int64_t gguf_get_val_i64(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val<int64_t>();
+}
+
+double gguf_get_val_f64(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val<double>();
+}
+
+bool gguf_get_val_bool(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val<bool>();
+}
+
+const char * gguf_get_val_str(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val<std::string>().c_str();
+}
+
+const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
+    return ctx->kv[key_id].data.data();
+}
+
+int64_t gguf_get_n_tensors(const struct gguf_context * ctx) {
+    return ctx->info.size();
+}
+
+int64_t gguf_find_tensor(const struct gguf_context * ctx, const char * name) {
+    // return -1 if tensor not found
+    int64_t tensor_id = -1;
+
+    const int64_t n_tensors = gguf_get_n_tensors(ctx);
+
+    for (int64_t i = 0; i < n_tensors; ++i) {
+        if (strcmp(name, gguf_get_tensor_name(ctx, i)) == 0) {
+            tensor_id = i;
+            break;
+        }
+    }
+
+    return tensor_id;
+}
+
+size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int64_t tensor_id) {
+    GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx));
+    return ctx->info[tensor_id].offset;
+}
+
+const char * gguf_get_tensor_name(const struct gguf_context * ctx, int64_t tensor_id) {
+    GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx));
+    return ctx->info[tensor_id].t.name;
+}
+
+enum ggml_type gguf_get_tensor_type(const struct gguf_context * ctx, int64_t tensor_id) {
+    GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx));
+    return ctx->info[tensor_id].t.type;
+}
+
+size_t gguf_get_tensor_size(const struct gguf_context * ctx, int64_t tensor_id) {
+    GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx));
+    return ggml_nbytes(&ctx->info[tensor_id].t);
+}
+
+int64_t gguf_remove_key(struct gguf_context * ctx, const char * key) {
+    const int64_t key_id = gguf_find_key(ctx, key);
+    if (key_id >= 0) {
+        ctx->kv.erase(ctx->kv.begin() + key_id);
+    }
+    return key_id;
+}
+
+template<typename T>
+static void gguf_check_reserved_keys(const std::string & key, const T val) {
+    if (key == GGUF_KEY_GENERAL_ALIGNMENT) {
+        if constexpr (std::is_same<T, uint32_t>::value) {
+            GGML_ASSERT(val > 0 && (val & (val - 1)) == 0 && GGUF_KEY_GENERAL_ALIGNMENT " must be power of 2");
+        } else {
+            GGML_ABORT(GGUF_KEY_GENERAL_ALIGNMENT " must be type u32");
+        }
+    }
+}
+
+void gguf_set_val_u8(struct gguf_context * ctx, const char * key, uint8_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_i8(struct gguf_context * ctx, const char * key, int8_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_u16(struct gguf_context * ctx, const char * key, uint16_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_i16(struct gguf_context * ctx, const char * key, int16_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_u32(struct gguf_context * ctx, const char * key, uint32_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_i32(struct gguf_context * ctx, const char * key, int32_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_f32(struct gguf_context * ctx, const char * key, float val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_u64(struct gguf_context * ctx, const char * key, uint64_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_i64(struct gguf_context * ctx, const char * key, int64_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_f64(struct gguf_context * ctx, const char * key, double val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_str(struct gguf_context * ctx, const char * key, const char * val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, std::string(val));
+}
+
+void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, size_t n) {
+    gguf_check_reserved_keys(key, data);
+    gguf_remove_key(ctx, key);
+
+    const size_t nbytes = n*gguf_type_size(type);
+    std::vector<int8_t> tmp(nbytes);
+    if (!tmp.empty()) {
+        memcpy(tmp.data(), data, nbytes);
+    }
+    ctx->kv.emplace_back(key, tmp);
+    ctx->kv.back().cast(type);
+}
+
+void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char ** data, size_t n) {
+    gguf_check_reserved_keys(key, data);
+    gguf_remove_key(ctx, key);
+
+    std::vector<std::string> tmp(n);
+    for (size_t i = 0; i < n; ++i) {
+        tmp[i] = data[i];
+    }
+    ctx->kv.emplace_back(key, tmp);
+}
+
+// set or add KV pairs from another context
+void gguf_set_kv(struct gguf_context * ctx, const struct gguf_context * src) {
+    const int64_t n_kv = gguf_get_n_kv(src);
+    for (int64_t i = 0; i < n_kv; ++i) {
+        const struct gguf_kv & kv = src->kv[i];
+
+        if (!kv.is_array) {
+            switch (kv.get_type()) {
+                case GGUF_TYPE_UINT8:   gguf_set_val_u8  (ctx, kv.get_key().c_str(), kv.get_val<uint8_t>());             break;
+                case GGUF_TYPE_INT8:    gguf_set_val_i8  (ctx, kv.get_key().c_str(), kv.get_val<int8_t>());              break;
+                case GGUF_TYPE_UINT16:  gguf_set_val_u16 (ctx, kv.get_key().c_str(), kv.get_val<uint16_t>());            break;
+                case GGUF_TYPE_INT16:   gguf_set_val_i16 (ctx, kv.get_key().c_str(), kv.get_val<int16_t>());             break;
+                case GGUF_TYPE_UINT32:  gguf_set_val_u32 (ctx, kv.get_key().c_str(), kv.get_val<uint32_t>());            break;
+                case GGUF_TYPE_INT32:   gguf_set_val_i32 (ctx, kv.get_key().c_str(), kv.get_val<int32_t>());             break;
+                case GGUF_TYPE_FLOAT32: gguf_set_val_f32 (ctx, kv.get_key().c_str(), kv.get_val<float>());               break;
+                case GGUF_TYPE_UINT64:  gguf_set_val_u64 (ctx, kv.get_key().c_str(), kv.get_val<uint64_t>());            break;
+                case GGUF_TYPE_INT64:   gguf_set_val_i64 (ctx, kv.get_key().c_str(), kv.get_val<int64_t>());             break;
+                case GGUF_TYPE_FLOAT64: gguf_set_val_f64 (ctx, kv.get_key().c_str(), kv.get_val<double>());              break;
+                case GGUF_TYPE_BOOL:    gguf_set_val_bool(ctx, kv.get_key().c_str(), kv.get_val<bool>());                break;
+                case GGUF_TYPE_STRING:  gguf_set_val_str (ctx, kv.get_key().c_str(), kv.get_val<std::string>().c_str()); break;
+                case GGUF_TYPE_ARRAY:
+                default: GGML_ABORT("invalid type");
+            }
+            continue;
+        }
+
+        const size_t ne = kv.get_ne();
+
+        switch (kv.get_type()) {
+            case GGUF_TYPE_UINT8:
+            case GGUF_TYPE_INT8:
+            case GGUF_TYPE_UINT16:
+            case GGUF_TYPE_INT16:
+            case GGUF_TYPE_UINT32:
+            case GGUF_TYPE_INT32:
+            case GGUF_TYPE_FLOAT32:
+            case GGUF_TYPE_UINT64:
+            case GGUF_TYPE_INT64:
+            case GGUF_TYPE_FLOAT64:
+            case GGUF_TYPE_BOOL: {
+                gguf_set_arr_data(ctx, kv.get_key().c_str(), kv.get_type(), kv.data.data(), ne);
+            } break;
+            case GGUF_TYPE_STRING: {
+                std::vector<const char *> tmp(ne);
+                for (size_t j = 0; j < ne; ++j) {
+                    tmp[j] = kv.data_string[j].c_str();
+                }
+                gguf_set_arr_str(ctx, kv.get_key().c_str(), tmp.data(), ne);
+            } break;
+            case GGUF_TYPE_ARRAY:
+            default: GGML_ABORT("invalid type");
+        }
+    }
+}
+
+void gguf_add_tensor(
+             struct gguf_context * ctx,
+        const struct ggml_tensor * tensor) {
+    GGML_ASSERT(tensor);
+    if (gguf_find_tensor(ctx, tensor->name) != -1) {
+        GGML_ABORT("duplicate tensor name: %s", tensor->name);
+    }
+
+    struct gguf_tensor_info ti;
+    ti.t = *tensor;
+    ti.offset = ctx->info.empty() ? 0 :
+        ctx->info.back().offset + GGML_PAD(ggml_nbytes(&ctx->info.back().t), ctx->alignment);
+    ctx->info.push_back(ti);
+}
+
+void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type) {
+    const int64_t tensor_id = gguf_find_tensor(ctx, name);
+    if (tensor_id < 0) {
+        GGML_ABORT("tensor not found: %s", name);
+    }
+    struct ggml_tensor * tensor = &ctx->info[tensor_id].t;
+    const size_t  type_size = ggml_type_size(type);
+    const int64_t blck_size = ggml_blck_size(type);
+
+    tensor->type = type;
+    GGML_ASSERT(tensor->ne[0] % blck_size == 0 && "tensor row size not divisible by block size of new type");
+
+    tensor->nb[0] = type_size;
+    tensor->nb[1] = tensor->nb[0]*(tensor->ne[0]/blck_size);
+    for (int i = 2; i < GGML_MAX_DIMS; i++) {
+        tensor->nb[i] = tensor->nb[i - 1]*tensor->ne[i - 1];
+    }
+
+    // update offsets
+    const int64_t n_tensors = gguf_get_n_tensors(ctx);
+    for (int64_t i = tensor_id + 1; i < n_tensors; ++i) {
+        ctx->info[i].offset = ctx->info[i - 1].offset + GGML_PAD(ggml_nbytes(&ctx->info[i - 1].t), ctx->alignment);
+    }
+}
+
+void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data) {
+    const int64_t tensor_id = gguf_find_tensor(ctx, name);
+    if (tensor_id < 0) {
+        GGML_ABORT("tensor not found: %s", name);
+    }
+
+    ctx->info[tensor_id].t.data = (void *)(uintptr_t)data; // double cast suppresses warning about casting away const
+}
+
+struct gguf_writer {
+    std::vector<int8_t> & buf;
+
+    gguf_writer(std::vector<int8_t> & buf) : buf(buf) {}
+
+    template <typename T>
+    void write(const T & val) const {
+        for (size_t i = 0; i < sizeof(val); ++i) {
+            buf.push_back(reinterpret_cast<const int8_t *>(&val)[i]);
+        }
+    }
+
+    void write(const std::vector<int8_t> & val) const {
+        buf.insert(buf.end(), val.begin(), val.end());
+    }
+
+    void write(const bool & val) const {
+        const int8_t val8 = val ? 1 : 0;
+        write(val8);
+    }
+
+    void write(const std::string & val) const {
+        {
+            const uint64_t n = val.length();
+            write(n);
+        }
+        for (size_t i = 0; i < val.length(); ++i) {
+            buf.push_back(reinterpret_cast<const int8_t *>(val.data())[i]);
+        }
+    }
+
+    void write(const char * val) const {
+        write(std::string(val));
+    }
+
+    void write(const enum ggml_type & val) const {
+        write(int32_t(val));
+    }
+
+    void write(const enum gguf_type & val) const {
+        write(int32_t(val));
+    }
+
+    void write(const struct gguf_kv & kv) const {
+        const uint64_t ne = kv.get_ne();
+
+        write(kv.get_key());
+
+        if (kv.is_array) {
+            write(GGUF_TYPE_ARRAY);
+            write(kv.get_type());
+            write(ne);
+        } else {
+            write(kv.get_type());
+        }
+
+        switch (kv.get_type()) {
+            case GGUF_TYPE_UINT8:
+            case GGUF_TYPE_INT8:
+            case GGUF_TYPE_UINT16:
+            case GGUF_TYPE_INT16:
+            case GGUF_TYPE_UINT32:
+            case GGUF_TYPE_INT32:
+            case GGUF_TYPE_FLOAT32:
+            case GGUF_TYPE_UINT64:
+            case GGUF_TYPE_INT64:
+            case GGUF_TYPE_FLOAT64: {
+                write(kv.data);
+            } break;
+            case GGUF_TYPE_BOOL: {
+                for (size_t i = 0; i < ne; ++i) {
+                    write(kv.get_val<bool>(i));
+                }
+            } break;
+            case GGUF_TYPE_STRING: {
+                for (size_t i = 0; i < ne; ++i) {
+                    write(kv.get_val<std::string>(i));
+                }
+            } break;
+            case GGUF_TYPE_ARRAY:
+            default: GGML_ABORT("invalid type");
+        }
+    }
+
+    void write_tensor_meta(const struct gguf_tensor_info & info) const {
+        write(info.t.name);
+
+        const uint32_t n_dims = ggml_n_dims(&info.t);
+        write(n_dims);
+
+        for (uint32_t j = 0; j < n_dims; ++j) {
+            write(info.t.ne[j]);
+        }
+        write(info.t.type);
+        write(info.offset);
+    }
+
+    void pad(const size_t alignment) const {
+        while (buf.size() % alignment != 0) {
+            const int8_t zero = 0;
+            write(zero);
+        }
+    }
+
+    void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) const {
+        GGML_ASSERT(buf.size() - offset_data == info.offset);
+
+        GGML_ASSERT(ggml_is_contiguous(&info.t));
+        const size_t offset = buf.size();
+        const size_t nbytes = ggml_nbytes(&info.t);
+
+        buf.resize(offset + nbytes);
+        if (info.t.buffer) {
+            ggml_backend_tensor_get(&info.t, buf.data() + offset, 0, nbytes);
+        } else {
+            GGML_ASSERT(info.t.data);
+            memcpy(buf.data() + offset, info.t.data, nbytes);
+        }
+
+        pad(alignment);
+    }
+};
+
+void gguf_write_to_buf(const struct gguf_context * ctx, std::vector<int8_t> & buf, bool only_meta) {
+    const struct gguf_writer gw(buf);
+
+    const int64_t n_kv      = gguf_get_n_kv(ctx);
+    const int64_t n_tensors = gguf_get_n_tensors(ctx);
+
+    // write header
+    gw.write(GGUF_MAGIC[0]);
+    gw.write(GGUF_MAGIC[1]);
+    gw.write(GGUF_MAGIC[2]);
+    gw.write(GGUF_MAGIC[3]);
+    gw.write(ctx->version);
+    gw.write(n_tensors);
+    gw.write(n_kv);
+
+    // write key-value pairs
+    for (int64_t i = 0; i < n_kv; ++i) {
+        gw.write(ctx->kv[i]);
+    }
+
+    // write tensor info
+    for (int64_t i = 0; i < n_tensors; ++i) {
+        gw.write_tensor_meta(ctx->info[i]);
+    }
+
+    // we require the data section to be aligned
+    gw.pad(ctx->alignment);
+
+    if (only_meta) {
+        return;
+    }
+
+    const size_t offset_data = gw.buf.size();
+
+    // write tensor data
+    for (int64_t i = 0; i < n_tensors; ++i) {
+        gw.write_tensor_data(ctx->info[i], offset_data, ctx->alignment);
+    }
+}
+
+bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) {
+    FILE * file = ggml_fopen(fname, "wb");
+
+    if (!file) {
+        fprintf(stderr, "%s: failed to open file '%s' for writing GGUF data\n", __func__, fname);
+        return false;
+    }
+
+    std::vector<int8_t> buf;
+    gguf_write_to_buf(ctx, buf, only_meta);
+    const bool ok = fwrite(buf.data(), 1, buf.size(), file) == buf.size();
+    fclose(file);
+    return ok;
+}
+
+size_t gguf_get_meta_size(const struct gguf_context * ctx) {
+    // only return size
+    std::vector<int8_t> buf;
+    gguf_write_to_buf(ctx, buf, /*only_meta =*/ true);
+    return buf.size();
+}
+
+void gguf_get_meta_data(const struct gguf_context * ctx, void * data) {
+    std::vector<int8_t> buf;
+    gguf_write_to_buf(ctx, buf, /*only_meta =*/ true);
+    memcpy(data, buf.data(), buf.size());
+}
index a05ba4f635c0109539b5a0207299fa4155189a05..6ec709dd323a6d5319d67d64f607c0acc85f3267 100644 (file)
@@ -1,5 +1,6 @@
 #include "llama-impl.h"
 
+#include "gguf.h"
 #include "llama.h"
 
 #include <cinttypes>
@@ -138,7 +139,7 @@ std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
             {
                 const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i);
                 int arr_n = gguf_get_arr_n(ctx_gguf, i);
-                const void * data = gguf_get_arr_data(ctx_gguf, i);
+                const void * data = arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx_gguf, i);
                 std::stringstream ss;
                 ss << "[";
                 for (int j = 0; j < arr_n; j++) {
index 7743b46522ce58e5344683c90327f1f492bdcb8d..1c4e30878a1382945493de3220589b0d837b59c3 100644 (file)
@@ -18,7 +18,7 @@ const char * llama_file_version_name(llama_fver version) {
 }
 
 namespace GGUFMeta {
-    template <typename T, gguf_type gt_, T (*gfun)(const gguf_context *, const int)>
+    template <typename T, gguf_type gt_, T (*gfun)(const gguf_context *, const int64_t)>
     struct GKV_Base_Type {
         static constexpr gguf_type gt = gt_;
 
@@ -60,10 +60,11 @@ namespace GGUFMeta {
         public:
         static constexpr gguf_type gt = GGUF_TYPE_ARRAY;
         static ArrayInfo getter(const gguf_context *ctx, const int k) {
+            const enum gguf_type arr_type = gguf_get_arr_type(ctx, k);
             return ArrayInfo {
-                gguf_get_arr_type(ctx, k),
+                arr_type,
                 size_t(gguf_get_arr_n(ctx, k)),
-                gguf_get_arr_data(ctx, k),
+                arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx, k),
             };
         }
     };
@@ -553,7 +554,7 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap,
             const enum gguf_type type   = gguf_get_kv_type(meta.get(), i);
             const std::string type_name =
                 type == GGUF_TYPE_ARRAY
-                ? format("%s[%s,%d]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(meta.get(), i)), gguf_get_arr_n(meta.get(), i))
+                ? format("%s[%s,%zu]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(meta.get(), i)), gguf_get_arr_n(meta.get(), i))
                 : gguf_type_name(type);
 
             std::string value          = gguf_kv_to_str(meta.get(), i);
index 104f90343a402edde2e5afccfcc9ceaa2dbf2a54..038cf58ddcce5d812b576b34c2a90d1e60e3b464 100644 (file)
@@ -875,7 +875,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
 
         // update the gguf meta data as we go
         gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type);
-        gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data, new_size);
+        GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size);
+        gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data);
 
         // write tensor data + padding
         fout.write((const char *) new_data, new_size);
index 1bb5fb47c4317006d86834eeb322a82c7ab1a7b3..611957ac0d13fc5e9d9a9e477216c93abd7a133f 100644 (file)
@@ -15,66 +15,71 @@ constexpr int offset_has_tensors = 2000;
 constexpr int offset_has_data    = 3000;
 
 enum handcrafted_file_type {
-    HANDCRAFTED_HEADER_BAD_MAGIC          =  10,
-    HANDCRAFTED_HEADER_BAD_VERSION_1      =  20,
-    HANDCRAFTED_HEADER_BAD_VERSION_FUTURE =  30,
-    HANDCRAFTED_HEADER_BAD_N_TENSORS      =  40,
-    HANDCRAFTED_HEADER_BAD_N_KV           =  50,
-    HANDCRAFTED_HEADER_EMPTY              = 800,
-
-    HANDCRAFTED_KV_BAD_KEY_SIZE           =  10 + offset_has_kv,
-    HANDCRAFTED_KV_BAD_TYPE               =  20 + offset_has_kv,
-    HANDCRAFTED_KV_BAD_VALUE_SIZE         =  30 + offset_has_kv,
-    HANDCRAFTED_KV_DUPLICATE_KEY          =  40 + offset_has_kv,
-    HANDCRAFTED_KV_SUCCESS                = 800 + offset_has_kv,
-
-    HANDCRAFTED_TENSORS_BAD_NAME_SIZE     =  10 + offset_has_tensors,
-    HANDCRAFTED_TENSORS_BAD_N_DIMS        =  20 + offset_has_tensors,
-    HANDCRAFTED_TENSORS_BAD_SHAPE         =  30 + offset_has_tensors,
-    HANDCRAFTED_TENSORS_NE_TOO_BIG        =  40 + offset_has_tensors,
-    HANDCRAFTED_TENSORS_BAD_TYPE          =  50 + offset_has_tensors,
-    HANDCRAFTED_TENSORS_BAD_OFFSET        =  60 + offset_has_tensors,
-    HANDCRAFTED_TENSORS_DUPLICATE_NAME    =  70 + offset_has_tensors,
-    HANDCRAFTED_TENSORS_BAD_ALIGNMENT     =  80 + offset_has_tensors,
-    HANDCRAFTED_TENSORS_SUCCESS           = 800 + offset_has_tensors,
-    HANDCRAFTED_TENSORS_CUSTOM_ALIGN      = 810 + offset_has_tensors,
-
-    HANDCRAFTED_DATA_NOT_ENOUGH_DATA      =  10 + offset_has_data,
-    HANDCRAFTED_DATA_BAD_ALIGNMENT        =  20 + offset_has_data,
-    HANDCRAFTED_DATA_SUCCESS              = 800 + offset_has_data,
-    HANDCRAFTED_DATA_CUSTOM_ALIGN         = 810 + offset_has_data,
+    HANDCRAFTED_HEADER_BAD_MAGIC           =  10,
+    HANDCRAFTED_HEADER_BAD_VERSION_1       =  20,
+    HANDCRAFTED_HEADER_BAD_VERSION_FUTURE  =  30,
+    HANDCRAFTED_HEADER_BAD_N_TENSORS       =  40,
+    HANDCRAFTED_HEADER_BAD_N_KV            =  50,
+    HANDCRAFTED_HEADER_EMPTY               = 800,
+
+    HANDCRAFTED_KV_BAD_KEY_SIZE            =  10 + offset_has_kv,
+    HANDCRAFTED_KV_BAD_TYPE                =  20 + offset_has_kv,
+    // HANDCRAFTED_KV_BAD_VALUE_SIZE          =  30 + offset_has_kv, // removed because it can result in allocations > 1 TB (default sanitizer limit)
+    HANDCRAFTED_KV_DUPLICATE_KEY           =  40 + offset_has_kv,
+    HANDCRAFTED_KV_BAD_ALIGN               =  50 + offset_has_kv,
+    HANDCRAFTED_KV_SUCCESS                 = 800 + offset_has_kv,
+
+    HANDCRAFTED_TENSORS_BAD_NAME_SIZE      =  10 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_BAD_N_DIMS         =  20 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_BAD_SHAPE          =  30 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_NE_TOO_BIG         =  40 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_BAD_TYPE           =  50 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_BAD_OFFSET         =  60 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_DUPLICATE_NAME     =  70 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_BAD_ALIGN          =  75 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_INCONSISTENT_ALIGN =  80 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_SUCCESS            = 800 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_CUSTOM_ALIGN       = 810 + offset_has_tensors,
+
+    HANDCRAFTED_DATA_NOT_ENOUGH_DATA       =  10 + offset_has_data,
+    HANDCRAFTED_DATA_BAD_ALIGN             =  15 + offset_has_data,
+    HANDCRAFTED_DATA_INCONSISTENT_ALIGN    =  20 + offset_has_data,
+    HANDCRAFTED_DATA_SUCCESS               = 800 + offset_has_data,
+    HANDCRAFTED_DATA_CUSTOM_ALIGN          = 810 + offset_has_data,
 };
 
 std::string handcrafted_file_type_name(const enum handcrafted_file_type hft) {
     switch (hft) {
-        case HANDCRAFTED_HEADER_BAD_MAGIC:          return "HEADER_BAD_MAGIC";
-        case HANDCRAFTED_HEADER_BAD_VERSION_1:      return "HEADER_BAD_VERSION_1";
-        case HANDCRAFTED_HEADER_BAD_VERSION_FUTURE: return "HEADER_BAD_VERSION_FUTURE";
-        case HANDCRAFTED_HEADER_BAD_N_KV:           return "HEADER_BAD_N_KV";
-        case HANDCRAFTED_HEADER_BAD_N_TENSORS:      return "HEADER_BAD_N_TENSORS";
-        case HANDCRAFTED_HEADER_EMPTY:              return "HEADER_EMPTY";
-
-        case HANDCRAFTED_KV_BAD_KEY_SIZE:           return "KV_BAD_KEY_SIZE";
-        case HANDCRAFTED_KV_BAD_TYPE:               return "KV_BAD_TYPE";
-        case HANDCRAFTED_KV_BAD_VALUE_SIZE:         return "KV_BAD_VALUE_SIZE";
-        case HANDCRAFTED_KV_DUPLICATE_KEY:          return "KV_DUPLICATE_KEY";
-        case HANDCRAFTED_KV_SUCCESS:                return "KV_RANDOM_KV";
-
-        case HANDCRAFTED_TENSORS_BAD_NAME_SIZE:     return "TENSORS_BAD_NAME_SIZE";
-        case HANDCRAFTED_TENSORS_BAD_N_DIMS:        return "TENSORS_BAD_N_DIMS";
-        case HANDCRAFTED_TENSORS_BAD_SHAPE:         return "TENSORS_BAD_SHAPE";
-        case HANDCRAFTED_TENSORS_NE_TOO_BIG:        return "TENSORS_NE_TOO_BIG";
-        case HANDCRAFTED_TENSORS_BAD_TYPE:          return "TENSORS_BAD_TYPE";
-        case HANDCRAFTED_TENSORS_BAD_OFFSET:        return "TENSORS_BAD_OFFSET";
-        case HANDCRAFTED_TENSORS_DUPLICATE_NAME:    return "TENSORS_DUPLICATE_NAME";
-        case HANDCRAFTED_TENSORS_BAD_ALIGNMENT:     return "TENSORS_BAD_ALIGNMENT";
-        case HANDCRAFTED_TENSORS_SUCCESS:           return "TENSORS_SUCCESS";
-        case HANDCRAFTED_TENSORS_CUSTOM_ALIGN:      return "TENSORS_CUSTOM_ALIGN";
-
-        case HANDCRAFTED_DATA_NOT_ENOUGH_DATA:      return "DATA_NOT_ENOUGH_DATA";
-        case HANDCRAFTED_DATA_BAD_ALIGNMENT:        return "DATA_BAD_ALIGNMENT";
-        case HANDCRAFTED_DATA_SUCCESS:              return "DATA_SUCCESS";
-        case HANDCRAFTED_DATA_CUSTOM_ALIGN:         return "DATA_CUSTOM_ALIGN";
+        case HANDCRAFTED_HEADER_BAD_MAGIC:           return "HEADER_BAD_MAGIC";
+        case HANDCRAFTED_HEADER_BAD_VERSION_1:       return "HEADER_BAD_VERSION_1";
+        case HANDCRAFTED_HEADER_BAD_VERSION_FUTURE:  return "HEADER_BAD_VERSION_FUTURE";
+        case HANDCRAFTED_HEADER_BAD_N_KV:            return "HEADER_BAD_N_KV";
+        case HANDCRAFTED_HEADER_BAD_N_TENSORS:       return "HEADER_BAD_N_TENSORS";
+        case HANDCRAFTED_HEADER_EMPTY:               return "HEADER_EMPTY";
+
+        case HANDCRAFTED_KV_BAD_KEY_SIZE:            return "KV_BAD_KEY_SIZE";
+        case HANDCRAFTED_KV_BAD_TYPE:                return "KV_BAD_TYPE";
+        case HANDCRAFTED_KV_DUPLICATE_KEY:           return "KV_DUPLICATE_KEY";
+        case HANDCRAFTED_KV_BAD_ALIGN:               return "KV_BAD_ALIGN";
+        case HANDCRAFTED_KV_SUCCESS:                 return "KV_RANDOM_KV";
+
+        case HANDCRAFTED_TENSORS_BAD_NAME_SIZE:      return "TENSORS_BAD_NAME_SIZE";
+        case HANDCRAFTED_TENSORS_BAD_N_DIMS:         return "TENSORS_BAD_N_DIMS";
+        case HANDCRAFTED_TENSORS_BAD_SHAPE:          return "TENSORS_BAD_SHAPE";
+        case HANDCRAFTED_TENSORS_NE_TOO_BIG:         return "TENSORS_NE_TOO_BIG";
+        case HANDCRAFTED_TENSORS_BAD_TYPE:           return "TENSORS_BAD_TYPE";
+        case HANDCRAFTED_TENSORS_BAD_OFFSET:         return "TENSORS_BAD_OFFSET";
+        case HANDCRAFTED_TENSORS_DUPLICATE_NAME:     return "TENSORS_DUPLICATE_NAME";
+        case HANDCRAFTED_TENSORS_BAD_ALIGN:          return "TENSORS_BAD_ALIGN";
+        case HANDCRAFTED_TENSORS_INCONSISTENT_ALIGN: return "TENSORS_INCONSISTENT_ALIGN";
+        case HANDCRAFTED_TENSORS_SUCCESS:            return "TENSORS_SUCCESS";
+        case HANDCRAFTED_TENSORS_CUSTOM_ALIGN:       return "TENSORS_CUSTOM_ALIGN";
+
+        case HANDCRAFTED_DATA_NOT_ENOUGH_DATA:       return "DATA_NOT_ENOUGH_DATA";
+        case HANDCRAFTED_DATA_BAD_ALIGN:             return "DATA_BAD_ALIGN";
+        case HANDCRAFTED_DATA_INCONSISTENT_ALIGN:    return "DATA_INCONSISTENT_ALIGN";
+        case HANDCRAFTED_DATA_SUCCESS:               return "DATA_SUCCESS";
+        case HANDCRAFTED_DATA_CUSTOM_ALIGN:          return "DATA_CUSTOM_ALIGN";
     }
     GGML_ABORT("fatal error");
 }
@@ -140,31 +145,41 @@ std::vector<std::pair<enum gguf_type, enum gguf_type>> get_kv_types(std::mt19937
     return kv_types;
 }
 
-static void helper_write(const void * data, const size_t nbytes, FILE * file) {
+template <typename T>
+static void helper_write(FILE * file, const T & val) {
+    GGML_ASSERT(fwrite(&val, 1, sizeof(val), file) == sizeof(val));
+}
+
+static void helper_write(FILE * file, const void * data, const size_t nbytes) {
     GGML_ASSERT(fwrite(data, 1, nbytes, file) == nbytes);
 }
 
 static FILE * get_handcrafted_file(const unsigned int seed, const enum handcrafted_file_type hft, const int extra_bytes = 0) {
     FILE * file = tmpfile();
 
+    if (!file) {
+        return file;
+    }
+
     std::mt19937 rng(seed);
+    uint32_t alignment = GGUF_DEFAULT_ALIGNMENT;
 
     if (hft == HANDCRAFTED_HEADER_BAD_MAGIC) {
         const char bad_magic[4] = {'F', 'U', 'G', 'G'};
-        helper_write(bad_magic, sizeof(bad_magic), file);
+        helper_write(file, bad_magic, sizeof(bad_magic));
     } else {
-        helper_write(GGUF_MAGIC, 4, file);
+        helper_write(file, GGUF_MAGIC, 4);
     }
 
     if (hft == HANDCRAFTED_HEADER_BAD_VERSION_1) {
         const uint32_t version = 1;
-        helper_write(&version, sizeof(version), file);
+        helper_write(file, version);
     } else if (hft == HANDCRAFTED_HEADER_BAD_VERSION_FUTURE) {
         const uint32_t version = GGUF_VERSION + 1;
-        helper_write(&version, sizeof(version), file);
+        helper_write(file, version);
     } else {
         const uint32_t version = GGUF_VERSION;
-        helper_write(&version, sizeof(version), file);
+        helper_write(file, version);
     }
 
     std::vector<tensor_config_t> tensor_configs;
@@ -174,10 +189,10 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft
 
     if (hft == HANDCRAFTED_HEADER_BAD_N_TENSORS) {
         const uint64_t n_tensors = -1;
-        helper_write(&n_tensors, sizeof(n_tensors), file);
+        helper_write(file, n_tensors);
     } else {
         const uint64_t n_tensors = tensor_configs.size();
-        helper_write(&n_tensors, sizeof(n_tensors), file);
+        helper_write(file, n_tensors);
     }
 
     std::vector<std::pair<enum gguf_type, enum gguf_type>> kv_types;
@@ -186,41 +201,49 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft
     }
     {
         uint64_t n_kv = kv_types.size();
-        if (hft == HANDCRAFTED_TENSORS_CUSTOM_ALIGN || hft == HANDCRAFTED_DATA_CUSTOM_ALIGN) {
+        if (hft == HANDCRAFTED_KV_BAD_ALIGN      ||
+            hft == HANDCRAFTED_TENSORS_BAD_ALIGN || hft == HANDCRAFTED_TENSORS_CUSTOM_ALIGN ||
+            hft == HANDCRAFTED_DATA_BAD_ALIGN    || hft == HANDCRAFTED_DATA_CUSTOM_ALIGN) {
+
             n_kv += 1;
         } else if (hft == HANDCRAFTED_HEADER_BAD_N_KV) {
             n_kv = -1;
         }
-        helper_write(&n_kv, sizeof(n_kv), file);
+        helper_write(file, n_kv);
     }
 
     if (hft < offset_has_kv) {
+        while (ftell(file) % alignment != 0) {
+            const char pad = 0;
+            helper_write(file, pad);
+        }
+
         for (int i = 0; i < extra_bytes; ++i) {
             const char tmp = 0;
-            helper_write(&tmp, sizeof(tmp), file);
+            helper_write(file, tmp);
         }
         rewind(file);
         return file;
     }
 
     for (int i = 0; i < int(kv_types.size()); ++i) {
-        const enum gguf_type type     = gguf_type(hft == HANDCRAFTED_KV_BAD_TYPE ? -1 : kv_types[i].first);
-        const enum gguf_type type_arr = gguf_type(hft == HANDCRAFTED_KV_BAD_TYPE ? -1 : kv_types[i].second);
+        const enum gguf_type type     = gguf_type(hft == HANDCRAFTED_KV_BAD_TYPE ? GGUF_TYPE_COUNT : kv_types[i].first);
+        const enum gguf_type type_arr = gguf_type(hft == HANDCRAFTED_KV_BAD_TYPE ? GGUF_TYPE_COUNT : kv_types[i].second);
 
         const std::string key = "my_key_" + std::to_string((hft == HANDCRAFTED_KV_DUPLICATE_KEY ? i/2 : i));
 
         if (hft == HANDCRAFTED_KV_BAD_KEY_SIZE) {
             const uint64_t n = -1;
-            helper_write(&n, sizeof(n), file);
+            helper_write(file, n);
         } else {
             const uint64_t n = key.length();
-            helper_write(&n, sizeof(n), file);
+            helper_write(file, n);
         }
-        helper_write(key.data(), key.length(), file);
+        helper_write(file, key.data(), key.length());
 
         {
             const int32_t type32 = int32_t(type);
-            helper_write(&type32, sizeof(type32), file);
+            helper_write(file, type32);
         }
 
         uint32_t data[16];
@@ -233,69 +256,67 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft
 
         if (type == GGUF_TYPE_STRING) {
             const uint64_t n = rng() % sizeof(data);
-            helper_write(&n,   sizeof(n), file);
-            helper_write(data,        n,  file);
+            helper_write(file, n);
+            helper_write(file, data, n);
             continue;
         }
 
         if (type == GGUF_TYPE_ARRAY) {
             {
                 const int32_t type32 = int32_t(type_arr);
-                helper_write(&type32, sizeof(type32), file);
+                helper_write(file, type32);
             }
             if (type_arr == GGUF_TYPE_STRING) {
                 const uint64_t nstr = rng() % (16 + 1);
-                helper_write(&nstr, sizeof(nstr), file);
+                helper_write(file, nstr);
                 for (uint64_t istr = 0; istr < nstr; ++istr) {
                     const uint64_t n = rng() % (sizeof(uint32_t) + 1);
-                    helper_write(&n,          sizeof(n), file);
-                    helper_write(&data[istr],        n,  file);
+                    helper_write(file, n);
+                    helper_write(file, &data[istr], n);
                 }
                 continue;
             }
             const size_t type_size = gguf_type_size(type_arr);
             const uint64_t n = (rng() % sizeof(data)) / type_size;
-            helper_write(&n,    sizeof(n),   file);
-            helper_write(&data, n*type_size, file);
+            helper_write(file, n);
+            helper_write(file, &data, n*type_size);
             continue;
         }
 
-        size_t type_size = hft == HANDCRAFTED_KV_BAD_TYPE ? 1 : gguf_type_size(type);
-        if (hft == HANDCRAFTED_KV_BAD_VALUE_SIZE) {
-            type_size += rng() % 3;
-        }
-        helper_write(data, type_size, file);
+        helper_write(file, data, hft == HANDCRAFTED_KV_BAD_TYPE ? 1 : gguf_type_size(type));
     }
 
-    if (hft == HANDCRAFTED_TENSORS_CUSTOM_ALIGN || hft == HANDCRAFTED_DATA_CUSTOM_ALIGN) {
-        const std::string key = "general.alignment";
-        {
-            const uint64_t n = key.length();
-            helper_write(&n, sizeof(n), file);
-        }
-        helper_write(key.data(), key.length(), file);
+    if (hft == HANDCRAFTED_KV_BAD_ALIGN      ||
+        hft == HANDCRAFTED_TENSORS_BAD_ALIGN || hft == HANDCRAFTED_TENSORS_CUSTOM_ALIGN ||
+        hft == HANDCRAFTED_DATA_BAD_ALIGN    || hft == HANDCRAFTED_DATA_CUSTOM_ALIGN) {
+
+        const uint64_t n = strlen(GGUF_KEY_GENERAL_ALIGNMENT);
+        helper_write(file, n);
+        helper_write(file, GGUF_KEY_GENERAL_ALIGNMENT, n);
 
         const int32_t type = gguf_type(GGUF_TYPE_UINT32);
-        helper_write(&type, sizeof(type), file);
+        helper_write(file, type);
 
-        const uint32_t alignment = GGUF_DEFAULT_ALIGNMENT + 1;
-        helper_write(&alignment, sizeof(alignment), file);
+        alignment = expect_context_not_null(hft) ? 1 : 13;
+        helper_write(file, alignment);
     }
 
     if (hft < offset_has_tensors) {
+        while (ftell(file) % alignment != 0) {
+            const char pad = 0;
+            helper_write(file, pad);
+        }
+
         for (int i = 0; i < extra_bytes; ++i) {
             const char tmp = 0;
-            helper_write(&tmp, sizeof(tmp), file);
+            helper_write(file, tmp);
         }
         rewind(file);
         return file;
     }
 
-    uint32_t alignment = GGUF_DEFAULT_ALIGNMENT;
-    if (hft == HANDCRAFTED_TENSORS_BAD_ALIGNMENT || hft == HANDCRAFTED_DATA_BAD_ALIGNMENT) {
-        alignment -= 1;
-    } else if (hft == HANDCRAFTED_TENSORS_CUSTOM_ALIGN || hft == HANDCRAFTED_DATA_CUSTOM_ALIGN) {
-        alignment += 1;
+    if (hft == HANDCRAFTED_TENSORS_INCONSISTENT_ALIGN || hft == HANDCRAFTED_DATA_INCONSISTENT_ALIGN) {
+        alignment = 1;
     }
 
     uint64_t offset = 0;
@@ -313,9 +334,9 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft
         }
         {
             const uint64_t n = name.length();
-            helper_write(&n, sizeof(n), file);
+            helper_write(file, n);
         }
-        helper_write(name.data(), name.length(), file);
+        helper_write(file, name.data(), name.length());
 
         uint32_t n_dims = hft == HANDCRAFTED_TENSORS_NE_TOO_BIG ? 2 : 1;
         for (int i = GGML_MAX_DIMS-1; i >= 1; --i) {
@@ -326,35 +347,35 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft
         }
         if (hft == HANDCRAFTED_TENSORS_BAD_N_DIMS) {
             const uint32_t n_dims_bad = GGML_MAX_DIMS + 1;
-            helper_write(&n_dims_bad, sizeof(n_dims_bad), file);
+            helper_write(file, n_dims_bad);
         } else {
-            helper_write(&n_dims,     sizeof(n_dims),     file);
+            helper_write(file, n_dims);
         }
 
         if (hft == HANDCRAFTED_TENSORS_BAD_SHAPE) {
             for (uint32_t j = 0; j < n_dims; ++j) {
                 const int64_t bad_dim = -1;
-                helper_write(&bad_dim, sizeof(bad_dim), file);
+                helper_write(file, bad_dim);
             }
         } else if (hft == HANDCRAFTED_TENSORS_NE_TOO_BIG){
             for (uint32_t j = 0; j < n_dims; ++j) {
                 const int64_t big_dim = 4*int64_t(INT32_MAX);
-                helper_write(&big_dim, sizeof(big_dim), file);
+                helper_write(file, big_dim);
             }
         } else {
-            helper_write(shape.data(), n_dims*sizeof(int64_t), file);
+            helper_write(file, shape.data(), n_dims*sizeof(int64_t));
         }
 
         {
-            const int32_t type32 = hft == HANDCRAFTED_TENSORS_BAD_TYPE ? -1 : int32_t(type);
-            helper_write(&type32, sizeof(type32), file);
+            const int32_t type32 = hft == HANDCRAFTED_TENSORS_BAD_TYPE ? GGML_TYPE_COUNT : int32_t(type);
+            helper_write(file, type32);
         }
 
         if (hft == HANDCRAFTED_TENSORS_BAD_OFFSET) {
             const uint64_t bad_offset = -1;
-            helper_write(&bad_offset, sizeof(bad_offset), file);
+            helper_write(file, bad_offset);
         } else {
-            helper_write(&offset, sizeof(offset), file);
+            helper_write(file, offset);
         }
 
         int64_t ne = shape[0];
@@ -364,12 +385,9 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft
         offset += GGML_PAD(ggml_row_size(type, ne), alignment);
     }
 
-    const uint32_t alignment_overshoot = ftell(file) % alignment;
-    if (alignment_overshoot != 0) {
-        for (size_t i = alignment_overshoot; i < alignment; ++i) {
-            const char pad = 0;
-            helper_write(&pad, sizeof(pad), file);
-        }
+    while (ftell(file) % alignment != 0) {
+        const char pad = 0;
+        helper_write(file, pad);
     }
 
     if (hft >= offset_has_data) {
@@ -380,13 +398,13 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft
         }
         for (uint64_t i = 0; i < nbytes; ++i) {
             const uint8_t random_byte = i % 256;
-            helper_write(&random_byte, sizeof(random_byte), file);
+            helper_write(file, random_byte);
         }
     }
 
     for (int i = 0; i < extra_bytes; ++i) {
         const char tmp = 0;
-        helper_write(&tmp, sizeof(tmp), file);
+        helper_write(file, tmp);
     }
     rewind(file);
     return file;
@@ -505,6 +523,16 @@ static bool handcrafted_check_kv(const gguf_context * gguf_ctx, const unsigned i
             }
 
             const char * data_gguf = reinterpret_cast<const char *>(gguf_get_arr_data(gguf_ctx, id));
+
+            if (type_arr == GGUF_TYPE_BOOL) {
+                for (size_t arr_i = 0; arr_i < arr_n; ++arr_i) {
+                    if (bool(data8[arr_i]) != bool(data_gguf[arr_i])) {
+                        ok = false;
+                    }
+                }
+                continue;
+            }
+
             if (!std::equal(data8, data8 + arr_n*type_size, data_gguf)) {
                 ok = false;
             }
@@ -512,12 +540,20 @@ static bool handcrafted_check_kv(const gguf_context * gguf_ctx, const unsigned i
         }
 
         const char * data_gguf = reinterpret_cast<const char *>(gguf_get_val_data(gguf_ctx, id));
+
+        if (type == GGUF_TYPE_BOOL) {
+            if (bool(*data8) != bool(*data_gguf)) {
+                ok = false;
+            }
+            continue;
+        }
+
         if (!std::equal(data8, data8 + gguf_type_size(type), data_gguf)) {
             ok = false;
         }
     }
 
-    const uint32_t expected_alignment = alignment_defined ? GGUF_DEFAULT_ALIGNMENT + 1 : GGUF_DEFAULT_ALIGNMENT;
+    const uint32_t expected_alignment = alignment_defined ? 1 : GGUF_DEFAULT_ALIGNMENT;
     if (gguf_get_alignment(gguf_ctx) != expected_alignment) {
         ok = false;
     }
@@ -539,7 +575,7 @@ static bool handcrafted_check_tensors(const gguf_context * gguf_ctx, const unsig
 
     bool ok = true;
 
-    const int id_alignment = gguf_find_key(gguf_ctx, "general.alignment");
+    const int id_alignment = gguf_find_key(gguf_ctx, GGUF_KEY_GENERAL_ALIGNMENT);
     const uint32_t alignment = id_alignment >= 0 ? gguf_get_val_u32(gguf_ctx, id_alignment) : GGUF_DEFAULT_ALIGNMENT;
 
     uint64_t expected_offset = 0;
@@ -607,7 +643,7 @@ static bool handcrafted_check_tensor_data(const gguf_context * gguf_ctx, const u
 
         std::vector<uint8_t> data(size);
         GGML_ASSERT(fseek(file, gguf_get_data_offset(gguf_ctx) + offset, SEEK_SET) == 0);
-        GGML_ASSERT(fread(data.data(), 1, size, file) == size);
+        GGML_ASSERT(fread(data.data(), 1, data.size(), file) == data.size());
 
         for (size_t j = 0; j < size; ++j) {
             const uint8_t expected_byte = (j + offset) % 256;
@@ -627,15 +663,15 @@ static std::pair<int, int> test_handcrafted_file(const unsigned int seed) {
     const std::vector<handcrafted_file_type> hfts = {
         HANDCRAFTED_HEADER_BAD_MAGIC,
         HANDCRAFTED_HEADER_BAD_VERSION_1,
-        // HANDCRAFTED_FILE_TYPE_BAD_VERSION_FUTURE, // FIXME
+        HANDCRAFTED_HEADER_BAD_VERSION_FUTURE,
         HANDCRAFTED_HEADER_BAD_N_KV,
         HANDCRAFTED_HEADER_BAD_N_TENSORS,
         HANDCRAFTED_HEADER_EMPTY,
 
         HANDCRAFTED_KV_BAD_KEY_SIZE,
         HANDCRAFTED_KV_BAD_TYPE,
-        // HANDCRAFTED_KV_BAD_VALUE_SIZE, // FIXME sanitizer limit
-        // HANDCRAFTED_FILE_TYPE_DUPLICATE_KEY, // FIXME
+        HANDCRAFTED_KV_DUPLICATE_KEY,
+        HANDCRAFTED_KV_BAD_ALIGN,
         HANDCRAFTED_KV_SUCCESS,
 
         HANDCRAFTED_TENSORS_BAD_NAME_SIZE,
@@ -643,14 +679,16 @@ static std::pair<int, int> test_handcrafted_file(const unsigned int seed) {
         HANDCRAFTED_TENSORS_BAD_SHAPE,
         HANDCRAFTED_TENSORS_NE_TOO_BIG,
         HANDCRAFTED_TENSORS_BAD_TYPE,
-        // HANDCRAFTED_TENSORS_BAD_OFFSET, // FIXME
+        HANDCRAFTED_TENSORS_BAD_OFFSET,
         HANDCRAFTED_TENSORS_DUPLICATE_NAME,
-        // HANDCRAFTED_TENSORS_BAD_ALIGNMENT, // FIXME
+        HANDCRAFTED_TENSORS_BAD_ALIGN,
+        HANDCRAFTED_TENSORS_INCONSISTENT_ALIGN,
         HANDCRAFTED_TENSORS_SUCCESS,
         HANDCRAFTED_TENSORS_CUSTOM_ALIGN,
 
         HANDCRAFTED_DATA_NOT_ENOUGH_DATA,
-        // HANDCRAFTED_DATA_BAD_ALIGNMENT, // FIXME
+        HANDCRAFTED_DATA_BAD_ALIGN,
+        HANDCRAFTED_DATA_INCONSISTENT_ALIGN,
         HANDCRAFTED_DATA_SUCCESS,
         HANDCRAFTED_DATA_CUSTOM_ALIGN,
     };
@@ -674,6 +712,7 @@ static std::pair<int, int> test_handcrafted_file(const unsigned int seed) {
             /*no_alloc =*/ false,
             /*ctx      =*/ hft >= offset_has_data ? &ctx : nullptr,
         };
+
         struct gguf_context * gguf_ctx = gguf_init_from_file_impl(file, gguf_params);
 
         if (expect_context_not_null(hft)) {
@@ -689,7 +728,7 @@ static std::pair<int, int> test_handcrafted_file(const unsigned int seed) {
         }
         ntest++;
 
-        if (false && hft >= offset_has_data && !expect_context_not_null(hft)) { // FIXME
+        if (hft >= offset_has_data && !expect_context_not_null(hft)) {
             printf("%s:   - no_dangling_ggml_context_pointer: ", __func__);
             if (ctx) {
                 printf("\033[1;31mFAIL\033[0m\n");
@@ -700,23 +739,6 @@ static std::pair<int, int> test_handcrafted_file(const unsigned int seed) {
             ntest++;
         }
 
-        if (false && expect_context_not_null(hft)) { // FIXME
-            FILE * file_eb = get_handcrafted_file(seed, hft, /*extra_bytes =*/ 1);
-            struct gguf_context * gguf_ctx_eb = gguf_init_from_file_impl(file_eb, gguf_params);
-
-            printf("%s:   - context_null_with_extra_bytes: ", __func__);
-            if (gguf_ctx_eb) {
-                printf("\033[1;31mFAIL\033[0m\n");
-            } else {
-                printf("\033[1;32mOK\033[0m\n");
-                npass++;
-            }
-            ntest++;
-
-            gguf_free(gguf_ctx_eb);
-            fclose(file_eb);
-        }
-
         const bool alignment_defined = hft == HANDCRAFTED_TENSORS_CUSTOM_ALIGN || hft == HANDCRAFTED_DATA_CUSTOM_ALIGN;
 
         if (expect_context_not_null(hft)) {
@@ -763,14 +785,15 @@ static std::pair<int, int> test_handcrafted_file(const unsigned int seed) {
             ntest++;
         }
 
+        fclose(file);
         if (gguf_ctx) {
             ggml_free(ctx);
             gguf_free(gguf_ctx);
         }
-        fclose(file);
         printf("\n");
     }
 
+
     return std::make_pair(npass, ntest);
 }
 
@@ -789,10 +812,6 @@ static struct random_gguf_context_result get_random_gguf_context(ggml_backend_t
         const std::string key = "my_key_" + std::to_string(rng() % 1024);
         const enum gguf_type type = gguf_type(rng() % GGUF_TYPE_COUNT);
 
-        if (type == GGUF_TYPE_STRING || type == GGUF_TYPE_ARRAY) {
-            continue; // FIXME memory leak
-        }
-
         switch (type) {
             case GGUF_TYPE_UINT8:   gguf_set_val_u8  (gguf_ctx, key.c_str(), rng() % (1 <<  7));             break;
             case GGUF_TYPE_INT8:    gguf_set_val_i8  (gguf_ctx, key.c_str(), rng() % (1 <<  7) - (1 <<  6)); break;
@@ -826,6 +845,9 @@ static struct random_gguf_context_result get_random_gguf_context(ggml_backend_t
                         std::vector<uint32_t> random_data((nbytes + sizeof(uint32_t) - 1) / sizeof(uint32_t));
                         for (size_t j = 0; j < random_data.size(); ++j) {
                             random_data[j] = rng();
+                            if (type_arr == GGUF_TYPE_BOOL) {
+                                random_data[j] &= 0x01010101; // the sanitizer complains if booleans are not 0 or 1
+                            }
                         }
                         gguf_set_arr_data(gguf_ctx, key.c_str(), type_arr, random_data.data(), ne);
                     } break;
@@ -928,6 +950,17 @@ static bool all_kv_in_other(const gguf_context * ctx, const gguf_context * other
                 continue;
             }
 
+            if (type_arr == GGUF_TYPE_BOOL) {
+                const int8_t * data       = reinterpret_cast<const int8_t *>(gguf_get_arr_data(ctx,   id));
+                const int8_t * data_other = reinterpret_cast<const int8_t *>(gguf_get_arr_data(other, idx_other));
+                for (int arr_i = 0; arr_i < arr_n; ++arr_i) {
+                    if (bool(data[arr_i]) != bool(data_other[arr_i])) {
+                        ok = false;
+                    }
+                }
+                continue;
+            }
+
             if (type_arr == GGUF_TYPE_STRING) {
                 for (int arr_i = 0; arr_i < arr_n; ++arr_i) {
                     const std::string str       = gguf_get_arr_str(ctx,   id,       arr_i);
@@ -939,8 +972,8 @@ static bool all_kv_in_other(const gguf_context * ctx, const gguf_context * other
                 continue;
             }
 
-            const char * data       = reinterpret_cast<const char *>(gguf_get_arr_data(ctx,   id));
-            const char * data_other = reinterpret_cast<const char *>(gguf_get_arr_data(other, idx_other));
+            const int8_t * data       = reinterpret_cast<const int8_t *>(gguf_get_arr_data(ctx,   id));
+            const int8_t * data_other = reinterpret_cast<const int8_t *>(gguf_get_arr_data(other, idx_other));
             if (!std::equal(data, data + arr_n*gguf_type_size(type_arr), data_other)) {
                 ok = false;
             }
@@ -1028,21 +1061,6 @@ static bool same_tensor_data(const struct ggml_context * orig, const struct ggml
 }
 
 static std::pair<int, int> test_roundtrip(ggml_backend_dev_t dev, const unsigned int seed, const bool only_meta) {
-    FILE * file = tmpfile();
-#ifdef _WIN32
-    if (!file) {
-        printf("%s: failed to create tmpfile(), needs elevated privileges on Windows");
-        printf("%s: skipping tests");
-        return std::make_pair(0, 0);
-    }
-#else
-    GGML_ASSERT(file);
-#endif // _WIN32
-
-    if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_CPU) {
-        return std::make_pair(0, 0); // FIXME
-    }
-
     ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
     printf("%s: device=%s, backend=%s, only_meta=%s\n",
         __func__, ggml_backend_dev_description(dev), ggml_backend_name(backend), only_meta ? "yes" : "no");
@@ -1060,10 +1078,24 @@ static std::pair<int, int> test_roundtrip(ggml_backend_dev_t dev, const unsigned
         bbuf       = result.buffer;
     }
 
-    struct gguf_buf gbuf = gguf_buf_init(16 * 1024);
-    gguf_write_to_buf(gguf_ctx_0, &gbuf, only_meta);
-    helper_write(gbuf.data, gbuf.offset, file);
-    rewind(file);
+    FILE * file = tmpfile();
+
+#ifdef _WIN32
+    if (!file) {
+        printf("%s: failed to create tmpfile(), needs elevated privileges on Windows");
+        printf("%s: skipping tests");
+        return std::make_pair(0, 0);
+    }
+#else
+    GGML_ASSERT(file);
+#endif // _WIN32
+
+    {
+        std::vector<int8_t> buf;
+        gguf_write_to_buf(gguf_ctx_0, buf, only_meta);
+        GGML_ASSERT(fwrite(buf.data(), 1, buf.size(), file) == buf.size());
+        rewind(file);
+    }
 
     struct ggml_context * ctx_1 = nullptr;
     struct gguf_init_params gguf_params = {
@@ -1151,9 +1183,8 @@ static std::pair<int, int> test_roundtrip(ggml_backend_dev_t dev, const unsigned
     ggml_free(ctx_1);
     gguf_free(gguf_ctx_0);
     gguf_free(gguf_ctx_1);
-    gguf_buf_free(gbuf);
     ggml_backend_free(backend);
-    GGML_ASSERT(fclose(file) == 0);
+    fclose(file);
 
     printf("\n");
     return std::make_pair(npass, ntest);