]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : sync llama.cpp (gguf + metal + ROCm + etc.) (#489)
authorGeorgi Gerganov <redacted>
Mon, 28 Aug 2023 11:31:13 +0000 (14:31 +0300)
committerGitHub <redacted>
Mon, 28 Aug 2023 11:31:13 +0000 (14:31 +0300)
* ggml : sync llama.cpp (gguf + metal + ROCm + etc.)

ggml-ci

* cuda : sync rope updates

ggml-ci

include/ggml/ggml-alloc.h
include/ggml/ggml.h
src/ggml-alloc.c
src/ggml-cuda.cu
src/ggml-cuda.h
src/ggml-metal.h
src/ggml-metal.m
src/ggml-metal.metal
src/ggml.c

index 14a4350ac2e968f5455632f2b7ad5dba5f4a2d8e..9559da75871a608fb29f08edebf860674295e043 100644 (file)
@@ -12,7 +12,7 @@ GGML_API struct ggml_allocr * ggml_allocr_new_measure(size_t alignment);
 
 // tell the allocator to parse nodes following the order described in the list
 // you should call this if your graph are optimized to execute out-of-order
-GGML_API void   ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, int * list, int n);
+GGML_API void   ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n);
 
 GGML_API void   ggml_allocr_free(struct ggml_allocr * alloc);
 GGML_API bool   ggml_allocr_is_measure(struct ggml_allocr * alloc);
index f83fb5c26857be9f8faa44d0f97c8640518013e6..4ef3d525371fe6ae2621eabfc9e3cd349da4c8b8 100644 (file)
 #define GGML_EXIT_ABORTED 1
 
 #define GGUF_MAGIC   0x46554747 // "GGUF"
-#define GGUF_VERSION 1
+#define GGUF_VERSION 2
 
 #define GGUF_DEFAULT_ALIGNMENT 32
 
@@ -1835,6 +1835,9 @@ extern "C" {
         GGUF_TYPE_BOOL    = 7,
         GGUF_TYPE_STRING  = 8,
         GGUF_TYPE_ARRAY   = 9,
+        GGUF_TYPE_UINT64  = 10,
+        GGUF_TYPE_INT64   = 11,
+        GGUF_TYPE_FLOAT64 = 12,
         GGUF_TYPE_COUNT,       // marks the end of the enum
     };
 
@@ -1875,6 +1878,9 @@ extern "C" {
     GGML_API uint32_t     gguf_get_val_u32 (struct gguf_context * ctx, int i);
     GGML_API int32_t      gguf_get_val_i32 (struct gguf_context * ctx, int i);
     GGML_API float        gguf_get_val_f32 (struct gguf_context * ctx, int i);
+    GGML_API uint64_t     gguf_get_val_u64 (struct gguf_context * ctx, int i);
+    GGML_API int64_t      gguf_get_val_i64 (struct gguf_context * ctx, int i);
+    GGML_API double       gguf_get_val_f64 (struct gguf_context * ctx, int i);
     GGML_API bool         gguf_get_val_bool(struct gguf_context * ctx, int i);
     GGML_API const char * gguf_get_val_str (struct gguf_context * ctx, int i);
     GGML_API int          gguf_get_arr_n   (struct gguf_context * ctx, int i);
@@ -1894,6 +1900,9 @@ extern "C" {
     GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val);
     GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t  val);
     GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float    val);
+    GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val);
+    GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t  val);
+    GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double   val);
     GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool     val);
     GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val);
     GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n);
@@ -1952,6 +1961,7 @@ extern "C" {
     GGML_API int ggml_cpu_has_clblast    (void);
     GGML_API int ggml_cpu_has_gpublas    (void);
     GGML_API int ggml_cpu_has_sse3       (void);
+    GGML_API int ggml_cpu_has_ssse3      (void);
     GGML_API int ggml_cpu_has_vsx        (void);
 
     //
index 6810a20f318cb58c71e2c5b81a0ef54b3817fc62..140e9a2a7370a6fb6c32f88f6abbf3f4ae081979 100644 (file)
@@ -8,6 +8,7 @@
 
 #define UNUSED(x) (void)(x)
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
+#define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
 
 //#define GGML_ALLOCATOR_DEBUG
 
@@ -67,8 +68,8 @@ struct ggml_allocr {
     struct hash_node hash_table[GGML_GRAPH_HASHTABLE_SIZE];
     size_t max_size;
     bool measure;
-    int parse_seq[GGML_MAX_NODES];
-    bool has_parse_seq;
+    int parse_seq[GGML_MAX_CONCUR];
+    int parse_seq_len;
 
 #ifdef GGML_ALLOCATOR_DEBUG
     struct ggml_tensor * allocated_tensors[1024];
@@ -238,15 +239,11 @@ static void ggml_allocator_free_tensor(struct ggml_allocr * alloc, struct ggml_t
     alloc->n_free_blocks++;
 }
 
-void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, int * list, int n) {
-    int pos = 0;
+void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n) {
     for (int i = 0; i < n; i++) {
-        if (list[i] != -1) {
-            alloc->parse_seq[pos] = list[i];
-            pos++;
-        }
+        alloc->parse_seq[i] = list[i];
     }
-    alloc->has_parse_seq = true;
+    alloc->parse_seq_len = n;
 }
 
 void ggml_allocr_reset(struct ggml_allocr * alloc) {
@@ -269,7 +266,7 @@ struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment)
         /*.max_size      = */ 0,
         /*.measure       = */ false,
         /*.parse_seq     = */ {0},
-        /*.has_parse_seq = */ false,
+        /*.parse_seq_len = */ 0,
 #ifdef GGML_ALLOCATOR_DEBUG
         /*.allocated_tensors = */ {0},
 #endif
@@ -298,7 +295,7 @@ struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) {
         /*.max_size      = */ 0,
         /*.measure       = */ true,
         /*.parse_seq     = */ {0},
-        /*.has_parse_seq = */ false,
+        /*.parse_seq_len = */ 0,
 #ifdef GGML_ALLOCATOR_DEBUG
         /*.allocated_tensors = */ {0},
 #endif
@@ -445,8 +442,8 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
                         else {
                             AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
                             node->data = parent->data;
+                            return;
                         }
-                        return;
                     }
                 }
             }
@@ -497,69 +494,86 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
                 allocate_node(alloc, input);
             }
         }
