]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : mul_mat_id use the same tensor for all the experts (llama/6387)
authorslaren <redacted>
Wed, 3 Apr 2024 13:07:05 +0000 (15:07 +0200)
committerGeorgi Gerganov <redacted>
Sun, 7 Apr 2024 13:15:57 +0000 (16:15 +0300)
* ggml : update mul_mat_id to use the same tensor for all the experts

* update cuda

* minor

* update metal

* update test-backend-ops

* fix cuda

* Update ggml-metal.m

Co-authored-by: Georgi Gerganov <redacted>
* update convert.py

* update convert-hf-to-gguf.py

* update convert.py for mixtral hf models

* Update convert-hf-to-gguf.py

Co-authored-by: Georgi Gerganov <redacted>
* cuda : support non-pow-2 number of experts

* allow quantize to work for split and merged experts models in the same way

* cleanup + disable mmap automatically with split tensors models

* update imatrix

* test-backend-ops : test qwen argsort

* update grok model loading

* llama : add merged experts tensors to the grok tensor map

* minor

* gguf : bump version

* fix quantizing of merged experts

* convert-hf-to-gguf.py : update grok (untested)

* make linter happy

* cuda/argsort : use shared memory instead of pool memory

* convert : fix grok tensor names

* metal : add support for non-pow-2 argsort

* llama : more loader cleanup, better error checking

* cuda : fix warning

* llama : still use mmap for loading old models, but copy the data to a host buffer

* add review note

* llama : remove ffn tensor counting + add sanity check

ggml-ci

* convert : fix handling of n_experts == None

ggml-ci

* imatrix : fix ncall counters

* llama : produce error if imatrix size does not match

* quantize : terminate on errors + trace logs

ggml-ci

* metal : pad shared memory to 16 bytes

---------

Co-authored-by: Georgi Gerganov <redacted>
ggml-cuda.cu
ggml-cuda/argsort.cu
ggml-metal.m
ggml-metal.metal
ggml.c
ggml.h

index 5607386cee9b5be72b233961c1f2ba0f3d1c3b5c..ce28cb55d01b2f70d47dcefc88c624ec5ba5b598 100644 (file)
@@ -401,10 +401,8 @@ GGML_CALL static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t
 GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
     ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
 
-    if (tensor->view_src != NULL && tensor->view_offs == 0) {
+    if (tensor->view_src != NULL) {
         assert(tensor->view_src->buffer->buft == buffer->buft);
-        tensor->backend = tensor->view_src->backend;
-        tensor->extra = tensor->view_src->extra;
         return;
     }
 
@@ -1962,227 +1960,49 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
     }
 }
 
