]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
sync : ggml (Metal fixes, new ops, tests) (#1633)
authorGeorgi Gerganov <redacted>
Wed, 13 Dec 2023 19:55:03 +0000 (21:55 +0200)
committerGitHub <redacted>
Wed, 13 Dec 2023 19:55:03 +0000 (21:55 +0200)
* sync : ggml (Metal fixes, new ops, tests)

* cuda : fix bin bcast when src1 and dst have different types

ggml-alloc.h
ggml-cuda.cu
ggml-metal.m
ggml-metal.metal
ggml-quants.c
ggml.c
ggml.h

index ad87cebc8873f4584ad99e2c67e110a603f4a71c..64a412468915b47c58ce100bb2dfcaf4d5502b32 100644 (file)
@@ -43,7 +43,7 @@ GGML_API size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph
 // ggml-backend v2 API
 //
 
-// Seperate tensor and graph allocator objects
+// Separate tensor and graph allocator objects
 // This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators
 // The original API is kept as a wrapper around the new API
 
index 85f7a293783bed95758dd729fb49fb6c97d27017..019648bddc4d9e23d363d289b432f10dbc481853 100644 (file)
@@ -1,13 +1,15 @@
 #include <algorithm>
+#include <assert.h>
+#include <atomic>
+#include <cinttypes>
 #include <cstddef>
 #include <cstdint>
-#include <cinttypes>
 #include <float.h>
 #include <limits>
 #include <stdint.h>
 #include <stdio.h>
-#include <atomic>
-#include <assert.h>
+#include <vector>
+
 
 #if defined(GGML_USE_HIPBLAS)
 #include <hip/hip_runtime.h>
@@ -437,6 +439,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 
 #define CUDA_GELU_BLOCK_SIZE 256
 #define CUDA_SILU_BLOCK_SIZE 256
+#define CUDA_TANH_BLOCK_SIZE 256
 #define CUDA_RELU_BLOCK_SIZE 256
 #define CUDA_SQR_BLOCK_SIZE 256
 #define CUDA_CPY_BLOCK_SIZE 32
@@ -449,6 +452,11 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 #define CUDA_QUANTIZE_BLOCK_SIZE 256
 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
 #define CUDA_GET_ROWS_BLOCK_SIZE 256
+#define CUDA_UPSCALE_BLOCK_SIZE 256
+#define CUDA_CONCAT_BLOCK_SIZE 256
+#define CUDA_PAD_BLOCK_SIZE 256
+#define CUDA_ACC_BLOCK_SIZE 256
+#define CUDA_IM2COL_BLOCK_SIZE 256
 
 // dmmv = dequantize_mul_mat_vec
 #ifndef GGML_CUDA_DMMV_X
@@ -610,6 +618,24 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
     dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
 }
 
+static __global__ void acc_f32(const float * x, const float * y, float * dst, const int ne,
+    const int ne10, const int ne11, const int ne12,
+    const int nb1, const int nb2, int offset) {
+    const int i = blockDim.x * blockIdx.x + threadIdx.x;
+    if (i >= ne) {
+        return;
+    }
+    int src1_idx = i - offset;
+    int oz = src1_idx / nb2;
+    int oy = (src1_idx - (oz * nb2)) / nb1;
+    int ox = src1_idx % nb1;
+    if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
+        dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
+    } else {
+        dst[i] = x[i];
+    }
+}
+
 static __global__ void gelu_f32(const float * x, float * dst, const int k) {
     const float GELU_COEF_A    = 0.044715f;
     const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
@@ -632,6 +658,23 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
     dst[i] = x[i] / (1.0f + expf(-x[i]));
 }
 
+static __global__ void gelu_quick_f32(const float *x, float *dst, int k) {
+    const float GELU_QUICK_COEF = -1.702f;
+    const int i  = blockDim.x*blockIdx.x + threadIdx.x;
+    if (i >= k) {
+        return;
+    }
+    dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i])));
+}
+
+static __global__ void tanh_f32(const float *x, float *dst, int k) {
+    const int i  = blockDim.x*blockIdx.x + threadIdx.x;
+    if (i >= k) {
+        return;
+    }
+    dst[i] = tanhf(x[i]);
+}
+
 static __global__ void relu_f32(const float * x, float * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
@@ -641,6 +684,14 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) {
     dst[i] = fmaxf(x[i], 0);
 }
 
+static __global__ void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope) {
+    const int i  = blockDim.x*blockIdx.x + threadIdx.x;
+    if (i >= k) {
+        return;
+    }
+    dst[i] = fmaxf(x[i], 0) + fminf(x[i], 0.0f) * negative_slope;
+}
+
 static __global__ void sqr_f32(const float * x, float * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
@@ -686,6 +737,132 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c
     }
 }
 
+static __global__ void concat_f32(const float  *x,const float  *y, float *dst, const int ne0, const int ne02) {
+    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (nidx >= ne0) {
+        return;
+    }
+    // operation
+    int offset_dst =
+        nidx +
+        blockIdx.y * ne0 +
+        blockIdx.z * ne0 * gridDim.y;
+    if (blockIdx.z < ne02) { // src0
+        int offset_src =
+            nidx +
+            blockIdx.y * ne0 +
+            blockIdx.z * ne0 * gridDim.y;
+            dst[offset_dst] = x[offset_src];
+    } else {
+        int offset_src =
+            nidx +
+            blockIdx.y * ne0 +
+            (blockIdx.z - ne02) * ne0 *  gridDim.y;
+            dst[offset_dst] = y[offset_src];
+    }
+}
+
+static __global__ void upscale_f32(const float  *x, float *dst, const int ne00, const int nb02, const int scale_factor) {
+    int ne0 = ne00 * scale_factor;
+    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (nidx >= ne0) {
+        return;
+    }
+    // operation
+    int i00 = nidx / scale_factor;
+    int i01 = blockIdx.y / scale_factor;
+    int offset_src =
+        i00 +
+        i01 * ne00 +
+        blockIdx.z * nb02;
+    int offset_dst =
+        nidx +
+        blockIdx.y * ne0 +
+        blockIdx.z * ne0 * gridDim.y;
+    dst[offset_dst] = x[offset_src];
+}
+
+static __global__ void pad_f32(const float  *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02) {
+    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (nidx >= ne0) {
+        return;
+    }
+
+    // operation
+    int offset_dst =
+        nidx +
+        blockIdx.y * ne0 +
+        blockIdx.z * ne0 * gridDim.y;
+    if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02) {
+        int offset_src =
+            nidx +
+            blockIdx.y * ne00 +
+            blockIdx.z * ne00 * ne01;
+            dst[offset_dst] = x[offset_src];
+    } else {
+        dst[offset_dst] = 0.0f;
+    }
+}
+
+template <int block_size>
+static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
+    int start = blockIdx.x * group_size;
+    int end = start + group_size;
+
+    start += threadIdx.x;
+
+    if (end >= ne_elements) {
+        end = ne_elements;
+    }
+
+    float tmp = 0.0f; // partial sum for thread in warp
+
+    for (int j = start; j < end; j += block_size) {
+        tmp += x[j];
+    }
+
+    tmp = warp_reduce_sum(tmp);
+    if (block_size > WARP_SIZE) {
+        __shared__ float s_sum[32];
+        int warp_id = threadIdx.x / WARP_SIZE;
+        int lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            s_sum[warp_id] = tmp;
+        }
+        __syncthreads();
+        tmp = s_sum[lane_id];
+        tmp = warp_reduce_sum(tmp);
+    }
+
+    float mean = tmp / group_size;
+    tmp = 0.0f;
+
+    for (int j = start; j < end; j += block_size) {
+        float xi = x[j] - mean;
+        dst[j] = xi;
+        tmp += xi * xi;
+    }
+
+    tmp = warp_reduce_sum(tmp);
+    if (block_size > WARP_SIZE) {
+        __shared__ float s_sum[32];
+        int warp_id = threadIdx.x / WARP_SIZE;
+        int lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            s_sum[warp_id] = tmp;
+        }
+        __syncthreads();
+        tmp = s_sum[lane_id];
+        tmp = warp_reduce_sum(tmp);
+    }
+
+    float variance = tmp / group_size;
+    float scale = rsqrtf(variance + eps);
+    for (int j = start; j < end; j += block_size) {
+        dst[j] *= scale;
+    }
+}
+
 template <int block_size>
 static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
     const int row = blockIdx.x*blockDim.y + threadIdx.y;
@@ -1684,31 +1861,65 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
 }
 
 template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
-static __global__ void k_get_rows(const void * x, const int32_t * y, dst_t * dst, const int ncols) {
-    const int col = (blockIdx.x*blockDim.x + threadIdx.x)*2;
-    const int row = blockDim.y*blockIdx.y + threadIdx.y;
-
-    if (col >= ncols) {
+static __global__ void k_get_rows(
+            const void * src0, const int32_t * src1, dst_t * dst,
+            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+            size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
+
+    const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
+    const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
+    const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
+    const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
+
+    if (i00 >= ne00) {
         return;
     }
 
-    const int r = y[row];
+    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
 
-    // copy x[r*ncols + col] to dst[row*ncols + col]
-    const int xi = r*ncols + col;
-    const int di = row*ncols + col;
+    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+    const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
 
-    const int ib = xi/qk; // block index
-    const int iqs = (xi%qk)/qr; // quant index
-    const int iybs = di - di%qk; // y block start index
+    const int ib = i00/qk; // block index
+    const int iqs = (i00%qk)/qr; // quant index
+    const int iybs = i00 - i00%qk; // dst block start index
     const int y_offset = qr == 1 ? 1 : qk/2;
 
     // dequantize
     dfloat2 v;
-    dequantize_kernel(x, ib, iqs, v);
+    dequantize_kernel(src0_row, ib, iqs, v);
+
+    dst_row[iybs + iqs + 0]        = v.x;
+    dst_row[iybs + iqs + y_offset] = v.y;
+}
+
+template<typename src0_t, typename dst_t>
+static __global__ void k_get_rows_float(
+            const src0_t * src0, const int32_t * src1, dst_t * dst,
+            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+            size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
+
+    const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
+    const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
+    const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
+    const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
+
+    if (i00 >= ne00) {
+        return;
+    }
+
+    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+
+    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+    const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
 
-    dst[iybs + iqs + 0]        = v.x;
-    dst[iybs + iqs + y_offset] = v.y;
+    dst_row[i00] = src0_row[i00];
 }
 
 template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
@@ -5035,29 +5246,98 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
 
 static  __global__ void im2col_f32_f16(
         const float * x, half * dst,
-        int ofs0, int ofs1, int IW, int IH, int CHW,
+        int offset_delta, int IW, int IH, int OW, int KW, int KH, int pelements, int CHW,
         int s0, int s1, int p0, int p1, int d0, int d1) {
-    const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
-    const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
+    const int i = threadIdx.x + blockIdx.x * blockDim.x;
+    if (i >= pelements) {
+        return;
+    }
+
+    const int ksize = OW * (KH > 1 ? KW : 1);
+    const int kx = i / ksize;
+    const int kd = kx * ksize;
+    const int ky = (i - kd) / OW;
+    const int ix = i % OW;
+
+    const int iiw = ix * s0 + kx * d0 - p0;
+    const int iih = blockIdx.y * s1 + ky * d1 - p1;
 
     const int offset_dst =
-        (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
-        (blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
+        (blockIdx.y * OW + ix) * CHW +
+        (blockIdx.z * (KW * KH) + ky * KW + kx);
 
     if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
         dst[offset_dst] = __float2half(0.0f);
     } else {
-        const int offset_src =  threadIdx.x * ofs0 + blockIdx.x * ofs1;
+        const int offset_src = blockIdx.z * offset_delta;
         dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
     }
 }
 
 template<int qk, int qr, dequantize_kernel_t dq>
-static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
+static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+                            const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
     const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
-    const int block_num_x = (ncols + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
-    const dim3 block_nums(block_num_x, nrows, 1);
-    k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols);
+    const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
+    const dim3 block_nums(block_num_x, ne10, ne11*ne12);
+
+    // strides in elements
+    //const size_t s0 = nb0 / ggml_element_size(dst);
+    const size_t s1 = nb1 / ggml_element_size(dst);
+    const size_t s2 = nb2 / ggml_element_size(dst);
+    const size_t s3 = nb3 / ggml_element_size(dst);
+
+    const size_t s10 = nb10 / ggml_element_size(src1);
+    const size_t s11 = nb11 / ggml_element_size(src1);
+    const size_t s12 = nb12 / ggml_element_size(src1);
+    //const size_t s13 = nb13 / ggml_element_size(src1);
+
+    GGML_ASSERT(ne00 % 2 == 0);
+
+    k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
+            src0_dd, src1_dd, dst_dd,
+            ne00, /*ne01, ne02, ne03,*/
+            /*ne10, ne11,*/ ne12, /*ne13,*/
+            /* s0,*/ s1, s2, s3,
+            /* nb00,*/ nb01, nb02, nb03,
+            s10, s11, s12/*, s13*/);
+
+    (void) dst;
+}
+
+template<typename src0_t>
+static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+                                const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
+    const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
+    const dim3 block_nums(block_num_x, ne10, ne11*ne12);
+
+    // strides in elements
+    //const size_t s0 = nb0 / ggml_element_size(dst);
+    const size_t s1 = nb1 / ggml_element_size(dst);
+    const size_t s2 = nb2 / ggml_element_size(dst);
+    const size_t s3 = nb3 / ggml_element_size(dst);
+
+    const size_t s10 = nb10 / ggml_element_size(src1);
+    const size_t s11 = nb11 / ggml_element_size(src1);
+    const size_t s12 = nb12 / ggml_element_size(src1);
+    //const size_t s13 = nb13 / ggml_element_size(src1);
+
+    k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
+            src0_dd, src1_dd, dst_dd,
+            ne00, /*ne01, ne02, ne03,*/
+            /*ne10, ne11,*/ ne12, /*ne13,*/
+            /* s0,*/ s1, s2, s3,
+            /* nb00,*/ nb01, nb02, nb03,
+            s10, s11, s12/*, s13*/);
+
+    (void) dst;
 }
 
 template<float (*bin_op)(const float, const float)>