-        for (int ind = 0; ind < gf->n_nodes; ind++) {
-            int i;
-            if (alloc->has_parse_seq) {
-                i = alloc->parse_seq[ind];
-            } else {
-                i = ind;
-            }
-            struct ggml_tensor * node = gf->nodes[i];
-
-            // allocate parents (leafs)
-            for (int j = 0; j < GGML_MAX_SRC; j++) {
-                struct ggml_tensor * parent = node->src[j];
-                if (parent == NULL) {
-                    break;
+        // if we have parse_seq then we allocate nodes following the list, and we only free nodes at barriers
+        int last_barrier_pos = 0;
+        int n_nodes = alloc->parse_seq_len ? alloc->parse_seq_len : gf->n_nodes;
+
+        for (int ind = 0; ind < n_nodes; ind++) {
+            // allocate a node if there is no parse_seq or this is not a barrier
+            if ((alloc->parse_seq_len==0) || alloc->parse_seq[ind] != -1) {
+                int i = alloc->parse_seq_len ? alloc->parse_seq[ind] : ind;
+                struct ggml_tensor * node = gf->nodes[i];
+
+                // allocate parents (leafs)
+                for (int j = 0; j < GGML_MAX_SRC; j++) {
+                    struct ggml_tensor * parent = node->src[j];
+                    if (parent == NULL) {
+                        break;
+                    }
+                    allocate_node(alloc, parent);
                 }
-                allocate_node(alloc, parent);
-            }
 
-            // allocate node
-            allocate_node(alloc, node);
+                // allocate node
+                allocate_node(alloc, node);
 
-            AT_PRINTF("exec: %s (%s) <= ", ggml_op_name(node->op), node->name);
-            for (int j = 0; j < GGML_MAX_SRC; j++) {
-                struct ggml_tensor * parent = node->src[j];
-                if (parent == NULL) {
-                    break;
-                }
-                AT_PRINTF("%s", parent->name);
-                if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
-                    AT_PRINTF(", ");
+                AT_PRINTF("exec: %s (%s) <= ", ggml_op_name(node->op), node->name);
+                for (int j = 0; j < GGML_MAX_SRC; j++) {
+                    struct ggml_tensor * parent = node->src[j];
+                    if (parent == NULL) {
+                        break;
+                    }
+                    AT_PRINTF("%s", parent->name);
+                    if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
+                        AT_PRINTF(", ");
+                    }
                 }
+                AT_PRINTF("\n");
             }
-            AT_PRINTF("\n");
+
 
             // update parents
-            for (int j = 0; j < GGML_MAX_SRC; j++) {
-                struct ggml_tensor * parent = node->src[j];
-                if (parent == NULL) {
-                    break;
-                }
-                struct hash_node * p_hn = hash_get(ht, parent);
-                p_hn->n_children -= 1;
-
-                //AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views);
-
-                if (p_hn->n_children == 0 && p_hn->n_views == 0) {
-                    if (ggml_is_view(parent)) {
-                        struct ggml_tensor * view_src = get_view_source(parent);
-                        struct hash_node * view_src_hn = hash_get(ht, view_src);
-                        view_src_hn->n_views -= 1;
-                        AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
-                        if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
-                            ggml_allocator_free_tensor(alloc, view_src);
+            // update immediately if there is no parse_seq
+            // update only at barriers if there is parse_seq
+            if ((alloc->parse_seq_len==0) || alloc->parse_seq[ind] == -1) {
+                int update_start = alloc->parse_seq_len ? last_barrier_pos : ind;
+                int update_end   = alloc->parse_seq_len ? ind              : ind + 1;
+                for (int i = update_start; i < update_end; i++) {
+                    int node_i = alloc->parse_seq_len ? alloc->parse_seq[i] : i;
+                    struct ggml_tensor * node = gf->nodes[node_i];
+
+                    for (int j = 0; j < GGML_MAX_SRC; j++) {
+                        struct ggml_tensor * parent = node->src[j];
+                        if (parent == NULL) {
+                            break;
                         }
-                    }
-                    else {
-                        if (parent->data != node->data) {
-                            ggml_allocator_free_tensor(alloc, parent);
+                        struct hash_node * p_hn = hash_get(ht, parent);
+                        p_hn->n_children -= 1;
+
+                        //AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views);
+
+                        if (p_hn->n_children == 0 && p_hn->n_views == 0) {
+                            if (ggml_is_view(parent)) {
+                                struct ggml_tensor * view_src = get_view_source(parent);
+                                struct hash_node * view_src_hn = hash_get(ht, view_src);
+                                view_src_hn->n_views -= 1;
+                                AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
+                                if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
+                                    ggml_allocator_free_tensor(alloc, view_src);
+                                }
+                            }
+                            else {
+                                if (parent->data != node->data) {
+                                    ggml_allocator_free_tensor(alloc, parent);
+                                }
+                            }
                         }
                     }
                 }
+                AT_PRINTF("\n");
+                if (alloc->parse_seq_len) {
+                    last_barrier_pos = ind + 1;
+                }
             }
-            AT_PRINTF("\n");
         }
         // free graph outputs here that wouldn't be freed otherwise because they have no children
         if (outputs != NULL && outputs[g] != NULL) {
index c0fb9fb650e0d22b04a444ef9056b8129f0ac392..5fd62563022967f0a02dccd1361aec23617fb11d 100644 (file)
 #include <atomic>
 #include <assert.h>
 
+#if defined(GGML_USE_HIPBLAS)
+#include <hip/hip_runtime.h>
+#include <hipblas/hipblas.h>
+#include <hip/hip_fp16.h>
+#ifdef __HIP_PLATFORM_AMD__
+// for rocblas_initialize()
+#include "rocblas/rocblas.h"
+#endif
+#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
+#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
+#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
+#define CUBLAS_OP_N HIPBLAS_OP_N
+#define CUBLAS_OP_T HIPBLAS_OP_T
+#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
+#define CUBLAS_TF32_TENSOR_OP_MATH 0
+#define CUDA_R_16F  HIPBLAS_R_16F
+#define CUDA_R_32F  HIPBLAS_R_32F
+#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
+#define cublasCreate hipblasCreate
+#define cublasGemmEx hipblasGemmEx
+#define cublasHandle_t hipblasHandle_t
+#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
+#define cublasSetStream hipblasSetStream
+#define cublasSgemm hipblasSgemm
+#define cublasStatus_t hipblasStatus_t
+#define cudaDeviceProp hipDeviceProp_t
+#define cudaDeviceSynchronize hipDeviceSynchronize
+#define cudaError_t hipError_t
+#define cudaEventCreateWithFlags hipEventCreateWithFlags
+#define cudaEventDisableTiming hipEventDisableTiming
+#define cudaEventRecord hipEventRecord
+#define cudaEvent_t hipEvent_t
+#define cudaEventDestroy hipEventDestroy
+#define cudaFree hipFree
+#define cudaFreeHost hipHostFree
+#define cudaGetDevice hipGetDevice
+#define cudaGetDeviceCount hipGetDeviceCount
+#define cudaGetDeviceProperties hipGetDeviceProperties
+#define cudaGetErrorString hipGetErrorString
+#define cudaGetLastError hipGetLastError
+#define cudaMalloc hipMalloc
+#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
+#define cudaMemcpy hipMemcpy
+#define cudaMemcpy2DAsync hipMemcpy2DAsync
+#define cudaMemcpyAsync hipMemcpyAsync
+#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
+#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
+#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
+#define cudaMemcpyKind hipMemcpyKind
+#define cudaMemset hipMemset
+#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
+#define cudaSetDevice hipSetDevice
+#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
+#define cudaStreamNonBlocking hipStreamNonBlocking
+#define cudaStreamSynchronize hipStreamSynchronize
+#define cudaStreamWaitEvent(stream, event) hipStreamWaitEvent(stream, event, 0)
+#define cudaStream_t hipStream_t
+#define cudaSuccess hipSuccess
+#else
 #include <cuda_runtime.h>
 #include <cublas_v2.h>
 #include <cuda_fp16.h>
+#endif
 
 #include "ggml-cuda.h"
 #include "ggml.h"
 
 #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
+#ifndef CC_TURING
 #define CC_TURING   700
+#endif
+
+#if defined(GGML_USE_HIPBLAS)
+#define __CUDA_ARCH__ 1300
+
+typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
+static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
+    const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
+    const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
+    const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
+    return reinterpret_cast<const int&>(c);
+}
+
+static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
+#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
+    c = __builtin_amdgcn_sdot4(a, b, c, false);
+#elif defined(__gfx1100__)
+    c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
+#elif defined(__gfx1010__) || defined(__gfx900__)
+    int tmp1;
+    int tmp2;
+    asm("\n \
+        v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
+        v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
+        v_add3_u32 %0, %1, %2, %0 \n \
+        v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
+        v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
+        v_add3_u32 %0, %1, %2, %0 \n \
+        "
+        : "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
+        : "v"(a), "v"(b)
+    );
+#else
+    const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
+    const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
+    c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
+#endif
+    return c;
+}
+#endif
 
 #if defined(_MSC_VER)
 #pragma warning(disable: 4244 4267) // possible loss of data
@@ -205,11 +306,11 @@ typedef struct {
 #define QI4_K (QK_K / (4*QR4_K))
 #ifdef GGML_QKK_64
 typedef struct {
-    half    d[2];              // super-block scales/mins
+    half    dm[2];             // super-block scales/mins
     uint8_t scales[2];         // 4-bit block scales/mins
     uint8_t qs[QK_K/2];        // 4--bit quants
 } block_q4_K;
-static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
+static_assert(sizeof(block_q4_K) == sizeof(half2) + QK_K/2 + 2, "wrong q4_K block size/padding");
 #else
 typedef struct {
     half2 dm;                  // super-block scale for quantized scales/mins
@@ -287,7 +388,7 @@ static int g_device_count = -1;
 static int g_main_device = 0;
 static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
 static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
-static bool g_mul_mat_q = false;
+static bool g_mul_mat_q = true;
 
 static void * g_scratch_buffer = nullptr;
 static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default
@@ -424,8 +525,8 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in
 static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
     const block_q4_1 * x = (const block_q4_1 *) vx;
 
-    const dfloat d = x[ib].dm.x;
-    const dfloat m = x[ib].dm.y;
+    const dfloat d = __low2half(x[ib].dm);
+    const dfloat m = __high2half(x[ib].dm);
 
     const int vui = x[ib].qs[iqs];
 
@@ -467,8 +568,8 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in
 static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
     const block_q5_1 * x = (const block_q5_1 *) vx;
 
-    const dfloat d = x[ib].dm.x;
-    const dfloat m = x[ib].dm.y;
+    const dfloat d = __low2half(x[ib].dm);
+    const dfloat m = __high2half(x[ib].dm);
 
     uint32_t qh;
     memcpy(&qh, x[ib].qh, sizeof(qh));
@@ -520,8 +621,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
     const uint8_t q = x[i].qs[32*n + l];
     float * y = yy + i*QK_K + 128*n;
 
-    float dall = x[i].dm.x;
-    float dmin = x[i].dm.y;
+    float dall = __low2half(x[i].dm);
+    float dmin = __high2half(x[i].dm);
     y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
     y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
     y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
@@ -531,8 +632,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
     const int il = tid%16;  // 0...15
     const uint8_t q = x[i].qs[il] >> (2*is);
     float * y = yy + i*QK_K + 16*is + il;
-    float dall = x[i].dm.x;
-    float dmin = x[i].dm.y;
+    float dall = __low2half(x[i].dm);
+    float dmin = __high2half(x[i].dm);
     y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
     y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
 #endif
@@ -618,8 +719,8 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float
 
     float * y = yy + i*QK_K + 64*il + n*ir;
 
-    const float dall = x[i].dm.x;
-    const float dmin = x[i].dm.y;
+    const float dall = __low2half(x[i].dm);
+    const float dmin = __high2half(x[i].dm);
 
     const uint8_t * q = x[i].qs + 32*il + n*ir;
 
@@ -636,8 +737,8 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float
     const int tid = threadIdx.x;
     const uint8_t * q = x[i].qs;
     float * y = yy + i*QK_K;
-    const float d = (float)x[i].d[0];
-    const float m = (float)x[i].d[1];
+    const float d = (float)x[i].dm[0];
+    const float m = (float)x[i].dm[1];
     y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
     y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >>  4) - m * (x[i].scales[1] >> 4);
 #endif
@@ -657,8 +758,8 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float
 
     float * y = yy + i*QK_K + 64*il + 2*ir;
 
-    const float dall = x[i].dm.x;
-    const float dmin = x[i].dm.y;
+    const float dall = __low2half(x[i].dm);
+    const float dmin = __high2half(x[i].dm);
 
     const uint8_t * ql = x[i].qs + 32*il + 2*ir;
     const uint8_t * qh = x[i].qh + 2*ir;
@@ -770,8 +871,8 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
         const float   * y = yy + i * QK_K + y_offset;
         const uint8_t * q = x[i].qs + q_offset;
 
-        const float dall = x[i].dm.x;
-        const float dmin = x[i].dm.y;
+        const float dall = __low2half(x[i].dm);
+        const float dmin = __high2half(x[i].dm);
 
         const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
         aux[0] = a[0] & 0x0f0f0f0f;
@@ -991,8 +1092,8 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
         const float   * y1 = yy + i*QK_K + y_offset;
         const float   * y2 = y1 + 128;
 
-        const float dall = x[i].dm.x;
-        const float dmin = x[i].dm.y;
+        const float dall = __low2half(x[i].dm);
+        const float dmin = __high2half(x[i].dm);
 
         const uint16_t * a = (const uint16_t *)x[i].scales;
         aux[0] = a[im+0] & kmask1;
@@ -1054,8 +1155,8 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
         const uint16_t * a = (const uint16_t *)x[i].scales;
         aux16[0] = a[0] & 0x0f0f;
         aux16[1] = (a[0] >> 4) & 0x0f0f;
-        const float d = (float)x[i].d[0];
-        const float m = (float)x[i].d[1];
+        const float d = (float)x[i].dm[0];
+        const float m = (float)x[i].dm[1];
         float sum = 0.f;
         for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
             sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
@@ -1124,8 +1225,8 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
         const float   * y1  = yy + i*QK_K + y_offset;
         const float   * y2  = y1 + 128;
 
-        const float dall = x[i].dm.x;
-        const float dmin = x[i].dm.y;
+        const float dall = __low2half(x[i].dm);
+        const float dmin = __high2half(x[i].dm);
 
         const uint16_t * a = (const uint16_t *)x[i].scales;
         aux[0] = a[im+0] & kmask1;
@@ -1348,8 +1449,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
         return;
     }
 
-    y[ib].ds.x = d;
-    y[ib].ds.y = sum;
+    reinterpret_cast<half&>(y[ib].ds.x) = d;
+    reinterpret_cast<half&>(y[ib].ds.y) = sum;
 }
 
 template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
@@ -2346,7 +2447,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
         u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
     }
 
-    return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, bq8_1->ds.x);
+    return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, __low2half(bq8_1->ds));
 }
 
 template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