-#if 0
-template<typename ... Srcs>
-static __global__ void k_compute_batched_ptrs_id(
-        const void ** ptrs_src, void ** ptrs_dst,
-        int ne12, int ne13,
-        int ne23,
-        int nb02, int nb03,
-        int nb12, int nb13,
-        int nb2, int nb3,
-        int r2, int r3,
-        ggml_type src0_type, half * src0_as_f16, int64_t src0_ne,
-        const half * src1_f16, half * dst_f16,
-        const int32_t * ids, const int id,
-        Srcs... src0s) {
-
-    int i = ids[id];
-
-    half * src0_f16;
-    const void * srcs_ar[] = { (const half *) src0s... };
-    if (src0_type == GGML_TYPE_F16) {
-        src0_f16 = (half *) srcs_ar[i];
-    } else {
-        src0_f16 = src0_as_f16;
-        if (threadIdx.x == 0 && threadIdx.y == 0) {
-            const to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(src0_type);
-            to_fp16(srcs_ar[i], src0_f16, src0_ne, cudaStreamFireAndForget);
-        }
-    }
-
-    int i13 = blockIdx.x * blockDim.x + threadIdx.x;
-    int i12 = blockIdx.y * blockDim.y + threadIdx.y;
-
-    if (i13 >= ne13 || i12 >= ne12) {
-        return;
-    }
-
-    int i03 = i13 / r3;
-    int i02 = i12 / r2;
-
-    ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_f16 + i02*nb02   + i03*nb03;
-    ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_f16 + i12*nb12/2 + i13*nb13/2;
-    ptrs_dst[0*ne23 + i12 + i13*ne12] = (      char *)  dst_f16 + i12* nb2/2 + i13* nb3/2;
-}
-
-static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) {
-    const struct ggml_tensor * ids = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-    const struct ggml_tensor * src00 = dst->src[2];
-
-    const int id = dst->op_params[0];
-
-    GGML_ASSERT(!ggml_is_transposed(src00));
-    GGML_ASSERT(!ggml_is_transposed(src1));
-
-    GGML_ASSERT(src00->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
-    const int64_t ne00 = src00->ne[0]; GGML_UNUSED(ne00);
-    const int64_t ne01 = src00->ne[1];
-    const int64_t ne02 = src00->ne[2];
-    const int64_t ne03 = src00->ne[3];
-
-    //const int64_t nb01 = src00->nb[1];
-    const int64_t nb02 = src00->nb[2]; GGML_UNUSED(nb02);
-    const int64_t nb03 = src00->nb[3]; GGML_UNUSED(nb03);
-
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
-    const int64_t ne12 = src1->ne[2];
-    const int64_t ne13 = src1->ne[3];
-
-    //const int64_t nb11 = src1->nb[1];
-    const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12);
-    const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13);
-
-    const int64_t ne1 = ggml_nelements(src1);
-    const int64_t ne  = ggml_nelements(dst);
-
-    ggml_cuda_set_device(g_main_device);
-    cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
-
-    CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream));
-
-    //ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
-    //void * src0_ddq = src0_extra->data_device[g_main_device];
-    //half * src0_as_f16 = (half *) src0_ddq;
-
-    ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
-    float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
-
-    ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
-    float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
-
-    // convert src1 to fp16
-    const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
-    GGML_ASSERT(to_fp16_cuda != nullptr);
-
-    size_t src1_as = 0;
-    half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
-    to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
-
-    size_t dst_as = 0;
-    half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
-
-    GGML_ASSERT(ne12 % ne02 == 0);
-    GGML_ASSERT(ne13 % ne03 == 0);
-
-    // broadcast factors
-    const int64_t r2 = ne12/ne02;
-    const int64_t r3 = ne13/ne03;
-
-    const half alpha_f16 = 1.0f;
-    const half beta_f16  = 0.0f;
-
-    // use cublasGemmBatchedEx
-    const int ne23 = ne12*ne13;
-
-    const void ** ptrs_src = nullptr;
-          void ** ptrs_dst = nullptr;
-
-    size_t ptrs_src_s = 0;
-    size_t ptrs_dst_s = 0;
-
-    ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s);
-    ptrs_dst = (      void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s);
-
-    int64_t src0_ne = ggml_nelements(src00);
-    half * src0_as_f16 = nullptr;
-    size_t src0_as = 0;
-    if (src00->type != GGML_TYPE_F16) {
-        src0_as_f16 = (half *) ggml_cuda_pool_malloc(src0_ne * sizeof(half), &src0_as);
-    }
-
-    static_assert(GGML_MAX_SRC == 6, "GGML_MAX_SRC == 6");
-    dim3 block_dims(ne13, ne12);
-    k_compute_batched_ptrs_id<<<1, block_dims, 0, main_stream>>>(
-            ptrs_src, ptrs_dst,
-            ne12, ne13,
-            ne23,
-            ne00*ne01*sizeof(half), ne00*ne01*ne02*sizeof(half),
-            nb12, nb13,
-            dst->nb[2], dst->nb[3],
-            r2, r3,
-            src00->type, src0_as_f16, src0_ne,
-            src1_as_f16, dst_f16,
-            (const int *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device], id,
-            dst->src[2] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[2]->extra)->data_device[g_main_device] : nullptr,
-            dst->src[3] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[3]->extra)->data_device[g_main_device] : nullptr,
-            dst->src[4] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[4]->extra)->data_device[g_main_device] : nullptr,
-            dst->src[5] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[5]->extra)->data_device[g_main_device] : nullptr
-    );
-    CUDA_CHECK(cudaGetLastError());
-
-    CUBLAS_CHECK(
-    cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
-            ne01, ne11, ne10,
-            &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, ne00,
-                        (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, ne10,
-            &beta_f16,  (      void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
-            ne23,
-            CUBLAS_COMPUTE_16F,
-            CUBLAS_GEMM_DEFAULT_TENSOR_OP));
-
-    if (src0_as != 0) {
-        ggml_cuda_pool_free(src0_as_f16, src0_as);
-    }
-    if (ptrs_src_s != 0) {
-        ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
-    }
-    if (ptrs_dst_s != 0) {
-        ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
-    }
-
-    const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
-    to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
-
-    ggml_cuda_pool_free(src1_as_f16, src1_as);
-    ggml_cuda_pool_free(dst_f16, dst_as);
-}
-#endif
-
 static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-#if 0
-    ggml_cuda_mul_mat_id_cublas(dst);
-    // TODO: mmq/mmv support
-#endif
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
+    const ggml_tensor * ids  = dst->src[2];
+
+    GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers");
 
     cudaStream_t stream = ctx.stream();
 
     const size_t nb11 = src1->nb[1];
     const size_t nb1  =  dst->nb[1];
 
-    const struct ggml_tensor * ids = src0;
     const int32_t id = ((int32_t *) dst->op_params)[0];
-    const int32_t n_as = ((int32_t *) dst->op_params)[1];
+    const int32_t n_as = src0->ne[2];
 
     std::vector<char> ids_host(ggml_nbytes(ids));
     const char * ids_dev = (const char *) ids->data;
     CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
     CUDA_CHECK(cudaStreamSynchronize(stream));
 
+    ggml_tensor src0_row = *src0;
     ggml_tensor src1_row = *src1;
     ggml_tensor dst_row = *dst;
 
+    char * src0_original = (char *) src0->data;
     char * src1_original = (char *) src1->data;
     char * dst_original  = (char *)  dst->data;
 