@@ -5069,7 +5349,6 @@ struct bin_bcast_cuda {
 
         GGML_TENSOR_BINARY_OP_LOCALS
 
-
         int nr0 = ne10/ne0;
         int nr1 = ne11/ne1;
         int nr2 = ne12/ne2;
@@ -5117,26 +5396,28 @@ struct bin_bcast_cuda {
             int64_t ne12 = cne1[2];
             int64_t ne13 = cne1[3];
 
-            //size_t nb0 = cnb0[0];
+            size_t nb0 = cnb0[0];
             size_t nb1 = cnb0[1];
             size_t nb2 = cnb0[2];
             size_t nb3 = cnb0[3];
 
-            //size_t nb10 = cnb1[0];
+            size_t nb10 = cnb1[0];
             size_t nb11 = cnb1[1];
             size_t nb12 = cnb1[2];
             size_t nb13 = cnb1[3];
 
-            //size_t s0 = nb0 / sizeof(src1_t);
-            size_t s1 = nb1 / sizeof(src1_t);
-            size_t s2 = nb2 / sizeof(src1_t);
-            size_t s3 = nb3 / sizeof(src1_t);
+            size_t s0 = nb0 / sizeof(dst_t);
+            size_t s1 = nb1 / sizeof(dst_t);
+            size_t s2 = nb2 / sizeof(dst_t);
+            size_t s3 = nb3 / sizeof(dst_t);
 
-            //size_t s10 = nb10 / sizeof(src1_t);
+            size_t s10 = nb10 / sizeof(src1_t);
             size_t s11 = nb11 / sizeof(src1_t);
             size_t s12 = nb12 / sizeof(src1_t);
             size_t s13 = nb13 / sizeof(src1_t);
 
+            GGML_ASSERT(s0 == 1);
+            GGML_ASSERT(s10 == 1);
 
             const int block_size = 128;
 
@@ -5174,6 +5455,13 @@ struct bin_bcast_cuda {
     }
 };
 
+static void acc_f32_cuda(const float * x, const float * y, float * dst, const int n_elements,
+    const int ne10, const int ne11, const int ne12,
+    const int nb1, const int nb2, const int offset, cudaStream_t stream) {
+    int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE;
+    acc_f32<<<num_blocks, CUDA_ACC_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset);
+}
+
 static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
     gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -5184,11 +5472,26 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
     silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
+static void gelu_quick_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
+    gelu_quick_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
+    tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
 static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
     relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
+static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
+    leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
+}
+
 static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
     sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -5205,6 +5508,38 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
     }
 }
 
+static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const int group_size, const int ne_elements, cudaStream_t stream) {
+    static const float eps = 1e-6f;
+    if (group_size < 1024) {
+        const dim3 block_dims(WARP_SIZE, 1, 1);
+        group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
+    } else {
+        const dim3 block_dims(1024, 1, 1);
+        group_norm_f32<1024><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
+    }
+}
+
+static void concat_f32_cuda(const float * x, const float * y, float * dst, const int ne0, int ne1, int ne2, int ne02, cudaStream_t stream) {
+    int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
+    dim3 gridDim(num_blocks, ne1, ne2);
+    concat_f32<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
+}
+
+static void upscale_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int scale_factor, cudaStream_t stream) {
+    int ne0 = (ne00 * scale_factor);
+    int num_blocks = (ne0 + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
+    dim3 gridDim(num_blocks, (ne01 * scale_factor), ne02);
+    upscale_f32<<<gridDim, CUDA_UPSCALE_BLOCK_SIZE, 0, stream>>>(x, dst, ne00, ne00 * ne01, scale_factor);
+}
+
+static void pad_f32_cuda(const float * x, float * dst,
+    const int ne00, const int ne01, const int ne02,
+    const int ne0, const int ne1, const int ne2, cudaStream_t stream) {
+    int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
+    dim3 gridDim(num_blocks, ne1, ne2);
+    pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02);
+}
+
 static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
     GGML_ASSERT(ncols % WARP_SIZE == 0);
     if (ncols < 1024) {
@@ -6167,13 +6502,14 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con
     soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
 }
 
-static void im2col_f32_f16_cuda(const float * x, half * dst,
-    int OH, int IW, int IH, int OW, int IC,
-    int KH, int KW, int N,  int ofs0, int ofs1,
-    int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) {
-    dim3 block_nums(IC, OH, OW);
-    dim3 block_dims(N,  KH, KW);
-    im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
+static void im2col_f32_f16_cuda(const float* x, half* dst,
+    int IW, int IH, int OW, int OH, int KW, int KH, int IC,
+    int offset_delta,
+    int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
+    const int parallel_elements = OW * KW * KH;
+    const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
+    dim3 block_nums(num_blocks, OH, IC);
+    im2col_f32_f16<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, offset_delta, IW, IH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
 }
 
 // buffer pool for cuda
@@ -6447,36 +6783,34 @@ static void ggml_cuda_op_get_rows(
 
     GGML_ASSERT(src1->type == GGML_TYPE_I32);
     GGML_ASSERT(dst->type == GGML_TYPE_F32);
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    GGML_ASSERT(ggml_is_contiguous(src1));
-    GGML_ASSERT(ggml_is_contiguous(dst));
 
-    const int ncols = src0->ne[0];
-    const int nrows = ggml_nelements(src1);
+    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+    GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
+    GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
 
     const int32_t * src1_i32 = (const int32_t *) src1_d;
 
     switch (src0->type) {
         case GGML_TYPE_F16:
-            get_rows_cuda<1, 1, convert_f16>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+            get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream);
             break;
         case GGML_TYPE_F32:
-            get_rows_cuda<1, 1, convert_f32>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+            get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
             break;
         case GGML_TYPE_Q4_0:
-            get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+            get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
             break;
         case GGML_TYPE_Q4_1:
-            get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+            get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
             break;
         case GGML_TYPE_Q5_0:
-            get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+            get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
             break;
         case GGML_TYPE_Q5_1:
-            get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+            get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
             break;
         case GGML_TYPE_Q8_0:
-            get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+            get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
             break;
         default:
             // TODO: k-quants
@@ -6522,6 +6856,25 @@ inline void ggml_cuda_op_add(
     ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
 }
 
+inline void ggml_cuda_op_acc(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
+
+    int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
+    int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
+    // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
+    int offset = dst->op_params[3] / 4; // offset in bytes
+
+    acc_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, main_stream);
+
+    (void) dst;
+}
+
 inline void ggml_cuda_op_mul(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -6564,6 +6917,34 @@ inline void ggml_cuda_op_silu(
     (void) src1_dd;
 }
 
+inline void ggml_cuda_op_gelu_quick(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    gelu_quick_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_cuda_op_tanh(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    tanh_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
 inline void ggml_cuda_op_relu(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -6578,6 +6959,23 @@ inline void ggml_cuda_op_relu(
     (void) src1_dd;
 }
 
+inline void ggml_cuda_op_leaky_relu(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    float negative_slope;
+    memcpy(&negative_slope, dst->op_params, sizeof(float));
+
+    leaky_relu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), negative_slope, main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
 inline void ggml_cuda_op_sqr(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -6612,6 +7010,71 @@ inline void ggml_cuda_op_norm(
     (void) src1_dd;
 }
 
+
+inline void ggml_cuda_op_group_norm(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    int num_groups = dst->op_params[0];
+    int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
+    group_norm_f32_cuda(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_cuda_op_concat(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    for (int i3 = 0; i3 < dst->ne[3]; i3++) {
+        concat_f32_cuda(src0_dd + i3 * (src0->nb[3] / 4), src1_dd + i3 * (src1->nb[3] / 4), dst_dd + i3 * (dst->nb[3] / 4), dst->ne[0], dst->ne[1], dst->ne[2], src0->ne[2], main_stream);
+    }
+
+    (void) src1;
+    (void) dst;
+}
+
+inline void ggml_cuda_op_upscale(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
+
+    const int scale_factor = dst->op_params[0];
+
+    upscale_f32_cuda(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], scale_factor, main_stream);
+
+    (void) src1;
+    (void) dst;
+}
+
+inline void ggml_cuda_op_pad(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
+
+    pad_f32_cuda(src0_dd, dst_dd,
+        src0->ne[0], src0->ne[1], src0->ne[2],
+        dst->ne[0], dst->ne[1], dst->ne[2], main_stream);
+
+    (void) src1;
+    (void) dst;
+}
+
 inline void ggml_cuda_op_rms_norm(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -7126,7 +7589,6 @@ inline void ggml_cuda_op_im2col(
 
     const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
 
-    const int64_t N  = src1->ne[is_2D ? 3 : 2];
     const int64_t IC = src1->ne[is_2D ? 2 : 1];
     const int64_t IH = is_2D ? src1->ne[1] : 1;
     const int64_t IW =         src1->ne[0];
@@ -7137,17 +7599,15 @@ inline void ggml_cuda_op_im2col(
     const int64_t OH = is_2D ? dst->ne[2] : 1;
     const int64_t OW =         dst->ne[1];
 
-    const size_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
-    const size_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
+    const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
 
-    im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
-        OH, IW, IH, OW, IC, KH, KW, N,
-        ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream);
+    im2col_f32_f16_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
 
     (void) src0;
     (void) src0_dd;
 }
 
+
 inline void ggml_cuda_op_sum_rows(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -7696,6 +8156,10 @@ static void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, gg
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add);
 }
 
+static void ggml_cuda_acc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_acc);
+}
+
 static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul);
 }
@@ -7712,10 +8176,22 @@ static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, g
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
 }
 
+static void ggml_cuda_gelu_quick(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu_quick);
+}
+
+static void ggml_cuda_tanh(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_tanh);
+}
+
 static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
 }
 
+static void ggml_cuda_leaky_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_leaky_relu);
+}
+
 static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr);
 }
@@ -7724,6 +8200,22 @@ static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, g
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
 }
 
+static void ggml_cuda_group_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_group_norm);
+}
+
+static void ggml_cuda_concat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_concat);
+}
+
+static void ggml_cuda_upscale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_upscale);
+}
+
+static void ggml_cuda_pad(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_pad);
+}
+
 static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm);
 }
@@ -8234,36 +8726,69 @@ static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) {
 }
 #endif
 
-static void ggml_cuda_mul_mat_id(const ggml_tensor * _src0, const ggml_tensor * _src1, ggml_tensor * dst) {
+static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
 #if 0
-//#ifdef CUDA_USE_TENSOR_CORES
-//    const bool use_tensor_cores = true;
-//#else
-//    const bool use_tensor_cores = false;
-//#endif
-
     ggml_cuda_mul_mat_id_cublas(dst);
-
     // TODO: mmq/mmv support
-#else
-    const struct ggml_tensor * ids = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-    const int id = dst->op_params[0];
+#endif
 
-    int32_t * ids_dev = (int32_t *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
+    GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
 
-    int32_t a_id;
-    CUDA_CHECK(cudaMemcpyAsync(&a_id, ids_dev + id, sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
-    CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
+    const struct ggml_tensor * ids = src0;
+    const int32_t id = ((int32_t *) dst->op_params)[0];
+    const int32_t n_as = ((int32_t *) dst->op_params)[1];
 
-    GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
-    const struct ggml_tensor * src0 = dst->src[a_id + 2];
+    std::vector<char> ids_host(ggml_nbytes(ids));
 
-    ggml_cuda_mul_mat(src0, src1, dst);
-#endif
+    if (ids->backend == GGML_BACKEND_GPU) {
+        const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
+        CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
+        CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
+    } else {
+        memcpy(ids_host.data(), ids->data, ggml_nbytes(ids));
+    }
+
+    const ggml_tensor_extra_gpu * src1_extra = (const ggml_tensor_extra_gpu *) src1->extra;
+    const ggml_tensor_extra_gpu * dst_extra = (const ggml_tensor_extra_gpu *) dst->extra;
+
+    ggml_tensor_extra_gpu src1_row_extra;
+    ggml_tensor_extra_gpu dst_row_extra;
+
+    ggml_tensor src1_row = *src1;
+    ggml_tensor dst_row = *dst;
+
+    src1_row.ne[1] = 1;
+    dst_row.ne[1] = 1;
+
+    src1_row.nb[2] = src1_row.nb[1];
+    dst_row.nb[2] = dst_row.nb[1];
+
+    src1_row.nb[3] = src1_row.nb[1];
+    dst_row.nb[3] = dst_row.nb[1];
+
+    src1_row.extra = &src1_row_extra;
+    dst_row.extra = &dst_row_extra;
+
+
+    for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
+        //int32_t row_id;
+        //CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
+        //CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
+
+        const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
+
+        GGML_ASSERT(row_id >= 0 && row_id < n_as);
 
-    (void) _src0;
-    (void) _src1;
+        const struct ggml_tensor * src0_row = dst->src[row_id + 2];
+
+        src1_row_extra.data_device[g_main_device] = (char *) src1_extra->data_device[g_main_device] + i01*src1->nb[1];
+        src1_row.data = (char *) src1->data + i01*src1->nb[1];
+
+        dst_row_extra.data_device[g_main_device] = (char *) dst_extra->data_device[g_main_device] + i01*dst->nb[1];
+        dst_row.data = (char *) dst->data + i01*dst->nb[1];
+
+        ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
+    }
 }
 
 static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -8683,6 +9208,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
         case GGML_OP_ADD:
             func = ggml_cuda_add;
             break;
+        case GGML_OP_ACC:
+            func = ggml_cuda_acc;
+            break;
         case GGML_OP_MUL:
             func = ggml_cuda_mul;
             break;
@@ -8697,6 +9225,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
                 case GGML_UNARY_OP_SILU:
                     func = ggml_cuda_silu;
                     break;
+                case GGML_UNARY_OP_GELU_QUICK:
+                    func = ggml_cuda_gelu_quick;
+                    break;
+                case GGML_UNARY_OP_TANH:
+                    func = ggml_cuda_tanh;
+                    break;
                 case GGML_UNARY_OP_RELU:
                     func = ggml_cuda_relu;
                     break;
@@ -8707,6 +9241,21 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
         case GGML_OP_NORM:
             func = ggml_cuda_norm;
             break;
+        case GGML_OP_GROUP_NORM:
+            func = ggml_cuda_group_norm;
+            break;
+        case GGML_OP_CONCAT:
+            func = ggml_cuda_concat;
+            break;
+        case GGML_OP_UPSCALE:
+            func = ggml_cuda_upscale;
+            break;
+        case GGML_OP_PAD:
+            func = ggml_cuda_pad;
+            break;
+        case GGML_OP_LEAKY_RELU:
+            func = ggml_cuda_leaky_relu;
+            break;
         case GGML_OP_RMS_NORM:
             func = ggml_cuda_rms_norm;
             break;
@@ -8729,9 +9278,6 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
             func = ggml_cuda_sqr;
             break;
         case GGML_OP_CLAMP:
-            if (!any_on_device) {
-                return false;
-            }
             func = ggml_cuda_clamp;
             break;
         case GGML_OP_CPY:
@@ -8740,6 +9286,7 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
         case GGML_OP_CONT:
             func = ggml_cuda_dup;
             break;
+        case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
         case GGML_OP_VIEW:
         case GGML_OP_PERMUTE:
@@ -9159,6 +9706,8 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
                 case GGML_UNARY_OP_GELU:
                 case GGML_UNARY_OP_SILU:
                 case GGML_UNARY_OP_RELU:
+                case GGML_UNARY_OP_GELU_QUICK:
+                case GGML_UNARY_OP_TANH:
                     return true;
                 default:
                     return false;
@@ -9181,6 +9730,45 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
                 }
                 return true;
             } break;
