]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : sync latest repo (mostly refactoring changes)
authorGeorgi Gerganov <redacted>
Sun, 2 Jul 2023 18:45:27 +0000 (21:45 +0300)
committerGeorgi Gerganov <redacted>
Sun, 2 Jul 2023 18:46:09 +0000 (21:46 +0300)
examples/common.cpp
examples/common.h
examples/quantize/quantize.cpp
ggml-cuda.cu
ggml-cuda.h
ggml-metal.m
ggml-metal.metal
ggml-opencl.cpp
ggml.c
ggml.h
whisper.cpp

index fe00278c2584e1c49d243420d3f69bab8c2262ca..960656def7aa9707793fbeef275d69fd14d04cbd 100644 (file)
@@ -39,6 +39,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
             params.top_p = std::stof(argv[++i]);
         } else if (arg == "--temp") {
             params.temp = std::stof(argv[++i]);
+        } else if (arg == "--repeat-last-n") {
+            params.repeat_last_n = std::stof(argv[++i]);
+        } else if (arg == "--repeat-penalty") {
+            params.repeat_penalty = std::stof(argv[++i]);
         } else if (arg == "-b" || arg == "--batch_size") {
             params.n_batch = std::stoi(argv[++i]);
         } else if (arg == "-m" || arg == "--model") {
@@ -90,6 +94,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stderr, "  --top_k N             top-k sampling (default: %d)\n", params.top_k);
     fprintf(stderr, "  --top_p N             top-p sampling (default: %.1f)\n", params.top_p);
     fprintf(stderr, "  --temp N              temperature (default: %.1f)\n", params.temp);
+    fprintf(stderr, "  --repeat-last-n N     last n tokens to consider for penalize (default: %d, 0 = disabled)\n", params.repeat_last_n);
+    fprintf(stderr, "  --repeat-penalty N    penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)\n", (double)params.repeat_penalty);
     fprintf(stderr, "  -b N, --batch_size N  batch size for prompt processing (default: %d)\n", params.n_batch);
     fprintf(stderr, "  -m FNAME, --model FNAME\n");
     fprintf(stderr, "                        model path (default: %s)\n", params.model.c_str());
index 7e9b867d3d6f1f8af3a00140a153aa321c68d72f..12b2b339d670533476d5c3e8acfab140db04c0bb 100644 (file)
@@ -23,6 +23,8 @@ struct gpt_params {
     int32_t top_k = 40;
     float   top_p = 0.9f;
     float   temp  = 0.9f;
+    int32_t repeat_last_n  = 64;
+    float   repeat_penalty = 1.00f;
 
     int32_t n_batch = 8; // batch size for prompt processing
 
index 3df7b1c71232dd68175d330d1dffbf5da6e0e7a6..64e8f35c3863a6be0cfee2211c8be922f95699a7 100644 (file)
@@ -57,7 +57,7 @@ bool whisper_model_quantize(const std::string & fname_inp, const std::string & f
     {
         uint32_t magic;
         finp.read((char *) &magic, sizeof(magic));
-        if (magic != 0x67676d6c) {
+        if (magic != GGML_FILE_MAGIC) {
             fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
             return false;
         }
index 36a251ecce973bddee7df398cd61931fbed22cf9..50df20edd7a7b211324e0af1c18e970bc250251d 100644 (file)
@@ -117,7 +117,13 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo
 
 //================================= k-quants
 
+#ifdef GGML_QKK_64
+#define QK_K 64
+#define K_SCALE_SIZE 4
+#else
 #define QK_K 256
+#define K_SCALE_SIZE 12
+#endif
 
 typedef struct {
     uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
@@ -128,13 +134,25 @@ typedef struct {
 static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
 
 typedef struct {
-    uint8_t hmask[QK_K/8];
-    uint8_t qs[QK_K/4]; // nibbles / quants
-    uint8_t scales[3*QK_K/64];
-    half d;
+    uint8_t hmask[QK_K/8];     // quants - high bit
+    uint8_t qs[QK_K/4];        // quants - low 2 bits
+#ifdef GGML_QKK_64
+    uint8_t scales[2]; // scales, quantized with 8 bits
+#else
+    uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
+#endif
+    half d;             // super-block scale
 } block_q3_K;
-static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");
+//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding");
 
+#ifdef GGML_QKK_64
+typedef struct {
+    half    d[2];              // super-block scales/mins
+    uint8_t scales[2];         // 4-bit block scales/mins
+    uint8_t qs[QK_K/2];        // 4--bit quants
+} block_q4_K;
+static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
+#else
 typedef struct {
     half d;                    // super-block scale for quantized scales
     half dmin;                 // super-block scale for quantized mins
@@ -142,15 +160,26 @@ typedef struct {
     uint8_t qs[QK_K/2];        // 4--bit quants
 } block_q4_K;
 static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
+#endif
 
+#ifdef GGML_QKK_64
 typedef struct {
-    half    d;                   // super-block scale for quantized scales
-    half    dmin;                // super-block scale for quantized mins
-    uint8_t scales[3*QK_K/64];   // scales, quantized with 6 bits
+    half d;                  // super-block scale
+    int8_t scales[QK_K/16];  // block scales
+    uint8_t qh[QK_K/8];      // quants, high bit
+    uint8_t qs[QK_K/2];      // quants, low 4 bits
+} block_q5_K;
+static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
+#else
+typedef struct {
+    half d;               // super-block scale for quantized scales
+    half dmin;            // super-block scale for quantized mins
+    uint8_t scales[K_SCALE_SIZE];   // scales and mins, quantized with 6 bits
     uint8_t qh[QK_K/8];          // quants, high bit
     uint8_t qs[QK_K/2];          // quants, low 4 bits
 } block_q5_K;
-static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
+static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
+#endif
 
 typedef struct {
     uint8_t ql[QK_K/2];   // quants, lower 4 bits
@@ -185,6 +214,11 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
 #endif
 
+struct ggml_tensor_extra_gpu {
+    void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
+    cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
+};
+
 static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
@@ -194,6 +228,15 @@ static __global__ void add_f32(const float * x, const float * y, float * dst, co
     dst[i] = x[i] + y[i];
 }
 
+static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = __hadd(x[i], __float2half(y[i]));
+}
+
 static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
@@ -349,13 +392,14 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
 static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
 
     const int i   = blockIdx.x;
+    const block_q2_K * x = (const block_q2_K *) vx;
+
     const int tid = threadIdx.x;
+#if QK_K == 256
     const int n   = tid/32;
     const int l   = tid - 32*n;
     const int is  = 8*n + l/16;
 
-    const block_q2_K * x = (const block_q2_K *) vx;
-
     const uint8_t q = x[i].qs[32*n + l];
     float * y = yy + i*QK_K + 128*n;
 
@@ -365,21 +409,32 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
     y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
     y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
     y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
+#else
+    const int is = tid/16;  // 0 or 1
+    const int il = tid%16;  // 0...15
+    const uint8_t q = x[i].qs[il] >> (2*is);
+    float * y = yy + i*QK_K + 16*is + il;
+    float dall = x[i].d;
+    float dmin = x[i].dmin;
+    y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
+    y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
+#endif
 
 }
 
 static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
 
-    int r = threadIdx.x/4;
-    int i = blockIdx.x;
-    int tid = r/2;
-    int is0 = r%2;
-    int l0 = 16*is0 + 4*(threadIdx.x%4);
-    int n = tid / 4;
-    int j = tid - 4*n;
-
+    const int i = blockIdx.x;
     const block_q3_K * x = (const block_q3_K *) vx;
 
+#if QK_K == 256
+    const int r = threadIdx.x/4;
+    const int tid = r/2;
+    const int is0 = r%2;
+    const int l0 = 16*is0 + 4*(threadIdx.x%4);
+    const int n = tid / 4;
+    const int j = tid - 4*n;
+
     uint8_t m = 1 << (4*n + j);
     int is = 8*n + 2*j + is0;
     int shift = 2*j;
@@ -396,9 +451,31 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
     const uint8_t * hm = x[i].hmask;
 
     for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
+#else
+    const int tid = threadIdx.x;
+    const int is  = tid/16;  // 0 or 1
+    const int il  = tid%16;  // 0...15
+    const int im  = il/8;    // 0...1
+    const int in  = il%8;    // 0...7
+
+    float * y = yy + i*QK_K + 16*is + il;
+
+    const uint8_t q = x[i].qs[il] >> (2*is);
+    const uint8_t h = x[i].hmask[in] >> (2*is + im);
+    const float   d = (float)x[i].d;
+
+    if (is == 0) {
+        y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
+        y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
+    } else {
+        y[ 0] = d * ((x[i].scales[0] >>  4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
+        y[32] = d * ((x[i].scales[1] >>  4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
+    }
+#endif
 
 }
 
+#if QK_K == 256
 static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
     if (j < 4) {
         d = q[j] & 63; m = q[j + 4] & 63;
@@ -407,19 +484,14 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
         m = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);
     }
 }
+#endif
 
 static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
     const block_q4_K * x = (const block_q4_K *) vx;
 
     const int i = blockIdx.x;
 
-    //// assume 64 threads - this is very slightly better than the one below
-    //const int tid = threadIdx.x;
-    //const int il  = tid/16;
-    //const int ir  = tid%16;
-    //const int is  = 2*il;
-    //const int n   = 2;
-
+#if QK_K == 256
     // assume 32 threads
     const int tid = threadIdx.x;
     const int il  = tid/8;
@@ -443,6 +515,15 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
         y[l + 0] = d1 * (q[l] & 0xF) - m1;
         y[l +32] = d2 * (q[l] >>  4) - m2;
     }
+#else
+    const int tid = threadIdx.x;
+    const uint8_t * q = x[i].qs;
+    float * y = yy + i*QK_K;
+    const float d = (float)x[i].d[0];
+    const float m = (float)x[i].d[1];
+    y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
+    y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >>  4) - m * (x[i].scales[1] >> 4);
+#endif
 }
 
 static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
@@ -450,6 +531,7 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
 
     const int i = blockIdx.x;
 
+#if QK_K == 256
     // assume 64 threads - this is very slightly better than the one below
     const int tid = threadIdx.x;
     const int il  = tid/16;   // il is in 0...3
@@ -476,12 +558,25 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
     hm <<= 1;
     y[32] = d2 * ((ql[ 0] >>  4) + (qh[ 0] & hm ? 16 : 0)) - m2;
     y[33] = d2 * ((ql[ 1] >>  4) + (qh[ 1] & hm ? 16 : 0)) - m2;
+#else
+    const int tid = threadIdx.x;
+    const uint8_t q = x[i].qs[tid];
+    const int im = tid/8;  // 0...3
+    const int in = tid%8;  // 0...7
+    const int is = tid/16; // 0 or 1
+    const uint8_t h = x[i].qh[in] >> im;
+    const float d = x[i].d;
+    float * y = yy + i*QK_K + tid;
+    y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
+    y[32] = d * x[i].scales[is+2] * ((q >>  4) - ((h >> 4) & 1 ? 0 : 16));
+#endif
 }
 
 static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
     const block_q6_K * x = (const block_q6_K *) vx;
 
     const int i = blockIdx.x;
+#if QK_K == 256
 
     // assume 64 threads - this is very slightly better than the one below
     const int tid = threadIdx.x;
@@ -501,6 +596,24 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
     y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
     y[64] = d * sc[4] * ((int8_t)((ql[ 0]  >> 4) | (((qh >> 4) & 3) << 4)) - 32);
     y[96] = d * sc[6] * ((int8_t)((ql[32]  >> 4) | (((qh >> 6) & 3) << 4)) - 32);
+#else
+
+    // assume 32 threads
+    const int tid = threadIdx.x;
+    const int ip  = tid/16;         // 0 or 1
+    const int il  = tid - 16*ip;    // 0...15
+
+    float * y = yy + i*QK_K + 16*ip + il;
+
+    const float d = x[i].d;
+
+    const uint8_t   ql = x[i].ql[16*ip + il];
+    const uint8_t   qh = x[i].qh[il] >> (2*ip);
+    const int8_t  * sc = x[i].scales;
+
+    y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
+    y[32] = d * sc[ip+2] * ((int8_t)((ql  >> 4) | (((qh >> 4) & 3) << 4)) - 32);
+#endif
 }
 
 static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
@@ -515,6 +628,9 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
 
     const block_q2_K * x = (const block_q2_K *)vx + ib0;
 
+    float tmp = 0; // partial sum for thread in warp
+
+#if QK_K == 256
     const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...15
     const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0,1
 
@@ -528,8 +644,6 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
     const int s_offset = 8*im;
     const int y_offset = 128*im + l0;
 
-    float tmp = 0; // partial sum for thread in warp
-
     uint32_t aux[4];
     const uint8_t * d = (const uint8_t *)aux;
     const uint8_t * m = (const uint8_t *)(aux + 2);
@@ -565,6 +679,39 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
         tmp += dall * sum1 - dmin * sum2;
 
     }
+#else
+    const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION);  // 0...15 or 0...7
+    const int ix  = threadIdx.x%(2*K_QUANTS_PER_ITERATION);  // 0....1 or 0...3
+    const int offset = tid * K_QUANTS_PER_ITERATION;
+
+    uint32_t uaux[2];
+    const uint8_t * d = (const uint8_t *)uaux;
+
+    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+
+        const float   * y = yy + i * QK_K + offset;
+        const uint8_t * q = x[i].qs + offset;
+        const uint32_t * s = (const uint32_t *)x[i].scales;
+
+        uaux[0] = s[0] & 0x0f0f0f0f;
+        uaux[1] = (s[0] >> 4) & 0x0f0f0f0f;
+
+        const half2 * dh = (const half2 *)&x[i].d;
+
+        const float2 dall = __half22float2(dh[0]);
+
+        float sum1 = 0, sum2 = 0;
+        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+            const uint8_t ql = q[l];
+            sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3)
+                  + y[l+16] * d[1] * ((ql >> 2) & 3)
+                  + y[l+32] * d[2] * ((ql >> 4) & 3)
+                  + y[l+48] * d[3] * ((ql >> 6) & 3);
+            sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7];
+        }
+        tmp += dall.x * sum1 - dall.y * sum2;
+    }
+#endif
 
     // sum up partial sums and write back result
     __syncthreads();
@@ -573,16 +720,13 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
     }
 
-    if (tid == 0) {
+    if (threadIdx.x == 0) {
         dst[row] = tmp;
     }
 }
 
 static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
 
-    const uint16_t kmask1 = 0x0303;
-    const uint16_t kmask2 = 0x0f0f;
-
     const int row = blockIdx.y*blockDim.y + threadIdx.y;
     if (row > nrows) return;
 
@@ -591,6 +735,13 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
 
     const block_q3_K * x = (const block_q3_K *)vx + ib0;
 
+    float tmp = 0; // partial sum for thread in warp
+
+#if QK_K == 256
+
+    const uint16_t kmask1 = 0x0303;
+    const uint16_t kmask2 = 0x0f0f;
+
     const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
     const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0,1
 
@@ -610,8 +761,6 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
 
     const uint16_t s_shift = 4*im;
 
-    float tmp = 0; // partial sum for thread in warp
-
     for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
 
         const float   * y  = yy + i * QK_K + y_offset;
@@ -640,6 +789,34 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
         tmp += d * sum;
 
     }
+#else
+
+    const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION);  // 0...15 or 0...7
+    const int ix  = threadIdx.x%(2*K_QUANTS_PER_ITERATION);  // 0....1 or 0...3
+    const int offset = tid * K_QUANTS_PER_ITERATION;         // 0...15 or 0...14
+    const int in = offset/8;                                 // 0 or 1
+    const int im = offset%8;                                 // 0...7
+
+    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+
+        const float   * y = yy + i * QK_K + offset;
+        const uint8_t * q = x[i].qs + offset;
+        const uint8_t * s = x[i].scales;
+
+        const float dall = (float)x[i].d;
+
+        float sum = 0;
+        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+            const uint8_t hl = x[i].hmask[im+l] >> in;
+            const uint8_t ql = q[l];
+            sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))
+                 + y[l+16] * dall * ((s[0] >>  4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4))
+                 + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4))
+                 + y[l+48] * dall * ((s[1] >>  4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4));
+        }
+        tmp += sum;
+    }
+#endif
 
     // sum up partial sums and write back result
     __syncthreads();
@@ -648,22 +825,25 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
     }
 
-    if (tid == 0) {
+    if (threadIdx.x == 0) {
         dst[row] = tmp;
     }
 }
 
 static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
 
-    const uint16_t kmask1 = 0x3f3f;
-    const uint16_t kmask2 = 0x0f0f;
-    const uint16_t kmask3 = 0xc0c0;
-
     const int row = blockIdx.y*blockDim.y + threadIdx.y;
     if (row > nrows) return;
     const int num_blocks_per_row = ncols / QK_K;
     const int ib0 = row*num_blocks_per_row;
 
+    const block_q4_K * x = (const block_q4_K *)vx + ib0;
+
+#if QK_K == 256
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
     const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
     const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0,1
 
@@ -683,8 +863,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
     uint16_t aux[4];
     const uint8_t * sc = (const uint8_t *)aux;
 
-    const block_q4_K * x = (const block_q4_K *)vx + ib0;
-
     float tmp = 0; // partial sum for thread in warp
 
     for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
@@ -713,6 +891,36 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
         tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
 
     }
+#else
+    const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION);  // 0...15
+    const int ix  = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
+
+    const int step = tid * K_QUANTS_PER_ITERATION;
+
+    uint16_t aux16[2];
+    const uint8_t * s = (const uint8_t *)aux16;
+
+    float tmp = 0;
+
+    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+        const uint8_t * q = x[i].qs + step;
+        const float   * y = yy + i*QK_K + step;
+        const uint16_t * a = (const uint16_t *)x[i].scales;
+        aux16[0] = a[0] & 0x0f0f;
+        aux16[1] = (a[0] >> 4) & 0x0f0f;
+        const float d = (float)x[i].d[0];
+        const float m = (float)x[i].d[1];
+        float sum = 0.f;
+        for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
+            sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
+                 + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2])
+                 + y[j+32] * (d * s[1] * (q[j+ 0] >>  4) - m * s[3])
+                 + y[j+48] * (d * s[1] * (q[j+16] >>  4) - m * s[3]);
+        }
+        tmp += sum;
+    }
+
+#endif
 
     // sum up partial sums and write back result
     __syncthreads();
@@ -728,15 +936,19 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
 
 static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) {
 
-    const uint16_t kmask1 = 0x3f3f;
-    const uint16_t kmask2 = 0x0f0f;
-    const uint16_t kmask3 = 0xc0c0;
-
-    //const int row = blockIdx.x*blockDim.y + threadIdx.y;
     const int row = blockIdx.x;
     const int num_blocks_per_row = ncols / QK_K;
     const int ib0 = row*num_blocks_per_row;
 
+    const block_q5_K * x = (const block_q5_K *)vx + ib0;
+
+    float tmp = 0; // partial sum for thread in warp
+
+#if QK_K == 256
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
     const int tid = threadIdx.x/2;  // 0...15
     const int ix  = threadIdx.x%2;
 
@@ -757,10 +969,6 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
     uint16_t aux[4];
     const uint8_t * sc = (const uint8_t *)aux;
 
-    const block_q5_K * x = (const block_q5_K *)vx + ib0;
-
-    float tmp = 0; // partial sum for thread in warp
-
     for (int i = ix; i < num_blocks_per_row; i += 2) {
 
         const uint8_t * ql1 = x[i].qs + q_offset;
@@ -793,8 +1001,31 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
                   + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
         }
         tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
+    }
 
+#else
+    const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION);  // 0...15
+    const int ix  = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
+    const int step = tid * K_QUANTS_PER_ITERATION;
+    const int im = step/8;
+    const int in = step%8;
+
+    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+        const uint8_t * q = x[i].qs + step;
+        const int8_t  * s = x[i].scales;
+        const float   * y = yy + i*QK_K + step;
+        const float     d = x[i].d;
+        float sum = 0.f;
+        for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
+            const uint8_t h = x[i].qh[in+j] >> im;
+            sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16))
+                 + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16))
+                 + y[j+32] * d * s[2] * ((q[j+ 0] >>  4) - ((h >> 4) & 1 ? 0 : 16))
+                 + y[j+48] * d * s[3] * ((q[j+16] >>  4) - ((h >> 6) & 1 ? 0 : 16));
+        }
+        tmp += sum;
     }
+#endif
 
     // sum up partial sums and write back result
     __syncthreads();
@@ -803,7 +1034,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
         tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
     }
 
-    if (tid == 0) {
+    if (threadIdx.x == 0) {
         dst[row] = tmp;
     }
 }
@@ -820,6 +1051,8 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
 
     const block_q6_K * x = (const block_q6_K *)vx + ib0;
 
+#if QK_K == 256
+
     const int tid = threadIdx.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
     const int ix  = threadIdx.x%K_QUANTS_PER_ITERATION;  // 0 or 0, 1
 
@@ -874,6 +1107,37 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
 
     }
 
+#else
+
+    const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION);  // 0...7
+    const int ix  = threadIdx.x%(2*K_QUANTS_PER_ITERATION);  // 0...3
+
+    const int step = tid * K_QUANTS_PER_ITERATION;
+
+    float tmp = 0; // partial sum for thread in warp
+
+    for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
+
+        const float   * y  = yy + i * QK_K + step;
+        const uint8_t * ql = x[i].ql + step;
+        const uint8_t * qh = x[i].qh + step;
+        const int8_t  * s  = x[i].scales;
+
+        const float d = x[i+0].d;
+
+        float sum = 0;
+        for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
+            sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32)
+                 + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32)
+                 + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >>  4) | ((qh[j] & 0x30) >> 0)) - 32)
+                 + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >>  4) | ((qh[j] & 0xc0) >> 2)) - 32);
+        }
+        tmp += sum;
+
+    }
+
+#endif
+
     // sum up partial sums and write back result
     __syncthreads();
 #pragma unroll
@@ -985,7 +1249,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y,
 }
 
 static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
-    const half * x = (half *) vx;
+    const half * x = (const half *) vx;
 
     const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
     const int channel = blockDim.z*blockIdx.z + threadIdx.z;
@@ -1033,9 +1297,9 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl
 
 static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
     const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
-    const int row_stride_x, const int nchannels_x, const int channel_stride_x) {
+    const int row_stride_x, const int channel_stride_x) {
 
-    const half * x = (half *) vx;
+    const half * x = (const half *) vx;
 
     const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
     const int channel = blockDim.z*blockIdx.z + threadIdx.z;
@@ -1078,14 +1342,14 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
 }
 
 static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
-    const float * xi = (float *) cxi;
+    const float * xi = (const float *) cxi;
     float * dsti = (float *) cdsti;
 
     *dsti = *xi;
 }
 
 static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
-    const float * xi = (float *) cxi;
+    const float * xi = (const float *) cxi;
     half * dsti = (half *) cdsti;
 
     *dsti = __float2half(*xi);
@@ -1209,6 +1473,11 @@ static void add_f32_cuda(const float * x, const float * y, float * dst, const in
     add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
 }
 
+static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
+    add_f16_f32_f16<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
+}
+
 static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
     const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
     mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
@@ -1252,12 +1521,20 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
 
 static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
+#if QK_K == 256
     dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
+#else
+    dequantize_block_q2_K<<<nb, 32, 0, stream>>>(vx, y);
+#endif
 }
 
 static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
+#if QK_K == 256
     dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
+#else
+    dequantize_block_q3_K<<<nb, 32, 0, stream>>>(vx, y);
+#endif
 }
 
 static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -1267,12 +1544,20 @@ static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cu
 
 static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
+#if QK_K == 256
     dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
+#else
+    dequantize_block_q5_K<<<nb, 32, 0, stream>>>(vx, y);
+#endif
 }
 
 static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
+#if QK_K == 256
     dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
+#else
+    dequantize_block_q6_K<<<nb, 32, 0, stream>>>(vx, y);
+#endif
 }
 
 static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
@@ -1418,7 +1703,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_cuda(
     const dim3 block_nums(1, nrows_x, nchannels_x);
     const dim3 block_dims(WARP_SIZE, 1, 1);
     mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
-        (vx, y, dst, ncols_x, nrows_x, row_stride_x, nchannels_x, channel_stride_x);
+        (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x);
 }
 
 static void ggml_cpy_f32_f32_cuda(
@@ -1675,7 +1960,7 @@ inline void ggml_cuda_op_add(
     float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
     cudaStream_t & cudaStream_main){
 
-    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(src0_ddq_i != nullptr || src0_ddf_i != nullptr);
     GGML_ASSERT(src1_ddf_i != nullptr);
     GGML_ASSERT(dst_ddf_i != nullptr);
 
@@ -1683,8 +1968,13 @@ inline void ggml_cuda_op_add(
     const int64_t i01_diff = i01_high - i01_low;
 
     // compute
-    add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
-    CUDA_CHECK(cudaGetLastError());
+    if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+        add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
+    } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
+        add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne0*i01_diff, cudaStream_main);
+    } else {
+        GGML_ASSERT(false);
+    }
 
     (void) src1;
     (void) dst;
