]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : sync llama.cpp (NUMA + thread improvements + k-quants)
authorGeorgi Gerganov <redacted>
Mon, 26 Jun 2023 18:10:24 +0000 (21:10 +0300)
committerGeorgi Gerganov <redacted>
Mon, 26 Jun 2023 18:10:24 +0000 (21:10 +0300)
include/ggml/ggml.h
src/ggml-cuda.cu
src/ggml-metal.m
src/ggml-metal.metal
src/ggml.c

index 5ebd9c46c3d5267df4e2a879cfcb633bd23c6085..6b106b1c39f6349dbdac1f08b18bfe988b990360 100644 (file)
@@ -469,6 +469,9 @@ extern "C" {
     GGML_API int64_t ggml_cycles(void);
     GGML_API int64_t ggml_cycles_per_ms(void);
 
+    GGML_API void    ggml_numa_init(void); // call once for better performance on NUMA systems
+    GGML_API bool    ggml_is_numa(void); // true if init detected that system has >1 NUMA node
+
     GGML_API void    ggml_print_object (const struct ggml_object * obj);
     GGML_API void    ggml_print_objects(const struct ggml_context * ctx);
 
index 010682edb703cbbf1282b0ab76e5298132662f49..c34e96abfd08c105407164f244d11d6622af4420 100644 (file)
@@ -117,7 +117,13 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo
 
 //================================= k-quants
 
+#ifdef GGML_QKK_64
+#define QK_K 64
+#define K_SCALE_SIZE 4
+#else
 #define QK_K 256
+#define K_SCALE_SIZE 12
+#endif
 
 typedef struct {
     uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
@@ -128,13 +134,25 @@ typedef struct {
 static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
 
 typedef struct {
-    uint8_t hmask[QK_K/8];
-    uint8_t qs[QK_K/4]; // nibbles / quants
-    uint8_t scales[3*QK_K/64];
-    half d;
+    uint8_t hmask[QK_K/8];     // quants - high bit
+    uint8_t qs[QK_K/4];        // quants - low 2 bits
+#ifdef GGML_QKK_64
+    uint8_t scales[2]; // scales, quantized with 8 bits
+#else
+    uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
+#endif
+    half d;             // super-block scale
 } block_q3_K;
-static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");
+//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding");
 
+#ifdef GGML_QKK_64
+typedef struct {
+    half    d[2];              // super-block scales/mins
+    uint8_t scales[2];         // 4-bit block scales/mins
+    uint8_t qs[QK_K/2];        // 4--bit quants
+} block_q4_K;
+static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
+#else
 typedef struct {
     half d;                    // super-block scale for quantized scales
     half dmin;                 // super-block scale for quantized mins
@@ -142,15 +160,26 @@ typedef struct {
     uint8_t qs[QK_K/2];        // 4--bit quants
 } block_q4_K;
 static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
+#endif
 
+#ifdef GGML_QKK_64
 typedef struct {
-    half    d;                   // super-block scale for quantized scales
-    half    dmin;                // super-block scale for quantized mins
-    uint8_t scales[3*QK_K/64];   // scales, quantized with 6 bits
+    half d;                  // super-block scale
+    int8_t scales[QK_K/16];  // block scales
+    uint8_t qh[QK_K/8];      // quants, high bit
+    uint8_t qs[QK_K/2];      // quants, low 4 bits
+} block_q5_K;
+static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
+#else
+typedef struct {
+    half d;               // super-block scale for quantized scales
+    half dmin;            // super-block scale for quantized mins
+    uint8_t scales[K_SCALE_SIZE];   // scales and mins, quantized with 6 bits
     uint8_t qh[QK_K/8];          // quants, high bit
     uint8_t qs[QK_K/2];          // quants, low 4 bits
 } block_q5_K;
-static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
+static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
+#endif
 
 typedef struct {
     uint8_t ql[QK_K/2];   // quants, lower 4 bits
@@ -349,13 +378,14 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
 static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
 
     const int i   = blockIdx.x;
+    const block_q2_K * x = (const block_q2_K *) vx;
+
     const int tid = threadIdx.x;
+#if QK_K == 256
     const int n   = tid/32;
     const int l   = tid - 32*n;
     const int is  = 8*n + l/16;
 
-    const block_q2_K * x = (const block_q2_K *) vx;
-
     const uint8_t q = x[i].qs[32*n + l];
     float * y = yy + i*QK_K + 128*n;
 
@@ -365,21 +395,32 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
     y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
     y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
     y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
+#else
+    const int is = tid/16;  // 0 or 1
+    const int il = tid%16;  // 0...15
+    const uint8_t q = x[i].qs[il] >> (2*is);
+    float * y = yy + i*QK_K + 16*is + il;
+    float dall = x[i].d;
+    float dmin = x[i].dmin;
+    y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
+    y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
+#endif
 
 }
 
 static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
 
-    int r = threadIdx.x/4;
-    int i = blockIdx.x;
-    int tid = r/2;
-    int is0 = r%2;
-    int l0 = 16*is0 + 4*(threadIdx.x%4);
-    int n = tid / 4;
-    int j = tid - 4*n;
-
+    const int i = blockIdx.x;
     const block_q3_K * x = (const block_q3_K *) vx;
 
+#if QK_K == 256
+    const int r = threadIdx.x/4;
+    const int tid = r/2;
+    const int is0 = r%2;
+    const int l0 = 16*is0 + 4*(threadIdx.x%4);
+    const int n = tid / 4;
+    const int j = tid - 4*n;
+
     uint8_t m = 1 << (4*n + j);
     int is = 8*n + 2*j + is0;
     int shift = 2*j;
@@ -396,9 +437,31 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
     const uint8_t * hm = x[i].hmask;
 
     for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
+#else
+    const int tid = threadIdx.x;
+    const int is  = tid/16;  // 0 or 1
+    const int il  = tid%16;  // 0...15
+    const int im  = il/8;    // 0...1
+    const int in  = il%8;    // 0...7
+
+    float * y = yy + i*QK_K + 16*is + il;
+
+    const uint8_t q = x[i].qs[il] >> (2*is);
+    const uint8_t h = x[i].hmask[in] >> (2*is + im);
+    const float   d = (float)x[i].d;
+
+    if (is == 0) {
+        y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
+        y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
+    } else {
+        y[ 0] = d * ((x[i].scales[0] >>  4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
+        y[32] = d * ((x[i].scales[1] >>  4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
+    }
+#endif
 
 }
 
+#if QK_K == 256
 static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
     if (j < 4) {
         d = q[j] & 63; m = q[j + 4] & 63;
@@ -407,19 +470,14 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
         m = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);
     }
 }
+#endif
 
 static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
     const block_q4_K * x = (const block_q4_K *) vx;
 
     const int i = blockIdx.x;
 
-    //// assume 64 threads - this is very slightly better than the one below
-    //const int tid = threadIdx.x;
-    //const int il  = tid/16;
-    //const int ir  = tid%16;
-    //const int is  = 2*il;
-    //const int n   = 2;
-
+#if QK_K == 256
     // assume 32 threads
     const int tid = threadIdx.x;
     const int il  = tid/8;
@@ -443,6 +501,15 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
         y[l + 0] = d1 * (q[l] & 0xF) - m1;
         y[l +32] = d2 * (q[l] >>  4) - m2;
     }
+#else
+    const int tid = threadIdx.x;
+    const uint8_t * q = x[i].qs;
+    float * y = yy + i*QK_K;
+    const float d = (float)x[i].d[0];
+    const float m = (float)x[i].d[1];
+    y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
+    y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >>  4) - m * (x[i].scales[1] >> 4);
+#endif
 }
 
 static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
@@ -450,6 +517,7 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
 
     const int i = blockIdx.x;
 
+#if QK_K == 256
     // assume 64 threads - this is very slightly better than the one below
     const int tid = threadIdx.x;
     const int il  = tid/16;   // il is in 0...3
@@ -476,12 +544,25 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
     hm <<= 1;
     y[32] = d2 * ((ql[ 0] >>  4) + (qh[ 0] & hm ? 16 : 0)) - m2;
     y[33] = d2 * ((ql[ 1] >>  4) + (qh[ 1] & hm ? 16 : 0)) - m2;
+#else
+    const int tid = threadIdx.x;
+    const uint8_t q = x[i].qs[tid];
+    const int im = tid/8;  // 0...3
+    const int in = tid%8;  // 0...7
+    const int is = tid/16; // 0 or 1
+    const uint8_t h = x[i].qh[in] >> im;
+    const float d = x[i].d;
+    float * y = yy + i*QK_K + tid;
+    y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
+    y[32] = d * x[i].scales[is+2] * ((q >>  4) - ((h >> 4) & 1 ? 0 : 16));
+#endif
 }
 
 static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
     const block_q6_K * x = (const block_q6_K *) vx;
 
     const int i = blockIdx.x;
+#if QK_K == 256
 
     // assume 64 threads - this is very slightly better than the one below
     const int tid = threadIdx.x;
@@ -501,6 +582,24 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
     y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
     y[64] = d * sc[4] * ((int8_t)((ql[ 0]  >> 4) | (((qh >> 4) & 3) << 4)) - 32);
     y[96] = d * sc[6] * ((int8_t)((ql[32]  >> 4) | (((qh >> 6) & 3) << 4)) - 32);
+#else
+
+    // assume 32 threads
+    const int tid = threadIdx.x;
+    const int ip  = tid/16;         // 0 or 1
+    const int il  = tid - 16*ip;    // 0...15
+
+    float * y = yy + i*QK_K + 16*ip + il;
+
+    const float d = x[i].d;
+
+    const uint8_t   ql = x[i].ql[16*ip + il];
+    const uint8_t   qh = x[i].qh[il] >> (2*ip);
+    const int8_t  * sc = x[i].scales;
+
+    y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
+    y[32] = d * sc[ip+2] * ((int8_t)((ql  >> 4) | (((qh >> 4) & 3) << 4)) - 32);
+#endif
 }
 
 static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