+        case GGML_OP_GET_ROWS:
+            {
+                switch (op->src[0]->type) {
+                    case GGML_TYPE_F16:
+                    case GGML_TYPE_F32:
+                    case GGML_TYPE_Q4_0:
+                    case GGML_TYPE_Q4_1:
+                    case GGML_TYPE_Q5_0:
+                    case GGML_TYPE_Q5_1:
+                    case GGML_TYPE_Q8_0:
+                        return true;
+                    default:
+                        return false;
+                }
+            } break;
+        case GGML_OP_CPY:
+            {
+                ggml_type src0_type = op->src[0]->type;
+                ggml_type src1_type = op->src[1]->type;
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
+                    return true;
+                }
+                return false;
+            } break;
         case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
         case GGML_OP_VIEW:
@@ -9188,7 +9776,6 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
         case GGML_OP_TRANSPOSE:
         case GGML_OP_NORM:
         case GGML_OP_REPEAT:
-        case GGML_OP_GET_ROWS:
         case GGML_OP_DUP:
         case GGML_OP_ADD:
         case GGML_OP_MUL:
@@ -9197,7 +9784,6 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
         case GGML_OP_SCALE:
         case GGML_OP_SQR:
         case GGML_OP_CLAMP:
-        case GGML_OP_CPY:
         case GGML_OP_CONT:
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_SOFT_MAX:
@@ -9206,6 +9792,12 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
         case GGML_OP_IM2COL:
         case GGML_OP_SUM_ROWS:
         case GGML_OP_ARGSORT:
+        case GGML_OP_ACC:
+        case GGML_OP_CONCAT:
+        case GGML_OP_GROUP_NORM:
+        case GGML_OP_UPSCALE:
+        case GGML_OP_PAD:
+        case GGML_OP_LEAKY_RELU:
             return true;
         default:
             return false;
@@ -9264,7 +9856,9 @@ static ggml_backend_t ggml_backend_reg_cuda_init(const char * params, void * use
     UNUSED(params);
 }
 