@@ -1716,7 +2006,6 @@ inline void ggml_cuda_op_mul(
 
         // compute
         mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
-        CUDA_CHECK(cudaGetLastError());
     }
 
     (void) dst;
@@ -1737,7 +2026,6 @@ inline void ggml_cuda_op_silu(
 
     // compute
     silu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
-    CUDA_CHECK(cudaGetLastError());
 
     (void) src1;
     (void) dst;
@@ -1760,7 +2048,6 @@ inline void ggml_cuda_op_rms_norm(
 
     // compute
     rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
-    CUDA_CHECK(cudaGetLastError());
 
     (void) src1;
     (void) dst;
@@ -1839,7 +2126,6 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
             GGML_ASSERT(false);
             break;
     }
-    CUDA_CHECK(cudaGetLastError());
 
 #ifdef GGML_CUDA_DMMV_F16
     if (src1_convert_f16) {
@@ -1916,7 +2202,6 @@ inline void ggml_cuda_op_rope(
 
     // compute
     rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
-    CUDA_CHECK(cudaGetLastError());
 
     (void) dst;
     (void) src0_ddq_i;
@@ -1940,7 +2225,6 @@ inline void ggml_cuda_op_diag_mask_inf(
 
     // compute
     diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main);
-    CUDA_CHECK(cudaGetLastError());
 
     (void) dst;
     (void) src0_ddq_i;
@@ -1962,7 +2246,6 @@ inline void ggml_cuda_op_soft_max(
 
     // compute
     soft_max_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
-    CUDA_CHECK(cudaGetLastError());
 
     (void) src1;
     (void) dst;
@@ -2058,10 +2341,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
     size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
     size_t  dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
 
-    // if multiple GPUs are used they need to wait for the main GPU to finish
+    // if multiple devices are used they need to wait for the main device
+    // here an event is recorded that signifies that the main device has finished calculating the input data
     if (split && g_device_count > 1) {
         CUDA_CHECK(cudaSetDevice(g_main_device));
-        CUDA_CHECK(cudaDeviceSynchronize());
+        CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device], g_cudaStreams_main[g_main_device]));
     }
 
     for (int id = 0; id < g_device_count; ++id) {
@@ -2087,6 +2371,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
         int64_t row_diff = row_high - row_low;
 
         cudaSetDevice(id);
+        cudaStream_t cudaStream_main = g_cudaStreams_main[id];
+
+        // wait for main GPU data if necessary
+        if (split && id != g_main_device) {
+            CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, src0_extra->events[g_main_device]));
+        }
 
         if (src0_on_device && src0_is_contiguous) {
             if (src0_is_f32) {
@@ -2162,8 +2452,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
                 }
                 const int64_t i11 = i13*ne12 + i12;
 
-                cudaStream_t cudaStream_main = g_cudaStreams_main[id];
-
                 // for split tensors the data begins at i0 == i0_offset_low
                 char  * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
                 float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride;
@@ -2223,6 +2511,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
 
                 // do the computation
                 op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main);
+                CUDA_CHECK(cudaGetLastError());
 
                 // copy dst to host or other device if necessary
                 if (!dst_on_device) {
@@ -2252,6 +2541,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
                         CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, dst_stride*sizeof(float), kind, cudaStream_main));
                     }
                 }
+
+                // signify to main device that other device is done
+                if (split && g_device_count > 1 && id != g_main_device) {
+                    CUDA_CHECK(cudaEventRecord(src0_extra->events[id], cudaStream_main));
+                }
             }
         }
     }
@@ -2263,7 +2557,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
         }
 
         CUDA_CHECK(cudaSetDevice(id));
-        CUDA_CHECK(cudaDeviceSynchronize());
 
         if (src0_asq[id] > 0) {
             ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]);
@@ -2278,11 +2571,32 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
             ggml_cuda_pool_free(dst_ddf[id], dst_asf[id]);
         }
     }
+
+    // main device waits for all other devices to be finished
+    if (split && g_device_count > 1) {
+        CUDA_CHECK(cudaSetDevice(g_main_device));
+        for (int id = 0; id < g_device_count; ++id) {
+            if (id != g_main_device) {
+                CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams_main[g_main_device], src0_extra->events[id]));
+            }
+        }
+    }
+
+    if (dst->backend == GGML_BACKEND_CPU) {
+        CUDA_CHECK(cudaSetDevice(g_main_device));
+        CUDA_CHECK(cudaDeviceSynchronize());
+    }
 }
 
 void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
-    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, true, true);
+    // ggml_cuda_add permits f16 dst even though this could in theory cause problems with the pointer arithmetic in ggml_cuda_op.
+    // Due to flatten_rows == true this does in practice not make a difference however.
+    // Better solution would be nice but right now that would require disproportionate changes.
+    GGML_ASSERT(
+        (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) &&
+        src1->type == GGML_TYPE_F32 &&
+        (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16));
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, false, true);
 }
 
 void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -2511,6 +2825,10 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
         cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
 
         extra->data_device[id] = buf;
+
+        if (backend == GGML_BACKEND_GPU_SPLIT) {
+            CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id], cudaEventDisableTiming));
+        }
     }
 
     tensor->extra = extra;
@@ -2524,18 +2842,21 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
     ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
 
     for (int id = 0; id < g_device_count; ++id) {
-        if (extra->data_device[id] == nullptr) {
-            continue;
+        if (extra->data_device[id] != nullptr) {
+            CUDA_CHECK(cudaSetDevice(id));
+            CUDA_CHECK(cudaFree(extra->data_device[id]));
         }
 
-        CUDA_CHECK(cudaSetDevice(id));
-        CUDA_CHECK(cudaFree(extra->data_device[id]));
+        if (extra->events[id] != nullptr) {
+            CUDA_CHECK(cudaSetDevice(id));
+            CUDA_CHECK(cudaEventDestroy(extra->events[id]));
+        }
     }
 
     delete extra;
 }
 
-void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
+void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) {
     if (scratch && g_scratch_size == 0) {
         return;
     }
@@ -2544,22 +2865,24 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
     if (tensor->src0 != nullptr && tensor->src0->backend == GGML_BACKEND_CPU) {
         const ggml_op src0_op = tensor->src0->op;
         if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) {
-            ggml_cuda_assign_buffers_impl(tensor->src0, scratch);
+            ggml_cuda_assign_buffers_impl(tensor->src0, scratch, force_inplace);
         }
     }
     if (tensor->op == GGML_OP_CPY && tensor->src1->backend == GGML_BACKEND_CPU) {
-        ggml_cuda_assign_buffers_impl(tensor->src1, scratch);
+        ggml_cuda_assign_buffers_impl(tensor->src1, scratch, force_inplace);
     }
 
     tensor->backend = GGML_BACKEND_GPU;
     struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
+    memset(extra, 0, sizeof(*extra));
 
     const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) ||
-        tensor->op == GGML_OP_VIEW;
+        tensor->op == GGML_OP_VIEW ||
+        force_inplace;
     const size_t size = ggml_nbytes(tensor);
 
     CUDA_CHECK(cudaSetDevice(g_main_device));
-    if (inplace && tensor->src0->backend == GGML_BACKEND_GPU) {
+    if (inplace && (tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT)) {
         struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src0->extra;
         char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
         size_t offset = 0;
@@ -2598,11 +2921,15 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
 }
 
 void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
-    ggml_cuda_assign_buffers_impl(tensor, true);
+    ggml_cuda_assign_buffers_impl(tensor, true, false);
 }
 
 void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
-    ggml_cuda_assign_buffers_impl(tensor, false);
+    ggml_cuda_assign_buffers_impl(tensor, false, false);
+}
+
+void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) {
+    ggml_cuda_assign_buffers_impl(tensor, false, true);
 }
 
 void ggml_cuda_set_main_device(int main_device) {
@@ -2635,7 +2962,7 @@ void ggml_cuda_free_scratch() {
 bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor){
     ggml_cuda_func_t func;
     const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
-        || tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT
+        || (tensor->src0 != nullptr && (tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT))
         || (tensor->src1 != nullptr && tensor->src1->backend == GGML_BACKEND_GPU);
 
     switch (tensor->op) {
index d32b4484267ab5a3f9028db07330302432c1fd93..3c1e8deb6a6ddbfbb43e22bd01cc64e560fe313d 100644 (file)
@@ -8,10 +8,6 @@ extern "C" {
 
 #define GGML_CUDA_MAX_DEVICES       16
 
-struct ggml_tensor_extra_gpu {
-    void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
-};
-
 void   ggml_init_cublas(void);
 void   ggml_cuda_set_tensor_split(const float * tensor_split);
 
@@ -29,6 +25,7 @@ void   ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
 void   ggml_cuda_free_data(struct ggml_tensor * tensor);
 void   ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
 void   ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
+void   ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
 void   ggml_cuda_set_main_device(int main_device);
 void   ggml_cuda_set_scratch_size(size_t scratch_size);
 void   ggml_cuda_free_scratch(void);
index a7e104dc76fcafbe7bb405dfd235778eb79d0388..fd69c41fe357d6e1636f2640186c7908352450b9 100644 (file)
@@ -51,21 +51,21 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(get_rows_f16);
     GGML_METAL_DECL_KERNEL(get_rows_q4_0);
     GGML_METAL_DECL_KERNEL(get_rows_q4_1);
-    GGML_METAL_DECL_KERNEL(get_rows_q2_k);
-    GGML_METAL_DECL_KERNEL(get_rows_q3_k);
-    GGML_METAL_DECL_KERNEL(get_rows_q4_k);
-    GGML_METAL_DECL_KERNEL(get_rows_q5_k);
-    GGML_METAL_DECL_KERNEL(get_rows_q6_k);
+    GGML_METAL_DECL_KERNEL(get_rows_q2_K);
+    GGML_METAL_DECL_KERNEL(get_rows_q3_K);
+    GGML_METAL_DECL_KERNEL(get_rows_q4_K);
+    GGML_METAL_DECL_KERNEL(get_rows_q5_K);
+    GGML_METAL_DECL_KERNEL(get_rows_q6_K);
     GGML_METAL_DECL_KERNEL(rms_norm);
     GGML_METAL_DECL_KERNEL(norm);
     GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q3_k_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q5_k_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
     GGML_METAL_DECL_KERNEL(rope);
     GGML_METAL_DECL_KERNEL(alibi_f32);
     GGML_METAL_DECL_KERNEL(cpy_f32_f16);
@@ -132,7 +132,13 @@ struct ggml_metal_context * ggml_metal_init(void) {
             exit(1);
         }
 
+#ifdef GGML_QKK_64
+        MTLCompileOptions* options = [MTLCompileOptions new];
+        options.preprocessorMacros = @{ @"QK_K" : @(64) };
+        ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
+#else
         ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
+#endif
         if (error) {
             fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
             exit(1);
@@ -159,21 +165,21 @@ struct ggml_metal_context * ggml_metal_init(void) {
         GGML_METAL_ADD_KERNEL(get_rows_f16);
         GGML_METAL_ADD_KERNEL(get_rows_q4_0);
         GGML_METAL_ADD_KERNEL(get_rows_q4_1);
-        GGML_METAL_ADD_KERNEL(get_rows_q2_k);
-        GGML_METAL_ADD_KERNEL(get_rows_q3_k);
-        GGML_METAL_ADD_KERNEL(get_rows_q4_k);
-        GGML_METAL_ADD_KERNEL(get_rows_q5_k);
-        GGML_METAL_ADD_KERNEL(get_rows_q6_k);
+        GGML_METAL_ADD_KERNEL(get_rows_q2_K);
+        GGML_METAL_ADD_KERNEL(get_rows_q3_K);
+        GGML_METAL_ADD_KERNEL(get_rows_q4_K);
+        GGML_METAL_ADD_KERNEL(get_rows_q5_K);
+        GGML_METAL_ADD_KERNEL(get_rows_q6_K);
         GGML_METAL_ADD_KERNEL(rms_norm);
         GGML_METAL_ADD_KERNEL(norm);
         GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q3_k_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q5_k_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
         GGML_METAL_ADD_KERNEL(rope);
         GGML_METAL_ADD_KERNEL(alibi_f32);
         GGML_METAL_ADD_KERNEL(cpy_f32_f16);
@@ -196,7 +202,9 @@ struct ggml_metal_context * ggml_metal_init(void) {
 
 void ggml_metal_free(struct ggml_metal_context * ctx) {
     fprintf(stderr, "%s: deallocating\n", __func__);
-
+    for (int i = 0; i < ctx->n_buffers; ++i) {
+        [ctx->buffers[i].metal release];
+    }
     free(ctx);
 }
 
@@ -662,7 +670,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 4;
                                             nth1 = 16;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
                                         } break;
                                     case GGML_TYPE_Q3_K:
                                         {
@@ -671,7 +679,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 4;
                                             nth1 = 16;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_k_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
                                         } break;
                                     case GGML_TYPE_Q4_K:
                                         {
@@ -680,7 +688,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 4;
                                             nth1 = 16;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
                                         } break;
                                     case GGML_TYPE_Q5_K:
                                         {
@@ -689,7 +697,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 4;
                                             nth1 = 16;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_k_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
                                         } break;
                                     case GGML_TYPE_Q6_K:
                                         {
@@ -698,7 +706,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 4;
                                             nth1 = 16;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
                                         } break;
                                     default:
                                         {
@@ -750,11 +758,11 @@ void ggml_metal_graph_compute(
                                 case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
                                 case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
                                 case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
-                                case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
-                                case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_k]; break;
-                                case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
-                                case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_k]; break;
-                                case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
+                                case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
+                                case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
+                                case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
+                                case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
+                                case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
                                 default: GGML_ASSERT(false && "not implemented");
                             }
 
index d1e49222db2eb6c9c7614978fc88c349b86ed5e9..e62fe6842ea72bf950619935c74e9fecf1d10799 100644 (file)
@@ -428,7 +428,7 @@ kernel void kernel_mul_mat_q4_0_f32(
     }
     threadgroup_barrier(mem_flags::mem_threadgroup);
     if (ith == 0) {
-        for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
+        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
         dst[r1*ne0 + r0] = sum[0];
     }
 }
@@ -497,7 +497,7 @@ kernel void kernel_mul_mat_q4_1_f32(
     }
     threadgroup_barrier(mem_flags::mem_threadgroup);
     if (ith == 0) {
-        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
+        for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
         dst[r1*ne0 + r0] = sum[0];
     }
 }
@@ -775,47 +775,76 @@ kernel void kernel_cpy_f32_f32(
 
 //============================================ k-quants ======================================================
 
+#ifndef QK_K
 #define QK_K 256
+#else
+static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64");
+#endif
+
+#if QK_K == 256
+#define K_SCALE_SIZE 12
+#else
+#define K_SCALE_SIZE 4
+#endif
 
 typedef struct {
     uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
     uint8_t qs[QK_K/4];      // quants
     half d;           // super-block scale for quantized scales
     half dmin;        // super-block scale for quantized mins
-} block_q2_k;
+} block_q2_K;
 // 84 bytes / block
 
 typedef struct {
     uint8_t hmask[QK_K/8];     // quants - high bit
     uint8_t qs[QK_K/4];        // quants - low 2 bits
-    uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
-    half d;                    // super-block scale
-} block_q3_k;
-// 110 bytes / block
-
+#if QK_K == 64
+    uint8_t scales[2];
+#else
+    uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
+#endif
+    half d;             // super-block scale
+} block_q3_K;
+
+#if QK_K == 64
+typedef struct {
+    half    d[2];          // super-block scales/mins
+    uint8_t scales[2];
+    uint8_t qs[QK_K/2];    // 4-bit quants
+} block_q4_K;
+#else
 typedef struct {
     half d;             // super-block scale for quantized scales
     half dmin;          // super-block scale for quantized mins
-    uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
+    uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
     uint8_t qs[QK_K/2];        // 4--bit quants
-} block_q4_k;
-// 144 bytes / block
+} block_q4_K;
+#endif
 
+#if QK_K == 64
+typedef struct {
+    half  d;                     // super-block scales/mins
+    int8_t  scales[QK_K/16];     // 8-bit block scales
+    uint8_t qh[QK_K/8];          // quants, high bit
+    uint8_t qs[QK_K/2];          // quants, low 4 bits
+} block_q5_K;
+#else
 typedef struct {
     half d;                      // super-block scale for quantized scales
     half dmin;                   // super-block scale for quantized mins
     uint8_t scales[3*QK_K/64];   // scales and mins, quantized with 6 bits
     uint8_t qh[QK_K/8];          // quants, high bit
     uint8_t qs[QK_K/2];          // quants, low 4 bits
-} block_q5_k;
+} block_q5_K;
 // 176 bytes / block
+#endif
 
 typedef struct {
     uint8_t ql[QK_K/2];      // quants, lower 4 bits
     uint8_t qh[QK_K/4];      // quants, upper 2 bits
     int8_t  scales[QK_K/16]; // scales, quantized with 8 bits
     half d;                  // super-block scale
-} block_q6_k;
+} block_q6_K;
 // 210 bytes / block
 
 static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
@@ -836,7 +865,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
 
 //========================================== dequantization =============================
 
-static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, int k) {
+static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
@@ -847,6 +876,7 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
 
         device const uint8_t * q = x[i].qs;
 
+#if QK_K == 256
         int is = 0;
         float dl, ml;
         for (int n = 0; n < QK_K; n += 128) {
@@ -865,14 +895,29 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
             }
             q += 32;
         }
+#else
+        float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
+        float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
+        float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
+        float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
+        for (int l = 0; l < 16; ++l) {
+            y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1;
+            y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2;
+            y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3;
+            y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4;
+        }
+        y += QK_K;
+#endif
 
     }
 }
 
-static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, int k) {
+static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
+#if QK_K == 256
+
     const uint16_t kmask1 = 0x0303;
     const uint16_t kmask2 = 0x0f0f;
 
@@ -918,22 +963,49 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i
             }
             q += 32;
         }
+    }
+#else
+    for (int i = 0; i < nb; i++) {
 
+        const float d_all = (float)(x[i].d);
+
+        device const uint8_t * q = x[i].qs;
+        device const uint8_t * hm = x[i].hmask;
+
+        const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
+        const float d2 = d_all * ((x[i].scales[0] >>  4) - 8);
+        const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
+        const float d4 = d_all * ((x[i].scales[1] >>  4) - 8);
+
+        for (int l = 0; l < 8; ++l) {
+            uint8_t h = hm[l];
+            y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
+            y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
+            y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
+            y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
+            y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
+            y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
+            y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
+            y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
+        }
+        y += QK_K;
     }
+#endif
 
 }
 
-static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
+static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
-
     for (int i = 0; i < nb; i++) {
 
+        device const uint8_t * q = x[i].qs;
+
+#if QK_K == 256
         const float d = x[i].d;
         const float min = x[i].dmin;
 
-        device const uint8_t * q = x[i].qs;
         device const uint8_t * scales = x[i].scales;
 
         int is = 0;
@@ -945,14 +1017,29 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i
             for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l]  >> 4) - m2;
             q += 32; is += 2;
         }
+#else
+        device const uint8_t * s = x[i].scales;
+        device const half2 * dh = (device const half2 *)x[i].d;
+        const float2 d = (float2)dh[0];
+        const float d1 = d[0] * (s[0] & 0xF);
+        const float d2 = d[0] * (s[1] & 0xF);
+        const float m1 = d[1] * (s[0] >>  4);
+        const float m2 = d[1] * (s[1] >>  4);
+        for (int l = 0; l < 32; ++l) {
+            y[l+ 0] = d1 * (q[l] & 0xF) - m1;
+            y[l+32] = d2 * (q[l] >>  4) - m2;
+        }
+        y += QK_K;
+#endif
 
     }
 }
 
-static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, int k) {
+static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
+#if QK_K == 256
    for (int i = 0; i < nb; i++) {
 
         const float d = (float)(x[i].d);
@@ -973,10 +1060,32 @@ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, i
             u1 <<= 2; u2 <<= 2;
         }
     }
+#else
+    for (int i = 0; i < nb; i++) {
+
+        const float d = (float)x[i].d;
+
+        device const uint8_t * ql = x[i].qs;
+        device const uint8_t * qh = x[i].qh;
+        device const int8_t  * sc = x[i].scales;
+
+        for (int l = 0; l < 8; ++l) {
+            y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
+            y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
+            y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
+            y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
+            y[l+32] = d * sc[2] * ((ql[l+ 0] >>  4) - (qh[l] & 0x10 ? 0 : 16));
+            y[l+40] = d * sc[2] * ((ql[l+ 8] >>  4) - (qh[l] & 0x20 ? 0 : 16));
+            y[l+48] = d * sc[3] * ((ql[l+16] >>  4) - (qh[l] & 0x40 ? 0 : 16));
+            y[l+56] = d * sc[3] * ((ql[l+24] >>  4) - (qh[l] & 0x80 ? 0 : 16));
+        }
+        y += QK_K;
+    }
+#endif
 
 }
 
