]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
sync : llama (mul_mat_id + get_rows kernels, typos) (#649)
authorGeorgi Gerganov <redacted>
Wed, 13 Dec 2023 19:53:20 +0000 (21:53 +0200)
committerGitHub <redacted>
Wed, 13 Dec 2023 19:53:20 +0000 (21:53 +0200)
* sync : llama (mul_mat_id + get_rows kernels, typos)

ggml-ci

* cuda : restore im2col

ggml-ci

* metal : fix accuracy of dequantization kernels

* cuda : restore correct im2col kernel

ggml-ci

* metal : fix moe test by reducing the expert size

ggml-ci

* cuda : fix bin bcast when src1 and dst have different types

---------

Co-authored-by: slaren <redacted>
include/ggml/ggml-alloc.h
include/ggml/ggml.h
src/ggml-cuda.cu
src/ggml-metal.m
src/ggml-metal.metal
src/ggml-quants.c
src/ggml.c
tests/test-backend-ops.cpp
tests/test-grad0.cpp
tests/test-quantize-perf.cpp

index ad87cebc8873f4584ad99e2c67e110a603f4a71c..64a412468915b47c58ce100bb2dfcaf4d5502b32 100644 (file)
@@ -43,7 +43,7 @@ GGML_API size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph
 // ggml-backend v2 API
 //
 
-// Seperate tensor and graph allocator objects
+// Separate tensor and graph allocator objects
 // This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators
 // The original API is kept as a wrapper around the new API
 
index c0fd2230f86e31bcab150b56485b187b6d725a2c..1447646b13c439a94026cf9af1cc4b42aa617fef 100644 (file)
 #define GGML_QNT_VERSION_FACTOR 1000 // do not change this
 
 #define GGML_MAX_DIMS           4
-#define GGML_MAX_PARAMS         1024
+#define GGML_MAX_PARAMS         2048
 #define GGML_MAX_CONTEXTS       64
-#define GGML_MAX_SRC            6
+#define GGML_MAX_SRC            10
 #define GGML_MAX_NAME           64
 #define GGML_MAX_OP_PARAMS      64
 #define GGML_DEFAULT_N_THREADS  4
@@ -1054,7 +1054,8 @@ extern "C" {
     //  ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
     GGML_API struct ggml_tensor * ggml_mul_mat_id(
             struct ggml_context * ctx,
-            struct ggml_tensor  * as[],
+            struct ggml_tensor  * const as[],
+            int                   n_as,
             struct ggml_tensor  * ids,
             int                   id,
             struct ggml_tensor  * b);
@@ -1266,6 +1267,7 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    // supports 3D: a->ne[2] == b->ne[1]
     GGML_API struct ggml_tensor * ggml_get_rows(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
index e8a8783d5c43703d38f0b2531b62423c27c11517..019648bddc4d9e23d363d289b432f10dbc481853 100644 (file)
@@ -1,13 +1,15 @@
 #include <algorithm>
+#include <assert.h>
+#include <atomic>
+#include <cinttypes>
 #include <cstddef>
 #include <cstdint>
-#include <cinttypes>
 #include <float.h>
 #include <limits>
 #include <stdint.h>
 #include <stdio.h>
-#include <atomic>
-#include <assert.h>
+#include <vector>
+
 
 #if defined(GGML_USE_HIPBLAS)
 #include <hip/hip_runtime.h>
@@ -1859,31 +1861,65 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
 }
 
 template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
-static __global__ void k_get_rows(const void * x, const int32_t * y, dst_t * dst, const int ncols) {
-    const int col = (blockIdx.x*blockDim.x + threadIdx.x)*2;
-    const int row = blockDim.y*blockIdx.y + threadIdx.y;
-
-    if (col >= ncols) {
+static __global__ void k_get_rows(
+            const void * src0, const int32_t * src1, dst_t * dst,
+            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+            size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
+
+    const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
+    const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
+    const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
+    const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
+
+    if (i00 >= ne00) {
         return;
     }
 
-    const int r = y[row];
+    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
 
-    // copy x[r*ncols + col] to dst[row*ncols + col]
-    const int xi = r*ncols + col;
-    const int di = row*ncols + col;
+    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+    const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
 
-    const int ib = xi/qk; // block index
-    const int iqs = (xi%qk)/qr; // quant index
-    const int iybs = di - di%qk; // y block start index
+    const int ib = i00/qk; // block index
+    const int iqs = (i00%qk)/qr; // quant index
+    const int iybs = i00 - i00%qk; // dst block start index
     const int y_offset = qr == 1 ? 1 : qk/2;
 
     // dequantize
     dfloat2 v;
-    dequantize_kernel(x, ib, iqs, v);
+    dequantize_kernel(src0_row, ib, iqs, v);
 
-    dst[iybs + iqs + 0]        = v.x;
-    dst[iybs + iqs + y_offset] = v.y;
+    dst_row[iybs + iqs + 0]        = v.x;
+    dst_row[iybs + iqs + y_offset] = v.y;
+}
+
+template<typename src0_t, typename dst_t>
+static __global__ void k_get_rows_float(
+            const src0_t * src0, const int32_t * src1, dst_t * dst,
+            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+            size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
+
+    const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
+    const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
+    const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
+    const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
+
+    if (i00 >= ne00) {
+        return;
+    }
+
+    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+
+    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+    const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
+
+    dst_row[i00] = src0_row[i00];
 }
 
 template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
@@ -5239,11 +5275,69 @@ static  __global__ void im2col_f32_f16(
 }
 
 template<int qk, int qr, dequantize_kernel_t dq>
-static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
+static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+                            const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
     const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
-    const int block_num_x = (ncols + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
-    const dim3 block_nums(block_num_x, nrows, 1);
-    k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols);
+    const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
+    const dim3 block_nums(block_num_x, ne10, ne11*ne12);
+
+    // strides in elements
+    //const size_t s0 = nb0 / ggml_element_size(dst);
+    const size_t s1 = nb1 / ggml_element_size(dst);
+    const size_t s2 = nb2 / ggml_element_size(dst);
+    const size_t s3 = nb3 / ggml_element_size(dst);
+
+    const size_t s10 = nb10 / ggml_element_size(src1);
+    const size_t s11 = nb11 / ggml_element_size(src1);
+    const size_t s12 = nb12 / ggml_element_size(src1);
+    //const size_t s13 = nb13 / ggml_element_size(src1);
+
+    GGML_ASSERT(ne00 % 2 == 0);
+
+    k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
+            src0_dd, src1_dd, dst_dd,
+            ne00, /*ne01, ne02, ne03,*/
+            /*ne10, ne11,*/ ne12, /*ne13,*/
+            /* s0,*/ s1, s2, s3,
+            /* nb00,*/ nb01, nb02, nb03,
+            s10, s11, s12/*, s13*/);
+
+    (void) dst;
+}
+
+template<typename src0_t>
+static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+                                const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
+    const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
+    const dim3 block_nums(block_num_x, ne10, ne11*ne12);
+
+    // strides in elements
+    //const size_t s0 = nb0 / ggml_element_size(dst);
+    const size_t s1 = nb1 / ggml_element_size(dst);
+    const size_t s2 = nb2 / ggml_element_size(dst);
+    const size_t s3 = nb3 / ggml_element_size(dst);
+
+    const size_t s10 = nb10 / ggml_element_size(src1);
+    const size_t s11 = nb11 / ggml_element_size(src1);
+    const size_t s12 = nb12 / ggml_element_size(src1);
+    //const size_t s13 = nb13 / ggml_element_size(src1);
+
+    k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
+            src0_dd, src1_dd, dst_dd,
+            ne00, /*ne01, ne02, ne03,*/
+            /*ne10, ne11,*/ ne12, /*ne13,*/
+            /* s0,*/ s1, s2, s3,
+            /* nb00,*/ nb01, nb02, nb03,
+            s10, s11, s12/*, s13*/);
+
+    (void) dst;
 }
 
 template<float (*bin_op)(const float, const float)>