@@ -515,6 +614,9 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
 
     const block_q2_K * x = (const block_q2_K *)vx + ib0;
 
+    float tmp = 0; // partial sum for thread in warp
+
+#if QK_K == 256
     const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...15
     const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0,1
 
@@ -528,8 +630,6 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
     const int s_offset = 8*im;
     const int y_offset = 128*im + l0;
 
-    float tmp = 0; // partial sum for thread in warp
-
     uint32_t aux[4];
     const uint8_t * d = (const uint8_t *)aux;
     const uint8_t * m = (const uint8_t *)(aux + 2);
@@ -565,6 +665,39 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
         tmp += dall * sum1 - dmin * sum2;
 
     }
+#else
+    const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION);  // 0...15 or 0...7
+    const int ix  = threadIdx.x%(2*K_QUANTS_PER_ITERATION);  // 0....1 or 0...3
+    const int offset = tid * K_QUANTS_PER_ITERATION;
+
+    uint32_t uaux[2];
+    const uint8_t * d = (const uint8_t *)uaux;
+
+    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+
+        const float   * y = yy + i * QK_K + offset;
+        const uint8_t * q = x[i].qs + offset;
+        const uint32_t * s = (const uint32_t *)x[i].scales;
+
+        uaux[0] = s[0] & 0x0f0f0f0f;
+        uaux[1] = (s[0] >> 4) & 0x0f0f0f0f;
+
+        const half2 * dh = (const half2 *)&x[i].d;
+
+        const float2 dall = __half22float2(dh[0]);
+
+        float sum1 = 0, sum2 = 0;
+        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+            const uint8_t ql = q[l];
+            sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3)
+                  + y[l+16] * d[1] * ((ql >> 2) & 3)
+                  + y[l+32] * d[2] * ((ql >> 4) & 3)
+                  + y[l+48] * d[3] * ((ql >> 6) & 3);
+            sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7];
+        }
+        tmp += dall.x * sum1 - dall.y * sum2;
+    }
+#endif
 
     // sum up partial sums and write back result
     __syncthreads();
@@ -573,16 +706,13 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
     }
 
-    if (tid == 0) {
+    if (threadIdx.x == 0) {
         dst[row] = tmp;
     }
 }
 
 static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
 
-    const uint16_t kmask1 = 0x0303;
-    const uint16_t kmask2 = 0x0f0f;
-
     const int row = blockIdx.y*blockDim.y + threadIdx.y;
     if (row > nrows) return;
 
@@ -591,6 +721,13 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
 
     const block_q3_K * x = (const block_q3_K *)vx + ib0;
 
+    float tmp = 0; // partial sum for thread in warp
+
+#if QK_K == 256
+
+    const uint16_t kmask1 = 0x0303;
+    const uint16_t kmask2 = 0x0f0f;
+
     const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
     const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0,1
 
@@ -610,8 +747,6 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
 
     const uint16_t s_shift = 4*im;
 
-    float tmp = 0; // partial sum for thread in warp
-
     for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
 
         const float   * y  = yy + i * QK_K + y_offset;
@@ -640,6 +775,34 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
         tmp += d * sum;
 
     }
+#else
+
+    const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION);  // 0...15 or 0...7
+    const int ix  = threadIdx.x%(2*K_QUANTS_PER_ITERATION);  // 0....1 or 0...3
+    const int offset = tid * K_QUANTS_PER_ITERATION;         // 0...15 or 0...14
+    const int in = offset/8;                                 // 0 or 1
+    const int im = offset%8;                                 // 0...7
+
+    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+
+        const float   * y = yy + i * QK_K + offset;
+        const uint8_t * q = x[i].qs + offset;
+        const uint8_t * s = x[i].scales;
+
+        const float dall = (float)x[i].d;
+
+        float sum = 0;
+        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+            const uint8_t hl = x[i].hmask[im+l] >> in;
+            const uint8_t ql = q[l];
+            sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))
+                 + y[l+16] * dall * ((s[0] >>  4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4))
+                 + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4))
+                 + y[l+48] * dall * ((s[1] >>  4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4));
+        }
+        tmp += sum;
+    }
+#endif
 
     // sum up partial sums and write back result
     __syncthreads();
@@ -648,22 +811,25 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
     }
 
-    if (tid == 0) {
+    if (threadIdx.x == 0) {
         dst[row] = tmp;
     }
 }
 
 static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
 
-    const uint16_t kmask1 = 0x3f3f;
-    const uint16_t kmask2 = 0x0f0f;
-    const uint16_t kmask3 = 0xc0c0;
-
     const int row = blockIdx.y*blockDim.y + threadIdx.y;
     if (row > nrows) return;
     const int num_blocks_per_row = ncols / QK_K;
     const int ib0 = row*num_blocks_per_row;
 
+    const block_q4_K * x = (const block_q4_K *)vx + ib0;
+
+#if QK_K == 256
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
     const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
     const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0,1
 
@@ -683,8 +849,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
     uint16_t aux[4];
     const uint8_t * sc = (const uint8_t *)aux;
 
-    const block_q4_K * x = (const block_q4_K *)vx + ib0;
-
     float tmp = 0; // partial sum for thread in warp
 
     for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
@@ -713,6 +877,36 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
         tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
 
     }
+#else
+    const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION);  // 0...15
+    const int ix  = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
+
+    const int step = tid * K_QUANTS_PER_ITERATION;
+
+    uint16_t aux16[2];
+    const uint8_t * s = (const uint8_t *)aux16;
+
+    float tmp = 0;
+
+    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+        const uint8_t * q = x[i].qs + step;
+        const float   * y = yy + i*QK_K + step;
+        const uint16_t * a = (const uint16_t *)x[i].scales;
+        aux16[0] = a[0] & 0x0f0f;
+        aux16[1] = (a[0] >> 4) & 0x0f0f;
+        const float d = (float)x[i].d[0];
+        const float m = (float)x[i].d[1];
+        float sum = 0.f;
+        for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
+            sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
+                 + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2])
+                 + y[j+32] * (d * s[1] * (q[j+ 0] >>  4) - m * s[3])
+                 + y[j+48] * (d * s[1] * (q[j+16] >>  4) - m * s[3]);
+        }
+        tmp += sum;
+    }
+
+#endif
 
     // sum up partial sums and write back result
     __syncthreads();
@@ -728,15 +922,19 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
 
 static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) {
 
-    const uint16_t kmask1 = 0x3f3f;
-    const uint16_t kmask2 = 0x0f0f;
-    const uint16_t kmask3 = 0xc0c0;
-
-    //const int row = blockIdx.x*blockDim.y + threadIdx.y;
     const int row = blockIdx.x;
     const int num_blocks_per_row = ncols / QK_K;
     const int ib0 = row*num_blocks_per_row;
 
+    const block_q5_K * x = (const block_q5_K *)vx + ib0;
+
+    float tmp = 0; // partial sum for thread in warp
+
+#if QK_K == 256
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
     const int tid = threadIdx.x/2;  // 0...15
     const int ix  = threadIdx.x%2;
 
@@ -757,10 +955,6 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
     uint16_t aux[4];
     const uint8_t * sc = (const uint8_t *)aux;
 
-    const block_q5_K * x = (const block_q5_K *)vx + ib0;
-
-    float tmp = 0; // partial sum for thread in warp
-
     for (int i = ix; i < num_blocks_per_row; i += 2) {
 
         const uint8_t * ql1 = x[i].qs + q_offset;
@@ -793,8 +987,31 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
                   + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
         }
         tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
+    }
 