-extern "C" int ggml_backend_cuda_reg_devices() {
+extern "C" int ggml_backend_cuda_reg_devices();
+
+int ggml_backend_cuda_reg_devices() {
     int device_count = ggml_cuda_get_device_count();
     //int device_count = 1; // DEBUG: some tools require delaying CUDA initialization
     for (int i = 0; i < device_count; i++) {
index d7146ea61cfde65f450fc4d0c578efdecd29549c..465679a6bb0c8702f4f46a1ccf81161cad38b626 100644 (file)
@@ -66,9 +66,11 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(div_row);
     GGML_METAL_DECL_KERNEL(scale);
     GGML_METAL_DECL_KERNEL(scale_4);
-    GGML_METAL_DECL_KERNEL(silu);
+    GGML_METAL_DECL_KERNEL(tanh);
     GGML_METAL_DECL_KERNEL(relu);
     GGML_METAL_DECL_KERNEL(gelu);
+    GGML_METAL_DECL_KERNEL(gelu_quick);
+    GGML_METAL_DECL_KERNEL(silu);
     GGML_METAL_DECL_KERNEL(soft_max);
     GGML_METAL_DECL_KERNEL(soft_max_4);
     GGML_METAL_DECL_KERNEL(diag_mask_inf);
@@ -86,6 +88,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(get_rows_q5_K);
     GGML_METAL_DECL_KERNEL(get_rows_q6_K);
     GGML_METAL_DECL_KERNEL(rms_norm);
+    GGML_METAL_DECL_KERNEL(group_norm);
     GGML_METAL_DECL_KERNEL(norm);
     GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
     GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
@@ -102,6 +105,21 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
+    //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
+    //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
+    //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
@@ -130,8 +148,11 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(rope_f16);
     GGML_METAL_DECL_KERNEL(alibi_f32);
     GGML_METAL_DECL_KERNEL(im2col_f16);
+    GGML_METAL_DECL_KERNEL(upscale_f32);
+    GGML_METAL_DECL_KERNEL(pad_f32);
     GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
     GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
+    GGML_METAL_DECL_KERNEL(leaky_relu_f32);
     GGML_METAL_DECL_KERNEL(cpy_f32_f16);
     GGML_METAL_DECL_KERNEL(cpy_f32_f32);
     GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
@@ -140,6 +161,7 @@ struct ggml_metal_context {
     //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
     //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
     GGML_METAL_DECL_KERNEL(cpy_f16_f16);
+    GGML_METAL_DECL_KERNEL(cpy_f16_f32);
     GGML_METAL_DECL_KERNEL(concat);
     GGML_METAL_DECL_KERNEL(sqr);
     GGML_METAL_DECL_KERNEL(sum_rows);
@@ -318,9 +340,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(div_row);
         GGML_METAL_ADD_KERNEL(scale);
         GGML_METAL_ADD_KERNEL(scale_4);
-        GGML_METAL_ADD_KERNEL(silu);
+        GGML_METAL_ADD_KERNEL(tanh);
         GGML_METAL_ADD_KERNEL(relu);
         GGML_METAL_ADD_KERNEL(gelu);
+        GGML_METAL_ADD_KERNEL(gelu_quick);
+        GGML_METAL_ADD_KERNEL(silu);
         GGML_METAL_ADD_KERNEL(soft_max);
         GGML_METAL_ADD_KERNEL(soft_max_4);
         GGML_METAL_ADD_KERNEL(diag_mask_inf);
@@ -338,6 +362,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(get_rows_q5_K);
         GGML_METAL_ADD_KERNEL(get_rows_q6_K);
         GGML_METAL_ADD_KERNEL(rms_norm);
+        GGML_METAL_ADD_KERNEL(group_norm);
         GGML_METAL_ADD_KERNEL(norm);
         GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
         GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
@@ -354,6 +379,21 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
         GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
         GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
+        //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
+        //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
+        //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
         if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
             GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
             GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
@@ -384,8 +424,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(rope_f16);
         GGML_METAL_ADD_KERNEL(alibi_f32);
         GGML_METAL_ADD_KERNEL(im2col_f16);
+        GGML_METAL_ADD_KERNEL(upscale_f32);
+        GGML_METAL_ADD_KERNEL(pad_f32);
         GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
         GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
+        GGML_METAL_ADD_KERNEL(leaky_relu_f32);
         GGML_METAL_ADD_KERNEL(cpy_f32_f16);
         GGML_METAL_ADD_KERNEL(cpy_f32_f32);
         GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
@@ -394,6 +437,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
         //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
         GGML_METAL_ADD_KERNEL(cpy_f16_f16);
+        GGML_METAL_ADD_KERNEL(cpy_f16_f32);
         GGML_METAL_ADD_KERNEL(concat);
         GGML_METAL_ADD_KERNEL(sqr);
         GGML_METAL_ADD_KERNEL(sum_rows);
@@ -418,9 +462,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(div_row);
     GGML_METAL_DEL_KERNEL(scale);
     GGML_METAL_DEL_KERNEL(scale_4);
-    GGML_METAL_DEL_KERNEL(silu);
+    GGML_METAL_DEL_KERNEL(tanh);
     GGML_METAL_DEL_KERNEL(relu);
     GGML_METAL_DEL_KERNEL(gelu);
+    GGML_METAL_DEL_KERNEL(gelu_quick);
+    GGML_METAL_DEL_KERNEL(silu);
     GGML_METAL_DEL_KERNEL(soft_max);
     GGML_METAL_DEL_KERNEL(soft_max_4);
     GGML_METAL_DEL_KERNEL(diag_mask_inf);
@@ -438,6 +484,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(get_rows_q5_K);
     GGML_METAL_DEL_KERNEL(get_rows_q6_K);
     GGML_METAL_DEL_KERNEL(rms_norm);
+    GGML_METAL_DEL_KERNEL(group_norm);
     GGML_METAL_DEL_KERNEL(norm);
     GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
     GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
@@ -454,6 +501,21 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
     GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
     GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
+    //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
+    //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
+    //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
     if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
         GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
         GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
@@ -484,8 +546,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(rope_f16);
     GGML_METAL_DEL_KERNEL(alibi_f32);
     GGML_METAL_DEL_KERNEL(im2col_f16);
+    GGML_METAL_DEL_KERNEL(upscale_f32);
+    GGML_METAL_DEL_KERNEL(pad_f32);
     GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
     GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
+    GGML_METAL_DEL_KERNEL(leaky_relu_f32);
     GGML_METAL_DEL_KERNEL(cpy_f32_f16);
     GGML_METAL_DEL_KERNEL(cpy_f32_f32);
     GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
@@ -494,6 +559,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
     //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
     GGML_METAL_DEL_KERNEL(cpy_f16_f16);
+    GGML_METAL_DEL_KERNEL(cpy_f16_f32);
     GGML_METAL_DEL_KERNEL(concat);
     GGML_METAL_DEL_KERNEL(sqr);
     GGML_METAL_DEL_KERNEL(sum_rows);
@@ -795,9 +861,11 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
     switch (op->op) {
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(op)) {
-                case GGML_UNARY_OP_SILU:
+                case GGML_UNARY_OP_TANH:
                 case GGML_UNARY_OP_RELU:
                 case GGML_UNARY_OP_GELU:
+                case GGML_UNARY_OP_GELU_QUICK:
+                case GGML_UNARY_OP_SILU:
                     return true;
                 default:
                     return false;
@@ -809,6 +877,7 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
         case GGML_OP_PERMUTE:
         case GGML_OP_CONCAT:
         case GGML_OP_ADD:
+        case GGML_OP_ACC:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
         case GGML_OP_SCALE:
@@ -816,21 +885,50 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
         case GGML_OP_SUM_ROWS:
         case GGML_OP_SOFT_MAX:
         case GGML_OP_RMS_NORM:
+        case GGML_OP_GROUP_NORM:
         case GGML_OP_NORM:
         case GGML_OP_ALIBI:
         case GGML_OP_ROPE:
         case GGML_OP_IM2COL:
+        case GGML_OP_UPSCALE:
+        case GGML_OP_PAD:
         case GGML_OP_ARGSORT:
-        case GGML_OP_DUP:
-        case GGML_OP_CPY:
-        case GGML_OP_CONT:
+        case GGML_OP_LEAKY_RELU:
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
             return true;
+        case GGML_OP_CPY:
+        case GGML_OP_DUP:
+        case GGML_OP_CONT:
+            {
+                switch (op->src[0]->type) {
+                    case GGML_TYPE_F32:
+                        switch (op->type) {
+                           case GGML_TYPE_F16:
+                           case GGML_TYPE_F32:
+                           case GGML_TYPE_Q8_0:
+                           case GGML_TYPE_Q4_0:
+                           case GGML_TYPE_Q4_1:
+                                return true;
+                           default:
+                                return false;
+                        }
+                    case GGML_TYPE_F16:
+                        switch (op->type) {
+                           case GGML_TYPE_F16:
+                           case GGML_TYPE_F32:
+                                return true;
+                           default:
+                                return false;
+                        }
+                    default:
+                        return false;
+                };
+            }
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_GET_ROWS:
             {
-                return op->ne[0] % 4 == 0;
+                return op->ne[3] == 1;
             }
         default:
             return false;
@@ -906,7 +1004,10 @@ void ggml_metal_graph_compute(
                         } break;
                 }
 
-                GGML_ASSERT(ggml_metal_supports_op(dst));
+                if (!ggml_metal_supports_op(dst)) {
+                    GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
+                    GGML_ASSERT(!"unsupported op");
+                }
 
                 const int64_t  ne00 = src0 ? src0->ne[0] : 0;
                 const int64_t  ne01 = src0 ? src0->ne[1] : 0;
@@ -1003,34 +1104,39 @@ void ggml_metal_graph_compute(
                     case GGML_OP_MUL:
                     case GGML_OP_DIV:
                         {
-                            GGML_ASSERT(ggml_is_contiguous(src0));
-                            GGML_ASSERT(ggml_is_contiguous(src1));
+                            const size_t offs = 0;
 
                             bool bcast_row = false;
 
                             int64_t nb = ne00;
 
-                            if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
+                            id<MTLComputePipelineState> pipeline = nil;
+
+                            if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
+                                GGML_ASSERT(ggml_is_contiguous(src0));
+
                                 // src1 is a row
                                 GGML_ASSERT(ne11 == 1);
 
                                 nb = ne00 / 4;
                                 switch (dst->op) {
-                                    case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
-                                    case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
-                                    case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
+                                    case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
+                                    case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
+                                    case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
                                     default: GGML_ASSERT(false);
                                 }
 
                                 bcast_row = true;
                             } else {
                                 switch (dst->op) {
-                                    case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
-                                    case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
-                                    case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
+                                    case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
+                                    case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
+                                    case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
                                     default: GGML_ASSERT(false);
                                 }
                             }
+
+                            [encoder setComputePipelineState:pipeline];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
@@ -1058,18 +1164,99 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:24];
                             [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:25];
                             [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:26];
-                            [encoder setBytes:&nb   length:sizeof(nb)   atIndex:27];
+                            [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
+                            [encoder setBytes:&nb   length:sizeof(nb)   atIndex:28];
 
                             if (bcast_row) {
                                 const int64_t n = ggml_nelements(dst)/4;
 
                                 [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                             } else {
-                                const int nth = MIN(1024, ne0);
+                                const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
 
                                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                             }
                         } break;
+                    case GGML_OP_ACC:
+                        {
+                            GGML_ASSERT(src0t == GGML_TYPE_F32);
+                            GGML_ASSERT(src1t == GGML_TYPE_F32);
+                            GGML_ASSERT(dstt  == GGML_TYPE_F32);
+
+                            GGML_ASSERT(ggml_is_contiguous(src0));
+                            GGML_ASSERT(ggml_is_contiguous(src1));
+
+                            const size_t pnb1 = ((int32_t *) dst->op_params)[0];
+                            const size_t pnb2 = ((int32_t *) dst->op_params)[1];
+                            const size_t pnb3 = ((int32_t *) dst->op_params)[2];
+                            const size_t offs = ((int32_t *) dst->op_params)[3];
+
+                            const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
+
+                            if (!inplace) {
+                                // run a separete kernel to cpy src->dst
+                                // not sure how to avoid this
+                                // TODO: make a simpler cpy_bytes kernel
+
+                                const int nth = MIN(1024, ne00);
+
+                                [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
+                                [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
+                                [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
+                                [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
+                                [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:3];
+                                [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:4];
+                                [encoder setBytes:&ne03    length:sizeof( int64_t) atIndex:5];
+                                [encoder setBytes:&nb00    length:sizeof(uint64_t) atIndex:6];
+                                [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:7];
+                                [encoder setBytes:&nb02    length:sizeof(uint64_t) atIndex:8];
+                                [encoder setBytes:&nb03    length:sizeof(uint64_t) atIndex:9];
+                                [encoder setBytes:&ne0     length:sizeof( int64_t) atIndex:10];
+                                [encoder setBytes:&ne1     length:sizeof( int64_t) atIndex:11];
+                                [encoder setBytes:&ne2     length:sizeof( int64_t) atIndex:12];
+                                [encoder setBytes:&ne3     length:sizeof( int64_t) atIndex:13];
+                                [encoder setBytes:&nb0     length:sizeof(uint64_t) atIndex:14];
+                                [encoder setBytes:&nb1     length:sizeof(uint64_t) atIndex:15];
+                                [encoder setBytes:&nb2     length:sizeof(uint64_t) atIndex:16];
+                                [encoder setBytes:&nb3     length:sizeof(uint64_t) atIndex:17];
+
+                                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                            }
+
+                            [encoder setComputePipelineState:ctx->pipeline_add];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+                            [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
+                            [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
+                            [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
+                            [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
+                            [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
+                            [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
+                            [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
+                            [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
+                            [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
+                            [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
+                            [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
+                            [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
+                            [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
+                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:19];
+                            [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:20];
+                            [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:21];
+                            [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:22];
+                            [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:23];
+                            [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
+                            [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
+                            [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
+                            [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
+
+                            const int nth = MIN(1024, ne0);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        } break;
                     case GGML_OP_SCALE:
                         {
                             GGML_ASSERT(ggml_is_contiguous(src0));
@@ -1093,16 +1280,15 @@ void ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_UNARY:
                         switch (ggml_get_unary_op(gf->nodes[i])) {
-                            case GGML_UNARY_OP_SILU:
+                            case GGML_UNARY_OP_TANH:
                                 {
-                                    [encoder setComputePipelineState:ctx->pipeline_silu];
+                                    [encoder setComputePipelineState:ctx->pipeline_tanh];
                                     [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                     [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 
                                     const int64_t n = ggml_nelements(dst);
-                                    GGML_ASSERT(n % 4 == 0);
 
-                                    [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                                 } break;
                             case GGML_UNARY_OP_RELU:
                                 {
@@ -1123,6 +1309,28 @@ void ggml_metal_graph_compute(
                                     const int64_t n = ggml_nelements(dst);
                                     GGML_ASSERT(n % 4 == 0);
 
+                                    [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                } break;
+                            case GGML_UNARY_OP_GELU_QUICK:
+                                {
+                                    [encoder setComputePipelineState:ctx->pipeline_gelu_quick];
+                                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                                    const int64_t n = ggml_nelements(dst);
+                                    GGML_ASSERT(n % 4 == 0);
+
+                                    [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                } break;
+                            case GGML_UNARY_OP_SILU:
+                                {
+                                    [encoder setComputePipelineState:ctx->pipeline_silu];
+                                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                                    const int64_t n = ggml_nelements(dst);
+                                    GGML_ASSERT(n % 4 == 0);
+
                                     [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                                 } break;
                             default:
@@ -1197,6 +1405,8 @@ void ggml_metal_graph_compute(
                             [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
                             if (id_src1) {
                                 [encoder setBuffer:id_src1 offset:offs_src1   atIndex:1];
+                            } else {
+                                [encoder setBuffer:id_src0 offset:offs_src0   atIndex:1];
                             }
                             [encoder setBuffer:id_dst  offset:offs_dst    atIndex:2];
                             [encoder setBytes:&ne00  length:sizeof(ne00)  atIndex:3];
@@ -1448,7 +1658,7 @@ void ggml_metal_graph_compute(
                                 else if (src0t == GGML_TYPE_Q6_K) {
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 } else {
-                                    int64_t ny = (ne11 + nrows - 1)/nrows;
+                                    const int64_t ny = (ne11 + nrows - 1)/nrows;
                                     [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                             }
@@ -1460,7 +1670,7 @@ void ggml_metal_graph_compute(
 
                             GGML_ASSERT(src0t == GGML_TYPE_I32);
 
-                            const int n_as = ne00;
+                            const int n_as = ((int32_t *) dst->op_params)[1];
 
                             // TODO: make this more general
                             GGML_ASSERT(n_as <= 8);
@@ -1492,14 +1702,22 @@ void ggml_metal_graph_compute(
 
                             // find the break-even point where the matrix-matrix kernel becomes more efficient compared
                             // to the matrix-vector kernel
-                            int ne11_mm_min = 0;
+                            int ne11_mm_min = 1;
 
                             const int idx = ((int32_t *) dst->op_params)[0];
 
+                            // batch size
+                            GGML_ASSERT(ne01 == ne11);
+
+                            const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
+
                             // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
                             // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
-                            if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
-                                ne11 > ne11_mm_min) {
+                            // !!!
+                            // TODO: for now, always use mat-vec kernels until we figure out how to improve the
+                            //       indirect matrix multiplication
+                            // !!!
+                            if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
                                 switch (src2->type) {
                                     case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32];  break;
                                     case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32];  break;
@@ -1518,19 +1736,22 @@ void ggml_metal_graph_compute(
                                 [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
                                 [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
                                 [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
-                                [encoder setBytes:&ne20    length:sizeof(ne20) atIndex:3];
-                                [encoder setBytes:&ne22    length:sizeof(ne22) atIndex:4];
-                                [encoder setBytes:&nb21    length:sizeof(nb21) atIndex:5];
-                                [encoder setBytes:&nb22    length:sizeof(nb22) atIndex:6];
-                                [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:7];
-                                [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:8];
-                                [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:9];
-                                [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:10];
-                                [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:11];
-                                [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:12];
-                                [encoder setBytes:&r2      length:sizeof(r2)   atIndex:13];
-                                [encoder setBytes:&r3      length:sizeof(r3)   atIndex:14];
-                                [encoder setBytes:&idx     length:sizeof(idx)  atIndex:15];
+                                [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:3];
+                                [encoder setBytes:&ne20    length:sizeof(ne20) atIndex:4];
+                                [encoder setBytes:&ne22    length:sizeof(ne22) atIndex:5];
+                                [encoder setBytes:&nb21    length:sizeof(nb21) atIndex:6];
+                                [encoder setBytes:&nb22    length:sizeof(nb22) atIndex:7];
+                                [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:8];
+                                [encoder setBytes:&ne13    length:sizeof(ne13) atIndex:9];
+                                [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:10];
+                                [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:11];
+                                [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:12];
+                                [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:13];
+                                [encoder setBytes:&_ne1    length:sizeof(_ne1) atIndex:14];
+                                [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:15];
+                                [encoder setBytes:&r2      length:sizeof(r2)   atIndex:16];
+                                [encoder setBytes:&r3      length:sizeof(r3)   atIndex:17];
+                                [encoder setBytes:&idx     length:sizeof(idx)  atIndex:18];
                                 // TODO: how to make this an array? read Metal docs
                                 for (int j = 0; j < n_as; ++j) {
                                     struct ggml_tensor * src_cur = dst->src[2 + j];
@@ -1538,11 +1759,157 @@ void ggml_metal_graph_compute(
                                     size_t offs_src_cur = 0;
                                     id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
 
-                                    [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
+                                    [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
                                 }
 
                                 [encoder setThreadgroupMemoryLength:8192 atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+
+                                // TODO: processing one row at a time (ne11 -> 1) is not efficient
+                                [encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                            } else {
+                                int nth0 = 32;
+                                int nth1 = 1;
+                                int nrows = 1;
+                                //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
+
+                                // use custom matrix x vector kernel
+                                switch (src2t) {
+                                    case GGML_TYPE_F32:
+                                        {
+                                            GGML_ASSERT(src1t == GGML_TYPE_F32);
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
+                                        } break;
+                                    case GGML_TYPE_F16:
+                                        {
+                                            GGML_ASSERT(src1t == GGML_TYPE_F32);
+                                            nth0 = 32;
+                                            nth1 = 1;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
+                                        } break;
+                                    case GGML_TYPE_Q4_0:
+                                        {
+                                            nth0 = 8;
+                                            nth1 = 8;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
+                                        } break;
+                                    case GGML_TYPE_Q4_1:
+                                        {
+                                            nth0 = 8;
+                                            nth1 = 8;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
+                                        } break;
+                                    case GGML_TYPE_Q5_0:
+                                        {
+                                            nth0 = 8;
+                                            nth1 = 8;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
+                                        } break;
+                                    case GGML_TYPE_Q5_1:
+                                        {
+                                            nth0 = 8;
+                                            nth1 = 8;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
+                                        } break;
+                                    case GGML_TYPE_Q8_0:
+                                        {
+                                            nth0 = 8;
+                                            nth1 = 8;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
+                                        } break;
+                                    case GGML_TYPE_Q2_K:
+                                        {
+                                            nth0 = 2;
+                                            nth1 = 32;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
+                                        } break;
+                                    case GGML_TYPE_Q3_K:
+                                        {
+                                            nth0 = 2;
+                                            nth1 = 32;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
+                                        } break;
+                                    case GGML_TYPE_Q4_K:
+                                        {
+                                            nth0 = 4; //1;
+                                            nth1 = 8; //32;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
+                                        } break;
+                                    case GGML_TYPE_Q5_K:
+                                        {
+                                            nth0 = 2;
+                                            nth1 = 32;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
+                                        } break;
+                                    case GGML_TYPE_Q6_K:
+                                        {
+                                            nth0 = 2;
+                                            nth1 = 32;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
+                                        } break;
+                                    default:
+                                        {
+                                            GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
+                                            GGML_ASSERT(false && "not implemented");
+                                        }
+                                };
+
+                                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
+                                [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
+                                [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
+                                [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
+                                [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
+                                [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
+                                [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
+                                [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
+                                [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
+                                [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
+                                [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
+                                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
+                                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
+                                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
+                                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:17];
+                                [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
+                                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:19];
+                                [encoder setBytes:&r2   length:sizeof(r2)   atIndex:20];
+                                [encoder setBytes:&r3   length:sizeof(r3)   atIndex:21];
+                                [encoder setBytes:&idx  length:sizeof(idx)  atIndex:22];
+                                // TODO: how to make this an array? read Metal docs
+                                for (int j = 0; j < n_as; ++j) {
+                                    struct ggml_tensor * src_cur = dst->src[2 + j];
+
+                                    size_t offs_src_cur = 0;
+                                    id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
+
+                                    [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
+                                }
+
+                                if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
+                                    src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
+                                    src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                }
+                                else if (src2t == GGML_TYPE_Q4_K) {
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                }
+                                else if (src2t == GGML_TYPE_Q3_K) {
+#ifdef GGML_QKK_64
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+#else
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+#endif
+                                }
+                                else if (src2t == GGML_TYPE_Q5_K) {
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                }
+                                else if (src2t == GGML_TYPE_Q6_K) {
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                } else {
+                                    const int64_t ny = (_ne1 + nrows - 1)/nrows;
+                                    [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                }
                             }
                         } break;
                     case GGML_OP_GET_ROWS:
@@ -1563,16 +1930,19 @@ void ggml_metal_graph_compute(
                                 default: GGML_ASSERT(false && "not implemented");
                             }
 
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                            [encoder setBuffer:id_src0     offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_src1     offset:offs_src1 atIndex:1];
+                            [encoder setBuffer:id_dst      offset:offs_dst  atIndex:2];
                             [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
                             [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
-                            [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:5];
-
-                            const int64_t n = ggml_nelements(src1);
-
-                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                            [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
+                            [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
+                            [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
+                            [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
+                            [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:9];
+                            [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:10];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
                         } break;
                     case GGML_OP_RMS_NORM:
                         {
@@ -1599,6 +1969,38 @@ void ggml_metal_graph_compute(
 
                             [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
+                    case GGML_OP_GROUP_NORM:
+                        {
+                            GGML_ASSERT(ne00 % 4 == 0);
+
+                            //float eps;
+                            //memcpy(&eps, dst->op_params, sizeof(float));
+
+                            const float eps = 1e-6f; // TODO: temporarily hardcoded
+
+                            const int32_t n_groups = ((int32_t *) dst->op_params)[0];
+
+                            int nth = 32; // SIMD width
+
+                            //while (nth < ne00/4 && nth < 1024) {
+                            //    nth *= 2;
+                            //}
+
+                            [encoder setComputePipelineState:ctx->pipeline_group_norm];
+                            [encoder setBuffer:id_src0  offset:offs_src0        atIndex:0];
+                            [encoder setBuffer:id_dst   offset:offs_dst         atIndex:1];
+                            [encoder setBytes:&ne00     length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&ne01     length:sizeof( int64_t) atIndex:3];
+                            [encoder setBytes:&ne02     length:sizeof( int64_t) atIndex:4];
+                            [encoder setBytes:&nb00     length:sizeof(uint64_t) atIndex:5];
+                            [encoder setBytes:&nb01     length:sizeof(uint64_t) atIndex:6];
+                            [encoder setBytes:&nb02     length:sizeof(uint64_t) atIndex:7];
+                            [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
+                            [encoder setBytes:&eps      length:sizeof(   float) atIndex:9];
+                            [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        } break;
                     case GGML_OP_NORM:
                         {
                             float eps;
@@ -1768,6 +2170,65 @@ void ggml_metal_graph_compute(
 
                             [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
                         } break;
+                    case GGML_OP_UPSCALE:
+                        {
+                            GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+                            const int sf = dst->op_params[0];
+
+                            [encoder setComputePipelineState:ctx->pipeline_upscale_f32];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                            [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+                            [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+                            [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+                            [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+                            [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
+                            [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
+                            [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
+                            [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
+                            [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
+                            [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
+                            [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
+                            [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
+                            [encoder setBytes:&sf   length:sizeof(sf)   atIndex:18];
+
+                            const int nth = MIN(1024, ne0);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        } break;
+                    case GGML_OP_PAD:
+                        {
+                            GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+                            [encoder setComputePipelineState:ctx->pipeline_pad_f32];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                            [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+                            [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+                            [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+                            [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+                            [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
+                            [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
+                            [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
+                            [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
+                            [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
+                            [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
+                            [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
+                            [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
+
+                            const int nth = MIN(1024, ne0);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        } break;
                     case GGML_OP_ARGSORT:
                         {
                             GGML_ASSERT(src0->type == GGML_TYPE_F32);
@@ -1789,6 +2250,22 @@ void ggml_metal_graph_compute(
 
                             [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
                         } break;
+                    case GGML_OP_LEAKY_RELU:
+                        {
+                            GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+                            float slope;
+                            memcpy(&slope, dst->op_params, sizeof(float));
+
+                            [encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32];
+                            [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst    atIndex:1];
+                            [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
+
+                            const int64_t n = ggml_nelements(dst);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
                     case GGML_OP_DUP:
                     case GGML_OP_CPY:
                     case GGML_OP_CONT:
@@ -1817,7 +2294,7 @@ void ggml_metal_graph_compute(
                                     {
                                         switch (dstt) {
                                             case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
-                                            case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
+                                            case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
                                             default: GGML_ASSERT(false && "not implemented");
                                         };
                                     } break;
index 2f8ea22d66226040f9b9b1de4b70466994f6ff94..fe0ada445a2d45c4f295021565a8e9247c9b2dfc 100644 (file)
@@ -79,6 +79,7 @@ kernel void kernel_add(
         constant  int64_t & nb1,
         constant  int64_t & nb2,
         constant  int64_t & nb3,
+        constant  int64_t & offs,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]],
         uint3   ntg[[threads_per_threadgroup]]) {
@@ -90,9 +91,9 @@ kernel void kernel_add(
     const int64_t i12 = i02 % ne12;
     const int64_t i11 = i01 % ne11;
 
-    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
     device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
-    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
+    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1  + offs;
 
     for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
         const int i10 = i0 % ne10;
@@ -204,7 +205,7 @@ kernel void kernel_add_row(
         device const float4 * src0,
         device const float4 * src1,
         device       float4 * dst,
-        constant    int64_t & nb [[buffer(27)]],
+        constant    int64_t & nb [[buffer(28)]],
         uint tpig[[thread_position_in_grid]]) {
     dst[tpig] = src0[tpig] + src1[tpig % nb];
 }
@@ -213,7 +214,7 @@ kernel void kernel_mul_row(
         device const float4 * src0,
         device const float4 * src1,
         device       float4 * dst,
-        constant    int64_t & nb  [[buffer(27)]],
+        constant    int64_t & nb  [[buffer(28)]],
         uint tpig[[thread_position_in_grid]]) {
     dst[tpig] = src0[tpig] * src1[tpig % nb];
 }
@@ -222,7 +223,7 @@ kernel void kernel_div_row(
         device const float4 * src0,
         device const float4 * src1,
         device       float4 * dst,
-        constant    int64_t & nb  [[buffer(27)]],
+        constant    int64_t & nb  [[buffer(28)]],
         uint tpig[[thread_position_in_grid]]) {
     dst[tpig] = src0[tpig] / src1[tpig % nb];
 }
@@ -243,19 +244,53 @@ kernel void kernel_scale_4(
     dst[tpig] = src0[tpig] * scale;
 }
 
-kernel void kernel_silu(
-        device const float4 * src0,
-        device       float4 * dst,
+kernel void kernel_relu(
+        device const float * src0,
+        device       float * dst,
         uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-    dst[tpig] = x / (1.0f + exp(-x));
+    dst[tpig] = max(0.0f, src0[tpig]);
 }
 
-kernel void kernel_relu(
+kernel void kernel_tanh(
         device const float * src0,
         device       float * dst,
         uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = max(0.0f, src0[tpig]);
+    device const float & x = src0[tpig];
+    dst[tpig] = precise::tanh(x);
+}
+
+constant float GELU_COEF_A     = 0.044715f;
+constant float GELU_QUICK_COEF = -1.702f;
+constant float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
+
+kernel void kernel_gelu(
+    device const float4 * src0,
+    device       float4 * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    device const float4 & x = src0[tpig];
+
+    // BEWARE !!!
+    // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
+    // This was observed with Falcon 7B and 40B models
+    //
+    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+kernel void kernel_gelu_quick(
+    device const float4 * src0,
+    device       float4 * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    device const float4 & x = src0[tpig];
+
+    dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
+}
+
+kernel void kernel_silu(
+        device const float4 * src0,
+        device       float4 * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    device const float4 & x = src0[tpig];
+    dst[tpig] = x / (1.0f + exp(-x));
 }
 
 kernel void kernel_sqr(
@@ -313,22 +348,6 @@ kernel void kernel_sum_rows(
     dst_row[0] = row_sum;
 }
 
-constant float GELU_COEF_A    = 0.044715f;
-constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
-
-kernel void kernel_gelu(
-    device const float4 * src0,
-    device       float4 * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-
-    // BEWARE !!!
-    // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
-    // This was observed with Falcon 7B and 40B models
-    //
-    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
-}
-
 kernel void kernel_soft_max(
         device const float * src0,
         device const float * src1,
@@ -347,9 +366,9 @@ kernel void kernel_soft_max(
     const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
     const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
 
-    device const float * psrc0 =        src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-    device const float * pmask = src1 ? src1                                      + i01*ne00 : nullptr;
-    device       float * pdst  =        dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+    device const float * psrc0 =         src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+    device const float * pmask = src1 != src0 ? src1                               + i01*ne00 : nullptr;
+    device       float * pdst  =         dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
 
     // parallel max
     float lmax = -INFINITY;
@@ -385,7 +404,12 @@ kernel void kernel_soft_max(
         pdst[i00] = exp_psrc0;
     }
 
+    // This barrier fixes a failing test
+    // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+    threadgroup_barrier(mem_flags::mem_none);
+
     float sum = simd_sum(lsum);
+
     if (ntg > N_SIMDWIDTH) {
         if (sgitg == 0) {
             buf[tiisg] = 0.0f;
@@ -428,9 +452,9 @@ kernel void kernel_soft_max_4(
     const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
     const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
 
-    device const float4 * psrc4 =        (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
-    device const float4 * pmask = src1 ? (device const float4 *)(src1 +                                      i01*ne00) : nullptr;
-    device       float4 * pdst4 =        (device       float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    device const float4 * psrc4 =                (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 +                                      i01*ne00) : nullptr;
+    device       float4 * pdst4 =                (device       float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
 
     // parallel max
     float4 lmax4 = -INFINITY;
@@ -468,7 +492,13 @@ kernel void kernel_soft_max_4(
     }
 
     const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
+
+    // This barrier fixes a failing test
+    // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+    threadgroup_barrier(mem_flags::mem_none);
+
     float sum = simd_sum(lsum);
+
     if (ntg > N_SIMDWIDTH) {
         if (sgitg == 0) {
             buf[tiisg] = 0.0f;
@@ -639,6 +669,94 @@ kernel void kernel_rms_norm(
     }
 }
 
+kernel void kernel_group_norm(
+        device const float * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int32_t & n_groups,
+        constant     float & eps,
+        threadgroup float  * buf [[threadgroup(0)]],
+        uint tgpig[[threadgroup_position_in_grid]],
+        uint tpitg[[thread_position_in_threadgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]],
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint   ntg[[threads_per_threadgroup]]) {
+    const int64_t ne = ne00*ne01*ne02;
+    const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
+
+    int start = tgpig * gs;
+    int end   = start + gs;
+
+    start += tpitg;
+
+    if (end >= ne) {
+        end = ne;
+    }
+
+    float tmp = 0.0f; // partial sum for thread in warp
+
+    for (int j = start; j < end; j += ntg) {
+        tmp += src0[j];
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    tmp = simd_sum(tmp);
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            buf[sgitg] = tmp;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        tmp = buf[tiisg];
+        tmp = simd_sum(tmp);
+    }
+
+    const float mean = tmp / gs;
+    tmp = 0.0f;
+
+    for (int j = start; j < end; j += ntg) {
+        float xi = src0[j] - mean;
+        dst[j] = xi;
+        tmp += xi * xi;
+    }
+
+    tmp = simd_sum(tmp);
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            buf[sgitg] = tmp;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        tmp = buf[tiisg];
+        tmp = simd_sum(tmp);
+    }
+
+    const float variance = tmp / gs;
+    const float scale = 1.0f/sqrt(variance + eps);
+    for (int j = start; j < end; j += ntg) {
+        dst[j] *= scale;
+    }
+}
+
 // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
 // il indicates where the q4 quants begin (0 or QK4_0/4)
 // we assume that the yl's have been multiplied with the appropriate scale factor
@@ -731,7 +849,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
 //      giard against the number of rows not being divisible by
 //      N_DST, so this is another explicit assumption of the implementation.
 template<typename block_q_type, int nr, int nsg, int nw>
-void mul_vec_q_n_f32(
+void mul_vec_q_n_f32_impl(
         device const void  * src0,
         device const float * src1,
         device       float * dst,
@@ -813,7 +931,7 @@ kernel void kernel_mul_mv_q4_0_f32(
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+    mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mv_q4_1_f32(
@@ -832,7 +950,7 @@ kernel void kernel_mul_mv_q4_1_f32(
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
-     mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+     mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mv_q5_0_f32(
@@ -851,7 +969,7 @@ kernel void kernel_mul_mv_q5_0_f32(
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+    mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mv_q5_1_f32(
@@ -870,28 +988,28 @@ kernel void kernel_mul_mv_q5_1_f32(
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+    mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
 }
 
 
 #define NB_Q8_0 8
 
-kernel void kernel_mul_mv_q8_0_f32(
+void kernel_mul_mv_q8_0_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne01[[buffer(4)]],
-        constant   int64_t & ne02[[buffer(5)]],
-        constant   int64_t & ne10[[buffer(9)]],
-        constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0 [[buffer(15)]],
-        constant   int64_t & ne1 [[buffer(16)]],
-        constant   uint    & r2  [[buffer(17)]],
-        constant   uint    & r3  [[buffer(18)]],
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
     const int nr  = N_DST;
     const int nsg = N_SIMDGROUP;
     const int nw  = N_SIMDWIDTH;
@@ -945,9 +1063,29 @@ kernel void kernel_mul_mv_q8_0_f32(
     }
 }
 
+[[host_name("kernel_mul_mv_q8_0_f32")]]
+kernel void kernel_mul_mv_q8_0_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2   [[buffer(17)]],
+        constant   uint    & r3   [[buffer(18)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+    kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+}
+
 #define N_F32_F32 4
 
-kernel void kernel_mul_mv_f32_f32(
+void kernel_mul_mv_f32_f32_impl(
         device const  char * src0,
         device const  char * src1,
         device       float * dst,
@@ -965,8 +1103,8 @@ kernel void kernel_mul_mv_f32_f32(
         constant  uint64_t & nb12,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
-        constant   uint    & r2   [[buffer(17)]],
-        constant   uint    & r3   [[buffer(18)]],
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]]) {
 
@@ -1025,6 +1163,32 @@ kernel void kernel_mul_mv_f32_f32(
     }
 }
 
+[[host_name("kernel_mul_mv_f32_f32")]]
+kernel void kernel_mul_mv_f32_f32(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2   [[buffer(17)]],
+        constant   uint    & r3   [[buffer(18)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]]) {
+    kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
+}
+
 #define N_F16_F16 4
 
 kernel void kernel_mul_mv_f16_f16(
@@ -1105,7 +1269,7 @@ kernel void kernel_mul_mv_f16_f16(
     }
 }
 
-kernel void kernel_mul_mv_f16_f32_1row(
+void kernel_mul_mv_f16_f32_1row_impl(
         device const  char * src0,
         device const  char * src1,
         device       float * dst,
@@ -1123,8 +1287,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
         constant  uint64_t & nb12,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
-        constant   uint    & r2   [[buffer(17)]],
-        constant   uint    & r3   [[buffer(18)]],
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]]) {
 
@@ -1161,12 +1325,37 @@ kernel void kernel_mul_mv_f16_f32_1row(
             dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
         }
     }
+}
 
+[[host_name("kernel_mul_mv_f16_f32_1row")]]
+kernel void kernel_mul_mv_f16_f32_1row(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2   [[buffer(17)]],
+        constant   uint    & r3   [[buffer(18)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]]) {
+    kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
 }
 
 #define N_F16_F32 4
 
-kernel void kernel_mul_mv_f16_f32(
+void kernel_mul_mv_f16_f32_impl(
         device const  char * src0,
         device const  char * src1,
         device       float * dst,
@@ -1184,8 +1373,8 @@ kernel void kernel_mul_mv_f16_f32(
         constant  uint64_t & nb12,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
-        constant   uint    & r2   [[buffer(17)]],
-        constant   uint    & r3   [[buffer(18)]],
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]]) {
 
@@ -1244,6 +1433,32 @@ kernel void kernel_mul_mv_f16_f32(
     }
 }
 
+[[host_name("kernel_mul_mv_f16_f32")]]
+kernel void kernel_mul_mv_f16_f32(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2   [[buffer(17)]],
+        constant   uint    & r3   [[buffer(18)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]]) {
+    kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
+}
+
 // Assumes row size (ne00) is a multiple of 4
 kernel void kernel_mul_mv_f16_f32_l4(
         device const  char * src0,
@@ -1548,25 +1763,116 @@ kernel void kernel_im2col_f16(
     }
 }
 
-// bitonic sort implementation following the CUDA kernels as reference
-typedef void (argsort_t)(
-        device const float * x,
-        device     int32_t * dst,
-        constant   int64_t & ncols,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]]);
-
-template<ggml_sort_order order>
-kernel void kernel_argsort_f32_i32(
-        device const float   * x,
-        device       int32_t * dst,
-        constant     int64_t & ncols,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]]) {
-    // bitonic sort
-    int col = tpitg[0];
-    int row = tgpig[1];
-
+kernel void kernel_upscale_f32(
+    device  const char * src0,
+    device        char * dst,
+    constant   int64_t & ne00,
+    constant   int64_t & ne01,
+    constant   int64_t & ne02,
+    constant   int64_t & ne03,
+    constant  uint64_t & nb00,
+    constant  uint64_t & nb01,
+    constant  uint64_t & nb02,
+    constant  uint64_t & nb03,
+    constant   int64_t & ne0,
+    constant   int64_t & ne1,
+    constant   int64_t & ne2,
+    constant   int64_t & ne3,
+    constant  uint64_t & nb0,
+    constant  uint64_t & nb1,
+    constant  uint64_t & nb2,
+    constant  uint64_t & nb3,
+    constant   int32_t & sf,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    const int64_t i3 = tgpig.z;
+    const int64_t i2 = tgpig.y;
+    const int64_t i1 = tgpig.x;
+
+    const int64_t i03 = i3;
+    const int64_t i02 = i2;
+    const int64_t i01 = i1/sf;
+
+    device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
+    device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1);
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        dst_ptr[i0] = src0_ptr[i0/sf];
+    }
+}
+
+kernel void kernel_pad_f32(
+    device  const char * src0,
+    device        char * dst,
+    constant   int64_t & ne00,
+    constant   int64_t & ne01,
+    constant   int64_t & ne02,
+    constant   int64_t & ne03,
+    constant  uint64_t & nb00,
+    constant  uint64_t & nb01,
+    constant  uint64_t & nb02,
+    constant  uint64_t & nb03,
+    constant   int64_t & ne0,
+    constant   int64_t & ne1,
+    constant   int64_t & ne2,
+    constant   int64_t & ne3,
+    constant  uint64_t & nb0,
+    constant  uint64_t & nb1,
+    constant  uint64_t & nb2,
+    constant  uint64_t & nb3,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    const int64_t i3 = tgpig.z;
+    const int64_t i2 = tgpig.y;
+    const int64_t i1 = tgpig.x;
+
+    const int64_t i03 = i3;
+    const int64_t i02 = i2;
+    const int64_t i01 = i1;
+
+    device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
+    device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1);
+
+    if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
+        for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+            if (i0 < ne00) {
+                dst_ptr[i0] = src0_ptr[i0];
+            } else {
+                dst_ptr[i0] = 0.0f;
+            }
+        }
+
+        return;
+    }
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        dst_ptr[i0] = 0.0f;
+    }
+}
+
+// bitonic sort implementation following the CUDA kernels as reference
+typedef void (argsort_t)(
+        device const float * x,
+        device     int32_t * dst,
+        constant   int64_t & ncols,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]]);
+
+template<ggml_sort_order order>
+kernel void kernel_argsort_f32_i32(
+        device const float   * x,
+        device       int32_t * dst,
+        constant     int64_t & ncols,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]]) {
+    // bitonic sort
+    int col = tpitg[0];
+    int row = tgpig[1];
+
     if (col >= ncols) return;
 
     device const float   * x_row   = x   + row * ncols;
@@ -1600,9 +1906,17 @@ kernel void kernel_argsort_f32_i32(
 template [[host_name("kernel_argsort_f32_i32_asc")]]  kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
 template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
 
+kernel void kernel_leaky_relu_f32(
+        device const float * src0,
+        device       float * dst,
+        constant     float & slope,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
+}
+
 kernel void kernel_cpy_f16_f16(
-        device const half * src0,
-        device       half * dst,
+        device  const half * src0,
+        device        half * dst,
         constant   int64_t & ne00,
         constant   int64_t & ne01,
         constant   int64_t & ne02,
@@ -1641,6 +1955,47 @@ kernel void kernel_cpy_f16_f16(
     }
 }
 
+kernel void kernel_cpy_f16_f32(
+        device  const half * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+    device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+        device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+        dst_data[i00] = src[0];
+    }
+}
+
 kernel void kernel_cpy_f32_f16(
         device const float * src0,
         device        half * dst,
@@ -1917,9 +2272,9 @@ kernel void kernel_cpy_f32_q4_1(
 }
 
 kernel void kernel_concat(
-    device const char * src0,
-    device const char * src1,
-    device       char * dst,
+    device  const char * src0,
+    device  const char * src1,
+    device        char * dst,
     constant   int64_t & ne00,
     constant   int64_t & ne01,
     constant   int64_t & ne02,
@@ -1956,7 +2311,7 @@ kernel void kernel_concat(
     const int64_t i12 = i02 % ne12;
     const int64_t i11 = i01 % ne11;
 
-    device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
+    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
     device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
     device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1  + tpitg.x*nb0;
 
@@ -2064,19 +2419,19 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
 
 //====================================== dot products =========================
 
-kernel void kernel_mul_mv_q2_K_f32(
+void kernel_mul_mv_q2_K_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne01[[buffer(4)]],
-        constant   int64_t & ne02[[buffer(5)]],
-        constant   int64_t & ne10[[buffer(9)]],
-        constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0 [[buffer(15)]],
-        constant   int64_t & ne1 [[buffer(16)]],
-        constant   uint    & r2  [[buffer(17)]],
-        constant   uint    & r3  [[buffer(18)]],
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2214,8 +2569,8 @@ kernel void kernel_mul_mv_q2_K_f32(
     }
 }
 
-#if QK_K == 256
-kernel void kernel_mul_mv_q3_K_f32(
+[[host_name("kernel_mul_mv_q2_K_f32")]]
+kernel void kernel_mul_mv_q2_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -2229,8 +2584,29 @@ kernel void kernel_mul_mv_q3_K_f32(
         constant   uint    & r2  [[buffer(17)]],
         constant   uint    & r3  [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
+#if QK_K == 256
+void kernel_mul_mv_q3_K_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const int nb = ne00/QK_K;
 
@@ -2373,19 +2749,19 @@ kernel void kernel_mul_mv_q3_K_f32(
     }
 }
 #else
-kernel void kernel_mul_mv_q3_K_f32(
+void kernel_mul_mv_q3_K_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne01[[buffer(4)]],
-        constant   int64_t & ne02[[buffer(5)]],
-        constant   int64_t & ne10[[buffer(9)]],
-        constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0 [[buffer(15)]],
-        constant   int64_t & ne1 [[buffer(16)]],
-        constant   uint    & r2  [[buffer(17)]],
-        constant   uint    & r3  [[buffer(18)]],
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2450,20 +2826,41 @@ kernel void kernel_mul_mv_q3_K_f32(
 }
 #endif
 
+[[host_name("kernel_mul_mv_q3_K_f32")]]
+kernel void kernel_mul_mv_q3_K_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01[[buffer(4)]],
+        constant   int64_t & ne02[[buffer(5)]],
+        constant   int64_t & ne10[[buffer(9)]],
+        constant   int64_t & ne12[[buffer(11)]],
+        constant   int64_t & ne0 [[buffer(15)]],
+        constant   int64_t & ne1 [[buffer(16)]],
+        constant   uint    & r2  [[buffer(17)]],
+        constant   uint    & r3  [[buffer(18)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
 #if QK_K == 256
-kernel void kernel_mul_mv_q4_K_f32(
+void kernel_mul_mv_q4_K_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne01 [[buffer(4)]],
-        constant   int64_t & ne02 [[buffer(5)]],
-        constant   int64_t & ne10 [[buffer(9)]],
-        constant   int64_t & ne12 [[buffer(11)]],
-        constant   int64_t & ne0  [[buffer(15)]],
-        constant   int64_t & ne1  [[buffer(16)]],
-        constant   uint    & r2   [[buffer(17)]],
-        constant   uint    & r3   [[buffer(18)]],
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2564,19 +2961,19 @@ kernel void kernel_mul_mv_q4_K_f32(
     }
 }
 #else
-kernel void kernel_mul_mv_q4_K_f32(
+void kernel_mul_mv_q4_K_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne01[[buffer(4)]],
-        constant   int64_t & ne02[[buffer(5)]],
-        constant   int64_t & ne10[[buffer(9)]],
-        constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0 [[buffer(15)]],
-        constant   int64_t & ne1 [[buffer(16)]],
-        constant   uint    & r2  [[buffer(17)]],
-        constant   uint    & r3  [[buffer(18)]],
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2660,7 +3057,8 @@ kernel void kernel_mul_mv_q4_K_f32(
 }
 #endif
 
-kernel void kernel_mul_mv_q5_K_f32(
+[[host_name("kernel_mul_mv_q4_K_f32")]]
+kernel void kernel_mul_mv_q4_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -2677,6 +3075,26 @@ kernel void kernel_mul_mv_q5_K_f32(
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
+    kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_q5_K_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
     const int nb = ne00/QK_K;
 
     const int64_t r0 = tgpig.x;
@@ -2836,10 +3254,10 @@ kernel void kernel_mul_mv_q5_K_f32(
             dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
         }
     }
-
 }
 
-kernel void kernel_mul_mv_q6_K_f32(
+[[host_name("kernel_mul_mv_q5_K_f32")]]
+kernel void kernel_mul_mv_q5_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -2853,18 +3271,38 @@ kernel void kernel_mul_mv_q6_K_f32(
         constant   uint    & r2  [[buffer(17)]],
         constant   uint    & r3  [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    const uint8_t kmask1 = 0x03;
-    const uint8_t kmask2 = 0x0C;
-    const uint8_t kmask3 = 0x30;
-    const uint8_t kmask4 = 0xC0;
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    const int nb = ne00/QK_K;
+    kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
 
-    const int64_t r0 = tgpig.x;
-    const int64_t r1 = tgpig.y;
+void kernel_mul_mv_q6_K_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    const uint8_t kmask1 = 0x03;
+    const uint8_t kmask2 = 0x0C;
+    const uint8_t kmask3 = 0x30;
+    const uint8_t kmask4 = 0xC0;
+
+    const int nb = ne00/QK_K;
+
+    const int64_t r0 = tgpig.x;
+    const int64_t r1 = tgpig.y;
     const int     im = tgpig.z;
 
     const int row = 2 * r0 + sgitg;
@@ -2945,6 +3383,27 @@ kernel void kernel_mul_mv_q6_K_f32(
     }
 }
 
+[[host_name("kernel_mul_mv_q6_K_f32")]]
+kernel void kernel_mul_mv_q6_K_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01[[buffer(4)]],
+        constant   int64_t & ne02[[buffer(5)]],
+        constant   int64_t & ne10[[buffer(9)]],
+        constant   int64_t & ne12[[buffer(11)]],
+        constant   int64_t & ne0 [[buffer(15)]],
+        constant   int64_t & ne1 [[buffer(16)]],
+        constant   uint    & r2  [[buffer(17)]],
+        constant   uint    & r3  [[buffer(18)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
 //============================= templates and their specializations =============================
 
 // NOTE: this is not dequantizing - we are simply fitting the template
@@ -3062,10 +3521,10 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
 
 template <typename type4x4>
 void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
-    const half d = xb->d;
-    const half min = xb->dmin;
+    const float d = xb->d;
+    const float min = xb->dmin;
     device const uint8_t * q = (device const uint8_t *)xb->qs;
-    half dl, ml;
+    float dl, ml;
     uint8_t sc = xb->scales[il];
 
 #if QK_K == 256
@@ -3135,10 +3594,10 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
     q = q + (il/4) * 32 + 16 * (il&1);
     il = il & 3;
     const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
-    const half d   = il < 2 ? xb->d : xb->d / 16.h;
-    const half min = xb->dmin;
-    const half dl = d * sc[0];
-    const half ml = min * sc[1];
+    const float d   = il < 2 ? xb->d : xb->d / 16.h;
+    const float min = xb->dmin;
+    const float dl = d * sc[0];
+    const float ml = min * sc[1];
 #else
     q = q + 16 * (il&1);
     device const uint8_t * s = xb->scales;
@@ -3165,13 +3624,13 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
     uint8_t ul = 1 << (il/2);
     il = il & 3;
     const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
-    const half d = il < 2 ? xb->d : xb->d / 16.h;
-    const half min = xb->dmin;
-    const half dl = d * sc[0];
-    const half ml = min * sc[1];
+    const float d = il < 2 ? xb->d : xb->d / 16.h;
+    const float min = xb->dmin;
+    const float dl = d * sc[0];
+    const float ml = min * sc[1];
 
-    const ushort mask = il<2 ? 0x0F : 0xF0;
-    const half qh_val = il<2 ? 16.h : 256.h;
+    const ushort mask  = il<2 ? 0x0F : 0xF0;
+    const float qh_val = il<2 ? 16.f : 256.f;
     for (int i = 0; i < 16; ++i) {
         reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
     }
@@ -3219,22 +3678,90 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
 kernel void kernel_get_rows(
         device const  void * src0,
-        device const   int * src1,
+        device const  char * src1,
         device       float * dst,
         constant   int64_t & ne00,
         constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
         constant  uint64_t & nb1,
-        uint                 tgpig[[threadgroup_position_in_grid]],
+        constant  uint64_t & nb2,
+        uint3                tgpig[[threadgroup_position_in_grid]],
         uint                 tiitg[[thread_index_in_threadgroup]],
-        uint                 tptg[[threads_per_threadgroup]]) {
-    const int i = tgpig;
-    const int r = ((device int32_t *) src1)[i];
+        uint3                tptg [[threads_per_threadgroup]]) {
+    //const int64_t i = tgpig;
+    //const int64_t r = ((device int32_t *) src1)[i];
+
+    const int64_t i10 = tgpig.x;
+    const int64_t i11 = tgpig.y;
 
-    for (int ind = tiitg; ind < ne00/16; ind += tptg) {
+    const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+    const int64_t i02 = i11;
+
+    for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
         float4x4 temp;
         dequantize_func(
-            ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
-        *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
+            ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
+        *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
+    }
+}
+
+kernel void kernel_get_rows_f32(
+        device const  void * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        uint3                tgpig[[threadgroup_position_in_grid]],
+        uint                 tiitg[[thread_index_in_threadgroup]],
+        uint3                tptg [[threads_per_threadgroup]]) {
+    const int64_t i10 = tgpig.x;
+    const int64_t i11 = tgpig.y;
+
+    const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+    const int64_t i02 = i11;
+
+    for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+        ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
+            ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
+    }
+}
+
+kernel void kernel_get_rows_f16(
+        device const  void * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        uint3                tgpig[[threadgroup_position_in_grid]],
+        uint                 tiitg[[thread_index_in_threadgroup]],
+        uint3                tptg [[threads_per_threadgroup]]) {
+    const int64_t i10 = tgpig.x;
+    const int64_t i11 = tgpig.y;
+
+    const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+    const int64_t i02 = i11;
+
+    for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+        ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
+            ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
     }
 }
 
@@ -3426,19 +3953,22 @@ kernel void kernel_mul_mm(device const  uchar * src0,
 
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
 kernel void kernel_mul_mm_id(
-        device const int32_t * ids,
+        device const   uchar * ids,
         device const   uchar * src1,
-        device         float * dst,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne02,
         constant     int64_t & nb01,
         constant     int64_t & nb02,
         constant     int64_t & ne12,
+        constant     int64_t & ne13,
         constant     int64_t & nb10,
         constant     int64_t & nb11,
         constant     int64_t & nb12,
         constant     int64_t & ne0,
         constant     int64_t & ne1,
+        constant     int64_t & nb1,
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
@@ -3456,10 +3986,16 @@ kernel void kernel_mul_mm_id(
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
     device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
 
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
     kernel_mul_mm_impl<block_q, nl, dequantize_func>(
-        src0[ids[idx]],
-        src1,
-        dst,
+        src0[id],
+        src1 + bid*nb11,
+        (device float *) (dst + bid*nb1),
         ne00,
         ne02,
         nb01,
@@ -3484,17 +4020,26 @@ kernel void kernel_mul_mm_id(
 #define QK_NL 4
 #endif
 
+//
+// get rows
+//
+
 typedef void (get_rows_t)(
         device const void * src0,
-        device const  int * src1,
+        device const char * src1,
         device      float * dst,
         constant  int64_t & ne00,
         constant uint64_t & nb01,
+        constant uint64_t & nb02,
+        constant  int64_t & ne10,
+        constant uint64_t & nb10,
+        constant uint64_t & nb11,
         constant uint64_t & nb1,
-        uint, uint, uint);
+        constant uint64_t & nb2,
+        uint3, uint, uint3);
 
-template [[host_name("kernel_get_rows_f32")]]  kernel get_rows_t kernel_get_rows<float4x4,   1, dequantize_f32>;
-template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_t kernel_get_rows<half4x4,    1, dequantize_f16>;
+//template [[host_name("kernel_get_rows_f32")]]  kernel get_rows_t kernel_get_rows<float4x4,   1, dequantize_f32>;
+//template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_t kernel_get_rows<half4x4,    1, dequantize_f16>;
 template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
 template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
 template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
@@ -3506,6 +4051,10 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
 template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
 template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
 
+//
+// matrix-matrix multiplication
+//
+
 typedef void (mat_mm_t)(
         device const  uchar * src0,
         device const  uchar * src1,
@@ -3538,20 +4087,27 @@ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
 template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
 template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
 
+//
+// indirect matrix-matrix multiplication
+//
+
 typedef void (mat_mm_id_t)(
-        device const int32_t * ids,
+        device const   uchar * ids,
         device const   uchar * src1,
-        device         float * dst,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne02,
         constant     int64_t & nb01,
         constant     int64_t & nb02,
         constant     int64_t & ne12,
+        constant     int64_t & ne13,
         constant     int64_t & nb10,
         constant     int64_t & nb11,
         constant     int64_t & nb12,
         constant     int64_t & ne0,
         constant     int64_t & ne1,
+        constant     int64_t & nb1,
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
@@ -3578,3 +4134,775 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mu
 template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
 template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
 template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
+
+//
+// matrix-vector multiplication
+//
+
+[[host_name("kernel_mul_mv_id_f32_f32")]]
+kernel void kernel_mul_mv_id_f32_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_f32_f32_impl(
+        src0[id],
+        src1 + bid*nb11,
+        (device float *) (dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        nb00,
+        nb01,
+        nb02,
+        ne10,
+        ne11,
+        ne12,
+        nb10,
+        nb11,
+        nb12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg);
+}
+
+[[host_name("kernel_mul_mv_id_f16_f32")]]
+kernel void kernel_mul_mv_id_f16_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_f16_f32_impl(
+        src0[id],
+        src1 + bid*nb11,
+        (device float *) (dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        nb00,
+        nb01,
+        nb02,
+        ne10,
+        ne11,
+        ne12,
+        nb10,
+        nb11,
+        nb12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg);
+}
+
+[[host_name("kernel_mul_mv_id_q8_0_f32")]]
+kernel void kernel_mul_mv_id_q8_0_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_q8_0_f32_impl(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q4_0_f32")]]
+kernel void kernel_mul_mv_id_q4_0_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q4_1_f32")]]
+kernel void kernel_mul_mv_id_q4_1_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q5_0_f32")]]
+kernel void kernel_mul_mv_id_q5_0_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q5_1_f32")]]
+kernel void kernel_mul_mv_id_q5_1_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q2_K_f32")]]
+kernel void kernel_mul_mv_id_q2_K_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_q2_K_f32_impl(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q3_K_f32")]]
+kernel void kernel_mul_mv_id_q3_K_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_q3_K_f32_impl(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q4_K_f32")]]
+kernel void kernel_mul_mv_id_q4_K_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_q4_K_f32_impl(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q5_K_f32")]]
+kernel void kernel_mul_mv_id_q5_K_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_q5_K_f32_impl(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q6_K_f32")]]
+kernel void kernel_mul_mv_id_q6_K_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_q6_K_f32_impl(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
index 7285d5f7fbcc00ce41b5d481b702ecc52c5671f6..0e8163a16b39549671363ac859cad2a7e0aaeefa 100644 (file)
@@ -3114,7 +3114,7 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
 
     size_t vl = __riscv_vsetvl_e8m1(qk/2);
 
-    // These tempory registers are for masking and shift operations
+    // These temporary registers are for masking and shift operations
     vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
     vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
 
@@ -4757,7 +4757,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
 
             vl = 16;
 
-            // retreive lane to multiply with scale
+            // retrieve lane to multiply with scale
             vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
             vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
             vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
diff --git a/ggml.c b/ggml.c
index ca56f063c3a87440353e3efce428c18e003517fa..29e18a24c76a85f7babfcfe1652c32f0746aeadd 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -1,4 +1,4 @@
-#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
+#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
 #define _USE_MATH_DEFINES // For M_PI on MSVC
 
 #include "ggml-impl.h"
@@ -33,7 +33,7 @@
 // we should just be careful :)
 #pragma warning(disable: 4244 4267)
 
-// disable POSIX deprecation warnigns
+// disable POSIX deprecation warnings
 // these functions are never going away, anyway
 #pragma warning(disable: 4996)
 #endif
@@ -1395,7 +1395,7 @@ inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) {
 inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]);  }
 inline static void ggml_vec_elu_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
 inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
-inline static void ggml_vec_leaky_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.1f*x[i]; }
+inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
 
 static const float GELU_COEF_A     = 0.044715f;
 static const float GELU_QUICK_COEF = -1.702f;
@@ -1623,7 +1623,9 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "POOL_1D",
     "POOL_2D",
     "UPSCALE",
+    "PAD",
     "ARGSORT",
+    "LEAKY_RELU",
 
     "FLASH_ATTN",
     "FLASH_FF",
@@ -1650,7 +1652,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CROSS_ENTROPY_LOSS_BACK",
 };
 
-static_assert(GGML_OP_COUNT == 70, "GGML_OP_COUNT != 70");
+static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1707,7 +1709,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "pool_1d(x)",
     "pool_2d(x)",
     "upscale(x)",
+    "pad(x)",
     "argsort(x)",
+    "leaky_relu(x)",
 
     "flash_attn(x)",
     "flash_ff(x)",
@@ -1734,7 +1738,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "cross_entropy_loss_back(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 70, "GGML_OP_COUNT != 70");
+static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -1750,17 +1754,16 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
     "GELU",
     "GELU_QUICK",
     "SILU",
-    "LEAKY",
 };
 
-static_assert(GGML_UNARY_OP_COUNT == 11, "GGML_UNARY_OP_COUNT != 11");
+static_assert(GGML_UNARY_OP_COUNT == 10, "GGML_UNARY_OP_COUNT != 10");
 
 
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
 static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
 
 // WARN:
-// Mis-confguration can lead to problem that's hard to reason about:
+// Mis-configuration can lead to problem that's hard to reason about:
 // * At best  it crash or talks nosense.
 // * At worst it talks slightly difference but hard to perceive.
 //
@@ -3830,12 +3833,25 @@ struct ggml_tensor * ggml_relu_inplace(
     return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU);
 }
 
-// ggml_leaky
+// ggml_leaky_relu
 
-struct ggml_tensor * ggml_leaky(
+struct ggml_tensor * ggml_leaky_relu(
         struct ggml_context * ctx,
-        struct ggml_tensor  * a) {
-    return ggml_unary(ctx, a, GGML_UNARY_OP_LEAKY);
+        struct ggml_tensor  * a, float negative_slope, bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+    ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
+
+    result->op   = GGML_OP_LEAKY_RELU;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+
+    return result;
 }
 
 // ggml_gelu
@@ -4022,8 +4038,9 @@ static struct ggml_tensor * ggml_group_norm_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    result->op = GGML_OP_GROUP_NORM;
     result->op_params[0] = n_groups;
+
+    result->op = GGML_OP_GROUP_NORM;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
     result->src[1] = NULL; // TODO: maybe store epsilon here?
@@ -4075,17 +4092,18 @@ struct ggml_tensor * ggml_mul_mat(
 
 struct ggml_tensor * ggml_mul_mat_id(
         struct ggml_context * ctx,
-        struct ggml_tensor  * as[],
+        struct ggml_tensor  * const as[],
+        int                   n_as,
         struct ggml_tensor  * ids,
         int                   id,
         struct ggml_tensor  * b) {
 
-    int64_t n_as = ids->ne[0];
-
     GGML_ASSERT(ids->type == GGML_TYPE_I32);
-    GGML_ASSERT(ggml_is_vector(ids));
+    GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1);
+    GGML_ASSERT(ids->ne[1] == b->ne[1]);
+    GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
     GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2);
-    GGML_ASSERT(id >= 0 && id < n_as);
+    GGML_ASSERT(id >= 0 && id < ids->ne[0]);
 
     bool is_node = false;
 
@@ -4097,13 +4115,14 @@ struct ggml_tensor * ggml_mul_mat_id(
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne);
 
     ggml_set_op_params_i32(result, 0, id);
+    ggml_set_op_params_i32(result, 1, n_as);
 
     result->op   = GGML_OP_MUL_MAT_ID;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = ids;
     result->src[1] = b;
 
-    for (int64_t i = 0; i < n_as; i++) {
+    for (int i = 0; i < n_as; i++) {
         struct ggml_tensor * a = as[i];
         GGML_ASSERT(ggml_are_same_shape(as[0], a));
         GGML_ASSERT(ggml_can_mul_mat(a, b));
@@ -4731,7 +4750,9 @@ struct ggml_tensor * ggml_get_rows(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b) {
-    GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32);
+    GGML_ASSERT(a->ne[2] == b->ne[1]);
+    GGML_ASSERT(b->ne[3] == 1);
+    GGML_ASSERT(b->type == GGML_TYPE_I32);
 
     bool is_node = false;
 
@@ -4741,7 +4762,7 @@ struct ggml_tensor * ggml_get_rows(
 
     // TODO: implement non F32 return
     //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
-    struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0]);
+    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
 
     result->op   = GGML_OP_GET_ROWS;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5519,6 +5540,30 @@ static struct ggml_tensor * ggml_upscale_impl(
     return result;
 }
 
+struct ggml_tensor * ggml_pad(
+    struct ggml_context * ctx,
+    struct ggml_tensor  * a,
+    int p0, int p1, int p2, int p3) {
+    bool is_node = false;
+
+    if (a->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
+            a->ne[0] + p0,
+            a->ne[1] + p1,
+            a->ne[2] + p2,
+            a->ne[3] + p3);
+
+    result->op = GGML_OP_PAD;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+
+    return result;
+}
+
 struct ggml_tensor * ggml_upscale(
     struct ggml_context * ctx,
     struct ggml_tensor * a,
@@ -7520,7 +7565,7 @@ static void ggml_compute_forward_acc_f32(
     GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
 
     // view src0 and dst with these strides and data offset inbytes during acc
-    // nb0 is implicitely element_size because src0 and dst are contiguous
+    // nb0 is implicitly element_size because src0 and dst are contiguous
     size_t nb1     = ((int32_t *) dst->op_params)[0];
     size_t nb2     = ((int32_t *) dst->op_params)[1];
     size_t nb3     = ((int32_t *) dst->op_params)[2];
@@ -7714,8 +7759,10 @@ static void ggml_compute_forward_mul_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
+// TODO: OpenCL kernel support broadcast
 #ifdef GGML_USE_CLBLAST
     if (src1->backend == GGML_BACKEND_GPU) {
+        GGML_ASSERT(ggml_are_same_shape(src0, src1));
         if (ith == 0) {
             ggml_cl_mul(src0, src1, dst);
         }
@@ -8981,10 +9028,9 @@ static void ggml_compute_forward_silu(
             } break;
     }
 }
+// ggml_compute_forward_leaky_relu
 
-// ggml_compute_forward_leaky
-
-static void ggml_compute_forward_leaky_f32(
+static void ggml_compute_forward_leaky_relu_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -8998,24 +9044,27 @@ static void ggml_compute_forward_leaky_f32(
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
+    float negative_slope;
+    memcpy(&negative_slope, dst->op_params, sizeof(float));
+
     assert(dst->nb[0]  == sizeof(float));
     assert(src0->nb[0] == sizeof(float));
 
     for (int i = 0; i < n; i++) {
-        ggml_vec_leaky_f32(nc,
+        ggml_vec_leaky_relu_f32(nc,
                 (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
+                (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
     }
 }
 
-static void ggml_compute_forward_leaky(
+static void ggml_compute_forward_leaky_relu(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_leaky_f32(params, src0, dst);
+                ggml_compute_forward_leaky_relu_f32(params, src0, dst);
             } break;
         default:
             {
@@ -9504,8 +9553,11 @@ static bool ggml_compute_forward_mul_mat_use_blas(
     const int64_t ne0 = dst->ne[0];
     const int64_t ne1 = dst->ne[1];
 
+    // NOTE: with GGML_OP_MUL_MAT_ID we don't want to go through the BLAS branch because it will dequantize (to_float)
+    //       all the experts for each batch element and the processing would become incredibly slow
     // TODO: find the optimal values for these
-    if (ggml_is_contiguous(src0) &&
+    if (dst->op != GGML_OP_MUL_MAT_ID &&
+        ggml_is_contiguous(src0) &&
         ggml_is_contiguous(src1) &&
       //src0->type == GGML_TYPE_F32 &&
         src1->type == GGML_TYPE_F32 &&
@@ -9519,11 +9571,16 @@ static bool ggml_compute_forward_mul_mat_use_blas(
 }
 #endif
 
+// off1 = offset in i11 and i1
+// cne1 = ne11 and ne1
+// in a normal matrix multiplication, off1 = 0 and cne1 = ne1
+// during GGML_TASK_INIT, the full src1 is converted regardless of off1 and cne1
 static void ggml_compute_forward_mul_mat(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
+              struct ggml_tensor * dst,
+              int64_t off1, int64_t cne1) {
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
@@ -9591,10 +9648,9 @@ static void ggml_compute_forward_mul_mat(
                 const int64_t i03 = i13/r3;
                 const int64_t i02 = i12/r2;
 
-                const void  * x = (char *)            src0->data + i02*nb02 + i03*nb03;
-                const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
-
-                float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
+                const void  * x = (char *)            src0->data +             i02*nb02 + i03*nb03;
+                const float * y = (float *) ((char *) src1->data + off1*nb11 + i12*nb12 + i13*nb13);
+                      float * d = (float *) ((char *)  dst->data + off1*nb1  + i12*nb2  + i13*nb3);
 
                 if (type != GGML_TYPE_F32) {
                             float * const wdata    = params->wdata;
@@ -9611,10 +9667,10 @@ static void ggml_compute_forward_mul_mat(
                 }
 
                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
-                        ne11, ne01, ne10,
-                        1.0f,    y, ne10,
-                                 x, ne00,
-                        0.0f,    d, ne01);
+                         cne1, ne01, ne10,
+                         1.0f,    y, ne10,
+                                  x, ne00,
+                         0.0f,    d, ne01);
             }
         }
 
@@ -9630,6 +9686,7 @@ static void ggml_compute_forward_mul_mat(
             const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
 
             assert(params->wsize >= ne11*ne12*ne13*row_size);
+            assert(src1->type == GGML_TYPE_F32);
 
             for (int64_t i13 = 0; i13 < ne13; ++i13) {
                 for (int64_t i12 = 0; i12 < ne12; ++i12) {
@@ -9652,7 +9709,7 @@ static void ggml_compute_forward_mul_mat(
     const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
 
     const int64_t nr0 = ne01;           // src0 rows
-    const int64_t nr1 = ne11*ne12*ne13; // src1 rows
+    const int64_t nr1 = cne1*ne12*ne13; // src1 rows
 
     //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
 
@@ -9694,9 +9751,9 @@ static void ggml_compute_forward_mul_mat(
     for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
         for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
             for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
-                const int64_t i13 = (ir1/(ne12*ne11));
-                const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11;
-                const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11);
+                const int64_t i13 = (ir1/(ne12*cne1));
+                const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1;
+                const int64_t i11 = (ir1 - i13*ne12*cne1 - i12*cne1) + off1;
 
                 // broadcast src0 into src1
                 const int64_t i03 = i13/r3;
@@ -9736,20 +9793,28 @@ static void ggml_compute_forward_mul_mat(
 
 static void ggml_compute_forward_mul_mat_id(
         const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
               struct ggml_tensor * dst) {
 
-    const struct ggml_tensor * ids = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    const int id = ggml_get_op_params_i32(dst, 0);
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        // during GGML_TASK_INIT the entire src1 is converted to vec_dot_type
+        ggml_compute_forward_mul_mat(params, dst->src[2], src1, dst, 0, dst->ne[1]);
+        return;
+    }
 
-    const int a_id = ((int32_t *)ids->data)[id];
+    const struct ggml_tensor * ids = src0;
+    const int id   = ggml_get_op_params_i32(dst, 0);
+    const int n_as = ggml_get_op_params_i32(dst, 1);
 
-    GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
+    for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
+        const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
 
-    const struct ggml_tensor * src0 = dst->src[a_id + 2];
+        GGML_ASSERT(row_id >= 0 && row_id < n_as);
 
-    ggml_compute_forward_mul_mat(params, src0, src1, dst);
+        const struct ggml_tensor * src0_row = dst->src[row_id + 2];
+        ggml_compute_forward_mul_mat(params, src0_row, src1, dst, i01, 1);
+    }
 }
 
 // ggml_compute_forward_out_prod
@@ -10161,7 +10226,7 @@ static void ggml_compute_forward_set_f32(
     GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
 
     // view src0 and dst with these strides and data offset inbytes during set
-    // nb0 is implicitely element_size because src0 and dst are contiguous
+    // nb0 is implicitly element_size because src0 and dst are contiguous
     size_t nb1     = ((int32_t *) dst->op_params)[0];
     size_t nb2     = ((int32_t *) dst->op_params)[1];
     size_t nb3     = ((int32_t *) dst->op_params)[2];
@@ -10325,21 +10390,30 @@ static void ggml_compute_forward_get_rows_q(
         return;
     }
 
-    const int nc = src0->ne[0];
-    const int nr = ggml_nelements(src1);
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int64_t nc = ne00;
+    const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
+
     const enum ggml_type type = src0->type;
     ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
 
-    assert( dst->ne[0] == nc);
-    assert( dst->ne[1] == nr);
-    assert(src0->nb[0] == ggml_type_size(type));
+    assert(ne0  == nc);
+    assert(ne02 == ne11);
+    assert(nb00 == ggml_type_size(type));
+    assert(ggml_nrows(dst) == nr);
 
-    for (int i = 0; i < nr; ++i) {
-        const int r = ((int32_t *) src1->data)[i];
+    // TODO: multi-thread
+    for (int64_t i12 = 0; i12 < ne12; ++i12) {
+        for (int64_t i11 = 0; i11 < ne11; ++i11) {
+            for (int64_t i10 = 0; i10 < ne10; ++i10) {
+                const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
 
-        dequantize_row_q(
-                (const void *) ((char *) src0->data + r*src0->nb[1]),
-                     (float *) ((char *)  dst->data + i*dst->nb[1]), nc);
+                dequantize_row_q(
+                        (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
+                             (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
+            }
+        }
     }
 }
 
@@ -10354,19 +10428,26 @@ static void ggml_compute_forward_get_rows_f16(
         return;
     }
 
-    const int nc = src0->ne[0];
-    const int nr = ggml_nelements(src1);
+    GGML_TENSOR_BINARY_OP_LOCALS
 
-    assert( dst->ne[0] == nc);
-    assert( dst->ne[1] == nr);
-    assert(src0->nb[0] == sizeof(ggml_fp16_t));
+    const int64_t nc = ne00;
+    const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
 
-    for (int i = 0; i < nr; ++i) {
-        const int r = ((int32_t *) src1->data)[i];
+    assert(ne0  == nc);
+    assert(ne02 == ne11);
+    assert(nb00 == sizeof(ggml_fp16_t));
+    assert(ggml_nrows(dst) == nr);
 
-        for (int j = 0; j < nc; ++j) {
-            ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j];
-            ((float *) ((char *)  dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v);
+    // TODO: multi-thread
+    for (int64_t i12 = 0; i12 < ne12; ++i12) {
+        for (int64_t i11 = 0; i11 < ne11; ++i11) {
+            for (int64_t i10 = 0; i10 < ne10; ++i10) {
+                const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+                ggml_fp16_to_fp32_row(
+                        (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
+                             (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
+            }
         }
     }
 }
@@ -10382,19 +10463,27 @@ static void ggml_compute_forward_get_rows_f32(
         return;
     }
 
-    const int nc = src0->ne[0];
-    const int nr = ggml_nelements(src1);
+    GGML_TENSOR_BINARY_OP_LOCALS
 
-    assert( dst->ne[0] == nc);
-    assert( dst->ne[1] == nr);
-    assert(src0->nb[0] == sizeof(float));
+    const int64_t nc = ne00;
+    const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
 
-    for (int i = 0; i < nr; ++i) {
-        const int r = ((int32_t *) src1->data)[i];
+    assert(ne0  == nc);
+    assert(ne02 == ne11);
+    assert(nb00 == sizeof(float));
+    assert(ggml_nrows(dst) == nr);
 
-        ggml_vec_cpy_f32(nc,
-                (float *) ((char *)  dst->data + i*dst->nb[1]),
-                (float *) ((char *) src0->data + r*src0->nb[1]));
+    // TODO: multi-thread
+    for (int64_t i12 = 0; i12 < ne12; ++i12) {
+        for (int64_t i11 = 0; i11 < ne11; ++i11) {
+            for (int64_t i10 = 0; i10 < ne10; ++i10) {
+                const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+                ggml_vec_cpy_f32(nc,
+                        (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3),
+                        (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
+            }
+        }
     }
 }
 
@@ -12114,6 +12203,7 @@ static void ggml_compute_forward_upscale_f32(
     GGML_ASSERT(src0->nb[0] == sizeof(float));
 
     const int ith = params->ith;
+    const int nth = params->nth;
 
     GGML_TENSOR_UNARY_OP_LOCALS
 
@@ -12121,16 +12211,17 @@ static void ggml_compute_forward_upscale_f32(
 
     // TODO: optimize
 
-    for (int i03 = 0; i03 < ne03; i03++) {
-        for (int i02 = ith; i02 < ne02; i02++) {
-            for (int m = 0; m < dst->ne[1]; m++) {
-                int i01 = m / scale_factor;
-                for (int n = 0; n < dst->ne[0]; n++) {
-                    int i00 = n / scale_factor;
-
-                    const float * x = (float *)((char *) src0->data + i00 * nb00 +i01 * nb01 + i02 * nb02 + i03 * nb03);
+    for (int64_t i3 = 0; i3 < ne3; i3++) {
+        const int64_t i03 = i3;
+        for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
+            const int64_t i02 = i2;
+            for (int64_t i1 = 0; i1 < ne1; i1++) {
+                const int64_t i01 = i1 / scale_factor;
+                for (int64_t i0 = 0; i0 < ne0; i0++) {
+                    const int64_t i00 = i0 / scale_factor;
 
-                    float * y = (float *)((char *) dst->data + n * dst->nb[0] + m * dst->nb[1] + i02 * dst->nb[2] + i03 * dst->nb[3]);
+                    const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                          float * y = (float *)((char *)  dst->data +  i0*nb0  +  i1*nb1  +  i2*nb2  +  i3*nb3);
 
                     *y = *x;
                 }
@@ -12155,6 +12246,64 @@ static void ggml_compute_forward_upscale(
     }
 }
 
+// ggml_compute_forward_pad
+
+static void ggml_compute_forward_pad_f32(
+    const struct ggml_compute_params * params,
+    const struct ggml_tensor * src0,
+          struct ggml_tensor * dst) {
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT( dst->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    float * dst_ptr = (float *) dst->data;
+
+    // TODO: optimize
+
+    for (int64_t i2 = 0; i2 < ne2; ++i2) {
+        for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
+            for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                for (int64_t i3 = 0; i3 < ne3; ++i3) {
+                    const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
+
+                    const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+
+                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+                        dst_ptr[dst_idx] = *src_ptr;
+                    } else {
+                        dst_ptr[dst_idx] = 0;
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_pad(
+    const struct ggml_compute_params * params,
+    const struct ggml_tensor * src0,
+    struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_pad_f32(params, src0, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_argsort
 
 static void ggml_compute_forward_argsort_f32(
@@ -13362,10 +13511,6 @@ static void ggml_compute_forward_unary(
             {
                 ggml_compute_forward_silu(params, src0, dst);
             } break;
-        case GGML_UNARY_OP_LEAKY:
-            {
-                ggml_compute_forward_leaky(params, src0, dst);
-            } break;
         default:
             {
                 GGML_ASSERT(false);
@@ -14037,11 +14182,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_MUL_MAT:
             {
-                ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
+                ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor, 0, tensor->ne[1]);
             } break;
         case GGML_OP_MUL_MAT_ID:
             {
-                ggml_compute_forward_mul_mat_id(params, tensor);
+                ggml_compute_forward_mul_mat_id(params, tensor->src[0], tensor->src[1], tensor);
             } break;
         case GGML_OP_OUT_PROD:
             {
@@ -14147,10 +14292,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_upscale(params, tensor->src[0], tensor);
             } break;
+        case GGML_OP_PAD:
+            {
+                ggml_compute_forward_pad(params, tensor->src[0], tensor);
+            } break;
         case GGML_OP_ARGSORT:
             {
                 ggml_compute_forward_argsort(params, tensor->src[0], tensor);
             } break;
+        case GGML_OP_LEAKY_RELU:
+            {
+                ggml_compute_forward_leaky_relu(params, tensor->src[0], tensor);
+            } break;
         case GGML_OP_FLASH_ATTN:
             {
                 const int32_t t = ggml_get_op_params_i32(tensor, 0);
@@ -14475,7 +14628,7 @@ void ggml_build_backward_gradient_checkpointing(
             // insert new tensors recomputing src, reusing already made replacements,
             // remember replacements: remember new tensors with mapping from corresponding gf nodes
             // recurse for input tensors,
-            // unless (i.e. terminating when) input tensors are replacments (like checkpoints)
+            // unless (i.e. terminating when) input tensors are replacements (like checkpoints)
             node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
         }
         // insert rewritten backward node with replacements made into resulting backward graph gb
@@ -15143,10 +15296,18 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
+        case GGML_OP_PAD:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_ARGSORT:
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
+        case GGML_OP_LEAKY_RELU:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_FLASH_ATTN:
             {
                 struct ggml_tensor * flash_grad = NULL;
@@ -15752,6 +15913,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_ARGMAX:
         case GGML_OP_REPEAT:
         case GGML_OP_REPEAT_BACK:
+        case GGML_OP_LEAKY_RELU:
             {
                 n_tasks = 1;
             } break;
@@ -15764,7 +15926,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
                 case GGML_UNARY_OP_TANH:
                 case GGML_UNARY_OP_ELU:
                 case GGML_UNARY_OP_RELU:
-                case GGML_UNARY_OP_LEAKY:
                     {
                         n_tasks = 1;
                     } break;
@@ -15883,6 +16044,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             {
                 n_tasks = n_threads;
             } break;
+        case GGML_OP_PAD:
+            {
+                n_tasks = n_threads;
+            } break;
         case GGML_OP_ARGSORT:
             {
                 n_tasks = n_threads;
diff --git a/ggml.h b/ggml.h
index a8f10cbd5c1d8ead7963644753cd3d8bdd776f05..1447646b13c439a94026cf9af1cc4b42aa617fef 100644 (file)
--- a/ggml.h
+++ b/ggml.h
 #define GGML_QNT_VERSION_FACTOR 1000 // do not change this
 
 #define GGML_MAX_DIMS           4
-#define GGML_MAX_PARAMS         1024
+#define GGML_MAX_PARAMS         2048
 #define GGML_MAX_CONTEXTS       64
-#define GGML_MAX_SRC            6
+#define GGML_MAX_SRC            10
 #define GGML_MAX_NAME           64
 #define GGML_MAX_OP_PARAMS      64
 #define GGML_DEFAULT_N_THREADS  4
@@ -423,7 +423,9 @@ extern "C" {
         GGML_OP_POOL_1D,
         GGML_OP_POOL_2D,
         GGML_OP_UPSCALE, // nearest interpolate
+        GGML_OP_PAD,
         GGML_OP_ARGSORT,
+        GGML_OP_LEAKY_RELU,
 
         GGML_OP_FLASH_ATTN,
         GGML_OP_FLASH_FF,
@@ -463,7 +465,6 @@ extern "C" {
         GGML_UNARY_OP_GELU,
         GGML_UNARY_OP_GELU_QUICK,
         GGML_UNARY_OP_SILU,
-        GGML_UNARY_OP_LEAKY,
 
         GGML_UNARY_OP_COUNT,
     };
@@ -793,6 +794,9 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    // dst = a
+    // view(dst, nb1, nb2, nb3, offset) += b
+    // return dst
     GGML_API struct ggml_tensor * ggml_acc(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -957,15 +961,14 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
-    GGML_API struct ggml_tensor * ggml_leaky(
+    GGML_API struct ggml_tensor * ggml_leaky_relu(
             struct ggml_context * ctx,
-            struct ggml_tensor  * a);
+            struct ggml_tensor  * a, float negative_slope, bool inplace);
 
     GGML_API struct ggml_tensor * ggml_relu_inplace(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
-    // TODO: double-check this computation is correct
     GGML_API struct ggml_tensor * ggml_gelu(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
@@ -1051,7 +1054,8 @@ extern "C" {
     //  ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
     GGML_API struct ggml_tensor * ggml_mul_mat_id(
             struct ggml_context * ctx,
-            struct ggml_tensor  * as[],
+            struct ggml_tensor  * const as[],
+            int                   n_as,
             struct ggml_tensor  * ids,
             int                   id,
             struct ggml_tensor  * b);
@@ -1263,6 +1267,7 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    // supports 3D: a->ne[2] == b->ne[1]
     GGML_API struct ggml_tensor * ggml_get_rows(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -1549,6 +1554,15 @@ extern "C" {
             struct ggml_tensor  * a,
             int                   scale_factor);
 
+    // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
+    GGML_API struct ggml_tensor * ggml_pad(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                  p0,
+            int                  p1,
+            int                  p2,
+            int                  p3);
+
     // sort rows
     enum ggml_sort_order {
         GGML_SORT_ASC,