+    src0_row.ne[2] = 1;
+    src0_row.ne[3] = 1;
+    src0_row.nb[3] = src0->nb[2];
+
     if (src1->ne[1] == 1) {
         for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
             const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
 
             GGML_ASSERT(row_id >= 0 && row_id < n_as);
 
-            const struct ggml_tensor * src0_row = dst->src[row_id + 2];
-
+            src0_row.data = src0_original + row_id*src0->nb[2];
             src1_row.data = src1_original + i01*src1->nb[1];
             dst_row.data  =  dst_original + i01*dst->nb[1];
 
-            ggml_cuda_mul_mat(ctx, src0_row, &src1_row, &dst_row);
+            ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
         }
     } else {
         ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
@@ -2192,8 +2012,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
         dst_row.data  =  dst_contiguous.get();
 
         for (int32_t row_id = 0; row_id < n_as; ++row_id) {
-            const struct ggml_tensor * src0_row = dst->src[row_id + 2];
-
             int64_t num_src1_rows = 0;
             for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
                 const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
@@ -2213,6 +2031,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
                 continue;
             }
 
+            src0_row.data = src0_original + row_id*src0->nb[2];
+
             src1_row.ne[1] = num_src1_rows;
             dst_row.ne[1] = num_src1_rows;
 
@@ -2224,7 +2044,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
             dst_row.nb[2] = num_src1_rows*nb1;
             dst_row.nb[3] = num_src1_rows*nb1;
 
-            ggml_cuda_mul_mat(ctx, src0_row, &src1_row, &dst_row);
+            ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
 
             num_src1_rows = 0;
             for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
@@ -2389,7 +2209,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
     cudaError_t err = cudaGetLastError();
     if (err != cudaSuccess) {
         fprintf(stderr, "%s: %s failed\n", __func__, ggml_op_desc(dst));
-        GGML_ASSERT(false);
+        CUDA_CHECK(err);
     }
 
     return true;
index 1333287e42e45a955fa796d2c763d00537520426..1641440617779e9da3de304e678f475bd569675a 100644 (file)
@@ -8,32 +8,41 @@ static inline __device__ void ggml_cuda_swap(T & a, T & b) {
 }
 
 template<ggml_sort_order order>
-static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols) {
+static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
     // bitonic sort
     int col = threadIdx.x;
     int row = blockIdx.y;
 
-    if (col >= ncols) return;
+    if (col >= ncols_pad) {
+        return;
+    }
 
     const float * x_row = x + row * ncols;
-    int * dst_row = dst + row * ncols;
+    extern __shared__ int dst_row[];
 
     // initialize indices
-    if (col < ncols) {
-        dst_row[col] = col;
-    }
+    dst_row[col] = col;
+
     __syncthreads();
 
-    for (int k = 2; k <= ncols; k *= 2) {
+    for (int k = 2; k <= ncols_pad; k *= 2) {
         for (int j = k / 2; j > 0; j /= 2) {
             int ixj = col ^ j;
             if (ixj > col) {
                 if ((col & k) == 0) {
-                    if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
+                    if (dst_row[col] >= ncols ||
+                        (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
+                            x_row[dst_row[col]] > x_row[dst_row[ixj]] :
+                            x_row[dst_row[col]] < x_row[dst_row[ixj]]))
+                    ) {
                         ggml_cuda_swap(dst_row[col], dst_row[ixj]);
                     }
                 } else {
-                    if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
+                    if (dst_row[ixj] >= ncols ||
+                        (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
+                            x_row[dst_row[col]] < x_row[dst_row[ixj]] :
+                            x_row[dst_row[col]] > x_row[dst_row[ixj]]))
+                    ) {
                         ggml_cuda_swap(dst_row[col], dst_row[ixj]);
                     }
                 }
@@ -41,18 +50,35 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int n
             __syncthreads();
         }
     }
+
+    // copy the result to dst without the padding
+    if (col < ncols) {
+        dst[row * ncols + col] = dst_row[col];
+    }
+}
+
+static int next_power_of_2(int x) {
+    int n = 1;
+    while (n < x) {
+        n *= 2;
+    }
+    return n;
 }
 
 static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
     // bitonic sort requires ncols to be power of 2
-    GGML_ASSERT((ncols & (ncols - 1)) == 0);
+    const int ncols_pad = next_power_of_2(ncols);
 
-    const dim3 block_dims(ncols, 1, 1);
+    const dim3 block_dims(ncols_pad, 1, 1);
     const dim3 block_nums(1, nrows, 1);
+    const size_t shared_mem = ncols_pad * sizeof(int);
+
+    GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
+
     if (order == GGML_SORT_ORDER_ASC) {
-        k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
+        k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
     } else if (order == GGML_SORT_ORDER_DESC) {
-        k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
+        k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
     } else {
         GGML_ASSERT(false);
     }
index a08abbc2918028cc178d997a67946b8030295fd6..419d8b9e56878f7638c984f129225dc6f3474e1d 100644 (file)
@@ -1685,37 +1685,31 @@ static enum ggml_status ggml_metal_graph_compute(
                     {
                         //GGML_ASSERT(ne00 == ne10);
                         //GGML_ASSERT(ne03 == ne13);
-
-                        GGML_ASSERT(src0t == GGML_TYPE_I32);
-
-                        const int n_as = ((int32_t *) dst->op_params)[1];
-
-                        // TODO: make this more general
-                        GGML_ASSERT(n_as <= 8);
+                        const int n_as = src0->ne[2];
 
                         // max size of the src1ids array in the kernel shared buffer
                         GGML_ASSERT(ne11 <= 4096);
 
-                        const int64_t  ne20 = src2 ? src2->ne[0] : 0;
-                        const int64_t  ne21 = src2 ? src2->ne[1] : 0;
-                        const int64_t  ne22 = src2 ? src2->ne[2] : 0;
-                        const int64_t  ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
+                        // src2 = ids
+                        const int64_t  ne20 = src2->ne[0]; GGML_UNUSED(ne20);
+                        const int64_t  ne21 = src2->ne[1];
+                        const int64_t  ne22 = src2->ne[2]; GGML_UNUSED(ne22);
+                        const int64_t  ne23 = src2->ne[3]; GGML_UNUSED(ne23);
+
+                        const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20);
+                        const uint64_t nb21 = src2->nb[1];
+                        const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22);
+                        const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23);
 
-                        const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
-                        const uint64_t nb21 = src2 ? src2->nb[1] : 0;
-                        const uint64_t nb22 = src2 ? src2->nb[2] : 0;
-                        const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
+                        const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
 
-                        const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
+                        GGML_ASSERT(src2t == GGML_TYPE_I32);
 
