]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : backport llama.cpp updates (close #709)
authorGeorgi Gerganov <redacted>
Mon, 10 Apr 2023 19:28:54 +0000 (22:28 +0300)
committerGeorgi Gerganov <redacted>
Mon, 10 Apr 2023 19:28:54 +0000 (22:28 +0300)
- About x2 overall performance improvement on Apple Silicon
- Results should now be the same for different number of threads (not
  tested)

ggml.c
ggml.h
whisper.cpp

diff --git a/ggml.c b/ggml.c
index ba0441940f8076912f3cb292c7873dbef76eed29..3942379ecdd2685436a9a84526fed7d72dd44fa7 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -16,6 +16,7 @@
 #include <stdlib.h>
 #include <string.h>
 #include <stdint.h>
+#include <inttypes.h>
 #include <stdio.h>
 #include <float.h>
 
@@ -79,20 +80,22 @@ static int sched_yield (void) {
 typedef void* thread_ret_t;
 #endif
 
-#ifdef __HAIKU__
-#define static_assert(cond, msg) _Static_assert(cond, msg)
+// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
+#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))
+#ifndef __FMA__
+#define __FMA__
+#endif
+#ifndef __F16C__
+#define __F16C__
+#endif
+#ifndef __SSE3__
+#define __SSE3__
 #endif
-
-#define GGML_MLOCK_SUPPORT 0
-
-#ifdef __has_include
-    #if __has_include(<sys/mman.h>)
-        #undef GGML_MLOCK_SUPPORT
-        #define GGML_MLOCK_SUPPORT 1
-        #include <sys/mman.h>
-    #endif
 #endif
 
+#ifdef __HAIKU__
+#define static_assert(cond, msg) _Static_assert(cond, msg)
+#endif
 
 /*#define GGML_PERF*/
 #define GGML_DEBUG 0
@@ -150,10 +153,10 @@ typedef double ggml_float;
 //
 #include <arm_neon.h>
 
-#define GGML_COMPUTE_FP16_TO_FP32(x) (x)
+#define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x))
 #define GGML_COMPUTE_FP32_TO_FP16(x) (x)
 
-#define GGML_FP16_TO_FP32(x) (x)
+#define GGML_FP16_TO_FP32(x) ((float) (x))
 #define GGML_FP32_TO_FP16(x) (x)
 
 #else
@@ -172,8 +175,13 @@ typedef double ggml_float;
 
 #ifdef __F16C__
 
+#ifdef _MSC_VER
+#define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
+#define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
+#else
 #define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
 #define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
+#endif
 
 #elif defined(__POWER9_VECTOR__)
 
@@ -322,7 +330,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
 // note: do not use these inside ggml.c
 // these are meant to be used via the ggml.h API
 float ggml_fp16_to_fp32(ggml_fp16_t x) {
-    return GGML_FP16_TO_FP32(x);
+    return (float) GGML_FP16_TO_FP32(x);
 }
 
 ggml_fp16_t ggml_fp32_to_fp16(float x) {
@@ -443,22 +451,65 @@ static inline __m128i packNibbles( __m256i bytes )
     __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
     return _mm_packus_epi16( r0, r1 );
 }
+#elif __AVX__
+static inline __m128i bytesFromNibbles( const uint8_t* rsi )
+{
+    // Load 8 bytes from memory
+    __m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
+
+    // Expand bytes into uint16_t values
+    __m128i bytes = _mm_cvtepu8_epi16( tmp );
+
+    // Unpack values into individual bytes
+    const __m128i lowMask = _mm_set1_epi8( 0xF );
+    __m128i high = _mm_andnot_si128( lowMask, bytes );
+    __m128i low = _mm_and_si128( lowMask, bytes );
+    high = _mm_slli_epi16( high, 4 );
+    bytes = _mm_or_si128( low, high );
+    return bytes;
+}
+
+static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
+{
+    // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
+    const __m128i lowByte = _mm_set1_epi16( 0xFF );
+    __m128i high = _mm_andnot_si128( lowByte, bytes1 );
+    __m128i low = _mm_and_si128( lowByte, bytes1 );
+    high = _mm_srli_epi16( high, 4 );
+    bytes1 = _mm_or_si128( low, high );
+    high = _mm_andnot_si128( lowByte, bytes2 );
+    low = _mm_and_si128( lowByte, bytes2 );
+    high = _mm_srli_epi16( high, 4 );
+    bytes2 = _mm_or_si128( low, high );
+
+    return _mm_packus_epi16( bytes1, bytes2);
+}
 #endif
 
 // method 5
 // blocks of QK elements
 // represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
+typedef struct {
+    float   d; // delta
+    uint8_t qs[QK / 2]; // nibbles / quants
+} block_q4_0;
+static_assert(sizeof(block_q4_0) == sizeof(float) + QK / 2, "wrong q4_0 block size/padding");
+
+// method 4
+// blocks of QK elements
+// represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
+typedef struct {
+    float   d;
+    float   m;
+    uint8_t qs[QK / 2]; // nibbles / quants
+} block_q4_1;
+static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 block size/padding");
 
 // reference implementation for deterministic creation of model files
-static void quantize_row_q4_0_reference(const float * restrict x, void * restrict y, int k) {
+static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
     assert(k % QK == 0);
     const int nb = k / QK;
 
-    const size_t bs = sizeof(float) + QK/2;
-
-    uint8_t * restrict pd = ((uint8_t *)y + 0*bs);
-    uint8_t * restrict pb = ((uint8_t *)y + 0*bs + sizeof(float));
-
     uint8_t pp[QK/2];
 
     for (int i = 0; i < nb; i++) {
@@ -472,39 +523,30 @@ static void quantize_row_q4_0_reference(const float * restrict x, void * restric
         const float d = amax / ((1 << 3) - 1);
         const float id = d ? 1.0f/d : 0.0f;
 
-        *(float *)pd = d;
-        pd += bs;
+        y[i].d = d;
 
         for (int l = 0; l < QK; l += 2) {
             const float v0 = x[i*QK + l + 0]*id;
             const float v1 = x[i*QK + l + 1]*id;
 
-            const uint8_t vi0 = ((int8_t) (round(v0))) + 8;
-            const uint8_t vi1 = ((int8_t) (round(v1))) + 8;
+            const uint8_t vi0 = (int8_t)roundf(v0) + 8;
+            const uint8_t vi1 = (int8_t)roundf(v1) + 8;
 
-            assert(vi0 >= 0 && vi0 < 16);
-            assert(vi1 >= 0 && vi1 < 16);
+            assert(vi0 < 16);
+            assert(vi1 < 16);
 
             pp[l/2] = vi0 | (vi1 << 4);
         }
 
-        memcpy(pb, pp, sizeof(pp));
-        pb += bs;
+        memcpy(y[i].qs, pp, sizeof(pp));
     }
 }
 
-void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
+static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int k) {
     assert(k % QK == 0);
-
-#if defined(__ARM_NEON) || defined(__AVX2__) || defined(__wasm_simd128__) || defined(__POWER9_VECTOR__)
     const int nb = k / QK;
-    const size_t bs = sizeof(float) + QK/2;
 
-    uint8_t * restrict pd = ((uint8_t *)y + 0*bs);
-    uint8_t * restrict pb = ((uint8_t *)y + 0*bs + sizeof(float));
-
-    uint8_t pp[QK/2];
-#endif
+    block_q4_0 * restrict y = vy;
 
 #if defined(__POWER9_VECTOR__)
     const vector float v85 = vec_splats(8.5f);
@@ -532,10 +574,10 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
         const float d = amax / ((1 << 3) - 1);
         const float id = d ? 1.0/d : 0.0;
 
-        *(float *)pd = d;
-        pd += bs;
+        y[i].d = d;
 
         const vector float vid = vec_splats(id);
+        uint8_t * restrict pb = y[i].qs;
         for (int l = 0; l < 8; l++) {
             const vector float vf  = vec_madd(srcv[l], vid, v85);
             const vector signed int vi = vec_signed(vf);
@@ -543,14 +585,9 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
             pb[2*l + 0] = vec_extract(vi, 0) | (vec_extract(vi, 1) << 4);
             pb[2*l + 1] = vec_extract(vi, 2) | (vec_extract(vi, 3) << 4);
         }
-
-        //memcpy(pb, pp, sizeof(pp));
-        pb += bs;
     }
 #elif __ARM_NEON
     for (int i = 0; i < nb; i++) {
-        float amax = 0.0f; // absolute max
-
         float32x4_t srcv [8];
         float32x4_t asrcv[8];
         float32x4_t amaxv[8];
@@ -562,27 +599,21 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
         for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
         for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
 
-        amax = MAX(
-                MAX(vgetq_lane_f32(amaxv[0], 0), vgetq_lane_f32(amaxv[0], 1)),
-                MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3)));
+        const float amax = vmaxvq_f32(amaxv[0]);
 
         const float d = amax / ((1 << 3) - 1);
-        const float id = d ? 1.0/d : 0.0;
+        const float id = d ? 1.0f/d : 0.0f;
 
-        *(float *)pd = d;
-        pd += bs;
+        y[i].d = d;
 
         for (int l = 0; l < 8; l++) {
             const float32x4_t v  = vmulq_n_f32(srcv[l], id);
             const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
             const int32x4_t   vi = vcvtq_s32_f32(vf);
 
-            pp[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
-            pp[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
+            y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
+            y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
         }
-
-        memcpy(pb, pp, sizeof(pp));
-        pb += bs;
     }
 #elif defined(__AVX2__)
     for (int i = 0; i < nb; i++) {
@@ -607,8 +638,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
 
         // Quantize these floats
         const float d = maxScalar / 7.0f;
-        *(float *)pd = d;
-        pd += bs;
+        y[i].d = d;
         const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
         const __m256 mul = _mm256_set1_ps( id );
 
@@ -648,8 +678,81 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
 
         // Compress the vector into 4 bit/value, and store
         __m128i res = packNibbles( i0 );
-        _mm_storeu_si128( ( __m128i* )pb, res );
-        pb += bs;
+        _mm_storeu_si128( ( __m128i* )y[i].qs, res );
+    }
+#elif defined(__AVX__)
+    for (int i = 0; i < nb; i++) {
+        // Load elements into 4 AVX vectors
+        __m256 v0 = _mm256_loadu_ps( x );
+        __m256 v1 = _mm256_loadu_ps( x + 8 );
+        __m256 v2 = _mm256_loadu_ps( x + 16 );
+        __m256 v3 = _mm256_loadu_ps( x + 24 );
+        x += 32;
+
+        // Compute max(abs(e)) for the block
+        const __m256 signBit = _mm256_set1_ps( -0.0f );
+        __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
+
+        __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+        max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+        max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+        const float maxScalar = _mm_cvtss_f32( max4 );
+
+        // Quantize these floats
+        const float d = maxScalar / 7.0f;
+        y[i].d = d;
+        const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
+        const __m256 mul = _mm256_set1_ps( id );
+
+        // Apply the multiplier
+        v0 = _mm256_mul_ps( v0, mul );
+        v1 = _mm256_mul_ps( v1, mul );
+        v2 = _mm256_mul_ps( v2, mul );
+        v3 = _mm256_mul_ps( v3, mul );
+
+        // Round to nearest integer
+        v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+        v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+        v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+        v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+        // Convert floats to integers
+        __m256i i0 = _mm256_cvtps_epi32( v0 );
+        __m256i i1 = _mm256_cvtps_epi32( v1 );
+        __m256i i2 = _mm256_cvtps_epi32( v2 );
+        __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+        // Since we don't have in AVX some necessary functions,
+        // we split the registers in half and call AVX2 analogs from SSE
+        __m128i ni0 = _mm256_castsi256_si128( i0 );
+        __m128i ni1 = _mm256_extractf128_si256( i0, 1);
+        __m128i ni2 = _mm256_castsi256_si128( i1 );
+        __m128i ni3 = _mm256_extractf128_si256( i1, 1);
+        __m128i ni4 = _mm256_castsi256_si128( i2 );
+        __m128i ni5 = _mm256_extractf128_si256( i2, 1);
+        __m128i ni6 = _mm256_castsi256_si128( i3 );
+        __m128i ni7 = _mm256_extractf128_si256( i3, 1);
+
+        // Convert int32 to int16
+        ni0 = _mm_packs_epi32( ni0, ni1 );
+        ni2 = _mm_packs_epi32( ni2, ni3 );
+        ni4 = _mm_packs_epi32( ni4, ni5 );
+        ni6 = _mm_packs_epi32( ni6, ni7 );
+        // Convert int16 to int8
+        ni0 = _mm_packs_epi16( ni0, ni2 );
+        ni4 = _mm_packs_epi16( ni4, ni6 );
+
+        // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
+        const __m128i off = _mm_set1_epi8( 8);
+        ni0 = _mm_add_epi8( ni0, off );
+        ni4 = _mm_add_epi8( ni4, off );
+
+        // Compress the vector into 4 bit/value, and store
+        __m128i res = packNibbles( ni0, ni4 );
+        _mm_storeu_si128( ( __m128i* )y[i].qs, res );
     }
 #elif defined(__wasm_simd128__)
     for (int i = 0; i < nb; i++) {
@@ -673,20 +776,16 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
         const float d = amax / ((1 << 3) - 1);
         const float id = d ? 1.0/d : 0.0;
 
-        *(float *)pd = d;
-        pd += bs;
+        y[i].d = d;
 
         for (int l = 0; l < 8; l++) {
             const v128_t v  = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id));
             const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
             const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);
 
-            pp[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4);
-            pp[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
+            y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4);
+            y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
         }
-
-        memcpy(pb, pp, sizeof(pp));
-        pb += bs;
     }
 #else
     // scalar
@@ -694,18 +793,11 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
 #endif
 }
 
-// method 4
-// blocks of QK elements
-// represented with 2 floats (min + delta) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
-void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
+static void quantize_row_q4_1_reference(const float * restrict x, void * restrict vy, int k) {
     assert(k % QK == 0);
-
     const int nb = k / QK;
-    const size_t bs = 2*sizeof(float) + QK/2;
 
-    uint8_t * restrict pd = ((uint8_t *)y + 0*bs);
-    uint8_t * restrict pm = ((uint8_t *)y + 0*bs +   sizeof(float));
-    uint8_t * restrict pb = ((uint8_t *)y + 0*bs + 2*sizeof(float));
+    block_q4_1 * restrict y = vy;
 
     uint8_t pp[QK/2];
 
@@ -722,45 +814,161 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
         const float d = (max - min) / ((1 << 4) - 1);
         const float id = d ? 1.0f/d : 0.0f;
 
-        *(float *)pm = min;
-        *(float *)pd = d;
-        pm += bs;
-        pd += bs;
+        y[i].d = d;
+        y[i].m = min;
 
         for (int l = 0; l < QK; l += 2) {
             const float v0 = (x[i*QK + l + 0] - min)*id;
             const float v1 = (x[i*QK + l + 1] - min)*id;
 
-            const uint8_t vi0 = round(v0);
-            const uint8_t vi1 = round(v1);
+            const uint8_t vi0 = roundf(v0);
+            const uint8_t vi1 = roundf(v1);
 
-            assert(vi0 >= 0 && vi0 < 16);
-            assert(vi1 >= 0 && vi1 < 16);
+            assert(vi0 < 16);
+            assert(vi1 < 16);
 
             pp[l/2] = vi0 | (vi1 << 4);
         }
 
-        memcpy(pb, pp, sizeof(pp));
-        pb += bs;
+        memcpy(y[i].qs, pp, sizeof(pp));
     }
 }
 