@@ -5255,7 +5349,6 @@ struct bin_bcast_cuda {
 
         GGML_TENSOR_BINARY_OP_LOCALS
 
-
         int nr0 = ne10/ne0;
         int nr1 = ne11/ne1;
         int nr2 = ne12/ne2;
@@ -5303,26 +5396,28 @@ struct bin_bcast_cuda {
             int64_t ne12 = cne1[2];
             int64_t ne13 = cne1[3];
 
-            //size_t nb0 = cnb0[0];
+            size_t nb0 = cnb0[0];
             size_t nb1 = cnb0[1];
             size_t nb2 = cnb0[2];
             size_t nb3 = cnb0[3];
 
-            //size_t nb10 = cnb1[0];
+            size_t nb10 = cnb1[0];
             size_t nb11 = cnb1[1];
             size_t nb12 = cnb1[2];
             size_t nb13 = cnb1[3];
 
-            //size_t s0 = nb0 / sizeof(src1_t);
-            size_t s1 = nb1 / sizeof(src1_t);
-            size_t s2 = nb2 / sizeof(src1_t);
-            size_t s3 = nb3 / sizeof(src1_t);
+            size_t s0 = nb0 / sizeof(dst_t);
+            size_t s1 = nb1 / sizeof(dst_t);
+            size_t s2 = nb2 / sizeof(dst_t);
+            size_t s3 = nb3 / sizeof(dst_t);
 
-            //size_t s10 = nb10 / sizeof(src1_t);
+            size_t s10 = nb10 / sizeof(src1_t);
             size_t s11 = nb11 / sizeof(src1_t);
             size_t s12 = nb12 / sizeof(src1_t);
             size_t s13 = nb13 / sizeof(src1_t);
 
+            GGML_ASSERT(s0 == 1);
+            GGML_ASSERT(s10 == 1);
 
             const int block_size = 128;
 
@@ -6688,36 +6783,34 @@ static void ggml_cuda_op_get_rows(
 
     GGML_ASSERT(src1->type == GGML_TYPE_I32);
     GGML_ASSERT(dst->type == GGML_TYPE_F32);
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    GGML_ASSERT(ggml_is_contiguous(src1));
-    GGML_ASSERT(ggml_is_contiguous(dst));
 
-    const int ncols = src0->ne[0];
-    const int nrows = ggml_nelements(src1);
+    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+    GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
+    GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
 
     const int32_t * src1_i32 = (const int32_t *) src1_d;
 
     switch (src0->type) {
         case GGML_TYPE_F16:
-            get_rows_cuda<1, 1, convert_f16>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+            get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream);
             break;
         case GGML_TYPE_F32:
-            get_rows_cuda<1, 1, convert_f32>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+            get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
             break;
         case GGML_TYPE_Q4_0:
-            get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+            get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
             break;
         case GGML_TYPE_Q4_1:
-            get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+            get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
             break;
         case GGML_TYPE_Q5_0:
-            get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+            get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
             break;
         case GGML_TYPE_Q5_1:
-            get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+            get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
             break;
         case GGML_TYPE_Q8_0:
-            get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+            get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
             break;
         default:
             // TODO: k-quants
@@ -7514,6 +7607,7 @@ inline void ggml_cuda_op_im2col(
     (void) src0_dd;
 }
 
+
 inline void ggml_cuda_op_sum_rows(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -8632,36 +8726,69 @@ static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) {
 }
 #endif
 
-static void ggml_cuda_mul_mat_id(const ggml_tensor * _src0, const ggml_tensor * _src1, ggml_tensor * dst) {
+static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
 #if 0
-//#ifdef CUDA_USE_TENSOR_CORES
-//    const bool use_tensor_cores = true;
-//#else
-//    const bool use_tensor_cores = false;
-//#endif
-
     ggml_cuda_mul_mat_id_cublas(dst);
-
     // TODO: mmq/mmv support
-#else
-    const struct ggml_tensor * ids = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-    const int id = dst->op_params[0];
+#endif
 
-    int32_t * ids_dev = (int32_t *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
+    GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
 
-    int32_t a_id;
-    CUDA_CHECK(cudaMemcpyAsync(&a_id, ids_dev + id, sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
-    CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
+    const struct ggml_tensor * ids = src0;
+    const int32_t id = ((int32_t *) dst->op_params)[0];
+    const int32_t n_as = ((int32_t *) dst->op_params)[1];
 
-    GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
-    const struct ggml_tensor * src0 = dst->src[a_id + 2];
+    std::vector<char> ids_host(ggml_nbytes(ids));
+
+    if (ids->backend == GGML_BACKEND_GPU) {
+        const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
+        CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
+        CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
+    } else {
+        memcpy(ids_host.data(), ids->data, ggml_nbytes(ids));
+    }
+
+    const ggml_tensor_extra_gpu * src1_extra = (const ggml_tensor_extra_gpu *) src1->extra;
+    const ggml_tensor_extra_gpu * dst_extra = (const ggml_tensor_extra_gpu *) dst->extra;
+
+    ggml_tensor_extra_gpu src1_row_extra;
+    ggml_tensor_extra_gpu dst_row_extra;
+
+    ggml_tensor src1_row = *src1;
+    ggml_tensor dst_row = *dst;
+
+    src1_row.ne[1] = 1;
+    dst_row.ne[1] = 1;
+
+    src1_row.nb[2] = src1_row.nb[1];
+    dst_row.nb[2] = dst_row.nb[1];
+
+    src1_row.nb[3] = src1_row.nb[1];
+    dst_row.nb[3] = dst_row.nb[1];
+
+    src1_row.extra = &src1_row_extra;
+    dst_row.extra = &dst_row_extra;
 
-    ggml_cuda_mul_mat(src0, src1, dst);
-#endif
 
-    (void) _src0;
-    (void) _src1;
+    for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
+        //int32_t row_id;
+        //CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
+        //CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
+
+        const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
+
+        GGML_ASSERT(row_id >= 0 && row_id < n_as);
+
+        const struct ggml_tensor * src0_row = dst->src[row_id + 2];
+
+        src1_row_extra.data_device[g_main_device] = (char *) src1_extra->data_device[g_main_device] + i01*src1->nb[1];
+        src1_row.data = (char *) src1->data + i01*src1->nb[1];
+
+        dst_row_extra.data_device[g_main_device] = (char *) dst_extra->data_device[g_main_device] + i01*dst->nb[1];
+        dst_row.data = (char *) dst->data + i01*dst->nb[1];
+
+        ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
+    }
 }
 
 static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -9603,6 +9730,45 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
                 }
                 return true;
             } break;
+        case GGML_OP_GET_ROWS:
+            {
+                switch (op->src[0]->type) {
+                    case GGML_TYPE_F16:
+                    case GGML_TYPE_F32:
+                    case GGML_TYPE_Q4_0:
+                    case GGML_TYPE_Q4_1:
+                    case GGML_TYPE_Q5_0:
+                    case GGML_TYPE_Q5_1:
+                    case GGML_TYPE_Q8_0:
+                        return true;
+                    default:
+                        return false;
+                }
+            } break;
+        case GGML_OP_CPY:
+            {
+                ggml_type src0_type = op->src[0]->type;
+                ggml_type src1_type = op->src[1]->type;
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
+                    return true;
+                }
+                return false;
+            } break;
         case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
         case GGML_OP_VIEW:
@@ -9610,7 +9776,6 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
         case GGML_OP_TRANSPOSE:
         case GGML_OP_NORM:
         case GGML_OP_REPEAT:
-        case GGML_OP_GET_ROWS:
         case GGML_OP_DUP:
         case GGML_OP_ADD:
         case GGML_OP_MUL:
@@ -9619,7 +9784,6 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
         case GGML_OP_SCALE:
         case GGML_OP_SQR:
         case GGML_OP_CLAMP:
-        case GGML_OP_CPY:
         case GGML_OP_CONT:
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_SOFT_MAX:
@@ -9692,7 +9856,9 @@ static ggml_backend_t ggml_backend_reg_cuda_init(const char * params, void * use
     UNUSED(params);
 }
 
-extern "C" int ggml_backend_cuda_reg_devices() {
+extern "C" int ggml_backend_cuda_reg_devices();
+
+int ggml_backend_cuda_reg_devices() {
     int device_count = ggml_cuda_get_device_count();
     //int device_count = 1; // DEBUG: some tools require delaying CUDA initialization
     for (int i = 0; i < device_count; i++) {
index 6b538fd39cf973d26e30bf934a93ae86d1fd456f..465679a6bb0c8702f4f46a1ccf81161cad38b626 100644 (file)
@@ -105,6 +105,21 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
+    //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
+    //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
+    //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
@@ -146,6 +161,7 @@ struct ggml_metal_context {
     //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
     //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
     GGML_METAL_DECL_KERNEL(cpy_f16_f16);
+    GGML_METAL_DECL_KERNEL(cpy_f16_f32);
     GGML_METAL_DECL_KERNEL(concat);
     GGML_METAL_DECL_KERNEL(sqr);
     GGML_METAL_DECL_KERNEL(sum_rows);
@@ -363,6 +379,21 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
         GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
         GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
+        //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
+        //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
+        //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
         if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
             GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
             GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
@@ -406,6 +437,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
         //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
         GGML_METAL_ADD_KERNEL(cpy_f16_f16);
+        GGML_METAL_ADD_KERNEL(cpy_f16_f32);
         GGML_METAL_ADD_KERNEL(concat);
         GGML_METAL_ADD_KERNEL(sqr);
         GGML_METAL_ADD_KERNEL(sum_rows);
@@ -469,6 +501,21 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
     GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
     GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
+    //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
+    //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
+    //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
     if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
         GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
         GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
@@ -512,6 +559,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
     //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
     GGML_METAL_DEL_KERNEL(cpy_f16_f16);
+    GGML_METAL_DEL_KERNEL(cpy_f16_f32);
     GGML_METAL_DEL_KERNEL(concat);
     GGML_METAL_DEL_KERNEL(sqr);
     GGML_METAL_DEL_KERNEL(sum_rows);
@@ -846,17 +894,42 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
         case GGML_OP_PAD:
         case GGML_OP_ARGSORT:
         case GGML_OP_LEAKY_RELU:
-        case GGML_OP_DUP:
-        case GGML_OP_CPY:
-        case GGML_OP_CONT:
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
             return true;
+        case GGML_OP_CPY:
+        case GGML_OP_DUP:
+        case GGML_OP_CONT:
+            {
+                switch (op->src[0]->type) {
+                    case GGML_TYPE_F32:
+                        switch (op->type) {
+                           case GGML_TYPE_F16:
+                           case GGML_TYPE_F32:
+                           case GGML_TYPE_Q8_0:
+                           case GGML_TYPE_Q4_0:
+                           case GGML_TYPE_Q4_1:
+                                return true;
+                           default:
+                                return false;
+                        }
+                    case GGML_TYPE_F16:
+                        switch (op->type) {
+                           case GGML_TYPE_F16:
+                           case GGML_TYPE_F32:
+                                return true;
+                           default:
+                                return false;
+                        }
+                    default:
+                        return false;
+                };
+            }
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_GET_ROWS:
             {
                 return op->ne[3] == 1;
-            } break;
+            }
         default:
             return false;
     }
@@ -1031,36 +1104,39 @@ void ggml_metal_graph_compute(
                     case GGML_OP_MUL:
                     case GGML_OP_DIV:
                         {
-                            GGML_ASSERT(ggml_is_contiguous(src0));
-                            GGML_ASSERT(ggml_is_contiguous(src1));
-
                             const size_t offs = 0;
 
                             bool bcast_row = false;
 
                             int64_t nb = ne00;
 
-                            if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
+                            id<MTLComputePipelineState> pipeline = nil;
+
+                            if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
+                                GGML_ASSERT(ggml_is_contiguous(src0));
+
                                 // src1 is a row
                                 GGML_ASSERT(ne11 == 1);
 
                                 nb = ne00 / 4;
                                 switch (dst->op) {
-                                    case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
-                                    case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
-                                    case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
+                                    case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
+                                    case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
+                                    case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
                                     default: GGML_ASSERT(false);
                                 }
 
                                 bcast_row = true;
                             } else {
                                 switch (dst->op) {
-                                    case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
-                                    case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
-                                    case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
+                                    case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
+                                    case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
+                                    case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
                                     default: GGML_ASSERT(false);
                                 }
                             }
+
+                            [encoder setComputePipelineState:pipeline];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
@@ -1096,7 +1172,7 @@ void ggml_metal_graph_compute(
 
                                 [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                             } else {
-                                const int nth = MIN(1024, ne0);
+                                const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
 
                                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                             }
@@ -1582,7 +1658,7 @@ void ggml_metal_graph_compute(
                                 else if (src0t == GGML_TYPE_Q6_K) {
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 } else {
-                                    int64_t ny = (ne11 + nrows - 1)/nrows;
+                                    const int64_t ny = (ne11 + nrows - 1)/nrows;
                                     [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                             }
@@ -1594,7 +1670,7 @@ void ggml_metal_graph_compute(
 
                             GGML_ASSERT(src0t == GGML_TYPE_I32);
 
-                            const int n_as = ne00;
+                            const int n_as = ((int32_t *) dst->op_params)[1];
 
                             // TODO: make this more general
                             GGML_ASSERT(n_as <= 8);
@@ -1626,14 +1702,22 @@ void ggml_metal_graph_compute(
 
                             // find the break-even point where the matrix-matrix kernel becomes more efficient compared
                             // to the matrix-vector kernel
-                            int ne11_mm_min = 0;
+                            int ne11_mm_min = 1;
 
                             const int idx = ((int32_t *) dst->op_params)[0];
 
+                            // batch size
+                            GGML_ASSERT(ne01 == ne11);
+
+                            const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
+
                             // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
                             // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
-                            if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
-                                ne11 > ne11_mm_min) {
+                            // !!!
+                            // TODO: for now, always use mat-vec kernels until we figure out how to improve the
+                            //       indirect matrix multiplication
+                            // !!!
+                            if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
                                 switch (src2->type) {
                                     case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32];  break;
                                     case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32];  break;
@@ -1652,19 +1736,22 @@ void ggml_metal_graph_compute(
                                 [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
                                 [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
                                 [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
-                                [encoder setBytes:&ne20    length:sizeof(ne20) atIndex:3];
-                                [encoder setBytes:&ne22    length:sizeof(ne22) atIndex:4];
-                                [encoder setBytes:&nb21    length:sizeof(nb21) atIndex:5];
-                                [encoder setBytes:&nb22    length:sizeof(nb22) atIndex:6];
-                                [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:7];
-                                [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:8];
-                                [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:9];
-                                [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:10];
-                                [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:11];
-                                [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:12];
-                                [encoder setBytes:&r2      length:sizeof(r2)   atIndex:13];
-                                [encoder setBytes:&r3      length:sizeof(r3)   atIndex:14];
-                                [encoder setBytes:&idx     length:sizeof(idx)  atIndex:15];
+                                [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:3];
+                                [encoder setBytes:&ne20    length:sizeof(ne20) atIndex:4];
+                                [encoder setBytes:&ne22    length:sizeof(ne22) atIndex:5];
+                                [encoder setBytes:&nb21    length:sizeof(nb21) atIndex:6];
+                                [encoder setBytes:&nb22    length:sizeof(nb22) atIndex:7];
+                                [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:8];
+                                [encoder setBytes:&ne13    length:sizeof(ne13) atIndex:9];
+                                [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:10];
+                                [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:11];
+                                [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:12];
+                                [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:13];
+                                [encoder setBytes:&_ne1    length:sizeof(_ne1) atIndex:14];
+                                [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:15];
+                                [encoder setBytes:&r2      length:sizeof(r2)   atIndex:16];
+                                [encoder setBytes:&r3      length:sizeof(r3)   atIndex:17];
+                                [encoder setBytes:&idx     length:sizeof(idx)  atIndex:18];
                                 // TODO: how to make this an array? read Metal docs
                                 for (int j = 0; j < n_as; ++j) {
                                     struct ggml_tensor * src_cur = dst->src[2 + j];
@@ -1672,11 +1759,157 @@ void ggml_metal_graph_compute(
                                     size_t offs_src_cur = 0;
                                     id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
 
-                                    [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
+                                    [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
                                 }
 
                                 [encoder setThreadgroupMemoryLength:8192 atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+
+                                // TODO: processing one row at a time (ne11 -> 1) is not efficient
+                                [encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                            } else {
+                                int nth0 = 32;
+                                int nth1 = 1;
+                                int nrows = 1;
+                                //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
+
+                                // use custom matrix x vector kernel
+                                switch (src2t) {
+                                    case GGML_TYPE_F32:
+                                        {
+                                            GGML_ASSERT(src1t == GGML_TYPE_F32);
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
+                                        } break;
+                                    case GGML_TYPE_F16:
+                                        {
+                                            GGML_ASSERT(src1t == GGML_TYPE_F32);
+                                            nth0 = 32;
+                                            nth1 = 1;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
+                                        } break;
+                                    case GGML_TYPE_Q4_0:
+                                        {
+                                            nth0 = 8;
+                                            nth1 = 8;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
+                                        } break;
+                                    case GGML_TYPE_Q4_1:
+                                        {
+                                            nth0 = 8;
+                                            nth1 = 8;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
+                                        } break;
+                                    case GGML_TYPE_Q5_0:
+                                        {
+                                            nth0 = 8;
+                                            nth1 = 8;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
+                                        } break;
+                                    case GGML_TYPE_Q5_1:
+                                        {
+                                            nth0 = 8;
+                                            nth1 = 8;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
+                                        } break;
+                                    case GGML_TYPE_Q8_0:
+                                        {
+                                            nth0 = 8;
+                                            nth1 = 8;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
+                                        } break;
+                                    case GGML_TYPE_Q2_K:
+                                        {
+                                            nth0 = 2;
+                                            nth1 = 32;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
+                                        } break;
+                                    case GGML_TYPE_Q3_K:
+                                        {
+                                            nth0 = 2;
+                                            nth1 = 32;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
+                                        } break;
+                                    case GGML_TYPE_Q4_K:
+                                        {
+                                            nth0 = 4; //1;
+                                            nth1 = 8; //32;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
+                                        } break;
+                                    case GGML_TYPE_Q5_K:
+                                        {
+                                            nth0 = 2;
+                                            nth1 = 32;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
+                                        } break;
+                                    case GGML_TYPE_Q6_K:
+                                        {
+                                            nth0 = 2;
+                                            nth1 = 32;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
+                                        } break;
+                                    default:
+                                        {
+                                            GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
+                                            GGML_ASSERT(false && "not implemented");
+                                        }
+                                };
+
+                                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
+                                [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
+                                [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
+                                [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
+                                [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
+                                [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
+                                [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
+                                [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
+                                [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
+                                [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
+                                [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
+                                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
+                                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
+                                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
+                                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:17];
+                                [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
+                                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:19];
+                                [encoder setBytes:&r2   length:sizeof(r2)   atIndex:20];
+                                [encoder setBytes:&r3   length:sizeof(r3)   atIndex:21];
+                                [encoder setBytes:&idx  length:sizeof(idx)  atIndex:22];
+                                // TODO: how to make this an array? read Metal docs
+                                for (int j = 0; j < n_as; ++j) {
+                                    struct ggml_tensor * src_cur = dst->src[2 + j];
+
+                                    size_t offs_src_cur = 0;
+                                    id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
+
+                                    [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
+                                }
+
+                                if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
+                                    src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
+                                    src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                }
+                                else if (src2t == GGML_TYPE_Q4_K) {
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                }
+                                else if (src2t == GGML_TYPE_Q3_K) {
+#ifdef GGML_QKK_64
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+#else
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+#endif
+                                }
+                                else if (src2t == GGML_TYPE_Q5_K) {
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                }
+                                else if (src2t == GGML_TYPE_Q6_K) {
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                } else {
+                                    const int64_t ny = (_ne1 + nrows - 1)/nrows;
+                                    [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                }
                             }
                         } break;
                     case GGML_OP_GET_ROWS:
@@ -1697,16 +1930,19 @@ void ggml_metal_graph_compute(
                                 default: GGML_ASSERT(false && "not implemented");
                             }
 
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                            [encoder setBuffer:id_src0     offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_src1     offset:offs_src1 atIndex:1];
+                            [encoder setBuffer:id_dst      offset:offs_dst  atIndex:2];
                             [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
                             [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
-                            [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:5];
-
-                            const int64_t n = ggml_nelements(src1);
-
-                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                            [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
+                            [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
+                            [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
+                            [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
+                            [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:9];
+                            [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:10];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
                         } break;
                     case GGML_OP_RMS_NORM:
                         {
@@ -2058,7 +2294,7 @@ void ggml_metal_graph_compute(
                                     {
                                         switch (dstt) {
                                             case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
-                                            case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
+                                            case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
                                             default: GGML_ASSERT(false && "not implemented");
                                         };
                                     } break;
index 54b3b8a16200c0f30768987b8a59b8e833cabdde..fe0ada445a2d45c4f295021565a8e9247c9b2dfc 100644 (file)
@@ -366,9 +366,9 @@ kernel void kernel_soft_max(
     const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
     const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
 
-    device const float * psrc0 =        src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-    device const float * pmask = src1 != src0 ? src1                              + i01*ne00 : nullptr;
-    device       float * pdst  =        dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+    device const float * psrc0 =         src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+    device const float * pmask = src1 != src0 ? src1                               + i01*ne00 : nullptr;
+    device       float * pdst  =         dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
 
     // parallel max
     float lmax = -INFINITY;
@@ -404,8 +404,12 @@ kernel void kernel_soft_max(
         pdst[i00] = exp_psrc0;
     }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+    // This barrier fixes a failing test
+    // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+    threadgroup_barrier(mem_flags::mem_none);
+
     float sum = simd_sum(lsum);
+
     if (ntg > N_SIMDWIDTH) {
         if (sgitg == 0) {
             buf[tiisg] = 0.0f;
@@ -488,8 +492,13 @@ kernel void kernel_soft_max_4(
     }
 
     const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    // This barrier fixes a failing test
+    // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+    threadgroup_barrier(mem_flags::mem_none);
+
     float sum = simd_sum(lsum);
+
     if (ntg > N_SIMDWIDTH) {
         if (sgitg == 0) {
             buf[tiisg] = 0.0f;
@@ -840,7 +849,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
 //      giard against the number of rows not being divisible by
 //      N_DST, so this is another explicit assumption of the implementation.
 template<typename block_q_type, int nr, int nsg, int nw>
-void mul_vec_q_n_f32(
+void mul_vec_q_n_f32_impl(
         device const void  * src0,
         device const float * src1,
         device       float * dst,
@@ -922,7 +931,7 @@ kernel void kernel_mul_mv_q4_0_f32(
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+    mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mv_q4_1_f32(
@@ -941,7 +950,7 @@ kernel void kernel_mul_mv_q4_1_f32(
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
-     mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+     mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mv_q5_0_f32(
@@ -960,7 +969,7 @@ kernel void kernel_mul_mv_q5_0_f32(
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+    mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mv_q5_1_f32(
@@ -979,28 +988,28 @@ kernel void kernel_mul_mv_q5_1_f32(
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+    mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
 }
 
 
 #define NB_Q8_0 8
 
-kernel void kernel_mul_mv_q8_0_f32(
+void kernel_mul_mv_q8_0_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne01[[buffer(4)]],
-        constant   int64_t & ne02[[buffer(5)]],
-        constant   int64_t & ne10[[buffer(9)]],
-        constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0 [[buffer(15)]],
-        constant   int64_t & ne1 [[buffer(16)]],
-        constant   uint    & r2  [[buffer(17)]],
-        constant   uint    & r3  [[buffer(18)]],
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
     const int nr  = N_DST;
     const int nsg = N_SIMDGROUP;
     const int nw  = N_SIMDWIDTH;
@@ -1054,9 +1063,29 @@ kernel void kernel_mul_mv_q8_0_f32(
     }
 }
 
+[[host_name("kernel_mul_mv_q8_0_f32")]]
+kernel void kernel_mul_mv_q8_0_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2   [[buffer(17)]],
+        constant   uint    & r3   [[buffer(18)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+    kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+}
+
 #define N_F32_F32 4
 
-kernel void kernel_mul_mv_f32_f32(
+void kernel_mul_mv_f32_f32_impl(
         device const  char * src0,
         device const  char * src1,
         device       float * dst,
@@ -1074,8 +1103,8 @@ kernel void kernel_mul_mv_f32_f32(
         constant  uint64_t & nb12,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
-        constant   uint    & r2   [[buffer(17)]],
-        constant   uint    & r3   [[buffer(18)]],
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]]) {
 
@@ -1134,6 +1163,32 @@ kernel void kernel_mul_mv_f32_f32(
     }
 }
 
+[[host_name("kernel_mul_mv_f32_f32")]]
+kernel void kernel_mul_mv_f32_f32(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2   [[buffer(17)]],
+        constant   uint    & r3   [[buffer(18)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]]) {
+    kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
+}
+
 #define N_F16_F16 4
 
 kernel void kernel_mul_mv_f16_f16(
@@ -1214,7 +1269,7 @@ kernel void kernel_mul_mv_f16_f16(
     }
 }
 
-kernel void kernel_mul_mv_f16_f32_1row(
+void kernel_mul_mv_f16_f32_1row_impl(
         device const  char * src0,
         device const  char * src1,
         device       float * dst,
@@ -1232,8 +1287,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
         constant  uint64_t & nb12,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
-        constant   uint    & r2   [[buffer(17)]],
-        constant   uint    & r3   [[buffer(18)]],
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]]) {
 
@@ -1270,12 +1325,37 @@ kernel void kernel_mul_mv_f16_f32_1row(
             dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
         }
     }
+}
 
+[[host_name("kernel_mul_mv_f16_f32_1row")]]
+kernel void kernel_mul_mv_f16_f32_1row(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2   [[buffer(17)]],
+        constant   uint    & r3   [[buffer(18)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]]) {
+    kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
 }
 
 #define N_F16_F32 4
 
-kernel void kernel_mul_mv_f16_f32(
+void kernel_mul_mv_f16_f32_impl(
         device const  char * src0,
         device const  char * src1,
         device       float * dst,
@@ -1293,8 +1373,8 @@ kernel void kernel_mul_mv_f16_f32(
         constant  uint64_t & nb12,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
-        constant   uint    & r2   [[buffer(17)]],
-        constant   uint    & r3   [[buffer(18)]],
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]]) {
 
@@ -1353,6 +1433,32 @@ kernel void kernel_mul_mv_f16_f32(
     }
 }
 
+[[host_name("kernel_mul_mv_f16_f32")]]
+kernel void kernel_mul_mv_f16_f32(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2   [[buffer(17)]],
+        constant   uint    & r3   [[buffer(18)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]]) {
+    kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
+}
+
 // Assumes row size (ne00) is a multiple of 4
 kernel void kernel_mul_mv_f16_f32_l4(
         device const  char * src0,
@@ -1809,8 +1915,8 @@ kernel void kernel_leaky_relu_f32(
 }
 
 kernel void kernel_cpy_f16_f16(
-        device const half * src0,
-        device       half * dst,
+        device  const half * src0,
+        device        half * dst,
         constant   int64_t & ne00,
         constant   int64_t & ne01,
         constant   int64_t & ne02,
@@ -1849,6 +1955,47 @@ kernel void kernel_cpy_f16_f16(
     }
 }
 
+kernel void kernel_cpy_f16_f32(
+        device  const half * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+    device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+        device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+        dst_data[i00] = src[0];
+    }
+}
+
 kernel void kernel_cpy_f32_f16(
         device const float * src0,
         device        half * dst,
@@ -2272,19 +2419,19 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
 
 //====================================== dot products =========================
 
-kernel void kernel_mul_mv_q2_K_f32(
+void kernel_mul_mv_q2_K_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne01[[buffer(4)]],
-        constant   int64_t & ne02[[buffer(5)]],
-        constant   int64_t & ne10[[buffer(9)]],
-        constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0 [[buffer(15)]],
-        constant   int64_t & ne1 [[buffer(16)]],
-        constant   uint    & r2  [[buffer(17)]],
-        constant   uint    & r3  [[buffer(18)]],
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2422,8 +2569,8 @@ kernel void kernel_mul_mv_q2_K_f32(
     }
 }
 
-#if QK_K == 256
-kernel void kernel_mul_mv_q3_K_f32(
+[[host_name("kernel_mul_mv_q2_K_f32")]]
+kernel void kernel_mul_mv_q2_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -2437,8 +2584,29 @@ kernel void kernel_mul_mv_q3_K_f32(
         constant   uint    & r2  [[buffer(17)]],
         constant   uint    & r3  [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
+#if QK_K == 256
+void kernel_mul_mv_q3_K_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const int nb = ne00/QK_K;
 
@@ -2581,19 +2749,19 @@ kernel void kernel_mul_mv_q3_K_f32(
     }
 }
 #else
-kernel void kernel_mul_mv_q3_K_f32(
+void kernel_mul_mv_q3_K_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne01[[buffer(4)]],
-        constant   int64_t & ne02[[buffer(5)]],
-        constant   int64_t & ne10[[buffer(9)]],
-        constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0 [[buffer(15)]],
-        constant   int64_t & ne1 [[buffer(16)]],
-        constant   uint    & r2  [[buffer(17)]],
-        constant   uint    & r3  [[buffer(18)]],
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2658,20 +2826,41 @@ kernel void kernel_mul_mv_q3_K_f32(
 }
 #endif
 
+[[host_name("kernel_mul_mv_q3_K_f32")]]
+kernel void kernel_mul_mv_q3_K_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01[[buffer(4)]],
+        constant   int64_t & ne02[[buffer(5)]],
+        constant   int64_t & ne10[[buffer(9)]],
+        constant   int64_t & ne12[[buffer(11)]],
+        constant   int64_t & ne0 [[buffer(15)]],
+        constant   int64_t & ne1 [[buffer(16)]],
+        constant   uint    & r2  [[buffer(17)]],
+        constant   uint    & r3  [[buffer(18)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
 #if QK_K == 256
-kernel void kernel_mul_mv_q4_K_f32(
+void kernel_mul_mv_q4_K_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne01 [[buffer(4)]],
-        constant   int64_t & ne02 [[buffer(5)]],
-        constant   int64_t & ne10 [[buffer(9)]],
-        constant   int64_t & ne12 [[buffer(11)]],
-        constant   int64_t & ne0  [[buffer(15)]],
-        constant   int64_t & ne1  [[buffer(16)]],
-        constant   uint    & r2   [[buffer(17)]],
-        constant   uint    & r3   [[buffer(18)]],
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2772,19 +2961,19 @@ kernel void kernel_mul_mv_q4_K_f32(
     }
 }
 #else
-kernel void kernel_mul_mv_q4_K_f32(
+void kernel_mul_mv_q4_K_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne01[[buffer(4)]],
-        constant   int64_t & ne02[[buffer(5)]],
-        constant   int64_t & ne10[[buffer(9)]],
-        constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0 [[buffer(15)]],
-        constant   int64_t & ne1 [[buffer(16)]],
-        constant   uint    & r2  [[buffer(17)]],
-        constant   uint    & r3  [[buffer(18)]],
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2868,7 +3057,8 @@ kernel void kernel_mul_mv_q4_K_f32(
 }
 #endif
 
-kernel void kernel_mul_mv_q5_K_f32(
+[[host_name("kernel_mul_mv_q4_K_f32")]]
+kernel void kernel_mul_mv_q4_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -2885,6 +3075,26 @@ kernel void kernel_mul_mv_q5_K_f32(
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
+    kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_q5_K_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
     const int nb = ne00/QK_K;
 
     const int64_t r0 = tgpig.x;
@@ -3044,10 +3254,10 @@ kernel void kernel_mul_mv_q5_K_f32(
             dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
         }
     }
-
 }
 
-kernel void kernel_mul_mv_q6_K_f32(
+[[host_name("kernel_mul_mv_q5_K_f32")]]
+kernel void kernel_mul_mv_q5_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -3061,8 +3271,28 @@ kernel void kernel_mul_mv_q6_K_f32(
         constant   uint    & r2  [[buffer(17)]],
         constant   uint    & r3  [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_q6_K_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const uint8_t kmask1 = 0x03;
     const uint8_t kmask2 = 0x0C;
@@ -3153,6 +3383,27 @@ kernel void kernel_mul_mv_q6_K_f32(
     }
 }
 
+[[host_name("kernel_mul_mv_q6_K_f32")]]
+kernel void kernel_mul_mv_q6_K_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01[[buffer(4)]],
+        constant   int64_t & ne02[[buffer(5)]],
+        constant   int64_t & ne10[[buffer(9)]],
+        constant   int64_t & ne12[[buffer(11)]],
+        constant   int64_t & ne0 [[buffer(15)]],
+        constant   int64_t & ne1 [[buffer(16)]],
+        constant   uint    & r2  [[buffer(17)]],
+        constant   uint    & r3  [[buffer(18)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
 //============================= templates and their specializations =============================
 
 // NOTE: this is not dequantizing - we are simply fitting the template
@@ -3270,10 +3521,10 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
 
 template <typename type4x4>
 void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
-    const half d = xb->d;
-    const half min = xb->dmin;
+    const float d = xb->d;
+    const float min = xb->dmin;
     device const uint8_t * q = (device const uint8_t *)xb->qs;
-    half dl, ml;
+    float dl, ml;
     uint8_t sc = xb->scales[il];
 
 #if QK_K == 256
@@ -3343,10 +3594,10 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
     q = q + (il/4) * 32 + 16 * (il&1);
     il = il & 3;
     const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
-    const half d   = il < 2 ? xb->d : xb->d / 16.h;
-    const half min = xb->dmin;
-    const half dl = d * sc[0];
-    const half ml = min * sc[1];
+    const float d   = il < 2 ? xb->d : xb->d / 16.h;
+    const float min = xb->dmin;
+    const float dl = d * sc[0];
+    const float ml = min * sc[1];
 #else
     q = q + 16 * (il&1);
     device const uint8_t * s = xb->scales;
@@ -3373,13 +3624,13 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
     uint8_t ul = 1 << (il/2);
     il = il & 3;
     const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
-    const half d = il < 2 ? xb->d : xb->d / 16.h;
-    const half min = xb->dmin;
-    const half dl = d * sc[0];
-    const half ml = min * sc[1];
+    const float d = il < 2 ? xb->d : xb->d / 16.h;
+    const float min = xb->dmin;
+    const float dl = d * sc[0];
+    const float ml = min * sc[1];
 
-    const ushort mask = il<2 ? 0x0F : 0xF0;
-    const half qh_val = il<2 ? 16.h : 256.h;
+    const ushort mask  = il<2 ? 0x0F : 0xF0;
+    const float qh_val = il<2 ? 16.f : 256.f;
     for (int i = 0; i < 16; ++i) {
         reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
     }
@@ -3427,22 +3678,90 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
 kernel void kernel_get_rows(
         device const  void * src0,
-        device const   int * src1,
+        device const  char * src1,
         device       float * dst,
         constant   int64_t & ne00,
         constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
         constant  uint64_t & nb1,
-        uint                 tgpig[[threadgroup_position_in_grid]],
+        constant  uint64_t & nb2,
+        uint3                tgpig[[threadgroup_position_in_grid]],
         uint                 tiitg[[thread_index_in_threadgroup]],
-        uint                 tptg[[threads_per_threadgroup]]) {
-    const int i = tgpig;
-    const int r = ((device int32_t *) src1)[i];
+        uint3                tptg [[threads_per_threadgroup]]) {
+    //const int64_t i = tgpig;
+    //const int64_t r = ((device int32_t *) src1)[i];
+
+    const int64_t i10 = tgpig.x;
+    const int64_t i11 = tgpig.y;
+
+    const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+    const int64_t i02 = i11;
 
-    for (int ind = tiitg; ind < ne00/16; ind += tptg) {
+    for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
         float4x4 temp;
         dequantize_func(
-            ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
-        *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
+            ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
+        *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
+    }
+}
+
+kernel void kernel_get_rows_f32(
+        device const  void * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        uint3                tgpig[[threadgroup_position_in_grid]],
+        uint                 tiitg[[thread_index_in_threadgroup]],
+        uint3                tptg [[threads_per_threadgroup]]) {
+    const int64_t i10 = tgpig.x;
+    const int64_t i11 = tgpig.y;
+
+    const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+    const int64_t i02 = i11;
+
+    for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+        ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
+            ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
+    }
+}
+
+kernel void kernel_get_rows_f16(
+        device const  void * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        uint3                tgpig[[threadgroup_position_in_grid]],
+        uint                 tiitg[[thread_index_in_threadgroup]],
+        uint3                tptg [[threads_per_threadgroup]]) {
+    const int64_t i10 = tgpig.x;
+    const int64_t i11 = tgpig.y;
+
+    const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+    const int64_t i02 = i11;
+
+    for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+        ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
+            ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
     }
 }
 
@@ -3634,19 +3953,22 @@ kernel void kernel_mul_mm(device const  uchar * src0,
 
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
 kernel void kernel_mul_mm_id(
-        device const int32_t * ids,
+        device const   uchar * ids,
         device const   uchar * src1,
-        device         float * dst,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne02,
         constant     int64_t & nb01,
         constant     int64_t & nb02,
         constant     int64_t & ne12,
+        constant     int64_t & ne13,
         constant     int64_t & nb10,
         constant     int64_t & nb11,
         constant     int64_t & nb12,
         constant     int64_t & ne0,
         constant     int64_t & ne1,
+        constant     int64_t & nb1,
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
@@ -3664,10 +3986,16 @@ kernel void kernel_mul_mm_id(
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
     device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
 
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
     kernel_mul_mm_impl<block_q, nl, dequantize_func>(
-        src0[ids[idx]],
-        src1,
-        dst,
+        src0[id],
+        src1 + bid*nb11,
+        (device float *) (dst + bid*nb1),
         ne00,
         ne02,
         nb01,
@@ -3692,17 +4020,26 @@ kernel void kernel_mul_mm_id(
 #define QK_NL 4
 #endif
 
+//
+// get rows
+//
+
 typedef void (get_rows_t)(
         device const void * src0,
-        device const  int * src1,
+        device const char * src1,
         device      float * dst,
         constant  int64_t & ne00,
         constant uint64_t & nb01,
+        constant uint64_t & nb02,
+        constant  int64_t & ne10,
+        constant uint64_t & nb10,
+        constant uint64_t & nb11,
         constant uint64_t & nb1,
-        uint, uint, uint);
+        constant uint64_t & nb2,
+        uint3, uint, uint3);
 
-template [[host_name("kernel_get_rows_f32")]]  kernel get_rows_t kernel_get_rows<float4x4,   1, dequantize_f32>;
-template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_t kernel_get_rows<half4x4,    1, dequantize_f16>;
+//template [[host_name("kernel_get_rows_f32")]]  kernel get_rows_t kernel_get_rows<float4x4,   1, dequantize_f32>;
+//template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_t kernel_get_rows<half4x4,    1, dequantize_f16>;
 template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
 template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
 template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
@@ -3714,6 +4051,10 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
 template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
 template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
 
+//
+// matrix-matrix multiplication
+//
+
 typedef void (mat_mm_t)(
         device const  uchar * src0,
         device const  uchar * src1,
@@ -3746,20 +4087,27 @@ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
 template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
 template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
 
+//
+// indirect matrix-matrix multiplication
+//
+
 typedef void (mat_mm_id_t)(
-        device const int32_t * ids,
+        device const   uchar * ids,
         device const   uchar * src1,
-        device         float * dst,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne02,
         constant     int64_t & nb01,
         constant     int64_t & nb02,
         constant     int64_t & ne12,
+        constant     int64_t & ne13,
         constant     int64_t & nb10,
         constant     int64_t & nb11,
         constant     int64_t & nb12,
         constant     int64_t & ne0,
         constant     int64_t & ne1,
+        constant     int64_t & nb1,
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
@@ -3786,3 +4134,775 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mu
 template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
 template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
 template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
+
+//
+// matrix-vector multiplication
+//
+
+[[host_name("kernel_mul_mv_id_f32_f32")]]
+kernel void kernel_mul_mv_id_f32_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_f32_f32_impl(
+        src0[id],
+        src1 + bid*nb11,
+        (device float *) (dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        nb00,
+        nb01,
+        nb02,
+        ne10,
+        ne11,
+        ne12,
+        nb10,
+        nb11,
+        nb12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg);
+}
+
+[[host_name("kernel_mul_mv_id_f16_f32")]]
+kernel void kernel_mul_mv_id_f16_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_f16_f32_impl(
+        src0[id],
+        src1 + bid*nb11,
+        (device float *) (dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        nb00,
+        nb01,
+        nb02,
+        ne10,
+        ne11,
+        ne12,
+        nb10,
+        nb11,
+        nb12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg);
+}
+
+[[host_name("kernel_mul_mv_id_q8_0_f32")]]
+kernel void kernel_mul_mv_id_q8_0_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_q8_0_f32_impl(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q4_0_f32")]]
+kernel void kernel_mul_mv_id_q4_0_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q4_1_f32")]]
+kernel void kernel_mul_mv_id_q4_1_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q5_0_f32")]]
+kernel void kernel_mul_mv_id_q5_0_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q5_1_f32")]]
+kernel void kernel_mul_mv_id_q5_1_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q2_K_f32")]]
+kernel void kernel_mul_mv_id_q2_K_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_q2_K_f32_impl(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q3_K_f32")]]
+kernel void kernel_mul_mv_id_q3_K_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_q3_K_f32_impl(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q4_K_f32")]]
+kernel void kernel_mul_mv_id_q4_K_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_q4_K_f32_impl(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q5_K_f32")]]
+kernel void kernel_mul_mv_id_q5_K_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_q5_K_f32_impl(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q6_K_f32")]]
+kernel void kernel_mul_mv_id_q6_K_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_q6_K_f32_impl(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
index 7285d5f7fbcc00ce41b5d481b702ecc52c5671f6..0e8163a16b39549671363ac859cad2a7e0aaeefa 100644 (file)
@@ -3114,7 +3114,7 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
 
     size_t vl = __riscv_vsetvl_e8m1(qk/2);
 
-    // These tempory registers are for masking and shift operations
+    // These temporary registers are for masking and shift operations
     vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
     vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
 
@@ -4757,7 +4757,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
 
             vl = 16;
 
-            // retreive lane to multiply with scale
+            // retrieve lane to multiply with scale
             vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
             vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
             vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
index 8652446c739352c8d2c373f0e0ff328c65a706a2..29e18a24c76a85f7babfcfe1652c32f0746aeadd 100644 (file)
@@ -1,4 +1,4 @@
-#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
+#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
 #define _USE_MATH_DEFINES // For M_PI on MSVC
 
 #include "ggml-impl.h"
@@ -33,7 +33,7 @@
 // we should just be careful :)
 #pragma warning(disable: 4244 4267)
 
-// disable POSIX deprecation warnigns
+// disable POSIX deprecation warnings
 // these functions are never going away, anyway
 #pragma warning(disable: 4996)
 #endif
@@ -1763,7 +1763,7 @@ static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size
 static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
 
 // WARN:
-// Mis-confguration can lead to problem that's hard to reason about:
+// Mis-configuration can lead to problem that's hard to reason about:
 // * At best  it crash or talks nosense.
 // * At worst it talks slightly difference but hard to perceive.
 //
@@ -4092,17 +4092,18 @@ struct ggml_tensor * ggml_mul_mat(
 
 struct ggml_tensor * ggml_mul_mat_id(
         struct ggml_context * ctx,
-        struct ggml_tensor  * as[],
+        struct ggml_tensor  * const as[],
+        int                   n_as,
         struct ggml_tensor  * ids,
         int                   id,
         struct ggml_tensor  * b) {
 
-    int64_t n_as = ids->ne[0];
-
     GGML_ASSERT(ids->type == GGML_TYPE_I32);
-    GGML_ASSERT(ggml_is_vector(ids));
+    GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1);
+    GGML_ASSERT(ids->ne[1] == b->ne[1]);
+    GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
     GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2);
-    GGML_ASSERT(id >= 0 && id < n_as);
+    GGML_ASSERT(id >= 0 && id < ids->ne[0]);
 
     bool is_node = false;
 
@@ -4114,13 +4115,14 @@ struct ggml_tensor * ggml_mul_mat_id(
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne);
 
     ggml_set_op_params_i32(result, 0, id);
+    ggml_set_op_params_i32(result, 1, n_as);
 
     result->op   = GGML_OP_MUL_MAT_ID;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = ids;
     result->src[1] = b;
 
-    for (int64_t i = 0; i < n_as; i++) {
+    for (int i = 0; i < n_as; i++) {
         struct ggml_tensor * a = as[i];
         GGML_ASSERT(ggml_are_same_shape(as[0], a));
         GGML_ASSERT(ggml_can_mul_mat(a, b));
@@ -4748,7 +4750,9 @@ struct ggml_tensor * ggml_get_rows(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b) {
-    GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32);
+    GGML_ASSERT(a->ne[2] == b->ne[1]);
+    GGML_ASSERT(b->ne[3] == 1);
+    GGML_ASSERT(b->type == GGML_TYPE_I32);
 
     bool is_node = false;
 
@@ -4758,7 +4762,7 @@ struct ggml_tensor * ggml_get_rows(
 
     // TODO: implement non F32 return
     //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
-    struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0]);
+    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
 
     result->op   = GGML_OP_GET_ROWS;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -7561,7 +7565,7 @@ static void ggml_compute_forward_acc_f32(
     GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
 
     // view src0 and dst with these strides and data offset inbytes during acc
-    // nb0 is implicitely element_size because src0 and dst are contiguous
+    // nb0 is implicitly element_size because src0 and dst are contiguous
     size_t nb1     = ((int32_t *) dst->op_params)[0];
     size_t nb2     = ((int32_t *) dst->op_params)[1];
     size_t nb3     = ((int32_t *) dst->op_params)[2];
@@ -9549,8 +9553,11 @@ static bool ggml_compute_forward_mul_mat_use_blas(
     const int64_t ne0 = dst->ne[0];
     const int64_t ne1 = dst->ne[1];
 
+    // NOTE: with GGML_OP_MUL_MAT_ID we don't want to go through the BLAS branch because it will dequantize (to_float)
+    //       all the experts for each batch element and the processing would become incredibly slow
     // TODO: find the optimal values for these
-    if (ggml_is_contiguous(src0) &&
+    if (dst->op != GGML_OP_MUL_MAT_ID &&
+        ggml_is_contiguous(src0) &&
         ggml_is_contiguous(src1) &&
       //src0->type == GGML_TYPE_F32 &&
         src1->type == GGML_TYPE_F32 &&
@@ -9564,11 +9571,16 @@ static bool ggml_compute_forward_mul_mat_use_blas(
 }
 #endif
 
+// off1 = offset in i11 and i1
+// cne1 = ne11 and ne1
+// in a normal matrix multiplication, off1 = 0 and cne1 = ne1
+// during GGML_TASK_INIT, the full src1 is converted regardless of off1 and cne1
 static void ggml_compute_forward_mul_mat(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
+              struct ggml_tensor * dst,
+              int64_t off1, int64_t cne1) {
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
@@ -9636,10 +9648,9 @@ static void ggml_compute_forward_mul_mat(
                 const int64_t i03 = i13/r3;
                 const int64_t i02 = i12/r2;
 
-                const void  * x = (char *)            src0->data + i02*nb02 + i03*nb03;
-                const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
-
-                float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
+                const void  * x = (char *)            src0->data +             i02*nb02 + i03*nb03;
+                const float * y = (float *) ((char *) src1->data + off1*nb11 + i12*nb12 + i13*nb13);
+                      float * d = (float *) ((char *)  dst->data + off1*nb1  + i12*nb2  + i13*nb3);
 
                 if (type != GGML_TYPE_F32) {
                             float * const wdata    = params->wdata;
@@ -9656,10 +9667,10 @@ static void ggml_compute_forward_mul_mat(
                 }
 
                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
-                        ne11, ne01, ne10,
-                        1.0f,    y, ne10,
-                                 x, ne00,
-                        0.0f,    d, ne01);
+                         cne1, ne01, ne10,
+                         1.0f,    y, ne10,
+                                  x, ne00,
+                         0.0f,    d, ne01);
             }
         }
 
@@ -9675,6 +9686,7 @@ static void ggml_compute_forward_mul_mat(
             const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
 
             assert(params->wsize >= ne11*ne12*ne13*row_size);
+            assert(src1->type == GGML_TYPE_F32);
 
             for (int64_t i13 = 0; i13 < ne13; ++i13) {
                 for (int64_t i12 = 0; i12 < ne12; ++i12) {
@@ -9697,7 +9709,7 @@ static void ggml_compute_forward_mul_mat(
     const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
 
     const int64_t nr0 = ne01;           // src0 rows
-    const int64_t nr1 = ne11*ne12*ne13; // src1 rows
+    const int64_t nr1 = cne1*ne12*ne13; // src1 rows
 
     //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
 
@@ -9739,9 +9751,9 @@ static void ggml_compute_forward_mul_mat(
     for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
         for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
             for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
-                const int64_t i13 = (ir1/(ne12*ne11));
-                const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11;
-                const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11);
+                const int64_t i13 = (ir1/(ne12*cne1));
+                const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1;
+                const int64_t i11 = (ir1 - i13*ne12*cne1 - i12*cne1) + off1;
 
                 // broadcast src0 into src1
                 const int64_t i03 = i13/r3;
@@ -9781,20 +9793,28 @@ static void ggml_compute_forward_mul_mat(
 
 static void ggml_compute_forward_mul_mat_id(
         const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
               struct ggml_tensor * dst) {
 
-    const struct ggml_tensor * ids = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    const int id = ggml_get_op_params_i32(dst, 0);
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        // during GGML_TASK_INIT the entire src1 is converted to vec_dot_type
+        ggml_compute_forward_mul_mat(params, dst->src[2], src1, dst, 0, dst->ne[1]);
+        return;
+    }
 
-    const int a_id = ((int32_t *)ids->data)[id];
+    const struct ggml_tensor * ids = src0;
+    const int id   = ggml_get_op_params_i32(dst, 0);
+    const int n_as = ggml_get_op_params_i32(dst, 1);
 
-    GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
+    for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
+        const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
 
-    const struct ggml_tensor * src0 = dst->src[a_id + 2];
+        GGML_ASSERT(row_id >= 0 && row_id < n_as);
 
-    ggml_compute_forward_mul_mat(params, src0, src1, dst);
+        const struct ggml_tensor * src0_row = dst->src[row_id + 2];
+        ggml_compute_forward_mul_mat(params, src0_row, src1, dst, i01, 1);
+    }
 }
 
 // ggml_compute_forward_out_prod
@@ -10206,7 +10226,7 @@ static void ggml_compute_forward_set_f32(
     GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
 
     // view src0 and dst with these strides and data offset inbytes during set
-    // nb0 is implicitely element_size because src0 and dst are contiguous
+    // nb0 is implicitly element_size because src0 and dst are contiguous
     size_t nb1     = ((int32_t *) dst->op_params)[0];
     size_t nb2     = ((int32_t *) dst->op_params)[1];
     size_t nb3     = ((int32_t *) dst->op_params)[2];
@@ -10370,21 +10390,30 @@ static void ggml_compute_forward_get_rows_q(
         return;
     }
 
-    const int nc = src0->ne[0];
-    const int nr = ggml_nelements(src1);
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int64_t nc = ne00;
+    const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
+
     const enum ggml_type type = src0->type;
     ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
 
-    assert( dst->ne[0] == nc);
-    assert( dst->ne[1] == nr);
-    assert(src0->nb[0] == ggml_type_size(type));
+    assert(ne0  == nc);
+    assert(ne02 == ne11);
+    assert(nb00 == ggml_type_size(type));
+    assert(ggml_nrows(dst) == nr);
 
-    for (int i = 0; i < nr; ++i) {
-        const int r = ((int32_t *) src1->data)[i];
+    // TODO: multi-thread
+    for (int64_t i12 = 0; i12 < ne12; ++i12) {
+        for (int64_t i11 = 0; i11 < ne11; ++i11) {
+            for (int64_t i10 = 0; i10 < ne10; ++i10) {
+                const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
 
-        dequantize_row_q(
-                (const void *) ((char *) src0->data + r*src0->nb[1]),
-                     (float *) ((char *)  dst->data + i*dst->nb[1]), nc);
+                dequantize_row_q(
+                        (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
+                             (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
+            }
+        }
     }
 }
 
@@ -10399,19 +10428,26 @@ static void ggml_compute_forward_get_rows_f16(
         return;
     }
 
-    const int nc = src0->ne[0];
-    const int nr = ggml_nelements(src1);
+    GGML_TENSOR_BINARY_OP_LOCALS
 
-    assert( dst->ne[0] == nc);
-    assert( dst->ne[1] == nr);
-    assert(src0->nb[0] == sizeof(ggml_fp16_t));
+    const int64_t nc = ne00;
+    const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
 
-    for (int i = 0; i < nr; ++i) {
-        const int r = ((int32_t *) src1->data)[i];
+    assert(ne0  == nc);
+    assert(ne02 == ne11);
+    assert(nb00 == sizeof(ggml_fp16_t));
+    assert(ggml_nrows(dst) == nr);
 
-        for (int j = 0; j < nc; ++j) {
-            ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j];
-            ((float *) ((char *)  dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v);
+    // TODO: multi-thread
+    for (int64_t i12 = 0; i12 < ne12; ++i12) {
+        for (int64_t i11 = 0; i11 < ne11; ++i11) {
+            for (int64_t i10 = 0; i10 < ne10; ++i10) {
+                const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+                ggml_fp16_to_fp32_row(
+                        (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
+                             (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
+            }
         }
     }
 }
@@ -10427,19 +10463,27 @@ static void ggml_compute_forward_get_rows_f32(
         return;
     }
 
-    const int nc = src0->ne[0];
-    const int nr = ggml_nelements(src1);
+    GGML_TENSOR_BINARY_OP_LOCALS
 
-    assert( dst->ne[0] == nc);
-    assert( dst->ne[1] == nr);
-    assert(src0->nb[0] == sizeof(float));
+    const int64_t nc = ne00;
+    const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
 
-    for (int i = 0; i < nr; ++i) {
-        const int r = ((int32_t *) src1->data)[i];
+    assert(ne0  == nc);
+    assert(ne02 == ne11);
+    assert(nb00 == sizeof(float));
+    assert(ggml_nrows(dst) == nr);
 
-        ggml_vec_cpy_f32(nc,
-                (float *) ((char *)  dst->data + i*dst->nb[1]),
-                (float *) ((char *) src0->data + r*src0->nb[1]));
+    // TODO: multi-thread
+    for (int64_t i12 = 0; i12 < ne12; ++i12) {
+        for (int64_t i11 = 0; i11 < ne11; ++i11) {
+            for (int64_t i10 = 0; i10 < ne10; ++i10) {
+                const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+                ggml_vec_cpy_f32(nc,
+                        (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3),
+                        (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
+            }
+        }
     }
 }
 
@@ -14138,11 +14182,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_MUL_MAT:
             {
-                ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
+                ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor, 0, tensor->ne[1]);
             } break;
         case GGML_OP_MUL_MAT_ID:
             {
-                ggml_compute_forward_mul_mat_id(params, tensor);
+                ggml_compute_forward_mul_mat_id(params, tensor->src[0], tensor->src[1], tensor);
             } break;
         case GGML_OP_OUT_PROD:
             {
@@ -14584,7 +14628,7 @@ void ggml_build_backward_gradient_checkpointing(
             // insert new tensors recomputing src, reusing already made replacements,
             // remember replacements: remember new tensors with mapping from corresponding gf nodes
             // recurse for input tensors,
-            // unless (i.e. terminating when) input tensors are replacments (like checkpoints)
+            // unless (i.e. terminating when) input tensors are replacements (like checkpoints)
             node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
         }
         // insert rewritten backward node with replacements made into resulting backward graph gb
index 937521ef9b2caf44812ffd2d9000939646307a83..afca85143fb30e1f1f85698a815c75b99b24bf42 100644 (file)
@@ -20,8 +20,6 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
     size_t size = ggml_nelements(tensor);
     std::vector<float> data(size);
 
-    std::random_device rd;
-
 #if 0
     std::default_random_engine generator(rd());
     std::uniform_real_distribution<float> distribution(min, max);
@@ -31,6 +29,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
     }
 #endif
     auto init_thread = [&](size_t start, size_t end) {
+        std::random_device rd;
         std::default_random_engine generator(rd());
         std::uniform_real_distribution<float> distribution(min, max);
 
@@ -51,7 +50,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
         t.join();
     }
 
-    if (tensor->type == GGML_TYPE_F32) {
+    if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) {
         ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float));
     } else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16) {
         GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0);
@@ -71,23 +70,28 @@ static std::vector<float> tensor_to_float(const ggml_tensor * t) {
     std::vector<uint8_t> buf(ggml_nbytes(t));
     ggml_backend_tensor_get(t, buf.data(), 0, ggml_nbytes(t));
 
+    ggml_type_traits_t tt = ggml_internal_get_type_traits(t->type);
+    size_t bs = ggml_blck_size(t->type);
+
     // access elements by index to avoid gaps in views
     for (int64_t i3 = 0; i3 < t->ne[3]; i3++) {
         for (int64_t i2 = 0; i2 < t->ne[2]; i2++) {
             for (int64_t i1 = 0; i1 < t->ne[1]; i1++) {
-                for (int64_t i0 = 0; i0 < t->ne[0]; i0++) {
-                    size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0*t->nb[0];
-                    float v;
+                for (int64_t i0 = 0; i0 < t->ne[0]; i0 += bs) {
+                    size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0];
                     if (t->type == GGML_TYPE_F16) {
-                        v = (float) ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i]);
+                        tv.push_back(ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i]));
                     } else if (t->type == GGML_TYPE_F32) {
-                        v = *(float *) &buf[i];
+                        tv.push_back(*(float *) &buf[i]);
                     } else if (t->type == GGML_TYPE_I32) {
-                        v = *(int32_t *) &buf[i];
+                        tv.push_back((float)*(int32_t *) &buf[i]);
+                    } else if (ggml_is_quantized(t->type)) {
+                        std::vector<float> vq(ggml_blck_size(t->type));
+                        tt.to_float(&buf[i], vq.data(), ggml_blck_size(t->type));
+                        tv.insert(tv.end(), vq.begin(), vq.end());
                     } else {
                         GGML_ASSERT(false);
                     }
-                    tv.push_back(v);
                 }
             }
         }
@@ -238,6 +242,10 @@ enum test_mode {
 struct test_case {
     virtual ~test_case() {}
 
+    virtual std::string op_desc(ggml_tensor * t) {
+        return ggml_op_desc(t);
+    }
+
     virtual std::string vars() {
         return "";
     }
@@ -245,7 +253,7 @@ struct test_case {
     virtual ggml_tensor * build_graph(ggml_context * ctx) = 0;
 
     virtual double max_nmse_err() {
-        return 1e-6;
+        return 1e-7;
     }
 
     virtual void initialize_tensors(ggml_context * ctx) {
@@ -331,13 +339,13 @@ struct test_case {
 
         ggml_tensor * out = build_graph(ctx);
 
-        if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) {
-            //printf("  %s: skipping\n", ggml_op_desc(out));
+        if (op_name != nullptr && op_desc(out) != op_name) {
+            //printf("  %s: skipping\n", op_desc(out).c_str());
             ggml_free(ctx);
             return true;
         }
 
-        printf("  %s(%s): ", ggml_op_desc(out), vars().c_str());
+        printf("  %s(%s): ", op_desc(out).c_str(), vars().c_str());
         fflush(stdout);
 
         // check if backends support op
@@ -400,7 +408,7 @@ struct test_case {
             for (size_t i = 0; i < f1.size(); i++) {
                 // check for nans
                 if (std::isnan(f1[i]) || std::isnan(f2[i])) {
-                    printf("NaN at index %zu ", i);
+                    printf("[%s] NaN at index %zu (%f %f) ", ggml_op_desc(t1), i, f1[i], f2[i]);
                     ud->ok = false;
                     return true;
                 }
@@ -408,12 +416,12 @@ struct test_case {
                 if (isinf_or_max(f1[i]) || isinf_or_max(f2[i])) {
                     if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) {
                         if (std::signbit(f1[i]) != std::signbit(f2[i])) {
-                            printf("inf sign mismatch: %f %f ", f1[i], f2[i]);
+                            printf("[%s] inf sign mismatch: %f %f ", ggml_op_desc(t1), f1[i], f2[i]);
                             ud->ok = false;
                             return true;
                         }
                     } else {
-                        printf("inf mismatch: %f %f ", f1[i], f2[i]);
+                        printf("[%s] inf mismatch: %f %f ", ggml_op_desc(t1), f1[i], f2[i]);
                         ud->ok = false;
                         return true;
                     }
@@ -422,10 +430,17 @@ struct test_case {
 
             double err = nmse(f1.data(), f2.data(), f1.size());
             if (err > ud->max_err) {
-                printf("NMSE = %f ", err);
+                printf("[%s] NMSE = %f ", ggml_op_desc(t1), err);
+                //for (int i = 0; i < f1.size(); i++) {
+                //    printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
+                //}
+                //printf("\n");
+                //exit(1);
                 ud->ok = false;
             }
             return true;
+
+            GGML_UNUSED(index);
         };
 
         ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud);
@@ -457,13 +472,13 @@ struct test_case {
 
         ggml_tensor * out = build_graph(ctx);
 
-        if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) {
-            //printf("  %s: skipping\n", ggml_op_desc(out));
+        if (op_name != nullptr && op_desc(out) != op_name) {
+            //printf("  %s: skipping\n", op_desc(out).c_str());
             ggml_free(ctx);
             return true;
         }
 
-        int len = printf("  %s(%s): ", ggml_op_desc(out), vars().c_str());
+        int len = printf("  %s(%s): ", op_desc(out).c_str(), vars().c_str());
         fflush(stdout);
 
         // check if backends support op
@@ -515,8 +530,9 @@ struct test_case {
             return size;
         };
         for (int i = 0; i < gf->n_nodes; i++) {
-            if (ggml_is_view_op(gf->nodes[i]->op) || gf->nodes[i] == out)
+            if (ggml_is_view_op(gf->nodes[i]->op) || gf->nodes[i] == out) {
                 continue;
+            }
             mem += tensor_op_size(gf->nodes[i]);
         }
 
@@ -571,17 +587,22 @@ struct test_get_rows : public test_case {
     const int n; // cols
     const int m; // rows
     const int r; // rows to get
+    const int b; // batch size
+    const bool v; // view (non-contiguous src1)
 
     std::string vars() override {
-        return VARS_TO_STR4(type, n, m, r);
+        return VARS_TO_STR6(type, n, m, r, b, v);
     }
 
-    test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3)
-        : type(type), n(n), m(m), r(r) {}
+    test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1, bool v = false)
+        : type(type), n(n), m(m), r(r), b(b), v(v) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
-        ggml_tensor * in = ggml_new_tensor_2d(ctx, type, n, m);
-        ggml_tensor * rows = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, r);
+        ggml_tensor * in = ggml_new_tensor_3d(ctx, type, n, m, b);
+        ggml_tensor * rows = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, r, b);
+        if (v) {
+            rows = ggml_view_2d(ctx, rows, r/2, b, rows->nb[1], 0);
+        }
         ggml_tensor * out = ggml_get_rows(ctx, in, rows);
         return out;
     }
@@ -589,12 +610,13 @@ struct test_get_rows : public test_case {
     void initialize_tensors(ggml_context * ctx) override {
         for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
             if (t->type == GGML_TYPE_I32) {
+                if (ggml_is_view_op(t->op)) { continue; }
                 // rows
-                std::vector<int> data(r);
-                for (int i = 0; i < r; i++) {
+                std::vector<int> data(r*b);
+                for (int i = 0; i < r*b; i++) {
                     data[i] = rand() % m;
                 }
-                ggml_backend_tensor_set(t, data.data(), 0, r * sizeof(int));
+                ggml_backend_tensor_set(t, data.data(), 0, r * b * sizeof(int));
             } else {
                 init_tensor_uniform(t);
             }
@@ -855,11 +877,10 @@ struct test_mul_mat_id : public test_case {
     const int64_t m;
     const int64_t n;
     const int64_t k;
-    const std::array<int64_t, 2> bs; // dims 3 and 4
-    const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
+    const bool v; // view (non-contiguous ids)
 
     std::string vars() override {
-        return VARS_TO_STR9(type_a, type_b, n_mats, id, m, n, k, bs, nr);
+        return VARS_TO_STR8(type_a, type_b, n_mats, id, m, n, k, v);
     }
 
     double max_nmse_err() override {
@@ -867,7 +888,7 @@ struct test_mul_mat_id : public test_case {
     }
 
     size_t op_size(ggml_tensor * t) override {
-        size_t a = ggml_nbytes(t->src[2]) * n * nr[0] * nr[1];
+        size_t a = ggml_nbytes(t->src[2]) * n;
         size_t b = ggml_nbytes(t->src[1]) * m;
         size_t c  = ggml_nbytes(t);
         return a + b + c;
@@ -877,35 +898,41 @@ struct test_mul_mat_id : public test_case {
 
     test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
             int n_mats = 2, int id = 0,
-            int64_t m = 32, int64_t n = 32, int64_t k = 32,
-            std::array<int64_t, 2> bs = {10, 10},
-            std::array<int64_t, 2> nr = {2, 2})
+            int64_t m = 32, int64_t n = 32, int64_t k = 32, bool v = false)
         : type_a(type_a), type_b(type_b), n_mats(n_mats), id(id),
-            m(m), n(n), k(k), bs(bs), nr(nr) {}
+            m(m), n(n), k(k), v(v) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         // C^T = A * B^T: (k, m) * (k, n) => (m, n)
         std::vector<ggml_tensor *> mats;
         for (int i = 0; i < n_mats; i++) {
-            ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
+            ggml_tensor * a = ggml_new_tensor_2d(ctx, type_a, k, m);
             mats.push_back(a);
         }
-        ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_mats);
-        ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
-        ggml_tensor * out = ggml_mul_mat_id(ctx, mats.data(), ids, id, b);
+        ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n);
+        if (v) {
+            ids = ggml_view_2d(ctx, ids, n_mats/2, ids->ne[1], ids->nb[1], 0);
+        }
+        ggml_tensor * b = ggml_new_tensor_2d(ctx, type_b, k, n);
+        ggml_tensor * out = ggml_mul_mat_id(ctx, mats.data(), n_mats, ids, v ? id/2 : id, b);
         return out;
     }
 
     void initialize_tensors(ggml_context * ctx) override {
+        std::random_device rd;
+        std::default_random_engine rng(rd());
         for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
             if (t->type == GGML_TYPE_I32) {
+                if (ggml_is_view_op(t->op)) { continue; }
                 // ids
-                std::vector<int> data(n_mats);
-                for (int i = 0; i < n_mats; i++) {
-                    data[i] = i;
+                for (int64_t r = 0; r < ggml_nrows(t); r++) {
+                    std::vector<int32_t> data(t->ne[0]);
+                    for (int i = 0; i < t->ne[0]; i++) {
+                        data[i] = i % n_mats;
+                    }
+                    std::shuffle(data.begin(), data.end(), rng);
+                    ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
                 }
-                std::shuffle(data.begin(), data.end(), std::default_random_engine(std::random_device()()));
-                ggml_backend_tensor_set(t, data.data(), 0, n_mats * sizeof(int));
             } else {
                 init_tensor_uniform(t);
             }
@@ -1306,17 +1333,115 @@ struct test_leaky_relu : public test_case {
     }
 };
 
+// Mixtral MOE
+struct test_moe : public test_case {
+    const int n_experts;
+    const int n_experts_per_tok;
+    const int n_tokens;
+    const int n_embd;
+    const int n_ff;
+
+    std::string op_desc(ggml_tensor * t) override {
+        return "MOE";
+
+        GGML_UNUSED(t);
+    }
+
+    std::string vars() override {
+        return VARS_TO_STR5(n_experts, n_experts_per_tok, n_tokens, n_embd, n_ff);
+    }
+
+    test_moe(int n_experts = 8, int n_experts_per_tok = 2, int n_tokens = 1, int n_embd = 4096, int n_ff = 14336)
+        : n_experts(n_experts), n_experts_per_tok(n_experts_per_tok), n_tokens(n_tokens), n_embd(n_embd), n_ff(n_ff) {
+    }
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * ffn_gate_inp = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_experts);
+
+        std::vector<ggml_tensor *> ffn_up_exp(n_experts);
+        std::vector<ggml_tensor *> ffn_gate_exp(n_experts);
+        std::vector<ggml_tensor *> ffn_down_exp(n_experts);
+
+        for (int i = 0; i < n_experts; ++i) {
+            ffn_up_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
+            ffn_gate_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
+            ffn_down_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd);
+        }
+
+        ggml_tensor * cur = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens);
+
+        ggml_tensor * logits = ggml_mul_mat(ctx, ffn_gate_inp, cur);
+        ggml_tensor * probs = ggml_soft_max_ext(ctx, logits, nullptr, 1.0f/sqrtf(n_embd));
+
+        // select experts
+        ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_experts_per_tok);
+
+        ggml_tensor * weights = ggml_get_rows(ctx,
+                ggml_reshape_3d(ctx, probs, 1, n_experts, n_tokens), selected_experts);
+
+        weights = ggml_reshape_2d(ctx, weights, n_experts_per_tok, n_tokens);
+
+        ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights);
+
+        weights = ggml_div(ctx, weights, weights_sum);
+
+        // compute expert outputs
+        ggml_tensor * moe_out = nullptr;
+
+        for (int i = 0; i < n_experts_per_tok; ++i) {
+            ggml_tensor * cur_expert;
+
+            ggml_tensor * cur_up = ggml_mul_mat_id(ctx, ffn_up_exp.data(), n_experts, selected_experts, i, cur);
+
+            ggml_tensor * cur_gate = ggml_mul_mat_id(ctx, ffn_gate_exp.data(), n_experts, selected_experts, i, cur);
+
+            cur_gate = ggml_silu(ctx, cur_gate);
+
+            cur_expert = ggml_mul(ctx, cur_up, cur_gate);
+
+            cur_expert = ggml_mul_mat_id(ctx, ffn_down_exp.data(), n_experts, selected_experts, i, cur_expert);
+
+            cur_expert = ggml_mul(ctx, cur_expert,
+                    ggml_view_2d(ctx, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
+
+            if (i == 0) {
+                moe_out = cur_expert;
+            } else {
+                moe_out = ggml_add(ctx, moe_out, cur_expert);
+            }
+        }
+
+        cur = moe_out;
+
+        return cur;
+    }
+};
+
 static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
     std::vector<std::unique_ptr<test_case>> test_cases;
 
+    const ggml_type all_types[] = {
+        GGML_TYPE_F32, GGML_TYPE_F16,
+        GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
+        GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
+        GGML_TYPE_Q8_0,
+        GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
+        GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
+        GGML_TYPE_Q6_K
+    };
+
     // unary ops
     for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
         test_cases.emplace_back(new test_unary((ggml_unary_op) op));
     }
 
-    for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
-        test_cases.emplace_back(new test_get_rows(type, 10, 5, 3));
-        test_cases.emplace_back(new test_get_rows(type, 16, 5, 3));
+    test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false));
+    for (ggml_type type : all_types) {
+        for (int b : {1, 7}) {
+            for (bool v : {false, true}) {
+                test_cases.emplace_back(new test_get_rows(type, 256, 5, 4, b, v));
+            }
+        }
     }
 
     test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1}));
@@ -1326,7 +1451,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 2}));
 
     test_cases.emplace_back(new test_dup());
-    test_cases.emplace_back(new test_cpy());
+
+    for (ggml_type type : all_types) {
+       test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, type, {256, 10, 10, 1}));
+    }
+
     test_cases.emplace_back(new test_cont());
 
     auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
@@ -1336,6 +1465,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     };
 
     add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 8, 1}, {1, 1, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1, 1}, {32, 1, 1, 1});
     add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 320, 320}, {1, 1, 1, 1});
     add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 1, 1}, {1, 1, 1, 1});
     add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 1}, {1, 1, 1, 1});
@@ -1362,8 +1492,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 640, 1}, {32, 32, 1, 1});
     add_test_bin_bcast(GGML_TYPE_F32, {5120, 1, 1, 1}, {1, 256, 1, 1});
     add_test_bin_bcast(GGML_TYPE_F32, {640, 1, 1, 1}, {1, 1, 1, 1});
-    add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {1, 1, 1, 1});
-    add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {2, 1, 1, 1});
+    //add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {1, 1, 1, 1});
+    //add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {2, 1, 1, 1});
 
     test_cases.emplace_back(new test_scale());
 
@@ -1372,16 +1502,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
     }
 
-    const ggml_type all_types[] = {
-        GGML_TYPE_F32, GGML_TYPE_F16,
-        GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
-        GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
-        GGML_TYPE_Q8_0,
-        GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
-        GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
-        GGML_TYPE_Q6_K
-    };
-
     for (ggml_type type_a : all_types) {
         for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
             // FIXME: CPU crashes on f16xf16
@@ -1405,9 +1525,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
 
     for (ggml_type type_a : all_types) {
         for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
-            for (int n_mats : {1, 2, 4}) {
+            for (int n_mats : {2, 4, 8}) {
                 for (int id = 0; id < n_mats; id++) {
-                    test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, {1, 1}, {1, 1}));
+                    for (bool v : {false, true}) {
+                        test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, v));
+                    }
                 }
             }
         }
@@ -1439,6 +1561,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     test_cases.emplace_back(new test_concat());
 
     for (ggml_sort_order order : {GGML_SORT_ASC, GGML_SORT_DESC}) {
+        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
     }
 
@@ -1449,6 +1572,12 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     test_cases.emplace_back(new test_pad());
     test_cases.emplace_back(new test_leaky_relu());
 
+#if !defined(__SANITIZE_THREAD__)
+    // FIXME: these tests use too much memory with thread sanitizer
+    test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 8*1024));
+    //test_cases.emplace_back(new test_moe(8, 2, 8, 4096, 14336));
+#endif
+
     // run tests
     if (mode == MODE_TEST) {
         ggml_backend_t backend_cpu = ggml_backend_cpu_init();
@@ -1464,14 +1593,17 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         ggml_backend_free(backend_cpu);
 
         return n_ok == test_cases.size();
-    } else if (mode == MODE_PERF) {
+    }
+
+    if (mode == MODE_PERF) {
         for (auto & test : test_cases) {
             test->eval_perf(backend, op_name);
         }
         return true;
-    } else {
-        GGML_ASSERT(false);
     }
+
+    GGML_ASSERT(false);
+    return false;
 }
 
 static void usage(char ** argv) {
@@ -1544,11 +1676,12 @@ int main(int argc, char ** argv) {
     }
 
     printf("%zu/%zu backends passed\n", n_ok, ggml_backend_reg_get_count());
+
     if (n_ok != ggml_backend_reg_get_count()) {
         printf("\033[1;31mFAIL\033[0m\n");
         return 1;
-    } else {
-        printf("\033[1;32mOK\033[0m\n");
-        return 0;
     }
+
+    printf("\033[1;32mOK\033[0m\n");
+    return 0;
 }
index 7fe9154ddbb1663106a4845dcf8d515070873e13..81c20a89cb586b14cf5b264170f96e0d84f39d56 100644 (file)
@@ -1,4 +1,4 @@
-#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
+#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
 #include "ggml.h"
 
 #include <cmath>
index 88fac0e23106bc4f11c4899da9c12479a29c352a..62d0190f9066c010d565e92c94442c628179525f 100644 (file)
@@ -117,7 +117,7 @@ static void usage(char * argv[]) {
     printf("  --size SIZE           set test size, divisible by 32 (L1_SIZE:%d)\n", L1_SIZE);
     printf("  -3                    use size as L1, L2, L3 sizes (L1:%d L2:%d L3:%d)\n", L1_SIZE, L2_SIZE, L3_SIZE);
     printf("  -4                    use size as L1, L2, L3, MEM sizes (L1:%d L2:%d L3:%d MEM:%d)\n", L1_SIZE, L2_SIZE, L3_SIZE, MEM_SIZE);
-    printf("  --op OP               set test opration as quantize_row_q_reference, quantize_row_q, dequantize_row_q,\n");
+    printf("  --op OP               set test operation as quantize_row_q_reference, quantize_row_q, dequantize_row_q,\n");
     printf("                        quantize_row_q_dot, vec_dot_q (all)\n");
     printf("  --type TYPE           set test type as");
     for (int i = 0; i < GGML_TYPE_COUNT; i++) {
@@ -202,7 +202,7 @@ int main(int argc, char * argv[]) {
             }
             int alignment = std::stoi(argv[i]);
             if (alignment < 0 || alignment > MAX_ALIGNMENT) {
-            fprintf(stderr, "error: aligment-offset must be less than %d\n", MAX_ALIGNMENT);
+            fprintf(stderr, "error: alignment-offset must be less than %d\n", MAX_ALIGNMENT);
                 invalid_param = true;
                 break;
             }