@@ -2432,7 +2533,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
 #pragma unroll
     for (int i = 0; i < QR2_K; ++ i) {
         u[i]  = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
-        d8[i] = bq8_1[bq8_offset + i].ds.x;
+        d8[i] = __low2half(bq8_1[bq8_offset + i].ds);
     }
 
     return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8);
@@ -2551,7 +2652,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
 #pragma unroll
     for (int i = 0; i < QR3_K; ++i) {
         u[i]  = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
-        d8[i] = bq8_1[bq8_offset + i].ds.x;
+        d8[i] = __low2half(bq8_1[bq8_offset + i].ds);
     }
 
     return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
@@ -2720,7 +2821,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
 
     for (int i = 0; i < QR4_K; ++i) {
         const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
-        d8[i] = bq8i->ds.x;
+        d8[i] = __low2half(bq8i->ds);
 
         const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
         u[2*i+0] = q8[0];
@@ -2744,11 +2845,11 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
     aux16[0] = a[0] & 0x0f0f;
     aux16[1] = (a[0] >> 4) & 0x0f0f;
 
-    const float dall = bq4_K->d[0];
-    const float dmin = bq4_K->d[1];
+    const float dall = bq4_K->dm[0];
+    const float dmin = bq4_K->dm[1];
 
-    const float d8_1 = bq8_1[0].ds.x;
-    const float d8_2 = bq8_1[1].ds.x;
+    const float d8_1 = __low2float(bq8_1[0].ds);
+    const float d8_2 = __low2float(bq8_1[1].ds);
 
     const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
     const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
@@ -2828,7 +2929,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd;
 
+#if QK_K == 256
         x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
+#else
+        x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = {bxi->dm[0], bxi->dm[1]};
+#endif
     }
 
 #pragma unroll
@@ -2901,7 +3006,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
 #pragma unroll
     for (int i = 0; i < QR5_K; ++i) {
         const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
-        d8[i] = bq8i->ds.x;
+        d8[i] = __low2float(bq8i->ds);
 
         const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
         u[2*i+0] = q8[0];
@@ -2919,8 +3024,8 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
 
     const float d = bq5_K->d;
 
-    const float d8_1 = bq8_1[0].ds.x;
-    const float d8_2 = bq8_1[1].ds.x;
+    const float d8_1 = __low2half(bq8_1[0].ds);
+    const float d8_2 = __low2half(bq8_1[1].ds);
 
     const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
     const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
@@ -3018,7 +3123,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd;
 
+#if QK_K == 256
         x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
+#endif
     }
 
 #pragma unroll
@@ -3075,7 +3182,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
 #pragma unroll
     for (int i = 0; i < QR6_K; ++i) {
         u[i]  = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1);
-        d8[i] = bq8_1[bq8_offset + 2*i].ds.x;
+        d8[i] = __low2half(bq8_1[bq8_offset + 2*i].ds);
     }
 
     return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);
@@ -3243,7 +3350,7 @@ static __device__ __forceinline__ void mul_mat_q(
                     *dsi_dst = *dsi_src;
                 } else {
                     float * dfi_dst = (float *) dsi_dst;
-                    *dfi_dst = (*dsi_src).x;
+                    *dfi_dst = __low2half(*dsi_src);
                 }
             }
 
@@ -3887,13 +3994,13 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
 // rope == RoPE == rotary positional embedding
 static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0,
                                 const float p_delta, const int p_delta_rows, const float theta_scale) {
-    const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);
+    const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
     if (col >= ncols) {
         return;
     }
 
-    const int row = blockDim.y*blockIdx.y + threadIdx.y;
+    const int row = blockDim.x*blockIdx.x + threadIdx.x;
     const int i = row*ncols + col;
 
     const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
@@ -3907,6 +4014,28 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
     dst[i + 1] = x0*sin_theta + x1*cos_theta;
 }
 
+static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float p0,
+                                const float p_delta, const int p_delta_rows, const float theta_scale) {
+    const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
+
+    if (col >= ncols) {
+        return;
+    }
+
+    const int row = blockDim.x*blockIdx.x + threadIdx.x;
+    const int i = row*ncols + col/2;
+
+    const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
+    const float sin_theta = sinf(theta);
+    const float cos_theta = cosf(theta);
+
+    const float x0 = x[i + 0];
+    const float x1 = x[i + ncols/2];
+
+    dst[i + 0]       = x0*cos_theta - x1*sin_theta;
+    dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
+}
+
 static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) {
     const int col = blockDim.x*blockIdx.x + threadIdx.x;
     const int half_n_dims = ncols/4;
@@ -3965,8 +4094,8 @@ static __global__ void alibi_f32(const float * x, float * dst, const int ncols,
 }
 
 static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
-    const int col = blockDim.x*blockIdx.x + threadIdx.x;
-    const int row = blockDim.y*blockIdx.y + threadIdx.y;
+    const int col = blockDim.y*blockIdx.y + threadIdx.y;
+    const int row = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (col >= ncols) {
         return;
@@ -3979,24 +4108,29 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
 
 // the CUDA soft max implementation differs from the CPU implementation
 // instead of doubles floats are used
-// values are also not normalized to the maximum value by subtracting it in the exponential function
-// theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
 static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
-    const int row = blockDim.y*blockIdx.y + threadIdx.y;
-    const int block_size = blockDim.x;
-    const int tid = threadIdx.x;
+    const int row = blockDim.x*blockIdx.x + threadIdx.x;
+    const int block_size = blockDim.y;
+    const int tid = threadIdx.y;
 
-    float tmp = 0.0;
+    float max_val = -INFINITY;
 
-    for (int block_start = 0; block_start < ncols; block_start += block_size) {
-        const int col = block_start + tid;
+    for (int col = tid; col < ncols; col += block_size) {
+        const int i = row*ncols + col;
+        max_val = max(max_val, x[i]);
+    }
 
-        if (col >= ncols) {
-            break;
-        }
+    // find the max value in the block
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
+    }
 
+    float tmp = 0.f;
+
+    for (int col = tid; col < ncols; col += block_size) {
         const int i = row*ncols + col;
-        const float val = expf(x[i]);
+        const float val = expf(x[i] - max_val);
         tmp += val;
         dst[i] = val;
     }
@@ -4007,15 +4141,11 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol
         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
     }
 
-    for (int block_start = 0; block_start < ncols; block_start += block_size) {
-        const int col = block_start + tid;
-
-        if (col >= ncols) {
-            break;
-        }
+    const float inv_tmp = 1.f / tmp;
 
+    for (int col = tid; col < ncols; col += block_size) {
         const int i = row*ncols + col;
-        dst[i] /= tmp;
+        dst[i] *= inv_tmp;
     }
 }
 
@@ -4585,6 +4715,8 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
     const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
     const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
 
+#if QK_K == 256
+
     int id;
     CUDA_CHECK(cudaGetDevice(&id));
     const int compute_capability = g_compute_capabilities[id];
@@ -4616,6 +4748,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
         mul_mat_q3_K<need_check><<<block_nums, block_dims, 0, stream>>>
             (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
     }
+#endif
 }
 
 static void ggml_mul_mat_q4_K_q8_1_cuda(
@@ -4775,13 +4908,22 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
 
 static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
                           const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
-    GGML_ASSERT(nrows % 2 == 0);
-    const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1);
+    GGML_ASSERT(ncols % 2 == 0);
+    const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
-    const dim3 block_nums(num_blocks_x, nrows, 1);
+    const dim3 block_nums(nrows, num_blocks_x, 1);
     rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
 }
 
+static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
+                          const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
+    GGML_ASSERT(ncols % 2 == 0);
+    const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
+    const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+    const dim3 block_nums(nrows, num_blocks_x, 1);
+    rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
+}
+
 static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) {
     GGML_ASSERT(nrows % 4 == 0);
     const dim3 block_dims(4*CUDA_ROPE_BLOCK_SIZE, 1, 1);
@@ -4800,15 +4942,15 @@ static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const
 }
 
 static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
-    const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1);
+    const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1);
     const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
-    const dim3 block_nums(block_num_x, nrows_x, 1);
+    const dim3 block_nums(nrows_x, block_num_x, 1);
     diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
 }
 
 static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
-    const dim3 block_dims(WARP_SIZE, 1, 1);
-    const dim3 block_nums(1, nrows_x, 1);
+    const dim3 block_dims(1, WARP_SIZE, 1);
+    const dim3 block_nums(nrows_x, 1, 1);
     soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
 }
 