-// TODO: vectorize
-void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
+static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) {
     assert(k % QK == 0);
 
     const int nb = k / QK;
-    const size_t bs = sizeof(float) + QK/2;
 
-    const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs);
-    const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + sizeof(float));
+    block_q4_1 * restrict y = vy;
+
+#if defined(__AVX2__)
+    for (int i = 0; i < nb; i++) {
+        // Load elements into 4 AVX vectors
+        __m256 v0 = _mm256_loadu_ps( x );
+        __m256 v1 = _mm256_loadu_ps( x + 8 );
+        __m256 v2 = _mm256_loadu_ps( x + 16 );
+        __m256 v3 = _mm256_loadu_ps( x + 24 );
+        x += 32;
+
+        // Compute max for the block
+        __m256 vmax;
+        vmax = _mm256_max_ps( v0, v1 );
+        vmax = _mm256_max_ps( vmax, v2 );
+        vmax = _mm256_max_ps( vmax, v3 );
+
+        __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( vmax, 1 ), _mm256_castps256_ps128( vmax ) );
+        max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+        max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+        const float maxScalar = _mm_cvtss_f32( max4 );
+
+        // Compute min for the block
+        __m256 vmin;
+        vmin = _mm256_min_ps( v0, v1 );
+        vmin = _mm256_min_ps( vmin, v2 );
+        vmin = _mm256_min_ps( vmin, v3 );
+
+        __m128 min4 = _mm_min_ps( _mm256_extractf128_ps( vmin, 1 ), _mm256_castps256_ps128( vmin ) );
+        min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
+        min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
+        const float minScalar = _mm_cvtss_f32( min4 );
+
+        // Quantize these floats
+        const float d = (maxScalar - minScalar) / ((1 << 4) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].m = minScalar;
+        y[i].d = d;
+
+        // x = (x-min)*id
+        const __m256 mul = _mm256_set1_ps( id );
+        const __m256 off = _mm256_set1_ps( minScalar );
+        v0 = _mm256_mul_ps( _mm256_sub_ps( v0, off ), mul );
+        v1 = _mm256_mul_ps( _mm256_sub_ps( v1, off ), mul );
+        v2 = _mm256_mul_ps( _mm256_sub_ps( v2, off ), mul );
+        v3 = _mm256_mul_ps( _mm256_sub_ps( v3, off ), mul );
+
+        // Round to nearest integer
+        v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+        v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+        v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+        v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+        // Convert floats to integers
+        __m256i i0 = _mm256_cvtps_epi32( v0 );
+        __m256i i1 = _mm256_cvtps_epi32( v1 );
+        __m256i i2 = _mm256_cvtps_epi32( v2 );
+        __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+        // Convert int32 to int16
+        i0 = _mm256_packs_epi32( i0, i1 );     // 0, 1, 2, 3,  8, 9, 10, 11,  4, 5, 6, 7, 12, 13, 14, 15
+        i2 = _mm256_packs_epi32( i2, i3 );     // 16, 17, 18, 19,  24, 25, 26, 27,  20, 21, 22, 23, 28, 29, 30, 31
+                                            // Convert int16 to int8
+        i0 = _mm256_packs_epi16( i0, i2 );     // 0, 1, 2, 3,  8, 9, 10, 11,  16, 17, 18, 19,  24, 25, 26, 27,  4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+
+        // We got our precious signed bytes, but the order is now wrong
+        // These AVX2 pack instructions process 16-byte pieces independently
+        // The following instruction is fixing the order
+        const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
+        i0 = _mm256_permutevar8x32_epi32( i0, perm );
+
+        // Compress the vector into 4 bit/value, and store
+        __m128i res = packNibbles( i0 );
+        _mm_storeu_si128( ( __m128i* )y[i].qs, res );
+    }
+#elif __ARM_NEON
+    for (int i = 0; i < nb; i++) {
+        float32x4_t srcv[8];
+        float32x4_t minv[8];
+        float32x4_t maxv[8];
+
+        for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*QK + 4*l);
+
+        for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]);
+        for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]);
+        for (int l = 0; l < 1; l++) minv[8*l] = vminq_f32(minv[8*l], minv[8*l + 4]);
+
+        for (int l = 0; l < 4; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l + 1]);
+        for (int l = 0; l < 2; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l + 2]);
+        for (int l = 0; l < 1; l++) maxv[8*l] = vmaxq_f32(maxv[8*l], maxv[8*l + 4]);
+
+        const float min = vminvq_f32(minv[0]);
+        const float max = vmaxvq_f32(maxv[0]);
+
+        const float d = (max - min) / ((1 << 4) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = d;
+        y[i].m = min;
+
+        const float32x4_t minv0 = vdupq_n_f32(min);
+
+        for (int l = 0; l < 8; l++) {
+            const float32x4_t v  = vmulq_n_f32(vsubq_f32(srcv[l], minv0), id);
+            const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(0.5f)); // needed to round to nearest
+            const int32x4_t   vi = vcvtq_s32_f32(vf);
+
+            y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
+            y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
+        }
+    }
+#else
+    // scalar
+    quantize_row_q4_1_reference(x, vy, k);
+#endif
+}
+
+static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
+    assert(k % QK == 0);
+    const int nb = k / QK;
+
+    const block_q4_0 * restrict x = vx;
 
 #if defined(__AVX2__)
     for (int i = 0; i < nb; i++) {
         // scale factor
-        const __m256 d_v = _mm256_broadcast_ss((const float *) (pd + i*bs));
+        const __m256 d_v = _mm256_broadcast_ss(&x[i].d);
 
-        const uint8_t * restrict pp = pb + i*bs;
+        const uint8_t * restrict pp = x[i].qs;
 
         for (int l = 0; l < QK; l += 32) {
             // Load 32x4-bit integers into 32x8-bit integers
@@ -790,17 +998,15 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
     }
 #elif defined(__ARM_NEON)
     for (int i = 0; i < nb; i++) {
-        const float d = *(const float *) (pd + i*bs);
+        const float32x4_t vd = vdupq_n_f32(x[i].d);
 
-        const uint8_t * restrict pp = pb + i*bs;
-
-        const float32x4_t vd = vdupq_n_f32(d);
+        const uint8_t * restrict pp = x[i].qs;
 
         for (int l = 0; l < QK; l += 16) {
             // Load 16x4-bit integers into 8x8-bit integers
             const uint8x8_t v8 = vld1_u8(pp + l/2);
 
-            // Expand 4-bit nibbles to 8-bit bytes
+            // Expand 4-bit qs to 8-bit bytes
             const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
             const uint8x8_t v1 = vshr_n_u8(v8, 4);
 
@@ -844,9 +1050,9 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
 #else
     // scalar
     for (int i = 0; i < nb; i++) {
-        const float d = *(const float *) (pd + i*bs);
+        const float d = x[i].d;
 
-        const uint8_t * restrict pp = pb + i*bs;
+        const uint8_t * restrict pp = x[i].qs;
 
         for (int l = 0; l < QK; l += 2) {
             const uint8_t vi = pp[l/2];
@@ -869,22 +1075,18 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
 #endif
 }
 
-void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {
+static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, int k) {
     assert(k % QK == 0);
-
     const int nb = k / QK;
-    const size_t bs = 2*sizeof(float) + QK/2;
 
-    const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs);
-    const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs + sizeof(float));
-    const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + 2*sizeof(float));
+    const block_q4_1 * restrict x = vx;
 
 #if defined(__AVX2__)
     for (int i = 0; i < nb; i++) {
-        const __m256 d_v = _mm256_broadcast_ss((const float *) (pd + i*bs));
-        const __m256 d_m = _mm256_broadcast_ss((const float *) (pm + i*bs));
+        const __m256 d_v = _mm256_broadcast_ss(&x[i].d);
+        const __m256 d_m = _mm256_broadcast_ss(&x[i].m);
 
-        const uint8_t * restrict pp = pb + i*bs;
+        const uint8_t * restrict pp = x[i].qs;
 
         for (int l = 0; l < QK; l += 32) {
             // Load 32x4-bit integers into 32x8-bit integers
@@ -909,12 +1111,56 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {
             }
         }
     }
+#elif defined(__ARM_NEON)
+    for (int i = 0; i < nb; i++) {
+        const float32x4_t vd = vdupq_n_f32(x[i].d);
+        const float32x4_t vm = vdupq_n_f32(x[i].m);
+
+        const uint8_t * restrict pp = x[i].qs;
+
+        for (int l = 0; l < QK; l += 16) {
+            // Load 16x4-bit integers into 8x8-bit integers
+            const uint8x8_t v8 = vld1_u8(pp + l/2);
+
+            // Expand 4-bit qs to 8-bit bytes
+            const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
+            const uint8x8_t v1 = vshr_n_u8(v8, 4);
+
+            // Interleave and combine
+            const uint8x8_t vx_0 = vzip1_u8(v0, v1);
+            const uint8x8_t vx_1 = vzip2_u8(v0, v1);
+
+            const uint8x16_t vq = vcombine_u8(vx_0, vx_1);
+
+            // convert to 2x uint16x8_t
+            const uint16x8_t vi_0 = vmovl_u8(vget_low_u8 (vq));
+            const uint16x8_t vi_1 = vmovl_u8(vget_high_u8(vq));
+
+            // convert to 4x float32x4_t
+            const float32x4_t vf_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_0)));
+            const float32x4_t vf_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_0)));
+            const float32x4_t vf_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_1)));
+            const float32x4_t vf_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_1)));
+
+            // multiply by d and add m
+            const float32x4_t r0 = vmlaq_f32(vm, vf_0, vd);
+            const float32x4_t r1 = vmlaq_f32(vm, vf_1, vd);
+            const float32x4_t r2 = vmlaq_f32(vm, vf_2, vd);
+            const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd);
+
+            // Store
+            vst1q_f32(y + i*QK + l +  0, r0);
+            vst1q_f32(y + i*QK + l +  4, r1);
+            vst1q_f32(y + i*QK + l +  8, r2);
+            vst1q_f32(y + i*QK + l + 12, r3);
+        }
+    }
 #else
     for (int i = 0; i < nb; i++) {
-        const float d = *(const float *) (pd + i*bs);
-        const float m = *(const float *) (pm + i*bs);
+        const float d = x[i].d;
+        const float m = x[i].m;
 
-        const uint8_t * restrict pp = pb + i*bs;
+        const uint8_t * restrict pp = x[i].qs;
 
         for (int l = 0; l < QK; l += 2) {
             const uint8_t vi = pp[l/2];
@@ -1027,7 +1273,7 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {
         }                                                         \
         const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
         const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
-        res = vaddvq_f32(vaddq_f32(t0, t1));                      \
+        res = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1));         \
     }
 
     #define GGML_F16_VEC                GGML_F16x8
@@ -1122,13 +1368,36 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {
 #define GGML_F16_EPR  8
 
 // F16 arithmetic is not supported by AVX, so we use F32 instead
-// we take advantage of the _mm256_cvt intrinsics to convert F16 <-> F32
 
 #define GGML_F32Cx8             __m256
 #define GGML_F32Cx8_ZERO        _mm256_setzero_ps()
 #define GGML_F32Cx8_SET1(x)     _mm256_set1_ps(x)
+
+#if defined(__F16C__)
+// the  _mm256_cvt intrinsics require F16C
 #define GGML_F32Cx8_LOAD(x)     _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x)))
 #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
+#else
+static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
+    float tmp[8];
+
+    for (int i = 0; i < 8; i++)
+        tmp[i] = GGML_FP16_TO_FP32(x[i]);
+
+    return _mm256_loadu_ps(tmp);
+}
+static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
+    float arr[8];
+
+    _mm256_storeu_ps(arr, y);
+
+    for (int i = 0; i < 8; i++)
+        x[i] = GGML_FP32_TO_FP16(arr[i]);
+}
+#define GGML_F32Cx8_LOAD(x)     __avx_f32cx8_load(x)
+#define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y)
+#endif
+
 #define GGML_F32Cx8_FMA         GGML_F32x8_FMA
 #define GGML_F32Cx8_ADD         _mm256_add_ps
 #define GGML_F32Cx8_MUL         _mm256_mul_ps
@@ -1440,9 +1709,8 @@ inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, co
 inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]/y[i];   }
 
 inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
-    ggml_float sumf = 0.0;
-
 #ifdef GGML_SIMD
+    float sumf = 0.0f;
     const int np = (n & ~(GGML_F32_STEP - 1));
 
     GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
@@ -1468,8 +1736,9 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
     }
 #else
     // scalar
+    ggml_float sumf = 0.0;
     for (int i = 0; i < n; ++i) {
-        sumf += x[i]*y[i];
+        sumf += (ggml_float)(x[i]*y[i]);
     }
 #endif
 
@@ -1479,25 +1748,15 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
 #if __AVX512F__ && QK == 32
 static inline __m512 dot_q4_0_oneblock_avx512(
     __m512 acc,
-    const uint8_t * pd0,
-    const uint8_t * pd1,
-    const uint8_t * pb0,
-    const uint8_t * pb1,
-    size_t bs,
+    const block_q4_0 * restrict x,
+    const block_q4_0 * restrict y,
     int i
 ) {
-    const float * d0_0 = (const float *) (pd0 + i*bs);
-    const float * d1_0 = (const float *) (pd1 + i*bs);
-
-    const uint8_t * restrict p0 = pb0 + (i+0)*bs;
-    const uint8_t * restrict p1 = pb1 + (i+0)*bs;
-
     // Compute combined scale for the block
-    float scaleScalar = d0_0[0] * d1_0[0];
-    __m512 scale = _mm512_set1_ps( scaleScalar );
+    __m512 d = _mm512_set1_ps( x[i].d * y[i].d );
 
-    __m256i bx = bytesFromNibbles( p0 );
-    __m256i by = bytesFromNibbles( p1 );
+    __m256i bx = bytesFromNibbles( x[i].qs );
+    __m256i by = bytesFromNibbles( y[i].qs );
 
     // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
     const __m256i off = _mm256_set1_epi8( 8 );
@@ -1513,7 +1772,7 @@ static inline __m512 dot_q4_0_oneblock_avx512(
     // Convert int32_t to float
     __m512 p = _mm512_cvtepi32_ps( i64 );
     // Apply the scale, and accumulate
-    return _mm512_fmadd_ps( scale, p, acc );
+    return _mm512_fmadd_ps( d, p, acc );
 }
 #endif
 
@@ -1542,30 +1801,25 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
 
     // leftovers
     for (int i = np; i < n; ++i) {
-        sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]);
+        sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
     }
 #else
     for (int i = 0; i < n; ++i) {
-        sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]);
+        sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
     }
 #endif
 
     *s = sumf;
 }
 
-inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
+static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
     const int nb = n / QK;
 
     assert(n % QK == 0);
     assert(nb % 2 == 0);
 
-    const size_t bs = sizeof(float) + QK/2;
-
-    const uint8_t * restrict pd0 = ((const uint8_t *)x + 0*bs);
-    const uint8_t * restrict pd1 = ((const uint8_t *)y + 0*bs);
-
-    const uint8_t * restrict pb0 = ((const uint8_t *)x + 0*bs + sizeof(float));
-    const uint8_t * restrict pb1 = ((const uint8_t *)y + 0*bs + sizeof(float));
+    const block_q4_0 * restrict x = vx;
+    const block_q4_0 * restrict y = vy;
 
     float sumf = 0.0;
 
@@ -1574,23 +1828,18 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
     float sum1 = 0.0f;
 
     for (int i = 0; i < nb; i += 2) {
-        const float d0_0 = *(const float *) (pd0 + i*bs);
-        const float d1_0 = *(const float *) (pd1 + i*bs);
-        const float d0_1 = *(const float *) (pd0 + (i + 1)*bs);
-        const float d1_1 = *(const float *) (pd1 + (i + 1)*bs);
-
-        //printf("d0_0: %f, d1_0: %f, d0_1: %f, d1_1: %f\n", d0_0, d1_0, d0_1, d1_1);
-
-        const uint8_t * restrict p0 = pb0 + i*bs;
-        const uint8_t * restrict p1 = pb1 + i*bs;
+        const block_q4_0 * restrict x0 = &x[i + 0];
+        const block_q4_0 * restrict y0 = &y[i + 0];
+        const block_q4_0 * restrict x1 = &x[i + 1];
+        const block_q4_0 * restrict y1 = &y[i + 1];
 
         const uint8x16_t m4b = vdupq_n_u8(0xf);
         const int8x16_t  s8b = vdupq_n_s8(0x8);
 
-        const uint8x16_t v0_0 = vld1q_u8(p0);
-        const uint8x16_t v1_0 = vld1q_u8(p1);
-        const uint8x16_t v0_1 = vld1q_u8(p0 + bs);
-        const uint8x16_t v1_1 = vld1q_u8(p1 + bs);
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v1_0 = vld1q_u8(y0->qs);
+        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+        const uint8x16_t v1_1 = vld1q_u8(y1->qs);
 
         // 4-bit -> 8-bit
         const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
@@ -1628,11 +1877,11 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
 
         // scalar
 #if defined(__ARM_FEATURE_QRDMX)
-        sum0 += d0_0*d1_0*vaddvq_s32(p_0);
-        sum1 += d0_1*d1_1*vaddvq_s32(p_1);
+        sum0 += x0->d * y0->d * vaddvq_s32(p_0);
+        sum1 += x1->d * y1->d * vaddvq_s32(p_1);
 #else
-        sum0 += d0_0*d1_0*(vgetq_lane_s32(p_0, 0) + vgetq_lane_s32(p_0, 1) + vgetq_lane_s32(p_0, 2) + vgetq_lane_s32(p_0, 3));
-        sum1 += d0_1*d1_1*(vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3));
+        sum0 += x0->d * y0->d * (vgetq_lane_s32(p_0, 0) + vgetq_lane_s32(p_0, 1) + vgetq_lane_s32(p_0, 2) + vgetq_lane_s32(p_0, 3));
+        sum1 += x1->d * y1->d * (vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3));
 #endif
 #else
            const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
@@ -1658,11 +1907,11 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
 
         // scalar
 #if defined(__ARM_FEATURE_QRDMX)
-        sum0 += d0_0*d1_0*vaddvq_s16(p_0);
-        sum1 += d0_1*d1_1*vaddvq_s16(p_1);
+        sum0 += x0->d * y0->d * vaddvq_s16(p_0);
+        sum1 += x1->d * y1->d * vaddvq_s16(p_1);
 #else
-        sum0 += d0_0*d1_0*(vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7));
-        sum1 += d0_1*d1_1*(vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7));
+        sum0 += x0->d * y0->d * (vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7));
+        sum1 += x1->d * y1->d * (vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7));
 #endif
 #endif
     }
@@ -1675,70 +1924,139 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
 
     const int superblock_size = 8;
     const int superblock_count = nb / superblock_size;
-    const int remainder = nb % superblock_size;
 
     for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
         int i = superblock_ix * superblock_size;
 
-        acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+0 );
-        acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+1 );
-        acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+2 );
-        acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+3 );
-        acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+4 );
-        acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+5 );
-        acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+6 );
-        acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+7 );
+        acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+0 );
+        acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+1 );
+        acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+2 );
+        acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+3 );
+        acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+4 );
+        acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+5 );
+        acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+6 );
+        acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+7 );
     }
 
     // Remainders
     for (int i = superblock_count * superblock_size; i < nb; ++i) {
-        acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i );
+        acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i );
     }
 
     // Horizontal sum of all lanes of the accumulator
     sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 );
 #elif defined(__AVX2__)
-    const size_t countBlocks = nb;
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    /* Prepare the constants we will need during execution */
+    const __m256i lowMask = _mm256_set1_epi8( 0xF );
+    const __m256i offset_8 = _mm256_set1_epi16( 8 );
+
+#define UNROLL_COUNT 8
+    // make sure we only unroll multiples of the block count
+    assert(nb % UNROLL_COUNT == 0);
+
+    // Main loop
+    for (int i = 0; i < nb; i+=UNROLL_COUNT) {
+        // This loop will be unrolled by the compiler
+        for (int u=0;u<UNROLL_COUNT;u++)  {
+            /* Compute combined scale for the block */
+            const __m256 scale = _mm256_mul_ps(
+                    _mm256_broadcast_ss( &x[i+u].d ),
+                    _mm256_broadcast_ss( &y[i+u].d ) );
+
+            /* get input from x
+               Input: 32 Nibbles (16 bytes) at *x[i+u]
+               Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
+
+            /* Load 16 bytes from memory */
+            const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
+            /* Expand bytes into uint16_t values */
+            const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
+            /* Unpack values into individual bytes */
+            __m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
+            const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
+            __m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
+            /* Now we have two vectors with bytes in [ 0 .. 15 ] interval.  Offset them into [ -8 .. +7 ] interval.  */
+            x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
+            x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
+
+            /* get input from y
+               Input: 32 Nibbles (16 bytes) at *y[i+u]
+               Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
+
+            /* Load 16 bytes from memory */
+            const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
+            /* Expand bytes into uint16_t values */
+            const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
+            /* Unpack values into individual bytes */
+            const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
+            __m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
+            __m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
+            /* Now we have two vectors with bytes in [ 0 .. 15 ] interval.  Offset them into [ -8 .. +7 ] interval.  */
+            y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
+            y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
+
+            /* Compute products of int16_t integers, add pairwise, store as int32_t */
+            __m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
+            __m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
+
+            /* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
+            __m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
+
+            /* Convert to vectore of 8 int32_t to 8 floats */
+            __m256 q = _mm256_cvtepi32_ps( xy_q );
+
+            /* Multiply q with scale and accumulate */
+            acc = _mm256_fmadd_ps( scale, q, acc );
+        }
+    }
+
+    // Return horizontal sum of the acc vector
+    __m128 res = _mm256_extractf128_ps( acc, 1 );
+    res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
+    res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
+    res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
 
+    sumf = _mm_cvtss_f32( res );
+#elif defined(__AVX__)
     // Initialize accumulator with zeros
     __m256 acc = _mm256_setzero_ps();
 
     // Main loop
     for (int i = 0; i < nb; ++i) {
-        const float * d0_0 = (const float *) (pd0 + i*bs);
-        const float * d1_0 = (const float *) (pd1 + i*bs);
+        // Compute combined scale for the block
+        const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
 
-        const uint8_t * restrict p0 = pb0 + i*bs;
-        const uint8_t * restrict p1 = pb1 + i*bs;
+        __m128i i32[2];
+        for (int j = 0; j < 2; ++j) {
+            // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
+            __m128i bx = bytesFromNibbles( x[i].qs + 8*j );
+            __m128i by = bytesFromNibbles( y[i].qs + 8*j );
 
-        // Compute combined scale for the block
-        const __m256 scale = _mm256_mul_ps( _mm256_broadcast_ss( d0_0 ), _mm256_broadcast_ss( d1_0 ) );
+            // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+            const __m128i off = _mm_set1_epi8( 8 );
+            bx = _mm_sub_epi8( bx, off );
+            by = _mm_sub_epi8( by, off );
 
-        // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
-        __m256i bx = bytesFromNibbles( p0 );
-        __m256i by = bytesFromNibbles( p1 );
+            // Get absolute values of x vectors
+            const __m128i ax = _mm_sign_epi8(bx, bx);
 
-        // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
-        const __m256i off = _mm256_set1_epi8( 8 );
-        bx = _mm256_sub_epi8( bx, off );
-        by = _mm256_sub_epi8( by, off );
+            // Sign the values of the y vectors
+            const __m128i sy = _mm_sign_epi8(by, bx);
 
-        // Sign-extend first 16 signed bytes into int16_t
-        __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
-        __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
-        // Compute products of int16_t integers, add pairwise
-        __m256i i32 = _mm256_madd_epi16( x16, y16 );
+            // Perform multiplication and create 16-bit values
+            const __m128i dot = _mm_maddubs_epi16(ax, sy);
 
-        // Sign-extend last 16 signed bytes into int16_t vectors
-        x16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
-        y16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
-        // Accumulate products of int16_t integers
-        i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16, y16 ) );
+            const __m128i ones = _mm_set1_epi16(1);
+            i32[j] = _mm_madd_epi16(ones, dot);
+        }
 
         // Convert int32_t to float
-        __m256 p = _mm256_cvtepi32_ps( i32 );
+        __m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
         // Apply the scale, and accumulate
-        acc = _mm256_fmadd_ps( scale, p, acc );
+        acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
     }
 
     // Return horizontal sum of the acc vector