+#else
+    const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION);  // 0...15
+    const int ix  = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
+    const int step = tid * K_QUANTS_PER_ITERATION;
+    const int im = step/8;
+    const int in = step%8;
+
+    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+        const uint8_t * q = x[i].qs + step;
+        const int8_t  * s = x[i].scales;
+        const float   * y = yy + i*QK_K + step;
+        const float     d = x[i].d;
+        float sum = 0.f;
+        for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
+            const uint8_t h = x[i].qh[in+j] >> im;
+            sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16))
+                 + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16))
+                 + y[j+32] * d * s[2] * ((q[j+ 0] >>  4) - ((h >> 4) & 1 ? 0 : 16))
+                 + y[j+48] * d * s[3] * ((q[j+16] >>  4) - ((h >> 6) & 1 ? 0 : 16));
+        }
+        tmp += sum;
     }
+#endif
 
     // sum up partial sums and write back result
     __syncthreads();
@@ -803,7 +1020,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
     }
 
-    if (tid == 0) {
+    if (threadIdx.x == 0) {
         dst[row] = tmp;
     }
 }
@@ -820,6 +1037,8 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
 
     const block_q6_K * x = (const block_q6_K *)vx + ib0;
 
+#if QK_K == 256
+
     const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
     const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0, 1
 
@@ -874,6 +1093,37 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
 
     }
 
+#else
+
+    const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION);  // 0...7
+    const int ix  = threadIdx.x%(2*K_QUANTS_PER_ITERATION);  // 0...3
+
+    const int step = tid * K_QUANTS_PER_ITERATION;
+
+    float tmp = 0; // partial sum for thread in warp
+
+    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+
+        const float   * y  = yy + i * QK_K + step;
+        const uint8_t * ql = x[i].ql + step;
+        const uint8_t * qh = x[i].qh + step;
+        const int8_t  * s  = x[i].scales;
+
+        const float d = x[i+0].d;
+
+        float sum = 0;
+        for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
+            sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32)
+                 + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32)
+                 + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >>  4) | ((qh[j] & 0x30) >> 0)) - 32)
+                 + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >>  4) | ((qh[j] & 0xc0) >> 2)) - 32);
+        }
+        tmp += sum;
+
+    }
+
+#endif
+
     // sum up partial sums and write back result
     __syncthreads();
 #pragma unroll
@@ -1252,12 +1502,20 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
 
 static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
+#if QK_K == 256
     dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
+#else
+    dequantize_block_q2_K<<<nb, 32, 0, stream>>>(vx, y);
+#endif
 }
 
 static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
+#if QK_K == 256
     dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
+#else
+    dequantize_block_q3_K<<<nb, 32, 0, stream>>>(vx, y);
+#endif
 }
 
 static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -1267,12 +1525,20 @@ static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cu
 
 static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
+#if QK_K == 256
     dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
+#else
+    dequantize_block_q5_K<<<nb, 32, 0, stream>>>(vx, y);
+#endif
 }
 
 static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
+#if QK_K == 256
     dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
+#else
+    dequantize_block_q6_K<<<nb, 32, 0, stream>>>(vx, y);
+#endif
 }
 
 static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
@@ -2553,6 +2819,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
 
     tensor->backend = GGML_BACKEND_GPU;
     struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
+    memset(extra, 0, sizeof(*extra));
 
     const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) ||
         tensor->op == GGML_OP_VIEW;
index a7e104dc76fcafbe7bb405dfd235778eb79d0388..7551231b9cf32cf2705de66fce9a15f339943370 100644 (file)
@@ -51,21 +51,21 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(get_rows_f16);
     GGML_METAL_DECL_KERNEL(get_rows_q4_0);
     GGML_METAL_DECL_KERNEL(get_rows_q4_1);
-    GGML_METAL_DECL_KERNEL(get_rows_q2_k);
-    GGML_METAL_DECL_KERNEL(get_rows_q3_k);
-    GGML_METAL_DECL_KERNEL(get_rows_q4_k);
-    GGML_METAL_DECL_KERNEL(get_rows_q5_k);
-    GGML_METAL_DECL_KERNEL(get_rows_q6_k);
+    GGML_METAL_DECL_KERNEL(get_rows_q2_K);
+    GGML_METAL_DECL_KERNEL(get_rows_q3_K);
+    GGML_METAL_DECL_KERNEL(get_rows_q4_K);
+    GGML_METAL_DECL_KERNEL(get_rows_q5_K);
+    GGML_METAL_DECL_KERNEL(get_rows_q6_K);
     GGML_METAL_DECL_KERNEL(rms_norm);
     GGML_METAL_DECL_KERNEL(norm);
     GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q3_k_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q5_k_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
     GGML_METAL_DECL_KERNEL(rope);
     GGML_METAL_DECL_KERNEL(alibi_f32);
     GGML_METAL_DECL_KERNEL(cpy_f32_f16);
@@ -132,7 +132,13 @@ struct ggml_metal_context * ggml_metal_init(void) {
             exit(1);
         }
 
+#ifdef GGML_QKK_64
+        MTLCompileOptions* options = [MTLCompileOptions new];
+        options.preprocessorMacros = @{ @"QK_K" : @(64) };
+        ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
+#else
         ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
+#endif
         if (error) {
             fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
             exit(1);
@@ -159,21 +165,21 @@ struct ggml_metal_context * ggml_metal_init(void) {
         GGML_METAL_ADD_KERNEL(get_rows_f16);
         GGML_METAL_ADD_KERNEL(get_rows_q4_0);
         GGML_METAL_ADD_KERNEL(get_rows_q4_1);
-        GGML_METAL_ADD_KERNEL(get_rows_q2_k);
-        GGML_METAL_ADD_KERNEL(get_rows_q3_k);
-        GGML_METAL_ADD_KERNEL(get_rows_q4_k);
-        GGML_METAL_ADD_KERNEL(get_rows_q5_k);
-        GGML_METAL_ADD_KERNEL(get_rows_q6_k);
+        GGML_METAL_ADD_KERNEL(get_rows_q2_K);
+        GGML_METAL_ADD_KERNEL(get_rows_q3_K);
+        GGML_METAL_ADD_KERNEL(get_rows_q4_K);
+        GGML_METAL_ADD_KERNEL(get_rows_q5_K);
+        GGML_METAL_ADD_KERNEL(get_rows_q6_K);
         GGML_METAL_ADD_KERNEL(rms_norm);
         GGML_METAL_ADD_KERNEL(norm);
         GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q3_k_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q5_k_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
         GGML_METAL_ADD_KERNEL(rope);
         GGML_METAL_ADD_KERNEL(alibi_f32);
         GGML_METAL_ADD_KERNEL(cpy_f32_f16);
@@ -662,7 +668,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 4;
                                             nth1 = 16;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
                                         } break;
                                     case GGML_TYPE_Q3_K:
                                         {
@@ -671,7 +677,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 4;
                                             nth1 = 16;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_k_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
                                         } break;
                                     case GGML_TYPE_Q4_K:
                                         {
@@ -680,7 +686,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 4;
                                             nth1 = 16;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
                                         } break;
                                     case GGML_TYPE_Q5_K:
                                         {
@@ -689,7 +695,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 4;
                                             nth1 = 16;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_k_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
                                         } break;
                                     case GGML_TYPE_Q6_K:
                                         {
@@ -698,7 +704,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 4;
                                             nth1 = 16;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
                                         } break;
                                     default:
                                         {
@@ -750,11 +756,11 @@ void ggml_metal_graph_compute(
                                 case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
                                 case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
                                 case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
-                                case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
-                                case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_k]; break;
-                                case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
-                                case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_k]; break;
-                                case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
+                                case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
+                                case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
+                                case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
+                                case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
+                                case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
                                 default: GGML_ASSERT(false && "not implemented");
                             }
 
index d1e49222db2eb6c9c7614978fc88c349b86ed5e9..e62fe6842ea72bf950619935c74e9fecf1d10799 100644 (file)
@@ -428,7 +428,7 @@ kernel void kernel_mul_mat_q4_0_f32(
     }
     threadgroup_barrier(mem_flags::mem_threadgroup);
     if (ith == 0) {
-        for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
+        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
         dst[r1*ne0 + r0] = sum[0];
     }
 }
@@ -497,7 +497,7 @@ kernel void kernel_mul_mat_q4_1_f32(
     }
     threadgroup_barrier(mem_flags::mem_threadgroup);
     if (ith == 0) {
-        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
+        for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
         dst[r1*ne0 + r0] = sum[0];
     }
 }
