]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : group all experts in a single ggml_mul_mat_id (llama/6505)
authorslaren <redacted>
Thu, 18 Apr 2024 13:18:48 +0000 (15:18 +0200)
committerGeorgi Gerganov <redacted>
Mon, 13 May 2024 08:02:26 +0000 (11:02 +0300)
* ggml : group all experts in a single ggml_mul_mat_id
cuda : improve mmid row copy

* cuda : fix bin bcast with non-cont src0

* test-backend-ops : only run all mul mat tests for base types

* llama : disable moe offloading with SYCL

---------

Co-authored-by: Georgi Gerganov <redacted>
ggml-cuda.cu
ggml-cuda/binbcast.cu
ggml-cuda/convert.cu
ggml-metal.m
ggml-metal.metal
ggml-sycl.cpp
ggml.c
ggml.h

index a3bbb920e68de22b292a7fd2d9a0909b6ddd54d7..07534370c34ff30f45fde5d58837ef9d5f3385be 100644 (file)
@@ -1231,7 +1231,7 @@ static void ggml_cuda_op_mul_mat_cublas(
 
     if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
         // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
-        ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool());
+        ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
         if (src0->type != GGML_TYPE_F16) {
             const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
             GGML_ASSERT(to_fp16_cuda != nullptr);
@@ -1241,7 +1241,7 @@ static void ggml_cuda_op_mul_mat_cublas(
         }
         const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get();
 
-        ggml_cuda_pool_alloc<half> src1_as_f16(ctx.pool());
+        ggml_cuda_pool_alloc<half> src1_as_f16(ctx.pool(id));
         if (src1->type != GGML_TYPE_F16) {
             const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
             GGML_ASSERT(to_fp16_cuda != nullptr);
@@ -1250,7 +1250,7 @@ static void ggml_cuda_op_mul_mat_cublas(
             to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
         }
         const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
-        ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(), row_diff*src1_ncols);
+        ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);
 
         const half alpha_f16 = 1.0f;
         const half beta_f16 = 0.0f;
@@ -1960,20 +1960,73 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
     }
 }
 
+struct mmid_row_mapping {
+    int32_t i1;
+    int32_t i2;
+};
+
+static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous,
+                                                 int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping,
+                                                 const char * __restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
+                                                 int64_t ne11, int64_t ne10,
+                                                 size_t nb11, size_t nb12) {
+    int32_t iid1 = blockIdx.x;
+    int32_t id = blockIdx.y;
+
+    const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
+
+    if (row_id_i != i02) {
+        return;
+    }
+
+    const int64_t i11 = id % ne11;
+    const int64_t i12 = iid1;
+
+    __shared__ int src1_row;
+    if (threadIdx.x == 0) {
+        src1_row = atomicAdd(cur_src1_row, 1);
+        row_mapping[src1_row] = {id, iid1};
+    }
+    __syncthreads();
+
+    const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
+    float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
+
+    for (int i = threadIdx.x; i < ne10; i += blockDim.x) {
+        src1_row_contiguous[i] = src1_row_original[i];
+    }
+}
+
+static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_original, const char * __restrict__ dst_contiguous,
+                                                  const mmid_row_mapping * __restrict__ row_mapping,
+                                                  int64_t ne0,
+                                                  size_t nb1, size_t nb2) {
+    int32_t i = blockIdx.x;
+
+    const int32_t i1 = row_mapping[i].i1;
+    const int32_t i2 = row_mapping[i].i2;
+
+    const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
+    float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
+
+    for (int j = threadIdx.x; j < ne0; j += blockDim.x) {
+        dst_row_original[j] = dst_row_contiguous[j];
+    }
+}
+
 static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
     const ggml_tensor * ids  = dst->src[2];
 
+    GGML_TENSOR_BINARY_OP_LOCALS
+
     GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers");
 
     cudaStream_t stream = ctx.stream();
 
-    const size_t nb11 = src1->nb[1];
-    const size_t nb1  =  dst->nb[1];
-
-    const int32_t id = ((int32_t *) dst->op_params)[0];
-    const int32_t n_as = src0->ne[2];
+    const int64_t n_as = ne02;
+    const int64_t n_ids = ids->ne[0];
 
     std::vector<char> ids_host(ggml_nbytes(ids));
     const char * ids_dev = (const char *) ids->data;
@@ -1982,7 +2035,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
 
     ggml_tensor src0_row = *src0;
     ggml_tensor src1_row = *src1;
-    ggml_tensor dst_row = *dst;
+    ggml_tensor dst_row  = *dst;
 
     char * src0_original = (char *) src0->data;
     char * src1_original = (char *) src1->data;
@@ -1990,19 +2043,39 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
 
     src0_row.ne[2] = 1;
     src0_row.ne[3] = 1;
-    src0_row.nb[3] = src0->nb[2];
+    src0_row.nb[3] = nb02;
 