@@ -1754,21 +2072,18 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
     float sum1 = 0.0f;
 
     for (int i = 0; i < nb; i += 2) {
-        const float d0_0 = *(const float *) (pd0 + i*bs);
-        const float d1_0 = *(const float *) (pd1 + i*bs);
-        const float d0_1 = *(const float *) (pd0 + (i + 1)*bs);
-        const float d1_1 = *(const float *) (pd1 + (i + 1)*bs);
-
-        const uint8_t * restrict p0 = pb0 + i*bs;
-        const uint8_t * restrict p1 = pb1 + i*bs;
+        const block_q4_0 * restrict x0 = &px[i + 0];
+        const block_q4_0 * restrict y0 = &py[i + 0];
+        const block_q4_0 * restrict x1 = &px[i + 1];
+        const block_q4_0 * restrict y1 = &py[i + 1];
 
         const v128_t m4b = wasm_u8x16_splat(0xf);
         const v128_t s8b = wasm_i8x16_splat(0x8);
 
-        const v128_t v0_0 = wasm_v128_load(p0);
-        const v128_t v0_1 = wasm_v128_load(p0 + bs);
-        const v128_t v1_0 = wasm_v128_load(p1);
-        const v128_t v1_1 = wasm_v128_load(p1 + bs);
+        const v128_t v0_0 = wasm_v128_load(x0.qs);
+        const v128_t v0_1 = wasm_v128_load(y0.qs);
+        const v128_t v1_0 = wasm_v128_load(x1.qs);
+        const v128_t v1_1 = wasm_v128_load(y1.qs);
 
         // 4-bit -> 8-bit
         const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
@@ -1818,12 +2133,12 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
         const v128_t p_0 = wasm_i16x8_add(pl_0, ph_0);
         const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1);
 
-        sum0 += d0_0*d1_0*(
+        sum0 += x0->d * y0->d * (
                 wasm_i16x8_extract_lane(p_0, 0) + wasm_i16x8_extract_lane(p_0, 1) +
                 wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) +
                 wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) +
                 wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7));
-        sum1 += d0_1*d1_1*(
+        sum1 += x1->d * y1->d * (
                 wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) +
                 wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) +
                 wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) +
@@ -1834,11 +2149,11 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
 #else
     // scalar
     for (int i = 0; i < nb; i++) {
-        const float d0 = *(const float *) (pd0 + i*bs);
-        const float d1 = *(const float *) (pd1 + i*bs);
+        const float d0 = x[i].d;
+        const float d1 = y[i].d;
 
-        const uint8_t * restrict p0 = pb0 + i*bs;
-        const uint8_t * restrict p1 = pb1 + i*bs;
+        const uint8_t * restrict p0 = x[i].qs;
+        const uint8_t * restrict p1 = y[i].qs;
 
         for (int j = 0; j < QK/2; j++) {
             const uint8_t v0 = p0[j];
@@ -1858,19 +2173,11 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
     *s = sumf;
 }
 
-inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
+static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
     const int nb = n / QK;
 
-    const size_t bs = 2*sizeof(float) + QK/2;
-
-    const uint8_t * restrict pd0 = ((const uint8_t *)x + 0*bs);
-    const uint8_t * restrict pd1 = ((const uint8_t *)y + 0*bs);
-
-    const uint8_t * restrict pm0 = ((const uint8_t *)x + 0*bs + sizeof(float));
-    const uint8_t * restrict pm1 = ((const uint8_t *)y + 0*bs + sizeof(float));
-
-    const uint8_t * restrict pb0 = ((const uint8_t *)x + 0*bs + 2*sizeof(float));
-    const uint8_t * restrict pb1 = ((const uint8_t *)y + 0*bs + 2*sizeof(float));
+    const block_q4_1 * restrict x = vx;
+    const block_q4_1 * restrict y = vy;
 
     float sumf = 0.0;
 
@@ -1882,32 +2189,28 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
 
     // Main loop
     for (int i = 0; i < nb; ++i) {
-        const float * m0 = (const float *) (pm0 + i*bs);
-        const float * m1 = (const float *) (pm1 + i*bs);
+        const float * d0 = &x[i].d;
+        const float * d1 = &y[i].d;
 
-        const float * d0 = (const float *) (pd0 + i*bs);
-        const float * d1 = (const float *) (pd1 + i*bs);
-
-        const uint8_t * restrict p0 = pb0 + i*bs;
-        const uint8_t * restrict p1 = pb1 + i*bs;
+        const float * m0 = &x[i].m;
+        const float * m1 = &y[i].m;
 
         const __m256 d0v = _mm256_broadcast_ss( d0 );
         const __m256 d1v = _mm256_broadcast_ss( d1 );
         const __m256 m0v = _mm256_broadcast_ss( m0 );
         const __m256 m1v = _mm256_broadcast_ss( m1 );
 
-
         // Compute combined scale for the block
         const __m256 scale_01 = _mm256_mul_ps( d0v, d1v );
 
         // Compute cross scales for the block
         const __m256 scale_0 = _mm256_mul_ps( d0v, m1v );
         const __m256 scale_1 = _mm256_mul_ps( m0v, d1v );
-        const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0b10101010 );
+        const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0xAA /* 0b10101010 */ );
 
         // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
-        __m256i bx = bytesFromNibbles( p0 );
-        __m256i by = bytesFromNibbles( p1 );
+        __m256i bx = bytesFromNibbles( x[i].qs );
+        __m256i by = bytesFromNibbles( y[i].qs );
 
         // Now we have a vector with bytes in [ 0 .. 15 ] interval.
 
@@ -1949,17 +2252,56 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
     res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
 
     sumf = _mm_cvtss_f32( res ) + acc_offset * QK;
-#else
-    // scalar
-    for (int i = 0; i < nb; i++) {
-        const float m0 = *(const float *) (pm0 + i*bs);
-        const float m1 = *(const float *) (pm1 + i*bs);
+#elif defined(__ARM_NEON)
+    float sum00 = 0.0f;
+    float sum01 = 0.0f;
+    float sum10 = 0.0f;
+    float sum11 = 0.0f;
 
-        const float d0 = *(const float *) (pd0 + i*bs);
-        const float d1 = *(const float *) (pd1 + i*bs);
+    for (int i = 0; i < nb; ++i) {
+        const block_q4_1 * restrict x0 = &x[i + 0];
+        const block_q4_1 * restrict y0 = &y[i + 0];
+
+        const uint8x16_t m4b = vdupq_n_u8(0xf);
+
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v1_0 = vld1q_u8(y0->qs);
+
+        // and with 0xf
+        const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
+        const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
+
+        const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
+        const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
 
-        const uint8_t * restrict p0 = pb0 + i*bs;
-        const uint8_t * restrict p1 = pb1 + i*bs;
+        // dot product into uint16x8_t
+        const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
+        const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
+
+        const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
+        const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
+
+        const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h);
+        const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h);
+
+        sum00 += x0->m*y0->m;
+        sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
+        sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
+        sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0));
+    }
+
+    sumf = QK*sum00 + sum01 + sum10 + sum11;
+#else
+    // scalar
+    for (int i = 0; i < nb; i++) {
+        const float d0 = x[i].d;
+        const float d1 = y[i].d;
+
+        const float m0 = x[i].m;
+        const float m1 = y[i].m;
+
+        const uint8_t * restrict p0 = x[i].qs;
+        const uint8_t * restrict p1 = y[i].qs;
 
         for (int j = 0; j < QK/2; j++) {
             const uint8_t v0 = p0[j];
@@ -2018,13 +2360,13 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * re
     // leftovers
     for (int i = np; i < n; ++i) {
         for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
-            sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]);
+            sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
         }
     }
 #else
     for (int i = 0; i < n; ++i) {
         for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
-            sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]);
+            sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
         }
     }
 #endif
@@ -2095,19 +2437,19 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) {
 #endif
 }
 
-inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrt(*s);   }
+inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s);   }
 inline static void ggml_vec_sqr_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i];   }
-inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrt(x[i]); }
+inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
 inline static void ggml_vec_abs_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
 inline static void ggml_vec_sgn_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
 inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
 inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
 
-static const ggml_float GELU_COEF_A    = 0.044715;
-static const ggml_float SQRT_2_OVER_PI = 0.79788456080286535587989211986876;
+static const float GELU_COEF_A    = 0.044715f;
+static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
 
 inline static float ggml_gelu_f32(float x) {
-    return 0.5*x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0 + GELU_COEF_A*x*x)));
+    return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
 }
 
 inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
@@ -2136,7 +2478,7 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
 
 // Sigmoid Linear Unit (SiLU) function
 inline static float ggml_silu_f32(float x) {
-    return x/(1.0 + exp(-x));
+    return x/(1.0f + expf(-x));
 }
 
 inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
@@ -2167,7 +2509,7 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
 #ifndef GGML_USE_ACCELERATE
     ggml_float sum = 0.0;
     for (int i = 0; i < n; ++i) {
-        sum += x[i];
+        sum += (ggml_float)x[i];
     }
     *s = sum;
 #else
@@ -2177,7 +2519,7 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
 
 inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
 #ifndef GGML_USE_ACCELERATE
-    ggml_float max = -INFINITY;
+    float max = -INFINITY;
     for (int i = 0; i < n; ++i) {
         max = MAX(max, x[i]);
     }
@@ -2187,7 +2529,10 @@ inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
 #endif
 }
 
-inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { ggml_vec_norm_f32(n, s, x); *s = 1./(*s); }
+inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) {
+    ggml_vec_norm_f32(n, s, x);
+    *s = 1.f/(*s);
+}
 
 //
 // logging
@@ -2230,8 +2575,8 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
 static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
 
 static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
-    sizeof(float  )   + QK/2,
-    sizeof(float  )*2 + QK/2,
+    sizeof(block_q4_0),
+    sizeof(block_q4_1),
     sizeof(int8_t ),
     sizeof(int16_t),
     sizeof(int32_t),
@@ -2269,6 +2614,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
 
     "SCALE",
     "CPY",
+    "CONT",
     "RESHAPE",
     "VIEW",
     "PERMUTE",
@@ -2284,7 +2630,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "FLASH_FF",
 };
 
-static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
+static_assert(GGML_OP_COUNT == 36, "GGML_OP_COUNT != 36");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -2313,6 +2659,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
 
     "x*v",
     "x-\\>y",
+    "cont(x)",
     "reshape(x)",
     "view(x)",
     "permute(x)",
@@ -2328,22 +2675,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "flash_ff(x)",
 };
 