@@ -4913,10 +5055,18 @@ void ggml_init_cublas() {
     static bool initialized = false;
 
     if (!initialized) {
+
+#ifdef __HIP_PLATFORM_AMD__
+        // Workaround for a rocBLAS bug when using multiple graphics cards:
+        // https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346
+        rocblas_initialize();
+        CUDA_CHECK(cudaDeviceSynchronize());
+#endif
+
         CUDA_CHECK(cudaGetDeviceCount(&g_device_count));
         GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
         int64_t total_vram = 0;
-        fprintf(stderr, "%s: found %d CUDA devices:\n", __func__, g_device_count);
+        fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count);
         for (int id = 0; id < g_device_count; ++id) {
             cudaDeviceProp prop;
             CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
@@ -5514,7 +5664,8 @@ inline void ggml_cuda_op_rope(
 
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
-    const bool is_glm = mode & 4;
+    const bool is_neox = mode & 2;
+    const bool is_glm  = mode & 4;
 
     // compute
     if (is_glm) {
@@ -5522,6 +5673,10 @@ inline void ggml_cuda_op_rope(
         const float id_p = min(p, n_ctx - 2.f);
         const float block_p = max(p - (n_ctx - 2.f), 0.f);
         rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
+    } else if (is_neox) {
+        GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
+        const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
+        rope_neox_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
     } else {
         const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
         rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
@@ -6183,9 +6338,11 @@ void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml
 
 void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented
 
     const int mode = ((int32_t *) dst->op_params)[2];
     const bool is_glm = mode & 4;
+
     ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, !is_glm); // flatten support not implemented for glm
 }
 
@@ -6313,7 +6470,7 @@ static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
     return extra;
 }
 
-void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) {
+void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) {
     if (scratch && g_scratch_size == 0) {
         return;
     }
@@ -6322,14 +6479,19 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
     if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) {
         const ggml_op src0_op = tensor->src[0]->op;
         if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW || src0_op == GGML_OP_PERMUTE) {
-            ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace);
+            ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace, no_alloc);
         }
     }
     if (tensor->op == GGML_OP_CPY && tensor->src[1]->backend == GGML_BACKEND_CPU) {
-        ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace);
+        ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace, no_alloc);
     }
 
     tensor->backend = GGML_BACKEND_GPU;
+
+    if (scratch && no_alloc) {
+        return;
+    }
+
     struct ggml_tensor_extra_gpu * extra;
 
     const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
@@ -6381,16 +6543,48 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
     tensor->extra = extra;
 }
 
+void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset) {
+    if (g_scratch_size == 0) {
+        return;
+    }
+    if (g_scratch_buffer == nullptr) {
+        CUDA_CHECK(cudaMalloc(&g_scratch_buffer, g_scratch_size));
+    }
+
+    struct ggml_tensor_extra_gpu * extra = ggml_cuda_alloc_temp_tensor_extra();
+
+    const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
+        tensor->op == GGML_OP_VIEW;
+
+    if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) {
+        struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra;
+        char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
+        size_t view_offset = 0;
+        if (tensor->op == GGML_OP_VIEW) {
+            memcpy(&view_offset, tensor->op_params, sizeof(size_t));
+        }
+        extra->data_device[g_main_device] = src0_ddc + view_offset;
+    } else {
+        extra->data_device[g_main_device] = (char *) g_scratch_buffer + offset;
+    }
+
+    tensor->extra = extra;
+}
+
 void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
-    ggml_cuda_assign_buffers_impl(tensor, true, false);
+    ggml_cuda_assign_buffers_impl(tensor, true, false, false);
+}
+
+void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor) {
+    ggml_cuda_assign_buffers_impl(tensor, true, false, true);
 }
 
 void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
-    ggml_cuda_assign_buffers_impl(tensor, false, false);
+    ggml_cuda_assign_buffers_impl(tensor, false, false, false);
 }
 
 void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) {
-    ggml_cuda_assign_buffers_impl(tensor, false, true);
+    ggml_cuda_assign_buffers_impl(tensor, false, true, false);
 }
 
 void ggml_cuda_set_main_device(int main_device) {
index cad05f5fa47ab68e7692ddf16b6a8013825a72a3..a72e82069b9f1430fdbcc007f93ae26ea6c2f924 100644 (file)
@@ -2,6 +2,14 @@
 
 #include "ggml.h"
 
+#ifdef GGML_USE_HIPBLAS
+#define GGML_CUDA_NAME "ROCm"
+#define GGML_CUBLAS_NAME "hipBLAS"
+#else
+#define GGML_CUDA_NAME "CUDA"
+#define GGML_CUBLAS_NAME "cuBLAS"
+#endif
+
 #ifdef  __cplusplus
 extern "C" {
 #endif
@@ -16,9 +24,14 @@ GGML_API bool   ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const str
 GGML_API void   ggml_cuda_set_tensor_split(const float * tensor_split);
 GGML_API void   ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
 GGML_API void   ggml_cuda_free_data(struct ggml_tensor * tensor);
+
 GGML_API void   ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
 GGML_API void   ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
 GGML_API void   ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
+
+GGML_API void   ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor);
+GGML_API void   ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset);
+
 GGML_API void   ggml_cuda_set_main_device(int main_device);
 GGML_API void   ggml_cuda_set_mul_mat_q(bool mul_mat_q);
 GGML_API void   ggml_cuda_set_scratch_size(size_t scratch_size);
index 00202b787c8043173b8308f49e6ae61080d42094..fca28d37ef97069ca7800fc890ea4d013c46437f 100644 (file)
@@ -24,6 +24,7 @@
 
 // max memory buffers that can be mapped to the device
 #define GGML_METAL_MAX_BUFFERS 16
+#define GGML_METAL_MAX_COMMAND_BUFFERS 32
 
 struct ggml_tensor;
 struct ggml_cgraph;
