]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
sync : ggml (Metal F32 support + reduce ggml-alloc size) (#3192)
authorGeorgi Gerganov <redacted>
Fri, 15 Sep 2023 16:06:03 +0000 (19:06 +0300)
committerGitHub <redacted>
Fri, 15 Sep 2023 16:06:03 +0000 (19:06 +0300)
* sync : ggml (Metal F32 support + reduce ggml-alloc size)

ggml-ci

* llama-bench : fix ggml_cpu_has_metal() duplicate function

ggml-ci

examples/llama-bench/llama-bench.cpp
ggml-alloc.c
ggml-metal.m
ggml-metal.metal
ggml.c
ggml.h

index dedaa34fd84ba2dc4e1030c350973802181694ed..34ddfde39d29575fa581575fafbf2a2bfd90c6f6 100644 (file)
@@ -74,14 +74,6 @@ static T stdev(const std::vector<T> & v) {
     return stdev;
 }
 
-static bool ggml_cpu_has_metal() {
-#if defined(GGML_USE_METAL)
-    return true;
-#else
-    return false;
-#endif
-}
-
 static std::string get_cpu_info() {
     std::string id;
 #ifdef __linux__
index a1f6e7bf4f66eec4d9437c86624cb85d7d72bc39..304964be4f3f07e1c10db0f46c4e013e813c989b 100644 (file)
@@ -131,6 +131,10 @@ static bool ggml_allocr_is_own(struct ggml_allocr * alloc, const struct ggml_ten
     return ptr >= alloc->data && (char *)ptr < (char *)alloc->data + alloc->max_size;
 }
 
+static bool ggml_is_view(struct ggml_tensor * t) {
+    return t->view_src != NULL;
+}
+
 void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
 #ifdef GGML_ALLOCATOR_DEBUG
     GGML_ASSERT(!ggml_is_view(tensor)); // views generally get data pointer from one of their sources
@@ -338,8 +342,8 @@ static void free_vmem(void * base_addr, size_t size) {
 
 // allocate uncommitted virtual memory to measure the size of the graph
 static void alloc_measure_vmem(void ** base_addr, size_t * size) {
-    // 1TB for 64-bit, 1GB for 32-bit
-    *size = sizeof(void *) == 4 ? 1ULL<<30 : 1ULL<<40;
+    // 128GB for 64-bit, 1GB for 32-bit
+    *size = sizeof(void *) == 4 ? 1ULL<<30 : 1ULL<<37;
     do {
         *base_addr = alloc_vmem(*size);
         if (*base_addr != NULL) {
@@ -399,10 +403,6 @@ bool ggml_allocr_is_measure(struct ggml_allocr * alloc) {
 
 //////////// compute graph allocator
 
-static bool ggml_is_view(struct ggml_tensor * t) {
-    return t->view_src != NULL;
-}
-
 static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
     if (a->type != b->type) {
         return false;
index 3e3be98c5d91d112c520f560b7d5e61388eadf56..1139ee3114610e05726b7e4319843eade838271c 100644 (file)
@@ -78,6 +78,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(get_rows_q6_K);
     GGML_METAL_DECL_KERNEL(rms_norm);
     GGML_METAL_DECL_KERNEL(norm);
+    GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
     GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
@@ -89,6 +90,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
@@ -237,6 +239,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(get_rows_q6_K);
         GGML_METAL_ADD_KERNEL(rms_norm);
         GGML_METAL_ADD_KERNEL(norm);
+        GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
         GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
@@ -248,6 +251,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
         GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
         GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
         GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
@@ -309,6 +313,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(get_rows_q6_K);
     GGML_METAL_DEL_KERNEL(rms_norm);
     GGML_METAL_DEL_KERNEL(norm);
+    GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
     GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
     GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
     GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
@@ -320,6 +325,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
     GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
     GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
     GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
     GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
     GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
@@ -885,6 +891,7 @@ void ggml_metal_graph_compute(
                                 ne00%32 == 0 &&
                                 ne11 > 1) {
                                 switch (src0->type) {
+                                    case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32];  break;
                                     case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32];  break;
                                     case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
                                     case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
@@ -919,6 +926,11 @@ void ggml_metal_graph_compute(
 
                                 // use custom matrix x vector kernel
                                 switch (src0t) {
+                                    case GGML_TYPE_F32:
+                                        {
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
+                                            nrows = 4;
+                                        } break;
                                     case GGML_TYPE_F16:
                                         {
                                             nth0 = 32;
index ea8b428448b8d1eb8ade086171a772763a393c2c..3087ecda812d90d510ffc6859ff3d4dc52f12e7e 100644 (file)
@@ -523,6 +523,79 @@ kernel void kernel_mul_mat_q8_0_f32(
     }
 }
 
+#define N_F32_F32 4
+
+kernel void kernel_mul_mat_f32_f32(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]]) {
+
+    const int64_t r0 = tgpig.x;
+    const int64_t rb = tgpig.y*N_F32_F32;
+    const int64_t im = tgpig.z;
+
+    device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+
+    if (ne00 < 128) {
+        for (int row = 0; row < N_F32_F32; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+
+            float sumf = 0;
+            for (int i = tiisg; i < ne00; i += 32) {
+                sumf += (float) x[i] * (float) y[i];
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    } else {
+        device const float4 * x4 = (device const float4 *)x;
+        for (int row = 0; row < N_F32_F32; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            device const float  * y  = (device const float  *) (src1 + r1*nb11 + im*nb12);
+            device const float4 * y4 = (device const float4 *) y;
+
+            float sumf = 0;
+            for (int i = tiisg; i < ne00/4; i += 32) {
+                for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    }
+}
+
 kernel void kernel_mul_mat_f16_f32_1row(
         device const  char * src0,
         device const  char * src1,
@@ -1399,13 +1472,13 @@ kernel void kernel_mul_mat_q4_K_f32(
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne01[[buffer(4)]],
-        constant   int64_t & ne02[[buffer(5)]],
-        constant   int64_t & ne10[[buffer(9)]],
-        constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0[[buffer(15)]],
-        constant   int64_t & ne1[[buffer(16)]],
-        constant   uint    & gqa[[buffer(17)]],
+        constant   int64_t & ne01 [[buffer(4)]],
+        constant   int64_t & ne02 [[buffer(5)]],
+        constant   int64_t & ne10 [[buffer(9)]],
+        constant   int64_t & ne12 [[buffer(11)]],
+        constant   int64_t & ne0  [[buffer(15)]],
+        constant   int64_t & ne1  [[buffer(16)]],
+        constant   uint    & gqa  [[buffer(17)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2012,7 +2085,6 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
     for (int i = 0; i < 16; ++i) {
         reg[i/4][i%4] = dl * (q[i] & mask) - ml;
     }
-
 }
 
 template <typename type4x4>
@@ -2269,6 +2341,7 @@ typedef void (mat_mm_t)(
         constant       uint & gqa,
         threadgroup uchar *, uint3, uint, uint);
 
+template [[host_name("kernel_mul_mm_f32_f32")]]  kernel mat_mm_t kernel_mul_mm<float4x4,   1,     dequantize_f32>;
 template [[host_name("kernel_mul_mm_f16_f32")]]  kernel mat_mm_t kernel_mul_mm<half4x4,    1,     dequantize_f16>;
 template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2,     dequantize_q4_0>;
 template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2,     dequantize_q4_1>;
diff --git a/ggml.c b/ggml.c
index 96edebeb8e340bf59b0cda75212f2647a332bc57..a0be068d6c9f73524de04d687eb273bb235b6573 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -17294,10 +17294,18 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
         } else {
             // wait for other threads to finish
             const int last = node_n;
-            do {
-                //sched_yield();
+            while (true) {
+                // TODO: this sched_yield can have significant impact on the performance - either positive or negative
+                //       depending on the workload and the operating system.
+                //       since it is not clear what is the best approach, it should potentially become user-configurable
+                //       ref: https://github.com/ggerganov/ggml/issues/291
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+                sched_yield();
+#endif
+
                 node_n = atomic_load(&state->shared->node_n);
-            } while (node_n == last);
+                if (node_n != last) break;
+            };
         }
 
         // check if we should stop
@@ -18348,7 +18356,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
     for (int i = 0; i < cgraph->n_leafs; i++) {
         struct ggml_tensor * node = cgraph->leafs[i];
 
-        GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
+        GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s %16s\n",
                 i,
                 node->ne[0], node->ne[1],
                 ggml_op_name(node->op),
@@ -20111,27 +20119,27 @@ const char * gguf_type_name(enum gguf_type type) {
     return GGUF_TYPE_NAME[type];
 }
 
-int gguf_get_version(struct gguf_context * ctx) {
+int gguf_get_version(const struct gguf_context * ctx) {
     return ctx->header.version;
 }
 
-size_t gguf_get_alignment(struct gguf_context * ctx) {
+size_t gguf_get_alignment(const struct gguf_context * ctx) {
     return ctx->alignment;
 }
 
-size_t gguf_get_data_offset(struct gguf_context * ctx) {
+size_t gguf_get_data_offset(const struct gguf_context * ctx) {
     return ctx->offset;
 }
 
-void * gguf_get_data(struct gguf_context * ctx) {
+void * gguf_get_data(const struct gguf_context * ctx) {
     return ctx->data;
 }
 
-int gguf_get_n_kv(struct gguf_context * ctx) {
+int gguf_get_n_kv(const struct gguf_context * ctx) {
     return ctx->header.n_kv;
 }
 
-int gguf_find_key(struct gguf_context * ctx, const char * key) {
+int gguf_find_key(const struct gguf_context * ctx, const char * key) {
     // return -1 if key not found
     int keyfound = -1;
 
@@ -20147,85 +20155,85 @@ int gguf_find_key(struct gguf_context * ctx, const char * key) {
     return keyfound;
 }
 
-const char * gguf_get_key(struct gguf_context * ctx, int i) {
+const char * gguf_get_key(const struct gguf_context * ctx, int i) {
     return ctx->kv[i].key.data;
 }
 
-enum gguf_type gguf_get_kv_type(struct gguf_context * ctx, int i) {
+enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int i) {
     return ctx->kv[i].type;
 }
 
-enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i) {
+enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int i) {
     return ctx->kv[i].value.arr.type;
 }
 
-const void * gguf_get_arr_data(struct gguf_context * ctx, int i) {
+const void * gguf_get_arr_data(const struct gguf_context * ctx, int i) {
     return ctx->kv[i].value.arr.data;
 }
 
-const char * gguf_get_arr_str(struct gguf_context * ctx, int key_id, int i) {
+const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) {
     struct gguf_kv * kv = &ctx->kv[key_id];
     struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
     return str->data;
 }
 
-int gguf_get_arr_n(struct gguf_context * ctx, int i) {
+int gguf_get_arr_n(const struct gguf_context * ctx, int i) {
     return ctx->kv[i].value.arr.n;
 }
 
-uint8_t gguf_get_val_u8(struct gguf_context * ctx, int i) {
+uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int i) {
     return ctx->kv[i].value.uint8;
 }
 
-int8_t gguf_get_val_i8(struct gguf_context * ctx, int i) {
+int8_t gguf_get_val_i8(const struct gguf_context * ctx, int i) {
     return ctx->kv[i].value.int8;
 }
 
-uint16_t gguf_get_val_u16(struct gguf_context * ctx, int i) {
+uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int i) {
     return ctx->kv[i].value.uint16;
 }
 
-int16_t gguf_get_val_i16(struct gguf_context * ctx, int i) {
+int16_t gguf_get_val_i16(const struct gguf_context * ctx, int i) {
     return ctx->kv[i].value.int16;
 }
 
-uint32_t gguf_get_val_u32(struct gguf_context * ctx, int i) {
+uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int i) {
     return ctx->kv[i].value.uint32;
 }
 
-int32_t gguf_get_val_i32(struct gguf_context * ctx, int i) {
+int32_t gguf_get_val_i32(const struct gguf_context * ctx, int i) {
     return ctx->kv[i].value.int32;
 }
 
-float gguf_get_val_f32(struct gguf_context * ctx, int i) {
+float gguf_get_val_f32(const struct gguf_context * ctx, int i) {
     return ctx->kv[i].value.float32;
 }
 
-uint64_t gguf_get_val_u64(struct gguf_context * ctx, int i) {
+uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int i) {
     return ctx->kv[i].value.uint64;
 }
 
-int64_t gguf_get_val_i64(struct gguf_context * ctx, int i) {
+int64_t gguf_get_val_i64(const struct gguf_context * ctx, int i) {
     return ctx->kv[i].value.int64;
 }
 
-double gguf_get_val_f64(struct gguf_context * ctx, int i) {
+double gguf_get_val_f64(const struct gguf_context * ctx, int i) {
     return ctx->kv[i].value.float64;
 }
 
-bool gguf_get_val_bool(struct gguf_context * ctx, int i) {
+bool gguf_get_val_bool(const struct gguf_context * ctx, int i) {
     return ctx->kv[i].value.bool_;
 }
 
-const char * gguf_get_val_str (struct gguf_context * ctx, int i) {
+const char * gguf_get_val_str (const struct gguf_context * ctx, int i) {
     return ctx->kv[i].value.str.data;
 }
 
-int gguf_get_n_tensors(struct gguf_context * ctx) {
+int gguf_get_n_tensors(const struct gguf_context * ctx) {
     return ctx->header.n_tensors;
 }
 
-int gguf_find_tensor(struct gguf_context * ctx, const char * name) {
+int gguf_find_tensor(const struct gguf_context * ctx, const char * name) {
     // return -1 if tensor not found
     int tensorfound = -1;
 
@@ -20241,11 +20249,11 @@ int gguf_find_tensor(struct gguf_context * ctx, const char * name) {
     return tensorfound;
 }
 
-size_t gguf_get_tensor_offset(struct gguf_context * ctx, int i) {
+size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i) {
     return ctx->infos[i].offset;
 }
 
-char * gguf_get_tensor_name(struct gguf_context * ctx, int i) {
+char * gguf_get_tensor_name(const struct gguf_context * ctx, int i) {
     return ctx->infos[i].name.data;
 }
 
@@ -20528,7 +20536,7 @@ static void gguf_bwrite_el(struct gguf_buf * buf, const void * val, size_t el_si
     buf->offset += el_size;
 }
 
-static void gguf_write_to_buf(struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) {
+static void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) {
     // write header
     gguf_bwrite_el(buf, &ctx->header.magic,     sizeof(ctx->header.magic));
     gguf_bwrite_el(buf, &ctx->header.version,   sizeof(ctx->header.version));
@@ -20643,7 +20651,7 @@ static void gguf_write_to_buf(struct gguf_context * ctx, struct gguf_buf * buf,
     }
 }
 
-void gguf_write_to_file(struct gguf_context * ctx, const char * fname, bool only_meta) {
+void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) {
     FILE * file = fopen(fname, "wb");
     if (!file) {
         GGML_ASSERT(false && "failed to open file for writing");
@@ -20660,7 +20668,7 @@ void gguf_write_to_file(struct gguf_context * ctx, const char * fname, bool only
     fclose(file);
 }
 
-size_t gguf_get_meta_size(struct gguf_context * ctx) {
+size_t gguf_get_meta_size(const struct gguf_context * ctx) {
     // no allocs - only compute size
     struct gguf_buf buf = gguf_buf_init(0);
 
@@ -20669,7 +20677,7 @@ size_t gguf_get_meta_size(struct gguf_context * ctx) {
     return buf.offset;
 }
 
-void gguf_get_meta_data(struct gguf_context * ctx, void * data) {
+void gguf_get_meta_data(const struct gguf_context * ctx, void * data) {
     struct gguf_buf buf = gguf_buf_init(16*1024);
 
     gguf_write_to_buf(ctx, &buf, true);
@@ -20745,6 +20753,14 @@ int ggml_cpu_has_arm_fma(void) {
 #endif
 }
 
+int ggml_cpu_has_metal(void) {
+#if defined(GGML_USE_METAL)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
 int ggml_cpu_has_f16c(void) {
 #if defined(__F16C__)
     return 1;
diff --git a/ggml.h b/ggml.h
index 6d4cf465d62b56df3561504d42f796ae24e8762d..f45456876da621d5beafdc2419d9d575a9fb02ff 100644 (file)
--- a/ggml.h
+++ b/ggml.h
 #    define GGML_DEPRECATED(func, hint) func
 #endif
 
+#ifndef __GNUC__
+#    define GGML_ATTRIBUTE_FORMAT(...)
+#elif defined(__MINGW32__)
+#    define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
+#else
+#    define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
+#endif
+
 #include <stdint.h>
 #include <stddef.h>
 #include <stdbool.h>
@@ -270,7 +278,7 @@ extern "C" {
 
 #if defined(__ARM_NEON) && defined(__CUDACC__)
     typedef half ggml_fp16_t;
-#elif defined(__ARM_NEON) && !defined(_MSC_VER)
+#elif defined(__ARM_NEON)
     typedef __fp16 ggml_fp16_t;
 #else
     typedef uint16_t ggml_fp16_t;
@@ -685,6 +693,7 @@ extern "C" {
 
     GGML_API const char *         ggml_get_name   (const struct ggml_tensor * tensor);
     GGML_API struct ggml_tensor * ggml_set_name   (      struct ggml_tensor * tensor, const char * name);
+    GGML_ATTRIBUTE_FORMAT(2, 3)
     GGML_API struct ggml_tensor * ggml_format_name(      struct ggml_tensor * tensor, const char * fmt, ...);
 
     //
@@ -1866,39 +1875,39 @@ extern "C" {
 
     GGML_API const char * gguf_type_name(enum gguf_type type);
 
-    GGML_API int    gguf_get_version    (struct gguf_context * ctx);
-    GGML_API size_t gguf_get_alignment  (struct gguf_context * ctx);
-    GGML_API size_t gguf_get_data_offset(struct gguf_context * ctx);
-    GGML_API void * gguf_get_data       (struct gguf_context * ctx);
+    GGML_API int    gguf_get_version    (const struct gguf_context * ctx);
+    GGML_API size_t gguf_get_alignment  (const struct gguf_context * ctx);
+    GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx);
+    GGML_API void * gguf_get_data       (const struct gguf_context * ctx);
 
-    GGML_API int          gguf_get_n_kv(struct gguf_context * ctx);
-    GGML_API int          gguf_find_key(struct gguf_context * ctx, const char * key);
-    GGML_API const char * gguf_get_key (struct gguf_context * ctx, int i);
+    GGML_API int          gguf_get_n_kv(const struct gguf_context * ctx);
+    GGML_API int          gguf_find_key(const struct gguf_context * ctx, const char * key);
+    GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int i);
 
-    GGML_API enum gguf_type gguf_get_kv_type (struct gguf_context * ctx, int i);
-    GGML_API enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i);
+    GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int i);
+    GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int i);
 
     // results are undefined if the wrong type is used for the key
-    GGML_API uint8_t      gguf_get_val_u8  (struct gguf_context * ctx, int i);
-    GGML_API int8_t       gguf_get_val_i8  (struct gguf_context * ctx, int i);
-    GGML_API uint16_t     gguf_get_val_u16 (struct gguf_context * ctx, int i);
-    GGML_API int16_t      gguf_get_val_i16 (struct gguf_context * ctx, int i);
-    GGML_API uint32_t     gguf_get_val_u32 (struct gguf_context * ctx, int i);
-    GGML_API int32_t      gguf_get_val_i32 (struct gguf_context * ctx, int i);
-    GGML_API float        gguf_get_val_f32 (struct gguf_context * ctx, int i);
-    GGML_API uint64_t     gguf_get_val_u64 (struct gguf_context * ctx, int i);
-    GGML_API int64_t      gguf_get_val_i64 (struct gguf_context * ctx, int i);
-    GGML_API double       gguf_get_val_f64 (struct gguf_context * ctx, int i);
-    GGML_API bool         gguf_get_val_bool(struct gguf_context * ctx, int i);
-    GGML_API const char * gguf_get_val_str (struct gguf_context * ctx, int i);
-    GGML_API int          gguf_get_arr_n   (struct gguf_context * ctx, int i);
-    GGML_API const void * gguf_get_arr_data(struct gguf_context * ctx, int i);
-    GGML_API const char * gguf_get_arr_str (struct gguf_context * ctx, int key_id, int i);
-
-    GGML_API int    gguf_get_n_tensors    (struct gguf_context * ctx);
-    GGML_API int    gguf_find_tensor      (struct gguf_context * ctx, const char * name);
-    GGML_API size_t gguf_get_tensor_offset(struct gguf_context * ctx, int i);
-    GGML_API char * gguf_get_tensor_name  (struct gguf_context * ctx, int i);
+    GGML_API uint8_t      gguf_get_val_u8  (const struct gguf_context * ctx, int i);
+    GGML_API int8_t       gguf_get_val_i8  (const struct gguf_context * ctx, int i);
+    GGML_API uint16_t     gguf_get_val_u16 (const struct gguf_context * ctx, int i);
+    GGML_API int16_t      gguf_get_val_i16 (const struct gguf_context * ctx, int i);
+    GGML_API uint32_t     gguf_get_val_u32 (const struct gguf_context * ctx, int i);
+    GGML_API int32_t      gguf_get_val_i32 (const struct gguf_context * ctx, int i);
+    GGML_API float        gguf_get_val_f32 (const struct gguf_context * ctx, int i);
+    GGML_API uint64_t     gguf_get_val_u64 (const struct gguf_context * ctx, int i);
+    GGML_API int64_t      gguf_get_val_i64 (const struct gguf_context * ctx, int i);
+    GGML_API double       gguf_get_val_f64 (const struct gguf_context * ctx, int i);
+    GGML_API bool         gguf_get_val_bool(const struct gguf_context * ctx, int i);
+    GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int i);
+    GGML_API int          gguf_get_arr_n   (const struct gguf_context * ctx, int i);
+    GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int i);
+    GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
+
+    GGML_API int    gguf_get_n_tensors    (const struct gguf_context * ctx);
+    GGML_API int    gguf_find_tensor      (const struct gguf_context * ctx, const char * name);
+    GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i);
+    GGML_API char * gguf_get_tensor_name  (const struct gguf_context * ctx, int i);
 
     // overrides existing values or adds a new one
     GGML_API void gguf_set_val_u8  (struct gguf_context * ctx, const char * key, uint8_t  val);
@@ -1943,11 +1952,11 @@ extern "C" {
     //
 
     // write the entire context to a binary file
-    GGML_API void gguf_write_to_file(struct gguf_context * ctx, const char * fname, bool only_meta);
+    GGML_API void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta);
 
     // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
-    GGML_API size_t gguf_get_meta_size(struct gguf_context * ctx);
-    GGML_API void   gguf_get_meta_data(struct gguf_context * ctx, void * data);
+    GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
+    GGML_API void   gguf_get_meta_data(const struct gguf_context * ctx, void * data);
 
     //
     // system info
@@ -1961,6 +1970,7 @@ extern "C" {
     GGML_API int ggml_cpu_has_fma        (void);
     GGML_API int ggml_cpu_has_neon       (void);
     GGML_API int ggml_cpu_has_arm_fma    (void);
+    GGML_API int ggml_cpu_has_metal      (void);
     GGML_API int ggml_cpu_has_f16c       (void);
     GGML_API int ggml_cpu_has_fp16_va    (void);
     GGML_API int ggml_cpu_has_wasm_simd  (void);