-    if (src1->ne[1] == 1) {
-        for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
-            const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
+    src1_row.ne[1] = 1;
+    src1_row.ne[2] = 1;
+    src1_row.ne[3] = 1;
+    src1_row.nb[2] = nb11;
+    src1_row.nb[3] = nb11;
 
-            GGML_ASSERT(row_id >= 0 && row_id < n_as);
+    dst_row.ne[1] = 1;
+    dst_row.ne[2] = 1;
+    dst_row.ne[3] = 1;
+    dst_row.nb[2] = nb1;
+    dst_row.nb[3] = nb1;
 
-            src0_row.data = src0_original + row_id*src0->nb[2];
-            src1_row.data = src1_original + i01*src1->nb[1];
-            dst_row.data  =  dst_original + i01*dst->nb[1];
+    if (ne12 == 1) {
+        for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
+            for (int64_t id = 0; id < n_ids; id++) {
+                const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
 
-            ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
+                GGML_ASSERT(i02 >= 0 && i02 < n_as);
+
+                const int64_t i11 = id % ne11;
+                const int64_t i12 = iid1;
+
+                const int64_t i1 = id;
+                const int64_t i2 = i12;
+
+                src0_row.data = src0_original + i02*nb02;
+                src1_row.data = src1_original + i11*nb11 + i12*nb12;
+                dst_row.data  =  dst_original + i1*nb1   + i2*nb2;
+
+                ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
+            }
         }
     } else {
         ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
@@ -2011,54 +2084,69 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
         src1_row.data = src1_contiguous.get();
         dst_row.data  =  dst_contiguous.get();
 
-        for (int32_t row_id = 0; row_id < n_as; ++row_id) {
+        for (int64_t i02 = 0; i02 < n_as; i02++) {
             int64_t num_src1_rows = 0;
-            for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
-                const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
 
-                if (row_id_i != row_id) {
-                    continue;
-                }
+            for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
+                for (int64_t id = 0; id < n_ids; id++) {
+                    const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
 
-                GGML_ASSERT(row_id >= 0 && row_id < n_as);
+                    GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
 
-                CUDA_CHECK(cudaMemcpyAsync(src1_contiguous.get() + num_src1_rows*nb11, src1_original + i01*nb11,
-                                        nb11, cudaMemcpyDeviceToDevice, stream));
-                num_src1_rows++;
+                    if (row_id_i != i02) {
+                        continue;
+                    }
+
+                    num_src1_rows++;
+                }
             }
 
             if (num_src1_rows == 0) {
                 continue;
             }
 
-            src0_row.data = src0_original + row_id*src0->nb[2];
+            ggml_cuda_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
+            ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
+            CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
 
-            src1_row.ne[1] = num_src1_rows;
-            dst_row.ne[1] = num_src1_rows;
+            {
+                dim3 block_dims(std::min((unsigned int)ne10, 768u));
+                dim3 grid_dims(ids->ne[1], n_ids);
+                k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
+                        src1_original, src1_contiguous.get(),
+                        dev_cur_src1_row.get(), dev_row_mapping.get(),
+                        ids_dev, i02, ids->nb[1], ids->nb[0],
+                        ne11, ne10,
+                        nb11, nb12);
+                CUDA_CHECK(cudaGetLastError());
+            }
+
+            src0_row.data = src0_original + i02*nb02;
 
+            GGML_ASSERT(nb11 == sizeof(float)*ne10);
+            GGML_ASSERT(nb1 == sizeof(float)*ne0);
+
+            src1_row.ne[1] = num_src1_rows;
             src1_row.nb[1] = nb11;
             src1_row.nb[2] = num_src1_rows*nb11;
             src1_row.nb[3] = num_src1_rows*nb11;
 
+            dst_row.ne[1] = num_src1_rows;
             dst_row.nb[1] = nb1;
             dst_row.nb[2] = num_src1_rows*nb1;
             dst_row.nb[3] = num_src1_rows*nb1;
 
             ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
 
-            num_src1_rows = 0;
-            for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
-                const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
-
-                if (row_id_i != row_id) {
-                    continue;
-                }
-
-                GGML_ASSERT(row_id >= 0 && row_id < n_as);
-
-                CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous.get() + num_src1_rows*nb1,
-                                        nb1, cudaMemcpyDeviceToDevice, stream));
-                num_src1_rows++;
+            {
+                dim3 block_dims(std::min((unsigned int)ne0, 768u));
+                dim3 grid_dims(num_src1_rows);
+                k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
+                        dst_original, dst_contiguous.get(),
+                        dev_row_mapping.get(),
+                        ne0,
+                        nb1, nb2);
+                CUDA_CHECK(cudaGetLastError());
             }
         }
     }
@@ -2491,7 +2579,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
 GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
     const int min_batch_size = 32;
 
-    return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
+    return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
+           (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
 
     GGML_UNUSED(backend);
 }
index 959eaed95c136a500c002e17468cee7f268bc68a..19b08b74fb0af54c9519b550b18c470c308fac45 100644 (file)
@@ -22,6 +22,7 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
         int ne0, int ne1, int ne2, int ne3,
         int ne10, int ne11, int ne12, int ne13,
         /*int s0, */ int s1,  int s2,  int s3,