index 835c5f297cf95d0b6cff9a16eab111691076f3ad..ad2ee8cf5fef098169d765b7fa92b12f57c9f70c 100644 (file)
@@ -33,12 +33,15 @@ struct ggml_metal_buffer {
 struct ggml_metal_context {
     int n_cb;
 
-    float * logits;
-
     id<MTLDevice>       device;
     id<MTLCommandQueue> queue;
     id<MTLLibrary>      library;
 
+    id<MTLCommandBuffer>         command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
+    id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
+
+    dispatch_queue_t d_queue;
+
     int n_buffers;
     struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
 
@@ -63,6 +66,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(get_rows_f16);
     GGML_METAL_DECL_KERNEL(get_rows_q4_0);
     GGML_METAL_DECL_KERNEL(get_rows_q4_1);
+    GGML_METAL_DECL_KERNEL(get_rows_q8_0);
     GGML_METAL_DECL_KERNEL(get_rows_q2_K);
     GGML_METAL_DECL_KERNEL(get_rows_q3_K);
     GGML_METAL_DECL_KERNEL(get_rows_q4_K);
@@ -73,6 +77,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
@@ -81,6 +86,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
@@ -111,12 +117,13 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
 
     struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
 
-    ctx->n_cb   = n_cb;
+    ctx->n_cb   = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
     ctx->device = MTLCreateSystemDefaultDevice();
     ctx->queue  = [ctx->device newCommandQueue];
     ctx->n_buffers = 0;
     ctx->concur_list_len = 0;
 
+    ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
 
 #if 0
     // compile from source string and show compile log
@@ -167,7 +174,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
 #define GGML_METAL_ADD_KERNEL(name) \
         ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
         ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
-        fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); \
+        fprintf(stderr, "%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
+                (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
+                (int) ctx->pipeline_##name.threadExecutionWidth); \
         if (error) { \
             fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
             return NULL; \
@@ -186,6 +195,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(get_rows_f16);
         GGML_METAL_ADD_KERNEL(get_rows_q4_0);
         GGML_METAL_ADD_KERNEL(get_rows_q4_1);
+        GGML_METAL_ADD_KERNEL(get_rows_q8_0);
         GGML_METAL_ADD_KERNEL(get_rows_q2_K);
         GGML_METAL_ADD_KERNEL(get_rows_q3_K);
         GGML_METAL_ADD_KERNEL(get_rows_q4_K);
@@ -196,6 +206,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
@@ -203,6 +214,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
         GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
         GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
+        GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
         GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
         GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
         GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
@@ -218,12 +230,12 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
 #undef GGML_METAL_ADD_KERNEL
     }
 
-    fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
-    fprintf(stderr, "%s: hasUnifiedMemory             = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
+    fprintf(stderr, "%s: recommendedMaxWorkingSetSize  = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
+    fprintf(stderr, "%s: hasUnifiedMemory              = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
     if (ctx->device.maxTransferRate != 0) {
-        fprintf(stderr, "%s: maxTransferRate              = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
+        fprintf(stderr, "%s: maxTransferRate               = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
     } else {
-        fprintf(stderr, "%s: maxTransferRate              = built-in GPU\n", __func__);
+        fprintf(stderr, "%s: maxTransferRate               = built-in GPU\n", __func__);
     }
 
     return ctx;
@@ -231,9 +243,67 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
 
 void ggml_metal_free(struct ggml_metal_context * ctx) {
     fprintf(stderr, "%s: deallocating\n", __func__);
+#define GGML_METAL_DEL_KERNEL(name) \
+    [ctx->function_##name release]; \
+    [ctx->pipeline_##name release];
+
+    GGML_METAL_DEL_KERNEL(add);
+    GGML_METAL_DEL_KERNEL(add_row);
+    GGML_METAL_DEL_KERNEL(mul);
+    GGML_METAL_DEL_KERNEL(mul_row);
+    GGML_METAL_DEL_KERNEL(scale);
+    GGML_METAL_DEL_KERNEL(silu);
+    GGML_METAL_DEL_KERNEL(relu);
+    GGML_METAL_DEL_KERNEL(gelu);
+    GGML_METAL_DEL_KERNEL(soft_max);
+    GGML_METAL_DEL_KERNEL(diag_mask_inf);
+    GGML_METAL_DEL_KERNEL(get_rows_f16);
+    GGML_METAL_DEL_KERNEL(get_rows_q4_0);
+    GGML_METAL_DEL_KERNEL(get_rows_q4_1);
+    GGML_METAL_DEL_KERNEL(get_rows_q8_0);
+    GGML_METAL_DEL_KERNEL(get_rows_q2_K);
+    GGML_METAL_DEL_KERNEL(get_rows_q3_K);
+    GGML_METAL_DEL_KERNEL(get_rows_q4_K);
+    GGML_METAL_DEL_KERNEL(get_rows_q5_K);
+    GGML_METAL_DEL_KERNEL(get_rows_q6_K);
+    GGML_METAL_DEL_KERNEL(rms_norm);
+    GGML_METAL_DEL_KERNEL(norm);
+    GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
+    GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
+    GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
+    GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
+    GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
+    GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
+    GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
+    GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
+    GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
+    GGML_METAL_DEL_KERNEL(rope);
+    GGML_METAL_DEL_KERNEL(alibi_f32);
+    GGML_METAL_DEL_KERNEL(cpy_f32_f16);
+    GGML_METAL_DEL_KERNEL(cpy_f32_f32);
+    GGML_METAL_DEL_KERNEL(cpy_f16_f16);
+
+#undef GGML_METAL_DEL_KERNEL
+
     for (int i = 0; i < ctx->n_buffers; ++i) {
         [ctx->buffers[i].metal release];
     }
+
+    [ctx->library release];
+    [ctx->queue release];
+    [ctx->device release];
+
+    dispatch_release(ctx->d_queue);
+
     free(ctx);
 }
 
@@ -253,7 +323,7 @@ void ggml_metal_host_free(void * data) {
 }
 
 void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
-    ctx->n_cb = n_cb;
+    ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
 }
 
 int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
@@ -499,6 +569,8 @@ void ggml_metal_graph_compute(
                struct ggml_cgraph * gf) {
     metal_printf("%s: evaluating graph\n", __func__);
 
+    @autoreleasepool {
+
     // if there is ctx->concur_list, dispatch concurrently
     // else fallback to serial dispatch
     MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
@@ -513,32 +585,28 @@ void ggml_metal_graph_compute(
 
     const int n_cb = ctx->n_cb;
 
-    NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
-
     for (int i = 0; i < n_cb; ++i) {
-        command_buffers[i] = [ctx->queue commandBuffer];
+        ctx->command_buffers[i] = [ctx->queue commandBuffer];
 
         // enqueue the command buffers in order to specify their execution order
-        [command_buffers[i] enqueue];
-    }
+        [ctx->command_buffers[i] enqueue];
 
-    // TODO: is this the best way to start threads?
-    dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
+        ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
+    }
 
     for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
         const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
 
-        dispatch_async(queue, ^{
+        dispatch_async(ctx->d_queue, ^{
             size_t offs_src0 = 0;
             size_t offs_src1 = 0;
             size_t offs_dst  = 0;
 
-            id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
-
-            id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
+            id<MTLCommandBuffer> command_buffer  = ctx->command_buffers[cb_idx];
+            id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
 
-            const int node_start =                                  (cb_idx + 0) * n_nodes_per_cb;
-            const int node_end   = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
+            const int node_start =                                      (cb_idx + 0) * n_nodes_per_cb;
+            const int node_end   = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
 
             for (int ind = node_start; ind < node_end; ++ind) {
                 const int i = has_concur ? ctx->concur_list[ind] : ind;
@@ -744,32 +812,32 @@ void ggml_metal_graph_compute(
                                 [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
                                 ne00%32 == 0 &&
                                 ne11 > 1) {
-                                    switch (src0->type) {
-                                        case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
-                                        case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
-                                        case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
-                                        case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
-                                        case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
-                                        case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
-                                        case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
-                                        case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
-                                        default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
-                                    }
-                                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                                    [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                                    [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
-                                    [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                                    [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
-                                    [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
-                                    [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
-                                    [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
-                                    [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
-                                    [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
-                                    [encoder setThreadgroupMemoryLength:8192 atIndex:0];
-                                    [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                                switch (src0->type) {
+                                    case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32];  break;
+                                    case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
+                                    case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
+                                    case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
+                                    case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
+                                    case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
+                                    case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
+                                    case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
+                                    case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
+                                    default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
                                 }
-                            else {
+                                [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
+                                [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
+                                [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
+                                [encoder setBytes:&ne00    length:sizeof(ne00) atIndex:3];
+                                [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:4];
+                                [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:5];
+                                [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:6];
+                                [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:7];
+                                [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:8];
+                                [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:9];
+                                [encoder setBytes:&gqa     length:sizeof(gqa)  atIndex:10];
+                                [encoder setThreadgroupMemoryLength:8192 atIndex:0];
+                                [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                            } else {
                                 int nth0 = 32;
                                 int nth1 = 1;
 
@@ -799,6 +867,15 @@ void ggml_metal_graph_compute(
                                             nth1 = 8;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
                                         } break;
+                                    case GGML_TYPE_Q8_0:
+                                        {
+                                            GGML_ASSERT(ne02 == 1);
+                                            GGML_ASSERT(ne12 == 1);
+
+                                            nth0 = 8;
+                                            nth1 = 8;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
+                                        } break;
                                     case GGML_TYPE_Q2_K:
                                         {
                                             GGML_ASSERT(ne02 == 1);
@@ -868,24 +945,24 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
                                 [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:15];
                                 [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:16];
-                                [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
+                                [encoder setBytes:&gqa  length:sizeof(gqa)  atIndex:17];
 
-                                if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
+                                if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
                                     src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                                 else if (src0t == GGML_TYPE_Q3_K) {
 #ifdef GGML_QKK_64
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
 #else
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
 #endif
                                 }
                                 else if (src0t == GGML_TYPE_Q5_K) {
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                                 else if (src0t == GGML_TYPE_Q6_K) {
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 } else {
                                     [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
                                     [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -895,9 +972,10 @@ void ggml_metal_graph_compute(
                     case GGML_OP_GET_ROWS:
                         {
                             switch (src0->type) {
-                                case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
+                                case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_get_rows_f16];  break;
                                 case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
                                 case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
+                                case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
                                 case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
                                 case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
                                 case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
@@ -938,16 +1016,17 @@ void ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_NORM:
                         {
-                            const float eps = 1e-5f;
+                            float eps;
+                            memcpy(&eps, dst->op_params, sizeof(float));
 
                             const int nth = 256;
 
                             [encoder setComputePipelineState:ctx->pipeline_norm];
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                            [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
-                            [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
-                            [encoder setBytes:&eps  length:sizeof(   float) atIndex:4];
+                            [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
+                            [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:3];
+                            [encoder setBytes:&eps     length:sizeof(   float) atIndex:4];
                             [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
 
                             const int64_t nrows = ggml_nrows(src0);
@@ -990,7 +1069,9 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:16];
                             [encoder setBytes:&nb3  length:sizeof(uint64_t) atIndex:17];
                             [encoder setBytes:&m0  length:sizeof(    float) atIndex:18];
+
                             const int nth = 32;
+
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
                     case GGML_OP_ROPE:
@@ -1005,8 +1086,8 @@ void ggml_metal_graph_compute(
                             memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
 
                             [encoder setComputePipelineState:ctx->pipeline_rope];
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
                             [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
                             [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:3];
                             [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:4];
@@ -1057,24 +1138,24 @@ void ggml_metal_graph_compute(
                                 default: GGML_ASSERT(false && "not implemented");
                             }
 
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                            [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
-                            [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
-                            [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
-                            [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
-                            [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
-                            [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
-                            [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
-                            [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
-                            [encoder setBytes:&ne0  length:sizeof( int64_t) atIndex:10];
-                            [encoder setBytes:&ne1  length:sizeof( int64_t) atIndex:11];
-                            [encoder setBytes:&ne2  length:sizeof( int64_t) atIndex:12];
-                            [encoder setBytes:&ne3  length:sizeof( int64_t) atIndex:13];
-                            [encoder setBytes:&nb0  length:sizeof(uint64_t) atIndex:14];
-                            [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:15];
-                            [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:16];
-                            [encoder setBytes:&nb3  length:sizeof(uint64_t) atIndex:17];
+                            [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
+                            [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:3];
+                            [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:4];
+                            [encoder setBytes:&ne03    length:sizeof( int64_t) atIndex:5];
+                            [encoder setBytes:&nb00    length:sizeof(uint64_t) atIndex:6];
+                            [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:7];
+                            [encoder setBytes:&nb02    length:sizeof(uint64_t) atIndex:8];
+                            [encoder setBytes:&nb03    length:sizeof(uint64_t) atIndex:9];
+                            [encoder setBytes:&ne0     length:sizeof( int64_t) atIndex:10];
+                            [encoder setBytes:&ne1     length:sizeof( int64_t) atIndex:11];
+                            [encoder setBytes:&ne2     length:sizeof( int64_t) atIndex:12];
+                            [encoder setBytes:&ne3     length:sizeof( int64_t) atIndex:13];
+                            [encoder setBytes:&nb0     length:sizeof(uint64_t) atIndex:14];
+                            [encoder setBytes:&nb1     length:sizeof(uint64_t) atIndex:15];
+                            [encoder setBytes:&nb2     length:sizeof(uint64_t) atIndex:16];
+                            [encoder setBytes:&nb3     length:sizeof(uint64_t) atIndex:17];
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
@@ -1096,17 +1177,19 @@ void ggml_metal_graph_compute(
     }
 
     // wait for all threads to finish
-    dispatch_barrier_sync(queue, ^{});
-
-    [command_buffers[n_cb - 1] waitUntilCompleted];
+    dispatch_barrier_sync(ctx->d_queue, ^{});
 
     // check status of command buffers
     // needed to detect if the device ran out-of-memory for example (#1881)
     for (int i = 0; i < n_cb; i++) {
-        MTLCommandBufferStatus status = (MTLCommandBufferStatus) [command_buffers[i] status];
+        [ctx->command_buffers[i] waitUntilCompleted];
+
+        MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
         if (status != MTLCommandBufferStatusCompleted) {
             fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status);
             GGML_ASSERT(false);
         }
     }
+
+    }
 }
index ce3541f4bb55f6a06178975ad6b0fa51a2435906..82e1a0c7aca06e124898dc1c8caa7064bafef261 100644 (file)
@@ -18,6 +18,12 @@ typedef struct {
     uint8_t qs[QK4_1 / 2];  // nibbles / quants
 } block_q4_1;
 
+#define QK8_0 32
+typedef struct {
+    half    d;         // delta
+    int8_t  qs[QK8_0]; // quants
+} block_q8_0;
+
 kernel void kernel_add(
         device const float * src0,
         device const float * src1,
@@ -87,7 +93,12 @@ kernel void kernel_gelu(
     device       float * dst,
     uint tpig[[thread_position_in_grid]]) {
     float x = src0[tpig];
-    dst[tpig] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+
+    // BEWARE !!!
+    // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
+    // This was observed with Falcon 7B and 40B models
+    //
+    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
 }
 
 kernel void kernel_soft_max(
@@ -352,7 +363,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
     const int first_row = (r0 * nsg + sgitg) * nr;
     const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
     device const block_q_type * x = (device const block_q_type *) src0 + offset0;
-    device const float      * y = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+    device const float        * y = (device const float        *) src1 + r1*ne10 + im*ne00*ne1;
     float yl[16];       // src1 vector cache
     float sumf[nr]={0.f};
 
@@ -424,6 +435,68 @@ kernel void kernel_mul_mat_q4_1_f32(
      mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
 }
 
+kernel void kernel_mul_mat_q8_0_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01[[buffer(4)]],
+        constant   int64_t & ne02[[buffer(5)]],
+        constant   int64_t & ne10[[buffer(9)]],
+        constant   int64_t & ne12[[buffer(11)]],
+        constant   int64_t & ne0[[buffer(15)]],
+        constant   int64_t & ne1[[buffer(16)]],
+        constant   uint    & gqa[[buffer(17)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+    const int nr  = N_DST;
+    const int nsg = N_SIMDGROUP;
+    const int nw  = N_SIMDWIDTH;
+
+    const int nb = ne00/QK8_0;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+    const int first_row = (r0 * nsg + sgitg) * nr;
+    const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
+    device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
+    device const float      * y = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+
+    float yl[16];
+    float sumf[nr]={0.f};
+
+    const int ix = tiisg/2;
+    const int il = tiisg%2;
+
+    device const float * yb = y + ix * QK8_0 + 16*il;
+
+    // each thread in a SIMD group deals with half a block.
+    for (int ib = ix; ib < nb; ib += nw/2) {
+        for (int i = 0; i < 16; ++i) {
+            yl[i] = yb[i];
+        }
+
+        for (int row = 0; row < nr; row++) {
+            device const int8_t * qs = x[ib+row*nb].qs + 16*il;
+            float sumq = 0.f;
+            for (int iq = 0; iq < 16; ++iq) {
+                sumq += qs[iq] * yl[iq];
+            }
+            sumf[row] += sumq*x[ib+row*nb].d;
+        }
+
+        yb += QK8_0 * 16;
+    }
+
+    for (int row = 0; row < nr; ++row) {
+        const float tot = simd_sum(sumf[row]);
+        if (tiisg == 0 && first_row + row < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
+        }
+    }
+}
+
 kernel void kernel_mul_mat_f16_f32(
         device const  char * src0,
         device const  char * src1,
@@ -475,7 +548,6 @@ kernel void kernel_mul_mat_f16_f32(
     }
 }
 
-
 kernel void kernel_alibi_f32(
         device const float * src0,
         device       float * dst,
@@ -571,7 +643,25 @@ kernel void kernel_rope(
             dst_data[1] = x0*sin_theta + x1*cos_theta;
         }
     } else {
-        // TODO: implement
+        for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
+            for (int64_t ic = 0; ic < n_dims; ic += 2) {
+                const float cos_theta = cos(theta);
+                const float sin_theta = sin(theta);
+
+                theta *= theta_scale;
+
+                const int64_t i0 = ib*n_dims + ic/2;
+
+                device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                device       float * dst_data  = (device float *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+                const float x0 = src[0];
+                const float x1 = src[n_dims/2];
+
+                dst_data[0]        = x0*cos_theta - x1*sin_theta;
+                dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+            }
+        }
     }
 }
 
@@ -1598,12 +1688,12 @@ template <typename type4x4>
 void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
     device const uint16_t * qs = ((device const uint16_t *)xb + 1);
     const half d = il ? (xb->d / 16.h) : xb->d;
-    const half m = il ? (-8.h * 16.h) : -8.h;
+    const half m = il ? ( -8.h * 16.h) : -8.h;
     const ushort mask0 = il ? 0x00F0 : 0x000F;
     const ushort mask1 = il ? 0xF000 : 0x0F00;
 
     for (int i=0;i<8;i++) {
-        reg[i/2][2*(i%2)] = (((qs[i] & mask0)) + m) * d;
+        reg[i/2][2*(i%2)]   = (((qs[i] & mask0)     ) + m) * d;
         reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
     }
 }
@@ -1617,11 +1707,21 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
     const ushort mask1 = il ? 0xF000 : 0x0F00;
 
     for (int i=0;i<8;i++) {
-        reg[i/2][2*(i%2)] = (((qs[i] & mask0)) * d) + m;
+        reg[i/2][2*(i%2)]   = (((qs[i] & mask0)     ) * d) + m;
         reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
     }
 }
 
+template <typename type4x4>
+void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
+    device const int8_t * qs = ((device const int8_t *)xb->qs);
+    const half d = xb->d;
+
+    for (int i=0;i<16;i++) {
+        reg[i/4][i%4] = (qs[i + 16*il] * d);
+    }
+}
+
 template <typename type4x4>
 void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
     const half d = xb->d;
@@ -1924,9 +2024,10 @@ kernel void kernel_mul_mm(device const  uchar * src0,
 typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
                           constant uint64_t &, constant uint64_t &, uint, uint, uint);
 
-template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
+template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_t kernel_get_rows<half4x4,    1, dequantize_f16>;
 template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
 template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
+template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
 template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
 template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
 template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
@@ -1937,9 +2038,10 @@ typedef void (mat_mm_t)(device const uchar *, device const float *, device float
                              constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
                              constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
 
-template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
+template [[host_name("kernel_mul_mm_f16_f32")]]  kernel mat_mm_t kernel_mul_mm<half4x4,    1, dequantize_f16>;
 template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
 template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
+template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
 template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
 template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
 template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
index 97232219be5f59c6f00c3c69beb31caf833fd357..dadb30757a962c7d5db44a0d854403f7dd5b8c21 100644 (file)
@@ -2430,7 +2430,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
     const int nb = n / qk;
 
     assert(n % qk == 0);
-    assert(nb % 2 == 0);
 
     const block_q4_0 * restrict x = vx;
     const block_q8_0 * restrict y = vy;
@@ -2439,6 +2438,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
     float32x4_t sumv0 = vdupq_n_f32(0.0f);
     float32x4_t sumv1 = vdupq_n_f32(0.0f);
 
+    GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
     for (int i = 0; i < nb; i += 2) {
         const block_q4_0 * restrict x0 = &x[i + 0];
         const block_q4_0 * restrict x1 = &x[i + 1];
@@ -2617,6 +2617,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
     }
 
     // Main loop
+    GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
     for (int i = 2; i < nb; i+=2) {
         _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0);
         _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
@@ -2700,7 +2701,6 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
     const int nb = n / qk;
 
     assert(n % qk == 0);
-    assert(nb % 2 == 0);
 
     const block_q4_1 * restrict x = vx;
     const block_q8_1 * restrict y = vy;
@@ -2712,6 +2712,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
 
     float summs = 0;
 
+    GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
     for (int i = 0; i < nb; i += 2) {
         const block_q4_1 * restrict x0 = &x[i + 0];
         const block_q4_1 * restrict x1 = &x[i + 1];
@@ -2826,7 +2827,6 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
     const int nb = n / qk;
 
     assert(n % qk == 0);
-    assert(nb % 2 == 0);
     assert(qk == QK5_0);
 
     const block_q5_0 * restrict x = vx;
@@ -2842,6 +2842,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
     uint64_t tmp0[4];
     uint64_t tmp1[4];
 
+    GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
     for (int i = 0; i < nb; i += 2) {
         const block_q5_0 * restrict x0 = &x[i];
         const block_q5_0 * restrict x1 = &x[i + 1];
@@ -3066,7 +3067,6 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
     const int nb = n / qk;
 
     assert(n % qk == 0);
-    assert(nb % 2 == 0);
     assert(qk == QK5_1);
 
     const block_q5_1 * restrict x = vx;
@@ -3085,6 +3085,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
     uint64_t tmp0[4];
     uint64_t tmp1[4];
 
+    GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
     for (int i = 0; i < nb; i += 2) {
         const block_q5_1 * restrict x0 = &x[i];
         const block_q5_1 * restrict x1 = &x[i + 1];
@@ -3322,7 +3323,6 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
     const int nb = n / qk;
 
     assert(n % qk == 0);
-    assert(nb % 2 == 0);
 
     const block_q8_0 * restrict x = vx;
     const block_q8_0 * restrict y = vy;
@@ -3331,6 +3331,7 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
     float32x4_t sumv0 = vdupq_n_f32(0.0f);
     float32x4_t sumv1 = vdupq_n_f32(0.0f);
 
+    GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
     for (int i = 0; i < nb; i += 2) {
         const block_q8_0 * restrict x0 = &x[i + 0];
         const block_q8_0 * restrict x1 = &x[i + 1];
@@ -3548,9 +3549,9 @@ inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) {
 inline static void ggml_vec_elu_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
 inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
 
-static const float GELU_COEF_A    = 0.044715f;
-static const float GELU_QUICK_COEF    = -1.702f;
-static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+static const float GELU_COEF_A     = 0.044715f;
+static const float GELU_QUICK_COEF = -1.702f;
+static const float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
 
 inline static float ggml_gelu_f32(float x) {
     return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
@@ -5549,10 +5550,6 @@ struct ggml_tensor * ggml_repeat(
         is_node = true;
     }
 
-    if (ggml_are_same_shape(a, b) && !is_node) {
-        return a;
-    }
-
     struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
 
     result->op   = GGML_OP_REPEAT;
@@ -12537,7 +12534,7 @@ static void ggml_compute_forward_rope_f32(
                         dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
                     }
                 } else {
-                    // TODO: this is probably wrong, but I can't figure it out ..
+                    // TODO: this might be wrong for ne0 != n_dims - need double check
                     // ref:  https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
                     for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
                         for (int64_t ic = 0; ic < n_dims; ic += 2) {
@@ -12666,7 +12663,7 @@ static void ggml_compute_forward_rope_f16(
                         dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
                     }
                 } else {
-                    // TODO: this is probably wrong, but I can't figure it out ..
+                    // TODO: this might be wrong for ne0 != n_dims - need double check
                     // ref:  https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
                     for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
                         for (int64_t ic = 0; ic < n_dims; ic += 2) {
@@ -16352,10 +16349,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     float freq_scale;
                     float xpos_base;
                     bool  xpos_down;
-                                       memcpy(&freq_base,  (int32_t *) tensor->op_params + 4, sizeof(float));
-                                       memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float));
-                                       memcpy(&xpos_base,  (int32_t *) tensor->op_params + 6, sizeof(float));
-                                       memcpy(&xpos_down,  (int32_t *) tensor->op_params + 7, sizeof(bool));
+                    memcpy(&freq_base,  (int32_t *) tensor->op_params + 4, sizeof(float));
+                    memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float));
+                    memcpy(&xpos_base,  (int32_t *) tensor->op_params + 6, sizeof(float));
+                    memcpy(&xpos_down,  (int32_t *) tensor->op_params + 7, sizeof(bool));
 
                     src0->grad = ggml_add_impl(ctx,
                             src0->grad,
@@ -16383,10 +16380,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     float freq_scale;
                     float xpos_base;
                     bool  xpos_down;
-                                       memcpy(&freq_base,  (int32_t *) tensor->op_params + 4, sizeof(float));
-                                       memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float));
-                                       memcpy(&xpos_base,  (int32_t *) tensor->op_params + 6, sizeof(float));
-                                       memcpy(&xpos_down,  (int32_t *) tensor->op_params + 7, sizeof(bool));
+                    memcpy(&freq_base,  (int32_t *) tensor->op_params + 4, sizeof(float));
+                    memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float));
+                    memcpy(&xpos_base,  (int32_t *) tensor->op_params + 6, sizeof(float));
+                    memcpy(&xpos_down,  (int32_t *) tensor->op_params + 7, sizeof(bool));
 
                     src0->grad = ggml_add_impl(ctx,
                             src0->grad,
@@ -19392,7 +19389,7 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
 ////////////////////////////////////////////////////////////////////////////////
 
 struct gguf_str {
-    uint32_t n;
+    uint64_t n;  // GGUFv2
     char * data;
 };
 
@@ -19406,9 +19403,12 @@ static const size_t GGUF_TYPE_SIZE[GGUF_TYPE_COUNT] = {
     [GGUF_TYPE_FLOAT32] = sizeof(float),
     [GGUF_TYPE_BOOL]    = sizeof(bool),
     [GGUF_TYPE_STRING]  = sizeof(struct gguf_str),
+    [GGUF_TYPE_UINT64]  = sizeof(uint64_t),
+    [GGUF_TYPE_INT64]   = sizeof(int64_t),
+    [GGUF_TYPE_FLOAT64] = sizeof(double),
     [GGUF_TYPE_ARRAY]   = 0, // undefined
 };
-static_assert(GGUF_TYPE_COUNT == 10, "GGUF_TYPE_COUNT != 10");
+static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13");
 
 static const char * GGUF_TYPE_NAME[GGUF_TYPE_COUNT] = {
     [GGUF_TYPE_UINT8]   = "u8",
@@ -19421,8 +19421,11 @@ static const char * GGUF_TYPE_NAME[GGUF_TYPE_COUNT] = {
     [GGUF_TYPE_BOOL]    = "bool",
     [GGUF_TYPE_STRING]  = "str",
     [GGUF_TYPE_ARRAY]   = "arr",
+    [GGUF_TYPE_UINT64]  = "u64",
+    [GGUF_TYPE_INT64]   = "i64",
+    [GGUF_TYPE_FLOAT64] = "f64",
 };
-static_assert(GGUF_TYPE_COUNT == 10, "GGUF_TYPE_COUNT != 10");
+static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13");
 
 union gguf_value {
     uint8_t  uint8;
@@ -19432,6 +19435,9 @@ union gguf_value {
     uint32_t uint32;
     int32_t  int32;
     float    float32;
+    uint64_t uint64;
+    int64_t  int64;
+    double   float64;
     bool     bool_;
 
     struct gguf_str str;
@@ -19439,7 +19445,7 @@ union gguf_value {
     struct {
         enum gguf_type type;
 
-        uint32_t n;
+        uint64_t n;  // GGUFv2
         void * data;
     } arr;
 };
@@ -19447,8 +19453,6 @@ union gguf_value {
 struct gguf_kv {
     struct gguf_str key;
 
-    uint32_t n_bytes; // TODO: is this actually needed?
-
     enum  gguf_type  type;
     union gguf_value value;
 };
@@ -19456,15 +19460,15 @@ struct gguf_kv {
 struct gguf_header {
     uint32_t magic;
     uint32_t version;
-    uint32_t n_tensors;
-    uint32_t n_kv;
+    uint64_t n_tensors; // GGUFv2
+    uint64_t n_kv;      // GGUFv2
 };
 
 struct gguf_tensor_info {
     struct gguf_str name;
 
     uint32_t n_dims;
-    uint32_t ne[GGML_MAX_DIMS];
+    uint64_t ne[GGML_MAX_DIMS];
 
     enum ggml_type type;
 
@@ -19495,19 +19499,32 @@ static bool gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset)
     return n == size;
 }
 
-static bool gguf_fread_str(FILE * file, struct gguf_str * p, size_t * offset) {
+// NOTE: temporary handling of GGUFv1 >> remove after Oct 2023
+static bool gguf_fread_str_cur(FILE * file, struct gguf_str * p, size_t * offset) {
     p->n    = 0;
     p->data = NULL;
 
     bool ok = true;
 
-    // TODO: how to avoid mallocs for strings?
     ok = ok && gguf_fread_el(file, &p->n,    sizeof(p->n), offset); p->data = calloc(p->n + 1, 1);
     ok = ok && gguf_fread_el(file,  p->data, p->n,         offset);
 
     return ok;
 }
 
+static bool gguf_fread_str_v1(FILE * file, struct gguf_str * p, size_t * offset) {
+    p->n    = 0;
+    p->data = NULL;
+
+    bool ok = true;
+
+    uint32_t n = 0;
+    ok = ok && gguf_fread_el(file, &n,       sizeof(n), offset); p->data = calloc(n + 1, 1); p->n = n;
+    ok = ok && gguf_fread_el(file,  p->data, p->n,      offset);
+
+    return ok;
+}
+
 struct gguf_context * gguf_init_empty(void) {
     struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context));
 
@@ -19563,8 +19580,21 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
         ctx->data  = NULL;
 
         ok = ok && gguf_fread_el(file, &ctx->header.version,   sizeof(ctx->header.version),   &offset);
-        ok = ok && gguf_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset);
-        ok = ok && gguf_fread_el(file, &ctx->header.n_kv,      sizeof(ctx->header.n_kv),      &offset);
+
+        if (ctx->header.version == 1) {
+            // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023
+            uint32_t n_tensors = 0;
+            uint32_t n_kv      = 0;
+
+            ok = ok && gguf_fread_el(file, &n_tensors, sizeof(n_tensors), &offset);
+            ok = ok && gguf_fread_el(file, &n_kv,      sizeof(n_kv),      &offset);
+
+            ctx->header.n_tensors = n_tensors;
+            ctx->header.n_kv      = n_kv;
+        } else {
+            ok = ok && gguf_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset);
+            ok = ok && gguf_fread_el(file, &ctx->header.n_kv,      sizeof(ctx->header.n_kv),      &offset);
+        }
 
         if (!ok) {
             fprintf(stderr, "%s: failed to read header\n", __func__);
@@ -19574,6 +19604,12 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
         }
     }
 
+    // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023
+    bool (* gguf_fread_str)(FILE *, struct gguf_str *, size_t *) = gguf_fread_str_cur;
+    if (ctx->header.version == 1) {
+        gguf_fread_str = gguf_fread_str_v1;
+    }
+
     // read the kv pairs
     {
         ctx->kv = GGML_ALIGNED_MALLOC(ctx->header.n_kv * sizeof(struct gguf_kv));
@@ -19583,9 +19619,8 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
 
             //fprintf(stderr, "%s: reading kv %d\n", __func__, i);
 
-            ok = ok && gguf_fread_str(file, &kv->key,                          &offset);
-          //ok = ok && gguf_fread_el (file, &kv->n_bytes, sizeof(kv->n_bytes), &offset);
-            ok = ok && gguf_fread_el (file, &kv->type,    sizeof(kv->type),    &offset);
+            ok = ok && gguf_fread_str(file, &kv->key,                    &offset);
+            ok = ok && gguf_fread_el (file, &kv->type, sizeof(kv->type), &offset);
 
             //fprintf(stderr, "%s: reading kv with key %s\n", __func__, kv->key.data);
 
@@ -19597,12 +19632,23 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
                 case GGUF_TYPE_UINT32:  ok = ok && gguf_fread_el (file, &kv->value.uint32,  sizeof(kv->value.uint32),  &offset); break;
                 case GGUF_TYPE_INT32:   ok = ok && gguf_fread_el (file, &kv->value.int32,   sizeof(kv->value.int32),   &offset); break;
                 case GGUF_TYPE_FLOAT32: ok = ok && gguf_fread_el (file, &kv->value.float32, sizeof(kv->value.float32), &offset); break;
+                case GGUF_TYPE_UINT64:  ok = ok && gguf_fread_el (file, &kv->value.uint64,  sizeof(kv->value.uint64),  &offset); break;
+                case GGUF_TYPE_INT64:   ok = ok && gguf_fread_el (file, &kv->value.int64,   sizeof(kv->value.int64),   &offset); break;
+                case GGUF_TYPE_FLOAT64: ok = ok && gguf_fread_el (file, &kv->value.float64, sizeof(kv->value.float64), &offset); break;
                 case GGUF_TYPE_BOOL:    ok = ok && gguf_fread_el (file, &kv->value.bool_,   sizeof(kv->value.bool_),   &offset); break;
                 case GGUF_TYPE_STRING:  ok = ok && gguf_fread_str(file, &kv->value.str,                                &offset); break;
                 case GGUF_TYPE_ARRAY:
                     {
                         ok = ok && gguf_fread_el(file, &kv->value.arr.type, sizeof(kv->value.arr.type), &offset);
-                        ok = ok && gguf_fread_el(file, &kv->value.arr.n,    sizeof(kv->value.arr.n),    &offset);
+
+                        if (ctx->header.version == 1) {
+                            // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023
+                            uint32_t n = 0;
+                            ok = ok && gguf_fread_el(file, &n, sizeof(n), &offset);
+                            kv->value.arr.n = n;
+                        } else {
+                            ok = ok && gguf_fread_el(file, &kv->value.arr.n, sizeof(kv->value.arr.n), &offset);
+                        }
 
                         switch (kv->value.arr.type) {
                             case GGUF_TYPE_UINT8:
@@ -19612,6 +19658,9 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
                             case GGUF_TYPE_UINT32:
                             case GGUF_TYPE_INT32:
                             case GGUF_TYPE_FLOAT32:
+                            case GGUF_TYPE_UINT64:
+                            case GGUF_TYPE_INT64:
+                            case GGUF_TYPE_FLOAT64:
                             case GGUF_TYPE_BOOL:
                                 {
                                     kv->value.arr.data = malloc(kv->value.arr.n * GGUF_TYPE_SIZE[kv->value.arr.type]);
@@ -19658,7 +19707,14 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
             ok = ok && gguf_fread_str(file, &info->name,                          &offset);
             ok = ok && gguf_fread_el (file, &info->n_dims, sizeof(info->n_dims),  &offset);
             for (uint32_t j = 0; j < info->n_dims; ++j) {
-                ok = ok && gguf_fread_el(file, &info->ne[j], sizeof(info->ne[j]), &offset);
+                if (ctx->header.version == 1) {
+                    // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023
+                    uint32_t t = 0;
+                    ok = ok && gguf_fread_el(file, &t, sizeof(t), &offset);
+                    info->ne[j] = t;
+                } else {
+                    ok = ok && gguf_fread_el(file, &info->ne[j], sizeof(info->ne[j]), &offset);
+                }
             }
             ok = ok && gguf_fread_el (file, &info->type,   sizeof(info->type),    &offset);
             ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset),  &offset);
@@ -19952,6 +20008,18 @@ float gguf_get_val_f32(struct gguf_context * ctx, int i) {
     return ctx->kv[i].value.float32;
 }
 
+uint64_t gguf_get_val_u64(struct gguf_context * ctx, int i) {
+    return ctx->kv[i].value.uint64;
+}
+
+int64_t gguf_get_val_i64(struct gguf_context * ctx, int i) {
+    return ctx->kv[i].value.int64;
+}
+
+double gguf_get_val_f64(struct gguf_context * ctx, int i) {
+    return ctx->kv[i].value.float64;
+}
+
 bool gguf_get_val_bool(struct gguf_context * ctx, int i) {
     return ctx->kv[i].value.bool_;
 }
@@ -19998,7 +20066,7 @@ static int gguf_get_or_add_key(struct gguf_context * ctx, const char * key) {
     const int n_kv = gguf_get_n_kv(ctx);
 
     ctx->kv = realloc(ctx->kv, (n_kv + 1) * sizeof(struct gguf_kv));
-    ctx->kv[n_kv].key.n    = strlen(key) + 1;
+    ctx->kv[n_kv].key.n    = strlen(key);
     ctx->kv[n_kv].key.data = strdup(key);
     ctx->header.n_kv++;
 
@@ -20054,6 +20122,27 @@ void gguf_set_val_f32(struct gguf_context * ctx, const char * key, float val) {
     ctx->kv[idx].value.float32 = val;
 }
 
+void gguf_set_val_u64(struct gguf_context * ctx, const char * key, uint64_t val) {
+    const int idx = gguf_get_or_add_key(ctx, key);
+
+    ctx->kv[idx].type         = GGUF_TYPE_UINT64;
+    ctx->kv[idx].value.uint64 = val;
+}
+
+void gguf_set_val_i64(struct gguf_context * ctx, const char * key, int64_t val) {
+    const int idx = gguf_get_or_add_key(ctx, key);
+
+    ctx->kv[idx].type        = GGUF_TYPE_INT64;
+    ctx->kv[idx].value.int64 = val;
+}
+
+void gguf_set_val_f64(struct gguf_context * ctx, const char * key, double val) {
+    const int idx = gguf_get_or_add_key(ctx, key);
+
+    ctx->kv[idx].type          = GGUF_TYPE_FLOAT64;
+    ctx->kv[idx].value.float64 = val;
+}
+
 void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val) {
     const int idx = gguf_get_or_add_key(ctx, key);
 
@@ -20065,7 +20154,7 @@ void gguf_set_val_str(struct gguf_context * ctx, const char * key, const char *
     const int idx = gguf_get_or_add_key(ctx, key);
 
     ctx->kv[idx].type           = GGUF_TYPE_STRING;
-    ctx->kv[idx].value.str.n    = strlen(val) + 1;
+    ctx->kv[idx].value.str.n    = strlen(val);
     ctx->kv[idx].value.str.data = strdup(val);
 }
 
@@ -20088,7 +20177,7 @@ void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char **
     ctx->kv[idx].value.arr.data = malloc(n*sizeof(struct gguf_str));
     for (int i = 0; i < n; i++) {
         struct gguf_str * str = &((struct gguf_str *)ctx->kv[idx].value.arr.data)[i];
-        str->n    = strlen(data[i]) + 1;
+        str->n    = strlen(data[i]);
         str->data = strdup(data[i]);
     }
 }
@@ -20104,6 +20193,9 @@ void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) {
             case GGUF_TYPE_UINT32:  gguf_set_val_u32 (ctx, src->kv[i].key.data, src->kv[i].value.uint32);   break;
             case GGUF_TYPE_INT32:   gguf_set_val_i32 (ctx, src->kv[i].key.data, src->kv[i].value.int32);    break;
             case GGUF_TYPE_FLOAT32: gguf_set_val_f32 (ctx, src->kv[i].key.data, src->kv[i].value.float32);  break;
+            case GGUF_TYPE_UINT64:  gguf_set_val_u64 (ctx, src->kv[i].key.data, src->kv[i].value.uint64);   break;
+            case GGUF_TYPE_INT64:   gguf_set_val_i64 (ctx, src->kv[i].key.data, src->kv[i].value.int64);    break;
+            case GGUF_TYPE_FLOAT64: gguf_set_val_f64 (ctx, src->kv[i].key.data, src->kv[i].value.float64);  break;
             case GGUF_TYPE_BOOL:    gguf_set_val_bool(ctx, src->kv[i].key.data, src->kv[i].value.bool_);    break;
             case GGUF_TYPE_STRING:  gguf_set_val_str (ctx, src->kv[i].key.data, src->kv[i].value.str.data); break;
             case GGUF_TYPE_ARRAY:
@@ -20132,7 +20224,7 @@ void gguf_add_tensor(
     const int idx = ctx->header.n_tensors;
     ctx->infos = realloc(ctx->infos, (idx + 1)*sizeof(struct gguf_tensor_info));
 
-    ctx->infos[idx].name.n    = strlen(tensor->name) + 1;
+    ctx->infos[idx].name.n    = strlen(tensor->name);
     ctx->infos[idx].name.data = strdup(tensor->name);
 
     for (int i = 0; i < GGML_MAX_DIMS; ++i) {
@@ -20265,6 +20357,9 @@ static void gguf_write_to_buf(struct gguf_context * ctx, struct gguf_buf * buf,
             case GGUF_TYPE_UINT32:  gguf_bwrite_el (buf, &kv->value.uint32,  sizeof(kv->value.uint32) ); break;
             case GGUF_TYPE_INT32:   gguf_bwrite_el (buf, &kv->value.int32,   sizeof(kv->value.int32)  ); break;
             case GGUF_TYPE_FLOAT32: gguf_bwrite_el (buf, &kv->value.float32, sizeof(kv->value.float32)); break;
+            case GGUF_TYPE_UINT64:  gguf_bwrite_el (buf, &kv->value.uint64,  sizeof(kv->value.uint64) ); break;
+            case GGUF_TYPE_INT64:   gguf_bwrite_el (buf, &kv->value.int64,   sizeof(kv->value.int64)  ); break;
+            case GGUF_TYPE_FLOAT64: gguf_bwrite_el (buf, &kv->value.float64, sizeof(kv->value.float64)); break;
             case GGUF_TYPE_BOOL:    gguf_bwrite_el (buf, &kv->value.bool_,   sizeof(kv->value.bool_)  ); break;
             case GGUF_TYPE_STRING:  gguf_bwrite_str(buf, &kv->value.str                               ); break;
             case GGUF_TYPE_ARRAY:
@@ -20280,6 +20375,9 @@ static void gguf_write_to_buf(struct gguf_context * ctx, struct gguf_buf * buf,
                         case GGUF_TYPE_UINT32:
                         case GGUF_TYPE_INT32:
                         case GGUF_TYPE_FLOAT32:
+                        case GGUF_TYPE_UINT64:
+                        case GGUF_TYPE_INT64:
+                        case GGUF_TYPE_FLOAT64:
                         case GGUF_TYPE_BOOL:
                             {
                                 gguf_bwrite_el(buf, kv->value.arr.data, kv->value.arr.n * GGUF_TYPE_SIZE[kv->value.arr.type]);
@@ -20514,6 +20612,14 @@ int ggml_cpu_has_sse3(void) {
 #endif
 }
 
+int ggml_cpu_has_ssse3(void) {
+#if defined(__SSSE3__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
 int ggml_cpu_has_vsx(void) {
 #if defined(__POWER9_VECTOR__)
     return 1;