@@ -775,47 +775,76 @@ kernel void kernel_cpy_f32_f32(
 
 //============================================ k-quants ======================================================
 
+#ifndef QK_K
 #define QK_K 256
+#else
+static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64");
+#endif
+
+#if QK_K == 256
+#define K_SCALE_SIZE 12
+#else
+#define K_SCALE_SIZE 4
+#endif
 
 typedef struct {
     uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
     uint8_t qs[QK_K/4];      // quants
     half d;           // super-block scale for quantized scales
     half dmin;        // super-block scale for quantized mins
-} block_q2_k;
+} block_q2_K;
 // 84 bytes / block
 
 typedef struct {
     uint8_t hmask[QK_K/8];     // quants - high bit
     uint8_t qs[QK_K/4];        // quants - low 2 bits
-    uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
-    half d;                    // super-block scale
-} block_q3_k;
-// 110 bytes / block
-
+#if QK_K == 64
+    uint8_t scales[2];
+#else
+    uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
+#endif
+    half d;             // super-block scale
+} block_q3_K;
+
+#if QK_K == 64
+typedef struct {
+    half    d[2];          // super-block scales/mins
+    uint8_t scales[2];
+    uint8_t qs[QK_K/2];    // 4-bit quants
+} block_q4_K;
+#else
 typedef struct {
     half d;             // super-block scale for quantized scales
     half dmin;          // super-block scale for quantized mins
-    uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
+    uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
     uint8_t qs[QK_K/2];        // 4--bit quants
-} block_q4_k;
-// 144 bytes / block
+} block_q4_K;
+#endif
 
+#if QK_K == 64
+typedef struct {
+    half  d;                     // super-block scales/mins
+    int8_t  scales[QK_K/16];     // 8-bit block scales
+    uint8_t qh[QK_K/8];          // quants, high bit
+    uint8_t qs[QK_K/2];          // quants, low 4 bits
+} block_q5_K;
+#else
 typedef struct {
     half d;                      // super-block scale for quantized scales
     half dmin;                   // super-block scale for quantized mins
     uint8_t scales[3*QK_K/64];   // scales and mins, quantized with 6 bits
     uint8_t qh[QK_K/8];          // quants, high bit
     uint8_t qs[QK_K/2];          // quants, low 4 bits
-} block_q5_k;
+} block_q5_K;
 // 176 bytes / block
+#endif
 
 typedef struct {
     uint8_t ql[QK_K/2];      // quants, lower 4 bits
     uint8_t qh[QK_K/4];      // quants, upper 2 bits
     int8_t  scales[QK_K/16]; // scales, quantized with 8 bits
     half d;                  // super-block scale
-} block_q6_k;
+} block_q6_K;
 // 210 bytes / block
 
 static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
@@ -836,7 +865,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
 
 //========================================== dequantization =============================
 
-static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, int k) {
+static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
@@ -847,6 +876,7 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
 
         device const uint8_t * q = x[i].qs;
 
+#if QK_K == 256
         int is = 0;
         float dl, ml;
         for (int n = 0; n < QK_K; n += 128) {
@@ -865,14 +895,29 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
             }
             q += 32;
         }
+#else
+        float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
+        float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
+        float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
+        float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
+        for (int l = 0; l < 16; ++l) {
+            y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1;
+            y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2;
+            y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3;
+            y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4;
+        }
+        y += QK_K;
+#endif
 
     }
 }
 
-static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, int k) {
+static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
+#if QK_K == 256
+
     const uint16_t kmask1 = 0x0303;
     const uint16_t kmask2 = 0x0f0f;
 
@@ -918,22 +963,49 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i
             }
             q += 32;
         }
+    }
+#else
+    for (int i = 0; i < nb; i++) {
 
+        const float d_all = (float)(x[i].d);
+
+        device const uint8_t * q = x[i].qs;
+        device const uint8_t * hm = x[i].hmask;
+
+        const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
+        const float d2 = d_all * ((x[i].scales[0] >>  4) - 8);
+        const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
+        const float d4 = d_all * ((x[i].scales[1] >>  4) - 8);
+
+        for (int l = 0; l < 8; ++l) {
+            uint8_t h = hm[l];
+            y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
+            y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
+            y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
+            y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
+            y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
+            y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
+            y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
+            y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
+        }
+        y += QK_K;
     }
+#endif
 
 }
 
-static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
+static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
-
     for (int i = 0; i < nb; i++) {
 
+        device const uint8_t * q = x[i].qs;
+
+#if QK_K == 256
         const float d = x[i].d;
         const float min = x[i].dmin;
 
-        device const uint8_t * q = x[i].qs;
         device const uint8_t * scales = x[i].scales;
 
         int is = 0;
@@ -945,14 +1017,29 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i
             for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l]  >> 4) - m2;
             q += 32; is += 2;
         }
+#else
+        device const uint8_t * s = x[i].scales;
+        device const half2 * dh = (device const half2 *)x[i].d;
+        const float2 d = (float2)dh[0];
+        const float d1 = d[0] * (s[0] & 0xF);
+        const float d2 = d[0] * (s[1] & 0xF);
+        const float m1 = d[1] * (s[0] >>  4);
+        const float m2 = d[1] * (s[1] >>  4);
+        for (int l = 0; l < 32; ++l) {
+            y[l+ 0] = d1 * (q[l] & 0xF) - m1;
+            y[l+32] = d2 * (q[l] >>  4) - m2;
+        }
+        y += QK_K;
+#endif
 
     }
 }
 
-static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, int k) {
+static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
+#if QK_K == 256
    for (int i = 0; i < nb; i++) {
 
         const float d = (float)(x[i].d);
@@ -973,10 +1060,32 @@ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, i
             u1 <<= 2; u2 <<= 2;
         }
     }
+#else
+    for (int i = 0; i < nb; i++) {
+
+        const float d = (float)x[i].d;
+
+        device const uint8_t * ql = x[i].qs;
+        device const uint8_t * qh = x[i].qh;
+        device const int8_t  * sc = x[i].scales;
+
+        for (int l = 0; l < 8; ++l) {
+            y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
+            y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
+            y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
+            y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
+            y[l+32] = d * sc[2] * ((ql[l+ 0] >>  4) - (qh[l] & 0x10 ? 0 : 16));
+            y[l+40] = d * sc[2] * ((ql[l+ 8] >>  4) - (qh[l] & 0x20 ? 0 : 16));
+            y[l+48] = d * sc[3] * ((ql[l+16] >>  4) - (qh[l] & 0x40 ? 0 : 16));
+            y[l+56] = d * sc[3] * ((ql[l+24] >>  4) - (qh[l] & 0x80 ? 0 : 16));
+        }
+        y += QK_K;
+    }
+#endif
 
 }
 