+        /*int s00,*/ int s01, int s02, int s03,
         /*int s10,*/ int s11, int s12, int s13) {
     const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
     const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
@@ -36,9 +37,9 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
     const int i12 = i2 % ne12;
     const int i13 = i3 % ne13;
 
-    const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
+    const size_t i_src0 =  i3*s03 +  i2*s02 +  i1*s01;
     const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
-    const size_t i_dst  = i_src0;
+    const size_t i_dst  =  i3*s3  +  i2*s2  +  i1*s1;
 
     const src0_t * src0_row = src0 + i_src0;
     const src1_t * src1_row = src1 + i_src1;
@@ -55,6 +56,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
         int ne0, int ne1, int ne2, int ne3,
         int ne10, int ne11, int ne12, int ne13,
         /*int s0, */ int s1,  int s2,  int s3,
+        /*int s00,*/ int s01, int s02, int s03,
         /*int s10,*/ int s11, int s12, int s13) {
 
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
@@ -72,9 +74,9 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
     const int i12 = i2 % ne12;
     const int i13 = i3 % ne13;
 
-    const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
+    const size_t i_src0 =  i3*s03 +  i2*s02 +  i1*s01;
     const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
-    const size_t i_dst  = i_src0;
+    const size_t i_dst  =  i3*s3  +  i2*s2  +  i1*s1;
 
     const src0_t * src0_row = src0 + i_src0;
     const src1_t * src1_row = src1 + i_src1;
@@ -101,10 +103,14 @@ struct bin_bcast_cuda {
         int nr[4] = { nr0, nr1, nr2, nr3 };
 
         // collapse dimensions until first broadcast dimension
-        int64_t cne0[] = {ne0, ne1, ne2, ne3};
+        int64_t cne[] = {ne0, ne1, ne2, ne3};
+        int64_t cne0[] = {ne00, ne01, ne02, ne03};
         int64_t cne1[] = {ne10, ne11, ne12, ne13};
-        size_t cnb0[] = {nb0, nb1, nb2, nb3};
+
+        size_t cnb[] = {nb0, nb1, nb2, nb3};
+        size_t cnb0[] = {nb00, nb01, nb02, nb03};
         size_t cnb1[] = {nb10, nb11, nb12, nb13};
+
         auto collapse = [](int64_t cne[]) {
             cne[0] *= cne[1];
             cne[1] = cne[2];
@@ -118,32 +124,47 @@ struct bin_bcast_cuda {
             cnb[3] *= cne[3];
         };
 
-        for (int i = 0; i < 4; i++) {
-            if (nr[i] != 1) {
-                break;
-            }
-            if (i > 0) {
-                collapse_nb(cnb0, cne0);
-                collapse_nb(cnb1, cne1);
-                collapse(cne0);
-                collapse(cne1);
+        if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
+            for (int i = 0; i < 4; i++) {
+                if (nr[i] != 1) {
+                    break;
+                }
+                if (i > 0) {
+                    collapse_nb(cnb, cne);
+                    collapse_nb(cnb0, cne0);
+                    collapse_nb(cnb1, cne1);
+                    collapse(cne);
+                    collapse(cne0);
+                    collapse(cne1);
+                }
             }
         }
+
         {
-            int64_t ne0 = cne0[0];
-            int64_t ne1 = cne0[1];
-            int64_t ne2 = cne0[2];
-            int64_t ne3 = cne0[3];
+            int64_t ne0 = cne[0];
+            int64_t ne1 = cne[1];
+            int64_t ne2 = cne[2];
+            int64_t ne3 = cne[3];
+
+            //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00);
+            //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01);
+            //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
+            //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
 
             int64_t ne10 = cne1[0];
             int64_t ne11 = cne1[1];
             int64_t ne12 = cne1[2];
             int64_t ne13 = cne1[3];
 
-            size_t nb0 = cnb0[0];
-            size_t nb1 = cnb0[1];
-            size_t nb2 = cnb0[2];
-            size_t nb3 = cnb0[3];
+            size_t nb0 = cnb[0];
+            size_t nb1 = cnb[1];
+            size_t nb2 = cnb[2];
+            size_t nb3 = cnb[3];
+
+            size_t nb00 = cnb0[0];
+            size_t nb01 = cnb0[1];
+            size_t nb02 = cnb0[2];
+            size_t nb03 = cnb0[3];
 
             size_t nb10 = cnb1[0];
             size_t nb11 = cnb1[1];
@@ -160,7 +181,28 @@ struct bin_bcast_cuda {
             size_t s12 = nb12 / sizeof(src1_t);
             size_t s13 = nb13 / sizeof(src1_t);
 
+            size_t s00 = nb00 / sizeof(src0_t);
+            size_t s01 = nb01 / sizeof(src0_t);
+            size_t s02 = nb02 / sizeof(src0_t);
+            size_t s03 = nb03 / sizeof(src0_t);
+
+            GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
+            GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
+            GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
+            GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
+
+            GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
+            GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
+            GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
+            GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
+
+            GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
+            GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
+            GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
+            GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
+
             GGML_ASSERT(s0 == 1);
+            GGML_ASSERT(s00 == 1);
             GGML_ASSERT(s10 == 1);
 
             const int block_size = 128;
@@ -179,13 +221,14 @@ struct bin_bcast_cuda {
             );
 
             if (block_nums.z > 65535) {
-                // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
+                // this is the maximum number of blocks in z dimension, fallback to 1D grid kernel
                 int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
                 k_bin_bcast_unravel<bin_op><<<block_num, block_size, 0, stream>>>(
                     src0_dd, src1_dd, dst_dd,
                     ne0, ne1, ne2, ne3,
                     ne10, ne11, ne12, ne13,
                     /* s0, */ s1, s2, s3,
+                    /* s00, */ s01, s02, s03,
                     /* s10, */ s11, s12, s13);
             } else {
                 k_bin_bcast<bin_op><<<block_nums, block_dims, 0, stream>>>(
@@ -193,6 +236,7 @@ struct bin_bcast_cuda {
                     ne0, ne1, ne2, ne3,
                     ne10, ne11, ne12, ne13,
                     /* s0, */ s1, s2, s3,
+                    /* s00, */ s01, s02, s03,
                     /* s10, */ s11, s12, s13);
             }
         }
index ed4fa2748972b3fc3064bdfd8afc32510b5b01d1..b15e3578267b3837354d37352d01c700c31be12a 100644 (file)
@@ -45,6 +45,8 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
         vals[ix] = x0[ix];
     }
 
+    __syncthreads();
+
 #pragma unroll
     for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) {
         if (need_check && i0 + iy + 2*threadIdx.x >= k) {
index 0ec47febbd20c4909444485e7bf835970bc41515..fdba0de85bcdbb5d686ce97362901f55135725f0 100644 (file)
@@ -1747,15 +1747,10 @@ static enum ggml_status ggml_metal_graph_compute(
                     } break;
                 case GGML_OP_MUL_MAT_ID:
                     {
-                        //GGML_ASSERT(ne00 == ne10);
-                        //GGML_ASSERT(ne03 == ne13);
                         const int n_as = src0->ne[2];
 
-                        // max size of the src1ids array in the kernel shared buffer
-                        GGML_ASSERT(ne11 <= 4096);
-
                         // src2 = ids
-                        const int64_t  ne20 = src2->ne[0]; GGML_UNUSED(ne20);
+                        const int64_t  ne20 = src2->ne[0];
                         const int64_t  ne21 = src2->ne[1];
                         const int64_t  ne22 = src2->ne[2]; GGML_UNUSED(ne22);
                         const int64_t  ne23 = src2->ne[3]; GGML_UNUSED(ne23);
@@ -1776,15 +1771,13 @@ static enum ggml_status ggml_metal_graph_compute(
 
                         // find the break-even point where the matrix-matrix kernel becomes more efficient compared
                         // to the matrix-vector kernel
-                        int ne11_mm_min = n_as;
-
-                        const int idx = ((int32_t *) dst->op_params)[0];
+                        // ne20 = n_used_experts
+                        // ne21 = n_rows
+                        const int dst_rows = ne20*ne21;
+                        const int dst_rows_min = n_as;
 
-                        // batch size
-                        GGML_ASSERT(ne21 == ne11); // ?
-                        GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting
-                        const uint r2 = 1;
-                        const uint r3 = 1;
+                        // max size of the rowids array in the kernel shared buffer
+                        GGML_ASSERT(dst_rows <= 2048);
 
                         // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
                         // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
@@ -1794,7 +1787,7 @@ static enum ggml_status ggml_metal_graph_compute(
                         // !!!
                         if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
                             ne00 % 32 == 0 && ne00 >= 64 &&
-                            ne11 > ne11_mm_min) {
+                            dst_rows > dst_rows_min) {
 
                             // some Metal matrix data types require aligned pointers
                             // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
@@ -1836,26 +1829,26 @@ static enum ggml_status ggml_metal_graph_compute(
                             [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
                             [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
                             [encoder setBuffer:id_src2 offset:offs_src2    atIndex:3];
-                            [encoder setBytes:&nb21    length:sizeof(nb21) atIndex:4];
-                            [encoder setBytes:&ne00    length:sizeof(ne00) atIndex:5];
-                            [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:6];
-                            [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:7];
-                            [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:8];
-                            [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:9];
-                            [encoder setBytes:&ne13    length:sizeof(ne13) atIndex:10];
-                            [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:11];
-                            [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:12];
-                            [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:13];
-                            [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:14];
-                            [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:15];
-                            [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:16];
-                            [encoder setBytes:&r2      length:sizeof(r2)   atIndex:17];
-                            [encoder setBytes:&r3      length:sizeof(r3)   atIndex:18];
-                            [encoder setBytes:&idx     length:sizeof(idx)  atIndex:19];
-
-                            [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
-
-                            [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                            [encoder setBytes:&ne20    length:sizeof(ne20) atIndex:4];
+                            [encoder setBytes:&ne21    length:sizeof(ne21) atIndex:5];
+                            [encoder setBytes:&nb21    length:sizeof(nb21) atIndex:6];
+                            [encoder setBytes:&ne00    length:sizeof(ne00) atIndex:7];
+                            [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:8];
+                            [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:9];
+                            [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:10];
+                            [encoder setBytes:&ne11    length:sizeof(ne11) atIndex:11];
+                            [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:12];
+                            [encoder setBytes:&ne13    length:sizeof(ne13) atIndex:13];
+                            [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:14];
+                            [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:15];
+                            [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:16];
+                            [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:17];
+                            [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:18];
+                            [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:19];
+
+                            [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                         } else {
                             int nth0 = 32;
                             int nth1 = 1;
@@ -2008,72 +2001,72 @@ static enum ggml_status ggml_metal_graph_compute(
                                 GGML_ASSERT(ne00 >= nth0*nth1);
                             }
 
-                            const int64_t _ne1 = 1; // kernels needs a reference in constant memory
-
                             [encoder setComputePipelineState:pipeline];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
                             [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
-                            [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
-                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
-                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:6];
-                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:7];
-                            [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:8];
-                            [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
-                            [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
-                            [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
-                            [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12];
-                            [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
-                            [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
-                            [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
-                            [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
-                            [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
-                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:18];
-                            [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:19];
-                            [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:20];
-                            [encoder setBytes:&r2   length:sizeof(r2)   atIndex:21];
-                            [encoder setBytes:&r3   length:sizeof(r3)   atIndex:22];
-                            [encoder setBytes:&idx  length:sizeof(idx)  atIndex:23];
+                            [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
+                            [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
+                            [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
+                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
+                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
+                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
+                            [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
+                            [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
+                            [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
+                            [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
+                            [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
+                            [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
+                            [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
+                            [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
+                            [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
+                            [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
+                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:20];
+                            [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:21];
+                            [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:22];
+
+                            const int64_t _ne1 = 1;
+                            const int tgz = dst_rows;
 
                             if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 ||
                                 src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
                                 src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
                             else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
                                 const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
                             else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
                                 const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
                             else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
                                 const int mem_size = 32*sizeof(float);
                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
                             else if (src0t == GGML_TYPE_Q4_K) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
                             else if (src0t == GGML_TYPE_Q3_K) {
 #ifdef GGML_QKK_64
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
 #else
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
 #endif
                             }
                             else if (src0t == GGML_TYPE_Q5_K) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
                             else if (src0t == GGML_TYPE_Q6_K) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             } else {
-                                const int64_t ny = (_ne1 + nrows - 1)/nrows;
-                                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
+                                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
                         }
                     } break;
index d7ae37206f4fd94cd916f55b0f72577aefe108aa..7f37c17d668a82a614e3af4fb213c12b4a9d41b6 100644 (file)
@@ -899,16 +899,16 @@ void mul_vec_q_n_f32_impl(
         device const void  * src0,
         device const float * src1,
         device       float * dst,
-                   constant int64_t &   ne00,
-                   constant int64_t &   ne01,
-                   constant int64_t &   ne02,
-                   constant int64_t &   ne10,
-                   constant int64_t &   ne12,
-                   constant int64_t &   ne0,
-                   constant int64_t &   ne1,
-                   constant uint &      r2,
-                   constant uint &      r3,
-                   threadgroup int8_t * shared_values,
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values,
                    uint3 tgpig, uint tiisg, uint sgitg) {
     const int nb = ne00/QK4_0;
 
@@ -1073,19 +1073,19 @@ void kernel_mul_mv_q8_0_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne10,
-        constant   int64_t & ne12,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t   * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
     const int nr  = N_DST;
     const int nsg = N_SIMDGROUP;
     const int nw  = N_SIMDWIDTH;
@@ -1172,24 +1172,24 @@ void kernel_mul_mv_f32_f32_impl(
         device const  char * src0,
         device const  char * src1,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]]) {
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                  uint64_t   nb00,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                   int64_t   ne10,
+                   int64_t   ne11,
+                   int64_t   ne12,
+                  uint64_t   nb10,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                     uint    r2,
+                     uint    r3,
+                     uint3   tgpig,
+                     uint    tiisg) {
 
     const int64_t r0 = tgpig.x;
     const int64_t rb = tgpig.y*N_F32_F32;
@@ -1442,24 +1442,24 @@ void kernel_mul_mv_f16_f32_impl(
         device const  char * src0,
         device const  char * src1,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]]) {
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                  uint64_t   nb00,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                   int64_t   ne10,
+                   int64_t   ne11,
+                   int64_t   ne12,
+                  uint64_t   nb10,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+                   uint3     tgpig,
+                   uint      tiisg) {
 
     const int64_t r0 = tgpig.x;
     const int64_t rb = tgpig.y*N_F16_F32;
@@ -2744,19 +2744,19 @@ void kernel_mul_mv_q2_K_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne10,
-        constant   int64_t & ne12,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
 
     const int nb = ne00/QK_K;
     const int r0 = tgpig.x;
@@ -2924,19 +2924,19 @@ void kernel_mul_mv_q3_K_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne10,
-        constant   int64_t & ne12,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
 
     const int nb = ne00/QK_K;
 
@@ -3190,19 +3190,19 @@ void kernel_mul_mv_q4_K_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne10,
-        constant   int64_t & ne12,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
 
     const uint16_t kmask1 = 0x3f3f;
     const uint16_t kmask2 = 0x0f0f;
@@ -3429,19 +3429,19 @@ void kernel_mul_mv_q5_K_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne10,
-        constant   int64_t & ne12,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
 
     const int nb = ne00/QK_K;
 
@@ -3636,19 +3636,19 @@ void kernel_mul_mv_q6_K_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne10,
-        constant   int64_t & ne12,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
 
     const uint8_t kmask1 = 0x03;
     const uint8_t kmask2 = 0x0C;
@@ -3773,19 +3773,19 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne10,
-        constant   int64_t & ne12,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
 
     const int nb = ne00/QK_K;
     const int r0 = tgpig.x;
@@ -3902,19 +3902,19 @@ void kernel_mul_mv_iq2_xs_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne10,
-        constant   int64_t & ne12,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
 
     const int nb = ne00/QK_K;
     const int r0 = tgpig.x;
@@ -4041,19 +4041,19 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne10,
-        constant   int64_t & ne12,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
 
     const int nb = ne00/QK_K;
     const int r0 = tgpig.x;
@@ -4173,19 +4173,19 @@ void kernel_mul_mv_iq3_s_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne10,
-        constant   int64_t & ne12,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
 
     const int nb = ne00/QK_K;
     const int r0 = tgpig.x;
@@ -4305,19 +4305,19 @@ void kernel_mul_mv_iq2_s_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne10,
-        constant   int64_t & ne12,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
 
     const int nb = ne00/QK_K;
     const int r0 = tgpig.x;
@@ -4438,19 +4438,19 @@ void kernel_mul_mv_iq1_s_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne10,
-        constant   int64_t & ne12,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_value,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
 
     const int nb = ne00/QK_K;
     const int r0 = tgpig.x;
@@ -4528,19 +4528,19 @@ void kernel_mul_mv_iq1_m_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne10,
-        constant   int64_t & ne12,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_value,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
 
     const int nb = ne00/QK_K;
     const int r0 = tgpig.x;
@@ -4637,19 +4637,19 @@ void kernel_mul_mv_iq4_nl_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne10,
-        constant   int64_t & ne12,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t * shared_values_i8 [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values_i8,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
 
     threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
     const int nb = ne00/QK4_NL;
@@ -4732,19 +4732,20 @@ void kernel_mul_mv_iq4_xs_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne10,
-        constant   int64_t & ne12,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t  * shared_values_i8 [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values_i8,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg) {
+
     threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
     const int nb = ne00/QK_K;
     const int r0 = tgpig.x;
@@ -5686,25 +5687,25 @@ void kernel_mul_mm_impl(device const  uchar * src0,
     }
 }
 
-// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids
+// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
 void kernel_mul_mm_id_impl(
         device const  uchar * src0,
         device const  uchar * src1,
-        threadgroup   short * src1ids,
+        threadgroup ushort2 * rowids,
         device        float * dst,
         constant    int64_t & ne00,
         constant    int64_t & ne02,
         constant   uint64_t & nb01,
         constant   uint64_t & nb02,
+        constant    int64_t & ne11,
         constant    int64_t & ne12,
         constant   uint64_t & nb10,
         constant   uint64_t & nb11,
         constant   uint64_t & nb12,
         constant    int64_t & ne0,
                     int64_t   ne1,
-        constant       uint & r2,
-        constant       uint & r3,
+                    int64_t   ne0ne1,
         threadgroup   uchar * shared_memory,
         uint3                 tgpig[[threadgroup_position_in_grid]],
         uint                  tiitg[[thread_index_in_threadgroup]],
@@ -5715,7 +5716,6 @@ void kernel_mul_mm_id_impl(
 
     const uint r0 = tgpig.y;
     const uint r1 = tgpig.x;
-    const uint im = tgpig.z;
 
     if (r1 * BLOCK_SIZE_N >= ne1) return;
 
@@ -5733,19 +5733,16 @@ void kernel_mul_mm_id_impl(
     for (int i = 0; i < 8; i++){
         c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
     }
-
     short il = (tiitg % THREAD_PER_ROW);
 
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
-
-    uint   offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
     ushort offset1 = il/nl;
 
-    device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
+    threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col];
+
+    device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1;
     device const float   * y = (device const float   *)(src1
-        + nb12 * im
-        + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col]
+        + nb12 * id[1]
+        + nb11 * (id[0] % ne11)
         + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
 
     for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
@@ -5774,11 +5771,11 @@ void kernel_mul_mm_id_impl(
 
         for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
             for (int i = 0; i < 4; i++) {
-                simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
+                simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
             }
             simdgroup_barrier(mem_flags::mem_none);
             for (int i = 0; i < 2; i++) {
-                simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
+                simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
             }
 
             lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
@@ -5800,11 +5797,13 @@ void kernel_mul_mm_id_impl(
 
         threadgroup_barrier(mem_flags::mem_threadgroup);
 
-        device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0;
+        device float * C = dst + (BLOCK_SIZE_M * r0);
         if (sgitg == 0) {
-            for (int i = 0; i < n_rows; i++) {
-                for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
-                    *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
+            for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
+                threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
+                int joff =  jid[0] * ne0 + jid[1] * ne0ne1;
+                for (int i = 0; i < n_rows; i++) {
+                    *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M);
                 }
             }
         }
@@ -5859,11 +5858,14 @@ kernel void kernel_mul_mm_id(
         device const   uchar * src1,
         device         float * dst,
         device const   uchar * ids,
+        constant     int64_t & nei0,
+        constant     int64_t & nei1,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne02,
         constant    uint64_t & nb01,
         constant    uint64_t & nb02,
+        constant     int64_t & ne11,
         constant     int64_t & ne12,
         constant     int64_t & ne13,
         constant    uint64_t & nb10,
@@ -5872,47 +5874,52 @@ kernel void kernel_mul_mm_id(
         constant     int64_t & ne0,
         constant     int64_t & ne1,
         constant    uint64_t & nb1,
-        constant        uint & r2,
-        constant        uint & r3,
-        constant         int & idx,
         threadgroup    uchar * shared_memory [[threadgroup(0)]],
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    // expert id
-    const int32_t id = tgpig.z/(ne12*ne13);
-    device const uchar * src0 = src0s + id*nb02;
+    const int32_t i02 = tgpig.z;
+    tgpig.z = 0;
 
-    tgpig.z = tgpig.z%(ne12*ne13);
+    device const uchar * src0 = src0s + i02*nb02;
 
-    // row indices of src1 for expert id
-    threadgroup short * src1ids = (threadgroup short *)(shared_memory + 8192);
+    // row indices
+    threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
 
+    // TODO: parallelize this loop
     int64_t _ne1 = 0;
-    for (int64_t i1 = 0; i1 < ne1; i1++) {
-        if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
-            src1ids[_ne1++] = i1;
+    for (ushort ii1 = 0; ii1 < nei1; ii1++) {
+        for (ushort ii0 = 0; ii0 < nei0; ii0++) {
+            int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
+            if (id == i02) {
+                //if (tiitg == 0) {
+                    rowids[_ne1] = ushort2(ii0, ii1);
+                //}
+                _ne1++;
+            }
         }
     }
 
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
     kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
         src0,
         src1,
-        src1ids,
+        rowids,
         dst,
         ne00,
         ne02,
         nb01,
         nb02,
+        ne11,
         ne12,
         nb10,
         nb11,
         nb12,
         ne0,
         _ne1,
-        r2,
-        r3,
+        ne0*ne1,
         shared_memory,
         tgpig,
         tiitg,
@@ -5973,24 +5980,7 @@ template [[host_name("kernel_get_rows_iq4_xs")]]  kernel get_rows_t kernel_get_r
 // matrix-matrix multiplication
 //
 
-typedef void (mat_mm_t)(
-        device const  uchar * src0,
-        device const  uchar * src1,
-        device        float * dst,
-        constant    int64_t & ne00,
-        constant    int64_t & ne02,
-        constant   uint64_t & nb01,
-        constant   uint64_t & nb02,
-        constant    int64_t & ne12,
-        constant   uint64_t & nb10,
-        constant   uint64_t & nb11,
-        constant   uint64_t & nb12,
-        constant    int64_t & ne0,
-        constant    int64_t & ne1,
-        constant       uint & r2,
-        constant       uint & r3,
-        threadgroup   uchar *,
-        uint3, uint, uint);
+typedef decltype(kernel_mul_mm<float4x4, 1, dequantize_f32>) mat_mm_t;
 
 template [[host_name("kernel_mul_mm_f32_f32")]]     kernel mat_mm_t kernel_mul_mm<float4x4,      1,     dequantize_f32>;
 template [[host_name("kernel_mul_mm_f16_f32")]]     kernel mat_mm_t kernel_mul_mm<half4x4,       1,     dequantize_f16>;
@@ -6022,29 +6012,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]]  kernel mat_mm_t kernel_mul_m
 // indirect matrix-matrix multiplication
 //
 
-typedef void (mat_mm_id_t)(
-        device const   uchar * src0s,
-        device const   uchar * src1,
-        device         float * dst,
-        device const   uchar * ids,
-        constant    uint64_t & nbi1,
-        constant     int64_t & ne00,
-        constant     int64_t & ne02,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant     int64_t & ne12,
-        constant     int64_t & ne13,
-        constant    uint64_t & nb10,
-        constant    uint64_t & nb11,
-        constant    uint64_t & nb12,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant    uint64_t & nb1,
-        constant        uint & r2,
-        constant        uint & r3,
-        constant         int & idx,
-        threadgroup    uchar *,
-        uint3, uint, uint);
+typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
 
 template [[host_name("kernel_mul_mm_id_f32_f32")]]     kernel mat_mm_id_t kernel_mul_mm_id<float4x4,      1,     dequantize_f32>;
 template [[host_name("kernel_mul_mm_id_f16_f32")]]     kernel mat_mm_id_t kernel_mul_mm_id<half4x4,       1,     dequantize_f16>;
@@ -6080,71 +6048,71 @@ typedef void (kernel_mul_mv_impl_t)(
         device const  char * src0,
         device const  char * src1,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]]);
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                  uint64_t   nb00,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                   int64_t   ne10,
+                   int64_t   ne11,
+                   int64_t   ne12,
+                  uint64_t   nb10,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+                   uint3     tgpig,
+                   uint      tiisg);
 
 typedef void (kernel_mul_mv2_impl_t)(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne10,
-        constant   int64_t & ne12,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]);
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+        threadgroup int8_t * shared_values,
+                   uint3     tgpig,
+                   uint      tiisg,
+                   uint      sgitg);
 
 template<kernel_mul_mv_impl_t impl_fn>
 void mmv_fn(
         device const    char * src0,
         device const    char * src1,
         device         float * dst,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant     int64_t & ne10,
-        constant     int64_t & ne11,
-        constant     int64_t & ne12,
-        constant     int64_t & ne13,
-        constant    uint64_t & nb10,
-        constant    uint64_t & nb11,
-        constant    uint64_t & nb12,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant    uint64_t & nb1,
-        constant        uint & r2,
-        constant        uint & r3,
-        threadgroup int8_t   * shared_values [[threadgroup(0)]],
-        uint3                  tgpig[[threadgroup_position_in_grid]],
-        uint                   tiitg[[thread_index_in_threadgroup]],
-        uint                   tiisg[[thread_index_in_simdgroup]],
-        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+                     int64_t   ne00,
+                     int64_t   ne01,
+                     int64_t   ne02,
+                    uint64_t   nb00,
+                    uint64_t   nb01,
+                    uint64_t   nb02,
+                     int64_t   ne10,
+                     int64_t   ne11,
+                     int64_t   ne12,
+                     int64_t   ne13,
+                    uint64_t   nb10,
+                    uint64_t   nb11,
+                    uint64_t   nb12,
+                     int64_t   ne0,
+                     int64_t   ne1,
+                    uint64_t   nb1,
+                        uint   r2,
+                        uint   r3,
+        threadgroup int8_t   * shared_values,
+        uint3                  tgpig,
+        uint                   tiitg,
+        uint                   tiisg,
+        uint                   sgitg) {
     impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg);
 }
 
@@ -6153,59 +6121,33 @@ void mmv_fn(
         device const    char * src0,
         device const    char * src1,
         device         float * dst,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant     int64_t & ne10,
-        constant     int64_t & ne11,
-        constant     int64_t & ne12,
-        constant     int64_t & ne13,
-        constant    uint64_t & nb10,
-        constant    uint64_t & nb11,
-        constant    uint64_t & nb12,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant    uint64_t & nb1,
-        constant        uint & r2,
-        constant        uint & r3,
-        threadgroup int8_t   * shared_values [[threadgroup(0)]],
-        uint3                  tgpig[[threadgroup_position_in_grid]],
-        uint                   tiitg[[thread_index_in_threadgroup]],
-        uint                   tiisg[[thread_index_in_simdgroup]],
-        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+                     int64_t   ne00,
+                     int64_t   ne01,
+                     int64_t   ne02,
+                    uint64_t   nb00,
+                    uint64_t   nb01,
+                    uint64_t   nb02,
+                     int64_t   ne10,
+                     int64_t   ne11,
+                     int64_t   ne12,
+                     int64_t   ne13,
+                    uint64_t   nb10,
+                    uint64_t   nb11,
+                    uint64_t   nb12,
+                     int64_t   ne0,
+                     int64_t   ne1,
+                    uint64_t   nb1,
+                        uint   r2,
+                        uint   r3,
+        threadgroup int8_t   * shared_values,
+        uint3                  tgpig,
+        uint                   tiitg,
+        uint                   tiisg,
+        uint                   sgitg) {
     impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
 }
 
-typedef void (mul_mv_impl_fn_t)(
-        device const    char * src0,
-        device const    char * src1,
-        device         float * dst,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant     int64_t & ne10,
-        constant     int64_t & ne11,
-        constant     int64_t & ne12,
-        constant     int64_t & ne13,
-        constant    uint64_t & nb10,
-        constant    uint64_t & nb11,
-        constant    uint64_t & nb12,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant    uint64_t & nb1,
-        constant        uint & r2,
-        constant        uint & r3,
-        threadgroup int8_t   * shared_values [[threadgroup(0)]],
-        uint3                  tgpig[[threadgroup_position_in_grid]],
-        uint                   tiitg[[thread_index_in_threadgroup]],
-        uint                   tiisg[[thread_index_in_simdgroup]],
-        uint                   sgitg[[simdgroup_index_in_threadgroup]]);
+typedef decltype(mmv_fn<kernel_mul_mv_f32_f32_impl>) mul_mv_impl_fn_t;
 
 template<mul_mv_impl_fn_t impl_fn>
 kernel void kernel_mul_mv_id(
@@ -6213,6 +6155,8 @@ kernel void kernel_mul_mv_id(
         device const    char * src1,
         device         float * dst,
         device const    char * ids,
+        constant     int64_t & nei0,
+        constant     int64_t & nei1,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6230,43 +6174,50 @@ kernel void kernel_mul_mv_id(
         constant     int64_t & ne0,
         constant     int64_t & ne1,
         constant    uint64_t & nb1,
-        constant        uint & r2,
-        constant        uint & r3,
-        constant         int & idx,
         threadgroup int8_t   * shared_values [[threadgroup(0)]],
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    const int64_t bid = tgpig.z/(ne12*ne13);
+    const int iid1 = tgpig.z/nei0;
+    const int idx = tgpig.z%nei0;
+
+    tgpig.z = 0;
 
-    tgpig.z = tgpig.z%(ne12*ne13);
+    const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx];
 
-    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
-    device const char * src0 = src0s + id*nb02;
+    const int64_t i11 = idx % ne11;
+    const int64_t i12 = iid1;
+
+    const int64_t i1 = idx;
+    const int64_t i2 = i12;
+
+    device const char * src0_cur = src0s + i02*nb02;
+    device const char * src1_cur = src1 + i11*nb11 + i12*nb12;
+    device      float * dst_cur  = dst + i1*ne0 + i2*ne1*ne0;
 
     impl_fn(
-        src0,
-        src1 + bid*nb11,
-        dst  + bid*ne0,
-        ne00,
-        ne01,
-        ne02,
-        nb00,
-        nb01,
-        nb02,
-        ne10,
-        ne11,
-        ne12,
-        ne13,
-        nb10,
-        nb11,
-        nb12,
-        ne0,
-        ne1,
-        nb1,
-        r2,
-        r3,
+        /* src0 */ src0_cur,
+        /* src1 */ src1_cur,
+        /* dst  */ dst_cur,
+        /* ne00 */ ne00,
+        /* ne01 */ ne01,
+        /* ne02 */ 1,//ne02,
+        /* nb00 */ nb00,
+        /* nb01 */ nb01,
+        /* nb02 */ nb02,
+        /* ne10 */ ne10,
+        /* ne11 */ 1,//ne11,
+        /* ne12 */ 1,//ne12,
+        /* ne13 */ 1,//ne13,
+        /* nb10 */ nb10,
+        /* nb11 */ nb11,
+        /* nb12 */ nb12,
+        /* ne0  */ ne0,
+        /* ne1  */ 1,//ne1,
+        /* nb1  */ nb1,
+        /* r2   */ 1,
+        /* r3   */ 1,
         shared_values,
         tgpig,
         tiitg,
@@ -6274,36 +6225,7 @@ kernel void kernel_mul_mv_id(
         sgitg);
 }
 
-typedef void (kernel_mul_mv_id_t)(
-        device const    char * src0s,
-        device const    char * src1,
-        device         float * dst,
-        device const    char * ids,
-        constant    uint64_t & nbi1,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant     int64_t & ne10,
-        constant     int64_t & ne11,
-        constant     int64_t & ne12,
-        constant     int64_t & ne13,
-        constant    uint64_t & nb10,
-        constant    uint64_t & nb11,
-        constant    uint64_t & nb12,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant    uint64_t & nb1,
-        constant        uint & r2,
-        constant        uint & r3,
-        constant         int & idx,
-        threadgroup int8_t   * shared_values [[threadgroup(0)]],
-        uint3                  tgpig[[threadgroup_position_in_grid]],
-        uint                   tiitg[[thread_index_in_threadgroup]],
-        uint                   tiisg[[thread_index_in_simdgroup]],
-        uint                   sgitg[[simdgroup_index_in_threadgroup]]);
+typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>) kernel_mul_mv_id_t;
 
 template [[host_name("kernel_mul_mv_id_f32_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>;
 template [[host_name("kernel_mul_mv_id_f16_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f16_f32_impl>>;
index f5bb7da86988cd1f7f31dbf9ee1f334da3303406..a9b310243f04f7d4357eedf1c6e0833bd3fd345a 100644 (file)
@@ -17752,7 +17752,7 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
 
 GGML_CALL static bool ggml_backend_sycl_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
     const int min_batch_size = 32;
-    return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
+    return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS && op->op != GGML_OP_MUL_MAT_ID;
     GGML_UNUSED(backend);
 }
 
diff --git a/ggml.c b/ggml.c
index 707a1fe4140c300c559be6dca02fd42d87a5327d..a745104c655cf9b8ab9439a310b6bc74a4077ff2 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -4594,21 +4594,32 @@ void ggml_mul_mat_set_prec(
 
 // ggml_mul_mat_id
 
-// NOTE: id will be removed in the future and instead all the experts listed in ids will be computed
-//       this will allow computing all the used experts in a single matrix multiplication
+/*
+    c = ggml_mul_mat_id(ctx, as, b, ids);
+
+    as  -> [cols, rows, n_expert]
+    ids -> [n_experts_used, n_tokens] (i32)
+    b   -> [cols, n_expert_used, n_tokens]
+    c   -> [cols, n_expert_used, n_tokens]
+
+    in b, n_experts_used can be broadcasted to match the n_expert_used of ids
+
+    c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids
+*/
 struct ggml_tensor * ggml_mul_mat_id(
         struct ggml_context * ctx,
         struct ggml_tensor  * as,
-        struct ggml_tensor  * ids,
-        int                   id,
-        struct ggml_tensor  * b) {
-
+        struct ggml_tensor  * b,
+        struct ggml_tensor  * ids) {
+    GGML_ASSERT(!ggml_is_transposed(as));
     GGML_ASSERT(ids->type == GGML_TYPE_I32);
+
+    GGML_ASSERT(as->ne[3] == 1); // as is 3d (one matrix per expert)
+    GGML_ASSERT(b->ne[3] == 1); // b is 3d
     GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
-    GGML_ASSERT(ids->ne[1] == b->ne[1]); // must have an expert per b row
-    GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
-    GGML_ASSERT(id >= 0 && id < ids->ne[0]); // valid id
+    GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row
     GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat
+    GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast
 
     bool is_node = false;
 
@@ -4616,11 +4627,9 @@ struct ggml_tensor * ggml_mul_mat_id(
         is_node = true;
     }
 
-    const int64_t ne[4] = { as->ne[1], b->ne[1], b->ne[2], b->ne[3] };
+    const int64_t ne[4] = { as->ne[1], ids->ne[0], b->ne[2], 1 };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
-    ggml_set_op_params_i32(result, 0, id);
-
     result->op   = GGML_OP_MUL_MAT_ID;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = as;
@@ -11071,11 +11080,6 @@ static void ggml_compute_forward_mul_mat_id(
     enum ggml_type    const vec_dot_type          = type_traits[type].vec_dot_type;
     ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
 
-    GGML_ASSERT(ne0 == ne01);
-    GGML_ASSERT(ne1 == ne11);
-    GGML_ASSERT(ne2 == ne12);
-    GGML_ASSERT(ne3 == ne13);
-
     // we don't support permuted src0 or src1
     GGML_ASSERT(nb00 == ggml_type_size(type));
     GGML_ASSERT(nb10 == ggml_type_size(src1->type));
@@ -11086,22 +11090,21 @@ static void ggml_compute_forward_mul_mat_id(
     GGML_ASSERT(nb1 <= nb2);
     GGML_ASSERT(nb2 <= nb3);
 
-    // broadcast is not supported with mmid
-    assert(ne12 == 1);
-    assert(ne13 == 1);
-
     // row groups
-    const int id   = ggml_get_op_params_i32(dst, 0);
-    const int n_as = src0->ne[2];
+    const int n_ids = ids->ne[0]; // n_expert_used
+    const int n_as  = ne02;       // n_expert
 
     char * wdata_src1_end = (src1->type == vec_dot_type) ?
             (char *) params->wdata :
             (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
 
-    int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
-    int64_t * matrix_rows       = matrix_row_counts + n_as;     // [n_as][ne11]
+    struct mmid_row_mapping {
+        int32_t i1;
+        int32_t i2;
+    };
 
-    #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne11 + (i1)]
+    int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
+    struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
 
    if (params->type == GGML_TASK_TYPE_INIT) {
         if (ith != 0) {
@@ -11127,13 +11130,18 @@ static void ggml_compute_forward_mul_mat_id(
         // initialize matrix_row_counts
         memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
 
+#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
+
         // group rows by src0 matrix
-        for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
-            const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
+        for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
+            for (int id = 0; id < n_ids; ++id) {
+                const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
+
+                assert(i02 >= 0 && i02 < n_as);
 
-            GGML_ASSERT(row_id >= 0 && row_id < n_as);
-            MMID_MATRIX_ROW(row_id, matrix_row_counts[row_id]) = i01;
-            matrix_row_counts[row_id] += 1;
+                MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1};
+                matrix_row_counts[i02] += 1;
+            }
         }
 
         return;
@@ -11151,15 +11159,13 @@ static void ggml_compute_forward_mul_mat_id(
             continue;
         }
 
-        size_t src0_offset = cur_a*src0->nb[2];
+        const char * src0_cur = (const char *) src0->data + cur_a*nb02;
 
         const void * wdata    = (src1->type == vec_dot_type) ? src1->data : params->wdata;
         const size_t row_size = ggml_row_size(vec_dot_type, ne10);
 
-        const int64_t nr0 = ne01;           // src0 rows
-        const int64_t nr1 = cne1*ne12*ne13; // src1 rows
-
-        //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
+        const int64_t nr0 = ne01; // src0 rows
+        const int64_t nr1 = cne1; // src1 rows
 
         // distribute the thread work across the inner or outer loop based on which one is larger
 
@@ -11178,13 +11184,11 @@ static void ggml_compute_forward_mul_mat_id(
         const int64_t ir110 = dr1*ith1;
         const int64_t ir111 = MIN(ir110 + dr1, nr1);
 
-        //printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111);
-
         // threads with no work simply yield (not sure if it helps)
-        if (ir010 >= ir011 || ir110 >= ir111) {
-            sched_yield();
-            continue;
-        }
+        //if (ir010 >= ir011 || ir110 >= ir111) {
+        //    sched_yield();
+        //    continue;
+        //}
 
         // block-tiling attempt
         const int64_t blck_0 = 16;
@@ -11196,20 +11200,16 @@ static void ggml_compute_forward_mul_mat_id(
         for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
             for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
                 for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
-                    const int64_t  i13 = (ir1/(ne12*cne1)); // Note: currently, src1 is always a matrix
-                    const int64_t  i12 = (ir1 - i13*ne12*cne1)/cne1;
-                    const int64_t _i11 = (ir1 - i13*ne12*cne1 - i12*cne1);
-                    const int64_t  i11 = MMID_MATRIX_ROW(cur_a, _i11);
+                    const int64_t _i12 = ir1; // logical row index for this expert
 
-                    // broadcast src0 into src1
-                    //const int64_t i03 = i13/r3;
-                    //const int64_t i02 = i12/r2;
+                    struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
+                    const int id       = row_mapping.i1; // selected expert index
 
-                    const int64_t i1 = i11;
-                    const int64_t i2 = i12;
-                    const int64_t i3 = i13;
+                    const int64_t  i11 = id % ne11;
+                    const int64_t  i12 = row_mapping.i2; // row index in src1
 
-                    const char * src0_row = (const char *) src0->data + src0_offset;
+                    const int64_t  i1 = id;  // selected expert index
+                    const int64_t  i2 = i12; // row
 
                     // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
                     //       if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
@@ -11217,25 +11217,26 @@ static void ggml_compute_forward_mul_mat_id(
                     // TODO: this is a bit of a hack, we should probably have a better way to handle this
                     const char * src1_col = (const char *) wdata +
                         (src1_cont || src1->type != vec_dot_type
-                        ? (i11      + i12*ne11 + i13*ne12*ne11)*row_size
-                        : (i11*nb11 + i12*nb12 + i13*nb13));
+                        ? (i11      + i12*ne11)*row_size
+                        : (i11*nb11 + i12*nb12));
 
-                    float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
+                    float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
 
                     //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
                     //    vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
                     //}
 
                     for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
-                        vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_row + ir0*nb01, 0, src1_col, 0, 1);
+                        vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
                     }
+
                     memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
                 }
             }
         }
     }
 
-    #undef MMID_MATRIX_ROW
+#undef MMID_MATRIX_ROW
 }
 
 // ggml_compute_forward_out_prod
@@ -18583,7 +18584,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
                     const int n_as = src0->ne[2];
                     cur += GGML_PAD(cur, sizeof(int64_t));       // align
                     cur += n_as * sizeof(int64_t);               // matrix_row_counts
-                    cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows
+                    cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
                 } break;
             case GGML_OP_OUT_PROD:
                 {
@@ -21009,12 +21010,12 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
 
             ok = ok && cur != NULL;
 
-            ggml_set_name(cur, ctx->infos[i].name.data);
-
             if (!ok) {
                 break;
             }
 
+            ggml_set_name(cur, ctx->infos[i].name.data);
+
             // point the data member to the appropriate location in the binary blob using the tensor infos
             if (!params.no_alloc) {
               //cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file
diff --git a/ggml.h b/ggml.h
index 1a776ca83e4b0c601c7e1d5802aef1cd435026b7..6d2c8c566ec205e73c8fbf54e0d0994e9615c78e 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -1170,13 +1170,11 @@ extern "C" {
             enum ggml_prec       prec);
 
     // indirect matrix multiplication
-    //  ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
     GGML_API struct ggml_tensor * ggml_mul_mat_id(
             struct ggml_context * ctx,
             struct ggml_tensor  * as,
-            struct ggml_tensor  * ids,
-            int                   id,
-            struct ggml_tensor  * b);
+            struct ggml_tensor  * b,
+            struct ggml_tensor  * ids);
 
     // A: m columns, n rows,
     // B: p columns, n rows,