-                        GGML_ASSERT(!ggml_is_transposed(src2));
+                        GGML_ASSERT(!ggml_is_transposed(src0));
                         GGML_ASSERT(!ggml_is_transposed(src1));
 
                         GGML_ASSERT(src1t == GGML_TYPE_F32);
 
-                        const uint r2 = ne12/ne22;
-                        const uint r3 = ne13/ne23;
-
                         // find the break-even point where the matrix-matrix kernel becomes more efficient compared
                         // to the matrix-vector kernel
                         int ne11_mm_min = n_as;
@@ -1723,7 +1717,10 @@ static enum ggml_status ggml_metal_graph_compute(
                         const int idx = ((int32_t *) dst->op_params)[0];
 
                         // batch size
-                        GGML_ASSERT(ne01 == ne11);
+                        GGML_ASSERT(ne21 == ne11); // ?
+                        GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting
+                        const uint r2 = 1;
+                        const uint r3 = 1;
 
                         // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
                         // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
@@ -1732,7 +1729,7 @@ static enum ggml_status ggml_metal_graph_compute(
                         //       indirect matrix multiplication
                         // !!!
                         if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
-                            ne20 % 32 == 0 && ne20 >= 64 &&
+                            ne00 % 32 == 0 && ne00 >= 64 &&
                             ne11 > ne11_mm_min) {
 
                             // some Metal matrix data types require aligned pointers
@@ -1745,7 +1742,7 @@ static enum ggml_status ggml_metal_graph_compute(
 
                             id<MTLComputePipelineState> pipeline = nil;
 
-                            switch (src2->type) {
+                            switch (src0->type) {
                                 case GGML_TYPE_F32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32    ].pipeline; break;
                                 case GGML_TYPE_F16:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32    ].pipeline; break;
                                 case GGML_TYPE_Q4_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32   ].pipeline; break;
@@ -1774,36 +1771,27 @@ static enum ggml_status ggml_metal_graph_compute(
                             [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
                             [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
                             [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
-                            [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:3];
-                            [encoder setBytes:&ne20    length:sizeof(ne20) atIndex:4];
-                            [encoder setBytes:&ne22    length:sizeof(ne22) atIndex:5];
-                            [encoder setBytes:&nb21    length:sizeof(nb21) atIndex:6];
-                            [encoder setBytes:&nb22    length:sizeof(nb22) atIndex:7];
-                            [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:8];
-                            [encoder setBytes:&ne13    length:sizeof(ne13) atIndex:9];
-                            [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:10];
-                            [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:11];
-                            [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:12];
-                            [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:13];
-                            [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:14];
-                            [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:15];
-                            [encoder setBytes:&r2      length:sizeof(r2)   atIndex:16];
-                            [encoder setBytes:&r3      length:sizeof(r3)   atIndex:17];
-                            [encoder setBytes:&idx     length:sizeof(idx)  atIndex:18];
-                            // TODO: how to make this an array? read Metal docs
-                            for (int j = 0; j < 8; ++j) {
-                                // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
-                                struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
-
-                                size_t offs_src_cur = 0;
-                                id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
-
-                                [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
-                            }
+                            [encoder setBuffer:id_src2 offset:offs_src2    atIndex:3];
+                            [encoder setBytes:&nb21    length:sizeof(nb21) atIndex:4];
+                            [encoder setBytes:&ne00    length:sizeof(ne00) atIndex:5];
+                            [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:6];
+                            [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:7];
+                            [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:8];
+                            [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:9];
+                            [encoder setBytes:&ne13    length:sizeof(ne13) atIndex:10];
+                            [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:11];
+                            [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:12];
+                            [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:13];
+                            [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:14];
+                            [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:15];
+                            [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:16];
+                            [encoder setBytes:&r2      length:sizeof(r2)   atIndex:17];
+                            [encoder setBytes:&r3      length:sizeof(r3)   atIndex:18];
+                            [encoder setBytes:&idx     length:sizeof(idx)  atIndex:19];
 
                             [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
 
-                            [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                            [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                         } else {
                             int nth0 = 32;
                             int nth1 = 1;
@@ -1813,7 +1801,7 @@ static enum ggml_status ggml_metal_graph_compute(
                             id<MTLComputePipelineState> pipeline = nil;
 
                             // use custom matrix x vector kernel
-                            switch (src2t) {
+                            switch (src0t) {
                                 case GGML_TYPE_F32:
                                     {
                                         GGML_ASSERT(src1t == GGML_TYPE_F32);
@@ -1947,8 +1935,8 @@ static enum ggml_status ggml_metal_graph_compute(
                                     }
                             };
 
-                            if (ggml_is_quantized(src2t)) {
-                                GGML_ASSERT(ne20 >= nth0*nth1);
+                            if (ggml_is_quantized(src0t)) {
+                                GGML_ASSERT(ne00 >= nth0*nth1);
                             }
 
                             const int64_t _ne1 = 1; // kernels needs a reference in constant memory
@@ -1957,75 +1945,66 @@ static enum ggml_status ggml_metal_graph_compute(
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                            [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
-                            [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
-                            [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
-                            [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
-                            [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
-                            [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
-                            [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
-                            [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
-                            [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
-                            [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
-                            [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
-                            [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
-                            [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
-                            [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
-                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:17];
-                            [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
-                            [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:19];
-                            [encoder setBytes:&r2   length:sizeof(r2)   atIndex:20];
-                            [encoder setBytes:&r3   length:sizeof(r3)   atIndex:21];
-                            [encoder setBytes:&idx  length:sizeof(idx)  atIndex:22];
-                            // TODO: how to make this an array? read Metal docs
-                            for (int j = 0; j < 8; ++j) {
-                                // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
-                                struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
-
-                                size_t offs_src_cur = 0;
-                                id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
-
-                                [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
-                            }
+                            [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
+                            [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
+                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
+                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:6];
+                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:7];
+                            [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:8];
+                            [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
+                            [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
+                            [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
+                            [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12];
+                            [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
+                            [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
+                            [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
+                            [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
+                            [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
+                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:18];
+                            [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:19];
+                            [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:20];
+                            [encoder setBytes:&r2   length:sizeof(r2)   atIndex:21];
+                            [encoder setBytes:&r3   length:sizeof(r3)   atIndex:22];
+                            [encoder setBytes:&idx  length:sizeof(idx)  atIndex:23];
 
-                            if (src2t == GGML_TYPE_Q4_0  || src2t == GGML_TYPE_Q4_1  || src2t == GGML_TYPE_Q5_0 ||
-                                src2t == GGML_TYPE_Q5_1  || src2t == GGML_TYPE_Q8_0  || src2t == GGML_TYPE_Q2_K ||
-                                src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ1_M || src2t == GGML_TYPE_IQ2_S) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                            if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 ||
+                                src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
+                                src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
-                            else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
-                                const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
+                            else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
+                                const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
-                            else if (src2t == GGML_TYPE_IQ3_XXS || src2t == GGML_TYPE_IQ3_S) {
-                                const int mem_size = src2t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
+                            else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
+                                const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
-                            else if (src2t == GGML_TYPE_IQ4_NL || src2t == GGML_TYPE_IQ4_XS) {
+                            else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
                                 const int mem_size = 32*sizeof(float);
                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
-                            else if (src2t == GGML_TYPE_Q4_K) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                            else if (src0t == GGML_TYPE_Q4_K) {
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
-                            else if (src2t == GGML_TYPE_Q3_K) {
+                            else if (src0t == GGML_TYPE_Q3_K) {
 #ifdef GGML_QKK_64
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
 #else
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
 #endif
                             }
-                            else if (src2t == GGML_TYPE_Q5_K) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                            else if (src0t == GGML_TYPE_Q5_K) {
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
-                            else if (src2t == GGML_TYPE_Q6_K) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                            else if (src0t == GGML_TYPE_Q6_K) {
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             } else {
                                 const int64_t ny = (_ne1 + nrows - 1)/nrows;
-                                [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
                         }
                     } break;
@@ -2432,6 +2411,16 @@ static enum ggml_status ggml_metal_graph_compute(
 
                         enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
 
+                        // bitonic sort requires the number of elements to be power of 2
+                        int64_t ne00_padded = 1;
+                        while (ne00_padded < ne00) {
+                            ne00_padded *= 2;
+                        }
+
+                        // Metal kernels require the buffer size to be multiple of 16 bytes
+                        // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
+                        const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
+
                         id<MTLComputePipelineState> pipeline = nil;
 
                         switch (order) {
@@ -2441,11 +2430,13 @@ static enum ggml_status ggml_metal_graph_compute(
                         };
 
                         [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
-                        [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
-                        [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
+                        [encoder setBuffer:id_src0     offset:offs_src0        atIndex:0];
+                        [encoder setBuffer:id_dst      offset:offs_dst         atIndex:1];
+                        [encoder setBytes:&ne00        length:sizeof( int64_t) atIndex:2];
+                        [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
+                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
 
-                        [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
+                        [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
                     } break;
                 case GGML_OP_LEAKY_RELU:
                     {
index 744b2a8b4ce42c3e5982aceb478617ad7e1258a5..9a29f57a38c6b7cf501e09e034c6f66631074cdf 100644 (file)
@@ -13,8 +13,8 @@ using namespace metal;
 #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
 
 enum ggml_sort_order {
-    GGML_SORT_ASC,
-    GGML_SORT_DESC,
+    GGML_SORT_ORDER_ASC,
+    GGML_SORT_ORDER_DESC,
 };
 
 // general-purpose kernel for addition, multiplication and division of two tensors
@@ -1973,9 +1973,11 @@ kernel void kernel_timestep_embedding_f32(
 
 // bitonic sort implementation following the CUDA kernels as reference
 typedef void (argsort_t)(
-        device const float * x,
-        device     int32_t * dst,
-        constant   int64_t & ncols,
+        device const float  * x,
+        device     int32_t  * dst,
+        constant   int64_t  & ncols,
+        constant   int64_t  & ncols_pad,
+        threadgroup int32_t * shared_values [[threadgroup(0)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]]);
 
@@ -1984,33 +1986,42 @@ kernel void kernel_argsort_f32_i32(
         device const float   * x,
         device       int32_t * dst,
         constant     int64_t & ncols,
+        constant     int64_t & ncols_pad,
+        threadgroup int32_t  * shared_values [[threadgroup(0)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]]) {
     // bitonic sort
     int col = tpitg[0];
     int row = tgpig[1];
 
-    if (col >= ncols) return;
+    if (col >= ncols_pad) return;
 
-    device const float   * x_row   = x   + row * ncols;
-    device       int32_t * dst_row = dst + row * ncols;
+    device const float   * x_row   = x + row * ncols;
+    threadgroup int32_t  * dst_row = shared_values;
 
     // initialize indices
-    if (col < ncols) {
-        dst_row[col] = col;
-    }
+    dst_row[col] = col;
+
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    for (int k = 2; k <= ncols; k *= 2) {
+    for (int k = 2; k <= ncols_pad; k *= 2) {
         for (int j = k / 2; j > 0; j /= 2) {
             int ixj = col ^ j;
             if (ixj > col) {
                 if ((col & k) == 0) {
-                    if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
+                    if (dst_row[col] >= ncols ||
+                        (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
+                            x_row[dst_row[col]] > x_row[dst_row[ixj]] :
+                            x_row[dst_row[col]] < x_row[dst_row[ixj]]))
+                    ) {
                         SWAP(dst_row[col], dst_row[ixj]);
                     }
                 } else {
-                    if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
+                    if (dst_row[ixj] >= ncols ||
+                        (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
+                            x_row[dst_row[col]] < x_row[dst_row[ixj]] :
+                            x_row[dst_row[col]] > x_row[dst_row[ixj]]))
+                    ) {
                         SWAP(dst_row[col], dst_row[ixj]);
                     }
                 }
@@ -2018,10 +2029,15 @@ kernel void kernel_argsort_f32_i32(
             threadgroup_barrier(mem_flags::mem_threadgroup);
         }
     }
+
+    // copy the result to dst without the padding
+    if (col < ncols) {
+        dst[row * ncols + col] = dst_row[col];
+    }
 }
 
-template [[host_name("kernel_argsort_f32_i32_asc")]]  kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
-template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
+template [[host_name("kernel_argsort_f32_i32_asc")]]  kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
+template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
 
 kernel void kernel_leaky_relu_f32(
         device const float * src0,
@@ -5785,9 +5801,10 @@ kernel void kernel_mul_mm(device const  uchar * src0,
 
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
 kernel void kernel_mul_mm_id(
-        device const   uchar * ids,
+        device const   uchar * src0s,
         device const   uchar * src1,
         device         float * dst,
+        device const   uchar * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne02,
@@ -5804,22 +5821,14 @@ kernel void kernel_mul_mm_id(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const   uchar * src00,
-        device const   uchar * src01,
-        device const   uchar * src02,
-        device const   uchar * src03,
-        device const   uchar * src04,
-        device const   uchar * src05,
-        device const   uchar * src06,
-        device const   uchar * src07,
         threadgroup    uchar * shared_memory [[threadgroup(0)]],
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
 
     // expert id
     const int32_t id = tgpig.z/(ne12*ne13);
+    device const uchar * src0 = src0s + id*nb02;
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
@@ -5834,7 +5843,7 @@ kernel void kernel_mul_mm_id(
     }
 
     kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
-        src0s[id],
+        src0,
         src1,
         src1ids,
         dst,
@@ -5960,9 +5969,10 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]]  kernel mat_mm_t kernel_mul_m
 //
 
 typedef void (mat_mm_id_t)(
-        device const   uchar * ids,
+        device const   uchar * src0s,
         device const   uchar * src1,
         device         float * dst,
+        device const   uchar * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne02,
@@ -5979,14 +5989,6 @@ typedef void (mat_mm_id_t)(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const   uchar * src00,
-        device const   uchar * src01,
-        device const   uchar * src02,
-        device const   uchar * src03,
-        device const   uchar * src04,
-        device const   uchar * src05,
-        device const   uchar * src06,
-        device const   uchar * src07,
         threadgroup    uchar *,
         uint3, uint, uint);
 
@@ -6022,9 +6024,10 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]]  kernel mat_mm_id_t kernel
 
 [[host_name("kernel_mul_mv_id_f32_f32")]]
 kernel void kernel_mul_mv_id_f32_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6045,28 +6048,19 @@ kernel void kernel_mul_mv_id_f32_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_f32_f32_impl(
-        src0[id],
+        src0,
         src1 + bid*nb11,
         dst  + bid*ne0,
         ne00,
@@ -6091,9 +6085,10 @@ kernel void kernel_mul_mv_id_f32_f32(
 
 [[host_name("kernel_mul_mv_id_f16_f32")]]
 kernel void kernel_mul_mv_id_f16_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6114,28 +6109,19 @@ kernel void kernel_mul_mv_id_f16_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_f16_f32_impl(
-        src0[id],
+        src0,
         src1 + bid*nb11,
         dst  + bid*ne0,
         ne00,
@@ -6160,9 +6146,10 @@ kernel void kernel_mul_mv_id_f16_f32(
 
 [[host_name("kernel_mul_mv_id_q8_0_f32")]]
 kernel void kernel_mul_mv_id_q8_0_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6183,28 +6170,19 @@ kernel void kernel_mul_mv_id_q8_0_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_q8_0_f32_impl(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -6223,9 +6201,10 @@ kernel void kernel_mul_mv_id_q8_0_f32(
 
 [[host_name("kernel_mul_mv_id_q4_0_f32")]]
 kernel void kernel_mul_mv_id_q4_0_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6246,28 +6225,19 @@ kernel void kernel_mul_mv_id_q4_0_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -6286,9 +6256,10 @@ kernel void kernel_mul_mv_id_q4_0_f32(
 
 [[host_name("kernel_mul_mv_id_q4_1_f32")]]
 kernel void kernel_mul_mv_id_q4_1_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6309,28 +6280,19 @@ kernel void kernel_mul_mv_id_q4_1_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -6349,9 +6311,10 @@ kernel void kernel_mul_mv_id_q4_1_f32(
 
 [[host_name("kernel_mul_mv_id_q5_0_f32")]]
 kernel void kernel_mul_mv_id_q5_0_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6372,28 +6335,19 @@ kernel void kernel_mul_mv_id_q5_0_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -6412,9 +6366,10 @@ kernel void kernel_mul_mv_id_q5_0_f32(
 
 [[host_name("kernel_mul_mv_id_q5_1_f32")]]
 kernel void kernel_mul_mv_id_q5_1_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6435,28 +6390,19 @@ kernel void kernel_mul_mv_id_q5_1_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -6475,9 +6421,10 @@ kernel void kernel_mul_mv_id_q5_1_f32(
 
 [[host_name("kernel_mul_mv_id_q2_K_f32")]]
 kernel void kernel_mul_mv_id_q2_K_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6498,28 +6445,19 @@ kernel void kernel_mul_mv_id_q2_K_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_q2_K_f32_impl(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -6538,9 +6476,10 @@ kernel void kernel_mul_mv_id_q2_K_f32(
 
 [[host_name("kernel_mul_mv_id_q3_K_f32")]]
 kernel void kernel_mul_mv_id_q3_K_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6561,28 +6500,19 @@ kernel void kernel_mul_mv_id_q3_K_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_q3_K_f32_impl(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -6601,9 +6531,10 @@ kernel void kernel_mul_mv_id_q3_K_f32(
 
 [[host_name("kernel_mul_mv_id_q4_K_f32")]]
 kernel void kernel_mul_mv_id_q4_K_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6624,28 +6555,19 @@ kernel void kernel_mul_mv_id_q4_K_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_q4_K_f32_impl(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -6664,9 +6586,10 @@ kernel void kernel_mul_mv_id_q4_K_f32(
 
 [[host_name("kernel_mul_mv_id_q5_K_f32")]]
 kernel void kernel_mul_mv_id_q5_K_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6687,28 +6610,19 @@ kernel void kernel_mul_mv_id_q5_K_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_q5_K_f32_impl(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -6727,9 +6641,10 @@ kernel void kernel_mul_mv_id_q5_K_f32(
 
 [[host_name("kernel_mul_mv_id_q6_K_f32")]]
 kernel void kernel_mul_mv_id_q6_K_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6750,28 +6665,19 @@ kernel void kernel_mul_mv_id_q6_K_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_q6_K_f32_impl(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -6790,9 +6696,10 @@ kernel void kernel_mul_mv_id_q6_K_f32(
 
 [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
 kernel void kernel_mul_mv_id_iq2_xxs_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6813,29 +6720,20 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         threadgroup int8_t   * shared_values [[threadgroup(0)]],
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_iq2_xxs_f32_impl(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -6855,9 +6753,10 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
 
 [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
 kernel void kernel_mul_mv_id_iq2_xs_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6878,29 +6777,20 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         threadgroup int8_t   * shared_values [[threadgroup(0)]],
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_iq2_xs_f32_impl(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -6920,9 +6810,10 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
 
 [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
 kernel void kernel_mul_mv_id_iq3_xxs_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6943,29 +6834,20 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         threadgroup int8_t   * shared_values [[threadgroup(0)]],
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_iq3_xxs_f32_impl(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -6985,9 +6867,10 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
 
 [[host_name("kernel_mul_mv_id_iq3_s_f32")]]
 kernel void kernel_mul_mv_id_iq3_s_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -7008,29 +6891,20 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         threadgroup int8_t   * shared_values [[threadgroup(0)]],
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_iq3_s_f32_impl(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -7050,9 +6924,10 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
 
 [[host_name("kernel_mul_mv_id_iq2_s_f32")]]
 kernel void kernel_mul_mv_id_iq2_s_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -7073,29 +6948,20 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         threadgroup int8_t   * shared_values [[threadgroup(0)]],
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_iq2_s_f32_impl(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -7115,9 +6981,10 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
 
 [[host_name("kernel_mul_mv_id_iq1_s_f32")]]
 kernel void kernel_mul_mv_id_iq1_s_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -7138,28 +7005,19 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_iq1_s_f32_impl(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -7178,9 +7036,10 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
 
 [[host_name("kernel_mul_mv_id_iq1_m_f32")]]
 kernel void kernel_mul_mv_id_iq1_m_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -7201,28 +7060,19 @@ kernel void kernel_mul_mv_id_iq1_m_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_iq1_m_f32_impl(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -7241,9 +7091,10 @@ kernel void kernel_mul_mv_id_iq1_m_f32(
 
 [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
 kernel void kernel_mul_mv_id_iq4_nl_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -7264,29 +7115,20 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         threadgroup float    * shared_values [[threadgroup(0)]],
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_iq4_nl_f32_impl(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -7306,9 +7148,10 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
 
 [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
 kernel void kernel_mul_mv_id_iq4_xs_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -7329,33 +7172,24 @@ kernel void kernel_mul_mv_id_iq4_xs_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         threadgroup float    * shared_values [[threadgroup(0)]],
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
 #if QK_K == 64
     kernel_mul_mv_iq4_nl_f32_impl(
 #else
     kernel_mul_mv_iq4_xs_f32_impl(
 #endif
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
diff --git a/ggml.c b/ggml.c
index 7471e792606c1524715d02d6f135c1fd4e846564..c9b0a6a0ef776af3a453d21c0575df97a8cc807a 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -4573,45 +4573,38 @@ void ggml_mul_mat_set_prec(
 
 // ggml_mul_mat_id
 
+// NOTE: id will be removed in the future and instead all the experts listed in ids will be computed
+//       this will allow computing all the used experts in a single matrix multiplication
 struct ggml_tensor * ggml_mul_mat_id(
         struct ggml_context * ctx,
-        struct ggml_tensor  * const as[],
-        int                   n_as,
+        struct ggml_tensor  * as,
         struct ggml_tensor  * ids,
         int                   id,
         struct ggml_tensor  * b) {
 
     GGML_ASSERT(ids->type == GGML_TYPE_I32);
-    GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1);
-    GGML_ASSERT(ids->ne[1] == b->ne[1]);
+    GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
+    GGML_ASSERT(ids->ne[1] == b->ne[1]); // must have an expert per b row
     GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
-    GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2);
-    GGML_ASSERT(id >= 0 && id < ids->ne[0]);
+    GGML_ASSERT(id >= 0 && id < ids->ne[0]); // valid id
+    GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat
 
     bool is_node = false;
 
-    if (as[0]->grad || b->grad) {
+    if (as->grad || b->grad) {
         is_node = true;
     }
 
-    const int64_t ne[4] = { as[0]->ne[1], b->ne[1], b->ne[2], b->ne[3] };
+    const int64_t ne[4] = { as->ne[1], b->ne[1], b->ne[2], b->ne[3] };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
     ggml_set_op_params_i32(result, 0, id);
-    ggml_set_op_params_i32(result, 1, n_as);
 
     result->op   = GGML_OP_MUL_MAT_ID;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = ids;
+    result->src[0] = as;
     result->src[1] = b;
-
-    for (int i = 0; i < n_as; i++) {
-        struct ggml_tensor * a = as[i];
-        GGML_ASSERT(ggml_are_same_shape(as[0], a));
-        GGML_ASSERT(ggml_can_mul_mat(a, b));
-        GGML_ASSERT(!ggml_is_transposed(a));
-        result->src[i + 2] = a;
-    }
+    result->src[2] = ids;
 
     return result;
 }
@@ -10948,10 +10941,9 @@ static void ggml_compute_forward_mul_mat_id(
         const struct ggml_compute_params * params,
               struct ggml_tensor * dst) {
 
-    const struct ggml_tensor * ids = dst->src[0];
+    const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
-
-    const struct ggml_tensor * src0 = dst->src[2]; // only for GGML_TENSOR_BINARY_OP_LOCALS
+    const struct ggml_tensor * ids = dst->src[2];
 
     GGML_TENSOR_BINARY_OP_LOCALS
 
@@ -10981,13 +10973,13 @@ static void ggml_compute_forward_mul_mat_id(
     GGML_ASSERT(nb1 <= nb2);
     GGML_ASSERT(nb2 <= nb3);
 
-    // broadcast factors
-    const int64_t r2 = ne12/ne02;
-    const int64_t r3 = ne13/ne03;
+    // broadcast is not supported with mmid
+    assert(ne12 == 1);
+    assert(ne13 == 1);
 
     // row groups
     const int id   = ggml_get_op_params_i32(dst, 0);
-    const int n_as = ggml_get_op_params_i32(dst, 1);
+    const int n_as = src0->ne[2];
 
     char * wdata_src1_end = (src1->type == vec_dot_type) ?
             (char *) params->wdata :
@@ -11047,7 +11039,7 @@ static void ggml_compute_forward_mul_mat_id(
             continue;
         }
 
-        const struct ggml_tensor * src0_cur = dst->src[cur_a + 2];
+        size_t src0_offset = cur_a*src0->nb[2];
 
         const void * wdata    = (src1->type == vec_dot_type) ? src1->data : params->wdata;
         const size_t row_size = ggml_row_size(vec_dot_type, ne10);
@@ -11082,9 +11074,6 @@ static void ggml_compute_forward_mul_mat_id(
             continue;
         }
 
-        assert(ne12 % ne02 == 0);
-        assert(ne13 % ne03 == 0);
-
         // block-tiling attempt
         const int64_t blck_0 = 16;
         const int64_t blck_1 = 16;
@@ -11101,14 +11090,14 @@ static void ggml_compute_forward_mul_mat_id(
                     const int64_t  i11 = MMID_MATRIX_ROW(cur_a, _i11);
 
                     // broadcast src0 into src1
-                    const int64_t i03 = i13/r3;
-                    const int64_t i02 = i12/r2;
+                    //const int64_t i03 = i13/r3;
+                    //const int64_t i02 = i12/r2;
 
                     const int64_t i1 = i11;
                     const int64_t i2 = i12;
                     const int64_t i3 = i13;
 
-                    const char * src0_row = (const char *) src0_cur->data + (0 + i02*nb02 + i03*nb03);
+                    const char * src0_row = (const char *) src0->data + src0_offset;
 
                     // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
                     //       if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
@@ -18464,13 +18453,13 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
             case GGML_OP_MUL_MAT_ID:
                 {
                     cur = 0;
-                    const struct ggml_tensor * src0 = node->src[2];
+                    const struct ggml_tensor * src0 = node->src[0];
                     const struct ggml_tensor * src1 = node->src[1];
                     const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
                     if (src1->type != vec_dot_type) {
                         cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
                     }
-                    const int n_as = ggml_get_op_params_i32(node, 1);
+                    const int n_as = src0->ne[2];
                     cur += GGML_PAD(cur, sizeof(int64_t));       // align
                     cur += n_as * sizeof(int64_t);               // matrix_row_counts
                     cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows
diff --git a/ggml.h b/ggml.h
index 5d4a4ceb65c7e106bf2008ba82089fe8d4d7a83f..5cef45c0ba4ad13fefe8450f874918e6e74ed354 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -1164,8 +1164,7 @@ extern "C" {
     //  ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
     GGML_API struct ggml_tensor * ggml_mul_mat_id(
             struct ggml_context * ctx,
-            struct ggml_tensor  * const as[],
-            int                   n_as,
+            struct ggml_tensor  * as,
             struct ggml_tensor  * ids,
             int                   id,
             struct ggml_tensor  * b);