-static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) {
+static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
 
@@ -988,6 +1097,7 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
 
         const float d = x[i].d;
 
+#if QK_K == 256
         for (int n = 0; n < QK_K; n += 128) {
             for (int l = 0; l < 32; ++l) {
                 int is = l/16;
@@ -1005,10 +1115,23 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
             qh += 32;
             sc += 8;
         }
+#else
+        for (int l = 0; l < 16; ++l) {
+            const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
+            const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
+            const int8_t q3 = (int8_t)((ql[l+ 0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
+            const int8_t q4 = (int8_t)((ql[l+16]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+            y[l+ 0] = d * sc[0] * q1;
+            y[l+16] = d * sc[1] * q2;
+            y[l+32] = d * sc[2] * q3;
+            y[l+48] = d * sc[3] * q4;
+        }
+        y  += 64;
+#endif
     }
 }
 
-kernel void kernel_get_rows_q2_k(
+kernel void kernel_get_rows_q2_K(
         device const  void * src0,
         device const   int * src1,
         device       float * dst,
@@ -1019,12 +1142,12 @@ kernel void kernel_get_rows_q2_k(
     const int i = tpig;
     const int r = ((device int32_t *) src1)[i];
 
-    dequantize_row_q2_k(
-            (device const block_q2_k *) ((device char *) src0 + r*nb01),
+    dequantize_row_q2_K(
+            (device const block_q2_K *) ((device char *) src0 + r*nb01),
                        (device float *) ((device char *)  dst + i*nb1), ne00);
 }
 
-kernel void kernel_get_rows_q3_k(
+kernel void kernel_get_rows_q3_K(
         device const  void * src0,
         device const   int * src1,
         device       float * dst,
@@ -1035,12 +1158,12 @@ kernel void kernel_get_rows_q3_k(
     const int i = tpig;
     const int r = ((device int32_t *) src1)[i];
 
-    dequantize_row_q3_k(
-            (device const block_q3_k *) ((device char *) src0 + r*nb01),
+    dequantize_row_q3_K(
+            (device const block_q3_K *) ((device char *) src0 + r*nb01),
                        (device float *) ((device char *)  dst + i*nb1), ne00);
 }
 
-kernel void kernel_get_rows_q4_k(
+kernel void kernel_get_rows_q4_K(
         device const  void * src0,
         device const   int * src1,
         device       float * dst,
@@ -1051,12 +1174,12 @@ kernel void kernel_get_rows_q4_k(
     const int i = tpig;
     const int r = ((device int32_t *) src1)[i];
 
-    dequantize_row_q4_k(
-            (device const block_q4_k *) ((device char *) src0 + r*nb01),
+    dequantize_row_q4_K(
+            (device const block_q4_K *) ((device char *) src0 + r*nb01),
                        (device float *) ((device char *)  dst + i*nb1), ne00);
 }
 
-kernel void kernel_get_rows_q5_k(
+kernel void kernel_get_rows_q5_K(
         device const  void * src0,
         device const   int * src1,
         device       float * dst,
@@ -1067,12 +1190,12 @@ kernel void kernel_get_rows_q5_k(
     const int i = tpig;
     const int r = ((device int32_t *) src1)[i];
 
-    dequantize_row_q5_k(
-            (device const block_q5_k *) ((device char *) src0 + r*nb01),
+    dequantize_row_q5_K(
+            (device const block_q5_K *) ((device char *) src0 + r*nb01),
                        (device float *) ((device char *)  dst + i*nb1), ne00);
 }
 
-kernel void kernel_get_rows_q6_k(
+kernel void kernel_get_rows_q6_K(
         device const  void * src0,
         device const   int * src1,
         device       float * dst,
@@ -1083,14 +1206,14 @@ kernel void kernel_get_rows_q6_k(
     const int i = tpig;
     const int r = ((device int32_t *) src1)[i];
 
-    dequantize_row_q6_k(
-            (device const block_q6_k *) ((device char *) src0 + r*nb01),
+    dequantize_row_q6_K(
+            (device const block_q6_K *) ((device char *) src0 + r*nb01),
                        (device float *) ((device char *)  dst + i*nb1), ne00);
 }
 
 //====================================== dot products =========================
 
-kernel void kernel_mul_mat_q2_k_f32(
+kernel void kernel_mul_mat_q2_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -1107,12 +1230,15 @@ kernel void kernel_mul_mat_q2_k_f32(
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
 
-    device const block_q2_k * x = (device const block_q2_k *) src0 + r0*nb;
+    device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
     device const float     * yy = (device const float      *) src1 + r1*ne10;
 
     const int nth = tptg.x*tptg.y;
     const int ith = tptg.y*tpitg.x + tpitg.y;
 
+    float sumf = 0;
+
+#if QK_K == 256
     const int tid = tpitg.y;    // 0...16
     const int il  = tid/4;      // 0...3
     const int ir  = tid%4;      // 0...3
@@ -1125,9 +1251,6 @@ kernel void kernel_mul_mat_q2_k_f32(
     const int y_offset = 64*il + n*ir;
     const int q_offset = 32*ip + n*ir;
 
-    sum[ith] = 0.0f;
-
-    float sumf = 0;
     for (int i = tpitg.x; i < nb; i += tptg.x) {
 
         device const uint8_t * q = x[i].qs + q_offset;
@@ -1140,7 +1263,6 @@ kernel void kernel_mul_mat_q2_k_f32(
 
         device const float   * y = yy + i*QK_K + y_offset;
 
-        //float4 s = {0.f, 0.f, 0.f, 0.f};
         float2 s = {0.f, 0.f};
         float smin = 0;
         for (int l = 0; l < n; ++l) {
@@ -1155,25 +1277,38 @@ kernel void kernel_mul_mat_q2_k_f32(
         sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
 
     }
-    sum[ith] = sumf;
+#else
+    const int il = 4 * tpitg.x;
 
-    //int mask1 = (ith%4 == 0);
-    //int mask2 = (ith%16 == 0);
+    uint32_t aux[2];
+    thread const uint8_t * d = (thread const uint8_t *)aux;
+    thread const uint8_t * m = (thread const uint8_t *)aux + 4;
 
-    //threadgroup_barrier(mem_flags::mem_threadgroup);
-    //for (int i = 1; i < 4; ++i) sum[ith] += mask1 * sum[ith + i];
-    //threadgroup_barrier(mem_flags::mem_threadgroup);
-    //for (int i = 4; i < 16; i += 4) sum[ith] += mask2 * sum[ith + i];
-    //threadgroup_barrier(mem_flags::mem_threadgroup);
-    //if (ith == 0) {
-    //    for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
-    //    dst[r1*ne0 + r0] = sum[0];
-    //}
+    for (int i = tpitg.y; i < nb; i += tptg.y) {
+
+        device const uint8_t * q = x[i].qs + il;
+        device const float   * y = yy + i*QK_K + il;
+
+        const float dall = (float)x[i].d;
+        const float dmin = (float)x[i].dmin;
+
+        device const uint32_t * a = (device const uint32_t *)x[i].scales;
+        aux[0] = a[0] & 0x0f0f0f0f;
+        aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
+
+        for (int l = 0; l < 4; ++l) {
+            sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
+                  + y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
+                  + y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
+                  + y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
+        }
+    }
+#endif
+
+    sum[ith] = sumf;
 
     //
     // Accumulate the sum from all threads in the threadgroup
-    // This version is slightly faster than the commented out one below,
-    // which I copy-pasted from ggerganov's q4_0 dot product for metal.
     //
     threadgroup_barrier(mem_flags::mem_threadgroup);
     if (ith%4 == 0) {
@@ -1190,7 +1325,7 @@ kernel void kernel_mul_mat_q2_k_f32(
     }
 }
 
-kernel void kernel_mul_mat_q3_k_f32(
+kernel void kernel_mul_mat_q3_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -1203,23 +1338,25 @@ kernel void kernel_mul_mat_q3_k_f32(
         uint2 tpitg[[thread_position_in_threadgroup]],
         uint2  tptg[[threads_per_threadgroup]]) {
 
-    const uint16_t kmask1 = 0x0303;
-    const uint16_t kmask2 = 0x0f0f;
-
-    const uint8_t m3 = 3;
-    const int8_t  m4 = 4;
-
     const int nb = ne00/QK_K;
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
 
-    device const block_q3_k * x = (device const block_q3_k *) src0 + r0*nb;
+    device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
     device const float     * yy = (device const float      *) src1 + r1*ne10;
 
     const int nth = tptg.x*tptg.y;
     const int ith = tptg.y*tpitg.x + tpitg.y;
 
+#if QK_K == 256
+
+    const uint8_t m3 = 3;
+    const int8_t  m4 = 4;
+
+    const uint16_t kmask1 = 0x0303;
+    const uint16_t kmask2 = 0x0f0f;
+
     const int tid = tpitg.y;        // expecting 16
     const int ip  = tid/8;          // 0 or 1
     const int il  = tid/2 - 4*ip;   // 0...3
@@ -1273,6 +1410,39 @@ kernel void kernel_mul_mat_q3_k_f32(
 
     //sum[ith] = sumf;
     sum[ith] = sumf1 - 32.f*sumf2;
+#else
+    const int il = 4 * tpitg.x;  // 0, 4, 8, 12
+    const int im = il/8;         // 0, 0, 1, 1
+    const int in = il%8;         // 0, 4, 0, 4
+
+    float sumf = 0;
+
+    for (int i = tpitg.y; i < nb; i += tptg.y) {
+
+        const float d_all = (float)(x[i].d);
+
+        device const uint8_t * q = x[i].qs + il;
+        device const uint8_t * h = x[i].hmask + in;
+        device const float   * y = yy + i * QK_K + il;
+
+        const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
+        const float d2 = d_all * ((x[i].scales[0] >>  4) - 8);
+        const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
+        const float d4 = d_all * ((x[i].scales[1] >>  4) - 8);
+
+        for (int l = 0; l < 4; ++l) {
+            const uint8_t hm = h[l] >> im;
+            sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
+                  + y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
+                  + y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
+                  + y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
+        }
+
+    }
+
+    sum[ith] = sumf;
+
+#endif
 
     //
     // Accumulate the sum from all threads in the threadgroup
@@ -1293,7 +1463,7 @@ kernel void kernel_mul_mat_q3_k_f32(
 
 }
 
-kernel void kernel_mul_mat_q4_k_f32(
+kernel void kernel_mul_mat_q4_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -1305,21 +1475,25 @@ kernel void kernel_mul_mat_q4_k_f32(
         uint2 tpitg[[thread_position_in_threadgroup]],
         uint2  tptg[[threads_per_threadgroup]]) {
 
-    const uint16_t kmask1 = 0x3f3f;
-    const uint16_t kmask2 = 0x0f0f;
-    const uint16_t kmask3 = 0xc0c0;
-
     const int nb = ne00/QK_K;
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
 
-    device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
-    device const float     * yy = (device const float      *) src1 + r1*ne10;
-
     const int nth = tptg.x*tptg.y;
     const int ith = tptg.y*tpitg.x + tpitg.y;
 
+    device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb;
+    device const float     * yy = (device const float      *) src1 + r1*ne10;
+
+    float sumf = 0;
+
+#if QK_K == 256
+
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
     const int tid = tpitg.y;   // 0...16
     const int il  = tid/4;     // 0...3
     const int ir  = tid - 4*il;// 0...3
@@ -1332,11 +1506,8 @@ kernel void kernel_mul_mat_q4_k_f32(
     const int q_offset = 32*im + l0;
     const int y_offset = 64*im + l0;
 
-    sum[ith] = 0.0f;
-
     uchar2 sc1, sc2, sc3, sc4;
 
-    float sumf = 0;
     for (int i = tpitg.x; i < nb; i += tptg.x) {
 
         device const uint8_t * q1 = (x + i)->qs + q_offset;
@@ -1365,6 +1536,30 @@ kernel void kernel_mul_mat_q4_k_f32(
         sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
 
     }
+#else
+    uint16_t aux16[2];
+    thread const uint8_t * scales = (thread const uint8_t *)aux16;
+
+    const int il  = 4*tpitg.x;
+
+    for (int i = tpitg.y; i < nb; i += tptg.y) {
+
+        device const uint8_t * q = x[i].qs + il;
+        device const float   * y = yy + i * QK_K + il;
+
+        const float d = (float)x[i].d[0];
+        const float m = (float)x[i].d[1];
+
+        device const uint16_t * a = (device const uint16_t *)x[i].scales;
+        aux16[0] = a[0] & 0x0f0f;
+        aux16[1] = (a[0] >> 4) & 0x0f0f;
+
+        for (int l = 0; l < 4; ++l) {
+            sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16])
+                  + d * scales[1] * (y[l+32] * (q[l] >>  4) + y[l+48] * (q[l+16] >>  4)) - m * scales[3] * (y[l+32] + y[l+48]);
+        }
+    }
+#endif
 
     sum[ith] = sumf;
 
@@ -1401,7 +1596,7 @@ kernel void kernel_mul_mat_q4_k_f32(
     //}
 }
 
-kernel void kernel_mul_mat_q5_k_f32(
+kernel void kernel_mul_mat_q5_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -1413,21 +1608,25 @@ kernel void kernel_mul_mat_q5_k_f32(
         uint2 tpitg[[thread_position_in_threadgroup]],
         uint2  tptg[[threads_per_threadgroup]]) {
 
-    const uint16_t kmask1 = 0x3f3f;
-    const uint16_t kmask2 = 0x0f0f;
-    const uint16_t kmask3 = 0xc0c0;
-
     const int nb = ne00/QK_K;
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
 
-    device const block_q5_k * x = (device const block_q5_k *) src0 + r0*nb;
+    device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb;
     device const float     * yy = (device const float      *) src1 + r1*ne10;
 
     const int nth = tptg.x*tptg.y;
     const int ith = tptg.y*tpitg.x + tpitg.y;
 
+    float sumf = 0;
+
+#if QK_K == 256
+
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
     const int tid = tpitg.y;   // 0...16
     const int il  = tid/4;     // 0...3
     const int ir  = tid - 4*il;// 0...3
@@ -1447,7 +1646,6 @@ kernel void kernel_mul_mat_q5_k_f32(
 
     uchar2 sc1, sc2, sc3, sc4;
 
-    float sumf = 0;
     for (int i = tpitg.x; i < nb; i += tptg.x) {
 
         device const uint8_t * q1 = (x + i)->qs + q_offset;
@@ -1479,6 +1677,28 @@ kernel void kernel_mul_mat_q5_k_f32(
         sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
 
     }
+#else
+    const int il  = 4 * tpitg.x;  // 0, 4, 8, 12
+    const int im  = il/8;         // 0, 0, 1, 1
+    const int in  = il%8;         // 0, 4, 0, 4
+
+    for (int i = tpitg.y; i < nb; i += tptg.y) {
+
+        const float d = (float)x[i].d;
+        device const uint8_t * q = x[i].qs + il;
+        device const uint8_t * h = x[i].qh + in;
+        device const int8_t  * s = x[i].scales;
+        device const float   * y = yy + i*QK_K + il;
+
+        for (int l = 0; l < 4; ++l) {
+            const uint8_t hl = h[l] >> im;
+            sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16))
+                  + y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16))
+                  + y[l+32] * d * s[2] * ((q[l+ 0] >>  4) - (hl & 0x10 ? 0 : 16))
+                  + y[l+48] * d * s[3] * ((q[l+16] >>  4) - (hl & 0x40 ? 0 : 16));
+        }
+    }
+#endif
     sum[ith] = sumf;
 
     //
@@ -1500,7 +1720,7 @@ kernel void kernel_mul_mat_q5_k_f32(
 
 }
 
-kernel void kernel_mul_mat_q6_k_f32(
+kernel void kernel_mul_mat_q6_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -1522,12 +1742,15 @@ kernel void kernel_mul_mat_q6_k_f32(
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
 
-    device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb;
+    device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
     device const float     * yy = (device const float      *) src1 + r1*ne10;
 
     const int nth = tptg.x*tptg.y;
     const int ith = tptg.y*tpitg.x + tpitg.y;
 
+    float sumf = 0;
+
+#if QK_K == 256
     // Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
     const int iqs  = 16 * tpitg.y;
     const int ip   = iqs / 128;         // 0 or 1
@@ -1540,7 +1763,6 @@ kernel void kernel_mul_mat_q6_k_f32(
     const int q_offset_l = 64*ip + l0;
     const int q_offset_h = 32*ip + l0;
 
-    float sumf = 0;
     for (int i = tpitg.x; i < nb; i += tptg.x) {
 
         device const uint8_t * ql = x[i].ql + q_offset_l;
@@ -1562,6 +1784,28 @@ kernel void kernel_mul_mat_q6_k_f32(
         sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
 
     }
+#else
+    const int il  = 4*tpitg.x;    // 0, 4, 8, 12
+
+    for (int i = tpitg.y; i < nb; i += tptg.y) {
+        device const float * y = yy + i * QK_K + il;
+        device const uint8_t * ql = x[i].ql + il;
+        device const uint8_t * qh = x[i].qh + il;
+        device const int8_t  * s  = x[i].scales;
+
+        const float d = x[i].d;
+
+        float4 sums = {0.f, 0.f, 0.f, 0.f};
+        for (int l = 0; l < 4; ++l) {
+            sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
+            sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
+            sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >>  4) | ((qh[l] & kmask3) >> 0)) - 32);
+            sums[3] += y[l+48] * ((int8_t)((ql[l+16] >>  4) | ((qh[l] & kmask4) >> 2)) - 32);
+        }
+        sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
+    }
+
+#endif
 
     sum[ith] = sumf;
 
index 95f4cec6dd59cdc7b7ece51549420e355dfe3607..fed4ffb0ccd0538b56aa552023581974b3fe2332 100644 (file)
 
 #define CL_DMMV_BLOCK_SIZE 32
 
+#ifndef K_QUANTS_PER_ITERATION
+#define K_QUANTS_PER_ITERATION 1
+#else
+static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
+#endif
+
 #define MULTILINE_QUOTE(...) #__VA_ARGS__
 static std::string program_source = MULTILINE_QUOTE(
 
 typedef char int8_t;
 typedef uchar uint8_t;
+typedef short int16_t;
+typedef ushort uint16_t;
 typedef int int32_t;
 typedef uint uint32_t;
 
@@ -175,7 +183,9 @@ void convert_f16(__global half* x, const int ib, const int iqs, float* v0, float
     *v0 = vload_half(0, &x[ib + 0]);
     *v1 = vload_half(0, &x[ib + 1]);
 }
+);
 
+static std::string k_quants_source = MULTILINE_QUOTE(
 inline void get_scale_min_k4(int j, const __global uint8_t *q, uint8_t *d, uint8_t *m)
 {
     if (j < 4)
@@ -199,7 +209,7 @@ __kernel void dequantize_block_q2_K(__global const struct block_q2_K *x, __globa
     const int is = 8 * n + l / 16;
 
     const uint8_t q = x[i].qs[32 * n + l];
-    __global float *y = yy + i * 256 + 128 * n;
+    __global float *y = yy + i * QK_K + 128 * n;
 
     const float dall = vload_half(0, &x[i].d);
     const float dmin = vload_half(0, &x[i].dmin);
@@ -231,7 +241,7 @@ __kernel void dequantize_block_q3_K(__global const struct block_q3_K *x, __globa
     float d_all = vload_half(0, &x[i].d);
     float dl = d_all * (us - 32);
 
-    __global float *y = yy + i * 256 + 128 * n + 32 * j;
+    __global float *y = yy + i * QK_K + 128 * n + 32 * j;
     const __global uint8_t *q = x[i].qs + 32 * n;
     const __global uint8_t *hm = x[i].hmask;
 
@@ -248,7 +258,7 @@ __kernel void dequantize_block_q4_K(__global const struct block_q4_K *x, __globa
     const int is = 2 * il;
     const int n = 4;
 
-    __global float *y = yy + i * 256 + 64 * il + n * ir;
+    __global float *y = yy + i * QK_K + 64 * il + n * ir;
 
     const float dall = vload_half(0, &x[i].d);
     const float dmin = vload_half(0, &x[i].dmin);
@@ -277,7 +287,7 @@ __kernel void dequantize_block_q5_K(__global const struct block_q5_K *x, __globa
     const int ir = tid % 16;
     const int is = 2 * il;
 
-    __global float *y = yy + i * 256 + 64 * il + 2 * ir;
+    __global float *y = yy + i * QK_K + 64 * il + 2 * ir;
 
     const float dall = vload_half(0, &x[i].d);
     const float dmin = vload_half(0, &x[i].dmin);
@@ -309,7 +319,7 @@ __kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __globa
     const int il = tid - 32 * ip;
     const int is = 8 * ip + il / 16;
 
-    __global float *y = yy + i * 256 + 128 * ip + il;
+    __global float *y = yy + i * QK_K + 128 * ip + il;
 
     const float d = vload_half(0, &x[i].d);
 
@@ -323,161 +333,383 @@ __kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __globa
     y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
 }
 
+__kernel void dequantize_mul_mat_vec_q2_K(__global const struct block_q2_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
 
-void vec_dot_q2_K(__global const struct block_q2_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
+    const int row = get_group_id(0);
 
-    int n = iqs / 128;
-    int r = iqs - 128 * n;
-    int l = r / 8;
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
 
-    __global const float *y = yy + 128 * n + l;
-    __global const uint8_t *q = x[ib].qs + 32 * n + l;
-    __global const uint8_t *s = x[ib].scales + 8 * n;
+    __global const struct block_q2_K * x = xx + ib0;
 
-    const float dall = vload_half(0, &x[ib].d);
-    const float dmin = vload_half(0, &x[ib].dmin);
+    const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION;  // 0...31 or 0...15
+    const int ix  = get_local_id(0)%K_QUANTS_PER_ITERATION;  // 0 or 0,1
 
-    float sum = y[  0] * (dall * ((s[0] & 0xF) * ((q[ 0] >> 0) & 3)) - dmin * (s[0] >> 4))
-              + y[ 32] * (dall * ((s[2] & 0xF) * ((q[ 0] >> 2) & 3)) - dmin * (s[2] >> 4))
-              + y[ 64] * (dall * ((s[4] & 0xF) * ((q[ 0] >> 4) & 3)) - dmin * (s[4] >> 4))
-              + y[ 96] * (dall * ((s[6] & 0xF) * ((q[ 0] >> 6) & 3)) - dmin * (s[6] >> 4))
-              + y[ 16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4))
-              + y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[3] >> 4))
-              + y[ 80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4))
-              + y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4));
+    const int step = 16/K_QUANTS_PER_ITERATION;
 
-    *result = sum;
-}
+    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - step*im;                        // 0...15 or 0...7
 
-void vec_dot_q3_K(__global const struct block_q3_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
+    const int l0 = K_QUANTS_PER_ITERATION*in;            // 0...15 or 0...14 in steps of 2
+    const int q_offset = 32*im + l0;
+    const int s_offset = 8*im;
+    const int y_offset = 128*im + l0;
 
-    const uint32_t kmask1 = 0x03030303;
-    const uint32_t kmask2 = 0x0f0f0f0f;
+    tmp[16 * ix + tid] = 0;
 
-    uint32_t aux[3];
-    uint32_t utmp[4];
+    uint32_t aux[4];
+    const uint8_t * d = (const uint8_t *)aux;
+    const uint8_t * m = (const uint8_t *)(aux + 2);
 
-    int n = iqs/128;
-    int r = iqs - 128*n;
-    int l = r/8;
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
 
-    __global const float   * y = yy + 128*n + l;
-    __global const uint8_t * q = x[ib].qs + 32*n + l;
-    __global const uint8_t * hm = x[ib].hmask + l;
-    const int8_t * s = (const int8_t *)utmp + 8*n;
+        __global const float   * y = yy + i * QK_K + y_offset;
+        __global const uint8_t * q = x[i].qs + q_offset;
 
-    aux[0] = x[ib].scales[0] | x[ib].scales[1] << 8 | x[ib].scales[2] << 16 | x[ib].scales[3] << 24;
-    aux[1] = x[ib].scales[4] | x[ib].scales[5] << 8 | x[ib].scales[6] << 16 | x[ib].scales[7] << 24;
-    aux[2] = x[ib].scales[8] | x[ib].scales[9] << 8 | x[ib].scales[10] << 16 | x[ib].scales[11] << 24;
+        const float dall = vload_half(0, &x[i].d);
+        const float dmin = vload_half(0, &x[i].dmin);
 
-    utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
-    utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
-    utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
-    utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
+        __global const uint32_t * a = (__global const uint32_t *)(x[i].scales + s_offset);
+        aux[0] = a[0] & 0x0f0f0f0f;
+        aux[1] = a[1] & 0x0f0f0f0f;
+        aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
+        aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
 
-    const float dall = vload_half(0, &x[ib].d);
-    const uint8_t m = 1 << (4*n);
+        float sum1 = 0, sum2 = 0;
+        for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
+            sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
+                  + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
+                  + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
+                  + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
+                  + y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
+                  + y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
+                  + y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
+                  +y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
+            sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
+                  + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
 
-    float sum = y[  0] * (s[0] - 32) * (((q[ 0] >> 0) & 3) - (hm[ 0] & (m << 0) ? 0 : 4))
-              + y[ 32] * (s[2] - 32) * (((q[ 0] >> 2) & 3) - (hm[ 0] & (m << 1) ? 0 : 4))
-              + y[ 64] * (s[4] - 32) * (((q[ 0] >> 4) & 3) - (hm[ 0] & (m << 2) ? 0 : 4))
-              + y[ 96] * (s[6] - 32) * (((q[ 0] >> 6) & 3) - (hm[ 0] & (m << 3) ? 0 : 4))
-              + y[ 16] * (s[1] - 32) * (((q[16] >> 0) & 3) - (hm[16] & (m << 0) ? 0 : 4))
-              + y[ 48] * (s[3] - 32) * (((q[16] >> 2) & 3) - (hm[16] & (m << 1) ? 0 : 4))
-              + y[ 80] * (s[5] - 32) * (((q[16] >> 4) & 3) - (hm[16] & (m << 2) ? 0 : 4))
-              + y[112] * (s[7] - 32) * (((q[16] >> 6) & 3) - (hm[16] & (m << 3) ? 0 : 4));
+        }
+        tmp[16 * ix + tid] += dall * sum1 - dmin * sum2;
 
-    *result = sum * dall;
+    }
 
+    // sum up partial sums and write back result
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for (int s=16; s>0; s>>=1) {
+        if (tid < s) {
+            tmp[tid] += tmp[tid + s];
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    if (tid == 0) {
+        dst[row] = tmp[0];
+    }
 }
 
-void vec_dot_q4_K(__global const struct block_q4_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
+__kernel void dequantize_mul_mat_vec_q3_K(__global const struct block_q3_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
+    const uint16_t kmask1 = 0x0303;
+    const uint16_t kmask2 = 0x0f0f;
+
+    const int row = get_group_id(0);
+
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
 
-    const int j  = iqs / 64;        // j  is in 0...3
-    const int ir = (iqs - 64*j)/2;  // ir is in 0...28 in steps of 4
-    const int is = 2*j;             // is is in 0...6 in steps of 2
+    __global const struct block_q3_K * x = xx + ib0;
 
-    __global const float   * y = yy + 64*j + ir;
-    __global const uint8_t * q = x[ib].qs + 32*j + ir;
+    const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
+    const int ix  = get_local_id(0)%K_QUANTS_PER_ITERATION;  // 0 or 0,1
 
-    const float dall = vload_half(0, &x[ib].d);
-    const float dmin = vload_half(0, &x[ib].dmin);
+    const int n  = K_QUANTS_PER_ITERATION;               // iterations in the inner loop
+    const int step = 16/K_QUANTS_PER_ITERATION;
+    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - step*im;                        // 0....15 or 0...7
 
-    uint8_t sc, m;
-    get_scale_min_k4(is + 0, x[ib].scales, &sc, &m);
-    const float d1 = dall * sc;
-    const float m1 = dmin * m;
-    get_scale_min_k4(is + 1, x[ib].scales, &sc, &m);
-    const float d2 = dall * sc;
-    const float m2 = dmin * m;
+    const uint8_t m = 1 << (4*im);
+
+    const int l0 = n*in;                                 // 0...15 or 0...14 in steps of 2
+    const int q_offset =  32*im + l0;
+    const int y_offset = 128*im + l0;
+
+    uint16_t utmp[4];
+    const int8_t * s = (const int8_t *)utmp;
+
+    const uint16_t s_shift = 4*im;
+
+    tmp[16 * ix + tid] = 0;
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        __global const float   * y  = yy + i * QK_K + y_offset;
+        __global const uint8_t * q = x[i].qs + q_offset;
+        __global const uint8_t * h = x[i].hmask + l0;
+
+        __global const uint16_t * a = (__global const uint16_t *)x[i].scales;
+        utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
+        utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
+        utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
+        utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
+
+        const float d = vload_half(0, &x[i].d);
+
+        float sum = 0;
+        for (int l = 0; l < n; ++l) {
+            sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
+                 + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
+                 + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
+                 + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
+            sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
+                 + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
+                 + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
+                + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
+        }
+        tmp[16 * ix + tid] += d * sum;
 
-    float sum = 0;
-    for (int k = 0; k < 4; ++k) {
-        sum += y[k +  0] * (d1 * (q[k] & 0xF) - m1);
-        sum += y[k + 32] * (d2 * (q[k] >>  4) - m2);
     }
 
-    *result = sum;
+    // sum up partial sums and write back result
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for (int s=16; s>0; s>>=1) {
+        if (tid < s) {
+            tmp[tid] += tmp[tid + s];
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    if (tid == 0) {
+        dst[row] = tmp[0];
+    }
 }
 
-void vec_dot_q5_K(__global const struct block_q5_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
+__kernel void dequantize_mul_mat_vec_q4_K(__global const struct block_q4_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
 
-    const int j  = iqs / 64;
-    const int ir = (iqs - 64*j)/2;
-    const int is = 2*j;
+    //to rename it later, just to test now
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
 
-    __global const float   * y  = yy + 64*j + ir;
-    __global const uint8_t * ql = x[ib].qs + 32*j + ir;
-    __global const uint8_t * qh = x[ib].qh + ir;
+    const int row = get_group_id(0);
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
 
-    const float dall = vload_half(0, &x[ib].d);
-    const float dmin = vload_half(0, &x[ib].dmin);
+    const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION;  // 0...15
+    const int ix  = get_local_id(0)%K_QUANTS_PER_ITERATION;
 
-    uint8_t sc, m;
-    get_scale_min_k4(is + 0, x[ib].scales, &sc, &m);
-    const float d1 = dall * sc;
-    const float m1 = dmin * m;
-    get_scale_min_k4(is + 1, x[ib].scales, &sc, &m);
-    const float d2 = dall * sc;
-    const float m2 = dmin * m;
+    const int step = 8/K_QUANTS_PER_ITERATION;
+
+    const int il  = tid/step;     // 0...3
+    const int ir  = tid - step*il;// 0...3
+    const int n   = 2*K_QUANTS_PER_ITERATION;
+
+    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+    const int in = il%2;
+
+    const int l0 = n*(2*ir + in);
+    const int q_offset = 32*im + l0;
+    const int y_offset = 64*im + l0;
+
+    uint16_t aux[4];
+    const uint8_t * sc = (const uint8_t *)aux;
+
+    __global const struct block_q4_K * x = xx + ib0;
+
+    tmp[16 * ix + tid] = 0;
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        __global const uint8_t * q1 = x[i].qs + q_offset;
+        __global const uint8_t * q2 = q1 + 64;
+        __global const float   * y1 = yy + i*QK_K + y_offset;
+        __global const float   * y2 = y1 + 128;
+
+        const float dall = vload_half(0, &x[i].d);
+        const float dmin = vload_half(0, &x[i].dmin);
+
+        __global const uint16_t * a = (__global const uint16_t *)x[i].scales;
+        aux[0] = a[im+0] & kmask1;
+        aux[1] = a[im+2] & kmask1;
+        aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
+        aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
+
+        float4 s = (float4)(0.f);
+        float smin = 0;
+        for (int l = 0; l < n; ++l) {
+            s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4);
+            s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4);
+            smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
+        }
+        tmp[16 * ix + tid] += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
 
-    uint8_t hm  = 1 << is;
-    float sum = 0;
-    for (int k = 0; k < 4; ++k) {
-        sum += y[k +  0] * (d1 * ((ql[k] & 0xF) + (qh[k] & hm ? 16 : 0)) - m1);
     }
-    hm <<= 1;
-    for (int k = 0; k < 4; ++k) {
-        sum += y[k + 32] * (d2 * ((ql[k] >>  4) + (qh[k] & hm ? 16 : 0)) - m2);
+
+    // sum up partial sums and write back result
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for (int s=16; s>0; s>>=1) {
+        if (tid < s) {
+            tmp[tid] += tmp[tid + s];
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    if (tid == 0) {
+        dst[row] = tmp[0];
+    }
+}
+
+__kernel void dequantize_mul_mat_vec_q5_K(__global const struct block_q5_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
+
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
+    const int row = get_group_id(0);
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
+
+    const int tid = get_local_id(0)/2;  // 0...15
+    const int ix  = get_local_id(0)%2;
+
+    const int il  = tid/4;     // 0...3
+    const int ir  = tid - 4*il;// 0...3
+    const int n   = 2;
+
+    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
+    const int in = il%2;
+
+    const int l0 = n*(2*ir + in);
+    const int q_offset = 32*im + l0;
+    const int y_offset = 64*im + l0;
+
+    const uint8_t hm1  = 1 << (2*im);
+    const uint8_t hm2  = hm1 << 4;
+
+    uint16_t aux[4];
+    const uint8_t * sc = (const uint8_t *)aux;
+
+    __global const struct block_q5_K * x = xx + ib0;
+
+    tmp[16 * ix + tid] = 0;
+
+    for (int i = ix; i < num_blocks_per_row; i += 2) {
+
+        __global const uint8_t * ql1 = x[i].qs + q_offset;
+        __global const uint8_t * ql2 = ql1 + 64;
+        __global const uint8_t * qh  = x[i].qh + l0;
+        __global const float   * y1  = yy + i*QK_K + y_offset;
+        __global const float   * y2  = y1 + 128;
+
+        const float dall = vload_half(0, &x[i].d);
+        const float dmin = vload_half(0, &x[i].dmin);
+
+        __global const uint16_t * a = (__global const uint16_t *)x[i].scales;
+        aux[0] = a[im+0] & kmask1;
+        aux[1] = a[im+2] & kmask1;
+        aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
+        aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
+
+        float4 sum = (float4)(0.f);
+        float smin = 0;
+        for (int l = 0; l < n; ++l) {
+            sum.x += y1[l+ 0] * ((ql1[l+ 0] & 0xF) + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
+                   + y1[l+16] * ((ql1[l+16] & 0xF) + (qh[l+16] & (hm1 << 0) ? 16 : 0));
+            sum.y += y1[l+32] * ((ql1[l+ 0] >>  4) + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
+                   + y1[l+48] * ((ql1[l+16] >>  4) + (qh[l+16] & (hm1 << 1) ? 16 : 0));
+            sum.z += y2[l+ 0] * ((ql2[l+ 0] & 0xF) + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
+                   + y2[l+16] * ((ql2[l+16] & 0xF) + (qh[l+16] & (hm2 << 0) ? 16 : 0));
+            sum.w += y2[l+32] * ((ql2[l+ 0] >>  4) + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
+                   + y2[l+48] * ((ql2[l+16] >>  4) + (qh[l+16] & (hm2 << 1) ? 16 : 0));
+            smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
+                  + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
+        }
+        tmp[16 * ix + tid] += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
+
     }
-    *result = sum;
 
+    // sum up partial sums and write back result
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for (int s=16; s>0; s>>=1) {
+        if (tid < s) {
+            tmp[tid] += tmp[tid + s];
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    if (tid == 0) {
+        dst[row] = tmp[0];
+    }
 }
 
-void vec_dot_q6_K(__global const struct block_q6_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
+__kernel void dequantize_mul_mat_vec_q6_K(__global const struct block_q6_K * xx, __local float* tmp, __global const float * yy, __global float * dst, const int ncols) {
+
+    const int row = get_group_id(0);
 
+    const int num_blocks_per_row = ncols / QK_K;
+    const int ib0 = row*num_blocks_per_row;
 
-    const int ip = iqs / 128;        // 0 or 1
-    const int il = (iqs - 128*ip)/8; // 0...15
-    const int is = 8*ip;
+    __global const struct block_q6_K * x = xx + ib0;
 
-    __global const float * y = yy + 128*ip + il;
+    const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
+    const int ix  = get_local_id(0)%K_QUANTS_PER_ITERATION;  // 0 or 0, 1
 
-    const float d = vload_half(0, &x[ib].d);
+    const int step = 16/K_QUANTS_PER_ITERATION;          // 16 or 8
 
-    __global const uint8_t * ql = x[ib].ql + 64*ip + il;
-    __global const uint8_t * qh = x[ib].qh + 32*ip + il;
-    __global const int8_t  * sc = x[ib].scales + is;
+    const int im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
+    const int in = tid - step*im;                        // 0...15 or 0...7
+
+#if K_QUANTS_PER_ITERATION == 1
+    const int l0 = K_QUANTS_PER_ITERATION*in;            // 0...15
+    const int is = 0;
+#else
+    const int l0 = 4 * in;                               // 0, 4, 8, ..., 28
+    const int is = in / 4;
+#endif
+    const int ql_offset = 64*im + l0;
+    const int qh_offset = 32*im + l0;
+    const int s_offset  =  8*im + is;
+    const int y_offset = 128*im + l0;
+
+    tmp[16 * ix + tid] = 0; // partial sum for thread in warp
+
+    for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
+
+        __global const float   * y  = yy + i * QK_K + y_offset;
+        __global const uint8_t * ql = x[i].ql + ql_offset;
+        __global const uint8_t * qh = x[i].qh + qh_offset;
+        __global const int8_t  * s  = x[i].scales + s_offset;
+
+        const float d = vload_half(0, &x[i].d);
+
+#if K_QUANTS_PER_ITERATION == 1
+        float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
+                  + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
+                  + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
+                  + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
+                  + y[64] * s[4] * d * ((int8_t)((ql[ 0]  >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
+                  + y[80] * s[5] * d * ((int8_t)((ql[16]  >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
+                  + y[96] * s[6] * d * ((int8_t)((ql[32]  >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
+                  +y[112] * s[7] * d * ((int8_t)((ql[48]  >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
+        tmp[16 * ix + tid] += sum;
+#else
+        float sum = 0;
+        for (int l = 0; l < 4; ++l) {
+            sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
+                 + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
+                 + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
+                 + y[l+96] * s[6] * d * ((int8_t)((ql[l+32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
+        }
+        tmp[16 * ix + tid] += sum;
+#endif
 
-    *result = y[  0] * d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh[ 0] >> 0) & 3) << 4)) - 32)
-           + y[ 32] * d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh[ 0] >> 2) & 3) << 4)) - 32)
-           + y[ 64] * d * sc[4] * ((int8_t)((ql[ 0]  >> 4) | (((qh[ 0] >> 4) & 3) << 4)) - 32)
-           + y[ 96] * d * sc[6] * ((int8_t)((ql[32]  >> 4) | (((qh[ 0] >> 6) & 3) << 4)) - 32)
-           + y[ 16] * d * sc[1] * ((int8_t)((ql[16] & 0xF) | (((qh[16] >> 0) & 3) << 4)) - 32)
-           + y[ 48] * d * sc[3] * ((int8_t)((ql[48] & 0xF) | (((qh[16] >> 2) & 3) << 4)) - 32)
-           + y[ 80] * d * sc[5] * ((int8_t)((ql[16]  >> 4) | (((qh[16] >> 4) & 3) << 4)) - 32)
-           + y[112] * d * sc[7] * ((int8_t)((ql[48]  >> 4) | (((qh[16] >> 6) & 3) << 4)) - 32);
+    }
 
+    // sum up partial sums and write back result
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for (int s=16; s>0; s>>=1) {
+        if (tid < s) {
+            tmp[tid] += tmp[tid + s];
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    if (tid == 0) {
+        dst[row] = tmp[0];
+    }
 }
 
 );
@@ -549,44 +781,6 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float
 }
 );
 
-std::string dequant_mul_mat_vec_k_template = MULTILINE_QUOTE(
-__kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
-    const int block_size = get_local_size(0);
-    const int row = get_group_id(0);
-    const int tid = get_local_id(0);
-
-    const int iter_stride = 256;
-    const int vals_per_iter = iter_stride / block_size;
-    const int num_blocks_per_row = ncols / 256;
-    const int ib0 = row*num_blocks_per_row;
-
-    tmp[tid] = 0;
-
-    for (int i = 0; i < ncols; i += iter_stride) {
-        const int col = i + vals_per_iter*tid;
-        const int ib = ib0 + col/256; // x block index
-        const int iqs = col%256; // x quant index
-        const int iybs = col - col%256; // y block start index
-
-        // dequantize
-        float v;
-        DOT_KERNEL(x, ib, iqs, y + iybs, &v);
-        tmp[tid] += v;
-    }
-
-    // sum up partial sums and write back result
-    barrier(CLK_LOCAL_MEM_FENCE);
-    for (int s=block_size/2; s>0; s>>=1) {
-        if (tid < s) {
-            tmp[tid] += tmp[tid + s];
-        }
-        barrier(CLK_LOCAL_MEM_FENCE);
-    }
-    if (tid == 0) {
-        dst[row] = tmp[0];
-    }
-}
-);
 
 std::string mul_template = MULTILINE_QUOTE(
 __kernel void KERNEL_NAME(__global TYPE* x, const int x_offset, __global TYPE* y, const int y_offset, __global TYPE* dst, const int dst_offset, const int ky) {
@@ -649,18 +843,6 @@ std::array<std::string, 2> mul_str_values = {
     "mul_f32", "float"
 };
 
-std::array<std::string, 3> dmmv_k_str_keys = {
-    "KERNEL_NAME", "X_TYPE", "DOT_KERNEL"
-};
-
-std::array<std::string, 15> dmmv_k_str_values = {
-    "dequantize_mul_mat_vec_q2_K", "struct block_q2_K", "vec_dot_q2_K",
-    "dequantize_mul_mat_vec_q3_K", "struct block_q3_K", "vec_dot_q3_K",
-    "dequantize_mul_mat_vec_q4_K", "struct block_q4_K", "vec_dot_q4_K",
-    "dequantize_mul_mat_vec_q5_K", "struct block_q5_K", "vec_dot_q5_K",
-    "dequantize_mul_mat_vec_q6_K", "struct block_q6_K", "vec_dot_q6_K",
-};
-
 std::string& replace(std::string& s, const std::string& from, const std::string& to) {
     size_t pos = 0;
     while ((pos = s.find(from, pos)) != std::string::npos) {
@@ -673,6 +855,7 @@ std::string& replace(std::string& s, const std::string& from, const std::string&
 std::string generate_kernels() {
     std::stringstream src;
     src << program_source << '\n';
+    src << k_quants_source << '\n';
     for (size_t i = 0; i < dequant_str_values.size(); i += dequant_str_keys.size()) {
         std::string dequant_kernel = dequant_template;
         std::string dmmv_kernel = dequant_mul_mat_vec_template;
@@ -690,13 +873,6 @@ std::string generate_kernels() {
         }
         src << mul_kernel << '\n';
     }
-    for (size_t i = 0; i < dmmv_k_str_values.size(); i += dmmv_k_str_keys.size()) {
-        std::string dmmv_k_kernel = dequant_mul_mat_vec_k_template;
-        for (size_t j = 0; j < dmmv_k_str_keys.size(); j++) {
-            replace(dmmv_k_kernel, dmmv_k_str_keys[j], dmmv_k_str_values[i + j]);
-        }
-        src << dmmv_k_kernel << '\n';
-    }
 
     return src.str();
 }
@@ -729,10 +905,11 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co
         exit(1);
     }
 
-    const char* compile_opts = "-cl-mad-enable -cl-unsafe-math-optimizations -cl-finite-math-only -cl-fast-relaxed-math "
-                               "-DQK4_0=32 -DQR4_0=2 -DQK4_1=32 -DQR4_1=2 -DQK5_0=32 -DQR5_0=2 -DQK5_1=32 -DQR5_1=2 -DQK8_0=32 -DQR8_0=1";
+    std::string compile_opts = "-cl-mad-enable -cl-unsafe-math-optimizations -cl-finite-math-only -cl-fast-relaxed-math "
+                               "-DQK4_0=32 -DQR4_0=2 -DQK4_1=32 -DQR4_1=2 -DQK5_0=32 -DQR5_0=2 -DQK5_1=32 -DQR5_1=2 -DQK8_0=32 -DQR8_0=1 "
+                               "-DQK_K=256 -DK_QUANTS_PER_ITERATION=" + std::to_string(K_QUANTS_PER_ITERATION);
 
-    err = clBuildProgram(p, 0, NULL, compile_opts, NULL, NULL);
+    err = clBuildProgram(p, 0, NULL, compile_opts.c_str(), NULL, NULL);
     if(err < 0) {
 
         clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);
diff --git a/ggml.c b/ggml.c
index 9dcaafa9605b63b3a64cf27c2f195e8bbdbd2038..b7cfcf4ec82c93a6200948cda3447e7358b0ada3 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -90,6 +90,11 @@ static int sched_yield (void) {
 #include <stdatomic.h>
 
 typedef void* thread_ret_t;
+
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
 #endif
 
 // __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
@@ -118,6 +123,30 @@ typedef void* thread_ret_t;
 #define GGML_SOFT_MAX_UNROLL 4
 #define GGML_VEC_DOT_UNROLL  2
 
+//
+// logging
+//
+
+#if (GGML_DEBUG >= 1)
+#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG(...)
+#endif
+
+#if (GGML_DEBUG >= 5)
+#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG_5(...)
+#endif
+
+#if (GGML_DEBUG >= 10)
+#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG_10(...)
+#endif
+
+#define GGML_PRINT(...) printf(__VA_ARGS__)
+
 #ifdef GGML_USE_ACCELERATE
 // uncomment to use vDSP for soft max computation
 // note: not sure if it is actually faster
@@ -190,9 +219,27 @@ inline static void* ggml_aligned_malloc(size_t size) {
 #define GGML_ALIGNED_FREE(ptr)     free(ptr)
 #endif
 
-#define UNUSED(x) (void)(x)
+#define UNUSED GGML_UNUSED
 #define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
 
+//
+// tensor access macros
+//
+
+#define GGML_TENSOR_UNARY_OP_LOCALS \
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); \
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb); \
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne); \
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb);
+
+#define GGML_TENSOR_BINARY_OP_LOCALS \
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); \
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb); \
+    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); \
+    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb); \
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne); \
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb);
+
 #if defined(GGML_USE_ACCELERATE)
 #include <Accelerate/Accelerate.h>
 #if defined(GGML_USE_CLBLAST) // allow usage of CLBlast alongside Accelerate functions
@@ -458,7 +505,6 @@ void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n) {
     }
 }
 
-
 //
 // timing
 //
@@ -521,6 +567,7 @@ int64_t ggml_cycles_per_ms(void) {
 #define ggml_perf_cycles_per_ms() 0
 #endif
 
+
 //
 // cache line
 //
@@ -3417,6 +3464,8 @@ inline static void ggml_vec_log_f32  (const int n, float * y, const float * x) {
 inline static void ggml_vec_abs_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
 inline static void ggml_vec_sgn_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
 inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
+inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]);  }
+inline static void ggml_vec_elu_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
 inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
 
 static const float GELU_COEF_A    = 0.044715f;
@@ -3568,6 +3617,16 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
     *s = 1.f/(*s);
 }
 
+inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) {
+    float max = -INFINITY;
+    int idx = 0;
+    for (int i = 0; i < n; ++i) {
+        max = MAX(max, x[i]);
+        if (max == x[i]) { idx = i; }
+    }
+    *s = idx;
+}
+
 //
 // data types
 //
@@ -3677,12 +3736,15 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "SUM",
     "SUM_ROWS",
     "MEAN",
+    "ARGMAX",
     "REPEAT",
     "REPEAT_BACK",
     "ABS",
     "SGN",
     "NEG",
     "STEP",
+    "TANH",
+    "ELU",
     "RELU",
     "GELU",
     "GELU_QUICK",
@@ -3714,9 +3776,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "ROPE_BACK",
     "ALIBI",
     "CLAMP",
-    "CONV_1D_S1_PH",
-    "CONV_1D_S2_PH",
-    "CONV_2D_SK_P0",
+    "CONV_1D",
+    "CONV_2D",
 
     "FLASH_ATTN",
     "FLASH_FF",
@@ -3735,7 +3796,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CROSS_ENTROPY_LOSS_BACK",
 };
 
-static_assert(GGML_OP_COUNT == 64, "GGML_OP_COUNT != 64");
+static_assert(GGML_OP_COUNT == 66, "GGML_OP_COUNT != 66");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -3753,12 +3814,15 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "Σx",
     "Σx_k",
     "Σx/n",
+    "argmax(x)",
     "repeat(x)",
     "repeat_back(x)",
     "abs(x)",
     "sgn(x)",
     "-x",
     "step(x)",
+    "tanh(x)",
+    "elu(x)",
     "relu(x)",
     "gelu(x)",
     "gelu_quick(x)",
@@ -3790,9 +3854,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "rope_back(x)",
     "alibi(x)",
     "clamp(x)",
-    "conv_1d_s1_ph(x)",
-    "conv_1d_s2_ph(x)",
-    "conv_2d_sk_p0(x)",
+    "conv_1d(x)",
+    "conv_2d(x)",
 
     "flash_attn(x)",
     "flash_ff(x)",
@@ -3811,11 +3874,45 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "cross_entropy_loss_back(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 64, "GGML_OP_COUNT != 64");
+static_assert(GGML_OP_COUNT == 66, "GGML_OP_COUNT != 66");
 
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
 static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
 
+// WARN:
+// Mis-confguration can lead to problem that's hard to reason about:
+// * At best  it crash or talks nosense.
+// * At worst it talks slightly difference but hard to perceive.
+//
+// An op has to enable INIT or FINALIZE when any of it's branch needs that pass.
+// Take care about compile options (e.g., GGML_USE_xxx).
+static bool GGML_OP_HAS_INIT    [GGML_OP_COUNT] = { 0 };
+static bool GGML_OP_HAS_FINALIZE[GGML_OP_COUNT] = { 0 };
+
+static void ggml_setup_op_has_task_pass(void) {
+    {   // INIT
+        bool * p = GGML_OP_HAS_INIT;
+
+        p[GGML_OP_ACC                    ] = true;
+        p[GGML_OP_MUL_MAT                ] = true;
+        p[GGML_OP_OUT_PROD               ] = true;
+        p[GGML_OP_SET                    ] = true;
+        p[GGML_OP_GET_ROWS_BACK          ] = true;
+        p[GGML_OP_DIAG_MASK_INF          ] = true;
+        p[GGML_OP_DIAG_MASK_ZERO         ] = true;
+        p[GGML_OP_CONV_1D                ] = true;
+        p[GGML_OP_CONV_2D                ] = true;
+        p[GGML_OP_FLASH_ATTN_BACK        ] = true;
+        p[GGML_OP_CROSS_ENTROPY_LOSS     ] = true;
+    }
+
+    {   // FINALIZE
+        bool * p = GGML_OP_HAS_FINALIZE;
+
+        p[GGML_OP_CROSS_ENTROPY_LOSS     ] = true;
+    }
+}
+
 //
 // ggml context
 //
@@ -3842,12 +3939,31 @@ struct ggml_context_container {
     struct ggml_context context;
 };
 
+//
+// NUMA support
+//
+
+#define GGML_NUMA_MAX_NODES 8
+#define GGML_NUMA_MAX_CPUS 512
+
+struct ggml_numa_node {
+    uint32_t cpus[GGML_NUMA_MAX_CPUS]; // hardware threads on this node
+    uint32_t n_cpus;
+};
+
+struct ggml_numa_nodes {
+    struct ggml_numa_node nodes[GGML_NUMA_MAX_NODES];
+    uint32_t n_nodes;
+    uint32_t total_cpus; // hardware threads on system
+};
+
 //
 // ggml state
 //
 
 struct ggml_state {
     struct ggml_context_container contexts[GGML_MAX_CONTEXTS];
+    struct ggml_numa_nodes numa;
 };
 
 // global state
@@ -3872,6 +3988,75 @@ inline static void ggml_critical_section_end(void) {
     atomic_fetch_sub(&g_state_barrier, 1);
 }
 
+void ggml_numa_init(void) {
+    if (g_state.numa.n_nodes > 0) {
+        fprintf(stderr, "ggml_numa_init: NUMA already initialized\n");
+
+        return;
+    }
+
+#ifdef __linux__
+    struct stat st;
+    char path[256];
+    int rv;
+
+    // enumerate nodes
+    while (g_state.numa.n_nodes < GGML_NUMA_MAX_NODES) {
+        rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u", g_state.numa.n_nodes);
+        GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
+        if (stat(path, &st) != 0) { break; }
+        ++g_state.numa.n_nodes;
+    }
+
+    // enumerate CPUs
+    while (g_state.numa.total_cpus < GGML_NUMA_MAX_CPUS) {
+        rv = snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%u", g_state.numa.total_cpus);
+        GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
+        if (stat(path, &st) != 0) { break; }
+        ++g_state.numa.total_cpus;
+    }
+
+    GGML_PRINT_DEBUG("found %u numa nodes, %u CPUs\n", g_state.numa.n_nodes, g_state.numa.total_cpus);
+
+    if (g_state.numa.n_nodes < 1 || g_state.numa.total_cpus < 1) {
+        g_state.numa.n_nodes = 0;
+        return;
+    }
+
+    for (uint32_t n = 0; n < g_state.numa.n_nodes; ++n) {
+        struct ggml_numa_node * node = &g_state.numa.nodes[n];
+        GGML_PRINT_DEBUG("CPUs on node %u:", n);
+        node->n_cpus = 0;
+        for (uint32_t c = 0; c < g_state.numa.total_cpus; ++c) {
+            rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u/cpu%u", n, c);
+            GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
+            if (stat(path, &st) == 0) {
+                node->cpus[node->n_cpus++] = c;
+                GGML_PRINT_DEBUG(" %u", c);
+            }
+        }
+        GGML_PRINT_DEBUG("\n");
+    }
+
+    if (ggml_is_numa()) {
+        FILE *fptr = fopen("/proc/sys/kernel/numa_balancing", "r");
+        if (fptr != NULL) {
+            char buf[42];
+            if (fgets(buf, sizeof(buf), fptr) && strncmp(buf, "0\n", sizeof(buf)) != 0) {
+                GGML_PRINT("WARNING: /proc/sys/kernel/numa_balancing is enabled, this has been observed to impair performance\n");
+            }
+            fclose(fptr);
+        }
+    }
+#else
+    // TODO
+#endif
+}
+
+bool ggml_is_numa(void) {
+    return g_state.numa.n_nodes > 1;
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 
 void ggml_print_object(const struct ggml_object * obj) {
@@ -4128,6 +4313,10 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
 
             g_state = (struct ggml_state) {
                 /*.contexts =*/ { { 0 } },
+                /*.numa =*/ {
+                    .n_nodes = 0,
+                    .total_cpus = 0,
+                },
             };
 
             for (int i = 0; i < GGML_MAX_CONTEXTS; ++i) {
@@ -4145,6 +4334,8 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
         ggml_cl_init();
 #endif
 
+        ggml_setup_op_has_task_pass();
+
         is_first_call = false;
     }
 
@@ -5281,6 +5472,30 @@ struct ggml_tensor * ggml_mean(
     return result;
 }
 
+// ggml_argmax
+
+struct ggml_tensor * ggml_argmax(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a) {
+    GGML_ASSERT(ggml_is_matrix(a));
+    bool is_node = false;
+
+    if (a->grad) {
+        GGML_ASSERT(false);
+        is_node = true;
+    }
+
+    int64_t ne[GGML_MAX_DIMS] = { a->ne[1], 1, 1, 1 };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, a->n_dims, ne);
+
+    result->op   = GGML_OP_ARGMAX;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
 // ggml_repeat
 
 struct ggml_tensor * ggml_repeat(
@@ -5474,6 +5689,74 @@ struct ggml_tensor * ggml_step_inplace(
     return ggml_step_impl(ctx, a, true);
 }
 
+// ggml_tanh
+
+struct ggml_tensor * ggml_tanh_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_TANH;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_tanh(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_tanh_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_tanh_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_tanh_impl(ctx, a, true);
+}
+
+// ggml_elu
+
+struct ggml_tensor * ggml_elu_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_ELU;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_elu(
+    struct ggml_context * ctx,
+    struct ggml_tensor  * a) {
+    return ggml_elu_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_elu_inplace(
+    struct ggml_context * ctx,
+    struct ggml_tensor  * a) {
+    return ggml_elu_impl(ctx, a, true);
+}
+
 // ggml_relu
 
 struct ggml_tensor * ggml_relu_impl(
@@ -6656,6 +6939,7 @@ struct ggml_tensor * ggml_rope_impl(
         int                   n_past,
         int                   n_dims,
         int                   mode,
+        int                   n_ctx,
         bool                  inplace) {
     GGML_ASSERT(n_past >= 0);
     bool is_node = false;
@@ -6668,11 +6952,12 @@ struct ggml_tensor * ggml_rope_impl(
 
     ggml_scratch_save(ctx);
 
-    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
+    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
 
     ((int32_t *) b->data)[0] = n_past;
     ((int32_t *) b->data)[1] = n_dims;
     ((int32_t *) b->data)[2] = mode;
+    ((int32_t *) b->data)[3] = n_ctx;
 
     ggml_scratch_load(ctx);
 
@@ -6689,8 +6974,9 @@ struct ggml_tensor * ggml_rope(
         struct ggml_tensor  * a,
         int                   n_past,
         int                   n_dims,
-        int                   mode) {
-    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, false);
+        int                   mode,
+        int                   n_ctx) {
+    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, false);
 }
 
 struct ggml_tensor * ggml_rope_inplace(
@@ -6698,8 +6984,9 @@ struct ggml_tensor * ggml_rope_inplace(
         struct ggml_tensor  * a,
         int                   n_past,
         int                   n_dims,
-        int                   mode) {
-    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, true);
+        int                   mode,
+        int                   n_ctx) {
+    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, true);
 }
 
 // ggml_rope_back
@@ -6711,6 +6998,8 @@ struct ggml_tensor * ggml_rope_back(
         int                   n_dims,
         int                   mode) {
     GGML_ASSERT(n_past >= 0);
+    GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
+
     bool is_node = false;
 
     if (a->grad) {
@@ -6811,15 +7100,21 @@ struct ggml_tensor * ggml_clamp(
     return result;
 }
 
-// ggml_conv_1d_s1_ph
+// ggml_conv_1d
+
+static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
+    return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
+}
 
-struct ggml_tensor * ggml_conv_1d_s1_ph(
+GGML_API struct ggml_tensor * ggml_conv_1d(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        struct ggml_tensor  * b) {
+        struct ggml_tensor  * b,
+        int                   s0,
+        int                   p0,
+        int                   d0) {
     GGML_ASSERT(ggml_is_matrix(b));
     GGML_ASSERT(a->ne[1] == b->ne[1]);
-    GGML_ASSERT(a->ne[3] == 1);
     bool is_node = false;
 
     if (a->grad || b->grad) {
@@ -6827,26 +7122,43 @@ struct ggml_tensor * ggml_conv_1d_s1_ph(
         is_node = true;
     }
 
-    const int64_t ne[4] = { b->ne[0], a->ne[2], 1, 1, };
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
+    const int64_t ne[4] = {
+        ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0),
+        a->ne[2], 1, 1,
+    };
+    struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
+
+    ggml_scratch_save(ctx);
+    struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
+    ((int32_t*)c->data)[0] = s0;
+    ((int32_t*)c->data)[1] = p0;
+    ((int32_t*)c->data)[2] = d0;
+    ggml_scratch_load(ctx);
 
-    result->op   = GGML_OP_CONV_1D_S1_PH;
+    result->op = GGML_OP_CONV_1D;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
     result->src1 = b;
+    result->opt[0] = c;
 
     return result;
 }
 
-// ggml_conv_1d_s2_ph
+// ggml_conv_2d
 
-struct ggml_tensor * ggml_conv_1d_s2_ph(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b) {
-    GGML_ASSERT(ggml_is_matrix(b));
-    GGML_ASSERT(a->ne[1] == b->ne[1]);
-    GGML_ASSERT(a->ne[3] == 1);
+struct ggml_tensor* ggml_conv_2d(
+    struct ggml_context* ctx,
+    struct ggml_tensor * a,
+    struct ggml_tensor * b,
+    int                  s0,
+    int                  s1,
+    int                  p0,
+    int                  p1,
+    int                  d0,
+    int                  d1) {
+
+    GGML_ASSERT(b->ne[3] == 1);
+    GGML_ASSERT(a->ne[2] == b->ne[2]);
     bool is_node = false;
 
     if (a->grad || b->grad) {
@@ -6854,43 +7166,42 @@ struct ggml_tensor * ggml_conv_1d_s2_ph(
         is_node = true;
     }
 
-    const int64_t ne[4] = { b->ne[0]/2, a->ne[2], 1, 1, };
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
+    const int64_t ne[4] = {
+        ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0),
+        ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1),
+        a->ne[3], 1,
+    };
+    struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    ggml_scratch_save(ctx);
+    struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 6);
+    ((int32_t*)c->data)[0] = s0;
+    ((int32_t*)c->data)[1] = s1;
+    ((int32_t*)c->data)[2] = p0;
+    ((int32_t*)c->data)[3] = p1;
+    ((int32_t*)c->data)[4] = d0;
+    ((int32_t*)c->data)[5] = d1;
+    ggml_scratch_load(ctx);
 
-    result->op   = GGML_OP_CONV_1D_S2_PH;
+    result->op = GGML_OP_CONV_2D;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
     result->src1 = b;
+    result->opt[0] = c;
 
     return result;
+
 }
 
-// ggml_conv_2d_sk_p0
+// ggml_conv_1d_ph
 
-struct ggml_tensor * ggml_conv_2d_sk_p0(
+struct ggml_tensor* ggml_conv_1d_ph(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        struct ggml_tensor  * b) {
-    GGML_ASSERT(b->ne[3] == 1);
-    GGML_ASSERT(a->ne[2] == b->ne[2]);
-    GGML_ASSERT(b->ne[0] %  a->ne[0] == 0);
-    GGML_ASSERT(b->ne[1] %  a->ne[1] == 0);
-    bool is_node = false;
-
-    if (a->grad || b->grad) {
-        GGML_ASSERT(false); // TODO: implement backward
-        is_node = true;
-    }
-
-    const int64_t ne[4] = { b->ne[0]/a->ne[0], b->ne[1]/a->ne[1], a->ne[3], 1, };
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
-
-    result->op   = GGML_OP_CONV_2D_SK_P0;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src0 = a;
-    result->src1 = b;
-
-    return result;
+        struct ggml_tensor  * b,
+        int                   s,
+        int                   d) {
+    return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
 }
 
 // ggml_flash_attn
@@ -7440,25 +7751,7 @@ static void ggml_compute_forward_dup_f16(
         return;
     }
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3];
-
-    const int64_t ne0 = dst->ne[0];
-    const int64_t ne1 = dst->ne[1];
-    const int64_t ne2 = dst->ne[2];
-    const int64_t ne3 = dst->ne[3];
-
-    const size_t nb00 = src0->nb[0];
-    const size_t nb01 = src0->nb[1];
-    const size_t nb02 = src0->nb[2];
-    const size_t nb03 = src0->nb[3];
-
-    const size_t nb0 = dst->nb[0];
-    const size_t nb1 = dst->nb[1];
-    const size_t nb2 = dst->nb[2];
-    const size_t nb3 = dst->nb[3];
+    GGML_TENSOR_UNARY_OP_LOCALS;
 
     const int ith = params->ith; // thread index
     const int nth = params->nth; // number of threads
@@ -7729,25 +8022,7 @@ static void ggml_compute_forward_dup_f32(
         return;
     }
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3];
-
-    const int64_t ne0 = dst->ne[0];
-    const int64_t ne1 = dst->ne[1];
-    const int64_t ne2 = dst->ne[2];
-    const int64_t ne3 = dst->ne[3];
-
-    const size_t nb00 = src0->nb[0];
-    const size_t nb01 = src0->nb[1];
-    const size_t nb02 = src0->nb[2];
-    const size_t nb03 = src0->nb[3];
-
-    const size_t nb0 = dst->nb[0];
-    const size_t nb1 = dst->nb[1];
-    const size_t nb2 = dst->nb[2];
-    const size_t nb3 = dst->nb[3];
+    GGML_TENSOR_UNARY_OP_LOCALS;
 
     const int ith = params->ith; // thread index
     const int nth = params->nth; // number of threads
@@ -8045,24 +8320,8 @@ static void ggml_compute_forward_add_f32(
     const int nth = params->nth;
 
     const int nr  = ggml_nrows(src0);
-    const int64_t ne0 = src0->ne[0];
-    const int64_t ne1 = src0->ne[1];
-    const int64_t ne2 = src0->ne[2];
 
-    const size_t nb00 = src0->nb[0];
-    const size_t nb01 = src0->nb[1];
-    const size_t nb02 = src0->nb[2];
-    const size_t nb03 = src0->nb[3];
-
-    const size_t nb10 = src1->nb[0];
-    const size_t nb11 = src1->nb[1];
-    const size_t nb12 = src1->nb[2];
-    const size_t nb13 = src1->nb[3];
-
-    const size_t nb0 = dst->nb[0];
-    const size_t nb1 = dst->nb[1];
-    const size_t nb2 = dst->nb[2];
-    const size_t nb3 = dst->nb[3];
+    GGML_TENSOR_BINARY_OP_LOCALS;
 
     GGML_ASSERT( nb0 == sizeof(float));
     GGML_ASSERT(nb00 == sizeof(float));
@@ -8131,28 +8390,12 @@ static void ggml_compute_forward_add_f16_f32(
     const int nth = params->nth;
 
     const int nr  = ggml_nrows(src0);
-    const int64_t ne0 = src0->ne[0];
-    const int64_t ne1 = src0->ne[1];
-    const int64_t ne2 = src0->ne[2];
-
-    const size_t nb00 = src0->nb[0];
-    const size_t nb01 = src0->nb[1];
-    const size_t nb02 = src0->nb[2];
-    const size_t nb03 = src0->nb[3];
-
-    const size_t nb10 = src1->nb[0];
-    const size_t nb11 = src1->nb[1];
-    const size_t nb12 = src1->nb[2];
-    const size_t nb13 = src1->nb[3];
 
-    const size_t nb0 = dst->nb[0];
-    const size_t nb1 = dst->nb[1];
-    const size_t nb2 = dst->nb[2];
-    const size_t nb3 = dst->nb[3];
+    GGML_TENSOR_BINARY_OP_LOCALS;
 
     GGML_ASSERT(src0->type == GGML_TYPE_F16);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F16);
 
     GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
     GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
@@ -8201,24 +8444,8 @@ static void ggml_compute_forward_add_f16_f16(
     const int nth = params->nth;
 
     const int nr  = ggml_nrows(src0);
-    const int64_t ne0 = src0->ne[0];
-    const int64_t ne1 = src0->ne[1];
-    const int64_t ne2 = src0->ne[2];
-
-    const size_t nb00 = src0->nb[0];
-    const size_t nb01 = src0->nb[1];
-    const size_t nb02 = src0->nb[2];
-    const size_t nb03 = src0->nb[3];
 
-    const size_t nb10 = src1->nb[0];
-    const size_t nb11 = src1->nb[1];
-    const size_t nb12 = src1->nb[2];
-    const size_t nb13 = src1->nb[3];
-
-    const size_t nb0 = dst->nb[0];
-    const size_t nb1 = dst->nb[1];
-    const size_t nb2 = dst->nb[2];
-    const size_t nb3 = dst->nb[3];
+    GGML_TENSOR_BINARY_OP_LOCALS;
 
     GGML_ASSERT(src0->type == GGML_TYPE_F16);
     GGML_ASSERT(src1->type == GGML_TYPE_F16);
@@ -8268,25 +8495,8 @@ static void ggml_compute_forward_add_q_f32(
     }
 
     const int nr  = ggml_nrows(src0);
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    //const int64_t ne03 = src0->ne[3];
-
-    const size_t nb00 = src0->nb[0];
-    const size_t nb01 = src0->nb[1];
-    const size_t nb02 = src0->nb[2];
-    const size_t nb03 = src0->nb[3];
 
-    const size_t nb10 = src1->nb[0];
-    const size_t nb11 = src1->nb[1];
-    const size_t nb12 = src1->nb[2];
-    const size_t nb13 = src1->nb[3];
-
-    const size_t nb0  = dst->nb[0];
-    const size_t nb1  = dst->nb[1];
-    const size_t nb2  = dst->nb[2];
-    const size_t nb3  = dst->nb[3];
+    GGML_TENSOR_BINARY_OP_LOCALS;
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -8407,19 +8617,8 @@ static void ggml_compute_forward_add1_f32(
     const int nth = params->nth;
 
     const int nr  = ggml_nrows(src0);
-    const int64_t ne0 = src0->ne[0];
-    const int64_t ne1 = src0->ne[1];
-    const int64_t ne2 = src0->ne[2];
-
-    const size_t nb00 = src0->nb[0];
-    const size_t nb01 = src0->nb[1];
-    const size_t nb02 = src0->nb[2];
-    const size_t nb03 = src0->nb[3];
 
-    const size_t nb0 = dst->nb[0];
-    const size_t nb1 = dst->nb[1];
-    const size_t nb2 = dst->nb[2];
-    const size_t nb3 = dst->nb[3];
+    GGML_TENSOR_UNARY_OP_LOCALS;
 
     GGML_ASSERT( nb0 == sizeof(float));
     GGML_ASSERT(nb00 == sizeof(float));
@@ -8473,23 +8672,12 @@ static void ggml_compute_forward_add1_f16_f32(
     const int nth = params->nth;
 
     const int nr  = ggml_nrows(src0);
-    const int64_t ne0 = src0->ne[0];
-    const int64_t ne1 = src0->ne[1];
-    const int64_t ne2 = src0->ne[2];
 
-    const size_t nb00 = src0->nb[0];
-    const size_t nb01 = src0->nb[1];
-    const size_t nb02 = src0->nb[2];
-    const size_t nb03 = src0->nb[3];
-
-    const size_t nb0 = dst->nb[0];
-    const size_t nb1 = dst->nb[1];
-    const size_t nb2 = dst->nb[2];
-    const size_t nb3 = dst->nb[3];
+    GGML_TENSOR_UNARY_OP_LOCALS;
 
     GGML_ASSERT(src0->type == GGML_TYPE_F16);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F16);
 
     GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
     GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
@@ -8534,23 +8722,12 @@ static void ggml_compute_forward_add1_f16_f16(
     const int nth = params->nth;
 
     const int nr  = ggml_nrows(src0);
-    const int64_t ne0 = src0->ne[0];
-    const int64_t ne1 = src0->ne[1];
-    const int64_t ne2 = src0->ne[2];
 
-    const size_t nb00 = src0->nb[0];
-    const size_t nb01 = src0->nb[1];
-    const size_t nb02 = src0->nb[2];
-    const size_t nb03 = src0->nb[3];
-
-    const size_t nb0 = dst->nb[0];
-    const size_t nb1 = dst->nb[1];
-    const size_t nb2 = dst->nb[2];
-    const size_t nb3 = dst->nb[3];
+    GGML_TENSOR_UNARY_OP_LOCALS;
 
     GGML_ASSERT(src0->type == GGML_TYPE_F16);
     GGML_ASSERT(src1->type == GGML_TYPE_F16);
-    GGML_ASSERT(dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F16);
 
     GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
     GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
@@ -8595,19 +8772,8 @@ static void ggml_compute_forward_add1_q_f32(
     const int nth = params->nth;
 
     const int nr  = ggml_nrows(src0);
-    const int64_t ne0 = src0->ne[0];
-    const int64_t ne1 = src0->ne[1];
-    const int64_t ne2 = src0->ne[2];
 
-    const size_t nb00 = src0->nb[0];
-    const size_t nb01 = src0->nb[1];
-    const size_t nb02 = src0->nb[2];
-    const size_t nb03 = src0->nb[3];
-
-    const size_t nb0 = dst->nb[0];
-    const size_t nb1 = dst->nb[1];
-    const size_t nb2 = dst->nb[2];
-    const size_t nb3 = dst->nb[3];
+    GGML_TENSOR_UNARY_OP_LOCALS;
 
     const enum ggml_type type = src0->type;
     dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
@@ -8739,15 +8905,8 @@ static void ggml_compute_forward_acc_f32(
     const int nr = ggml_nrows(src1);
     const int nc = src1->ne[0];
 
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
-    const int64_t ne12 = src1->ne[2];
-    const int64_t ne13 = src1->ne[3];
-
-    const size_t nb10 = src1->nb[0];
-    const size_t nb11 = src1->nb[1];
-    const size_t nb12 = src1->nb[2];
-    const size_t nb13 = src1->nb[3];
+    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
+    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb);
 
     // src0 and dst as viewed during acc
     const size_t nb0 = ggml_element_size(src0);
@@ -8836,24 +8995,8 @@ static void ggml_compute_forward_sub_f32(
     }
 
     const int nr  = ggml_nrows(src0);
-    const int64_t ne0 = src0->ne[0];
-    const int64_t ne1 = src0->ne[1];
-    const int64_t ne2 = src0->ne[2];
-
-    const size_t nb00 = src0->nb[0];
-    const size_t nb01 = src0->nb[1];
-    const size_t nb02 = src0->nb[2];
-    const size_t nb03 = src0->nb[3];
-
-    const size_t nb10 = src1->nb[0];
-    const size_t nb11 = src1->nb[1];
-    const size_t nb12 = src1->nb[2];
-    const size_t nb13 = src1->nb[3];
 
-    const size_t nb0 = dst->nb[0];
-    const size_t nb1 = dst->nb[1];
-    const size_t nb2 = dst->nb[2];
-    const size_t nb3 = dst->nb[3];
+    GGML_TENSOR_BINARY_OP_LOCALS;
 
     GGML_ASSERT( nb0 == sizeof(float));
     GGML_ASSERT(nb00 == sizeof(float));
@@ -8943,29 +9086,7 @@ static void ggml_compute_forward_mul_f32(
 
     const int64_t nr = ggml_nrows(src0);
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
-    const int64_t ne12 = src1->ne[2];
-    const int64_t ne13 = src1->ne[3];
-
-    const size_t nb00 = src0->nb[0];
-    const size_t nb01 = src0->nb[1];
-    const size_t nb02 = src0->nb[2];
-    const size_t nb03 = src0->nb[3];
-
-    const size_t nb10 = src1->nb[0];
-    const size_t nb11 = src1->nb[1];
-    const size_t nb12 = src1->nb[2];
-    const size_t nb13 = src1->nb[3];
-
-    const size_t nb0 = dst->nb[0];
-    const size_t nb1 = dst->nb[1];
-    const size_t nb2 = dst->nb[2];
-    const size_t nb3 = dst->nb[3];
+    GGML_TENSOR_BINARY_OP_LOCALS;
 
     GGML_ASSERT( nb0 == sizeof(float));
     GGML_ASSERT(nb00 == sizeof(float));
@@ -9053,24 +9174,8 @@ static void ggml_compute_forward_div_f32(
     }
 
     const int nr  = ggml_nrows(src0);
-    const int64_t ne0 = src0->ne[0];
-    const int64_t ne1 = src0->ne[1];
-    const int64_t ne2 = src0->ne[2];
-
-    const size_t nb00 = src0->nb[0];
-    const size_t nb01 = src0->nb[1];
-    const size_t nb02 = src0->nb[2];
-    const size_t nb03 = src0->nb[3];
 
-    const size_t nb10 = src1->nb[0];
-    const size_t nb11 = src1->nb[1];
-    const size_t nb12 = src1->nb[2];
-    const size_t nb13 = src1->nb[3];
-
-    const size_t nb0 = dst->nb[0];
-    const size_t nb1 = dst->nb[1];
-    const size_t nb2 = dst->nb[2];
-    const size_t nb3 = dst->nb[3];
+    GGML_TENSOR_BINARY_OP_LOCALS;
 
     GGML_ASSERT( nb0 == sizeof(float));
     GGML_ASSERT(nb00 == sizeof(float));
@@ -9277,14 +9382,8 @@ static void ggml_compute_forward_sum_f32(
     assert(ggml_is_scalar(dst));
     assert(src0->nb[0] == sizeof(float));
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3];
-
-    const size_t nb01 = src0->nb[1];
-    const size_t nb02 = src0->nb[2];
-    const size_t nb03 = src0->nb[3];
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb);
 
     ggml_float sum     = 0;
     ggml_float row_sum = 0;
@@ -9333,29 +9432,13 @@ static void ggml_compute_forward_sum_rows_f32(
     GGML_ASSERT(src0->nb[0] == sizeof(float));
     GGML_ASSERT(dst->nb[0] == sizeof(float));
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3];
-
-    const int64_t ne0 = dst->ne[0];
-    const int64_t ne1 = dst->ne[1];
-    const int64_t ne2 = dst->ne[2];
-    const int64_t ne3 = dst->ne[3];
+    GGML_TENSOR_UNARY_OP_LOCALS;
 
     GGML_ASSERT(ne0 == 1);
     GGML_ASSERT(ne1 == ne01);
     GGML_ASSERT(ne2 == ne02);
     GGML_ASSERT(ne3 == ne03);
 
-    const size_t nb01 = src0->nb[1];
-    const size_t nb02 = src0->nb[2];
-    const size_t nb03 = src0->nb[3];
-
-    const size_t nb1 = dst->nb[1];
-    const size_t nb2 = dst->nb[2];
-    const size_t nb3 = dst->nb[3];
-
     for (int64_t i3 = 0; i3 < ne03; i3++) {
         for (int64_t i2 = 0; i2 < ne02; i2++) {
             for (int64_t i1 = 0; i1 < ne01; i1++) {
@@ -9399,19 +9482,7 @@ static void ggml_compute_forward_mean_f32(
 
     assert(src0->nb[0] == sizeof(float));
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3];
-
-    const size_t nb01 = src0->nb[1];
-    const size_t nb02 = src0->nb[2];
-    const size_t nb03 = src0->nb[3];
-
-    const int64_t ne0 = dst->ne[0];
-    const int64_t ne1 = dst->ne[1];
-    const int64_t ne2 = dst->ne[2];
-    const int64_t ne3 = dst->ne[3];
+    GGML_TENSOR_UNARY_OP_LOCALS;
 
     assert(ne0 == 1);
     assert(ne1 == ne01);
@@ -9423,10 +9494,6 @@ static void ggml_compute_forward_mean_f32(
     UNUSED(ne2);
     UNUSED(ne3);
 
-    const size_t nb1 = dst->nb[1];
-    const size_t nb2 = dst->nb[2];
-    const size_t nb3 = dst->nb[3];
-
     for (int64_t i03 = 0; i03 < ne03; i03++) {
         for (int64_t i02 = 0; i02 < ne02; i02++) {
             for (int64_t i01 = 0; i01 < ne01; i01++) {
@@ -9456,38 +9523,66 @@ static void ggml_compute_forward_mean(
     }
 }
 
-// ggml_compute_forward_repeat
+// ggml_compute_forward_argmax
 
-static void ggml_compute_forward_repeat_f32(
+static void ggml_compute_forward_argmax_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
-    GGML_ASSERT(params->ith == 0);
-    GGML_ASSERT(ggml_can_repeat(src0, dst));
+    assert(params->ith == 0);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
-    const int64_t ne0  = dst->ne[0];
-    const int64_t ne1  = dst->ne[1];
-    const int64_t ne2  = dst->ne[2];
-    const int64_t ne3  = dst->ne[3];
+    assert(src0->nb[0] == sizeof(float));
+    assert(dst->nb[0] == sizeof(float));
 
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3];
 
-    const size_t nb0  = dst->nb[0];
-    const size_t nb1  = dst->nb[1];
-    const size_t nb2  = dst->nb[2];
-    const size_t nb3  = dst->nb[3];
-
-    const size_t nb00 = src0->nb[0];
     const size_t nb01 = src0->nb[1];
-    const size_t nb02 = src0->nb[2];
-    const size_t nb03 = src0->nb[3];
+    const size_t nb0 = dst->nb[0];
+
+    for (int64_t i1 = 0; i1 < ne01; i1++) {
+        float * src = (float *) ((char *) src0->data + i1*nb01);
+        int32_t * dst_ = (int32_t *) ((char *)  dst->data + i1*nb0);
+        int v = 0;
+        ggml_vec_argmax_f32(ne00, &v, src);
+        dst_[0] = v;
+    }
+}
+
+static void ggml_compute_forward_argmax(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_argmax_f32(params, src0, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_repeat
+
+static void ggml_compute_forward_repeat_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(params->ith == 0);
+    GGML_ASSERT(ggml_can_repeat(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    GGML_TENSOR_UNARY_OP_LOCALS;
 
     // guaranteed to be an integer due to the check in ggml_can_repeat
     const int nr0 = (int)(ne0/ne00);
@@ -9548,25 +9643,7 @@ static void ggml_compute_forward_repeat_back_f32(
         return;
     }
 
-    const int64_t ne0  = dst->ne[0];
-    const int64_t ne1  = dst->ne[1];
-    const int64_t ne2  = dst->ne[2];
-    const int64_t ne3  = dst->ne[3];
-
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3];
-
-    const size_t nb0  = dst->nb[0];
-    const size_t nb1  = dst->nb[1];
-    const size_t nb2  = dst->nb[2];
-    const size_t nb3  = dst->nb[3];
-
-    const size_t nb00 = src0->nb[0];
-    const size_t nb01 = src0->nb[1];
-    const size_t nb02 = src0->nb[2];
-    const size_t nb03 = src0->nb[3];
+    GGML_TENSOR_UNARY_OP_LOCALS;
 
     // guaranteed to be an integer due to the check in ggml_can_repeat
     const int nr0 = (int)(ne00/ne0);
@@ -9796,6 +9873,90 @@ static void ggml_compute_forward_step(
     }
 }
 
+// ggml_compute_forward_tanh
+
+static void ggml_compute_forward_tanh_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert(dst->nb[0]  == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_tanh_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_tanh(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_tanh_f32(params, src0, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+// ggml_compute_forward_elu
+
+static void ggml_compute_forward_elu_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert(dst->nb[0]  == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_elu_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_elu(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_elu_f32(params, src0, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_relu
 
 static void ggml_compute_forward_relu_f32(
@@ -10097,18 +10258,7 @@ static void ggml_compute_forward_norm_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3];
-
-    const size_t nb01 = src0->nb[1];
-    const size_t nb02 = src0->nb[2];
-    const size_t nb03 = src0->nb[3];
-
-    const size_t nb1 = dst->nb[1];
-    const size_t nb2 = dst->nb[2];
-    const size_t nb3 = dst->nb[3];
+    GGML_TENSOR_UNARY_OP_LOCALS;
 
     const float eps = 1e-5f; // TODO: make this a parameter
 
@@ -10174,18 +10324,7 @@ static void ggml_compute_forward_rms_norm_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3];
-
-    const size_t nb01 = src0->nb[1];
-    const size_t nb02 = src0->nb[2];
-    const size_t nb03 = src0->nb[3];
-
-    const size_t nb1 = dst->nb[1];
-    const size_t nb2 = dst->nb[2];
-    const size_t nb3 = dst->nb[3];
+    GGML_TENSOR_UNARY_OP_LOCALS;
 
     const float eps = 1e-6f; // TODO: make this a parameter
 
@@ -10250,22 +10389,7 @@ static void ggml_compute_forward_rms_norm_back_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3];
-
-    const size_t nb01 = src0->nb[1];
-    const size_t nb02 = src0->nb[2];
-    const size_t nb03 = src0->nb[3];
-
-    const size_t nb11 = src1->nb[1];
-    const size_t nb12 = src1->nb[2];
-    const size_t nb13 = src1->nb[3];
-
-    const size_t nb1 = dst->nb[1];
-    const size_t nb2 = dst->nb[2];
-    const size_t nb3 = dst->nb[3];
+    GGML_TENSOR_BINARY_OP_LOCALS;
 
     const float eps = 1e-6f; // TODO: make this a parameter
 
@@ -10461,41 +10585,7 @@ static void ggml_compute_forward_mul_mat_f32(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3];
-
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
-    const int64_t ne10 = src1->ne[0];
-#endif
-    const int64_t ne11 = src1->ne[1];
-#ifndef NDEBUG
-    const int64_t ne12 = src1->ne[2];
-    const int64_t ne13 = src1->ne[3];
-
-    const int64_t ne0  = dst->ne[0];
-    const int64_t ne1  = dst->ne[1];
-    const int64_t ne2  = dst->ne[2];
-    const int64_t ne3  = dst->ne[3];
-
-    const int nb00 = src0->nb[0];
-#endif
-    const int nb01 = src0->nb[1];
-    const int nb02 = src0->nb[2];
-    const int nb03 = src0->nb[3];
-
-#ifndef NDEBUG
-    const int nb10 = src1->nb[0];
-#endif
-    const int nb11 = src1->nb[1];
-    const int nb12 = src1->nb[2];
-    const int nb13 = src1->nb[3];
-
-    const int nb0  = dst->nb[0];
-    const int nb1  = dst->nb[1];
-    const int nb2  = dst->nb[2];
-    const int nb3  = dst->nb[3];
+    GGML_TENSOR_BINARY_OP_LOCALS;
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -10632,37 +10722,10 @@ static void ggml_compute_forward_mul_mat_f16_f32(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3];
+    GGML_TENSOR_BINARY_OP_LOCALS;
 
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
-    const int64_t ne12 = src1->ne[2];
-    const int64_t ne13 = src1->ne[3];
-
-    const int64_t ne0  = dst->ne[0];
-    const int64_t ne1  = dst->ne[1];
-    const int64_t ne2  = dst->ne[2];
-    const int64_t ne3  = dst->ne[3];
     //const int64_t ne   = ne0*ne1*ne2*ne3;
 
-    const int nb00 = src0->nb[0];
-    const int nb01 = src0->nb[1];
-    const int nb02 = src0->nb[2];
-    const int nb03 = src0->nb[3];
-
-    const int nb10 = src1->nb[0];
-    const int nb11 = src1->nb[1];
-    const int nb12 = src1->nb[2];
-    const int nb13 = src1->nb[3];
-
-    const int nb0  = dst->nb[0];
-    const int nb1  = dst->nb[1];
-    const int nb2  = dst->nb[2];
-    const int nb3  = dst->nb[3];
-
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -10832,35 +10895,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3];
-
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
-    const int64_t ne12 = src1->ne[2];
-    const int64_t ne13 = src1->ne[3];
-
-    const int64_t ne0  = dst->ne[0];
-    const int64_t ne1  = dst->ne[1];
-    const int64_t ne2  = dst->ne[2];
-    const int64_t ne3  = dst->ne[3];
-
-    const int nb00 = src0->nb[0];
-    const int nb01 = src0->nb[1];
-    const int nb02 = src0->nb[2];
-    const int nb03 = src0->nb[3];
-
-    const int nb10 = src1->nb[0];
-    const int nb11 = src1->nb[1];
-    const int nb12 = src1->nb[2];
-    const int nb13 = src1->nb[3];
-
-    const int nb0  = dst->nb[0];
-    const int nb1  = dst->nb[1];
-    const int nb2  = dst->nb[2];
-    const int nb3  = dst->nb[3];
+    GGML_TENSOR_BINARY_OP_LOCALS;
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -10876,7 +10911,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
     enum ggml_type   const vec_dot_type       = quantize_fns[type].vec_dot_type;
 
     // we don't support permuted src0 or src1
-    GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
+    GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
     GGML_ASSERT(nb10 == sizeof(float));
 
     // dst cannot be transposed or permuted
@@ -11070,35 +11105,7 @@ static void ggml_compute_forward_out_prod_f32(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3];
-
-    const int64_t ne10 = src1->ne[0];
-    //const int64_t ne11 = src1->ne[1];
-    const int64_t ne12 = src1->ne[2];
-    const int64_t ne13 = src1->ne[3];
-
-    const int64_t ne0  = dst->ne[0];
-    const int64_t ne1  = dst->ne[1];
-    const int64_t ne2  = dst->ne[2];
-    const int64_t ne3  = dst->ne[3];
-
-    const int nb00 = src0->nb[0];
-    const int nb01 = src0->nb[1];
-    const int nb02 = src0->nb[2];
-    const int nb03 = src0->nb[3];
-
-    const int nb10 = src1->nb[0];
-    const int nb11 = src1->nb[1];
-    const int nb12 = src1->nb[2];
-    const int nb13 = src1->nb[3];
-
-    const int nb0  = dst->nb[0];
-    const int nb1  = dst->nb[1];
-    const int nb2  = dst->nb[2];
-    const int nb3  = dst->nb[3];
+    GGML_TENSOR_BINARY_OP_LOCALS;
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -11333,15 +11340,8 @@ static void ggml_compute_forward_set_f32(
     const int nr = ggml_nrows(src1);
     const int nc = src1->ne[0];
 
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
-    const int64_t ne12 = src1->ne[2];
-    const int64_t ne13 = src1->ne[3];
-
-    const size_t nb10 = src1->nb[0];
-    const size_t nb11 = src1->nb[1];
-    const size_t nb12 = src1->nb[2];
-    const size_t nb13 = src1->nb[3];
+    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
+    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb);
 
     // src0 and dst as viewed during set
     const size_t nb0 = ggml_element_size(src0);
@@ -11732,29 +11732,14 @@ static void ggml_compute_forward_diag_f32(
 
     // TODO: handle transposed/permuted matrices
 
-    const int ne00 = src0->ne[0];
-    const int ne01 = src0->ne[1];
-    const int ne02 = src0->ne[2];
-    const int ne03 = src0->ne[3];
-    const int ne0 = dst->ne[0];
-    const int ne1 = dst->ne[1];
-    const int ne2 = dst->ne[2];
-    const int ne3 = dst->ne[3];
+    GGML_TENSOR_UNARY_OP_LOCALS;
+
     GGML_ASSERT(ne00 == ne0);
     GGML_ASSERT(ne00 == ne1);
     GGML_ASSERT(ne01 == 1);
     GGML_ASSERT(ne02 == ne2);
     GGML_ASSERT(ne03 == ne3);
 
-    const int nb00 = src0->nb[0];
-    //const int nb01 = src0->nb[1];
-    const int nb02 = src0->nb[2];
-    const int nb03 = src0->nb[3];
-    const int nb0 = dst->nb[0];
-    const int nb1 = dst->nb[1];
-    const int nb2 = dst->nb[2];
-    const int nb3 = dst->nb[3];
-
     GGML_ASSERT(nb00 == sizeof(float));
     GGML_ASSERT(nb0  == sizeof(float));
 
@@ -12318,7 +12303,7 @@ static void ggml_compute_forward_rope_f32(
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    GGML_ASSERT(ggml_nelements(src1) == 3);
+    GGML_ASSERT(ggml_nelements(src1) == 4);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -12327,23 +12312,11 @@ static void ggml_compute_forward_rope_f32(
     const int n_past = ((int32_t *) src1->data)[0];
     const int n_dims = ((int32_t *) src1->data)[1];
     const int mode   = ((int32_t *) src1->data)[2];
+    const int n_ctx  = ((int32_t *) src1->data)[3];
 
     assert(n_past >= 0);
 
-    const size_t nb00 = src0->nb[0];
-    const size_t nb01 = src0->nb[1];
-    const size_t nb02 = src0->nb[2];
-    const size_t nb03 = src0->nb[3];
-
-    const int64_t ne0 = dst->ne[0];
-    const int64_t ne1 = dst->ne[1];
-    const int64_t ne2 = dst->ne[2];
-    const int64_t ne3 = dst->ne[3];
-
-    const size_t nb0 = dst->nb[0];
-    const size_t nb1 = dst->nb[1];
-    const size_t nb2 = dst->nb[2];
-    const size_t nb3 = dst->nb[3];
+    GGML_TENSOR_UNARY_OP_LOCALS;
 
     //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
     //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -12371,6 +12344,7 @@ static void ggml_compute_forward_rope_f32(
     const float theta_scale = powf(10000.0, -2.0f/n_dims);
 
     const bool is_neox = mode & 2;
+    const bool is_glm  = mode & 4;
 
     for (int64_t i3 = 0; i3 < ne3; i3++) {
         for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
@@ -12381,7 +12355,32 @@ static void ggml_compute_forward_rope_f32(
 
                 float theta = (float)p;
 
-                if (!is_neox) {
+                if (is_glm) {
+                    theta = MIN(p, n_ctx - 2);
+                    float block_theta = MAX(p - (n_ctx - 2), 0);
+                    for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
+                        const float cos_theta = cosf(theta);
+                        const float sin_theta = sinf(theta);
+                        const float cos_block_theta = cosf(block_theta);
+                        const float sin_block_theta = sinf(block_theta);
+
+                        theta *= theta_scale;
+                        block_theta *= theta_scale;
+
+                        const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                              float * dst_data  = (float *)((char *)  dst->data +  i3*nb3 + i2*nb2  + i1*nb1  + i0*nb0);
+
+                        const float x0 = src[0];
+                        const float x1 = src[n_dims/2];
+                        const float x2 = src[n_dims];
+                        const float x3 = src[n_dims/2*3];
+
+                        dst_data[0]          = x0*cos_theta - x1*sin_theta;
+                        dst_data[n_dims/2]   = x0*sin_theta + x1*cos_theta;
+                        dst_data[n_dims]     = x2*cos_block_theta - x3*sin_block_theta;
+                        dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta;
+                    }
+                } else if (!is_neox) {
                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
                         const float cos_theta = cosf(theta);
                         const float sin_theta = sinf(theta);
@@ -12431,7 +12430,7 @@ static void ggml_compute_forward_rope_f16(
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    GGML_ASSERT(ggml_nelements(src1) == 3);
+    GGML_ASSERT(ggml_nelements(src1) == 4);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -12440,23 +12439,11 @@ static void ggml_compute_forward_rope_f16(
     const int n_past = ((int32_t *) src1->data)[0];
     const int n_dims = ((int32_t *) src1->data)[1];
     const int mode   = ((int32_t *) src1->data)[2];
+    const int n_ctx  = ((int32_t *) src1->data)[3];
 
     assert(n_past >= 0);
 
-    const size_t nb00 = src0->nb[0];
-    const size_t nb01 = src0->nb[1];
-    const size_t nb02 = src0->nb[2];
-    const size_t nb03 = src0->nb[3];
-
-    const int64_t ne0 = dst->ne[0];
-    const int64_t ne1 = dst->ne[1];
-    const int64_t ne2 = dst->ne[2];
-    const int64_t ne3 = dst->ne[3];
-
-    const size_t nb0 = dst->nb[0];
-    const size_t nb1 = dst->nb[1];
-    const size_t nb2 = dst->nb[2];
-    const size_t nb3 = dst->nb[3];
+    GGML_TENSOR_UNARY_OP_LOCALS;
 
     //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
     //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -12484,6 +12471,7 @@ static void ggml_compute_forward_rope_f16(
     const float theta_scale = powf(10000.0, -2.0f/n_dims);
 
     const bool is_neox = mode & 2;
+    const bool is_glm  = mode & 4;
 
     for (int64_t i3 = 0; i3 < ne3; i3++) {
         for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
@@ -12494,7 +12482,32 @@ static void ggml_compute_forward_rope_f16(
 
                 float theta = (float)p;
 
-                if (!is_neox) {
+                if (is_glm) {
+                    theta = MIN(p, n_ctx - 2);
+                    float block_theta = MAX(p - (n_ctx - 2), 0);
+                    for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
+                        const float cos_theta = cosf(theta);
+                        const float sin_theta = sinf(theta);
+                        const float cos_block_theta = cosf(block_theta);
+                        const float sin_block_theta = sinf(block_theta);
+
+                        theta *= theta_scale;
+                        block_theta *= theta_scale;
+
+                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                              ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data +  i3*nb3 + i2*nb2  + i1*nb1  + i0*nb0);
+
+                        const float x0 = GGML_FP16_TO_FP32(src[0]);
+                        const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
+                        const float x2 = GGML_FP16_TO_FP32(src[n_dims]);
+                        const float x3 = GGML_FP16_TO_FP32(src[n_dims/2*3]);
+
+                        dst_data[0]          = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+                        dst_data[n_dims/2]   = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+                        dst_data[n_dims]     = GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta);
+                        dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta);
+                    }
+                } if (!is_neox) {
                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
                         const float cos_theta = cosf(theta);
                         const float sin_theta = sinf(theta);
@@ -12577,27 +12590,13 @@ static void ggml_compute_forward_rope_back_f32(
     // dx = rope_back(dy, src1)
     // src0 is dy, src1 contains options
 
-    const int n_past = ((int32_t *) src1->data)[0];
-    const int n_dims = ((int32_t *) src1->data)[1];
-    const int mode   = ((int32_t *) src1->data)[2];
-
-    assert(n_past >= 0);
-
-    const size_t nb00 = src0->nb[0];
-    const size_t nb01 = src0->nb[1];
-    const size_t nb02 = src0->nb[2];
-    const size_t nb03 = src0->nb[3];
-
-    const int64_t ne0 = dst->ne[0];
-    const int64_t ne1 = dst->ne[1];
-    const int64_t ne2 = dst->ne[2];
-    const int64_t ne3 = dst->ne[3];
-
-    const size_t nb0 = dst->nb[0];
-    const size_t nb1 = dst->nb[1];
-    const size_t nb2 = dst->nb[2];
-    const size_t nb3 = dst->nb[3];
+    const int n_past = ((int32_t *) src1->data)[0];
+    const int n_dims = ((int32_t *) src1->data)[1];
+    const int mode   = ((int32_t *) src1->data)[2];
+
+    assert(n_past >= 0);
 
+    GGML_TENSOR_UNARY_OP_LOCALS;
 
     //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
     //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -12696,21 +12695,7 @@ static void ggml_compute_forward_rope_back_f16(
 
     assert(n_past >= 0);
 
-    const size_t nb00 = src0->nb[0];
-    const size_t nb01 = src0->nb[1];
-    const size_t nb02 = src0->nb[2];
-    const size_t nb03 = src0->nb[3];
-
-    const int64_t ne0 = dst->ne[0];
-    const int64_t ne1 = dst->ne[1];
-    const int64_t ne2 = dst->ne[2];
-    const int64_t ne3 = dst->ne[3];
-
-    const size_t nb0 = dst->nb[0];
-    const size_t nb1 = dst->nb[1];
-    const size_t nb2 = dst->nb[2];
-    const size_t nb3 = dst->nb[3];
-
+    GGML_TENSOR_UNARY_OP_LOCALS;
 
     //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
     //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -12808,7 +12793,7 @@ static void ggml_compute_forward_rope_back(
     }
 }
 
-// ggml_compute_forward_conv_1d_s1_ph
+// ggml_compute_forward_conv_1d
 
 static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
         const struct ggml_compute_params * params,
@@ -12822,36 +12807,7 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    //const int64_t ne03 = src0->ne[3];
-
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
-    //const int64_t ne12 = src1->ne[2];
-    //const int64_t ne13 = src1->ne[3];
-
-    //const int64_t ne0  = dst->ne[0];
-    //const int64_t ne1  = dst->ne[1];
-    //const int64_t ne2  = dst->ne[2];
-    //const int64_t ne3  = dst->ne[3];
-    //const int64_t ne   = ne0*ne1*ne2*ne3;
-
-    const int nb00 = src0->nb[0];
-    const int nb01 = src0->nb[1];
-    const int nb02 = src0->nb[2];
-    //const int nb03 = src0->nb[3];
-
-    const int nb10 = src1->nb[0];
-    const int nb11 = src1->nb[1];
-    //const int nb12 = src1->nb[2];
-    //const int nb13 = src1->nb[3];
-
-    //const int nb0  = dst->nb[0];
-    const int nb1  = dst->nb[1];
-    //const int nb2  = dst->nb[2];
-    //const int nb3  = dst->nb[3];
+    GGML_TENSOR_BINARY_OP_LOCALS;
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -12942,36 +12898,7 @@ static void ggml_compute_forward_conv_1d_s1_ph_f32(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    //const int64_t ne03 = src0->ne[3];
-
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
-    //const int64_t ne12 = src1->ne[2];
-    //const int64_t ne13 = src1->ne[3];
-
-    //const int64_t ne0  = dst->ne[0];
-    //const int64_t ne1  = dst->ne[1];
-    //const int64_t ne2  = dst->ne[2];
-    //const int64_t ne3  = dst->ne[3];
-    //const int64_t ne   = ne0*ne1*ne2*ne3;
-
-    const int nb00 = src0->nb[0];
-    const int nb01 = src0->nb[1];
-    const int nb02 = src0->nb[2];
-    //const int nb03 = src0->nb[3];
-
-    const int nb10 = src1->nb[0];
-    const int nb11 = src1->nb[1];
-    //const int nb12 = src1->nb[2];
-    //const int nb13 = src1->nb[3];
-
-    //const int nb0  = dst->nb[0];
-    const int nb1  = dst->nb[1];
-    //const int nb2  = dst->nb[2];
-    //const int nb3  = dst->nb[3];
+    GGML_TENSOR_BINARY_OP_LOCALS;
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -13071,8 +12998,6 @@ static void ggml_compute_forward_conv_1d_s1_ph(
     }
 }
 
-// ggml_compute_forward_conv_1d_s2_ph
-
 static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
@@ -13085,36 +13010,7 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    //const int64_t ne03 = src0->ne[3];
-
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
-    //const int64_t ne12 = src1->ne[2];
-    //const int64_t ne13 = src1->ne[3];
-
-    //const int64_t ne0  = dst->ne[0];
-    //const int64_t ne1  = dst->ne[1];
-    //const int64_t ne2  = dst->ne[2];
-    //const int64_t ne3  = dst->ne[3];
-    //const int64_t ne   = ne0*ne1*ne2*ne3;
-
-    const int nb00 = src0->nb[0];
-    const int nb01 = src0->nb[1];
-    const int nb02 = src0->nb[2];
-    //const int nb03 = src0->nb[3];
-
-    const int nb10 = src1->nb[0];
-    const int nb11 = src1->nb[1];
-    //const int nb12 = src1->nb[2];
-    //const int nb13 = src1->nb[3];
-
-    //const int nb0  = dst->nb[0];
-    const int nb1  = dst->nb[1];
-    //const int nb2  = dst->nb[2];
-    //const int nb3  = dst->nb[3];
+    GGML_TENSOR_BINARY_OP_LOCALS;
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -13205,36 +13101,7 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    //const int64_t ne03 = src0->ne[3];
-
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
-    //const int64_t ne12 = src1->ne[2];
-    //const int64_t ne13 = src1->ne[3];
-
-    //const int64_t ne0  = dst->ne[0];
-    //const int64_t ne1  = dst->ne[1];
-    //const int64_t ne2  = dst->ne[2];
-    //const int64_t ne3  = dst->ne[3];
-    //const int64_t ne   = ne0*ne1*ne2*ne3;
-
-    const int nb00 = src0->nb[0];
-    const int nb01 = src0->nb[1];
-    const int nb02 = src0->nb[2];
-    //const int nb03 = src0->nb[3];
-
-    const int nb10 = src1->nb[0];
-    const int nb11 = src1->nb[1];
-    //const int nb12 = src1->nb[2];
-    //const int nb13 = src1->nb[3];
-
-    //const int nb0  = dst->nb[0];
-    const int nb1  = dst->nb[1];
-    //const int nb2  = dst->nb[2];
-    //const int nb3  = dst->nb[3];
+    GGML_TENSOR_BINARY_OP_LOCALS;
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -13334,6 +13201,28 @@ static void ggml_compute_forward_conv_1d_s2_ph(
     }
 }
 
+// ggml_compute_forward_conv_1d
+
+static void ggml_compute_forward_conv_1d(
+    const struct ggml_compute_params * params,
+    const struct ggml_tensor * src0,
+    const struct ggml_tensor * src1,
+    const struct ggml_tensor * opt0,
+    struct ggml_tensor * dst) {
+    const int32_t s0 = ((const int32_t*)(opt0->data))[0];
+    const int32_t p0 = ((const int32_t*)(opt0->data))[1];
+    const int32_t d0 = ((const int32_t*)(opt0->data))[2];
+    GGML_ASSERT(d0 == 1); // dilation not supported
+    GGML_ASSERT(p0 == src0->ne[0]/2); // only half padding supported
+    if (s0 == 1) {
+        ggml_compute_forward_conv_1d_s1_ph(params, src0, src1, dst);
+    } else if (s0 == 2) {
+        ggml_compute_forward_conv_1d_s2_ph(params, src0, src1, dst);
+    } else {
+        GGML_ASSERT(false); // only stride 1 and 2 supported
+    };
+}
+
 // ggml_compute_forward_conv_2d_sk_p0
 
 static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
@@ -13348,36 +13237,7 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    const int ne00 = src0->ne[0];
-    const int ne01 = src0->ne[1];
-    const int ne02 = src0->ne[2];
-    //const int ne03 = src0->ne[3];
-
-    const int ne10 = src1->ne[0];
-    //const int ne11 = src1->ne[1];
-    const int ne12 = src1->ne[2];
-    //const int ne13 = src1->ne[3];
-
-    const int ne0  = dst->ne[0];
-    const int ne1  = dst->ne[1];
-    const int ne2  = dst->ne[2];
-    //const int ne3  = dst->ne[3];
-    //const int ne   = ne0*ne1*ne2*ne3;
-
-    const int nb00 = src0->nb[0];
-    //const int nb01 = src0->nb[1];
-    //const int nb02 = src0->nb[2];
-    const int nb03 = src0->nb[3];
-
-    const int nb10 = src1->nb[0];
-    //const int nb11 = src1->nb[1];
-    const int nb12 = src1->nb[2];
-    //const int nb13 = src1->nb[3];
-
-    //const int nb0  = dst->nb[0];
-    //const int nb1  = dst->nb[1];
-    const int nb2  = dst->nb[2];
-    //const int nb3  = dst->nb[3];
+    GGML_TENSOR_BINARY_OP_LOCALS;
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -13386,8 +13246,7 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
     const int nk1 = ne01;
 
     // size of the convolution row - the kernel size unrolled across all channels
-    // round-up so it is more suitable for SIMD
-    const int ew0 = ggml_up32(nk0*nk1*ne02);
+    const int ew0 = nk0*nk1*ne02;
 
     GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
     GGML_ASSERT(nb10 == sizeof(float));
@@ -13471,6 +13330,34 @@ static void ggml_compute_forward_conv_2d_sk_p0(
     }
 }
 
+// ggml_compute_forward_conv_2d
+
+static void ggml_compute_forward_conv_2d(
+    const struct ggml_compute_params* params,
+    const struct ggml_tensor* src0,
+    const struct ggml_tensor* src1,
+    const struct ggml_tensor* opt0,
+    struct ggml_tensor* dst) {
+    const int32_t s0 = ((const int32_t*)(opt0->data))[0];
+    const int32_t s1 = ((const int32_t*)(opt0->data))[1];
+    const int32_t p0 = ((const int32_t*)(opt0->data))[2];
+    const int32_t p1 = ((const int32_t*)(opt0->data))[3];
+    const int32_t d0 = ((const int32_t*)(opt0->data))[4];
+    const int32_t d1 = ((const int32_t*)(opt0->data))[5];
+    GGML_ASSERT(d0 == 1); // dilation not supported
+    GGML_ASSERT(d1 == 1);
+    GGML_ASSERT(p0 == 0); // padding not supported
+    GGML_ASSERT(p1 == 0);
+
+    if (s0 == src0->ne[0] && s1 == src0->ne[1]) {
+        ggml_compute_forward_conv_2d_sk_p0(params, src0, src1, dst);
+    }
+    else {
+        GGML_ASSERT(false); // only stride equal to kernel size is supported
+    };
+}
+
+
 // ggml_compute_forward_flash_attn
 
 static void ggml_compute_forward_flash_attn_f32(
@@ -13483,45 +13370,14 @@ static void ggml_compute_forward_flash_attn_f32(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    const int64_t neq0 = q->ne[0];
-    const int64_t neq1 = q->ne[1];
-    const int64_t neq2 = q->ne[2];
-    const int64_t neq3 = q->ne[3];
-
-    const int64_t nek0 = k->ne[0];
-    const int64_t nek1 = k->ne[1];
-    //const int64_t nek2 = k->ne[2];
-    //const int64_t nek3 = k->ne[3];
-
-    //const int64_t nev0 = v->ne[0];
-    const int64_t nev1 = v->ne[1];
-    //const int64_t nev2 = v->ne[2];
-    //const int64_t nev3 = v->ne[3];
-
-    const int64_t ne0  = dst->ne[0];
-    const int64_t ne1  = dst->ne[1];
-    //const int64_t ne2  = dst->ne[2];
-    //const int64_t ne3  = dst->ne[3];
-
-    const int nbk0 = k->nb[0];
-    const int nbk1 = k->nb[1];
-    const int nbk2 = k->nb[2];
-    const int nbk3 = k->nb[3];
-
-    const int nbq0 = q->nb[0];
-    const int nbq1 = q->nb[1];
-    const int nbq2 = q->nb[2];
-    const int nbq3 = q->nb[3];
-
-    const int nbv0 = v->nb[0];
-    const int nbv1 = v->nb[1];
-    const int nbv2 = v->nb[2];
-    const int nbv3 = v->nb[3];
-
-    const int nb0  = dst->nb[0];
-    const int nb1  = dst->nb[1];
-    const int nb2  = dst->nb[2];
-    const int nb3  = dst->nb[3];
+    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne);
+    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb);
+    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne);
+    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb);
+    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne);
+    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb);
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne);
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb);
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -13692,45 +13548,14 @@ static void ggml_compute_forward_flash_attn_f16(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    const int64_t neq0 = q->ne[0];
-    const int64_t neq1 = q->ne[1];
-    const int64_t neq2 = q->ne[2];
-    const int64_t neq3 = q->ne[3];
-
-    const int64_t nek0 = k->ne[0];
-    const int64_t nek1 = k->ne[1];
-    //const int64_t nek2 = k->ne[2];
-    //const int64_t nek3 = k->ne[3];
-
-    //const int64_t nev0 = v->ne[0];
-    const int64_t nev1 = v->ne[1];
-    //const int64_t nev2 = v->ne[2];
-    //const int64_t nev3 = v->ne[3];
-
-    const int64_t ne0  = dst->ne[0];
-    const int64_t ne1  = dst->ne[1];
-    //const int64_t ne2  = dst->ne[2];
-    //const int64_t ne3  = dst->ne[3];
-
-    const int nbk0 = k->nb[0];
-    const int nbk1 = k->nb[1];
-    const int nbk2 = k->nb[2];
-    const int nbk3 = k->nb[3];
-
-    const int nbq0 = q->nb[0];
-    const int nbq1 = q->nb[1];
-    const int nbq2 = q->nb[2];
-    const int nbq3 = q->nb[3];
-
-    const int nbv0 = v->nb[0];
-    const int nbv1 = v->nb[1];
-    const int nbv2 = v->nb[2];
-    const int nbv3 = v->nb[3];
-
-    const int nb0  = dst->nb[0];
-    const int nb1  = dst->nb[1];
-    const int nb2  = dst->nb[2];
-    const int nb3  = dst->nb[3];
+    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne);
+    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb);
+    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne);
+    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb);
+    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne);
+    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb);
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne);
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb);
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -13964,65 +13789,18 @@ static void ggml_compute_forward_flash_ff_f16(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    const int64_t nea0 = a->ne[0];
-    const int64_t nea1 = a->ne[1];
-    const int64_t nea2 = a->ne[2];
-    const int64_t nea3 = a->ne[3];
-
-    const int64_t neb00 = b0->ne[0];
-    const int64_t neb01 = b0->ne[1];
-    //const int64_t neb02 = b0->ne[2];
-    //const int64_t neb03 = b0->ne[3];
-
-    const int64_t neb10 = b1->ne[0];
-    const int64_t neb11 = b1->ne[1];
-    //const int64_t neb12 = b1->ne[2];
-    //const int64_t neb13 = b1->ne[3];
-
-    const int64_t nec00 = c0->ne[0];
-    const int64_t nec01 = c0->ne[1];
-    //const int64_t nec02 = c0->ne[2];
-    //const int64_t nec03 = c0->ne[3];
-
-    const int64_t nec10 = c1->ne[0];
-    const int64_t nec11 = c1->ne[1];
-    //const int64_t nec12 = c1->ne[2];
-    //const int64_t nec13 = c1->ne[3];
-
-    const int64_t ne0 = dst->ne[0];
-    const int64_t ne1 = dst->ne[1];
-    const int64_t ne2 = dst->ne[2];
-    //const int64_t ne3 = dst->ne[3];
-
-    const int nba0 = a->nb[0];
-    const int nba1 = a->nb[1];
-    const int nba2 = a->nb[2];
-    const int nba3 = a->nb[3];
-
-    const int nbb00 = b0->nb[0];
-    const int nbb01 = b0->nb[1];
-    const int nbb02 = b0->nb[2];
-    const int nbb03 = b0->nb[3];
-
-    const int nbb10 = b1->nb[0];
-    //const int nbb11 = b1->nb[1];
-    //const int nbb12 = b1->nb[2];
-    //const int nbb13 = b1->nb[3];
-
-    const int nbc00 = c0->nb[0];
-    const int nbc01 = c0->nb[1];
-    const int nbc02 = c0->nb[2];
-    const int nbc03 = c0->nb[3];
-
-    const int nbc10 = c1->nb[0];
-    //const int nbc11 = c1->nb[1];
-    //const int nbc12 = c1->nb[2];
-    //const int nbc13 = c1->nb[3];
-
-    const int nb0 = dst->nb[0];
-    const int nb1 = dst->nb[1];
-    const int nb2 = dst->nb[2];
-    const int nb3 = dst->nb[3];
+    GGML_TENSOR_LOCALS(int64_t, nea,  a,   ne);
+    GGML_TENSOR_LOCALS(size_t,  nba,  a,   nb);
+    GGML_TENSOR_LOCALS(int64_t, neb0, b0,  ne);
+    GGML_TENSOR_LOCALS(size_t,  nbb0, b0,  nb);
+    GGML_TENSOR_LOCALS(int64_t, neb1, b1,  ne);
+    GGML_TENSOR_LOCALS(size_t,  nbb1, b1,  nb);
+    GGML_TENSOR_LOCALS(int64_t, nec0, c0,  ne);
+    GGML_TENSOR_LOCALS(size_t,  nbc0, c0,  nb);
+    GGML_TENSOR_LOCALS(int64_t, nec1, c1,  ne);
+    GGML_TENSOR_LOCALS(size_t,  nbc1, c1,  nb);
+    GGML_TENSOR_LOCALS(int64_t, ne,   dst, ne);
+    GGML_TENSOR_LOCALS(size_t,  nb,   dst, nb);
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -14170,55 +13948,16 @@ static void ggml_compute_forward_flash_attn_back_f32(
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
-    const int64_t neq0 = q->ne[0];
-    const int64_t neq1 = q->ne[1];
-    const int64_t neq2 = q->ne[2];
-    const int64_t neq3 = q->ne[3];
-
-    const int64_t nek0 = k->ne[0];
-    const int64_t nek1 = k->ne[1];
-    //const int64_t nek2 = k->ne[2];
-    //const int64_t nek3 = k->ne[3];
-
-    const int64_t nev0 = v->ne[0];
-    const int64_t nev1 = v->ne[1];
-    //const int64_t nev2 = v->ne[2];
-    //const int64_t nev3 = v->ne[3];
-
-    const int64_t ned0 = d->ne[0];
-    const int64_t ned1 = d->ne[1];
-    //const int64_t ned2 = d->ne[2];
-    //const int64_t ned3 = d->ne[3];
-
-    const int64_t ne0  = dst->ne[0];
-    const int64_t ne1  = dst->ne[1];
-    const int64_t ne2  = dst->ne[2];
-    const int64_t ne3  = dst->ne[3];
-
-    const int nbk0 = k->nb[0];
-    const int nbk1 = k->nb[1];
-    const int nbk2 = k->nb[2];
-    const int nbk3 = k->nb[3];
-
-    const int nbq0 = q->nb[0];
-    const int nbq1 = q->nb[1];
-    const int nbq2 = q->nb[2];
-    const int nbq3 = q->nb[3];
-
-    const int nbv0 = v->nb[0];
-    const int nbv1 = v->nb[1];
-    const int nbv2 = v->nb[2];
-    const int nbv3 = v->nb[3];
-
-    const int nbd0 = d->nb[0];
-    const int nbd1 = d->nb[1];
-    const int nbd2 = d->nb[2];
-    const int nbd3 = d->nb[3];
-
-    const int nb0  = dst->nb[0];
-    const int nb1  = dst->nb[1];
-    const int nb2  = dst->nb[2];
-    const int nb3  = dst->nb[3];
+    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne);
+    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb);
+    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne);
+    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb);
+    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne);
+    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb);
+    GGML_TENSOR_LOCALS(int64_t, ned, d,   ne);
+    GGML_TENSOR_LOCALS(size_t,  nbd, d,   nb);
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne);
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb);
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -14576,15 +14315,8 @@ static void ggml_compute_forward_win_part_f32(
         return;
     }
 
-    const int64_t ne00 = src0->ne[0]; UNUSED(ne00);
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3]; UNUSED(ne03);
-
-    const int64_t ne0 = dst->ne[0];
-    const int64_t ne1 = dst->ne[1];
-    const int64_t ne2 = dst->ne[2];
-    const int64_t ne3 = dst->ne[3]; UNUSED(ne3);
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne);
 
     const int32_t nep0 = ((const int32_t *)(opt0->data))[0];
     const int32_t nep1 = ((const int32_t *)(opt0->data))[1];
@@ -14647,14 +14379,8 @@ static void ggml_compute_forward_win_unpart_f32(
         return;
     }
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    //const int64_t ne03 = src0->ne[3];
-
-    const int64_t ne0 = dst->ne[0];
-    const int64_t ne1 = dst->ne[1];
-    const int64_t ne2 = dst->ne[2];
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne);
 
     const int32_t w = ((const int32_t *)(opt0->data))[0];
 
@@ -15195,7 +14921,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
     if (skip_cpu) {
         return;
     }
-    GGML_ASSERT(tensor->src0->backend == GGML_BACKEND_CPU);
+    GGML_ASSERT(tensor->src0 == NULL || tensor->src0->backend == GGML_BACKEND_CPU);
     GGML_ASSERT(tensor->src1 == NULL || tensor->src1->backend == GGML_BACKEND_CPU);
 #endif // GGML_USE_CUBLAS
 
@@ -15252,6 +14978,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_mean(params, tensor->src0, tensor);
             } break;
+        case GGML_OP_ARGMAX:
+            {
+                ggml_compute_forward_argmax(params, tensor->src0, tensor);
+            } break;
         case GGML_OP_REPEAT:
             {
                 ggml_compute_forward_repeat(params, tensor->src0, tensor);
@@ -15276,6 +15006,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_step(params, tensor->src0, tensor);
             } break;
+        case GGML_OP_TANH:
+            {
+                ggml_compute_forward_tanh(params, tensor->src0, tensor);
+            } break;
+        case GGML_OP_ELU:
+            {
+                ggml_compute_forward_elu(params, tensor->src0, tensor);
+            } break;
         case GGML_OP_RELU:
             {
                 ggml_compute_forward_relu(params, tensor->src0, tensor);
@@ -15392,17 +15130,13 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_clamp(params, tensor->src0, tensor->src1, tensor);
             } break;
-        case GGML_OP_CONV_1D_S1_PH:
+        case GGML_OP_CONV_1D:
             {
-                ggml_compute_forward_conv_1d_s1_ph(params, tensor->src0, tensor->src1, tensor);
+                ggml_compute_forward_conv_1d(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
             } break;
-        case GGML_OP_CONV_1D_S2_PH:
+        case GGML_OP_CONV_2D:
             {
-                ggml_compute_forward_conv_1d_s2_ph(params, tensor->src0, tensor->src1, tensor);
-            } break;
-        case GGML_OP_CONV_2D_SK_P0:
-            {
-                ggml_compute_forward_conv_2d_sk_p0(params, tensor->src0, tensor->src1, tensor);
+                ggml_compute_forward_conv_2d(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
             } break;
         case GGML_OP_FLASH_ATTN:
             {
@@ -15651,6 +15385,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 }
             } break;
         case GGML_OP_MEAN:
+        case GGML_OP_ARGMAX:
             {
                 GGML_ASSERT(false); // TODO: implement
             } break;
@@ -15704,6 +15439,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     // noop
                 }
             } break;
+        case GGML_OP_TANH:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
+        case GGML_OP_ELU:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_RELU:
             {
                 if (src0->grad) {
@@ -15723,14 +15466,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
-        case GGML_OP_ALIBI:
-            {
-                GGML_ASSERT(false); // TODO: not implemented
-            } break;
-        case GGML_OP_CLAMP:
-            {
-                GGML_ASSERT(false); // TODO: not implemented
-            } break;
         case GGML_OP_SILU:
             {
                 // necessary for llama
@@ -16047,7 +15782,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 // necessary for llama
                 if (src0->grad) {
                     assert(src1->type == GGML_TYPE_I32);
-                    assert(ggml_nelements(src1) == 3);
+                    assert(ggml_nelements(src1) == 4);
                     const int n_past = ((int32_t *) src1->data)[0];
                     const int n_dims = ((int32_t *) src1->data)[1];
                     const int mode   = ((int32_t *) src1->data)[2];
@@ -16068,32 +15803,38 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 if (src0->grad) {
                     assert(src1->type == GGML_TYPE_I32);
-                    assert(ggml_nelements(src1) == 3);
+                    assert(ggml_nelements(src1) == 4);
                     const int n_past = ((int32_t *) src1->data)[0];
                     const int n_dims = ((int32_t *) src1->data)[1];
                     const int mode   = ((int32_t *) src1->data)[2];
+                    const int n_ctx  = ((int32_t *) src1->data)[3];
                     src0->grad = ggml_add_impl(ctx,
                             src0->grad,
                             ggml_rope(ctx,
                                 tensor->grad,
                                 n_past,
                                 n_dims,
-                                mode),
+                                mode,
+                                n_ctx),
                             inplace);
                 }
                 if (src1->grad) {
                     // noop
                 }
             } break;
-        case GGML_OP_CONV_1D_S1_PH:
+        case GGML_OP_ALIBI:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
+        case GGML_OP_CLAMP:
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
-        case GGML_OP_CONV_1D_S2_PH:
+        case GGML_OP_CONV_1D:
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
-        case GGML_OP_CONV_2D_SK_P0:
+        case GGML_OP_CONV_2D:
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
@@ -16503,68 +16244,180 @@ typedef pthread_t ggml_thread_t;
 
 #endif
 
+// Android's libc implementation "bionic" does not support setting affinity
+#if defined(__linux__) && !defined(__BIONIC__)
+void set_numa_thread_affinity(int thread_n, int n_threads) {
+    if (!ggml_is_numa()) {
+        return;
+    }
+
+    // run thread on node_num thread_n / (threads per node)
+    const int node_num = thread_n / ((n_threads + g_state.numa.n_nodes - 1) / g_state.numa.n_nodes);
+    struct ggml_numa_node * node = &g_state.numa.nodes[node_num];
+    size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus);
+
+    cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus);
+    CPU_ZERO_S(setsize, cpus);
+    for (size_t i = 0; i < node->n_cpus; ++i) {
+        CPU_SET_S(node->cpus[i], setsize, cpus);
+    }
+
+    int rv = pthread_setaffinity_np(pthread_self(), setsize, cpus);
+    if (rv) {
+            fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n",
+                    strerror(rv));
+    }
+
+    CPU_FREE(cpus);
+}
+
+void clear_numa_thread_affinity(void) {
+    if (!ggml_is_numa()) {
+        return;
+    }
+
+    size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus);
+
+    cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus);
+    CPU_ZERO_S(setsize, cpus);
+    for (unsigned i = 0; i < g_state.numa.total_cpus; ++i) {
+        CPU_SET_S(i, setsize, cpus);
+    }
+
+    int rv = pthread_setaffinity_np(pthread_self(), setsize, cpus);
+    if (rv) {
+        fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n",
+            strerror(rv));
+    }
+
+    CPU_FREE(cpus);
+}
+#else
+// TODO: Windows etc.
+// (the linux implementation may also work on BSD, someone should test)
+void set_numa_thread_affinity(int thread_n, int n_threads) { UNUSED(thread_n); UNUSED(n_threads);  }
+void clear_numa_thread_affinity(void) {}
+#endif
+
 struct ggml_compute_state_shared {
-    ggml_lock_t spin;
+    struct ggml_cgraph * cgraph;
+
+    int64_t perf_node_start_cycles;
+    int64_t perf_node_start_time_us;
 
     int n_threads;
 
     // synchronization primitives
-    atomic_int  n_ready;
-    atomic_bool has_work;
-    atomic_bool stop; // stop all threads
+    atomic_int n_active; // num active threads
+    atomic_int node_n;   // active graph node
 };
 
 struct ggml_compute_state {
     ggml_thread_t thrd;
-
-    struct ggml_compute_params params;
-    struct ggml_tensor * node;
-
+    int ith;
     struct ggml_compute_state_shared * shared;
 };
 
+static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const struct ggml_compute_state_shared * st) {
+    int64_t cycles_cur  = ggml_perf_cycles()  - st->perf_node_start_cycles;
+    int64_t time_us_cur = ggml_perf_time_us() - st->perf_node_start_time_us;
+
+    node->perf_runs++;
+    node->perf_cycles  += cycles_cur;
+    node->perf_time_us += time_us_cur;
+}
+
 static thread_ret_t ggml_graph_compute_thread(void * data) {
     struct ggml_compute_state * state = (struct ggml_compute_state *) data;
+    struct ggml_cgraph * cgraph = state->shared->cgraph;
 
     const int n_threads = state->shared->n_threads;
+    set_numa_thread_affinity(state->ith, n_threads);
+
+    int node_n = -1;
 
     while (true) {
-        if (atomic_fetch_add(&state->shared->n_ready, 1) == n_threads - 1) {
-            atomic_store(&state->shared->has_work, false);
-        } else {
-            while (atomic_load(&state->shared->has_work)) {
-                if (atomic_load(&state->shared->stop)) {
-                    return 0;
+        if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
+            // all other threads are finished and spinning
+            // do finalize and init here so we don't have synchronize again
+            struct ggml_compute_params params = {
+                /*.type  =*/ GGML_TASK_FINALIZE,
+                /*.ith   =*/ 0,
+                /*.nth   =*/ 0,
+                /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
+                /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
+            };
+
+            if (node_n != -1) {
+                /* FINALIZE */
+                struct ggml_tensor * node = state->shared->cgraph->nodes[node_n];
+                if (GGML_OP_HAS_FINALIZE[node->op]) {
+                    params.nth = node->n_tasks;
+                    ggml_compute_forward(&params, node);
+                    ggml_graph_compute_perf_stats_node(node, state->shared);
                 }
-                ggml_lock_lock  (&state->shared->spin);
-                ggml_lock_unlock(&state->shared->spin);
             }
-        }
 
-        atomic_fetch_sub(&state->shared->n_ready, 1);
+            // distribute new work or execute it direct if 1T
+            while (++node_n < cgraph->n_nodes) {
+                GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes);
+
+                struct ggml_tensor * node = cgraph->nodes[node_n];
+
+                state->shared->perf_node_start_cycles  = ggml_perf_cycles();
+                state->shared->perf_node_start_time_us = ggml_perf_time_us();
+
+                params.nth = node->n_tasks;
+
+                /* INIT */
+                if (GGML_OP_HAS_INIT[node->op]) {
+                    params.type = GGML_TASK_INIT;
+                    ggml_compute_forward(&params, node);
+                }
+
+                if (node->n_tasks == 1) {
+                    // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1,
+                    // they do something more efficient than spinning (?)
+                    params.type = GGML_TASK_COMPUTE;
+                    ggml_compute_forward(&params, node);
 
-        // wait for work
-        while (!atomic_load(&state->shared->has_work)) {
-            if (atomic_load(&state->shared->stop)) {
-                return 0;
+                    if (GGML_OP_HAS_FINALIZE[node->op]) {
+                        params.type = GGML_TASK_FINALIZE;
+                        ggml_compute_forward(&params, node);
+                        ggml_graph_compute_perf_stats_node(node, state->shared);
+                    }
+                } else {
+                    break;
+                }
             }
-            ggml_lock_lock  (&state->shared->spin);
-            ggml_lock_unlock(&state->shared->spin);
+
+            atomic_store(&state->shared->n_active, n_threads);
+            atomic_store(&state->shared->node_n,   node_n);
+        } else {
+            // wait for other threads to finish
+            const int last = node_n;
+            do {
+                sched_yield();
+                node_n = atomic_load(&state->shared->node_n);
+            } while (node_n == last);
         }
 
         // check if we should stop
-        if (atomic_load(&state->shared->stop)) {
-            break;
-        }
+        if (node_n >= cgraph->n_nodes) break;
 
-        if (state->node) {
-            if (state->params.ith < state->params.nth) {
-                ggml_compute_forward(&state->params, state->node);
-            }
+        /* COMPUTE */
+        struct ggml_tensor * node = cgraph->nodes[node_n];
 
-            state->node = NULL;
-        } else {
-            break;
+        struct ggml_compute_params params = {
+            /*.type  =*/ GGML_TASK_COMPUTE,
+            /*.ith   =*/ state->ith,
+            /*.nth   =*/ node->n_tasks,
+            /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
+            /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
+        };
+
+        if (state->ith < node->n_tasks) {
+            ggml_compute_forward(&params, node);
         }
     }
 
@@ -16575,39 +16428,14 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
     const int n_threads = cgraph->n_threads;
 
     struct ggml_compute_state_shared state_shared = {
-        /*.spin      =*/ GGML_LOCK_INITIALIZER,
-        /*.n_threads =*/ n_threads,
-        /*.n_ready   =*/ 0,
-        /*.has_work  =*/ false,
-        /*.stop      =*/ false,
+        /*.cgraph                  =*/ cgraph,
+        /*.perf_node_start_cycles  =*/ 0,
+        /*.perf_node_start_time_us =*/ 0,
+        /*.n_threads               =*/ n_threads,
+        /*.n_active                =*/ n_threads,
+        /*.node_n                  =*/ -1,
     };
-    struct ggml_compute_state * workers = n_threads > 1 ? alloca(sizeof(struct ggml_compute_state)*(n_threads - 1)) : NULL;
-
-    // create thread pool
-    if (n_threads > 1) {
-        ggml_lock_init(&state_shared.spin);
-
-        atomic_store(&state_shared.has_work, true);
-
-        for (int j = 0; j < n_threads - 1; j++) {
-            workers[j] = (struct ggml_compute_state) {
-                .thrd   = 0,
-                .params = {
-                    .type  = GGML_TASK_COMPUTE,
-                    .ith   = j + 1,
-                    .nth   = n_threads,
-                    .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
-                    .wdata = cgraph->work ? cgraph->work->data : NULL,
-                },
-                .node   = NULL,
-                .shared = &state_shared,
-            };
-
-            int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
-            GGML_ASSERT(rc == 0);
-            UNUSED(rc);
-        }
-    }
+    struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
 
     // initialize tasks + work buffer
     {
@@ -16663,12 +16491,15 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                 case GGML_OP_SUM:
                 case GGML_OP_SUM_ROWS:
                 case GGML_OP_MEAN:
+                case GGML_OP_ARGMAX:
                 case GGML_OP_REPEAT:
                 case GGML_OP_REPEAT_BACK:
                 case GGML_OP_ABS:
                 case GGML_OP_SGN:
                 case GGML_OP_NEG:
                 case GGML_OP_STEP:
+                case GGML_OP_TANH:
+                case GGML_OP_ELU:
                 case GGML_OP_RELU:
                     {
                         node->n_tasks = 1;
@@ -16751,7 +16582,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                     } break;
                 case GGML_OP_SCALE:
                     {
-                        node->n_tasks = n_threads;
+                        node->n_tasks = 1;
                     } break;
                 case GGML_OP_SET:
                 case GGML_OP_CONT:
@@ -16782,8 +16613,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                     {
                         node->n_tasks = 1; //TODO
                     } break;
-                case GGML_OP_CONV_1D_S1_PH:
-                case GGML_OP_CONV_1D_S2_PH:
+                case GGML_OP_CONV_1D:
                     {
                         node->n_tasks = n_threads;
 
@@ -16812,7 +16642,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
 
                         work_size = MAX(work_size, cur);
                     } break;
-                case GGML_OP_CONV_2D_SK_P0:
+                case GGML_OP_CONV_2D:
                     {
                         node->n_tasks = n_threads;
 
@@ -16955,166 +16785,37 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
         }
     }
 
-    const int64_t perf_start_cycles  = ggml_perf_cycles();
-    const int64_t perf_start_time_us = ggml_perf_time_us();
-
-    for (int i = 0; i < cgraph->n_nodes; i++) {
-        GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, i, cgraph->n_nodes);
-
-        struct ggml_tensor * node = cgraph->nodes[i];
-
-        // TODO: this could be used to avoid unnecessary computations, but it needs to be improved
-        //if (node->grad == NULL && node->perf_runs > 0) {
-        //    continue;
-        //}
-
-        const int64_t perf_node_start_cycles  = ggml_perf_cycles();
-        const int64_t perf_node_start_time_us = ggml_perf_time_us();
-
-        // INIT
-        struct ggml_compute_params params = {
-            /*.type  =*/ GGML_TASK_INIT,
-            /*.ith   =*/ 0,
-            /*.nth   =*/ node->n_tasks,
-            /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
-            /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
-        };
-
-        ggml_compute_forward(&params, node);
-
-        // COMPUTE
-        if (node->n_tasks > 1) {
-            if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
-                atomic_store(&state_shared.has_work, false);
-            }
-
-            while (atomic_load(&state_shared.has_work)) {
-                ggml_lock_lock  (&state_shared.spin);
-                ggml_lock_unlock(&state_shared.spin);
-            }
-
-            // launch thread pool
-            for (int j = 0; j < n_threads - 1; j++) {
-                workers[j].params = (struct ggml_compute_params) {
-                    .type  = GGML_TASK_COMPUTE,
-                    .ith   = j + 1,
-                    .nth   = node->n_tasks,
-                    .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
-                    .wdata = cgraph->work ? cgraph->work->data : NULL,
-                };
-                workers[j].node = node;
-            }
-
-            atomic_fetch_sub(&state_shared.n_ready, 1);
-
-            while (atomic_load(&state_shared.n_ready) > 0) {
-                ggml_lock_lock  (&state_shared.spin);
-                ggml_lock_unlock(&state_shared.spin);
-            }
-
-            atomic_store(&state_shared.has_work, true);
-        }
-
-        params.type = GGML_TASK_COMPUTE;
-        ggml_compute_forward(&params, node);
-
-        // wait for thread pool
-        if (node->n_tasks > 1) {
-            if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
-                atomic_store(&state_shared.has_work, false);
-            }
-
-            while (atomic_load(&state_shared.has_work)) {
-                ggml_lock_lock  (&state_shared.spin);
-                ggml_lock_unlock(&state_shared.spin);
-            }
-
-            atomic_fetch_sub(&state_shared.n_ready, 1);
-
-            while (atomic_load(&state_shared.n_ready) != 0) {
-                ggml_lock_lock  (&state_shared.spin);
-                ggml_lock_unlock(&state_shared.spin);
-            }
-        }
-
-        // FINALIZE
-        if (node->n_tasks > 1) {
-            if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
-                atomic_store(&state_shared.has_work, false);
-            }
-
-            while (atomic_load(&state_shared.has_work)) {
-                ggml_lock_lock  (&state_shared.spin);
-                ggml_lock_unlock(&state_shared.spin);
-            }
-
-            // launch thread pool
-            for (int j = 0; j < n_threads - 1; j++) {
-                workers[j].params = (struct ggml_compute_params) {
-                    .type  = GGML_TASK_FINALIZE,
-                    .ith   = j + 1,
-                    .nth   = node->n_tasks,
-                    .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
-                    .wdata = cgraph->work ? cgraph->work->data : NULL,
-                };
-                workers[j].node = node;
-            }
-
-            atomic_fetch_sub(&state_shared.n_ready, 1);
-
-            while (atomic_load(&state_shared.n_ready) > 0) {
-                ggml_lock_lock  (&state_shared.spin);
-                ggml_lock_unlock(&state_shared.spin);
-            }
+    // create thread pool
+    if (n_threads > 1) {
+        for (int j = 1; j < n_threads; ++j) {
+            workers[j] = (struct ggml_compute_state) {
+                .thrd   = 0,
+                .ith = j,
+                .shared = &state_shared,
+            };
 
-            atomic_store(&state_shared.has_work, true);
+            const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
+            GGML_ASSERT(rc == 0);
         }
+    }
+    workers[0].ith = 0;
+    workers[0].shared = &state_shared;
 
-        params.type = GGML_TASK_FINALIZE;
-        ggml_compute_forward(&params, node);
-
-        // wait for thread pool
-        if (node->n_tasks > 1) {
-            if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
-                atomic_store(&state_shared.has_work, false);
-            }
-
-            while (atomic_load(&state_shared.has_work)) {
-                ggml_lock_lock  (&state_shared.spin);
-                ggml_lock_unlock(&state_shared.spin);
-            }
-
-            atomic_fetch_sub(&state_shared.n_ready, 1);
-
-            while (atomic_load(&state_shared.n_ready) != 0) {
-                ggml_lock_lock  (&state_shared.spin);
-                ggml_lock_unlock(&state_shared.spin);
-            }
-        }
+    const int64_t perf_start_cycles  = ggml_perf_cycles();
+    const int64_t perf_start_time_us = ggml_perf_time_us();
 
-        // performance stats (node)
-        {
-            int64_t perf_cycles_cur  = ggml_perf_cycles()  - perf_node_start_cycles;
-            int64_t perf_time_us_cur = ggml_perf_time_us() - perf_node_start_time_us;
+    // this is a work thread too
+    ggml_graph_compute_thread(&workers[0]);
 
-            node->perf_runs++;
-            node->perf_cycles  += perf_cycles_cur;
-            node->perf_time_us += perf_time_us_cur;
-        }
-    }
+    // don't leave affinity set on the main thread
+    clear_numa_thread_affinity();
 
     // join thread pool
     if (n_threads > 1) {
-        atomic_store(&state_shared.stop, true);
-        atomic_store(&state_shared.has_work, true);
-
-        for (int j = 0; j < n_threads - 1; j++) {
-            int rc = ggml_thread_join(workers[j].thrd, NULL);
+        for (int j = 1; j < n_threads; j++) {
+            const int rc = ggml_thread_join(workers[j].thrd, NULL);
             GGML_ASSERT(rc == 0);
-            UNUSED(rc);
         }
-
-        ggml_lock_destroy(&state_shared.spin);
     }
 
     // performance stats (graph)
@@ -17303,13 +17004,6 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
                     fwrite(&nb, sizeof(uint64_t), 1, fout);
                 }
 
-                // store the pointer address
-                {
-                    const uint64_t ptr = (uint64_t) tensor->data;
-
-                    fwrite(&ptr, sizeof(uint64_t), 1, fout);
-                }
-
                 fwrite(tensor->name, sizeof(char), GGML_MAX_NAME, fout);
 
                 // dump the data
@@ -17343,13 +17037,6 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
                     fwrite(&nb, sizeof(uint64_t), 1, fout);
                 }
 
-                // store the pointer address
-                {
-                    const uint64_t ptr = (uint64_t) tensor->data;
-
-                    fwrite(&ptr, sizeof(uint64_t), 1, fout);
-                }
-
                 fwrite(tensor->name, sizeof(char), GGML_MAX_NAME, fout);
 
                 // output the op arguments
@@ -17534,8 +17221,6 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
 
                 tensor->op = (enum ggml_op) op;
 
-                uint64_t ptr_cur = *(const uint64_t *) ptr; ptr += sizeof(ptr_cur);
-
                 memcpy(tensor->name, ptr, GGML_MAX_NAME); ptr += GGML_MAX_NAME;
 
                 tensor->data = (void *) ptr;
@@ -17581,8 +17266,6 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
                     nb[j] = nb_cur;
                 }
 
-                uint64_t ptr_cur = *(const uint64_t *) ptr; ptr += sizeof(ptr_cur); // TODO: not yet used
-
                 const char * ptr_name = ptr; ptr += GGML_MAX_NAME;
 
                 const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += (2 + GGML_MAX_OPT)*sizeof(int32_t);
diff --git a/ggml.h b/ggml.h
index 5ebd9c46c3d5267df4e2a879cfcb633bd23c6085..0af96c76b6d0ecd4951c6dc6b8b5e3959bfc3d0b 100644 (file)
--- a/ggml.h
+++ b/ggml.h
 #define GGML_MAX_PARAMS        256
 #define GGML_MAX_CONTEXTS      64
 #define GGML_MAX_OPT           4
-#define GGML_MAX_NAME          32
+#define GGML_MAX_NAME          48
 #define GGML_DEFAULT_N_THREADS 4
 
+#define GGML_UNUSED(x) (void)(x)
+
 #define GGML_ASSERT(x) \
     do { \
         if (!(x)) { \
         } \
     } while (0)
 
+// used to copy the number of elements and stride in bytes of tensors into local variables.
+// main purpose is to reduce code duplication and improve readability.
+//
+// example:
+//
+//    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
+//    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb);
+//
+#define GGML_TENSOR_LOCALS_1(type, prefix, pointer, array) \
+    const type prefix##0 = (pointer)->array[0]; \
+    GGML_UNUSED(prefix##0);
+#define GGML_TENSOR_LOCALS_2(type, prefix, pointer, array) \
+    GGML_TENSOR_LOCALS_1    (type, prefix, pointer, array) \
+    const type prefix##1 = (pointer)->array[1]; \
+    GGML_UNUSED(prefix##1);
+#define GGML_TENSOR_LOCALS_3(type, prefix, pointer, array) \
+    GGML_TENSOR_LOCALS_2    (type, prefix, pointer, array) \
+    const type prefix##2 = (pointer)->array[2]; \
+    GGML_UNUSED(prefix##2);
+#define GGML_TENSOR_LOCALS(type, prefix, pointer, array) \
+    GGML_TENSOR_LOCALS_3  (type, prefix, pointer, array) \
+    const type prefix##3 = (pointer)->array[3]; \
+    GGML_UNUSED(prefix##3);
+
 #ifdef  __cplusplus
 extern "C" {
 #endif
@@ -295,12 +321,15 @@ extern "C" {
         GGML_OP_SUM,
         GGML_OP_SUM_ROWS,
         GGML_OP_MEAN,
+        GGML_OP_ARGMAX,
         GGML_OP_REPEAT,
         GGML_OP_REPEAT_BACK,
         GGML_OP_ABS,
         GGML_OP_SGN,
         GGML_OP_NEG,
         GGML_OP_STEP,
+        GGML_OP_TANH,
+        GGML_OP_ELU,
         GGML_OP_RELU,
         GGML_OP_GELU,
         GGML_OP_GELU_QUICK,
@@ -332,9 +361,8 @@ extern "C" {
         GGML_OP_ROPE_BACK,
         GGML_OP_ALIBI,
         GGML_OP_CLAMP,
-        GGML_OP_CONV_1D_S1_PH,
-        GGML_OP_CONV_1D_S2_PH,
-        GGML_OP_CONV_2D_SK_P0,
+        GGML_OP_CONV_1D,
+        GGML_OP_CONV_2D,
 
         GGML_OP_FLASH_ATTN,
         GGML_OP_FLASH_FF,
@@ -444,6 +472,9 @@ extern "C" {
 
 
     // compute types
+
+    // NOTE: the INIT or FINALIZE pass is not scheduled unless explicitly enabled.
+    // This behavior was changed since https://github.com/ggerganov/llama.cpp/pull/1995.
     enum ggml_task_type {
         GGML_TASK_INIT = 0,
         GGML_TASK_COMPUTE,
@@ -469,6 +500,9 @@ extern "C" {
     GGML_API int64_t ggml_cycles(void);
     GGML_API int64_t ggml_cycles_per_ms(void);
 
+    GGML_API void    ggml_numa_init(void); // call once for better performance on NUMA systems
+    GGML_API bool    ggml_is_numa(void); // true if init detected that system has >1 NUMA node
+
     GGML_API void    ggml_print_object (const struct ggml_object * obj);
     GGML_API void    ggml_print_objects(const struct ggml_context * ctx);
 
@@ -684,6 +718,11 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    // argmax along rows
+    GGML_API struct ggml_tensor * ggml_argmax(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     // if a is the same shape as b, and a is not parameter, return a
     // otherwise, return a new tensor: repeat(a) to fit in b
     GGML_API struct ggml_tensor * ggml_repeat(
@@ -728,6 +767,22 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    GGML_API struct ggml_tensor * ggml_tanh(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_tanh_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_elu(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_elu_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     GGML_API struct ggml_tensor * ggml_relu(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
@@ -1033,13 +1088,15 @@ extern "C" {
     // rotary position embedding
     // if mode & 1 == 1, skip n_past elements
     // if mode & 2 == 1, GPT-NeoX style
+    // if mode & 4 == 1, ChatGLM style
     // TODO: avoid creating a new tensor every time
     GGML_API struct ggml_tensor * ggml_rope(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             int                   n_past,
             int                   n_dims,
-            int                   mode);
+            int                   mode,
+            int                   n_ctx);
 
     // in-place, returns view(a)
     GGML_API struct ggml_tensor * ggml_rope_inplace(
@@ -1047,7 +1104,8 @@ extern "C" {
             struct ggml_tensor  * a,
             int                   n_past,
             int                   n_dims,
-            int                   mode);
+            int                   mode,
+            int                   n_ctx);
 
     // rotary position embedding backward, i.e compute dx from dy
     // a - dy
@@ -1075,58 +1133,33 @@ extern "C" {
             float                 min,
             float                 max);
 
-    // TODO: implement general-purpose convolutions
-    // GGML_API struct ggml_tensor * ggml_conv_1d(
-    //        struct ggml_context * ctx,
-    //        struct ggml_tensor  * a,
-    //        struct ggml_tensor  * b,
-    //        int                   s0
-    //        int                   p0,
-    //        int                   d0);
-    //
-    // GGML_API struct ggml_tensor * ggml_conv_2d(
-    //        struct ggml_context * ctx,
-    //        struct ggml_tensor  * a,
-    //        struct ggml_tensor  * b,
-    //        int                   s0,
-    //        int                   s1,
-    //        int                   p0,
-    //        int                   p1,
-    //        int                   d0,
-    //        int                   d1);
-
-    // padding = half
-    // TODO: we don't support extra parameters for now
-    //       that's why we are hard-coding the stride, padding, and dilation
-    //       not great ..
-    // example:
-    // a:      3   80  768    1
-    // b:   3000   80    1    1
-    // res: 3000  768    1    1
-    // used in whisper
-    GGML_API struct ggml_tensor * ggml_conv_1d_s1_ph(
+    GGML_API struct ggml_tensor * ggml_conv_1d(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
+            struct ggml_tensor  * b,
+            int                   s0,  // stride
+            int                   p0,  // padding
+            int                   d0); // dilation
 
-    // used in whisper
-    GGML_API struct ggml_tensor * ggml_conv_1d_s2_ph(
+    GGML_API struct ggml_tensor * ggml_conv_2d(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
+            struct ggml_tensor  * b,
+            int                   s0,
+            int                   s1,
+            int                   p0,
+            int                   p1,
+            int                   d0,
+            int                   d1);
 
-    // kernel size is a->ne[0] x a->ne[1]
-    // stride is equal to kernel size
-    // padding is zero
-    // example:
-    // a:     16   16    3  768
-    // b:   1024 1024    3    1
-    // res:   64   64  768    1
-    // used in sam
-    GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0(
+    // conv_1d with padding = half
+    // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)
+    GGML_API struct ggml_tensor* ggml_conv_1d_ph(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
+            struct ggml_tensor  * b,
+            int                   s,
+            int                   d);
 
     GGML_API struct ggml_tensor * ggml_flash_attn(
             struct ggml_context * ctx,
index 5f3888c79165a33d622592397ab1bda1b89637a6..932ae6fe32d964a3f5cfd064bc944ce5252400fb 100644 (file)
@@ -812,7 +812,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
     {
         uint32_t magic;
         read_safe(loader, magic);
-        if (magic != 0x67676d6c) {
+        if (magic != GGML_FILE_MAGIC) {
             fprintf(stderr, "%s: invalid model data (bad magic)\n", __func__);
             return false;
         }
@@ -1472,7 +1472,7 @@ static bool whisper_encode_internal(
         {
             wstate.use_buf(ctx0, 1);
 
-            cur = ggml_conv_1d_s1_ph(ctx0, model.e_conv_1_w, mel);
+            cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
             cur = ggml_add(ctx0,
                     ggml_repeat(ctx0,
                         model.e_conv_1_b,
@@ -1483,7 +1483,7 @@ static bool whisper_encode_internal(
 
             wstate.use_buf(ctx0, 0);
 
-            cur = ggml_conv_1d_s2_ph(ctx0, model.e_conv_2_w, cur);
+            cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
             cur = ggml_add(ctx0,
                     ggml_repeat(ctx0,
                         model.e_conv_2_b,