-static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
-
-//
-// ggml object
-//
-
-struct ggml_object {
-    size_t offs;
-    size_t size;
-
-    struct ggml_object * next;
-
-    char padding[8];
-};
-
-static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
+static_assert(GGML_OP_COUNT == 36, "GGML_OP_COUNT != 36");
 
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
 static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -2356,9 +2688,9 @@ struct ggml_context {
     size_t mem_size;
     void * mem_buffer;
     bool   mem_buffer_owned;
-    bool   mem_buffer_mlocked;
+    bool   no_alloc;
 
-    int n_objects;
+    int    n_objects;
 
     struct ggml_object * objects_begin;
     struct ggml_object * objects_end;
@@ -2443,7 +2775,7 @@ void ggml_print_objects(const struct ggml_context * ctx) {
     GGML_PRINT("%s: --- end ---\n", __func__);
 }
 
-int ggml_nelements(const struct ggml_tensor * tensor) {
+int64_t ggml_nelements(const struct ggml_tensor * tensor) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
     return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
@@ -2575,6 +2907,9 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
     static bool is_first_call = true;
 
     if (is_first_call) {
+        // initialize time system (required on Windows)
+        ggml_time_init();
+
         // initialize GELU, SILU and EXP F32 tables
         {
             const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
@@ -2586,7 +2921,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
                 const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
                 table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
                 table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
-                table_exp_f16[i]  = GGML_FP32_TO_FP16(exp(f));
+                table_exp_f16[i]  = GGML_FP32_TO_FP16(expf(f));
             }
 
             const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
@@ -2639,7 +2974,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
         /*.mem_size           =*/ params.mem_size,
         /*.mem_buffer         =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
         /*.mem_buffer_owned   =*/ params.mem_buffer ? false : true,
-        /*.mem_buffer_mlocked =*/ false,
+        /*.no_alloc           =*/ params.no_alloc,
         /*.n_objects          =*/ 0,
         /*.objects_begin      =*/ NULL,
         /*.objects_end        =*/ NULL,
@@ -2671,14 +3006,6 @@ void ggml_free(struct ggml_context * ctx) {
             GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n",
                     __func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size);
 
-#if GGML_MLOCK_SUPPORT
-            if (ctx->mem_buffer_mlocked) {
-                if (munlock(ctx->mem_buffer, ctx->mem_size)) {
-                    fprintf(stderr, "%s: failed to munlock buffer: %s\n", __func__, strerror(errno));
-                }
-            }
-#endif
-
             if (ctx->mem_buffer_owned) {
                 free(ctx->mem_buffer);
             }
@@ -2707,44 +3034,13 @@ size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch)
     return result;
 }
 
-bool ggml_mlock_supported(void) {
-    return GGML_MLOCK_SUPPORT;
-}
-
-#if GGML_MLOCK_SUPPORT
-#ifdef __APPLE__
-    #define MLOCK_SUGGESTION "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or\n" \
-                             "decreasing 'vm.global_no_user_wire_amount'.  Also try increasing RLIMIT_MLOCK (ulimit -l)."
-#else
-    #define MLOCK_SUGGESTION "Try increasing RLIMIT_MLOCK (ulimit -l)."
-#endif
-bool ggml_mlock(struct ggml_context * ctx, char ** err_p) {
-    if (ctx->mem_buffer_mlocked) {
-        return true;
-    }
-    if (mlock(ctx->mem_buffer, ctx->mem_size)) {
-        int ret = asprintf(err_p, "failed to mlock %zu-byte buffer: %s\n" MLOCK_SUGGESTION,
-                           ctx->mem_size, strerror(errno));
-        GGML_ASSERT(ret >= 0);
-        return false;
-    }
-    ctx->mem_buffer_mlocked = true;
-    return true;
-}
-#else // GGML_MLOCK_SUPPORT
-bool ggml_mlock(struct ggml_context * ctx, char ** err_p) {
-    *err_p = strdup("can't mlock because it's not supported on this system");
-    return false;
-}
-#endif // GGML_MLOCK_SUPPORT
-
 ////////////////////////////////////////////////////////////////////////////////
 
 struct ggml_tensor * ggml_new_tensor_impl(
         struct ggml_context * ctx,
         enum   ggml_type type,
         int    n_dims,
-        const int* ne,
+        const int64_t* ne,
         void*  data) {
     // always insert objects at the end of the context's memory pool
     struct ggml_object * obj_cur = ctx->objects_end;
@@ -2755,7 +3051,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
 
     size_t size_needed = 0;
 
-    if (data == NULL) {
+    if (data == NULL && !ctx->no_alloc) {
         size_needed += GGML_TYPE_SIZE[type]*(ne[0]/GGML_BLCK_SIZE[type]);
         for (int i = 1; i < n_dims; i++) {
             size_needed *= ne[i];
@@ -2839,11 +3135,12 @@ struct ggml_tensor * ggml_new_tensor_impl(
         /*.perf_runs    =*/ 0,
         /*.perf_cycles  =*/ 0,
         /*.perf_time_us =*/ 0,
-        /*.data         =*/ data == NULL ? (void *)(result + 1) : data,
+        /*.data         =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
         /*.pad          =*/ { 0 },
     };
 
-    ggml_assert_aligned(result->data);
+    // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
+    //ggml_assert_aligned(result->data);
 
     for (int i = 0; i < n_dims; i++) {
         result->ne[i] = ne[i];
@@ -2864,44 +3161,44 @@ struct ggml_tensor * ggml_new_tensor(
         struct ggml_context * ctx,
         enum   ggml_type type,
         int    n_dims,
-        const int * ne) {
+        const int64_t * ne) {
     return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL);
 }
 
 struct ggml_tensor * ggml_new_tensor_1d(
         struct ggml_context * ctx,
         enum   ggml_type type,
-        int    ne0) {
+        int64_t ne0) {
     return ggml_new_tensor(ctx, type, 1, &ne0);
 }
 
 struct ggml_tensor * ggml_new_tensor_2d(
         struct ggml_context * ctx,
         enum   ggml_type type,
-        int    ne0,
-        int    ne1) {
-    const int ne[2] = { ne0, ne1 };
+        int64_t ne0,
+        int64_t ne1) {
+    const int64_t ne[2] = { ne0, ne1 };
     return ggml_new_tensor(ctx, type, 2, ne);
 }
 
 struct ggml_tensor * ggml_new_tensor_3d(
         struct ggml_context * ctx,
         enum   ggml_type type,
-        int    ne0,
-        int    ne1,
-        int    ne2) {
-    const int ne[3] = { ne0, ne1, ne2 };
+        int64_t ne0,
+        int64_t ne1,
+        int64_t ne2) {
+    const int64_t ne[3] = { ne0, ne1, ne2 };
     return ggml_new_tensor(ctx, type, 3, ne);
 }
 
 struct ggml_tensor * ggml_new_tensor_4d(
         struct ggml_context * ctx,
         enum   ggml_type type,
-        int    ne0,
-        int    ne1,
-        int    ne2,
-        int    ne3) {
-    const int ne[4] = { ne0, ne1, ne2, ne3 };
+        int64_t ne0,
+        int64_t ne1,
+        int64_t ne2,
+        int64_t ne3) {
+    const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
     return ggml_new_tensor(ctx, type, 4, ne);
 }
 
@@ -3244,7 +3541,14 @@ float * ggml_get_data_f32(const struct ggml_tensor * tensor) {
 struct ggml_tensor * ggml_view_tensor(
         struct ggml_context * ctx,
         const struct ggml_tensor * src) {
-    return ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data);
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data);
+
+    result->nb[0] = src->nb[0];
+    result->nb[1] = src->nb[1];
+    result->nb[2] = src->nb[2];
+    result->nb[3] = src->nb[3];
+
+    return result;
 }
 
 ////////////////////////////////////////////////////////////////////////////////
@@ -3548,7 +3852,7 @@ struct ggml_tensor * ggml_mean(
         is_node = true;
     }
 
-    int ne[GGML_MAX_DIMS] = { 1, a->ne[1], a->ne[2], a->ne[3] };
+    int64_t ne[GGML_MAX_DIMS] = { 1, a->ne[1], a->ne[2], a->ne[3] };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, ne);
 
     result->op   = GGML_OP_MEAN;
@@ -3909,7 +4213,7 @@ struct ggml_tensor * ggml_mul_mat(
         is_node = true;
     }
 
-    const int ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] };
+    const int64_t ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
 
     result->op   = GGML_OP_MUL_MAT;
@@ -4004,6 +4308,41 @@ struct ggml_tensor * ggml_cpy_inplace(
     return ggml_cpy_impl(ctx, a, b, true);
 }
 
+// ggml_cont
+
+struct ggml_tensor * ggml_cont_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && a->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_CONT;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_cont(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a) {
+    return ggml_cont_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_cont_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a) {
+    return ggml_cont_impl(ctx, a, true);
+}
+
 // ggml_reshape
 
 struct ggml_tensor * ggml_reshape(
@@ -4034,8 +4373,8 @@ struct ggml_tensor * ggml_reshape(
 struct ggml_tensor * ggml_reshape_2d(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        int                   ne0,
-        int                   ne1) {
+        int64_t               ne0,
+        int64_t               ne1) {
     GGML_ASSERT(ggml_is_contiguous(a));
     GGML_ASSERT(ggml_nelements(a) == ne0*ne1);
 
@@ -4046,7 +4385,7 @@ struct ggml_tensor * ggml_reshape_2d(
         is_node = true;
     }
 
-    const int ne[2] = { ne0, ne1 };
+    const int64_t ne[2] = { ne0, ne1 };
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a->data);
 
     result->op   = GGML_OP_RESHAPE;
@@ -4060,9 +4399,9 @@ struct ggml_tensor * ggml_reshape_2d(
 struct ggml_tensor * ggml_reshape_3d(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        int                   ne0,
-        int                   ne1,
-        int                   ne2) {
+        int64_t               ne0,
+        int64_t               ne1,
+        int64_t               ne2) {
     GGML_ASSERT(ggml_is_contiguous(a));
     GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2);
 
@@ -4073,7 +4412,7 @@ struct ggml_tensor * ggml_reshape_3d(
         is_node = true;
     }
 
-    const int ne[3] = { ne0, ne1, ne2 };
+    const int64_t ne[3] = { ne0, ne1, ne2 };
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a->data);
 
     result->op   = GGML_OP_RESHAPE;
@@ -4089,7 +4428,7 @@ struct ggml_tensor * ggml_reshape_3d(
 struct ggml_tensor * ggml_view_1d(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        int                   ne0,
+        int64_t               ne0,
         size_t                offset) {
     if (a->grad) {
         GGML_ASSERT(false); // gradient propagation is not supported
@@ -4110,15 +4449,15 @@ struct ggml_tensor * ggml_view_1d(
 struct ggml_tensor * ggml_view_2d(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        int                   ne0,
-        int                   ne1,
+        int64_t               ne0,
+        int64_t               ne1,
         size_t                nb1,
         size_t                offset) {
     if (a->grad) {
         GGML_ASSERT(false); // gradient propagation is not supported
     }
 
-    const int ne[GGML_MAX_DIMS] = { ne0, ne1, 1, 1 };
+    const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, 1, 1 };
 
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset);
 
@@ -4134,6 +4473,37 @@ struct ggml_tensor * ggml_view_2d(
     return result;
 }
 
+// ggml_view_3d
+
+struct ggml_tensor * ggml_view_3d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int64_t               ne0,
+        int64_t               ne1,
+        int64_t               ne2,
+        size_t                nb1,
+        size_t                nb2,
+        size_t                offset) {
+    if (a->grad) {
+        GGML_ASSERT(false); // gradient propagation is not supported
+    }
+
+    const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, ne2, 1 };
+
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, (char *) a->data + offset);
+
+    result->nb[1] = nb1;
+    result->nb[2] = nb2;
+    result->nb[3] = result->nb[2]*ne2;
+
+    result->op   = GGML_OP_VIEW;
+    result->grad = NULL;
+    result->src0 = a;
+    result->src1 = NULL; // TODO: maybe store the offset here?
+
+    return result;
+}
+
 // ggml_permute
 
 struct ggml_tensor * ggml_permute(
@@ -4349,7 +4719,7 @@ struct ggml_tensor * ggml_conv_1d_1s(
         is_node = true;
     }
 
-    const int ne[4] = { b->ne[0], a->ne[2], 1, 1, };
+    const int64_t ne[4] = { b->ne[0], a->ne[2], 1, 1, };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
 
     result->op   = GGML_OP_CONV_1D_1S;
@@ -4376,7 +4746,7 @@ struct ggml_tensor * ggml_conv_1d_2s(
         is_node = true;
     }
 
-    const int ne[4] = { b->ne[0]/2, a->ne[2], 1, 1, };
+    const int64_t ne[4] = { b->ne[0]/2, a->ne[2], 1, 1, };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
 
     result->op   = GGML_OP_CONV_1D_2S;
@@ -4469,102 +4839,191 @@ static void ggml_compute_forward_dup_f16(
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
     GGML_ASSERT(params->ith == 0);
-    GGML_ASSERT(ggml_is_contiguous(dst));
     GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
-    const int ne00 = src0->ne[0];
-    const int ne01 = src0->ne[1];
-    const int ne02 = src0->ne[2];
-    const int ne03 = src0->ne[3];
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
 
     const size_t nb00 = src0->nb[0];
     const size_t nb01 = src0->nb[1];
     const size_t nb02 = src0->nb[2];
     const size_t nb03 = src0->nb[3];
 
-    if (ggml_is_contiguous(src0) && src0->type == dst->type) {
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
+
+    if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
         memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
         return;
     }
 
-    if (src0->nb[0] == sizeof(ggml_fp16_t)) {
-        if (dst->type == GGML_TYPE_F16) {
-            size_t id = 0;
-            const size_t rs = ne00*nb00;
+    if (src0->type == dst->type &&
+        src0->ne[0] == dst->ne[0] &&
+        src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
+        // copy by rows
+        const size_t rs = ne00*nb00;
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    memcpy(
+                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
+                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
+                        rs);
+                }
+            }
+        }
+        return;
+    }
 
-            for (int i03 = 0; i03 < ne03; i03++) {
-                for (int i02 = 0; i02 < ne02; i02++) {
-                    for (int i01 = 0; i01 < ne01; i01++) {
-                        const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
-                        char * dst_ptr = (char *) dst->data + id*rs;
+    // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
 
-                        memcpy(dst_ptr, src0_ptr, rs);
+    if (ggml_is_contiguous(dst)) {
+        if (src0->nb[0] == sizeof(ggml_fp16_t)) {
+            if (dst->type == GGML_TYPE_F16) {
+                size_t id = 0;
+                const size_t rs = ne00*nb00;
 
-                        id++;
-                    }
-                }
-            }
-        } else if (dst->type == GGML_TYPE_F32) {
-            size_t id = 0;
-            float * dst_ptr = (float *) dst->data;
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        for (int i01 = 0; i01 < ne01; i01++) {
+                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+                            char * dst_ptr = (char *) dst->data + id*rs;
 
-            for (int i03 = 0; i03 < ne03; i03++) {
-                for (int i02 = 0; i02 < ne02; i02++) {
-                    for (int i01 = 0; i01 < ne01; i01++) {
-                        for (int i00 = 0; i00 < ne00; i00++) {
-                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                            memcpy(dst_ptr, src0_ptr, rs);
 
-                            dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
                             id++;
                         }
                     }
                 }
+            } else if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                float * dst_ptr = (float *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        for (int i01 = 0; i01 < ne01; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
+                                id++;
+                            }
+                        }
+                    }
+                }
+            } else {
+                GGML_ASSERT(false); // TODO: implement
             }
         } else {
-            GGML_ASSERT(false); // TODO: implement
-        }
-    } else {
-        //printf("%s: this is not optimal - fix me\n", __func__);
+            //printf("%s: this is not optimal - fix me\n", __func__);
 
-        if (dst->type == GGML_TYPE_F32) {
-            size_t id = 0;
-            float * dst_ptr = (float *) dst->data;
+            if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                float * dst_ptr = (float *) dst->data;
 
-            for (int i03 = 0; i03 < ne03; i03++) {
-                for (int i02 = 0; i02 < ne02; i02++) {
-                    for (int i01 = 0; i01 < ne01; i01++) {
-                        for (int i00 = 0; i00 < ne00; i00++) {
-                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        for (int i01 = 0; i01 < ne01; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
 
-                            dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
-                            id++;
+                                dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
+                                id++;
+                            }
+                        }
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_F16) {
+                size_t id = 0;
+                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        for (int i01 = 0; i01 < ne01; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = *src0_ptr;
+                                id++;
+                            }
                         }
                     }
                 }
+            } else {
+                GGML_ASSERT(false); // TODO: implement
             }
-        } else if (dst->type == GGML_TYPE_F16) {
-            size_t id = 0;
-            ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
-
-            for (int i03 = 0; i03 < ne03; i03++) {
-                for (int i02 = 0; i02 < ne02; i02++) {
-                    for (int i01 = 0; i01 < ne01; i01++) {
-                        for (int i00 = 0; i00 < ne00; i00++) {
-                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+        }
+        return;
+    }
 
-                            dst_ptr[id] = *src0_ptr;
-                            id++;
+    // dst counters
+    int64_t i10 = 0;
+    int64_t i11 = 0;
+    int64_t i12 = 0;
+    int64_t i13 = 0;
+
+    if (dst->type == GGML_TYPE_F16) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t));
+
+                        if (++i10 == ne00) {
+                            i10 = 0;
+                            if (++i11 == ne01) {
+                                i11 = 0;
+                                if (++i12 == ne02) {
+                                    i12 = 0;
+                                    if (++i13 == ne03) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } else if (dst->type == GGML_TYPE_F32) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
+
+                        if (++i10 == ne00) {
+                            i10 = 0;
+                            if (++i11 == ne01) {
+                                i11 = 0;
+                                if (++i12 == ne02) {
+                                    i12 = 0;
+                                    if (++i13 == ne03) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
                         }
                     }
                 }
             }
-        } else {
-            GGML_ASSERT(false); // TODO: implement
         }
+    } else {
+        GGML_ASSERT(false); // TODO: implement
     }
 }
 
@@ -4573,102 +5032,191 @@ static void ggml_compute_forward_dup_f32(
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
     GGML_ASSERT(params->ith == 0);
-    GGML_ASSERT(ggml_is_contiguous(dst));
     GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
-    const int ne00 = src0->ne[0];
-    const int ne01 = src0->ne[1];
-    const int ne02 = src0->ne[2];
-    const int ne03 = src0->ne[3];
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
 
     const size_t nb00 = src0->nb[0];
     const size_t nb01 = src0->nb[1];
     const size_t nb02 = src0->nb[2];
     const size_t nb03 = src0->nb[3];
 
-    if (ggml_is_contiguous(src0) && src0->type == dst->type) {
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
+
+    if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
         memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
         return;
     }
 
-    if (src0->nb[0] == sizeof(float)) {
-        if (dst->type == GGML_TYPE_F32) {
-            size_t id = 0;
-            const size_t rs = ne00*nb00;
-
-            for (int i03 = 0; i03 < ne03; i03++) {
-                for (int i02 = 0; i02 < ne02; i02++) {
-                    for (int i01 = 0; i01 < ne01; i01++) {
-                        const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
-                        char * dst_ptr = (char *) dst->data + id*rs;
-
-                        memcpy(dst_ptr, src0_ptr, rs);
-
-                        id++;
-                    }
+    if (src0->type == dst->type &&
+        src0->ne[0] == dst->ne[0] &&
+        src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
+        // copy by rows
+        const size_t rs = ne00*nb00;
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    memcpy(
+                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
+                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
+                        rs);
                 }
             }
-        } else if (dst->type == GGML_TYPE_F16) {
-            size_t id = 0;
-            ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+        }
+        return;
+    }
 
-            for (int i03 = 0; i03 < ne03; i03++) {
-                for (int i02 = 0; i02 < ne02; i02++) {
-                    for (int i01 = 0; i01 < ne01; i01++) {
-                        for (int i00 = 0; i00 < ne00; i00++) {
-                            const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+    if (ggml_is_contiguous(dst)) {
+        // TODO: simplify
+        if (src0->nb[0] == sizeof(float)) {
+            if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                const size_t rs = ne00*nb00;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        for (int i01 = 0; i01 < ne01; i01++) {
+                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+                            char * dst_ptr = (char *) dst->data + id*rs;
+
+                            memcpy(dst_ptr, src0_ptr, rs);
 
-                            dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
                             id++;
                         }
                     }
                 }
+            } else if (dst->type == GGML_TYPE_F16) {
+                size_t id = 0;
+                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        for (int i01 = 0; i01 < ne01; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
+                                id++;
+                            }
+                        }
+                    }
+                }
+            } else {
+                GGML_ASSERT(false); // TODO: implement
             }
         } else {
-            GGML_ASSERT(false); // TODO: implement
-        }
-    } else {
-        //printf("%s: this is not optimal - fix me\n", __func__);
+            //printf("%s: this is not optimal - fix me\n", __func__);
 
-        if (dst->type == GGML_TYPE_F32) {
-            size_t id = 0;
-            float * dst_ptr = (float *) dst->data;
+            if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                float * dst_ptr = (float *) dst->data;
 
-            for (int i03 = 0; i03 < ne03; i03++) {
-                for (int i02 = 0; i02 < ne02; i02++) {
-                    for (int i01 = 0; i01 < ne01; i01++) {
-                        for (int i00 = 0; i00 < ne00; i00++) {
-                            const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        for (int i01 = 0; i01 < ne01; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
 
-                            dst_ptr[id] = *src0_ptr;
-                            id++;
+                                dst_ptr[id] = *src0_ptr;
+                                id++;
+                            }
+                        }
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_F16) {
+                size_t id = 0;
+                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        for (int i01 = 0; i01 < ne01; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
+                                id++;
+                            }
                         }
                     }
                 }
+            } else {
+                GGML_ASSERT(false); // TODO: implement
             }
-        } else if (dst->type == GGML_TYPE_F16) {
-            size_t id = 0;
-            ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+        }
 
-            for (int i03 = 0; i03 < ne03; i03++) {
-                for (int i02 = 0; i02 < ne02; i02++) {
-                    for (int i01 = 0; i01 < ne01; i01++) {
-                        for (int i00 = 0; i00 < ne00; i00++) {
-                            const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+        return;
+    }
 
-                            dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
-                            id++;
+    // dst counters
+    int64_t i10 = 0;
+    int64_t i11 = 0;
+    int64_t i12 = 0;
+    int64_t i13 = 0;
+
+    if (dst->type == GGML_TYPE_F32) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        memcpy(dst_ptr, src0_ptr, sizeof(float));
+
+                        if (++i10 == dst->ne[0]) {
+                            i10 = 0;
+                            if (++i11 == dst->ne[1]) {
+                                i11 = 0;
+                                if (++i12 == dst->ne[2]) {
+                                    i12 = 0;
+                                    if (++i13 == dst->ne[3]) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } else if (dst->type == GGML_TYPE_F16) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
+
+                        if (++i10 == dst->ne[0]) {
+                            i10 = 0;
+                            if (++i11 == dst->ne[1]) {
+                                i11 = 0;
+                                if (++i12 == dst->ne[2]) {
+                                    i12 = 0;
+                                    if (++i13 == dst->ne[3]) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
                         }
                     }
                 }
             }
-        } else {
-            GGML_ASSERT(false); // TODO: implement
         }
+    } else {
+        GGML_ASSERT(false); // TODO: implement
     }
 }
 
@@ -4729,14 +5277,18 @@ static void ggml_compute_forward_add_f32(
     GGML_ASSERT(nb00 == sizeof(float));
 
     if (nb10 == sizeof(float)) {
-        const int j0 = (n/nth)*ith;
-        const int j1 = ith == nth - 1 ? n : (n/nth)*(ith + 1);
-
-        for (int j = j0; j < j1; j++) {
+        for (int j = ith; j < n; j += nth) {
+#ifdef GGML_USE_ACCELERATE
+            vDSP_vadd(
+                    (float *) ((char *) src0->data + j*nb01), 1,
+                    (float *) ((char *) src1->data + j*nb11), 1,
+                    (float *) ((char *) dst->data  + j*nb1),  1, nc);
+#else
             ggml_vec_add_f32(nc,
                     (float *) ((char *) dst->data  + j*nb1),
                     (float *) ((char *) src0->data + j*nb01),
                     (float *) ((char *) src1->data + j*nb11));
+#endif
         }
     } else {
         // src1 is not contiguous
@@ -5043,18 +5595,18 @@ static void ggml_compute_forward_sum_f32(
     assert(ggml_is_scalar(dst));
     assert(src0->nb[0] == sizeof(float));
 
-    const int ne00 = src0->ne[0];
-    const int ne01 = src0->ne[1];
-    const int ne02 = src0->ne[2];
-    const int ne03 = src0->ne[3];
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
 
     const size_t nb01 = src0->nb[1];
     const size_t nb02 = src0->nb[2];
     const size_t nb03 = src0->nb[3];
 
-    for (int i03 = 0; i03 < ne03; i03++) {
-        for (int i02 = 0; i02 < ne02; i02++) {
-            for (int i01 = 0; i01 < ne01; i01++) {
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = 0; i01 < ne01; i01++) {
                 ggml_vec_sum_f32(ne00,
                         (float *) (dst->data),
                         (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
@@ -5099,19 +5651,19 @@ static void ggml_compute_forward_mean_f32(
 
     assert(src0->nb[0] == sizeof(float));
 
-    const int ne00 = src0->ne[0];
-    const int ne01 = src0->ne[1];
-    const int ne02 = src0->ne[2];
-    const int ne03 = src0->ne[3];
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
 
     const size_t nb01 = src0->nb[1];
     const size_t nb02 = src0->nb[2];
     const size_t nb03 = src0->nb[3];
 
-    const int ne0 = dst->ne[0];
-    const int ne1 = dst->ne[1];
-    const int ne2 = dst->ne[2];
-    const int ne3 = dst->ne[3];
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+    const int64_t ne3 = dst->ne[3];
 
     assert(ne0 == 1);
     assert(ne1 == ne01);
@@ -5127,9 +5679,9 @@ static void ggml_compute_forward_mean_f32(
     const size_t nb2 = dst->nb[2];
     const size_t nb3 = dst->nb[3];
 
-    for (int i03 = 0; i03 < ne03; i03++) {
-        for (int i02 = 0; i02 < ne02; i02++) {
-            for (int i01 = 0; i01 < ne01; i01++) {
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = 0; i01 < ne01; i01++) {
                 ggml_vec_sum_f32(ne00,
                         (float *) ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
                         (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
@@ -5616,10 +6168,10 @@ static void ggml_compute_forward_norm_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int ne00 = src0->ne[0];
-    const int ne01 = src0->ne[1];
-    const int ne02 = src0->ne[2];
-    const int ne03 = src0->ne[3];
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
 
     const size_t nb01 = src0->nb[1];
     const size_t nb02 = src0->nb[2];
@@ -5629,31 +6181,32 @@ static void ggml_compute_forward_norm_f32(
     const size_t nb2 = dst->nb[2];
     const size_t nb3 = dst->nb[3];
 
-    const ggml_float eps = 1e-5f; // TODO: make this a parameter
+    const float eps = 1e-5f; // TODO: make this a parameter
 
     // TODO: optimize
-    for (int i03 = 0; i03 < ne03; i03++) {
-        for (int i02 = 0; i02 < ne02; i02++) {
-            for (int i01 = ith; i01 < ne01; i01 += nth) {
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
                 const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
 
-                ggml_float mean = 0.0;
-                for (int i00 = 0; i00 < ne00; i00++) {
-                    mean += x[i00];
+                ggml_float sum = 0.0;
+                for (int64_t i00 = 0; i00 < ne00; i00++) {
+                    sum += (ggml_float)x[i00];
                 }
 
-                mean /= ne00;
+                float mean = sum/ne00;
 
                 float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
 
                 ggml_float sum2 = 0.0;
-                for (int i00 = 0; i00 < ne00; i00++) {
-                    ggml_float v = x[i00] - mean;
+                for (int64_t i00 = 0; i00 < ne00; i00++) {
+                    float v = x[i00] - mean;
                     y[i00] = v;
-                    sum2 += v*v;
+                    sum2 += (ggml_float)(v*v);
                 }
 
-                const float scale = 1.0/sqrt(sum2/ne00 + eps);
+                float variance = sum2/ne00;
+                const float scale = 1.0f/sqrtf(variance + eps);
 
                 ggml_vec_scale_f32(ne00, y, scale);
             }
@@ -5698,10 +6251,10 @@ static void ggml_compute_forward_rms_norm_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int ne00 = src0->ne[0];
-    const int ne01 = src0->ne[1];
-    const int ne02 = src0->ne[2];
-    const int ne03 = src0->ne[3];
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
 
     const size_t nb01 = src0->nb[1];
     const size_t nb02 = src0->nb[2];
@@ -5711,20 +6264,20 @@ static void ggml_compute_forward_rms_norm_f32(
     const size_t nb2 = dst->nb[2];
     const size_t nb3 = dst->nb[3];
 
-    const ggml_float eps = 1e-6f; // TODO: make this a parameter
+    const float eps = 1e-6f; // TODO: make this a parameter
 
     // TODO: optimize
-    for (int i03 = 0; i03 < ne03; i03++) {
-        for (int i02 = 0; i02 < ne02; i02++) {
-            for (int i01 = ith; i01 < ne01; i01 += nth) {
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
                 const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
 
-                ggml_float mean = 0.0;
-                for (int i00 = 0; i00 < ne00; i00++) {
-                    mean += x[i00] * x[i00];
+                ggml_float sum = 0.0;
+                for (int64_t i00 = 0; i00 < ne00; i00++) {
+                    sum += (ggml_float)(x[i00] * x[i00]);
                 }
 
-                mean /= ne00;
+                float mean = sum/ne00;
 
                 float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
 
@@ -5733,7 +6286,7 @@ static void ggml_compute_forward_rms_norm_f32(
                 //     y[i00] = x[i00];
                 // }
 
-                const float scale = 1.0/sqrt(mean + eps);
+                const float scale = 1.0f/sqrtf(mean + eps);
 
                 ggml_vec_scale_f32(ne00, y, scale);
             }
@@ -5773,191 +6326,27 @@ static bool ggml_compute_forward_mul_mat_use_blas(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
               struct ggml_tensor * dst) {
-    //const int ne00 = src0->ne[0];
-    //const int ne01 = src0->ne[1];
-
-    const int ne10 = src1->ne[0];
-
-    const int ne0 = dst->ne[0];
-    const int ne1 = dst->ne[1];
-
-    // TODO: find the optimal values for these
-    if (ggml_is_contiguous(src0) &&
-        ggml_is_contiguous(src1) && ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) {
-
-        /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
-        return true;
-    }
-
-    return false;
-}
-#endif
+    //const int64_t ne00 = src0->ne[0];
+    //const int64_t ne01 = src0->ne[1];
 
-static void ggml_compute_forward_mul_mat_f32(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
-    const int ne00 = src0->ne[0];
-    const int ne01 = src0->ne[1];
-    const int ne02 = src0->ne[2];
-    const int ne03 = src0->ne[3];
-
-    const int ne10 = src1->ne[0];
-    const int ne11 = src1->ne[1];
-    const int ne12 = src1->ne[2];
-    const int ne13 = src1->ne[3];
-
-    const int ne0  = dst->ne[0];
-    const int ne1  = dst->ne[1];
-    const int ne2  = dst->ne[2];
-    const int ne3  = dst->ne[3];
-    //const int ne   = ne0*ne1*ne2*ne3;
-
-    const int nb00 = src0->nb[0];
-    const int nb01 = src0->nb[1];
-    const int nb02 = src0->nb[2];
-    const int nb03 = src0->nb[3];
-
-    const int nb10 = src1->nb[0];
-    const int nb11 = src1->nb[1];
-    const int nb12 = src1->nb[2];
-    const int nb13 = src1->nb[3];
-
-    const int nb0  = dst->nb[0];
-    const int nb1  = dst->nb[1];
-    const int nb2  = dst->nb[2];
-    const int nb3  = dst->nb[3];
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    assert(ne02 == ne12);
-    assert(ne03 == ne13);
-    assert(ne2  == ne12);
-    assert(ne3  == ne13);
-
-    // TODO: we don't support permuted src0
-    assert(nb00 == sizeof(float));
-
-    // dst cannot be transposed or permuted
-    assert(nb0 == sizeof(float));
-    assert(nb0 <= nb1);
-    assert(nb1 <= nb2);
-    assert(nb2 <= nb3);
-
-    assert(ne0 == ne01);
-    assert(ne1 == ne11);
-    assert(ne2 == ne02);
-    assert(ne3 == ne03);
-
-    // nb01 >= nb00 - src0 is not transposed
-    //   compute by src0 rows
-
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
-    if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
-        GGML_ASSERT(nb10 == sizeof(float));
-
-        if (params->ith != 0) {
-            return;
-        }
-
-        if (params->type == GGML_TASK_INIT) {
-            return;
-        }
-
-        if (params->type == GGML_TASK_FINALIZE) {
-            return;
-        }
-
-        for (int i03 = 0; i03 < ne03; i03++) {
-            for (int i02 = 0; i02 < ne02; i02++) {
-                const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
-                const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
-
-                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
-
-                // zT = y * xT
-                cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
-                        ne11, ne01, ne10,
-                        1.0f,    y, ne10,
-                                 x, ne10,
-                        0.0f,    d, ne01);
-            }
-        }
-
-        //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
-
-        return;
-    }
-#endif
-
-    if (params->type == GGML_TASK_INIT) {
-        return;
-    }
-
-    if (params->type == GGML_TASK_FINALIZE) {
-        return;
-    }
-
-    // TODO: do not support transposed src1
-    assert(nb10 == sizeof(float));
-
-    // parallelize by src0 rows using ggml_vec_dot_f32
-
-    // total rows in src0
-    const int nr = ne01*ne02*ne03;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // src0 indices
-        const int i03 = ir/(ne02*ne01);
-        const int i02 = (ir - i03*ne02*ne01)/ne01;
-        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-        for (int ic = 0; ic < ne11; ++ic) {
-            // src1 indices
-            const int i13 = i03;
-            const int i12 = i02;
-            const int i11 = ic;
+    const int64_t ne10 = src1->ne[0];
 
-            // dst indices
-            const int i0 = i01;
-            const int i1 = i11;
-            const int i2 = i02;
-            const int i3 = i03;
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
 
-            ggml_vec_dot_f32(ne00,
-                    (float *) ((char *)  dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
-                    (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)),
-                    (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)));
-        }
-    }
+    // TODO: find the optimal values for these
+    if (ggml_is_contiguous(src0) &&
+        ggml_is_contiguous(src1) && ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) {
 
-    //int64_t t1 = ggml_perf_time_us();
-    //static int64_t acc = 0;
-    //acc += t1 - t0;
-    //if (t1 - t0 > 10) {
-    //    printf("\n");
-    //    printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
-    //    printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
-    //    printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
-    //    printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
+        /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
+        return true;
+    }
 
-    //    printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
-    //}
+    return false;
 }
+#endif
 
-static void ggml_compute_forward_mul_mat_f16_f32(
+static void ggml_compute_forward_mul_mat_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -5965,28 +6354,33 @@ static void ggml_compute_forward_mul_mat_f16_f32(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    const int ne00 = src0->ne[0];
-    const int ne01 = src0->ne[1];
-    const int ne02 = src0->ne[2];
-    const int ne03 = src0->ne[3];
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
 
-    const int ne10 = src1->ne[0];
-    const int ne11 = src1->ne[1];
-    const int ne12 = src1->ne[2];
-    const int ne13 = src1->ne[3];
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+    const int64_t ne10 = src1->ne[0];
+#endif
+    const int64_t ne11 = src1->ne[1];
+#ifndef NDEBUG
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
 
-    const int ne0  = dst->ne[0];
-    const int ne1  = dst->ne[1];
-    const int ne2  = dst->ne[2];
-    const int ne3  = dst->ne[3];
-    //const int ne   = ne0*ne1*ne2*ne3;
+    const int64_t ne0  = dst->ne[0];
+    const int64_t ne1  = dst->ne[1];
+    const int64_t ne2  = dst->ne[2];
+    const int64_t ne3  = dst->ne[3];
 
     const int nb00 = src0->nb[0];
+#endif
     const int nb01 = src0->nb[1];
     const int nb02 = src0->nb[2];
     const int nb03 = src0->nb[3];
 
+#ifndef NDEBUG
     const int nb10 = src1->nb[0];
+#endif
     const int nb11 = src1->nb[1];
     const int nb12 = src1->nb[2];
     const int nb13 = src1->nb[3];
@@ -5999,32 +6393,31 @@ static void ggml_compute_forward_mul_mat_f16_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    GGML_ASSERT(ne02 == ne12);
-    GGML_ASSERT(ne03 == ne13);
-    GGML_ASSERT(ne2  == ne12);
-    GGML_ASSERT(ne3  == ne13);
+    assert(ne02 == ne12);
+    assert(ne03 == ne13);
+    assert(ne2  == ne12);
+    assert(ne3  == ne13);
 
-    // TODO: we don't support permuted src0
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+    // we don't support permuted src0 or src1
+    assert(nb00 == sizeof(float));
+    assert(nb10 == sizeof(float));
 
     // dst cannot be transposed or permuted
-    GGML_ASSERT(nb0 == sizeof(float));
-    GGML_ASSERT(nb0 <= nb1);
-    GGML_ASSERT(nb1 <= nb2);
-    GGML_ASSERT(nb2 <= nb3);
+    assert(nb0 == sizeof(float));
+    assert(nb0 <= nb1);
+    assert(nb1 <= nb2);
+    assert(nb2 <= nb3);
 
-    GGML_ASSERT(ne0 == ne01);
-    GGML_ASSERT(ne1 == ne11);
-    GGML_ASSERT(ne2 == ne02);
-    GGML_ASSERT(ne3 == ne03);
+    assert(ne0 == ne01);
+    assert(ne1 == ne11);
+    assert(ne2 == ne02);
+    assert(ne3 == ne03);
 
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
-        GGML_ASSERT(nb10 == sizeof(float));
-
         if (params->ith != 0) {
             return;
         }
@@ -6037,20 +6430,9 @@ static void ggml_compute_forward_mul_mat_f16_f32(
             return;
         }
 
-        float * const wdata = params->wdata;
-
-        for (int i03 = 0; i03 < ne03; i03++) {
-            for (int i02 = 0; i02 < ne02; i02++) {
-                {
-                    size_t id = 0;
-                    for (int i01 = 0; i01 < ne01; ++i01) {
-                        for (int i00 = 0; i00 < ne00; ++i00) {
-                            wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
-                        }
-                    }
-                }
-
-                const float * x = wdata;
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
 
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
@@ -6064,28 +6446,13 @@ static void ggml_compute_forward_mul_mat_f16_f32(
             }
         }
 
-        /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
+        //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
 
         return;
     }
 #endif
 
     if (params->type == GGML_TASK_INIT) {
-        ggml_fp16_t * const wdata = params->wdata;
-
-        size_t id = 0;
-        for (int i13 = 0; i13 < ne13; ++i13) {
-            for (int i12 = 0; i12 < ne12; ++i12) {
-                for (int i11 = 0; i11 < ne11; ++i11) {
-                    for (int i10 = 0; i10 < ne10; ++i10) {
-                        wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
-                    }
-                }
-            }
-        }
-
-        GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize);
-
         return;
     }
 
@@ -6093,11 +6460,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
         return;
     }
 
-    // fp16 -> half the size, so divide by 2
-    // TODO: do not support transposed src1
-    assert(nb10/2 == sizeof(ggml_fp16_t));
-
-    // parallelize by src0 rows using ggml_vec_dot_f16
+    // parallelize by src0 rows using ggml_vec_dot_f32
 
     // total rows in src0
     const int nr = ne01*ne02*ne03;
@@ -6109,32 +6472,32 @@ static void ggml_compute_forward_mul_mat_f16_f32(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    ggml_fp16_t * wdata = params->wdata;
-
     for (int ir = ir0; ir < ir1; ++ir) {
         // src0 indices
         const int i03 = ir/(ne02*ne01);
         const int i02 = (ir - i03*ne02*ne01)/ne01;
         const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
 
-        const int i13 = i03;
-        const int i12 = i02;
-
-        const int i0 = i01;
-        const int i2 = i02;
-        const int i3 = i03;
-
-        ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
-        ggml_fp16_t * src1_col =                                wdata + (       0 + i12*ne11 + i13*ne12*ne11)*ne00;
+        for (int64_t ic = 0; ic < ne11; ++ic) {
+            // src1 indices
+            const int i13 = i03;
+            const int i12 = i02;
+            const int i11 = ic;
 
-        float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
+            // dst indices
+            const int i0 = i01;
+            const int i1 = i11;
+            const int i2 = i02;
+            const int i3 = i03;
 
-        for (int ic = 0; ic < ne11; ++ic) {
-            ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00);
+            ggml_vec_dot_f32(ne00,
+                    (float *) ((char *)  dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
+                    (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)),
+                    (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)));
         }
     }
 
-    //int64_t t1 = ggml_time_us();
+    //int64_t t1 = ggml_perf_time_us();
     //static int64_t acc = 0;
     //acc += t1 - t0;
     //if (t1 - t0 > 10) {
@@ -6142,12 +6505,13 @@ static void ggml_compute_forward_mul_mat_f16_f32(
     //    printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
     //    printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
     //    printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
+    //    printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
 
     //    printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
     //}
 }
 
-static void ggml_compute_forward_mul_mat_q4_0_f32(
+static void ggml_compute_forward_mul_mat_f16_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -6155,21 +6519,21 @@ static void ggml_compute_forward_mul_mat_q4_0_f32(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    const int ne00 = src0->ne[0];
-    const int ne01 = src0->ne[1];
-    const int ne02 = src0->ne[2];
-    const int ne03 = src0->ne[3];
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
 
-    const int ne10 = src1->ne[0];
-    const int ne11 = src1->ne[1];
-    const int ne12 = src1->ne[2];
-    const int ne13 = src1->ne[3];
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
 
-    const int ne0  = dst->ne[0];
-    const int ne1  = dst->ne[1];
-    const int ne2  = dst->ne[2];
-    const int ne3  = dst->ne[3];
-    //const int ne   = ne0*ne1*ne2*ne3;
+    const int64_t ne0  = dst->ne[0];
+    const int64_t ne1  = dst->ne[1];
+    const int64_t ne2  = dst->ne[2];
+    const int64_t ne3  = dst->ne[3];
+    //const int64_t ne   = ne0*ne1*ne2*ne3;
 
     const int nb00 = src0->nb[0];
     const int nb01 = src0->nb[1];
@@ -6195,7 +6559,7 @@ static void ggml_compute_forward_mul_mat_q4_0_f32(
     GGML_ASSERT(ne3  == ne13);
 
     // TODO: we don't support permuted src0
-    GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_0]);
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
 
     // dst cannot be transposed or permuted
     GGML_ASSERT(nb0 == sizeof(float));
@@ -6229,13 +6593,14 @@ static void ggml_compute_forward_mul_mat_q4_0_f32(
 
         float * const wdata = params->wdata;
 
-        for (int i03 = 0; i03 < ne03; i03++) {
-            for (int i02 = 0; i02 < ne02; i02++) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
                 {
                     size_t id = 0;
-                    for (int i01 = 0; i01 < ne01; ++i01) {
-                        dequantize_row_q4_0((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
-                        id += ne00;
+                    for (int64_t i01 = 0; i01 < ne01; ++i01) {
+                        for (int64_t i00 = 0; i00 < ne00; ++i00) {
+                            wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
+                        }
                     }
                 }
 
@@ -6253,24 +6618,28 @@ static void ggml_compute_forward_mul_mat_q4_0_f32(
             }
         }
 
-        /*printf("CBLAS Q4_0 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
+        /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
 
         return;
     }
 #endif
 
     if (params->type == GGML_TASK_INIT) {
-        char * wdata = params->wdata;
+        ggml_fp16_t * const wdata = params->wdata;
 
-        for (int i13 = 0; i13 < ne13; ++i13) {
-            for (int i12 = 0; i12 < ne12; ++i12) {
-                for (int i11 = 0; i11 < ne11; ++i11) {
-                    quantize_row_q4_0((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
-                    wdata += (ne10*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
+        size_t id = 0;
+        for (int64_t i13 = 0; i13 < ne13; ++i13) {
+            for (int64_t i12 = 0; i12 < ne12; ++i12) {
+                for (int64_t i11 = 0; i11 < ne11; ++i11) {
+                    for (int64_t i10 = 0; i10 < ne10; ++i10) {
+                        wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
+                    }
                 }
             }
         }
 
+        GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize);
+
         return;
     }
 
@@ -6278,9 +6647,11 @@ static void ggml_compute_forward_mul_mat_q4_0_f32(
         return;
     }
 
+    // fp16 -> half the size, so divide by 2
     // TODO: do not support transposed src1
+    assert(nb10/2 == sizeof(ggml_fp16_t));
 
-    // parallelize by src0 rows using ggml_vec_dot_q4_0
+    // parallelize by src0 rows using ggml_vec_dot_f16
 
     // total rows in src0
     const int nr = ne01*ne02*ne03;
@@ -6292,7 +6663,7 @@ static void ggml_compute_forward_mul_mat_q4_0_f32(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    void * wdata = params->wdata;
+    ggml_fp16_t * wdata = params->wdata;
 
     for (int ir = ir0; ir < ir1; ++ir) {
         // src0 indices
@@ -6307,15 +6678,13 @@ static void ggml_compute_forward_mul_mat_q4_0_f32(
         const int i2 = i02;
         const int i3 = i03;
 
-        void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
-        char * src1_col =          ((char *)      wdata + (      (0 + i12*ne11 + i13*ne12*ne11)*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0]);
+        ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
+        ggml_fp16_t * src1_col =                                wdata + (       0 + i12*ne11 + i13*ne12*ne11)*ne00;
 
         float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
 
-        assert(ne00 % 32 == 0);
-
-        for (int ic = 0; ic < ne11; ++ic) {
-            ggml_vec_dot_q4_0(ne00, &dst_col[ic*ne0], src0_row, ((void *) (src1_col + (ic*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0])));
+        for (int64_t ic = 0; ic < ne11; ++ic) {
+            ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00);
         }
     }
 
@@ -6332,7 +6701,28 @@ static void ggml_compute_forward_mul_mat_q4_0_f32(
     //}
 }
 
-static void ggml_compute_forward_mul_mat_q4_1_f32(
+static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
+    [GGML_TYPE_Q4_0] = {
+        .dequantize_row_q         = dequantize_row_q4_0,
+        .quantize_row_q           = quantize_row_q4_0,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
+        .vec_dot_q                = ggml_vec_dot_q4_0,
+    },
+    [GGML_TYPE_Q4_1] = {
+        .dequantize_row_q         = dequantize_row_q4_1,
+        .quantize_row_q           = quantize_row_q4_1,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
+        .vec_dot_q                = ggml_vec_dot_q4_1,
+    },
+};
+
+// For internal test use
+quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
+    GGML_ASSERT(i < GGML_TYPE_COUNT);
+    return quantize_fns[i];
+}
+
+static void ggml_compute_forward_mul_mat_q_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -6340,21 +6730,20 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    const int ne00 = src0->ne[0];
-    const int ne01 = src0->ne[1];
-    const int ne02 = src0->ne[2];
-    const int ne03 = src0->ne[3];
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
 
-    const int ne10 = src1->ne[0];
-    const int ne11 = src1->ne[1];
-    const int ne12 = src1->ne[2];
-    const int ne13 = src1->ne[3];
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
 
-    const int ne0  = dst->ne[0];
-    const int ne1  = dst->ne[1];
-    const int ne2  = dst->ne[2];
-    const int ne3  = dst->ne[3];
-    //const int ne   = ne0*ne1*ne2*ne3;
+    const int64_t ne0  = dst->ne[0];
+    const int64_t ne1  = dst->ne[1];
+    const int64_t ne2  = dst->ne[2];
+    const int64_t ne3  = dst->ne[3];
 
     const int nb00 = src0->nb[0];
     const int nb01 = src0->nb[1];
@@ -6379,8 +6768,13 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
     GGML_ASSERT(ne2  == ne12);
     GGML_ASSERT(ne3  == ne13);
 
-    // TODO: we don't support permuted src0
-    GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_1]);
+    const enum ggml_type type = src0->type;
+    quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
+    vec_dot_q_t      const vec_dot_q      = quantize_fns[type].vec_dot_q;
+
+    // we don't support permuted src0 or src1
+    GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
+    GGML_ASSERT(nb10 == sizeof(float));
 
     // dst cannot be transposed or permuted
     GGML_ASSERT(nb0 == sizeof(float));
@@ -6398,8 +6792,6 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
 
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
-        GGML_ASSERT(nb10 == sizeof(float));
-
         if (params->ith != 0) {
             return;
         }
@@ -6413,13 +6805,14 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
         }
 
         float * const wdata = params->wdata;
+        dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
 
-        for (int i03 = 0; i03 < ne03; i03++) {
-            for (int i02 = 0; i02 < ne02; i02++) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
                 {
                     size_t id = 0;
-                    for (int i01 = 0; i01 < ne01; ++i01) {
-                        dequantize_row_q4_1((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
+                    for (int64_t i01 = 0; i01 < ne01; ++i01) {
+                        dequantize_row_q((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
                         id += ne00;
                     }
                 }
@@ -6446,15 +6839,13 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
 
     if (params->type == GGML_TASK_INIT) {
         char * wdata = params->wdata;
+        const size_t row_size = ne10*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type];
 
-        for (int i13 = 0; i13 < ne13; ++i13) {
-            for (int i12 = 0; i12 < ne12; ++i12) {
-                for (int i11 = 0; i11 < ne11; ++i11) {
-                    //for (int i10 = 0; i10 < ne10; ++i10) {
-                    //    wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
-                    //}
-                    quantize_row_q4_1((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
-                    wdata += (ne10*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1];
+        for (int64_t i13 = 0; i13 < ne13; ++i13) {
+            for (int64_t i12 = 0; i12 < ne12; ++i12) {
+                for (int64_t i11 = 0; i11 < ne11; ++i11) {
+                    quantize_row_q((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
+                    wdata += row_size;
                 }
             }
         }
@@ -6466,9 +6857,7 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
         return;
     }
 
-    // TODO: do not support transposed src1
-
-    // parallelize by src0 rows using ggml_vec_dot_q4_1
+    // parallelize by src0 rows using ggml_vec_dot_q
 
     // total rows in src0
     const int nr = ne01*ne02*ne03;
@@ -6481,6 +6870,7 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
     const int ir1 = MIN(ir0 + dr, nr);
 
     void * wdata = params->wdata;
+    const size_t row_size = ne00*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type];
 
     for (int ir = ir0; ir < ir1; ++ir) {
         // src0 indices
@@ -6496,14 +6886,14 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
         const int i3 = i03;
 
         void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
-        char * src1_col =          ((char *)      wdata + (      (0 + i12*ne11 + i13*ne12*ne11)*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1]);
+        char * src1_col =          ((char *)      wdata + (      (0 + i12*ne11 + i13*ne12*ne11)*row_size));
 
         float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
 
         assert(ne00 % 32 == 0);
 
-        for (int ic = 0; ic < ne11; ++ic) {
-            ggml_vec_dot_q4_1(ne00, &dst_col[ic*ne0], src0_row, ((void *) (src1_col + (ic*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1])));
+        for (int64_t ic = 0; ic < ne11; ++ic) {
+            vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
         }
     }
 
@@ -6527,12 +6917,9 @@ static void ggml_compute_forward_mul_mat(
         struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_Q4_0:
-            {
-                ggml_compute_forward_mul_mat_q4_0_f32(params, src0, src1, dst);
-            } break;
         case GGML_TYPE_Q4_1:
             {
-                ggml_compute_forward_mul_mat_q4_1_f32(params, src0, src1, dst);
+                ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
             } break;
         case GGML_TYPE_F16:
             {
@@ -6649,6 +7036,15 @@ static void ggml_compute_forward_cpy(
     ggml_compute_forward_dup(params, src0, dst);
 }
 
+// ggml_compute_forward_cont
+
+static void ggml_compute_forward_cont(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    ggml_compute_forward_dup(params, src0, dst);
+}
+
 // ggml_compute_forward_reshape
 
 static void ggml_compute_forward_reshape(
@@ -6693,34 +7089,7 @@ static void ggml_compute_forward_transpose(
 
 // ggml_compute_forward_get_rows
 
-static void ggml_compute_forward_get_rows_q4_0(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    assert(params->ith == 0);
-
-    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
-        return;
-    }
-
-    const int nc = src0->ne[0];
-    const int nr = ggml_nelements(src1);
-
-    assert( dst->ne[0] == nc);
-    assert( dst->ne[1] == nr);
-    assert(src0->nb[0] == GGML_TYPE_SIZE[GGML_TYPE_Q4_0]);
-
-    for (int i = 0; i < nr; ++i) {
-        const int r = ((int32_t *) src1->data)[i];
-
-        dequantize_row_q4_0(
-                (const void *) ((char *) src0->data + r*src0->nb[1]),
-                     (float *) ((char *)  dst->data + i*dst->nb[1]), nc);
-    }
-}
-
-static void ggml_compute_forward_get_rows_q4_1(
+static void ggml_compute_forward_get_rows_q(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -6733,15 +7102,17 @@ static void ggml_compute_forward_get_rows_q4_1(
 
     const int nc = src0->ne[0];
     const int nr = ggml_nelements(src1);
+    const enum ggml_type type = src0->type;
+    dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
 
     assert( dst->ne[0] == nc);
     assert( dst->ne[1] == nr);
-    assert(src0->nb[0] == GGML_TYPE_SIZE[GGML_TYPE_Q4_1]);
+    assert(src0->nb[0] == GGML_TYPE_SIZE[type]);
 
     for (int i = 0; i < nr; ++i) {
         const int r = ((int32_t *) src1->data)[i];
 
-        dequantize_row_q4_1(
+        dequantize_row_q(
                 (const void *) ((char *) src0->data + r*src0->nb[1]),
                      (float *) ((char *)  dst->data + i*dst->nb[1]), nc);
     }
@@ -6809,12 +7180,9 @@ static void ggml_compute_forward_get_rows(
         struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_Q4_0:
-            {
-                ggml_compute_forward_get_rows_q4_0(params, src0, src1, dst);
-            } break;
         case GGML_TYPE_Q4_1:
             {
-                ggml_compute_forward_get_rows_q4_1(params, src0, src1, dst);
+                ggml_compute_forward_get_rows_q(params, src0, src1, dst);
             } break;
         case GGML_TYPE_F16:
             {
@@ -6966,12 +7334,12 @@ static void ggml_compute_forward_soft_max_f32(
                 ggml_fp16_t s = GGML_FP32_TO_FP16(p[i] - max);
                 memcpy(&scvt, &s, sizeof(scvt));
                 const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
-                sum += val;
+                sum += (ggml_float)val;
                 p[i] = val;
             }
         }
 
-        assert(sum > 0.0f);
+        assert(sum > 0.0);
 
         sum = 1.0/sum;
         ggml_vec_scale_f32(nc, p, sum);
@@ -7014,7 +7382,6 @@ static void ggml_compute_forward_rope_f32(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    assert(params->ith == 0);
     assert(src1->type == GGML_TYPE_I32);
     assert(ggml_nelements(src1) == 3);
 
@@ -7026,10 +7393,10 @@ static void ggml_compute_forward_rope_f32(
     const int n_dims = ((int32_t *) src1->data)[1];
     const int mode   = ((int32_t *) src1->data)[2];
 
-    //const int ne0 = src0->ne[0];
-    const int ne1 = src0->ne[1];
-    const int ne2 = src0->ne[2];
-    const int ne3 = src0->ne[3];
+    //const int64_t ne0 = src0->ne[0];
+    const int64_t ne1 = src0->ne[1];
+    const int64_t ne2 = src0->ne[2];
+    const int64_t ne3 = src0->ne[3];
 
     const int nb0 = src0->nb[0];
     const int nb1 = src0->nb[1];
@@ -7041,22 +7408,39 @@ static void ggml_compute_forward_rope_f32(
 
     assert(nb0 == sizeof(float));
 
-    // TODO: optimize
-    for (int i3 = 0; i3 < ne3; i3++) {
-        for (int i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    // row index used to determine which thread to use
+    int ir = 0;
+
+    for (int64_t i3 = 0; i3 < ne3; i3++) {
+        for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
             const int p = (mode == 0 ? n_past + i2 : i2);
-            for (int i1 = 0; i1 < ne1; i1++) {
+            for (int64_t i1 = 0; i1 < ne1; i1++) {
+                if (ir++ < ir0) continue;
+                if (ir   > ir1) break;
+
                 for (int i0 = 0; i0 < n_dims; i0 += 2) {
-                    const double theta = pow(10000.0, ((double)-i0)/n_dims);
+                    const float theta = powf(10000.0, ((float)-i0)/n_dims);
 
-                    const double cos_theta = cos(p*theta);
-                    const double sin_theta = sin(p*theta);
+                    const float cos_theta = cosf(p*theta);
+                    const float sin_theta = sinf(p*theta);
 
                     const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
                           float * dst_data  = (float *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
 
-                    double x0 = src[0];
-                    double x1 = src[1];
+                    const float x0 = src[0];
+                    const float x1 = src[1];
 
                     dst_data[0] = x0*cos_theta - x1*sin_theta;
                     dst_data[1] = x0*sin_theta + x1*cos_theta;
@@ -7071,7 +7455,6 @@ static void ggml_compute_forward_rope_f16(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    assert(params->ith == 0);
     assert(src1->type == GGML_TYPE_I32);
     assert(ggml_nelements(src1) == 3);
 
@@ -7083,10 +7466,10 @@ static void ggml_compute_forward_rope_f16(
     const int n_dims = ((int32_t *) src1->data)[1];
     const int mode   = ((int32_t *) src1->data)[2];
 
-    //const int ne0 = src0->ne[0];
-    const int ne1 = src0->ne[1];
-    const int ne2 = src0->ne[2];
-    const int ne3 = src0->ne[3];
+    //const int64_t ne0 = src0->ne[0];
+    const int64_t ne1 = src0->ne[1];
+    const int64_t ne2 = src0->ne[2];
+    const int64_t ne3 = src0->ne[3];
 
     const int nb0 = src0->nb[0];
     const int nb1 = src0->nb[1];
@@ -7098,21 +7481,39 @@ static void ggml_compute_forward_rope_f16(
 
     assert(nb0 == sizeof(ggml_fp16_t));
 
-    for (int i3 = 0; i3 < ne3; i3++) {
-        for (int i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    // row index used to determine which thread to use
+    int ir = 0;
+
+    for (int64_t i3 = 0; i3 < ne3; i3++) {
+        for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
             const int p = (mode == 0 ? n_past + i2 : i2);
-            for (int i1 = 0; i1 < ne1; i1++) {
+            for (int64_t i1 = 0; i1 < ne1; i1++) {
+                if (ir++ < ir0) continue;
+                if (ir   > ir1) break;
+
                 for (int i0 = 0; i0 < n_dims; i0 += 2) {
-                    const double theta = pow(10000.0, ((double)-i0)/n_dims);
+                    const float theta = powf(10000.0, ((float)-i0)/n_dims);
 
-                    const double cos_theta = cos(p*theta);
-                    const double sin_theta = sin(p*theta);
+                    const float cos_theta = cosf(p*theta);
+                    const float sin_theta = sinf(p*theta);
 
                     const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
                           ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
 
-                    double x0 = ggml_fp16_to_fp32(src[0]);
-                    double x1 = ggml_fp16_to_fp32(src[1]);
+                    const float x0 = ggml_fp16_to_fp32(src[0]);
+                    const float x1 = ggml_fp16_to_fp32(src[1]);
 
                     dst_data[0] = ggml_fp32_to_fp16(x0*cos_theta - x1*sin_theta);
                     dst_data[1] = ggml_fp32_to_fp16(x0*sin_theta + x1*cos_theta);
@@ -7162,21 +7563,21 @@ static void ggml_compute_forward_conv_1d_1s_f16_f32(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    const int ne00 = src0->ne[0];
-    const int ne01 = src0->ne[1];
-    const int ne02 = src0->ne[2];
-    //const int ne03 = src0->ne[3];
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    //const int64_t ne03 = src0->ne[3];
 
-    const int ne10 = src1->ne[0];
-    const int ne11 = src1->ne[1];
-    //const int ne12 = src1->ne[2];
-    //const int ne13 = src1->ne[3];
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    //const int64_t ne12 = src1->ne[2];
+    //const int64_t ne13 = src1->ne[3];
 
-    //const int ne0  = dst->ne[0];
-    //const int ne1  = dst->ne[1];
-    //const int ne2  = dst->ne[2];
-    //const int ne3  = dst->ne[3];
-    //const int ne   = ne0*ne1*ne2*ne3;
+    //const int64_t ne0  = dst->ne[0];
+    //const int64_t ne1  = dst->ne[1];
+    //const int64_t ne2  = dst->ne[2];
+    //const int64_t ne3  = dst->ne[3];
+    //const int64_t ne   = ne0*ne1*ne2*ne3;
 
     const int nb00 = src0->nb[0];
     const int nb01 = src0->nb[1];
@@ -7213,11 +7614,11 @@ static void ggml_compute_forward_conv_1d_1s_f16_f32(
         {
             ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
 
-            for (int i02 = 0; i02 < ne02; i02++) {
-                for (int i01 = 0; i01 < ne01; i01++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
                     const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
                     ggml_fp16_t * dst_data = wdata + i02*ew0*ne00;
-                    for (int i00 = 0; i00 < ne00; i00++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
                         dst_data[i00*ew0 + i01] = src[i00];
                     }
                 }
@@ -7228,10 +7629,10 @@ static void ggml_compute_forward_conv_1d_1s_f16_f32(
         {
             ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00;
 
-            for (int i11 = 0; i11 < ne11; i11++) {
+            for (int64_t i11 = 0; i11 < ne11; i11++) {
                 const float * const src = (float *)((char *) src1->data + i11*nb11);
                 ggml_fp16_t * dst_data = wdata;
-                for (int i10 = 0; i10 < ne10; i10++) {
+                for (int64_t i10 = 0; i10 < ne10; i10++) {
                     dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]);
                 }
             }
@@ -7256,7 +7657,7 @@ static void ggml_compute_forward_conv_1d_1s_f16_f32(
 
     for (int i1 = ir0; i1 < ir1; i1++) {
         float * dst_data = (float *)((char *) dst->data + i1*nb1);
-        for (int i0 = 0; i0 < ne10; ++i0) {
+        for (int64_t i0 = 0; i0 < ne10; ++i0) {
             dst_data[i0] = 0;
             for (int k = -nh; k <= nh; k++) {
                 float v = 0.0f;
@@ -7282,21 +7683,21 @@ static void ggml_compute_forward_conv_1d_1s_f32(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    const int ne00 = src0->ne[0];
-    const int ne01 = src0->ne[1];
-    const int ne02 = src0->ne[2];
-    //const int ne03 = src0->ne[3];
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    //const int64_t ne03 = src0->ne[3];
 
-    const int ne10 = src1->ne[0];
-    const int ne11 = src1->ne[1];
-    //const int ne12 = src1->ne[2];
-    //const int ne13 = src1->ne[3];
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    //const int64_t ne12 = src1->ne[2];
+    //const int64_t ne13 = src1->ne[3];
 
-    //const int ne0  = dst->ne[0];
-    //const int ne1  = dst->ne[1];
-    //const int ne2  = dst->ne[2];
-    //const int ne3  = dst->ne[3];
-    //const int ne   = ne0*ne1*ne2*ne3;
+    //const int64_t ne0  = dst->ne[0];
+    //const int64_t ne1  = dst->ne[1];
+    //const int64_t ne2  = dst->ne[2];
+    //const int64_t ne3  = dst->ne[3];
+    //const int64_t ne   = ne0*ne1*ne2*ne3;
 
     const int nb00 = src0->nb[0];
     const int nb01 = src0->nb[1];
@@ -7333,11 +7734,11 @@ static void ggml_compute_forward_conv_1d_1s_f32(
         {
             float * const wdata = (float *) params->wdata + 0;
 
-            for (int i02 = 0; i02 < ne02; i02++) {
-                for (int i01 = 0; i01 < ne01; i01++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
                     const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
                     float * dst_data = wdata + i02*ew0*ne00;
-                    for (int i00 = 0; i00 < ne00; i00++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
                         dst_data[i00*ew0 + i01] = src[i00];
                     }
                 }
@@ -7348,10 +7749,10 @@ static void ggml_compute_forward_conv_1d_1s_f32(
         {
             float * const wdata = (float *) params->wdata + ne02*ew0*ne00;
 
-            for (int i11 = 0; i11 < ne11; i11++) {
+            for (int64_t i11 = 0; i11 < ne11; i11++) {
                 const float * const src = (float *)((char *) src1->data + i11*nb11);
                 float * dst_data = wdata;
-                for (int i10 = 0; i10 < ne10; i10++) {
+                for (int64_t i10 = 0; i10 < ne10; i10++) {
                     dst_data[(i10 + nh)*ew0 + i11] = src[i10];
                 }
             }
@@ -7376,7 +7777,7 @@ static void ggml_compute_forward_conv_1d_1s_f32(
 
     for (int i1 = ir0; i1 < ir1; i1++) {
         float * dst_data = (float *)((char *) dst->data + i1*nb1);
-        for (int i0 = 0; i0 < ne10; ++i0) {
+        for (int64_t i0 = 0; i0 < ne10; ++i0) {
             dst_data[i0] = 0;
             for (int k = -nh; k <= nh; k++) {
                 float v = 0.0f;
@@ -7430,21 +7831,21 @@ static void ggml_compute_forward_conv_1d_2s_f16_f32(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    const int ne00 = src0->ne[0];
-    const int ne01 = src0->ne[1];
-    const int ne02 = src0->ne[2];
-    //const int ne03 = src0->ne[3];
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    //const int64_t ne03 = src0->ne[3];
 
-    const int ne10 = src1->ne[0];
-    const int ne11 = src1->ne[1];
-    //const int ne12 = src1->ne[2];
-    //const int ne13 = src1->ne[3];
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    //const int64_t ne12 = src1->ne[2];
+    //const int64_t ne13 = src1->ne[3];
 
-    //const int ne0  = dst->ne[0];
-    //const int ne1  = dst->ne[1];
-    //const int ne2  = dst->ne[2];
-    //const int ne3  = dst->ne[3];
-    //const int ne   = ne0*ne1*ne2*ne3;
+    //const int64_t ne0  = dst->ne[0];
+    //const int64_t ne1  = dst->ne[1];
+    //const int64_t ne2  = dst->ne[2];
+    //const int64_t ne3  = dst->ne[3];
+    //const int64_t ne   = ne0*ne1*ne2*ne3;
 
     const int nb00 = src0->nb[0];
     const int nb01 = src0->nb[1];
@@ -7481,11 +7882,11 @@ static void ggml_compute_forward_conv_1d_2s_f16_f32(
         {
             ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
 
-            for (int i02 = 0; i02 < ne02; i02++) {
-                for (int i01 = 0; i01 < ne01; i01++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
                     const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
                     ggml_fp16_t * dst_data = wdata + i02*ew0*ne00;
-                    for (int i00 = 0; i00 < ne00; i00++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
                         dst_data[i00*ew0 + i01] = src[i00];
                     }
                 }
@@ -7496,10 +7897,10 @@ static void ggml_compute_forward_conv_1d_2s_f16_f32(
         {
             ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00;
 
-            for (int i11 = 0; i11 < ne11; i11++) {
+            for (int64_t i11 = 0; i11 < ne11; i11++) {
                 const float * const src = (float *)((char *) src1->data + i11*nb11);
                 ggml_fp16_t * dst_data = wdata;
-                for (int i10 = 0; i10 < ne10; i10++) {
+                for (int64_t i10 = 0; i10 < ne10; i10++) {
                     dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]);
                 }
             }
@@ -7524,7 +7925,7 @@ static void ggml_compute_forward_conv_1d_2s_f16_f32(
 
     for (int i1 = ir0; i1 < ir1; i1++) {
         float * dst_data = (float *)((char *) dst->data + i1*nb1);
-        for (int i0 = 0; i0 < ne10; i0 += 2) {
+        for (int64_t i0 = 0; i0 < ne10; i0 += 2) {
             dst_data[i0/2] = 0;
             for (int k = -nh; k <= nh; k++) {
                 float v = 0.0f;
@@ -7550,21 +7951,21 @@ static void ggml_compute_forward_conv_1d_2s_f32(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    const int ne00 = src0->ne[0];
-    const int ne01 = src0->ne[1];
-    const int ne02 = src0->ne[2];
-    //const int ne03 = src0->ne[3];
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    //const int64_t ne03 = src0->ne[3];
 
-    const int ne10 = src1->ne[0];
-    const int ne11 = src1->ne[1];
-    //const int ne12 = src1->ne[2];
-    //const int ne13 = src1->ne[3];
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    //const int64_t ne12 = src1->ne[2];
+    //const int64_t ne13 = src1->ne[3];
 
-    //const int ne0  = dst->ne[0];
-    //const int ne1  = dst->ne[1];
-    //const int ne2  = dst->ne[2];
-    //const int ne3  = dst->ne[3];
-    //const int ne   = ne0*ne1*ne2*ne3;
+    //const int64_t ne0  = dst->ne[0];
+    //const int64_t ne1  = dst->ne[1];
+    //const int64_t ne2  = dst->ne[2];
+    //const int64_t ne3  = dst->ne[3];
+    //const int64_t ne   = ne0*ne1*ne2*ne3;
 
     const int nb00 = src0->nb[0];
     const int nb01 = src0->nb[1];
@@ -7601,11 +8002,11 @@ static void ggml_compute_forward_conv_1d_2s_f32(
         {
             float * const wdata = (float *) params->wdata + 0;
 
-            for (int i02 = 0; i02 < ne02; i02++) {
-                for (int i01 = 0; i01 < ne01; i01++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
                     const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
                     float * dst_data = wdata + i02*ew0*ne00;
-                    for (int i00 = 0; i00 < ne00; i00++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
                         dst_data[i00*ew0 + i01] = src[i00];
                     }
                 }
@@ -7616,10 +8017,10 @@ static void ggml_compute_forward_conv_1d_2s_f32(
         {
             float * const wdata = (float *) params->wdata + ne02*ew0*ne00;
 
-            for (int i11 = 0; i11 < ne11; i11++) {
+            for (int64_t i11 = 0; i11 < ne11; i11++) {
                 const float * const src = (float *)((char *) src1->data + i11*nb11);
                 float * dst_data = wdata;
-                for (int i10 = 0; i10 < ne10; i10++) {
+                for (int64_t i10 = 0; i10 < ne10; i10++) {
                     dst_data[(i10 + nh)*ew0 + i11] = src[i10];
                 }
             }
@@ -7644,7 +8045,7 @@ static void ggml_compute_forward_conv_1d_2s_f32(
 
     for (int i1 = ir0; i1 < ir1; i1++) {
         float * dst_data = (float *)((char *) dst->data + i1*nb1);
-        for (int i0 = 0; i0 < ne10; i0 += 2) {
+        for (int64_t i0 = 0; i0 < ne10; i0 += 2) {
             dst_data[i0/2] = 0;
             for (int k = -nh; k <= nh; k++) {
                 float v = 0.0f;
@@ -7696,25 +8097,25 @@ static void ggml_compute_forward_flash_attn_f32(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    const int neq0 = q->ne[0];
-    const int neq1 = q->ne[1];
-    const int neq2 = q->ne[2];
-    const int neq3 = q->ne[3];
+    const int64_t neq0 = q->ne[0];
+    const int64_t neq1 = q->ne[1];
+    const int64_t neq2 = q->ne[2];
+    const int64_t neq3 = q->ne[3];
 
-    const int nek0 = k->ne[0];
-    const int nek1 = k->ne[1];
-    //const int nek2 = k->ne[2];
-    //const int nek3 = k->ne[3];
+    const int64_t nek0 = k->ne[0];
+    const int64_t nek1 = k->ne[1];
+    //const int64_t nek2 = k->ne[2];
+    //const int64_t nek3 = k->ne[3];
 
-    //const int nev0 = v->ne[0];
-    const int nev1 = v->ne[1];
-    //const int nev2 = v->ne[2];
-    //const int nev3 = v->ne[3];
+    //const int64_t nev0 = v->ne[0];
+    const int64_t nev1 = v->ne[1];
+    //const int64_t nev2 = v->ne[2];
+    //const int64_t nev3 = v->ne[3];
 
-    const int ne0  = dst->ne[0];
-    const int ne1  = dst->ne[1];
-    //const int ne2  = dst->ne[2];
-    //const int ne3  = dst->ne[3];
+    const int64_t ne0  = dst->ne[0];
+    const int64_t ne1  = dst->ne[1];
+    //const int64_t ne2  = dst->ne[2];
+    //const int64_t ne3  = dst->ne[3];
 
     const int nbk0 = k->nb[0];
     const int nbk1 = k->nb[1];
@@ -7739,10 +8140,10 @@ static void ggml_compute_forward_flash_attn_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int D = neq0;
-    const int N = neq1;
-    const int P = nek1 - N;
-    const int M = P + N;
+    const int64_t D = neq0;
+    const int64_t N = neq1;
+    const int64_t P = nek1 - N;
+    const int64_t M = P + N;
 
     const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
 
@@ -7788,7 +8189,7 @@ static void ggml_compute_forward_flash_attn_f32(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    const float scale = 1.0/sqrt((double) D);
+    const float scale = 1.0f/sqrtf(D);
 
     //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
 
@@ -7804,7 +8205,7 @@ static void ggml_compute_forward_flash_attn_f32(
             S[i] = -INFINITY;
         }
 
-        for (int ic = 0; ic < nek1; ++ic) {
+        for (int64_t ic = 0; ic < nek1; ++ic) {
             // k indices
             const int ik3 = iq3;
             const int ik2 = iq2;
@@ -7823,7 +8224,7 @@ static void ggml_compute_forward_flash_attn_f32(
         ggml_vec_scale_f32(nek1, S, scale);
 
         if (masked) {
-            for (int i = P; i < M; i++) {
+            for (int64_t i = P; i < M; i++) {
                 if (i > P + iq1) {
                     S[i] = -INFINITY;
                 }
@@ -7835,7 +8236,7 @@ static void ggml_compute_forward_flash_attn_f32(
             float max = -INFINITY;
             ggml_vec_max_f32(M, &max, S);
 
-            float sum = 0.0f;
+            ggml_float sum = 0.0;
             {
 #ifdef GGML_SOFT_MAX_ACCELERATE
                 max = -max;
@@ -7856,7 +8257,7 @@ static void ggml_compute_forward_flash_attn_f32(
                             ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
                             memcpy(&scvt[j], &s, sizeof(uint16_t));
                             const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
-                            sump[j] += val;
+                            sump[j] += (ggml_float)val;
                             SS[j] = val;
                         }
                     }
@@ -7868,7 +8269,7 @@ static void ggml_compute_forward_flash_attn_f32(
 #endif
             }
 
-            assert(sum > 0.0f);
+            assert(sum > 0.0);
 
             sum = 1.0/sum;
             ggml_vec_scale_f32(M, S, sum);
@@ -7881,7 +8282,7 @@ static void ggml_compute_forward_flash_attn_f32(
 #endif
         }
 
-        for (int ic = 0; ic < nev1; ++ic) {
+        for (int64_t ic = 0; ic < nev1; ++ic) {
             // dst indices
             const int i1 = iq1;
             const int i2 = iq2;
@@ -7905,25 +8306,25 @@ static void ggml_compute_forward_flash_attn_f16(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    const int neq0 = q->ne[0];
-    const int neq1 = q->ne[1];
-    const int neq2 = q->ne[2];
-    const int neq3 = q->ne[3];
+    const int64_t neq0 = q->ne[0];
+    const int64_t neq1 = q->ne[1];
+    const int64_t neq2 = q->ne[2];
+    const int64_t neq3 = q->ne[3];
 
-    const int nek0 = k->ne[0];
-    const int nek1 = k->ne[1];
-    //const int nek2 = k->ne[2];
-    //const int nek3 = k->ne[3];
+    const int64_t nek0 = k->ne[0];
+    const int64_t nek1 = k->ne[1];
+    //const int64_t nek2 = k->ne[2];
+    //const int64_t nek3 = k->ne[3];
 
-    //const int nev0 = v->ne[0];
-    const int nev1 = v->ne[1];
-    //const int nev2 = v->ne[2];
-    //const int nev3 = v->ne[3];
+    //const int64_t nev0 = v->ne[0];
+    const int64_t nev1 = v->ne[1];
+    //const int64_t nev2 = v->ne[2];
+    //const int64_t nev3 = v->ne[3];
 
-    const int ne0  = dst->ne[0];
-    const int ne1  = dst->ne[1];
-    //const int ne2  = dst->ne[2];
-    //const int ne3  = dst->ne[3];
+    const int64_t ne0  = dst->ne[0];
+    const int64_t ne1  = dst->ne[1];
+    //const int64_t ne2  = dst->ne[2];
+    //const int64_t ne3  = dst->ne[3];
 
     const int nbk0 = k->nb[0];
     const int nbk1 = k->nb[1];
@@ -7948,10 +8349,10 @@ static void ggml_compute_forward_flash_attn_f16(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int D = neq0;
-    const int N = neq1;
-    const int P = nek1 - N;
-    const int M = P + N;
+    const int64_t D = neq0;
+    const int64_t N = neq1;
+    const int64_t P = nek1 - N;
+    const int64_t M = P + N;
 
     const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
 
@@ -7997,7 +8398,7 @@ static void ggml_compute_forward_flash_attn_f16(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    const float scale = 1.0/sqrt((double) D);
+    const float scale = 1.0f/sqrtf(D);
 
     //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
 
@@ -8014,7 +8415,7 @@ static void ggml_compute_forward_flash_attn_f16(
         }
 
         if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
-            for (int ic = 0; ic < nek1; ++ic) {
+            for (int64_t ic = 0; ic < nek1; ++ic) {
                 // k indices
                 const int ik3 = iq3;
                 const int ik2 = iq2;
@@ -8029,7 +8430,7 @@ static void ggml_compute_forward_flash_attn_f16(
                         (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
             }
         } else {
-            for (int ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
+            for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
                 // k indices
                 const int ik3 = iq3;
                 const int ik2 = iq2;
@@ -8049,7 +8450,7 @@ static void ggml_compute_forward_flash_attn_f16(
         ggml_vec_scale_f32(nek1, S, scale);
 
         if (masked) {
-            for (int i = P; i < M; i++) {
+            for (int64_t i = P; i < M; i++) {
                 if (i > P + iq1) {
                     S[i] = -INFINITY;
                 }
@@ -8061,7 +8462,7 @@ static void ggml_compute_forward_flash_attn_f16(
             float max = -INFINITY;
             ggml_vec_max_f32(M, &max, S);
 
-            float sum = 0.0f;
+            ggml_float sum = 0.0;
             {
 #ifdef GGML_SOFT_MAX_ACCELERATE
                 max = -max;
@@ -8082,7 +8483,7 @@ static void ggml_compute_forward_flash_attn_f16(
                             ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
                             memcpy(&scvt[j], &s, sizeof(uint16_t));
                             const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
-                            sump[j] += val;
+                            sump[j] += (ggml_float)val;
                             SS[j] = val;
                         }
                     }
@@ -8094,7 +8495,7 @@ static void ggml_compute_forward_flash_attn_f16(
 #endif
             }
 
-            assert(sum > 0.0f);
+            assert(sum > 0.0);
 
             sum = 1.0/sum;
             ggml_vec_scale_f32(M, S, sum);
@@ -8109,12 +8510,12 @@ static void ggml_compute_forward_flash_attn_f16(
 
         ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
 
-        for (int i = 0; i < M; i++) {
+        for (int64_t i = 0; i < M; i++) {
             S16[i] = GGML_FP32_TO_FP16(S[i]);
         }
 
         if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
-            for (int ic = 0; ic < nev1; ++ic) {
+            for (int64_t ic = 0; ic < nev1; ++ic) {
                 // dst indices
                 const int i1 = iq1;
                 const int i2 = iq2;
@@ -8126,7 +8527,7 @@ static void ggml_compute_forward_flash_attn_f16(
                         S16);
             }
         } else {
-            for (int ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
+            for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
                 // dst indices
                 const int i1 = iq1;
                 const int i2 = iq2;
@@ -8182,35 +8583,35 @@ static void ggml_compute_forward_flash_ff_f16(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    const int nea0 = a->ne[0];
-    const int nea1 = a->ne[1];
-    const int nea2 = a->ne[2];
-    const int nea3 = a->ne[3];
+    const int64_t nea0 = a->ne[0];
+    const int64_t nea1 = a->ne[1];
+    const int64_t nea2 = a->ne[2];
+    const int64_t nea3 = a->ne[3];
 
-    const int neb00 = b0->ne[0];
-    const int neb01 = b0->ne[1];
-    //const int neb02 = b0->ne[2];
-    //const int neb03 = b0->ne[3];
+    const int64_t neb00 = b0->ne[0];
+    const int64_t neb01 = b0->ne[1];
+    //const int64_t neb02 = b0->ne[2];
+    //const int64_t neb03 = b0->ne[3];
 
-    const int neb10 = b1->ne[0];
-    const int neb11 = b1->ne[1];
-    //const int neb12 = b1->ne[2];
-    //const int neb13 = b1->ne[3];
+    const int64_t neb10 = b1->ne[0];
+    const int64_t neb11 = b1->ne[1];
+    //const int64_t neb12 = b1->ne[2];
+    //const int64_t neb13 = b1->ne[3];
 
-    const int nec00 = c0->ne[0];
-    const int nec01 = c0->ne[1];
-    //const int nec02 = c0->ne[2];
-    //const int nec03 = c0->ne[3];
+    const int64_t nec00 = c0->ne[0];
+    const int64_t nec01 = c0->ne[1];
+    //const int64_t nec02 = c0->ne[2];
+    //const int64_t nec03 = c0->ne[3];
 
-    const int nec10 = c1->ne[0];
-    const int nec11 = c1->ne[1];
-    //const int nec12 = c1->ne[2];
-    //const int nec13 = c1->ne[3];
+    const int64_t nec10 = c1->ne[0];
+    const int64_t nec11 = c1->ne[1];
+    //const int64_t nec12 = c1->ne[2];
+    //const int64_t nec13 = c1->ne[3];
 
-    const int ne0 = dst->ne[0];
-    const int ne1 = dst->ne[1];
-    const int ne2 = dst->ne[2];
-    //const int ne3 = dst->ne[3];
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+    //const int64_t ne3 = dst->ne[3];
 
     const int nba0 = a->nb[0];
     const int nba1 = a->nb[1];
@@ -8245,9 +8646,9 @@ static void ggml_compute_forward_flash_ff_f16(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int D = nea0;
-    //const int N = nea1;
-    const int M = neb01;
+    const int64_t D = nea0;
+    //const int64_t N = nea1;
+    const int64_t M = neb01;
 
     GGML_ASSERT(ne0 == nea0);
     GGML_ASSERT(ne1 == nea1);
@@ -8303,7 +8704,7 @@ static void ggml_compute_forward_flash_ff_f16(
 
         float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
 
-        for (int ic = 0; ic < neb01; ++ic) {
+        for (int64_t ic = 0; ic < neb01; ++ic) {
             // b0 indices
             const int ib03 = ia3;
             const int ib02 = ia2;
@@ -8323,7 +8724,7 @@ static void ggml_compute_forward_flash_ff_f16(
 
         ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
 
-        for (int i = 0; i < M; i++) {
+        for (int64_t i = 0; i < M; i++) {
             S16[i] = GGML_FP32_TO_FP16(S[i]);
         }
 
@@ -8335,7 +8736,7 @@ static void ggml_compute_forward_flash_ff_f16(
             const int i2 = ia2;
             const int i3 = ia3;
 
-            for (int ic = 0; ic < nec01; ++ic) {
+            for (int64_t ic = 0; ic < nec01; ++ic) {
 
                 ggml_vec_dot_f16(neb01,
                         (float *)       ((char *) dst->data + (ic*nb0 + i1*nb1   + i2*nb2   + i3*nb3)),
@@ -8474,6 +8875,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_cpy(params, tensor->src0, tensor);
             } break;
+        case GGML_OP_CONT:
+            {
+                ggml_compute_forward_cont(params, tensor->src0, tensor);
+            } break;
         case GGML_OP_RESHAPE:
             {
                 ggml_compute_forward_reshape(params, tensor->src0, tensor);
@@ -8718,8 +9123,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src1->grad =
                         ggml_add_impl(ctx,
                                 src1->grad,
-                                // TODO: fix transpose, the node will break the graph connections
-                                ggml_mul_mat(ctx, ggml_transpose(ctx, src0), tensor->grad),
+                                ggml_mul_mat(ctx,
+                                    ggml_cont(ctx, ggml_transpose(ctx, src0)),
+                                    tensor->grad),
                                 inplace);
                 }
             } break;
@@ -8731,6 +9137,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
+        case GGML_OP_CONT:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_RESHAPE:
             {
                 GGML_ASSERT(false); // TODO: not implemented
@@ -9147,8 +9557,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
 
                         size_t cur = 0;
 
-                        if (node->src0->type == GGML_TYPE_F16 &&
-                                node->src1->type == GGML_TYPE_F32) {
+                        if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
                                 node->n_tasks = 1; // TODO: this actually is doing nothing
@@ -9163,33 +9572,18 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
 #else
                             cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
 #endif
-                        } else if (node->src0->type == GGML_TYPE_F32 &&
-                                node->src1->type == GGML_TYPE_F32) {
+                        } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
                             cur = 0;
-                        } else if (node->src0->type == GGML_TYPE_Q4_0 &&
-                                node->src1->type == GGML_TYPE_F32) {
+                        } else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) {
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
                                 node->n_tasks = 1;
                                 cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
-                            } else {
-                                cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_0]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
-                            }
-#else
-                            cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_0]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
+                            } else
 #endif
-                        } else if (node->src0->type == GGML_TYPE_Q4_1 &&
-                                node->src1->type == GGML_TYPE_F32) {
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
-                            if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
-                                node->n_tasks = 1;
-                                cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
-                            } else {
-                                cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_1]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_1];
+                            {
+                                cur = GGML_TYPE_SIZE[node->src0->type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[node->src0->type];
                             }
-#else
-                            cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_1]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_1];
-#endif
                         } else {
                             GGML_ASSERT(false);
                         }
@@ -9201,6 +9595,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                         node->n_tasks = n_threads;
                     } break;
                 case GGML_OP_CPY:
+                case GGML_OP_CONT:
                 case GGML_OP_RESHAPE:
                 case GGML_OP_VIEW:
                 case GGML_OP_PERMUTE:
@@ -9216,7 +9611,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                     } break;
                 case GGML_OP_ROPE:
                     {
-                        node->n_tasks = 1;
+                        node->n_tasks = n_threads;
                     } break;
                 case GGML_OP_CONV_1D_1S:
                 case GGML_OP_CONV_1D_2S:
@@ -9254,7 +9649,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
 
                         size_t cur = 0;
 
-                        const int ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
+                        const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
 
                         if (node->src1->type == GGML_TYPE_F32) {
                             cur  = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1)
@@ -9513,7 +9908,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
 
         perf_total_per_op_us[node->op] += node->perf_time_us;
 
-        GGML_PRINT(" - %3d: [ %6d, %6d, %6d] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
+        GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 ", %" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
                 i,
                 node->ne[0], node->ne[1], node->ne[2],
                 GGML_OP_LABEL[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
@@ -9527,7 +9922,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
     for (int i = 0; i < cgraph->n_leafs; i++) {
         struct ggml_tensor * node = cgraph->leafs[i];
 
-        GGML_PRINT(" - %3d: [ %6d, %6d] %8s\n",
+        GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 "] %8s\n",
                 i,
                 node->ne[0], node->ne[1],
                 GGML_OP_LABEL[node->op]);
@@ -9598,7 +9993,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
 
         fprintf(fp, "  \"%p\" [ \
 style = filled; fillcolor = %s; shape = record; \
-label=\"%d [%d, %d] | <x>%s",
+label=\"%d [%" PRId64 ", %" PRId64 "] | <x>%s",
                 (void *) node, color,
                 i, node->ne[0], node->ne[1],
                 GGML_OP_SYMBOL[node->op]);
@@ -9619,11 +10014,11 @@ label=\"%d [%d, %d] | <x>%s",
             fprintf(fp, "  \"%p\" [ \
 style = filled; fillcolor = %s; shape = record; \
 label=\"<x>%.1e\"; ]\n",
-                    (void *) node, color, ggml_get_f32_1d(node, 0));
+                    (void *) node, color, (double)ggml_get_f32_1d(node, 0));
         } else {
             fprintf(fp, "  \"%p\" [ \
 style = filled; fillcolor = %s; shape = record; \
-label=\"<x>CONST %d [%d, %d]\"; ]\n",
+label=\"<x>CONST %d [%" PRId64 ", %" PRId64 "]\"; ]\n",
                     (void *) node, color,
                     i, node->ne[0], node->ne[1]);
         }
@@ -9687,9 +10082,9 @@ label=\"<x>CONST %d [%d, %d]\"; ]\n",
 static void ggml_opt_set_params(int np, struct ggml_tensor * const ps[], const float * x) {
     int i = 0;
     for (int p = 0; p < np; ++p) {
-        const int ne = ggml_nelements(ps[p]) ;
+        const int64_t ne = ggml_nelements(ps[p]) ;
         // TODO: add function to set tensor from array
-        for (int j = 0; j < ne; ++j) {
+        for (int64_t j = 0; j < ne; ++j) {
             ggml_set_f32_1d(ps[p], j, x[i++]);
         }
     }
@@ -9698,9 +10093,9 @@ static void ggml_opt_set_params(int np, struct ggml_tensor * const ps[], const f
 static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float * x) {
     int i = 0;
     for (int p = 0; p < np; ++p) {
-        const int ne = ggml_nelements(ps[p]) ;
+        const int64_t ne = ggml_nelements(ps[p]) ;
         // TODO: add function to get all elements at once
-        for (int j = 0; j < ne; ++j) {
+        for (int64_t j = 0; j < ne; ++j) {
             x[i++] = ggml_get_f32_1d(ps[p], j);
         }
     }
@@ -9709,9 +10104,9 @@ static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float *
 static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) {
     int i = 0;
     for (int p = 0; p < np; ++p) {
-        const int ne = ggml_nelements(ps[p]) ;
+        const int64_t ne = ggml_nelements(ps[p]) ;
         // TODO: add function to get all elements at once
-        for (int j = 0; j < ne; ++j) {
+        for (int64_t j = 0; j < ne; ++j) {
             g[i++] = ggml_get_f32_1d(ps[p]->grad, j);
         }
     }
@@ -9857,7 +10252,7 @@ static enum ggml_opt_result ggml_opt_adam(
             if (params.past <= t) {
                 const float rate = (pf[t%params.past] - fx)/fx;
 
-                if (fabs(rate) < params.delta) {
+                if (fabsf(rate) < params.delta) {
                     return GGML_OPT_OK;
                 }
             }
@@ -9936,7 +10331,7 @@ static enum ggml_opt_result linesearch_backtracking(
     const float dec = 0.5f;
     const float inc = 2.1f;
 
-    if (*step <= 0.) {
+    if (*step <= 0.f) {
         return GGML_LINESEARCH_INVALID_PARAMETERS;
     }
 
@@ -10024,7 +10419,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
         struct ggml_cgraph * gb) {
     if (params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE ||
         params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) {
-        if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1. <= params.lbfgs.wolfe) {
+        if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1.f <= params.lbfgs.wolfe) {
             return GGML_OPT_INVALID_WOLFE;
         }
     }
@@ -10145,8 +10540,8 @@ static enum ggml_opt_result ggml_opt_lbfgs(
 
         GGML_PRINT_DEBUG("f = %10.6f\n", ggml_get_f32_1d(f, 0));
 
-        if (xnorm < 1.0) {
-            xnorm = 1.0;
+        if (xnorm < 1.0f) {
+            xnorm = 1.0f;
         }
         if (gnorm/xnorm <= params.lbfgs.eps) {
             // converged
@@ -10159,7 +10554,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
             if (params.past <= k) {
                 const float rate = (pf[k%params.past] - fx)/fx;
 
-                if (fabs(rate) < params.delta) {
+                if (fabsf(rate) < params.delta) {
                     return GGML_OPT_OK;
                 }
             }
@@ -10309,6 +10704,7 @@ enum ggml_opt_result ggml_opt(
         struct ggml_init_params params_ctx = {
             .mem_size   = 16*1024*1024,
             .mem_buffer = NULL,
+            .no_alloc   = false,
         };
 
         ctx = ggml_init(params_ctx);
@@ -10355,64 +10751,50 @@ enum ggml_opt_result ggml_opt(
 
 ////////////////////////////////////////////////////////////////////////////////
 
-size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int qk, int64_t * hist) {
-    const int nb = k / qk;
-    const size_t bs = (sizeof(float) + sizeof(uint8_t)*qk/2);
-    const size_t row_size = nb*bs;
-
-    assert(k % qk == 0);
-
-    char * pdst = (char *) dst;
+size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) {
+    assert(k % QK == 0);
+    const int nb = k / QK;
 
     for (int j = 0; j < n; j += k) {
-        uint8_t * pd = (uint8_t *) (pdst + (j/k)*row_size + 0*bs);
-        uint8_t * pb = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + sizeof(float));
+        block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK;
 
-        quantize_row_q4_0_reference(src + j, pd, k);
+        quantize_row_q4_0_reference(src + j, y, k);
 
         for (int i = 0; i < nb; i++) {
-            for (int l = 0; l < qk; l += 2) {
-                const uint8_t vi0 = pb[l/2] & 0xF;
-                const uint8_t vi1 = pb[l/2] >> 4;
+            for (int l = 0; l < QK; l += 2) {
+                const uint8_t vi0 = y[i].qs[l/2] & 0xF;
+                const uint8_t vi1 = y[i].qs[l/2] >> 4;
 
                 hist[vi0]++;
                 hist[vi1]++;
             }
-            pb += bs;
         }
     }
 
-    return (n/k)*row_size;
+    return (n/QK*sizeof(block_q4_0));
 }
 
-size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int qk, int64_t * hist) {
-    const int nb = k / qk;
-    const size_t bs = (2*sizeof(float) + sizeof(uint8_t)*qk/2);
-    const size_t row_size = nb*bs;
-
-    assert(k % qk == 0);
-
-    char * pdst = (char *) dst;
+size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) {
+    assert(k % QK == 0);
+    const int nb = k / QK;
 
     for (int j = 0; j < n; j += k) {
-        uint8_t * pd = (uint8_t *) (pdst + (j/k)*row_size + 0*bs);
-        uint8_t * pb = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + 2*sizeof(float));
+        block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK;
 
-        quantize_row_q4_1(src + j, pd, k);
+        quantize_row_q4_1_reference(src + j, y, k);
 
         for (int i = 0; i < nb; i++) {
-            for (int l = 0; l < qk; l += 2) {
-                const uint8_t vi0 = pb[l/2] & 0xF;
-                const uint8_t vi1 = pb[l/2] >> 4;
+            for (int l = 0; l < QK; l += 2) {
+                const uint8_t vi0 = y[i].qs[l/2] & 0xF;
+                const uint8_t vi1 = y[i].qs[l/2] >> 4;
 
                 hist[vi0]++;
                 hist[vi1]++;
             }
-            pb += bs;
         }
     }
 
-    return (n/k)*row_size;
+    return (n/QK*sizeof(block_q4_1));
 }
 
 ////////////////////////////////////////////////////////////////////////////////
diff --git a/ggml.h b/ggml.h
index ddb97318b33916243b5c0fcee6707c8890b18a86..a5245a8ae6256c0bee449cff0d9112e24df152dc 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -236,6 +236,7 @@ enum ggml_op {
 
     GGML_OP_SCALE,
     GGML_OP_CPY,
+    GGML_OP_CONT,
     GGML_OP_RESHAPE,
     GGML_OP_VIEW,
     GGML_OP_PERMUTE,
@@ -253,16 +254,29 @@ enum ggml_op {
     GGML_OP_COUNT,
 };
 
+
+// ggml object
+struct ggml_object {
+    size_t offs;
+    size_t size;
+
+    struct ggml_object * next;
+
+    char padding[8];
+};
+
+static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
+
 // n-dimensional tensor
 struct ggml_tensor {
     enum ggml_type type;
 
     int    n_dims;
-    int    ne[GGML_MAX_DIMS]; // number of elements
-    size_t nb[GGML_MAX_DIMS]; // stride in bytes:
-                              // nb[0] = sizeof(type)
-                              // nb[1] = nb[0]   * ne[0] + padding
-                              // nb[i] = nb[i-1] * ne[i-1]
+    int64_t ne[GGML_MAX_DIMS]; // number of elements
+    size_t  nb[GGML_MAX_DIMS]; // stride in bytes:
+                               // nb[0] = sizeof(type)
+                               // nb[1] = nb[0]   * ne[0] + padding
+                               // nb[i] = nb[i-1] * ne[i-1]
 
     // compute data
     enum ggml_op op;
@@ -316,6 +330,7 @@ struct ggml_init_params {
     // memory pool
     size_t mem_size;   // bytes
     void * mem_buffer; // if NULL, memory will be allocated internally
+    bool   no_alloc;   // don't allocate memory for the tensor data
 };
 
 void    ggml_time_init(void); // call this once at the beginning of the program
@@ -327,8 +342,8 @@ int64_t ggml_cycles_per_ms(void);
 void ggml_print_object (const struct ggml_object * obj);
 void ggml_print_objects(const struct ggml_context * ctx);
 
-int    ggml_nelements(const struct ggml_tensor * tensor);
-size_t ggml_nbytes   (const struct ggml_tensor * tensor);
+int64_t ggml_nelements(const struct ggml_tensor * tensor);
+size_t  ggml_nbytes   (const struct ggml_tensor * tensor);
 
 int    ggml_blck_size (enum ggml_type type);
 size_t ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block
@@ -343,40 +358,37 @@ size_t ggml_used_mem(const struct ggml_context * ctx);
 
 size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch);
 
-bool ggml_mlock_supported(void);
-bool ggml_mlock(struct ggml_context * ctx, char ** err_p);
-
 struct ggml_tensor * ggml_new_tensor(
         struct ggml_context * ctx,
         enum   ggml_type type,
         int    n_dims,
-        const int *ne);
+        const int64_t *ne);
 
 struct ggml_tensor * ggml_new_tensor_1d(
         struct ggml_context * ctx,
         enum   ggml_type type,
-        int    ne0);
+        int64_t ne0);
 
 struct ggml_tensor * ggml_new_tensor_2d(
         struct ggml_context * ctx,
         enum   ggml_type type,
-        int    ne0,
-        int    ne1);
+        int64_t ne0,
+        int64_t ne1);
 
 struct ggml_tensor * ggml_new_tensor_3d(
         struct ggml_context * ctx,
         enum   ggml_type type,
-        int    ne0,
-        int    ne1,
-        int    ne2);
+        int64_t ne0,
+        int64_t ne1,
+        int64_t ne2);
 
 struct ggml_tensor * ggml_new_tensor_4d(
         struct ggml_context * ctx,
         enum   ggml_type type,
-        int    ne0,
-        int    ne1,
-        int    ne2,
-        int    ne3);
+        int64_t ne0,
+        int64_t ne1,
+        int64_t ne2,
+        int64_t ne3);
 
 struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
 struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
@@ -514,6 +526,11 @@ struct ggml_tensor * ggml_cpy(
         struct ggml_tensor  * a,
         struct ggml_tensor  * b);
 
+// make contiguous
+struct ggml_tensor * ggml_cont(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
 // return view(a), b specifies the new shape
 // TODO: when we start computing gradient, make a copy instead of view
 struct ggml_tensor * ggml_reshape(
@@ -526,33 +543,43 @@ struct ggml_tensor * ggml_reshape(
 struct ggml_tensor * ggml_reshape_2d(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        int                   ne0,
-        int                   ne1);
+        int64_t               ne0,
+        int64_t               ne1);
 
 // return view(a)
 // TODO: when we start computing gradient, make a copy instead of view
 struct ggml_tensor * ggml_reshape_3d(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        int                   ne0,
-        int                   ne1,
-        int                   ne2);
+        int64_t               ne0,
+        int64_t               ne1,
+        int64_t               ne2);
 
 // offset in bytes
 struct ggml_tensor * ggml_view_1d(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        int                   ne0,
+        int64_t               ne0,
         size_t                offset);
 
 struct ggml_tensor * ggml_view_2d(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        int                   ne0,
-        int                   ne1,
+        int64_t               ne0,
+        int64_t               ne1,
         size_t                nb1, // row stride in bytes
         size_t                offset);
 
+struct ggml_tensor * ggml_view_3d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int64_t               ne0,
+        int64_t               ne1,
+        int64_t               ne2,
+        size_t                nb1, // row   stride in bytes
+        size_t                nb2, // slice stride in bytes
+        size_t                offset);
+
 struct ggml_tensor * ggml_permute(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
@@ -748,8 +775,8 @@ enum ggml_opt_result ggml_opt(
 // quantization
 //
 
-size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int qk, int64_t * hist);
-size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int qk, int64_t * hist);
+size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
+size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
 
 //
 // system info
@@ -768,6 +795,30 @@ int ggml_cpu_has_blas(void);
 int ggml_cpu_has_sse3(void);
 int ggml_cpu_has_vsx(void);
 
+
+//
+// Internal types and functions exposed for tests and benchmarks
+//
+
+#ifdef  __cplusplus
+// restrict not standard in C++
+#define GGML_RESTRICT
+#else
+#define GGML_RESTRICT restrict
+#endif
+typedef void (*dequantize_row_q_t)(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
+typedef void (*quantize_row_q_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
+typedef void (*vec_dot_q_t)(const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT x, const void * GGML_RESTRICT y);
+
+typedef struct {
+    dequantize_row_q_t dequantize_row_q;
+    quantize_row_q_t   quantize_row_q;
+    quantize_row_q_t   quantize_row_q_reference;
+    vec_dot_q_t        vec_dot_q;
+} quantize_fns_t;
+
+quantize_fns_t ggml_internal_get_quantize_fn(size_t i);
+
 #ifdef  __cplusplus
 }
 #endif
index 95b6d33905d16c487c0c1d495a6f76443be44821..24b9e5d8bd042c5ee9969da44058f68d1a675a09 100644 (file)
@@ -654,9 +654,11 @@ static bool kv_cache_init(
                                  int   n_ctx) {
     cache.buf.resize(mem_bytes);
 
-    struct ggml_init_params params;
-    params.mem_size   = cache.buf.size();
-    params.mem_buffer = cache.buf.data();
+    struct ggml_init_params params = {
+        /*.mem_size   =*/ cache.buf.size(),
+        /*.mem_buffer =*/ cache.buf.data(),
+        /*.no_alloc   =*/ false,
+    };
 
     cache.ctx = ggml_init(params);
 
@@ -688,9 +690,11 @@ static bool kv_cache_reinit(struct whisper_kv_cache & cache) {
 
     WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_size(wtype));
 
-    struct ggml_init_params params;
-    params.mem_size   = cache.buf.size();
-    params.mem_buffer = cache.buf.data();
+    struct ggml_init_params params = {
+        /*.mem_size   =*/ cache.buf.size(),
+        /*.mem_buffer =*/ cache.buf.data(),
+        /*.no_alloc   =*/ false,
+    };
 
     cache.ctx = ggml_init(params);
 
@@ -1028,9 +1032,11 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
 
     // create the ggml context
     {
-        struct ggml_init_params params;
-        params.mem_size   = wctx.model.buf->size();
-        params.mem_buffer = wctx.model.buf->data();
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ wctx.model.buf->size(),
+            /*.mem_buffer =*/ wctx.model.buf->data(),
+            /*.no_alloc   =*/ false,
+        };
 
         model.ctx = ggml_init(params);
         if (!model.ctx) {
@@ -1254,10 +1260,12 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
                 break;
             }
 
-            int32_t nelements = 1;
-            int32_t ne[3] = { 1, 1, 1 };
+            int64_t nelements = 1;
+            int64_t ne[3] = { 1, 1, 1 };
             for (int i = 0; i < n_dims; ++i) {
-                read_safe(loader, ne[i]);
+                int32_t ne_cur;
+                read_safe(loader, ne_cur);
+                ne[i] = ne_cur;
                 nelements *= ne[i];
             }
 
@@ -1278,7 +1286,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
             }
 
             if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
-                fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
+                fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%lld, %lld, %lld], expected [%lld, %lld, %lld]\n",
                         __func__, name.data(), tensor->ne[0], tensor->ne[1], tensor->ne[2], ne[0], ne[1], ne[2]);
                 return false;
             }
@@ -1286,7 +1294,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
             const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
 
             if (nelements*bpe != ggml_nbytes(tensor)) {
-                fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
+                fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %llu\n",
                         __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
                 return false;
             }
@@ -1344,9 +1352,11 @@ static bool whisper_encode_internal(
     const int n_mels = hparams.n_mels;
     assert(mel_inp.n_mel == n_mels);
 
-    struct ggml_init_params params;
-    params.mem_size   = wstate.buf_compute.size();
-    params.mem_buffer = wstate.buf_compute.data();
+    struct ggml_init_params params = {
+        /*.mem_size   =*/ wstate.buf_compute.size(),
+        /*.mem_buffer =*/ wstate.buf_compute.data(),
+        /*.no_alloc   =*/ false,
+    };
 
     struct ggml_context * ctx0 = ggml_init(params);
 
@@ -1501,8 +1511,7 @@ static bool whisper_encode_internal(
                                 Vcur,
                                 n_state/n_head, n_head, n_ctx),
                             1, 2, 0, 3),
-                        ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head)
-                        );
+                        ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head));
 
             struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
 #else
@@ -1726,10 +1735,12 @@ static bool whisper_encode_internal(
 
             wstate.use_buf(ctx0, -1);
 
-            //struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
-            //struct ggml_tensor * v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx));
-            struct ggml_tensor* k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
-            struct ggml_tensor* v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx));
+            Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
+
+            struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
+            struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
+                    (   n_ctx)*ggml_element_size(wstate.kv_cross.v),
+                    (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
 
             ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
             ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
@@ -1797,9 +1808,11 @@ static bool whisper_decode_internal(
 
     //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
 
-    struct ggml_init_params params;
-    params.mem_size   = wstate.buf_compute.size();
-    params.mem_buffer = wstate.buf_compute.data();
+    struct ggml_init_params params = {
+        /*.mem_size   =*/ wstate.buf_compute.size(),
+        /*.mem_buffer =*/ wstate.buf_compute.data(),
+        /*.no_alloc   =*/ false,
+    };
 
     struct ggml_context * ctx0 = ggml_init(params);
 
@@ -1862,20 +1875,24 @@ static bool whisper_decode_internal(
 
             Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
 
-            struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
-                    layer.attn_v_w,
-                    cur);
-
-            Vcur = ggml_add(ctx0,
-                    ggml_repeat(ctx0,
-                        layer.attn_v_b,
-                        Vcur),
-                    Vcur);
-
             // store key and value to memory
             {
+                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
+                        layer.attn_v_w,
+                        cur);
+
+                Vcur = ggml_add(ctx0,
+                        ggml_repeat(ctx0,
+                            layer.attn_v_b,
+                            Vcur),
+                        Vcur);
+
+                Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N));
+
                 struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past));
-                struct ggml_tensor * v = ggml_view_1d(ctx0, kv_self.v, N*n_state, (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + n_past));
+                struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_state,
+                        (   n_ctx)*ggml_element_size(kv_self.v),
+                        (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v));
 
                 ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
                 ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
@@ -1914,16 +1931,14 @@ static bool whisper_decode_internal(
 
             struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
 
-            struct ggml_tensor * V_trans =
-                ggml_cpy(ctx0,
-                        ggml_permute(ctx0,
-                            ggml_reshape_3d(ctx0,
-                                ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state),
-                                n_state/n_head, n_head, n_past + N),
-                            1, 2, 0, 3),
-                        ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_state/n_head, n_head));
+            struct ggml_tensor * V =
+                ggml_view_3d(ctx0, kv_self.v,
+                        n_past + N, n_state/n_head, n_head,
+                        n_ctx*ggml_element_size(kv_self.v),
+                        n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
+                        il*n_ctx*ggml_element_size(kv_self.v)*n_state);
 
-            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
 
             struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 
@@ -1986,15 +2001,22 @@ static bool whisper_decode_internal(
                         ggml_view_1d(ctx0, wstate.kv_cross.k, M*n_state, il*M*ggml_element_size(wstate.kv_cross.k)*n_state),
                         n_state/n_head, n_head, M);
 
-            struct ggml_tensor * Vcross =
-                ggml_reshape_3d(ctx0,
-                        ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
-                        n_state/n_head, n_head, M);
+            //struct ggml_tensor * Vcross =
+            //    ggml_reshape_3d(ctx0,
+            //            ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
+            //            n_state/n_head, n_head, M);
 
-            struct ggml_tensor * V_trans =
-                ggml_cpy(ctx0,
-                        ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
-                        ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head));
+            //struct ggml_tensor * V_trans =
+            //    ggml_cpy(ctx0,
+            //            ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
+            //            ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head));
+
+            struct ggml_tensor * V =
+                ggml_view_3d(ctx0, wstate.kv_cross.v,
+                        M, n_state/n_head, n_head,
+                        M*ggml_element_size(wstate.kv_cross.v),
+                        M*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
+                        il*M*ggml_element_size(wstate.kv_cross.v)*n_state);
 
             // ------
 
@@ -2021,7 +2043,7 @@ static bool whisper_decode_internal(
 
             struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
 
-            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
 
             struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 
@@ -4726,6 +4748,7 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
             struct ggml_init_params gparams = {
                 /*.mem_size   =*/ buf.size(),
                 /*.mem_buffer =*/ buf.data(),
+                /*.no_alloc   =*/ false,
             };
 
             struct ggml_context * ctx0 = ggml_init(gparams);