-static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) {
+static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
@@ -988,6 +1097,7 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
 
         const float d = x[i].d;
 
+#if QK_K == 256
         for (int n = 0; n < QK_K; n += 128) {
             for (int l = 0; l < 32; ++l) {
                 int is = l/16;
@@ -1005,10 +1115,23 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
             qh += 32;
             sc += 8;
         }
+#else
+        for (int l = 0; l < 16; ++l) {
+            const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
+            const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
+            const int8_t q3 = (int8_t)((ql[l+ 0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
+            const int8_t q4 = (int8_t)((ql[l+16]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+            y[l+ 0] = d * sc[0] * q1;
+            y[l+16] = d * sc[1] * q2;
+            y[l+32] = d * sc[2] * q3;
+            y[l+48] = d * sc[3] * q4;
+        }
+        y  += 64;
+#endif
     }
 }
 
-kernel void kernel_get_rows_q2_k(
+kernel void kernel_get_rows_q2_K(
         device const  void * src0,
         device const   int * src1,
         device       float * dst,
@@ -1019,12 +1142,12 @@ kernel void kernel_get_rows_q2_k(
     const int i = tpig;
     const int r = ((device int32_t *) src1)[i];
 
-    dequantize_row_q2_k(
-            (device const block_q2_k *) ((device char *) src0 + r*nb01),
+    dequantize_row_q2_K(
+            (device const block_q2_K *) ((device char *) src0 + r*nb01),
                        (device float *) ((device char *)  dst + i*nb1), ne00);
 }
 
-kernel void kernel_get_rows_q3_k(
+kernel void kernel_get_rows_q3_K(
         device const  void * src0,
         device const   int * src1,
         device       float * dst,
@@ -1035,12 +1158,12 @@ kernel void kernel_get_rows_q3_k(
     const int i = tpig;
     const int r = ((device int32_t *) src1)[i];
 
-    dequantize_row_q3_k(
-            (device const block_q3_k *) ((device char *) src0 + r*nb01),
+    dequantize_row_q3_K(
+            (device const block_q3_K *) ((device char *) src0 + r*nb01),
                        (device float *) ((device char *)  dst + i*nb1), ne00);
 }
 
-kernel void kernel_get_rows_q4_k(
+kernel void kernel_get_rows_q4_K(
         device const  void * src0,
         device const   int * src1,
         device       float * dst,
@@ -1051,12 +1174,12 @@ kernel void kernel_get_rows_q4_k(
     const int i = tpig;
     const int r = ((device int32_t *) src1)[i];
 
-    dequantize_row_q4_k(
-            (device const block_q4_k *) ((device char *) src0 + r*nb01),
+    dequantize_row_q4_K(
+            (device const block_q4_K *) ((device char *) src0 + r*nb01),
                        (device float *) ((device char *)  dst + i*nb1), ne00);
 }
 
-kernel void kernel_get_rows_q5_k(
+kernel void kernel_get_rows_q5_K(
         device const  void * src0,
         device const   int * src1,
         device       float * dst,
@@ -1067,12 +1190,12 @@ kernel void kernel_get_rows_q5_k(
     const int i = tpig;
     const int r = ((device int32_t *) src1)[i];
 
-    dequantize_row_q5_k(
-            (device const block_q5_k *) ((device char *) src0 + r*nb01),
+    dequantize_row_q5_K(
+            (device const block_q5_K *) ((device char *) src0 + r*nb01),
                        (device float *) ((device char *)  dst + i*nb1), ne00);
 }
 
-kernel void kernel_get_rows_q6_k(
+kernel void kernel_get_rows_q6_K(
         device const  void * src0,
         device const   int * src1,
         device       float * dst,
@@ -1083,14 +1206,14 @@ kernel void kernel_get_rows_q6_k(
     const int i = tpig;
     const int r = ((device int32_t *) src1)[i];
 
-    dequantize_row_q6_k(
-            (device const block_q6_k *) ((device char *) src0 + r*nb01),
+    dequantize_row_q6_K(
+            (device const block_q6_K *) ((device char *) src0 + r*nb01),
                        (device float *) ((device char *)  dst + i*nb1), ne00);
 }
 
 //====================================== dot products =========================
 
-kernel void kernel_mul_mat_q2_k_f32(
+kernel void kernel_mul_mat_q2_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -1107,12 +1230,15 @@ kernel void kernel_mul_mat_q2_k_f32(
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
 
-    device const block_q2_k * x = (device const block_q2_k *) src0 + r0*nb;
+    device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
     device const float     * yy = (device const float      *) src1 + r1*ne10;
 
     const int nth = tptg.x*tptg.y;
     const int ith = tptg.y*tpitg.x + tpitg.y;
 
+    float sumf = 0;
+
+#if QK_K == 256
     const int tid = tpitg.y;    // 0...16
     const int il  = tid/4;      // 0...3
     const int ir  = tid%4;      // 0...3
@@ -1125,9 +1251,6 @@ kernel void kernel_mul_mat_q2_k_f32(
     const int y_offset = 64*il + n*ir;
     const int q_offset = 32*ip + n*ir;
 
-    sum[ith] = 0.0f;
-
-    float sumf = 0;
     for (int i = tpitg.x; i < nb; i += tptg.x) {
 
         device const uint8_t * q = x[i].qs + q_offset;
@@ -1140,7 +1263,6 @@ kernel void kernel_mul_mat_q2_k_f32(
 
         device const float   * y = yy + i*QK_K + y_offset;
 
-        //float4 s = {0.f, 0.f, 0.f, 0.f};
         float2 s = {0.f, 0.f};
         float smin = 0;
         for (int l = 0; l < n; ++l) {
@@ -1155,25 +1277,38 @@ kernel void kernel_mul_mat_q2_k_f32(
         sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
 
     }
-    sum[ith] = sumf;
+#else
+    const int il = 4 * tpitg.x;
 
-    //int mask1 = (ith%4 == 0);
-    //int mask2 = (ith%16 == 0);
+    uint32_t aux[2];
+    thread const uint8_t * d = (thread const uint8_t *)aux;
+    thread const uint8_t * m = (thread const uint8_t *)aux + 4;
 
-    //threadgroup_barrier(mem_flags::mem_threadgroup);
-    //for (int i = 1; i < 4; ++i) sum[ith] += mask1 * sum[ith + i];
-    //threadgroup_barrier(mem_flags::mem_threadgroup);
-    //for (int i = 4; i < 16; i += 4) sum[ith] += mask2 * sum[ith + i];
-    //threadgroup_barrier(mem_flags::mem_threadgroup);
-    //if (ith == 0) {
-    //    for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
-    //    dst[r1*ne0 + r0] = sum[0];
-    //}
+    for (int i = tpitg.y; i < nb; i += tptg.y) {
+
+        device const uint8_t * q = x[i].qs + il;
+        device const float   * y = yy + i*QK_K + il;
+
+        const float dall = (float)x[i].d;
+        const float dmin = (float)x[i].dmin;
+
+        device const uint32_t * a = (device const uint32_t *)x[i].scales;
+        aux[0] = a[0] & 0x0f0f0f0f;
+        aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
+
+        for (int l = 0; l < 4; ++l) {
+            sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
+                  + y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
+                  + y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
+                  + y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
+        }
+    }
+#endif
+
+    sum[ith] = sumf;
 
     //
     // Accumulate the sum from all threads in the threadgroup
-    // This version is slightly faster than the commented out one below,
-    // which I copy-pasted from ggerganov's q4_0 dot product for metal.
     //
     threadgroup_barrier(mem_flags::mem_threadgroup);
     if (ith%4 == 0) {
@@ -1190,7 +1325,7 @@ kernel void kernel_mul_mat_q2_k_f32(
     }
 }
 
-kernel void kernel_mul_mat_q3_k_f32(
+kernel void kernel_mul_mat_q3_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -1203,23 +1338,25 @@ kernel void kernel_mul_mat_q3_k_f32(
         uint2 tpitg[[thread_position_in_threadgroup]],
         uint2  tptg[[threads_per_threadgroup]]) {
 
-    const uint16_t kmask1 = 0x0303;
-    const uint16_t kmask2 = 0x0f0f;
-
-    const uint8_t m3 = 3;
-    const int8_t  m4 = 4;
-
     const int nb = ne00/QK_K;
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
 
-    device const block_q3_k * x = (device const block_q3_k *) src0 + r0*nb;
+    device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
     device const float     * yy = (device const float      *) src1 + r1*ne10;
 
     const int nth = tptg.x*tptg.y;
     const int ith = tptg.y*tpitg.x + tpitg.y;
 
+#if QK_K == 256
+
+    const uint8_t m3 = 3;
+    const int8_t  m4 = 4;
+
+    const uint16_t kmask1 = 0x0303;
+    const uint16_t kmask2 = 0x0f0f;
+
     const int tid = tpitg.y;        // expecting 16
     const int ip  = tid/8;          // 0 or 1
     const int il  = tid/2 - 4*ip;   // 0...3
@@ -1273,6 +1410,39 @@ kernel void kernel_mul_mat_q3_k_f32(
 
     //sum[ith] = sumf;
     sum[ith] = sumf1 - 32.f*sumf2;
+#else
+    const int il = 4 * tpitg.x;  // 0, 4, 8, 12
+    const int im = il/8;         // 0, 0, 1, 1
+    const int in = il%8;         // 0, 4, 0, 4
+
+    float sumf = 0;
+
+    for (int i = tpitg.y; i < nb; i += tptg.y) {
+
+        const float d_all = (float)(x[i].d);
+
+        device const uint8_t * q = x[i].qs + il;
+        device const uint8_t * h = x[i].hmask + in;
+        device const float   * y = yy + i * QK_K + il;
+
+        const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
+        const float d2 = d_all * ((x[i].scales[0] >>  4) - 8);
+        const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
+        const float d4 = d_all * ((x[i].scales[1] >>  4) - 8);
+
+        for (int l = 0; l < 4; ++l) {
+            const uint8_t hm = h[l] >> im;
+            sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
+                  + y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
+                  + y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
+                  + y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
+        }
+
+    }
+
+    sum[ith] = sumf;
+
+#endif
 
     //
     // Accumulate the sum from all threads in the threadgroup
@@ -1293,7 +1463,7 @@ kernel void kernel_mul_mat_q3_k_f32(
 
 }
 
-kernel void kernel_mul_mat_q4_k_f32(
+kernel void kernel_mul_mat_q4_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -1305,21 +1475,25 @@ kernel void kernel_mul_mat_q4_k_f32(
         uint2 tpitg[[thread_position_in_threadgroup]],
         uint2  tptg[[threads_per_threadgroup]]) {
 
-    const uint16_t kmask1 = 0x3f3f;
-    const uint16_t kmask2 = 0x0f0f;
-    const uint16_t kmask3 = 0xc0c0;
-
     const int nb = ne00/QK_K;
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
 
-    device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
-    device const float     * yy = (device const float      *) src1 + r1*ne10;
-
     const int nth = tptg.x*tptg.y;
     const int ith = tptg.y*tpitg.x + tpitg.y;
 
+    device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb;
+    device const float     * yy = (device const float      *) src1 + r1*ne10;
+
+    float sumf = 0;
+
+#if QK_K == 256
+
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
     const int tid = tpitg.y;   // 0...16
     const int il  = tid/4;     // 0...3
     const int ir  = tid - 4*il;// 0...3
@@ -1332,11 +1506,8 @@ kernel void kernel_mul_mat_q4_k_f32(
     const int q_offset = 32*im + l0;
     const int y_offset = 64*im + l0;
 
-    sum[ith] = 0.0f;
-
     uchar2 sc1, sc2, sc3, sc4;
 
-    float sumf = 0;
     for (int i = tpitg.x; i < nb; i += tptg.x) {
 
         device const uint8_t * q1 = (x + i)->qs + q_offset;
@@ -1365,6 +1536,30 @@ kernel void kernel_mul_mat_q4_k_f32(
         sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
 
     }
+#else
+    uint16_t aux16[2];
+    thread const uint8_t * scales = (thread const uint8_t *)aux16;
+
+    const int il  = 4*tpitg.x;
+
+    for (int i = tpitg.y; i < nb; i += tptg.y) {
+
+        device const uint8_t * q = x[i].qs + il;
+        device const float   * y = yy + i * QK_K + il;
+
+        const float d = (float)x[i].d[0];
+        const float m = (float)x[i].d[1];
+
+        device const uint16_t * a = (device const uint16_t *)x[i].scales;
+        aux16[0] = a[0] & 0x0f0f;
+        aux16[1] = (a[0] >> 4) & 0x0f0f;
+
+        for (int l = 0; l < 4; ++l) {
+            sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16])
+                  + d * scales[1] * (y[l+32] * (q[l] >>  4) + y[l+48] * (q[l+16] >>  4)) - m * scales[3] * (y[l+32] + y[l+48]);
+        }
+    }
+#endif
 
     sum[ith] = sumf;
 
@@ -1401,7 +1596,7 @@ kernel void kernel_mul_mat_q4_k_f32(
     //}
 }
 
-kernel void kernel_mul_mat_q5_k_f32(
+kernel void kernel_mul_mat_q5_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -1413,21 +1608,25 @@ kernel void kernel_mul_mat_q5_k_f32(
         uint2 tpitg[[thread_position_in_threadgroup]],
         uint2  tptg[[threads_per_threadgroup]]) {
 
-    const uint16_t kmask1 = 0x3f3f;
-    const uint16_t kmask2 = 0x0f0f;
-    const uint16_t kmask3 = 0xc0c0;
-
     const int nb = ne00/QK_K;
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
 
-    device const block_q5_k * x = (device const block_q5_k *) src0 + r0*nb;
+    device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb;
     device const float     * yy = (device const float      *) src1 + r1*ne10;
 
     const int nth = tptg.x*tptg.y;
     const int ith = tptg.y*tpitg.x + tpitg.y;
 
+    float sumf = 0;
+
+#if QK_K == 256
+
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
     const int tid = tpitg.y;   // 0...16
     const int il  = tid/4;     // 0...3
     const int ir  = tid - 4*il;// 0...3
@@ -1447,7 +1646,6 @@ kernel void kernel_mul_mat_q5_k_f32(
 
     uchar2 sc1, sc2, sc3, sc4;
 
-    float sumf = 0;
     for (int i = tpitg.x; i < nb; i += tptg.x) {
 
         device const uint8_t * q1 = (x + i)->qs + q_offset;
@@ -1479,6 +1677,28 @@ kernel void kernel_mul_mat_q5_k_f32(
         sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
 
     }
+#else
+    const int il  = 4 * tpitg.x;  // 0, 4, 8, 12
+    const int im  = il/8;         // 0, 0, 1, 1
+    const int in  = il%8;         // 0, 4, 0, 4
+
+    for (int i = tpitg.y; i < nb; i += tptg.y) {
+
+        const float d = (float)x[i].d;
+        device const uint8_t * q = x[i].qs + il;
+        device const uint8_t * h = x[i].qh + in;
+        device const int8_t  * s = x[i].scales;
+        device const float   * y = yy + i*QK_K + il;
+
+        for (int l = 0; l < 4; ++l) {
+            const uint8_t hl = h[l] >> im;
+            sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16))
+                  + y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16))
+                  + y[l+32] * d * s[2] * ((q[l+ 0] >>  4) - (hl & 0x10 ? 0 : 16))
+                  + y[l+48] * d * s[3] * ((q[l+16] >>  4) - (hl & 0x40 ? 0 : 16));
+        }
+    }
+#endif
     sum[ith] = sumf;
 
     //
@@ -1500,7 +1720,7 @@ kernel void kernel_mul_mat_q5_k_f32(
 
 }
 
-kernel void kernel_mul_mat_q6_k_f32(
+kernel void kernel_mul_mat_q6_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -1522,12 +1742,15 @@ kernel void kernel_mul_mat_q6_k_f32(
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
 
-    device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb;
+    device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
     device const float     * yy = (device const float      *) src1 + r1*ne10;
 
     const int nth = tptg.x*tptg.y;
     const int ith = tptg.y*tpitg.x + tpitg.y;
 
+    float sumf = 0;
+
+#if QK_K == 256
     // Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
     const int iqs  = 16 * tpitg.y;
     const int ip   = iqs / 128;         // 0 or 1
@@ -1540,7 +1763,6 @@ kernel void kernel_mul_mat_q6_k_f32(
     const int q_offset_l = 64*ip + l0;
     const int q_offset_h = 32*ip + l0;
 
-    float sumf = 0;
     for (int i = tpitg.x; i < nb; i += tptg.x) {
 
         device const uint8_t * ql = x[i].ql + q_offset_l;
@@ -1562,6 +1784,28 @@ kernel void kernel_mul_mat_q6_k_f32(
         sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
 
     }
+#else
+    const int il  = 4*tpitg.x;    // 0, 4, 8, 12
+
+    for (int i = tpitg.y; i < nb; i += tptg.y) {
+        device const float * y = yy + i * QK_K + il;
+        device const uint8_t * ql = x[i].ql + il;
+        device const uint8_t * qh = x[i].qh + il;
+        device const int8_t  * s  = x[i].scales;
+
+        const float d = x[i].d;
+
+        float4 sums = {0.f, 0.f, 0.f, 0.f};
+        for (int l = 0; l < 4; ++l) {
+            sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
+            sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
+            sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >>  4) | ((qh[l] & kmask3) >> 0)) - 32);
+            sums[3] += y[l+48] * ((int8_t)((ql[l+16] >>  4) | ((qh[l] & kmask4) >> 2)) - 32);
+        }
+        sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
+    }
+
+#endif
 
     sum[ith] = sumf;
 
index 1a441eb9874df068402b4a78d86f65c4a0cbdd63..c179bee9383880d4346c7d3875fefb52350a25cd 100644 (file)
@@ -1,3 +1,4 @@
+#define _GNU_SOURCE // Defines CLOCK_MONOTONIC on Linux
 #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
 
 #include "ggml.h"
@@ -90,6 +91,11 @@ static int sched_yield (void) {
 #include <stdatomic.h>
 
 typedef void* thread_ret_t;
+
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
 #endif
 
 // __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
@@ -118,6 +124,30 @@ typedef void* thread_ret_t;
 #define GGML_SOFT_MAX_UNROLL 4
 #define GGML_VEC_DOT_UNROLL  2
 
+//
+// logging
+//
+
+#if (GGML_DEBUG >= 1)
+#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG(...)
+#endif
+
+#if (GGML_DEBUG >= 5)
+#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG_5(...)
+#endif
+
+#if (GGML_DEBUG >= 10)
+#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG_10(...)
+#endif
+
+#define GGML_PRINT(...) printf(__VA_ARGS__)
+
 #ifdef GGML_USE_ACCELERATE
 // uncomment to use vDSP for soft max computation
 // note: not sure if it is actually faster
@@ -458,7 +488,6 @@ void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n) {
     }
 }
 
-
 //
 // timing
 //
@@ -521,6 +550,7 @@ int64_t ggml_cycles_per_ms(void) {
 #define ggml_perf_cycles_per_ms() 0
 #endif
 
+
 //
 // cache line
 //
@@ -3842,12 +3872,31 @@ struct ggml_context_container {
     struct ggml_context context;
 };
 
+//
+// NUMA support
+//
+
+#define GGML_NUMA_MAX_NODES 8
+#define GGML_NUMA_MAX_CPUS 512
+
+struct ggml_numa_node {
+    uint32_t cpus[GGML_NUMA_MAX_CPUS]; // hardware threads on this node
+    uint32_t n_cpus;
+};
+
+struct ggml_numa_nodes {
+    struct ggml_numa_node nodes[GGML_NUMA_MAX_NODES];
+    uint32_t n_nodes;
+    uint32_t total_cpus; // hardware threads on system
+};
+
 //
 // ggml state
 //
 
 struct ggml_state {
     struct ggml_context_container contexts[GGML_MAX_CONTEXTS];
+    struct ggml_numa_nodes numa;
 };
 
 // global state
@@ -3872,6 +3921,75 @@ inline static void ggml_critical_section_end(void) {
     atomic_fetch_sub(&g_state_barrier, 1);
 }
 
+void ggml_numa_init(void) {
+    if (g_state.numa.n_nodes > 0) {
+        fprintf(stderr, "ggml_numa_init: NUMA already initialized\n");
+
+        return;
+    }
+
+#ifdef __linux__
+    struct stat st;
+    char path[256];
+    int rv;
+
+    // enumerate nodes
+    while (g_state.numa.n_nodes < GGML_NUMA_MAX_NODES) {
+        rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u", g_state.numa.n_nodes);
+        GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
+        if (stat(path, &st) != 0) { break; }
+        ++g_state.numa.n_nodes;
+    }
+
+    // enumerate CPUs
+    while (g_state.numa.total_cpus < GGML_NUMA_MAX_CPUS) {
+        rv = snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%u", g_state.numa.total_cpus);
+        GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
+        if (stat(path, &st) != 0) { break; }
+        ++g_state.numa.total_cpus;
+    }
+
+    GGML_PRINT_DEBUG("found %u numa nodes, %u CPUs\n", g_state.numa.n_nodes, g_state.numa.total_cpus);
+
+    if (g_state.numa.n_nodes < 1 || g_state.numa.total_cpus < 1) {
+        g_state.numa.n_nodes = 0;
+        return;
+    }
+
+    for (uint32_t n = 0; n < g_state.numa.n_nodes; ++n) {
+        struct ggml_numa_node * node = &g_state.numa.nodes[n];
+        GGML_PRINT_DEBUG("CPUs on node %u:", n);
+        node->n_cpus = 0;
+        for (uint32_t c = 0; c < g_state.numa.total_cpus; ++c) {
+            rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u/cpu%u", n, c);
+            GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
+            if (stat(path, &st) == 0) {
+                node->cpus[node->n_cpus++] = c;
+                GGML_PRINT_DEBUG(" %u", c);
+            }
+        }
+        GGML_PRINT_DEBUG("\n");
+    }
+
+    if (ggml_is_numa()) {
+        FILE *fptr = fopen("/proc/sys/kernel/numa_balancing", "r");
+        if (fptr != NULL) {
+            char buf[42];
+            if (fgets(buf, sizeof(buf), fptr) && strncmp(buf, "0\n", sizeof(buf)) != 0) {
+                GGML_PRINT("WARNING: /proc/sys/kernel/numa_balancing is enabled, this has been observed to impair performance\n");
+            }
+            fclose(fptr);
+        }
+    }
+#else
+    // TODO
+#endif
+}
+
+bool ggml_is_numa(void) {
+    return g_state.numa.n_nodes > 1;
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 
 void ggml_print_object(const struct ggml_object * obj) {
@@ -4128,6 +4246,10 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
 
             g_state = (struct ggml_state) {
                 /*.contexts =*/ { { 0 } },
+                /*.numa =*/ {
+                    .n_nodes = 0,
+                    .total_cpus = 0,
+                },
             };
 
             for (int i = 0; i < GGML_MAX_CONTEXTS; ++i) {
@@ -16502,68 +16624,172 @@ typedef pthread_t ggml_thread_t;
 
 #endif
 
+#ifdef __linux__
+void set_numa_thread_affinity(int thread_n, int n_threads) {
+    if (!ggml_is_numa()) {
+        return;
+    }
+
+    // run thread on node_num thread_n / (threads per node)
+    const int node_num = thread_n / ((n_threads + g_state.numa.n_nodes - 1) / g_state.numa.n_nodes);
+    struct ggml_numa_node * node = &g_state.numa.nodes[node_num];
+    size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus);
+
+    cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus);
+    CPU_ZERO_S(setsize, cpus);
+    for (size_t i = 0; i < node->n_cpus; ++i) {
+        CPU_SET_S(node->cpus[i], setsize, cpus);
+    }
+
+    int rv = pthread_setaffinity_np(pthread_self(), setsize, cpus);
+    if (rv) {
+            fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n",
+                    strerror(rv));
+    }
+
+    CPU_FREE(cpus);
+}
+
+void clear_numa_thread_affinity(void) {
+    if (!ggml_is_numa()) {
+        return;
+    }
+
+    size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus);
+
+    cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus);
+    CPU_ZERO_S(setsize, cpus);
+    for (unsigned i = 0; i < g_state.numa.total_cpus; ++i) {
+        CPU_SET_S(i, setsize, cpus);
+    }
+
+    int rv = pthread_setaffinity_np(pthread_self(), setsize, cpus);
+    if (rv) {
+        fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n",
+            strerror(rv));
+    }
+
+    CPU_FREE(cpus);
+}
+#else
+// TODO: Windows etc.
+// (the linux implementation may also work on BSD, someone should test)
+void set_numa_thread_affinity(int thread_n, int n_threads) { UNUSED(thread_n); UNUSED(n_threads);  }
+void clear_numa_thread_affinity(void) {}
+#endif
+
 struct ggml_compute_state_shared {
-    ggml_lock_t spin;
+    struct ggml_cgraph * cgraph;
+
+    int64_t perf_node_start_cycles;
+    int64_t perf_node_start_time_us;
 
     int n_threads;
 
     // synchronization primitives
-    atomic_int  n_ready;
-    atomic_bool has_work;
-    atomic_bool stop; // stop all threads
+    atomic_int n_active; // num active threads
+    atomic_int node_n;   // active graph node
 };
 
 struct ggml_compute_state {
     ggml_thread_t thrd;
-
-    struct ggml_compute_params params;
-    struct ggml_tensor * node;
-
+    int ith;
     struct ggml_compute_state_shared * shared;
 };
 
+static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const struct ggml_compute_state_shared * st) {
+    int64_t cycles_cur  = ggml_perf_cycles()  - st->perf_node_start_cycles;
+    int64_t time_us_cur = ggml_perf_time_us() - st->perf_node_start_time_us;
+
+    node->perf_runs++;
+    node->perf_cycles  += cycles_cur;
+    node->perf_time_us += time_us_cur;
+}
+
 static thread_ret_t ggml_graph_compute_thread(void * data) {
     struct ggml_compute_state * state = (struct ggml_compute_state *) data;
+    struct ggml_cgraph * cgraph = state->shared->cgraph;
 
     const int n_threads = state->shared->n_threads;
+    set_numa_thread_affinity(state->ith, n_threads);
+
+    int node_n = -1;
 
     while (true) {
-        if (atomic_fetch_add(&state->shared->n_ready, 1) == n_threads - 1) {
-            atomic_store(&state->shared->has_work, false);
-        } else {
-            while (atomic_load(&state->shared->has_work)) {
-                if (atomic_load(&state->shared->stop)) {
-                    return 0;
-                }
-                ggml_lock_lock  (&state->shared->spin);
-                ggml_lock_unlock(&state->shared->spin);
+        if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
+            // all other threads are finished and spinning
+            // do finalize and init here so we don't have synchronize again
+            struct ggml_compute_params params = {
+                /*.type  =*/ GGML_TASK_FINALIZE,
+                /*.ith   =*/ 0,
+                /*.nth   =*/ 0,
+                /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
+                /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
+            };
+
+            if (node_n != -1) {
+                /* FINALIZE */
+                struct ggml_tensor * node = state->shared->cgraph->nodes[node_n];
+                params.nth = node->n_tasks;
+                ggml_compute_forward(&params, node);
+                ggml_graph_compute_perf_stats_node(node, state->shared);
             }
-        }
 
-        atomic_fetch_sub(&state->shared->n_ready, 1);
+            // distribute new work or execute it direct if 1T
+            while (++node_n < cgraph->n_nodes) {
+                GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes);
+
+                struct ggml_tensor * node = cgraph->nodes[node_n];
+
+                state->shared->perf_node_start_cycles  = ggml_perf_cycles();
+                state->shared->perf_node_start_time_us = ggml_perf_time_us();
+
+                /* INIT */
+                params.type = GGML_TASK_INIT;
+                params.nth  = node->n_tasks;
+                ggml_compute_forward(&params, node);
 
-        // wait for work
-        while (!atomic_load(&state->shared->has_work)) {
-            if (atomic_load(&state->shared->stop)) {
-                return 0;
+                if (node->n_tasks == 1) {
+                    // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1,
+                    // they do something more efficient than spinning (?)
+                    params.type = GGML_TASK_COMPUTE;
+                    ggml_compute_forward(&params, node);
+
+                    params.type = GGML_TASK_FINALIZE;
+                    ggml_compute_forward(&params, node);
+                    ggml_graph_compute_perf_stats_node(node, state->shared);
+                } else {
+                    break;
+                }
             }
-            ggml_lock_lock  (&state->shared->spin);
-            ggml_lock_unlock(&state->shared->spin);
+
+            atomic_store(&state->shared->n_active, n_threads);
+            atomic_store(&state->shared->node_n,   node_n);
+        } else {
+            // wait for other threads to finish
+            const int last = node_n;
+            do {
+                sched_yield();
+                node_n = atomic_load(&state->shared->node_n);
+            } while (node_n == last);
         }
 
         // check if we should stop
-        if (atomic_load(&state->shared->stop)) {
-            break;
-        }
+        if (node_n >= cgraph->n_nodes) break;
 
-        if (state->node) {
-            if (state->params.ith < state->params.nth) {
-                ggml_compute_forward(&state->params, state->node);
-            }
+        /* COMPUTE */
+        struct ggml_tensor * node = cgraph->nodes[node_n];
 
-            state->node = NULL;
-        } else {
-            break;
+        struct ggml_compute_params params = {
+            /*.type  =*/ GGML_TASK_COMPUTE,
+            /*.ith   =*/ state->ith,
+            /*.nth   =*/ node->n_tasks,
+            /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
+            /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
+        };
+
+        if (state->ith < node->n_tasks) {
+            ggml_compute_forward(&params, node);
         }
     }
 
@@ -16574,39 +16800,14 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
     const int n_threads = cgraph->n_threads;
 
     struct ggml_compute_state_shared state_shared = {
-        /*.spin      =*/ GGML_LOCK_INITIALIZER,
-        /*.n_threads =*/ n_threads,
-        /*.n_ready   =*/ 0,
-        /*.has_work  =*/ false,
-        /*.stop      =*/ false,
+        /*.cgraph                  =*/ cgraph,
+        /*.perf_node_start_cycles  =*/ 0,
+        /*.perf_node_start_time_us =*/ 0,
+        /*.n_threads               =*/ n_threads,
+        /*.n_active                =*/ n_threads,
+        /*.node_n                  =*/ -1,
     };
-    struct ggml_compute_state * workers = n_threads > 1 ? alloca(sizeof(struct ggml_compute_state)*(n_threads - 1)) : NULL;
-
-    // create thread pool
-    if (n_threads > 1) {
-        ggml_lock_init(&state_shared.spin);
-
-        atomic_store(&state_shared.has_work, true);
-
-        for (int j = 0; j < n_threads - 1; j++) {
-            workers[j] = (struct ggml_compute_state) {
-                .thrd   = 0,
-                .params = {
-                    .type  = GGML_TASK_COMPUTE,
-                    .ith   = j + 1,
-                    .nth   = n_threads,
-                    .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
-                    .wdata = cgraph->work ? cgraph->work->data : NULL,
-                },
-                .node   = NULL,
-                .shared = &state_shared,
-            };
-
-            int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
-            GGML_ASSERT(rc == 0);
-            UNUSED(rc);
-        }
-    }
+    struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
 
     // initialize tasks + work buffer
     {
@@ -16750,7 +16951,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                     } break;
                 case GGML_OP_SCALE:
                     {
-                        node->n_tasks = n_threads;
+                        node->n_tasks = 1;
                     } break;
                 case GGML_OP_SET:
                 case GGML_OP_CONT:
@@ -16954,166 +17155,37 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
         }
     }
 
-    const int64_t perf_start_cycles  = ggml_perf_cycles();
-    const int64_t perf_start_time_us = ggml_perf_time_us();
-
-    for (int i = 0; i < cgraph->n_nodes; i++) {
-        GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, i, cgraph->n_nodes);
-
-        struct ggml_tensor * node = cgraph->nodes[i];
-
-        // TODO: this could be used to avoid unnecessary computations, but it needs to be improved
-        //if (node->grad == NULL && node->perf_runs > 0) {
-        //    continue;
-        //}
-
-        const int64_t perf_node_start_cycles  = ggml_perf_cycles();
-        const int64_t perf_node_start_time_us = ggml_perf_time_us();
-
-        // INIT
-        struct ggml_compute_params params = {
-            /*.type  =*/ GGML_TASK_INIT,
-            /*.ith   =*/ 0,
-            /*.nth   =*/ node->n_tasks,
-            /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
-            /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
-        };
-
-        ggml_compute_forward(&params, node);
-
-        // COMPUTE
-        if (node->n_tasks > 1) {
-            if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
-                atomic_store(&state_shared.has_work, false);
-            }
-
-            while (atomic_load(&state_shared.has_work)) {
-                ggml_lock_lock  (&state_shared.spin);
-                ggml_lock_unlock(&state_shared.spin);
-            }
-
-            // launch thread pool
-            for (int j = 0; j < n_threads - 1; j++) {
-                workers[j].params = (struct ggml_compute_params) {
-                    .type  = GGML_TASK_COMPUTE,
-                    .ith   = j + 1,
-                    .nth   = node->n_tasks,
-                    .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
-                    .wdata = cgraph->work ? cgraph->work->data : NULL,
-                };
-                workers[j].node = node;
-            }
-
-            atomic_fetch_sub(&state_shared.n_ready, 1);
-
-            while (atomic_load(&state_shared.n_ready) > 0) {
-                ggml_lock_lock  (&state_shared.spin);
-                ggml_lock_unlock(&state_shared.spin);
-            }
-
-            atomic_store(&state_shared.has_work, true);
-        }
-
-        params.type = GGML_TASK_COMPUTE;
-        ggml_compute_forward(&params, node);
-
-        // wait for thread pool
-        if (node->n_tasks > 1) {
-            if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
-                atomic_store(&state_shared.has_work, false);
-            }
-
-            while (atomic_load(&state_shared.has_work)) {
-                ggml_lock_lock  (&state_shared.spin);
-                ggml_lock_unlock(&state_shared.spin);
-            }
-
-            atomic_fetch_sub(&state_shared.n_ready, 1);
-
-            while (atomic_load(&state_shared.n_ready) != 0) {
-                ggml_lock_lock  (&state_shared.spin);
-                ggml_lock_unlock(&state_shared.spin);
-            }
-        }
-
-        // FINALIZE
-        if (node->n_tasks > 1) {
-            if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
-                atomic_store(&state_shared.has_work, false);
-            }
-
-            while (atomic_load(&state_shared.has_work)) {
-                ggml_lock_lock  (&state_shared.spin);
-                ggml_lock_unlock(&state_shared.spin);
-            }
-
-            // launch thread pool
-            for (int j = 0; j < n_threads - 1; j++) {
-                workers[j].params = (struct ggml_compute_params) {
-                    .type  = GGML_TASK_FINALIZE,
-                    .ith   = j + 1,
-                    .nth   = node->n_tasks,
-                    .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
-                    .wdata = cgraph->work ? cgraph->work->data : NULL,
-                };
-                workers[j].node = node;
-            }
-
-            atomic_fetch_sub(&state_shared.n_ready, 1);
-
-            while (atomic_load(&state_shared.n_ready) > 0) {
-                ggml_lock_lock  (&state_shared.spin);
-                ggml_lock_unlock(&state_shared.spin);
-            }
+    // create thread pool
+    if (n_threads > 1) {
+        for (int j = 1; j < n_threads; ++j) {
+            workers[j] = (struct ggml_compute_state) {
+                .thrd   = 0,
+                .ith = j,
+                .shared = &state_shared,
+            };
 
-            atomic_store(&state_shared.has_work, true);
+            const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
+            GGML_ASSERT(rc == 0);
         }
+    }
+    workers[0].ith = 0;
+    workers[0].shared = &state_shared;
 
-        params.type = GGML_TASK_FINALIZE;
-        ggml_compute_forward(&params, node);
-
-        // wait for thread pool
-        if (node->n_tasks > 1) {
-            if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
-                atomic_store(&state_shared.has_work, false);
-            }
-
-            while (atomic_load(&state_shared.has_work)) {
-                ggml_lock_lock  (&state_shared.spin);
-                ggml_lock_unlock(&state_shared.spin);
-            }
-
-            atomic_fetch_sub(&state_shared.n_ready, 1);
-
-            while (atomic_load(&state_shared.n_ready) != 0) {
-                ggml_lock_lock  (&state_shared.spin);
-                ggml_lock_unlock(&state_shared.spin);
-            }
-        }
+    const int64_t perf_start_cycles  = ggml_perf_cycles();
+    const int64_t perf_start_time_us = ggml_perf_time_us();
 
-        // performance stats (node)
-        {
-            int64_t perf_cycles_cur  = ggml_perf_cycles()  - perf_node_start_cycles;
-            int64_t perf_time_us_cur = ggml_perf_time_us() - perf_node_start_time_us;
+    // this is a work thread too
+    ggml_graph_compute_thread(&workers[0]);
 
-            node->perf_runs++;
-            node->perf_cycles  += perf_cycles_cur;
-            node->perf_time_us += perf_time_us_cur;
-        }
-    }
+    // don't leave affinity set on the main thread
+    clear_numa_thread_affinity();
 
     // join thread pool
     if (n_threads > 1) {
-        atomic_store(&state_shared.stop, true);
-        atomic_store(&state_shared.has_work, true);
-
-        for (int j = 0; j < n_threads - 1; j++) {
-            int rc = ggml_thread_join(workers[j].thrd, NULL);
+        for (int j = 1; j < n_threads; j++) {
+            const int rc = ggml_thread_join(workers[j].thrd, NULL);
             GGML_ASSERT(rc == 0);
-            UNUSED(rc);
         }
-
-        ggml_lock_destroy(&state_shared.spin);
     }
 
     // performance stats (graph)