]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : mul_mat_id use the same tensor for all the experts (#6387)
authorslaren <redacted>
Wed, 3 Apr 2024 13:07:05 +0000 (15:07 +0200)
committerGitHub <redacted>
Wed, 3 Apr 2024 13:07:05 +0000 (16:07 +0300)
* ggml : update mul_mat_id to use the same tensor for all the experts

* update cuda

* minor

* update metal

* update test-backend-ops

* fix cuda

* Update ggml-metal.m

Co-authored-by: Georgi Gerganov <redacted>
* update convert.py

* update convert-hf-to-gguf.py

* update convert.py for mixtral hf models

* Update convert-hf-to-gguf.py

Co-authored-by: Georgi Gerganov <redacted>
* cuda : support non-pow-2 number of experts

* allow quantize to work for split and merged experts models in the same way

* cleanup + disable mmap automatically with split tensors models

* update imatrix

* test-backend-ops : test qwen argsort

* update grok model loading

* llama : add merged experts tensors to the grok tensor map

* minor

* gguf : bump version

* fix quantizing of merged experts

* convert-hf-to-gguf.py : update grok (untested)

* make linter happy

* cuda/argsort : use shared memory instead of pool memory

* convert : fix grok tensor names

* metal : add support for non-pow-2 argsort

* llama : more loader cleanup, better error checking

* cuda : fix warning

* llama : still use mmap for loading old models, but copy the data to a host buffer

* add review note

* llama : remove ffn tensor counting + add sanity check

ggml-ci

* convert : fix handling of n_experts == None

ggml-ci

* imatrix : fix ncall counters

* llama : produce error if imatrix size does not match

* quantize : terminate on errors + trace logs

ggml-ci

* metal : pad shared memory to 16 bytes

---------

Co-authored-by: Georgi Gerganov <redacted>
15 files changed:
convert-hf-to-gguf.py
convert.py
examples/imatrix/imatrix.cpp
examples/quantize/quantize.cpp
ggml-cuda.cu
ggml-cuda/argsort.cu
ggml-metal.m
ggml-metal.metal
ggml.c
ggml.h
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
gguf-py/pyproject.toml
llama.cpp
tests/test-backend-ops.cpp

index 18337839ab72df4600db6158476dc1fc2c8d6e4f..afa034a86900496fe566a219aecf75dd8e9426b5 100755 (executable)
@@ -1216,6 +1216,8 @@ class LlamaModel(Model):
         tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
         n_head = self.hparams.get("num_attention_heads")
         n_kv_head = self.hparams.get("num_key_value_heads")
+        n_experts = self.hparams.get("num_local_experts")
+        experts = dict()
         for name, data_torch in self.get_tensors():
             # we don't need these
             if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
@@ -1236,6 +1238,49 @@ class LlamaModel(Model):
 
             data = data.squeeze()
 
+            # process the experts separately
+            if name.find("block_sparse_moe.experts") != -1:
+                experts[name] = data
+                if len(experts) >= n_experts:
+                    # merge the experts into a single 3d tensor
+                    for bid in range(block_count):
+                        for wid in range(1, 4):
+                            full = True
+                            for xid in range(n_experts):
+                                ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.w{wid}.weight"
+                                if ename not in experts:
+                                    full = False
+                                    break
+                            if not full:
+                                continue
+
+                            datas = []
+                            for xid in range(n_experts):
+                                ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.w{wid}.weight"
+                                datas.append(experts[ename])
+                                del experts[ename]
+
+                            data = np.stack(datas, axis=0)
+                            data_dtype = data.dtype
+
+                            if self.ftype == 0 and data_dtype == np.float16:
+                                data = data.astype(np.float32)
+
+                            if self.ftype == 1 and data_dtype == np.float32:
+                                data = data.astype(np.float16)
+
+                            merged_name = f"layers.{bid}.feed_forward.experts.w{wid}.weight"
+
+                            new_name = tensor_map.get_name(merged_name, try_suffixes=(".weight", ".bias"))
+                            if new_name is None:
+                                print(f"Can not map tensor {name!r}")
+                                sys.exit()
+
+                            print(f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}")
+
+                            self.gguf_writer.add_tensor(new_name, data)
+                continue
+
             # map tensor names
             new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
             if new_name is None:
@@ -1249,7 +1294,7 @@ class LlamaModel(Model):
             if self.ftype == 0 and data_dtype == np.float16:
                 data = data.astype(np.float32)
 
-            # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
+            # 1d tensors need to be converted to float32
             if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
                 data = data.astype(np.float32)
 
@@ -1261,6 +1306,9 @@ class LlamaModel(Model):
 
             self.gguf_writer.add_tensor(new_name, data)
 
+        if len(experts) > 0:
+            raise ValueError(f"Unprocessed experts: {experts.keys()}")
+
 
 @Model.register("GrokForCausalLM")
 class GrokModel(Model):
@@ -1276,6 +1324,92 @@ class GrokModel(Model):
         super().set_gguf_parameters()
         self.gguf_writer.add_name("Grok")
 
+    def write_tensors(self):
+        block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
+        tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
+        n_experts = self.hparams.get("num_local_experts")
+        experts = dict()
+        for name, data_torch in self.get_tensors():
+            # we don't need these
+            if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
+                continue
+
+            old_dtype = data_torch.dtype
+
+            # convert any unsupported data types to float32
+            if data_torch.dtype not in (torch.float16, torch.float32):
+                data_torch = data_torch.to(torch.float32)
+
+            data = data_torch.squeeze().numpy()
+
+            # process the experts separately
+            if name.find(".moe.") != -1:
+                experts[name] = data
+                if len(experts) >= n_experts:
+                    # merge the experts into a single 3d tensor
+                    for bid in range(block_count):
+                        for wid in ["linear", "linear_1", "linear_v"]:
+                            full = True
+                            for xid in range(n_experts):
+                                ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid}.weight"
+                                if ename not in experts:
+                                    full = False
+                                    break
+                            if not full:
+                                continue
+
+                            datas = []
+                            for xid in range(n_experts):
+                                ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid}.weight"
+                                datas.append(experts[ename])
+                                del experts[ename]
+
+                            data = np.stack(datas, axis=0)
+                            data_dtype = data.dtype
+
+                            if self.ftype == 0 and data_dtype == np.float16:
+                                data = data.astype(np.float32)
+
+                            if self.ftype == 1 and data_dtype == np.float32:
+                                data = data.astype(np.float16)
+
+                            merged_name = f"transformer.decoder_layer.{bid}.moe.{wid}.weight"
+
+                            new_name = tensor_map.get_name(merged_name, try_suffixes=(".weight", ".bias"))
+                            if new_name is None:
+                                print(f"Can not map tensor {name!r}")
+                                sys.exit()
+
+                            print(f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}")
+
+                            self.gguf_writer.add_tensor(new_name, data)
+                continue
+
+            # map tensor names
+            new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
+            if new_name is None:
+                print(f"Can not map tensor {name!r}")
+                sys.exit()
+
+            n_dims = len(data.shape)
+            data_dtype = data.dtype
+
+            # if f32 desired, convert any float16 to float32
+            if self.ftype == 0 and data_dtype == np.float16:
+                data = data.astype(np.float32)
+
+            # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
+            if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
+                data = data.astype(np.float32)
+
+            # if f16 desired, convert any float32 2-dim weight tensors to float16
+            if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
+                data = data.astype(np.float16)
+
+            print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
+
+            self.gguf_writer.add_tensor(new_name, data)
+
 
 @Model.register("MiniCPMForCausalLM")
 class MiniCPMModel(Model):
index d3a9ccaf21e616ec770b55a2d85706a26b72a017..244eb75822fd895ac04aa595f9f610223701d040 100755 (executable)
@@ -828,6 +828,15 @@ def part_lazy(lazy_tensor: LazyTensor, n_part: int) -> LazyTensor:
     return LazyTensor(load, s, lazy_tensor.data_type, 'part ' + lazy_tensor.description)
 
 
+def pack_experts_lazy(lazy_tensors: list[LazyTensor]) -> LazyTensor:
+    def load() -> Tensor:
+        tensors = [lazy_tensor.load() for lazy_tensor in lazy_tensors]
+        return UnquantizedTensor(np.array([tensor.ndarray for tensor in tensors]))
+    s = lazy_tensors[0].shape.copy()
+    s.insert(0, len(lazy_tensors))
+    return LazyTensor(load, s, lazy_tensors[0].data_type, 'pack_experts ' + ' | '.join(lt.description for lt in lazy_tensors))
+
+
 # Functionality that simulates `torch.load` but where individual tensors are
 # only loaded into memory on demand, not all at once.
 # PyTorch can't do this natively as of time of writing:
@@ -1246,6 +1255,22 @@ def convert_model_names(model: LazyModel, params: Params, skip_unknown: bool) ->
 
     tmp = model
 
+    # merge experts into one tensor
+    if params.n_experts and params.n_experts > 0:
+        for i_l in range(params.n_layer):
+            for w in range(1, 4):
+                experts = []
+                for e in range(params.n_experts):
+                    if f"layers.{i_l}.feed_forward.experts.{e}.w{w}.weight" in model:
+                        experts.append(model[f"layers.{i_l}.feed_forward.experts.{e}.w{w}.weight"])
+                        del tmp[f"layers.{i_l}.feed_forward.experts.{e}.w{w}.weight"]
+                    elif f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight" in model:
+                        experts.append(model[f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight"])
+                        del tmp[f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight"]
+                    else:
+                        raise ValueError(f"Expert tensor not found: layers.{i_l}.feed_forward.experts.{e}.w{w}.weight")
+                tmp[f"layers.{i_l}.feed_forward.experts.w{w}.weight"] = pack_experts_lazy(experts)
+
     # HF models permut or pack some of the tensors, so we need to undo that
     for i in itertools.count():
         if f"model.layers.{i}.self_attn.q_proj.weight" in model:
index 12d34462b78ec023618156651bc433189718605a..d8cb0a6420456547229acf8aaca0398ecd1b4282 100644 (file)
@@ -98,35 +98,38 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
 
     const float * data = is_host ? (const float *) src1->data : m_src1_data.data();
 
+    // this has been adapted to the new format of storing merged experts in a single 3d tensor
+    // ref: https://github.com/ggerganov/llama.cpp/pull/6387
     if (t->op == GGML_OP_MUL_MAT_ID) {
         const int idx  = ((int32_t *) t->op_params)[0];
-        const int n_as = ((int32_t *) t->op_params)[1];
+        const ggml_tensor * ids = t->src[2];
+        const int n_as = src0->ne[2];
 
-        // the top-k selected expert ids are stored in the src0 tensor
-        // for simplicity, always copy src0 to host, because it is small
-        // take into account that src0 is not contiguous!
-        GGML_ASSERT(src0->ne[1] == src1->ne[1]);
-        GGML_ASSERT(n_as*ggml_nrows(src0)*sizeof(int) == GGML_PAD(ggml_nbytes(src0), n_as*sizeof(int)));
-        m_ids.resize(ggml_nbytes(src0)/sizeof(int));
-        ggml_backend_tensor_get(src0, m_ids.data(), 0, ggml_nbytes(src0));
+        // the top-k selected expert ids are stored in the ids tensor
+        // for simplicity, always copy ids to host, because it is small
+        // take into account that ids is not contiguous!
+        GGML_ASSERT(ids->ne[1] == src1->ne[1]);
+        GGML_ASSERT(n_as*ggml_nrows(ids)*sizeof(int) == GGML_PAD(ggml_nbytes(ids), n_as*sizeof(int)));
+        m_ids.resize(ggml_nbytes(ids)/sizeof(int));
+        ggml_backend_tensor_get(ids, m_ids.data(), 0, ggml_nbytes(ids));
+
+        auto & e = m_stats[wname];
+
+        ++e.ncall;
+        // NOTE: since we select top-k experts, the number of calls for the expert tensors will be k times larger
+        //       using the following line, we can correct for that if needed by replacing the line above with:
+        //if (idx == t->src[0]->ne[0] - 1) ++e.ncall;
 
         // loop over all possible experts, regardless if they are used or not in the batch
-        // this is necessary to guarantee equal number of "ncall" for each tensor
         for (int ex = 0; ex < n_as; ++ex) {
-            src0 = t->src[2 + ex];
-            wname = filter_tensor_name(src0->name);
-            auto& e = m_stats[wname];
+            size_t e_start = ex*src1->ne[0];
             if (e.values.empty()) {
-                e.values.resize(src1->ne[0], 0);
+                e.values.resize(src1->ne[0]*n_as, 0);
             }
-            else if (e.values.size() != (size_t)src1->ne[0]) {
-                fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]);
+            else if (e.values.size() != (size_t)src1->ne[0]*n_as) {
+                fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as);
                 exit(1); //GGML_ASSERT(false);
             }
-            // NOTE: since we select top-k experts, the number of calls for the expert tensors will be k times larger
-            //       using the following line, we can correct for that if needed
-            //if (idx == t->src[0]->ne[0] - 1) ++e.ncall;
-            ++e.ncall;
             if (m_params.verbosity > 1) {
                 printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type);
             }
@@ -136,7 +139,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
                 if (excur != ex) continue;
                 const float * x = data + row * src1->ne[0];
                 for (int j = 0; j < (int)src1->ne[0]; ++j) {
-                    e.values[j] += x[j]*x[j];
+                    e.values[e_start + j] += x[j]*x[j];
                 }
             }
             if (e.ncall > m_last_call) {
index 80c493f1f7175556c789af271876760107f668d9..64cb6db19d0040ea6171d0c1039fc0731927370a 100644 (file)
@@ -116,13 +116,13 @@ static void load_imatrix(const std::string & imatrix_file, std::unordered_map<st
     std::ifstream in(imatrix_file.c_str(), std::ios::binary);
     if (!in) {
         printf("%s: failed to open %s\n",__func__, imatrix_file.c_str());
-        return;
+        exit(1);
     }
     int n_entries;
     in.read((char *)&n_entries, sizeof(n_entries));
     if (in.fail() || n_entries < 1) {
         printf("%s: no data in file %s\n", __func__, imatrix_file.c_str());
-        return;
+        exit(1);
     }
     for (int i = 0; i < n_entries; ++i) {
         int len; in.read((char *)&len, sizeof(len));
@@ -130,11 +130,11 @@ static void load_imatrix(const std::string & imatrix_file, std::unordered_map<st
         in.read((char *)name_as_vec.data(), len);
         if (in.fail()) {
             printf("%s: failed reading name for entry %d from %s\n", __func__, i+1, imatrix_file.c_str());
-            return;
+            exit(1);
         }
         name_as_vec[len] = 0;
         std::string name{name_as_vec.data()};
-        auto & e = imatrix_data[std::move(name)];
+        auto & e = imatrix_data[name];
         int ncall;
         in.read((char *)&ncall, sizeof(ncall));
         int nval;
@@ -142,18 +142,22 @@ static void load_imatrix(const std::string & imatrix_file, std::unordered_map<st
         if (in.fail() || nval < 1) {
             printf("%s: failed reading number of values for entry %d\n", __func__, i);
             imatrix_data = {};
-            return;
+            exit(1);
         }
         e.resize(nval);
         in.read((char *)e.data(), nval*sizeof(float));
         if (in.fail()) {
             printf("%s: failed reading data for entry %d\n", __func__, i);
             imatrix_data = {};
-            return;
+            exit(1);
         }
         if (ncall > 0) {
             for (auto& v : e) v /= ncall;
         }
+
+        if (getenv("LLAMA_TRACE")) {
+            printf("%s: loaded data (size = %6d, ncall = %6d) for '%s'\n", __func__, int(e.size()), ncall, name.c_str());
+        }
     }
     printf("%s: loaded %d importance matrix entries from %s\n", __func__, int(imatrix_data.size()), imatrix_file.c_str());
 }
index be8e33a56c40f45dac8807b419925379c72f3c0f..f51b2042df3599963a6e76b0e02b4b78531ec7c8 100644 (file)
@@ -401,10 +401,8 @@ GGML_CALL static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t
 GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
     ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
 
-    if (tensor->view_src != NULL && tensor->view_offs == 0) {
+    if (tensor->view_src != NULL) {
         assert(tensor->view_src->buffer->buft == buffer->buft);
-        tensor->backend = tensor->view_src->backend;
-        tensor->extra = tensor->view_src->extra;
         return;
     }
 
@@ -1962,227 +1960,49 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
     }
 }
 
-#if 0
-template<typename ... Srcs>
-static __global__ void k_compute_batched_ptrs_id(
-        const void ** ptrs_src, void ** ptrs_dst,
-        int ne12, int ne13,
-        int ne23,
-        int nb02, int nb03,
-        int nb12, int nb13,
-        int nb2, int nb3,
-        int r2, int r3,
-        ggml_type src0_type, half * src0_as_f16, int64_t src0_ne,
-        const half * src1_f16, half * dst_f16,
-        const int32_t * ids, const int id,
-        Srcs... src0s) {
-
-    int i = ids[id];
-
-    half * src0_f16;
-    const void * srcs_ar[] = { (const half *) src0s... };
-    if (src0_type == GGML_TYPE_F16) {
-        src0_f16 = (half *) srcs_ar[i];
-    } else {
-        src0_f16 = src0_as_f16;
-        if (threadIdx.x == 0 && threadIdx.y == 0) {
-            const to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(src0_type);
-            to_fp16(srcs_ar[i], src0_f16, src0_ne, cudaStreamFireAndForget);
-        }
-    }
-
-    int i13 = blockIdx.x * blockDim.x + threadIdx.x;
-    int i12 = blockIdx.y * blockDim.y + threadIdx.y;
-
-    if (i13 >= ne13 || i12 >= ne12) {
-        return;
-    }
-
-    int i03 = i13 / r3;
-    int i02 = i12 / r2;
-
-    ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_f16 + i02*nb02   + i03*nb03;
-    ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_f16 + i12*nb12/2 + i13*nb13/2;
-    ptrs_dst[0*ne23 + i12 + i13*ne12] = (      char *)  dst_f16 + i12* nb2/2 + i13* nb3/2;
-}
-
-static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) {
-    const struct ggml_tensor * ids = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-    const struct ggml_tensor * src00 = dst->src[2];
-
-    const int id = dst->op_params[0];
-
-    GGML_ASSERT(!ggml_is_transposed(src00));
-    GGML_ASSERT(!ggml_is_transposed(src1));
-
-    GGML_ASSERT(src00->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
-    const int64_t ne00 = src00->ne[0]; GGML_UNUSED(ne00);
-    const int64_t ne01 = src00->ne[1];
-    const int64_t ne02 = src00->ne[2];
-    const int64_t ne03 = src00->ne[3];
-
-    //const int64_t nb01 = src00->nb[1];
-    const int64_t nb02 = src00->nb[2]; GGML_UNUSED(nb02);
-    const int64_t nb03 = src00->nb[3]; GGML_UNUSED(nb03);
-
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
-    const int64_t ne12 = src1->ne[2];
-    const int64_t ne13 = src1->ne[3];
-
-    //const int64_t nb11 = src1->nb[1];
-    const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12);
-    const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13);
-
-    const int64_t ne1 = ggml_nelements(src1);
-    const int64_t ne  = ggml_nelements(dst);
-
-    ggml_cuda_set_device(g_main_device);
-    cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
-
-    CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream));
-
-    //ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
-    //void * src0_ddq = src0_extra->data_device[g_main_device];
-    //half * src0_as_f16 = (half *) src0_ddq;
-
-    ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
-    float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
-
-    ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
-    float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
-
-    // convert src1 to fp16
-    const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
-    GGML_ASSERT(to_fp16_cuda != nullptr);
-
-    size_t src1_as = 0;
-    half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
-    to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
-
-    size_t dst_as = 0;
-    half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
-
-    GGML_ASSERT(ne12 % ne02 == 0);
-    GGML_ASSERT(ne13 % ne03 == 0);
-
-    // broadcast factors
-    const int64_t r2 = ne12/ne02;
-    const int64_t r3 = ne13/ne03;
-
-    const half alpha_f16 = 1.0f;
-    const half beta_f16  = 0.0f;
-
-    // use cublasGemmBatchedEx
-    const int ne23 = ne12*ne13;
-
-    const void ** ptrs_src = nullptr;
-          void ** ptrs_dst = nullptr;
-
-    size_t ptrs_src_s = 0;
-    size_t ptrs_dst_s = 0;
-
-    ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s);
-    ptrs_dst = (      void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s);
-
-    int64_t src0_ne = ggml_nelements(src00);
-    half * src0_as_f16 = nullptr;
-    size_t src0_as = 0;
-    if (src00->type != GGML_TYPE_F16) {
-        src0_as_f16 = (half *) ggml_cuda_pool_malloc(src0_ne * sizeof(half), &src0_as);
-    }
-
-    static_assert(GGML_MAX_SRC == 6, "GGML_MAX_SRC == 6");
-    dim3 block_dims(ne13, ne12);
-    k_compute_batched_ptrs_id<<<1, block_dims, 0, main_stream>>>(
-            ptrs_src, ptrs_dst,
-            ne12, ne13,
-            ne23,
-            ne00*ne01*sizeof(half), ne00*ne01*ne02*sizeof(half),
-            nb12, nb13,
-            dst->nb[2], dst->nb[3],
-            r2, r3,
-            src00->type, src0_as_f16, src0_ne,
-            src1_as_f16, dst_f16,
-            (const int *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device], id,
-            dst->src[2] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[2]->extra)->data_device[g_main_device] : nullptr,
-            dst->src[3] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[3]->extra)->data_device[g_main_device] : nullptr,
-            dst->src[4] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[4]->extra)->data_device[g_main_device] : nullptr,
-            dst->src[5] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[5]->extra)->data_device[g_main_device] : nullptr
-    );
-    CUDA_CHECK(cudaGetLastError());
-
-    CUBLAS_CHECK(
-    cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
-            ne01, ne11, ne10,
-            &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, ne00,
-                        (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, ne10,
-            &beta_f16,  (      void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
-            ne23,
-            CUBLAS_COMPUTE_16F,
-            CUBLAS_GEMM_DEFAULT_TENSOR_OP));
-
-    if (src0_as != 0) {
-        ggml_cuda_pool_free(src0_as_f16, src0_as);
-    }
-    if (ptrs_src_s != 0) {
-        ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
-    }
-    if (ptrs_dst_s != 0) {
-        ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
-    }
-
-    const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
-    to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
-
-    ggml_cuda_pool_free(src1_as_f16, src1_as);
-    ggml_cuda_pool_free(dst_f16, dst_as);
-}
-#endif
-
 static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-#if 0
-    ggml_cuda_mul_mat_id_cublas(dst);
-    // TODO: mmq/mmv support
-#endif
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
+    const ggml_tensor * ids  = dst->src[2];
+
+    GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers");
 
     cudaStream_t stream = ctx.stream();
 
     const size_t nb11 = src1->nb[1];
     const size_t nb1  =  dst->nb[1];
 
-    const struct ggml_tensor * ids = src0;
     const int32_t id = ((int32_t *) dst->op_params)[0];
-    const int32_t n_as = ((int32_t *) dst->op_params)[1];
+    const int32_t n_as = src0->ne[2];
 
     std::vector<char> ids_host(ggml_nbytes(ids));
     const char * ids_dev = (const char *) ids->data;
     CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
     CUDA_CHECK(cudaStreamSynchronize(stream));
 
+    ggml_tensor src0_row = *src0;
     ggml_tensor src1_row = *src1;
     ggml_tensor dst_row = *dst;
 
+    char * src0_original = (char *) src0->data;
     char * src1_original = (char *) src1->data;
     char * dst_original  = (char *)  dst->data;
 
+    src0_row.ne[2] = 1;
+    src0_row.ne[3] = 1;
+    src0_row.nb[3] = src0->nb[2];
+
     if (src1->ne[1] == 1) {
         for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
             const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
 
             GGML_ASSERT(row_id >= 0 && row_id < n_as);
 
-            const struct ggml_tensor * src0_row = dst->src[row_id + 2];
-
+            src0_row.data = src0_original + row_id*src0->nb[2];
             src1_row.data = src1_original + i01*src1->nb[1];
             dst_row.data  =  dst_original + i01*dst->nb[1];
 
-            ggml_cuda_mul_mat(ctx, src0_row, &src1_row, &dst_row);
+            ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
         }
     } else {
         ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
@@ -2192,8 +2012,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
         dst_row.data  =  dst_contiguous.get();
 
         for (int32_t row_id = 0; row_id < n_as; ++row_id) {
-            const struct ggml_tensor * src0_row = dst->src[row_id + 2];
-
             int64_t num_src1_rows = 0;
             for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
                 const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
@@ -2213,6 +2031,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
                 continue;
             }
 
+            src0_row.data = src0_original + row_id*src0->nb[2];
+
             src1_row.ne[1] = num_src1_rows;
             dst_row.ne[1] = num_src1_rows;
 
@@ -2224,7 +2044,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
             dst_row.nb[2] = num_src1_rows*nb1;
             dst_row.nb[3] = num_src1_rows*nb1;
 
-            ggml_cuda_mul_mat(ctx, src0_row, &src1_row, &dst_row);
+            ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
 
             num_src1_rows = 0;
             for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
@@ -2389,7 +2209,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
     cudaError_t err = cudaGetLastError();
     if (err != cudaSuccess) {
         fprintf(stderr, "%s: %s failed\n", __func__, ggml_op_desc(dst));
-        GGML_ASSERT(false);
+        CUDA_CHECK(err);
     }
 
     return true;
index 1333287e42e45a955fa796d2c763d00537520426..1641440617779e9da3de304e678f475bd569675a 100644 (file)
@@ -8,32 +8,41 @@ static inline __device__ void ggml_cuda_swap(T & a, T & b) {
 }
 
 template<ggml_sort_order order>
-static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols) {
+static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
     // bitonic sort
     int col = threadIdx.x;
     int row = blockIdx.y;
 
-    if (col >= ncols) return;
+    if (col >= ncols_pad) {
+        return;
+    }
 
     const float * x_row = x + row * ncols;
-    int * dst_row = dst + row * ncols;
+    extern __shared__ int dst_row[];
 
     // initialize indices
-    if (col < ncols) {
-        dst_row[col] = col;
-    }
+    dst_row[col] = col;
+
     __syncthreads();
 
-    for (int k = 2; k <= ncols; k *= 2) {
+    for (int k = 2; k <= ncols_pad; k *= 2) {
         for (int j = k / 2; j > 0; j /= 2) {
             int ixj = col ^ j;
             if (ixj > col) {
                 if ((col & k) == 0) {
-                    if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
+                    if (dst_row[col] >= ncols ||
+                        (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
+                            x_row[dst_row[col]] > x_row[dst_row[ixj]] :
+                            x_row[dst_row[col]] < x_row[dst_row[ixj]]))
+                    ) {
                         ggml_cuda_swap(dst_row[col], dst_row[ixj]);
                     }
                 } else {
-                    if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
+                    if (dst_row[ixj] >= ncols ||
+                        (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
+                            x_row[dst_row[col]] < x_row[dst_row[ixj]] :
+                            x_row[dst_row[col]] > x_row[dst_row[ixj]]))
+                    ) {
                         ggml_cuda_swap(dst_row[col], dst_row[ixj]);
                     }
                 }
@@ -41,18 +50,35 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int n
             __syncthreads();
         }
     }
+
+    // copy the result to dst without the padding
+    if (col < ncols) {
+        dst[row * ncols + col] = dst_row[col];
+    }
+}
+
+static int next_power_of_2(int x) {
+    int n = 1;
+    while (n < x) {
+        n *= 2;
+    }
+    return n;
 }
 
 static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
     // bitonic sort requires ncols to be power of 2
-    GGML_ASSERT((ncols & (ncols - 1)) == 0);
+    const int ncols_pad = next_power_of_2(ncols);
 
-    const dim3 block_dims(ncols, 1, 1);
+    const dim3 block_dims(ncols_pad, 1, 1);
     const dim3 block_nums(1, nrows, 1);
+    const size_t shared_mem = ncols_pad * sizeof(int);
+
+    GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
+
     if (order == GGML_SORT_ORDER_ASC) {
-        k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
+        k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
     } else if (order == GGML_SORT_ORDER_DESC) {
-        k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
+        k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
     } else {
         GGML_ASSERT(false);
     }
index a08abbc2918028cc178d997a67946b8030295fd6..419d8b9e56878f7638c984f129225dc6f3474e1d 100644 (file)
@@ -1685,37 +1685,31 @@ static enum ggml_status ggml_metal_graph_compute(
                     {
                         //GGML_ASSERT(ne00 == ne10);
                         //GGML_ASSERT(ne03 == ne13);
-
-                        GGML_ASSERT(src0t == GGML_TYPE_I32);
-
-                        const int n_as = ((int32_t *) dst->op_params)[1];
-
-                        // TODO: make this more general
-                        GGML_ASSERT(n_as <= 8);
+                        const int n_as = src0->ne[2];
 
                         // max size of the src1ids array in the kernel shared buffer
                         GGML_ASSERT(ne11 <= 4096);
 
-                        const int64_t  ne20 = src2 ? src2->ne[0] : 0;
-                        const int64_t  ne21 = src2 ? src2->ne[1] : 0;
-                        const int64_t  ne22 = src2 ? src2->ne[2] : 0;
-                        const int64_t  ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
+                        // src2 = ids
+                        const int64_t  ne20 = src2->ne[0]; GGML_UNUSED(ne20);
+                        const int64_t  ne21 = src2->ne[1];
+                        const int64_t  ne22 = src2->ne[2]; GGML_UNUSED(ne22);
+                        const int64_t  ne23 = src2->ne[3]; GGML_UNUSED(ne23);
+
+                        const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20);
+                        const uint64_t nb21 = src2->nb[1];
+                        const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22);
+                        const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23);
 
-                        const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
-                        const uint64_t nb21 = src2 ? src2->nb[1] : 0;
-                        const uint64_t nb22 = src2 ? src2->nb[2] : 0;
-                        const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
+                        const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
 
-                        const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
+                        GGML_ASSERT(src2t == GGML_TYPE_I32);
 
-                        GGML_ASSERT(!ggml_is_transposed(src2));
+                        GGML_ASSERT(!ggml_is_transposed(src0));
                         GGML_ASSERT(!ggml_is_transposed(src1));
 
                         GGML_ASSERT(src1t == GGML_TYPE_F32);
 
-                        const uint r2 = ne12/ne22;
-                        const uint r3 = ne13/ne23;
-
                         // find the break-even point where the matrix-matrix kernel becomes more efficient compared
                         // to the matrix-vector kernel
                         int ne11_mm_min = n_as;
@@ -1723,7 +1717,10 @@ static enum ggml_status ggml_metal_graph_compute(
                         const int idx = ((int32_t *) dst->op_params)[0];
 
                         // batch size
-                        GGML_ASSERT(ne01 == ne11);
+                        GGML_ASSERT(ne21 == ne11); // ?
+                        GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting
+                        const uint r2 = 1;
+                        const uint r3 = 1;
 
                         // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
                         // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
@@ -1732,7 +1729,7 @@ static enum ggml_status ggml_metal_graph_compute(
                         //       indirect matrix multiplication
                         // !!!
                         if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
-                            ne20 % 32 == 0 && ne20 >= 64 &&
+                            ne00 % 32 == 0 && ne00 >= 64 &&
                             ne11 > ne11_mm_min) {
 
                             // some Metal matrix data types require aligned pointers
@@ -1745,7 +1742,7 @@ static enum ggml_status ggml_metal_graph_compute(
 
                             id<MTLComputePipelineState> pipeline = nil;
 
-                            switch (src2->type) {
+                            switch (src0->type) {
                                 case GGML_TYPE_F32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32    ].pipeline; break;
                                 case GGML_TYPE_F16:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32    ].pipeline; break;
                                 case GGML_TYPE_Q4_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32   ].pipeline; break;
@@ -1774,36 +1771,27 @@ static enum ggml_status ggml_metal_graph_compute(
                             [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
                             [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
                             [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
-                            [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:3];
-                            [encoder setBytes:&ne20    length:sizeof(ne20) atIndex:4];
-                            [encoder setBytes:&ne22    length:sizeof(ne22) atIndex:5];
-                            [encoder setBytes:&nb21    length:sizeof(nb21) atIndex:6];
-                            [encoder setBytes:&nb22    length:sizeof(nb22) atIndex:7];
-                            [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:8];
-                            [encoder setBytes:&ne13    length:sizeof(ne13) atIndex:9];
-                            [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:10];
-                            [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:11];
-                            [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:12];
-                            [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:13];
-                            [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:14];
-                            [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:15];
-                            [encoder setBytes:&r2      length:sizeof(r2)   atIndex:16];
-                            [encoder setBytes:&r3      length:sizeof(r3)   atIndex:17];
-                            [encoder setBytes:&idx     length:sizeof(idx)  atIndex:18];
-                            // TODO: how to make this an array? read Metal docs
-                            for (int j = 0; j < 8; ++j) {
-                                // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
-                                struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
-
-                                size_t offs_src_cur = 0;
-                                id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
-
-                                [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
-                            }
+                            [encoder setBuffer:id_src2 offset:offs_src2    atIndex:3];
+                            [encoder setBytes:&nb21    length:sizeof(nb21) atIndex:4];
+                            [encoder setBytes:&ne00    length:sizeof(ne00) atIndex:5];
+                            [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:6];
+                            [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:7];
+                            [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:8];
+                            [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:9];
+                            [encoder setBytes:&ne13    length:sizeof(ne13) atIndex:10];
+                            [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:11];
+                            [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:12];
+                            [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:13];
+                            [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:14];
+                            [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:15];
+                            [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:16];
+                            [encoder setBytes:&r2      length:sizeof(r2)   atIndex:17];
+                            [encoder setBytes:&r3      length:sizeof(r3)   atIndex:18];
+                            [encoder setBytes:&idx     length:sizeof(idx)  atIndex:19];
 
                             [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
 
-                            [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                            [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                         } else {
                             int nth0 = 32;
                             int nth1 = 1;
@@ -1813,7 +1801,7 @@ static enum ggml_status ggml_metal_graph_compute(
                             id<MTLComputePipelineState> pipeline = nil;
 
                             // use custom matrix x vector kernel
-                            switch (src2t) {
+                            switch (src0t) {
                                 case GGML_TYPE_F32:
                                     {
                                         GGML_ASSERT(src1t == GGML_TYPE_F32);
@@ -1947,8 +1935,8 @@ static enum ggml_status ggml_metal_graph_compute(
                                     }
                             };
 
-                            if (ggml_is_quantized(src2t)) {
-                                GGML_ASSERT(ne20 >= nth0*nth1);
+                            if (ggml_is_quantized(src0t)) {
+                                GGML_ASSERT(ne00 >= nth0*nth1);
                             }
 
                             const int64_t _ne1 = 1; // kernels needs a reference in constant memory
@@ -1957,75 +1945,66 @@ static enum ggml_status ggml_metal_graph_compute(
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                            [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
-                            [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
-                            [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
-                            [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
-                            [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
-                            [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
-                            [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
-                            [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
-                            [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
-                            [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
-                            [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
-                            [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
-                            [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
-                            [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
-                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:17];
-                            [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
-                            [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:19];
-                            [encoder setBytes:&r2   length:sizeof(r2)   atIndex:20];
-                            [encoder setBytes:&r3   length:sizeof(r3)   atIndex:21];
-                            [encoder setBytes:&idx  length:sizeof(idx)  atIndex:22];
-                            // TODO: how to make this an array? read Metal docs
-                            for (int j = 0; j < 8; ++j) {
-                                // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
-                                struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
-
-                                size_t offs_src_cur = 0;
-                                id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
-
-                                [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
-                            }
+                            [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
+                            [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
+                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
+                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:6];
+                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:7];
+                            [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:8];
+                            [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
+                            [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
+                            [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
+                            [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12];
+                            [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
+                            [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
+                            [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
+                            [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
+                            [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
+                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:18];
+                            [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:19];
+                            [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:20];
+                            [encoder setBytes:&r2   length:sizeof(r2)   atIndex:21];
+                            [encoder setBytes:&r3   length:sizeof(r3)   atIndex:22];
+                            [encoder setBytes:&idx  length:sizeof(idx)  atIndex:23];
 
-                            if (src2t == GGML_TYPE_Q4_0  || src2t == GGML_TYPE_Q4_1  || src2t == GGML_TYPE_Q5_0 ||
-                                src2t == GGML_TYPE_Q5_1  || src2t == GGML_TYPE_Q8_0  || src2t == GGML_TYPE_Q2_K ||
-                                src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ1_M || src2t == GGML_TYPE_IQ2_S) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                            if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 ||
+                                src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
+                                src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
-                            else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
-                                const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
+                            else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
+                                const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
-                            else if (src2t == GGML_TYPE_IQ3_XXS || src2t == GGML_TYPE_IQ3_S) {
-                                const int mem_size = src2t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
+                            else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
+                                const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
-                            else if (src2t == GGML_TYPE_IQ4_NL || src2t == GGML_TYPE_IQ4_XS) {
+                            else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
                                 const int mem_size = 32*sizeof(float);
                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
-                            else if (src2t == GGML_TYPE_Q4_K) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                            else if (src0t == GGML_TYPE_Q4_K) {
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
-                            else if (src2t == GGML_TYPE_Q3_K) {
+                            else if (src0t == GGML_TYPE_Q3_K) {
 #ifdef GGML_QKK_64
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
 #else
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
 #endif
                             }
-                            else if (src2t == GGML_TYPE_Q5_K) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                            else if (src0t == GGML_TYPE_Q5_K) {
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
-                            else if (src2t == GGML_TYPE_Q6_K) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                            else if (src0t == GGML_TYPE_Q6_K) {
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             } else {
                                 const int64_t ny = (_ne1 + nrows - 1)/nrows;
-                                [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
                         }
                     } break;
@@ -2432,6 +2411,16 @@ static enum ggml_status ggml_metal_graph_compute(
 
                         enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
 
+                        // bitonic sort requires the number of elements to be power of 2
+                        int64_t ne00_padded = 1;
+                        while (ne00_padded < ne00) {
+                            ne00_padded *= 2;
+                        }
+
+                        // Metal kernels require the buffer size to be multiple of 16 bytes
+                        // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
+                        const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
+
                         id<MTLComputePipelineState> pipeline = nil;
 
                         switch (order) {
@@ -2441,11 +2430,13 @@ static enum ggml_status ggml_metal_graph_compute(
                         };
 
                         [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
-                        [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
-                        [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
+                        [encoder setBuffer:id_src0     offset:offs_src0        atIndex:0];
+                        [encoder setBuffer:id_dst      offset:offs_dst         atIndex:1];
+                        [encoder setBytes:&ne00        length:sizeof( int64_t) atIndex:2];
+                        [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
+                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
 
-                        [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
+                        [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
                     } break;
                 case GGML_OP_LEAKY_RELU:
                     {
index 744b2a8b4ce42c3e5982aceb478617ad7e1258a5..9a29f57a38c6b7cf501e09e034c6f66631074cdf 100644 (file)
@@ -13,8 +13,8 @@ using namespace metal;
 #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
 
 enum ggml_sort_order {
-    GGML_SORT_ASC,
-    GGML_SORT_DESC,
+    GGML_SORT_ORDER_ASC,
+    GGML_SORT_ORDER_DESC,
 };
 
 // general-purpose kernel for addition, multiplication and division of two tensors
@@ -1973,9 +1973,11 @@ kernel void kernel_timestep_embedding_f32(
 
 // bitonic sort implementation following the CUDA kernels as reference
 typedef void (argsort_t)(
-        device const float * x,
-        device     int32_t * dst,
-        constant   int64_t & ncols,
+        device const float  * x,
+        device     int32_t  * dst,
+        constant   int64_t  & ncols,
+        constant   int64_t  & ncols_pad,
+        threadgroup int32_t * shared_values [[threadgroup(0)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]]);
 
@@ -1984,33 +1986,42 @@ kernel void kernel_argsort_f32_i32(
         device const float   * x,
         device       int32_t * dst,
         constant     int64_t & ncols,
+        constant     int64_t & ncols_pad,
+        threadgroup int32_t  * shared_values [[threadgroup(0)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]]) {
     // bitonic sort
     int col = tpitg[0];
     int row = tgpig[1];
 
-    if (col >= ncols) return;
+    if (col >= ncols_pad) return;
 
-    device const float   * x_row   = x   + row * ncols;
-    device       int32_t * dst_row = dst + row * ncols;
+    device const float   * x_row   = x + row * ncols;
+    threadgroup int32_t  * dst_row = shared_values;
 
     // initialize indices
-    if (col < ncols) {
-        dst_row[col] = col;
-    }
+    dst_row[col] = col;
+
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    for (int k = 2; k <= ncols; k *= 2) {
+    for (int k = 2; k <= ncols_pad; k *= 2) {
         for (int j = k / 2; j > 0; j /= 2) {
             int ixj = col ^ j;
             if (ixj > col) {
                 if ((col & k) == 0) {
-                    if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
+                    if (dst_row[col] >= ncols ||
+                        (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
+                            x_row[dst_row[col]] > x_row[dst_row[ixj]] :
+                            x_row[dst_row[col]] < x_row[dst_row[ixj]]))
+                    ) {
                         SWAP(dst_row[col], dst_row[ixj]);
                     }
                 } else {
-                    if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
+                    if (dst_row[ixj] >= ncols ||
+                        (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
+                            x_row[dst_row[col]] < x_row[dst_row[ixj]] :
+                            x_row[dst_row[col]] > x_row[dst_row[ixj]]))
+                    ) {
                         SWAP(dst_row[col], dst_row[ixj]);
                     }
                 }
@@ -2018,10 +2029,15 @@ kernel void kernel_argsort_f32_i32(
             threadgroup_barrier(mem_flags::mem_threadgroup);
         }
     }
+
+    // copy the result to dst without the padding
+    if (col < ncols) {
+        dst[row * ncols + col] = dst_row[col];
+    }
 }
 
-template [[host_name("kernel_argsort_f32_i32_asc")]]  kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
-template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
+template [[host_name("kernel_argsort_f32_i32_asc")]]  kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
+template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
 
 kernel void kernel_leaky_relu_f32(
         device const float * src0,
@@ -5785,9 +5801,10 @@ kernel void kernel_mul_mm(device const  uchar * src0,
 
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
 kernel void kernel_mul_mm_id(
-        device const   uchar * ids,
+        device const   uchar * src0s,
         device const   uchar * src1,
         device         float * dst,
+        device const   uchar * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne02,
@@ -5804,22 +5821,14 @@ kernel void kernel_mul_mm_id(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const   uchar * src00,
-        device const   uchar * src01,
-        device const   uchar * src02,
-        device const   uchar * src03,
-        device const   uchar * src04,
-        device const   uchar * src05,
-        device const   uchar * src06,
-        device const   uchar * src07,
         threadgroup    uchar * shared_memory [[threadgroup(0)]],
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
 
     // expert id
     const int32_t id = tgpig.z/(ne12*ne13);
+    device const uchar * src0 = src0s + id*nb02;
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
@@ -5834,7 +5843,7 @@ kernel void kernel_mul_mm_id(
     }
 
     kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
-        src0s[id],
+        src0,
         src1,
         src1ids,
         dst,
@@ -5960,9 +5969,10 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]]  kernel mat_mm_t kernel_mul_m
 //
 
 typedef void (mat_mm_id_t)(
-        device const   uchar * ids,
+        device const   uchar * src0s,
         device const   uchar * src1,
         device         float * dst,
+        device const   uchar * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne02,
@@ -5979,14 +5989,6 @@ typedef void (mat_mm_id_t)(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const   uchar * src00,
-        device const   uchar * src01,
-        device const   uchar * src02,
-        device const   uchar * src03,
-        device const   uchar * src04,
-        device const   uchar * src05,
-        device const   uchar * src06,
-        device const   uchar * src07,
         threadgroup    uchar *,
         uint3, uint, uint);
 
@@ -6022,9 +6024,10 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]]  kernel mat_mm_id_t kernel
 
 [[host_name("kernel_mul_mv_id_f32_f32")]]
 kernel void kernel_mul_mv_id_f32_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6045,28 +6048,19 @@ kernel void kernel_mul_mv_id_f32_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_f32_f32_impl(
-        src0[id],
+        src0,
         src1 + bid*nb11,
         dst  + bid*ne0,
         ne00,
@@ -6091,9 +6085,10 @@ kernel void kernel_mul_mv_id_f32_f32(
 
 [[host_name("kernel_mul_mv_id_f16_f32")]]
 kernel void kernel_mul_mv_id_f16_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6114,28 +6109,19 @@ kernel void kernel_mul_mv_id_f16_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_f16_f32_impl(
-        src0[id],
+        src0,
         src1 + bid*nb11,
         dst  + bid*ne0,
         ne00,
@@ -6160,9 +6146,10 @@ kernel void kernel_mul_mv_id_f16_f32(
 
 [[host_name("kernel_mul_mv_id_q8_0_f32")]]
 kernel void kernel_mul_mv_id_q8_0_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6183,28 +6170,19 @@ kernel void kernel_mul_mv_id_q8_0_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_q8_0_f32_impl(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -6223,9 +6201,10 @@ kernel void kernel_mul_mv_id_q8_0_f32(
 
 [[host_name("kernel_mul_mv_id_q4_0_f32")]]
 kernel void kernel_mul_mv_id_q4_0_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6246,28 +6225,19 @@ kernel void kernel_mul_mv_id_q4_0_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -6286,9 +6256,10 @@ kernel void kernel_mul_mv_id_q4_0_f32(
 
 [[host_name("kernel_mul_mv_id_q4_1_f32")]]
 kernel void kernel_mul_mv_id_q4_1_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6309,28 +6280,19 @@ kernel void kernel_mul_mv_id_q4_1_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -6349,9 +6311,10 @@ kernel void kernel_mul_mv_id_q4_1_f32(
 
 [[host_name("kernel_mul_mv_id_q5_0_f32")]]
 kernel void kernel_mul_mv_id_q5_0_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6372,28 +6335,19 @@ kernel void kernel_mul_mv_id_q5_0_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -6412,9 +6366,10 @@ kernel void kernel_mul_mv_id_q5_0_f32(
 
 [[host_name("kernel_mul_mv_id_q5_1_f32")]]
 kernel void kernel_mul_mv_id_q5_1_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6435,28 +6390,19 @@ kernel void kernel_mul_mv_id_q5_1_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -6475,9 +6421,10 @@ kernel void kernel_mul_mv_id_q5_1_f32(
 
 [[host_name("kernel_mul_mv_id_q2_K_f32")]]
 kernel void kernel_mul_mv_id_q2_K_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6498,28 +6445,19 @@ kernel void kernel_mul_mv_id_q2_K_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_q2_K_f32_impl(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -6538,9 +6476,10 @@ kernel void kernel_mul_mv_id_q2_K_f32(
 
 [[host_name("kernel_mul_mv_id_q3_K_f32")]]
 kernel void kernel_mul_mv_id_q3_K_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6561,28 +6500,19 @@ kernel void kernel_mul_mv_id_q3_K_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_q3_K_f32_impl(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -6601,9 +6531,10 @@ kernel void kernel_mul_mv_id_q3_K_f32(
 
 [[host_name("kernel_mul_mv_id_q4_K_f32")]]
 kernel void kernel_mul_mv_id_q4_K_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6624,28 +6555,19 @@ kernel void kernel_mul_mv_id_q4_K_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_q4_K_f32_impl(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -6664,9 +6586,10 @@ kernel void kernel_mul_mv_id_q4_K_f32(
 
 [[host_name("kernel_mul_mv_id_q5_K_f32")]]
 kernel void kernel_mul_mv_id_q5_K_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6687,28 +6610,19 @@ kernel void kernel_mul_mv_id_q5_K_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_q5_K_f32_impl(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -6727,9 +6641,10 @@ kernel void kernel_mul_mv_id_q5_K_f32(
 
 [[host_name("kernel_mul_mv_id_q6_K_f32")]]
 kernel void kernel_mul_mv_id_q6_K_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6750,28 +6665,19 @@ kernel void kernel_mul_mv_id_q6_K_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_q6_K_f32_impl(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -6790,9 +6696,10 @@ kernel void kernel_mul_mv_id_q6_K_f32(
 
 [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
 kernel void kernel_mul_mv_id_iq2_xxs_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6813,29 +6720,20 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         threadgroup int8_t   * shared_values [[threadgroup(0)]],
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_iq2_xxs_f32_impl(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -6855,9 +6753,10 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
 
 [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
 kernel void kernel_mul_mv_id_iq2_xs_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6878,29 +6777,20 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         threadgroup int8_t   * shared_values [[threadgroup(0)]],
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_iq2_xs_f32_impl(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -6920,9 +6810,10 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
 
 [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
 kernel void kernel_mul_mv_id_iq3_xxs_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -6943,29 +6834,20 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         threadgroup int8_t   * shared_values [[threadgroup(0)]],
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_iq3_xxs_f32_impl(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -6985,9 +6867,10 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
 
 [[host_name("kernel_mul_mv_id_iq3_s_f32")]]
 kernel void kernel_mul_mv_id_iq3_s_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -7008,29 +6891,20 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         threadgroup int8_t   * shared_values [[threadgroup(0)]],
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_iq3_s_f32_impl(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -7050,9 +6924,10 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
 
 [[host_name("kernel_mul_mv_id_iq2_s_f32")]]
 kernel void kernel_mul_mv_id_iq2_s_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -7073,29 +6948,20 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         threadgroup int8_t   * shared_values [[threadgroup(0)]],
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_iq2_s_f32_impl(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -7115,9 +6981,10 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
 
 [[host_name("kernel_mul_mv_id_iq1_s_f32")]]
 kernel void kernel_mul_mv_id_iq1_s_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -7138,28 +7005,19 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_iq1_s_f32_impl(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -7178,9 +7036,10 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
 
 [[host_name("kernel_mul_mv_id_iq1_m_f32")]]
 kernel void kernel_mul_mv_id_iq1_m_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -7201,28 +7060,19 @@ kernel void kernel_mul_mv_id_iq1_m_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_iq1_m_f32_impl(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -7241,9 +7091,10 @@ kernel void kernel_mul_mv_id_iq1_m_f32(
 
 [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
 kernel void kernel_mul_mv_id_iq4_nl_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -7264,29 +7115,20 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         threadgroup float    * shared_values [[threadgroup(0)]],
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
     kernel_mul_mv_iq4_nl_f32_impl(
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
@@ -7306,9 +7148,10 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
 
 [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
 kernel void kernel_mul_mv_id_iq4_xs_f32(
-        device const    char * ids,
+        device const    char * src0s,
         device const    char * src1,
         device         float * dst,
+        device const    char * ids,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -7329,33 +7172,24 @@ kernel void kernel_mul_mv_id_iq4_xs_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
-        device const    char * src00,
-        device const    char * src01,
-        device const    char * src02,
-        device const    char * src03,
-        device const    char * src04,
-        device const    char * src05,
-        device const    char * src06,
-        device const    char * src07,
         threadgroup float    * shared_values [[threadgroup(0)]],
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
     const int64_t bid = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    device const char * src0 = src0s + id*nb02;
 
 #if QK_K == 64
     kernel_mul_mv_iq4_nl_f32_impl(
 #else
     kernel_mul_mv_iq4_xs_f32_impl(
 #endif
-        src0[id],
+        src0,
         (device const float *) (src1 + bid*nb11),
         dst + bid*ne0,
         ne00,
diff --git a/ggml.c b/ggml.c
index 7471e792606c1524715d02d6f135c1fd4e846564..c9b0a6a0ef776af3a453d21c0575df97a8cc807a 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -4573,45 +4573,38 @@ void ggml_mul_mat_set_prec(
 
 // ggml_mul_mat_id
 
+// NOTE: id will be removed in the future and instead all the experts listed in ids will be computed
+//       this will allow computing all the used experts in a single matrix multiplication
 struct ggml_tensor * ggml_mul_mat_id(
         struct ggml_context * ctx,
-        struct ggml_tensor  * const as[],
-        int                   n_as,
+        struct ggml_tensor  * as,
         struct ggml_tensor  * ids,
         int                   id,
         struct ggml_tensor  * b) {
 
     GGML_ASSERT(ids->type == GGML_TYPE_I32);
-    GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1);
-    GGML_ASSERT(ids->ne[1] == b->ne[1]);
+    GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
+    GGML_ASSERT(ids->ne[1] == b->ne[1]); // must have an expert per b row
     GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
-    GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2);
-    GGML_ASSERT(id >= 0 && id < ids->ne[0]);
+    GGML_ASSERT(id >= 0 && id < ids->ne[0]); // valid id
+    GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat
 
     bool is_node = false;
 
-    if (as[0]->grad || b->grad) {
+    if (as->grad || b->grad) {
         is_node = true;
     }
 
-    const int64_t ne[4] = { as[0]->ne[1], b->ne[1], b->ne[2], b->ne[3] };
+    const int64_t ne[4] = { as->ne[1], b->ne[1], b->ne[2], b->ne[3] };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
     ggml_set_op_params_i32(result, 0, id);
-    ggml_set_op_params_i32(result, 1, n_as);
 
     result->op   = GGML_OP_MUL_MAT_ID;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = ids;
+    result->src[0] = as;
     result->src[1] = b;
-
-    for (int i = 0; i < n_as; i++) {
-        struct ggml_tensor * a = as[i];
-        GGML_ASSERT(ggml_are_same_shape(as[0], a));
-        GGML_ASSERT(ggml_can_mul_mat(a, b));
-        GGML_ASSERT(!ggml_is_transposed(a));
-        result->src[i + 2] = a;
-    }
+    result->src[2] = ids;
 
     return result;
 }
@@ -10948,10 +10941,9 @@ static void ggml_compute_forward_mul_mat_id(
         const struct ggml_compute_params * params,
               struct ggml_tensor * dst) {
 
-    const struct ggml_tensor * ids = dst->src[0];
+    const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
-
-    const struct ggml_tensor * src0 = dst->src[2]; // only for GGML_TENSOR_BINARY_OP_LOCALS
+    const struct ggml_tensor * ids = dst->src[2];
 
     GGML_TENSOR_BINARY_OP_LOCALS
 
@@ -10981,13 +10973,13 @@ static void ggml_compute_forward_mul_mat_id(
     GGML_ASSERT(nb1 <= nb2);
     GGML_ASSERT(nb2 <= nb3);
 
-    // broadcast factors
-    const int64_t r2 = ne12/ne02;
-    const int64_t r3 = ne13/ne03;
+    // broadcast is not supported with mmid
+    assert(ne12 == 1);
+    assert(ne13 == 1);
 
     // row groups
     const int id   = ggml_get_op_params_i32(dst, 0);
-    const int n_as = ggml_get_op_params_i32(dst, 1);
+    const int n_as = src0->ne[2];
 
     char * wdata_src1_end = (src1->type == vec_dot_type) ?
             (char *) params->wdata :
@@ -11047,7 +11039,7 @@ static void ggml_compute_forward_mul_mat_id(
             continue;
         }
 
-        const struct ggml_tensor * src0_cur = dst->src[cur_a + 2];
+        size_t src0_offset = cur_a*src0->nb[2];
 
         const void * wdata    = (src1->type == vec_dot_type) ? src1->data : params->wdata;
         const size_t row_size = ggml_row_size(vec_dot_type, ne10);
@@ -11082,9 +11074,6 @@ static void ggml_compute_forward_mul_mat_id(
             continue;
         }
 
-        assert(ne12 % ne02 == 0);
-        assert(ne13 % ne03 == 0);
-
         // block-tiling attempt
         const int64_t blck_0 = 16;
         const int64_t blck_1 = 16;
@@ -11101,14 +11090,14 @@ static void ggml_compute_forward_mul_mat_id(
                     const int64_t  i11 = MMID_MATRIX_ROW(cur_a, _i11);
 
                     // broadcast src0 into src1
-                    const int64_t i03 = i13/r3;
-                    const int64_t i02 = i12/r2;
+                    //const int64_t i03 = i13/r3;
+                    //const int64_t i02 = i12/r2;
 
                     const int64_t i1 = i11;
                     const int64_t i2 = i12;
                     const int64_t i3 = i13;
 
-                    const char * src0_row = (const char *) src0_cur->data + (0 + i02*nb02 + i03*nb03);
+                    const char * src0_row = (const char *) src0->data + src0_offset;
 
                     // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
                     //       if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
@@ -18464,13 +18453,13 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
             case GGML_OP_MUL_MAT_ID:
                 {
                     cur = 0;
-                    const struct ggml_tensor * src0 = node->src[2];
+                    const struct ggml_tensor * src0 = node->src[0];
                     const struct ggml_tensor * src1 = node->src[1];
                     const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
                     if (src1->type != vec_dot_type) {
                         cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
                     }
-                    const int n_as = ggml_get_op_params_i32(node, 1);
+                    const int n_as = src0->ne[2];
                     cur += GGML_PAD(cur, sizeof(int64_t));       // align
                     cur += n_as * sizeof(int64_t);               // matrix_row_counts
                     cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows
diff --git a/ggml.h b/ggml.h
index 5d4a4ceb65c7e106bf2008ba82089fe8d4d7a83f..5cef45c0ba4ad13fefe8450f874918e6e74ed354 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -1164,8 +1164,7 @@ extern "C" {
     //  ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
     GGML_API struct ggml_tensor * ggml_mul_mat_id(
             struct ggml_context * ctx,
-            struct ggml_tensor  * const as[],
-            int                   n_as,
+            struct ggml_tensor  * as,
             struct ggml_tensor  * ids,
             int                   id,
             struct ggml_tensor  * b);
index 27eaf723cd85fdb9fb282b6b8a12f8379fffe15f..f468802d1d4030b7e576a72d4c2f118e30a7216a 100644 (file)
@@ -221,9 +221,9 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.FFN_DOWN:        "blk.{bid}.ffn_down",
     MODEL_TENSOR.FFN_UP:          "blk.{bid}.ffn_up",
     MODEL_TENSOR.FFN_ACT:         "blk.{bid}.ffn",
-    MODEL_TENSOR.FFN_GATE_EXP:    "blk.{bid}.ffn_gate.{xid}",
-    MODEL_TENSOR.FFN_DOWN_EXP:    "blk.{bid}.ffn_down.{xid}",
-    MODEL_TENSOR.FFN_UP_EXP:      "blk.{bid}.ffn_up.{xid}",
+    MODEL_TENSOR.FFN_GATE_EXP:    "blk.{bid}.ffn_gate_exps",
+    MODEL_TENSOR.FFN_DOWN_EXP:    "blk.{bid}.ffn_down_exps",
+    MODEL_TENSOR.FFN_UP_EXP:      "blk.{bid}.ffn_up_exps",
     MODEL_TENSOR.LAYER_OUT_NORM:  "blk.{bid}.layer_output_norm",
     MODEL_TENSOR.SSM_IN:          "blk.{bid}.ssm_in",
     MODEL_TENSOR.SSM_CONV1D:      "blk.{bid}.ssm_conv1d",
index 11fd34b8b91038cd5aa10999125764a839852da1..93a5a455ee770dbfab70a9c5d511869fccec78ac 100644 (file)
@@ -231,9 +231,8 @@ class TensorNameMap:
         ),
 
         MODEL_TENSOR.FFN_UP_EXP: (
-            "layers.{bid}.feed_forward.experts.{xid}.w3",           # mixtral
-            "model.layers.{bid}.block_sparse_moe.experts.{xid}.w3", # mixtral
-            "transformer.decoder_layer.{bid}.moe.{xid}.linear_v",   # Grok
+            "layers.{bid}.feed_forward.experts.w3",                 # mixtral (merged)
+            "transformer.decoder_layer.{bid}.moe.linear_v",         # Grok (merged)
         ),
 
         # AWQ-activation gate
@@ -252,9 +251,8 @@ class TensorNameMap:
         ),
 
         MODEL_TENSOR.FFN_GATE_EXP: (
-            "layers.{bid}.feed_forward.experts.{xid}.w1",           # mixtral
-            "model.layers.{bid}.block_sparse_moe.experts.{xid}.w1", # mixtral
-            "transformer.decoder_layer.{bid}.moe.{xid}.linear"      # Grok
+            "layers.{bid}.feed_forward.experts.w1",                 # mixtral (merged)
+            "transformer.decoder_layer.{bid}.moe.linear"            # Grok (merged)
         ),
 
         # Feed-forward down
@@ -280,10 +278,8 @@ class TensorNameMap:
         ),
 
         MODEL_TENSOR.FFN_DOWN_EXP: (
-            "layers.{bid}.feed_forward.experts.{xid}.w2",           # mixtral
-            "model.layers.{bid}.block_sparse_moe.experts.{xid}.w2", # mixtral
-            "transformer.decoder_layer.{bid}.moe.{xid}.linear_1",   # Grok
-
+            "layers.{bid}.feed_forward.experts.w2",                 # mixtral (merged)
+            "transformer.decoder_layer.{bid}.moe.linear_1",         # Grok (merged)
         ),
 
         MODEL_TENSOR.ATTN_Q_NORM: (
index 96396e04e8f3f4437b2264539123e2f4596674da..13cbfffbcabb17cab200d1ed913cb57b737addcf 100644 (file)
@@ -1,6 +1,6 @@
 [tool.poetry]
 name = "gguf"
-version = "0.8.0"
+version = "0.9.0"
 description = "Read and write ML models in GGUF for GGML"
 authors = ["GGML <ggml@ggml.ai>"]
 packages = [
index 21e7a067af65f8a07f13e572a0753a941c784d6c..2df03f99010fa38cd6419fdfd115066a66bc1644 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -426,9 +426,12 @@ enum llm_tensor {
     LLM_TENSOR_FFN_DOWN,
     LLM_TENSOR_FFN_UP,
     LLM_TENSOR_FFN_ACT,
-    LLM_TENSOR_FFN_DOWN_EXP,
+    LLM_TENSOR_FFN_DOWN_EXP,  // split experts for backward compatibility
     LLM_TENSOR_FFN_GATE_EXP,
     LLM_TENSOR_FFN_UP_EXP,
+    LLM_TENSOR_FFN_DOWN_EXPS, // merged experts
+    LLM_TENSOR_FFN_GATE_EXPS,
+    LLM_TENSOR_FFN_UP_EXPS,
     LLM_TENSOR_ATTN_Q_NORM,
     LLM_TENSOR_ATTN_K_NORM,
     LLM_TENSOR_LAYER_OUT_NORM,
@@ -463,6 +466,9 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
             { LLM_TENSOR_FFN_GATE_EXP,    "blk.%d.ffn_gate.%d" },
             { LLM_TENSOR_FFN_DOWN_EXP,    "blk.%d.ffn_down.%d" },
             { LLM_TENSOR_FFN_UP_EXP,      "blk.%d.ffn_up.%d" },
+            { LLM_TENSOR_FFN_GATE_EXPS,   "blk.%d.ffn_gate_exps" },
+            { LLM_TENSOR_FFN_DOWN_EXPS,   "blk.%d.ffn_down_exps" },
+            { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
         },
     },
     {
@@ -516,6 +522,9 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
             { LLM_TENSOR_FFN_GATE_EXP,    "blk.%d.ffn_gate.%d" },
             { LLM_TENSOR_FFN_DOWN_EXP,    "blk.%d.ffn_down.%d" },
             { LLM_TENSOR_FFN_UP_EXP,      "blk.%d.ffn_up.%d" },
+            { LLM_TENSOR_FFN_GATE_EXPS,   "blk.%d.ffn_gate_exps" },
+            { LLM_TENSOR_FFN_DOWN_EXPS,   "blk.%d.ffn_down_exps" },
+            { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
             { LLM_TENSOR_LAYER_OUT_NORM,  "blk.%d.layer_output_norm" },
             { LLM_TENSOR_ATTN_OUT_NORM,   "blk.%d.attn_output_norm" },
         },
@@ -1864,9 +1873,9 @@ struct llama_layer {
 
     // ff MoE
     struct ggml_tensor * ffn_gate_inp;
-    struct ggml_tensor * ffn_gate_exp[LLAMA_MAX_EXPERTS];
-    struct ggml_tensor * ffn_down_exp[LLAMA_MAX_EXPERTS];
-    struct ggml_tensor * ffn_up_exp  [LLAMA_MAX_EXPERTS];
+    struct ggml_tensor * ffn_gate_exps;
+    struct ggml_tensor * ffn_down_exps;
+    struct ggml_tensor * ffn_up_exp;
 
     // ff bias
     struct ggml_tensor * ffn_down_b; // b2
@@ -2868,19 +2877,19 @@ struct llama_model_loader {
 
     llama_mmaps mappings;
 
-    // Holds information on a model weights
-    struct llama_tensor_weights {
+    // Holds information on a model weight
+    struct llama_tensor_weight {
         uint16_t  idx; // source file index
         size_t   offs; // tensor data offset in the original file
 
         ggml_tensor * tensor;
 
-        llama_tensor_weights(uint16_t idx, const char * name, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) {
+        llama_tensor_weight(uint16_t idx, const char * name, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) {
             const int tensor_idx = gguf_find_tensor(gguf_ctx, name);
             offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx);
         }
     };
-    std::vector<llama_tensor_weights> weights;
+    std::vector<llama_tensor_weight> weights;
 
     std::unordered_map<std::string, struct llama_model_kv_override> kv_overrides;
 
@@ -2920,7 +2929,7 @@ struct llama_model_loader {
         // For subsidiary files, `meta` tensor data offset must not be used,
         // so we build a unified tensors index for weights.
         for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
-            weights.emplace_back(llama_tensor_weights(0, cur->name, meta, cur));
+            weights.emplace_back(0, cur->name, meta, cur);
         }
         files.emplace_back(new llama_file(fname.c_str(), "rb"));
         contexts.emplace_back(ctx);
@@ -2960,7 +2969,7 @@ struct llama_model_loader {
 
                 // Save tensors data offset info of the shard.
                 for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
-                    weights.emplace_back(llama_tensor_weights(idx, cur->name, ctx_gguf, cur));
+                    weights.emplace_back(idx, cur->name, ctx_gguf, cur);
                 }
                 files.emplace_back(new llama_file(split_path, "rb"));
                 contexts.emplace_back(ctx);
@@ -3164,21 +3173,37 @@ struct llama_model_loader {
         return weights.at(i).tensor->name;
     }
 
-    const llama_tensor_weights & get_weights(const char * name) const {
+    const llama_tensor_weight * get_weight(const char * name) const {
         for (const auto & weight : weights) {
             if (strcmp(name, weight.tensor->name) == 0) {
-                return weight;
+                return &weight;
             }
         }
-        throw std::runtime_error(format("tensor %s not found", name));
+        return nullptr;
+    }
+
+    const llama_tensor_weight & require_weight(const char * name) const {
+        const llama_tensor_weight * weight = get_weight(name);
+        if (!weight) {
+            throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name));
+        }
+        return *weight;
     }
 
     struct ggml_tensor * get_tensor_meta(const char * name) const {
-        try {
-            return get_weights(name).tensor;
-        } catch (const std::runtime_error & e) {
-            return NULL;
+        const auto * weight = get_weight(name);
+        if (!weight) {
+            return nullptr;
+        }
+        return weight->tensor;
+    }
+
+    struct ggml_tensor * require_tensor_meta(const char * name) const {
+        struct ggml_tensor * tensor = get_tensor_meta(name);
+        if (!tensor) {
+            throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name));
         }
+        return tensor;
     }
 
     struct ggml_tensor * get_tensor_meta(int i) const {
@@ -3194,7 +3219,7 @@ struct llama_model_loader {
         return tensor;
     }
 
-    struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector<int64_t> & ne, bool required = true) {
+    const struct ggml_tensor * check_tensor_dims(const std::string & name, const std::vector<int64_t> & ne, bool required) const {
         const struct ggml_tensor * cur = get_tensor_meta(name.c_str());
 
         if (cur == NULL) {
@@ -3206,8 +3231,8 @@ struct llama_model_loader {
 
         {
             bool is_ok = true;
-            for (size_t i = 0; i < ne.size(); ++i) {
-                if (ne[i] != cur->ne[i]) {
+            for (size_t i = 0; i < GGML_MAX_DIMS; ++i) {
+                if ((i < ne.size() && ne[i] != cur->ne[i]) || (i >= ne.size() && cur->ne[i] != 1)) {
                     is_ok = false;
                     break;
                 }
@@ -3221,9 +3246,47 @@ struct llama_model_loader {
             }
         }
 
+        return cur;
+    }
+
+    struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector<int64_t> & ne, bool required = true) {
+        const struct ggml_tensor * cur = check_tensor_dims(name, ne, required);
+
+        if (cur == NULL) {
+            return NULL;
+        }
+
         return create_tensor_for(ctx, cur);
     }
 
+    struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::vector<int64_t> & ne, size_t offset, bool required = true) {
+        const struct ggml_tensor * cur = check_tensor_dims(name, ne, required);
+
+        if (cur == NULL) {
+            return NULL;
+        }
+
+        if (cur->type != base->type) {
+            throw std::runtime_error(format("%s: tensor '%s' has wrong type; expected %s, got %s", __func__, name.c_str(), ggml_type_name(base->type), ggml_type_name(cur->type)));
+        }
+
+        std::array<int64_t, GGML_MAX_DIMS> dims;
+        for (size_t i = 0; i < GGML_MAX_DIMS; ++i) {
+            dims[i] = i < ne.size() ? ne[i] : 1;
+        }
+
+        struct ggml_tensor * tensor = ggml_view_4d(ctx, base,
+                                        dims[0], dims[1], dims[2], dims[3],
+                                        cur->nb[1], cur->nb[2], cur->nb[3],
+                                        offset);
+
+        ggml_set_name(tensor, name.c_str());
+
+        n_created++;
+
+        return tensor;
+    }
+
     void done_getting_tensors() const {
         if (n_created != n_tensors) {
             throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created));
@@ -3236,7 +3299,7 @@ struct llama_model_loader {
             mmaps_used.reserve(files.size());
             for (const auto & file : files) {
                 std::unique_ptr<llama_mmap> mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, ggml_is_numa()));
-                mmaps_used.emplace_back(std::make_pair(mapping->size, 0));
+                mmaps_used.emplace_back(mapping->size, 0);
                 if (mlock_mmaps) {
                     std::unique_ptr<llama_mlock> mlock_mmap(new llama_mlock());
                     mlock_mmap->init(mapping->addr);
@@ -3260,18 +3323,25 @@ struct llama_model_loader {
         *last  = 0;
         *addr = mapping->addr;
         for (ggml_tensor * tensor = ggml_get_first_tensor(ctx); tensor; tensor = ggml_get_next_tensor(ctx, tensor)) {
-            const auto & w = get_weights(ggml_get_name(tensor));
-            if (w.idx != idx) {
-                continue;
+            try {
+                const auto * weight = get_weight(ggml_get_name(tensor));
+                if (!weight) {
+                    continue;
+                }
+                if (weight->idx != idx) {
+                    continue;
+                }
+                *first = std::min(*first, weight->offs);
+                *last  = std::max(*last,  weight->offs + ggml_nbytes(tensor));
+            } catch(...) {
+                // the tensor is not in the model
             }
-            *first = std::min(*first, w.offs);
-            *last  = std::max(*last,  w.offs + ggml_nbytes(tensor));
         }
     }
 
     // for backwards compatibility, does not support ggml-backend
     void load_data_for(struct ggml_tensor * cur) const {
-        const auto & w = get_weights(ggml_get_name(cur));
+        const auto & w = require_weight(ggml_get_name(cur));
 
         if (use_mmap) {
             const auto & mapping = mappings.at(w.idx);
@@ -3304,44 +3374,49 @@ struct llama_model_loader {
 
         std::vector<no_init<uint8_t>> read_buf;
         for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur != NULL; cur = ggml_get_next_tensor(ctx, cur)) {
+            const auto * weight = get_weight(ggml_get_name(cur));
+            if (weight == nullptr) {
+                // this can happen with split experts models
+                continue;
+            }
+
             if (progress_callback) {
                 if (!progress_callback((float) size_done / size_data, progress_callback_user_data)) {
                     return false;
                 }
             }
 
-            const auto & w = get_weights(ggml_get_name(cur));
             size_t n_size = ggml_nbytes(cur);
 
             if (use_mmap) {
-                const auto & mapping = mappings.at(w.idx);
+                const auto & mapping = mappings.at(weight->idx);
                 ggml_backend_buffer_t buf_mmap = nullptr;
-                if (bufs_mmap.count(w.idx)) {
-                    buf_mmap = bufs_mmap.at(w.idx);
+                if (bufs_mmap.count(weight->idx)) {
+                    buf_mmap = bufs_mmap.at(weight->idx);
                 }
                 GGML_ASSERT(buf_mmap || cur->data); // either we have a buffer to allocate the tensor in, or it is already allocated
                 if (buf_mmap && cur->data == nullptr) {
-                    ggml_backend_tensor_alloc(buf_mmap, cur, (uint8_t *) mapping->addr + w.offs);
+                    ggml_backend_tensor_alloc(buf_mmap, cur, (uint8_t *) mapping->addr + weight->offs);
                     if (lmlocks) {
-                        const auto & lmlock = lmlocks->at(w.idx);
-                        lmlock->grow_to(w.offs + ggml_nbytes(cur));
+                        const auto & lmlock = lmlocks->at(weight->idx);
+                        lmlock->grow_to(weight->offs + ggml_nbytes(cur));
                     }
 
-                    auto & mmap_used = mmaps_used[w.idx];
-                    mmap_used.first  = std::min(mmap_used.first,  w.offs);
-                    mmap_used.second = std::max(mmap_used.second, w.offs + n_size);
+                    auto & mmap_used = mmaps_used[weight->idx];
+                    mmap_used.first  = std::min(mmap_used.first,  weight->offs);
+                    mmap_used.second = std::max(mmap_used.second, weight->offs + n_size);
                 } else {
-                    ggml_backend_tensor_set(cur, (uint8_t *) mapping->addr + w.offs, 0, n_size);
+                    ggml_backend_tensor_set(cur, (uint8_t *) mapping->addr + weight->offs, 0, n_size);
                 }
             } else {
-                GGML_ASSERT(w.idx < files.size());
-                const auto & file = files.at(w.idx);
+                GGML_ASSERT(weight->idx < files.size());
+                const auto & file = files.at(weight->idx);
                 if (ggml_backend_buffer_is_host(cur->buffer)) {
-                    file->seek(w.offs, SEEK_SET);
+                    file->seek(weight->offs, SEEK_SET);
                     file->read_raw(cur->data, ggml_nbytes(cur));
                 } else {
                     read_buf.resize(ggml_nbytes(cur));
-                    file->seek(w.offs, SEEK_SET);
+                    file->seek(weight->offs, SEEK_SET);
                     file->read_raw(read_buf.data(), ggml_nbytes(cur));
                     ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size);
                 }
@@ -4270,6 +4345,7 @@ static bool llm_load_tensors(
 
     const int64_t n_layer     = hparams.n_layer;
     const int64_t i_gpu_start = std::max((int64_t) hparams.n_layer - n_gpu_layers, (int64_t) 0);
+    bool use_mmap_buffer = true;
 
     // there is very little benefit to offloading the input layer, so always keep it on the CPU
     model.buft_input = llama_default_buffer_type_cpu(true);
@@ -4358,6 +4434,10 @@ static bool llm_load_tensors(
 
     // create one context per buffer type
     size_t ctx_size = ggml_tensor_overhead()*(ml.n_tensors + 1); // +1 for models where tok_embd is duplicated as output
+
+    // for moe merged tensors
+    ctx_size += ggml_tensor_overhead()*hparams.n_expert*n_layer;
+
     std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
     for (auto & it : buft_layer_count) {
         struct ggml_init_params params = {
@@ -4384,6 +4464,11 @@ static bool llm_load_tensors(
         const int64_t n_vocab      = hparams.n_vocab;
         const int64_t n_vocab_type = hparams.n_vocab_type;
         const int64_t n_ff         = hparams.n_ff;
+        const int64_t n_expert     = hparams.n_expert;
+
+        if (n_expert > 0 && hparams.n_expert_used == 0) {
+            throw std::runtime_error("model has expert layers but no expert layers are used");
+        }
 
         GGML_ASSERT(n_embd_gqa == n_embd_k_gqa);
 
@@ -4438,30 +4523,50 @@ static bool llm_load_tensors(
 
                         layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
 
-                        layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd}, false);
-
-                        if (layer.ffn_gate_inp == nullptr) {
-                            GGML_ASSERT(hparams.n_expert      == 0);
-                            GGML_ASSERT(hparams.n_expert_used == 0);
-
+                        if (n_expert == 0) {
                             layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
                             layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
                             layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
                         } else {
-                            GGML_ASSERT(hparams.n_expert      > 0);
-                            GGML_ASSERT(hparams.n_expert_used > 0);
-
-                            // MoE branch
-                            for (uint32_t x = 0; x < hparams.n_expert; ++x) {
-                                layer.ffn_gate_exp[x] = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), {n_embd,   n_ff});
-                                layer.ffn_down_exp[x] = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x), {  n_ff, n_embd});
-                                layer.ffn_up_exp[x]   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXP,   "weight", i, x), {n_embd,   n_ff});
+                            layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
+
+                            layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd,   n_ff, n_expert}, false);
+                            if (layer.ffn_gate_exps) {
+                                layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert});
+                                layer.ffn_up_exps   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert});
+                            } else {
+                                // merge split expert into a single tensor for compatibility with older models
+                                // requires disabling mmap
+                                use_mmap_buffer = false;
+
+                                ggml_type type_gate = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, 0).c_str())->type;
+                                ggml_type type_down = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, 0).c_str())->type;
+                                ggml_type type_up   = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_UP_EXP,   "weight", i, 0).c_str())->type;
+
+                                layer.ffn_gate_exps = ggml_new_tensor_3d(ctx_split, type_gate, n_embd,   n_ff, n_expert);
+                                layer.ffn_down_exps = ggml_new_tensor_3d(ctx_split, type_down,   n_ff, n_embd, n_expert);
+                                layer.ffn_up_exps   = ggml_new_tensor_3d(ctx_split, type_up,   n_embd,   n_ff, n_expert);
+
+                                ggml_set_name(layer.ffn_gate_exps, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i).c_str());
+                                ggml_set_name(layer.ffn_down_exps, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i).c_str());
+                                ggml_set_name(layer.ffn_up_exps,   tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i).c_str());
+
+                                for (uint32_t x = 0; x < n_expert; ++x) {
+                                    // the individual experts are loaded into a view of the merged tensor
+                                    ml.create_tensor_as_view(ctx_split, layer.ffn_gate_exps, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), { n_embd, n_ff }, layer.ffn_gate_exps->nb[2]*x);
+                                    ml.create_tensor_as_view(ctx_split, layer.ffn_down_exps, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x), { n_ff, n_embd }, layer.ffn_down_exps->nb[2]*x);
+                                    ml.create_tensor_as_view(ctx_split, layer.ffn_up_exps,   tn(LLM_TENSOR_FFN_UP_EXP,   "weight", i, x), { n_embd, n_ff }, layer.ffn_up_exps->nb[2]*x);
+                                }
                             }
                         }
                     }
                 } break;
             case LLM_ARCH_GROK:
                 {
+                    if (n_expert == 0) {
+                        throw std::runtime_error("Grok model cannot have zero experts");
+                    }
+
                     model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
 
                     // output
@@ -4493,16 +4598,35 @@ static bool llm_load_tensors(
 
                         layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
 
-                        layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd});
+                        layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
 
-                        GGML_ASSERT(hparams.n_expert      > 0);
-                        GGML_ASSERT(hparams.n_expert_used > 0);
-
-                        // MoE branch
-                        for (uint32_t x = 0; x < hparams.n_expert; ++x) {
-                            layer.ffn_gate_exp[x] = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), {n_embd,   n_ff});
-                            layer.ffn_down_exp[x] = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x), {  n_ff, n_embd});
-                            layer.ffn_up_exp[x]   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXP,   "weight", i, x), {n_embd,   n_ff});
+                        layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, false);
+                        if (layer.ffn_gate_exps) {
+                            layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert});
+                            layer.ffn_up_exps   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert});
+                        } else {
+                            // merge split expert into a single tensor for compatibility with older models
+                            // requires disabling mmap
+                            use_mmap_buffer = false;
+
+                            ggml_type type_gate = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, 0).c_str())->type;
+                            ggml_type type_down = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, 0).c_str())->type;
+                            ggml_type type_up   = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_UP_EXP,   "weight", i, 0).c_str())->type;
+
+                            layer.ffn_gate_exps = ggml_new_tensor_3d(ctx_split, type_gate, n_embd,   n_ff, n_expert);
+                            layer.ffn_down_exps = ggml_new_tensor_3d(ctx_split, type_down,   n_ff, n_embd, n_expert);
+                            layer.ffn_up_exps   = ggml_new_tensor_3d(ctx_split, type_up,   n_embd,   n_ff, n_expert);
+
+                            ggml_set_name(layer.ffn_gate_exps, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i).c_str());
+                            ggml_set_name(layer.ffn_down_exps, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i).c_str());
+                            ggml_set_name(layer.ffn_up_exps,   tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i).c_str());
+
+                            for (uint32_t x = 0; x < n_expert; ++x) {
+                                // the individual experts are loaded into a view of the merged tensor
+                                ml.create_tensor_as_view(ctx_split, layer.ffn_gate_exps, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), { n_embd, n_ff }, layer.ffn_gate_exps->nb[2]*x);
+                                ml.create_tensor_as_view(ctx_split, layer.ffn_down_exps, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x), { n_ff, n_embd }, layer.ffn_down_exps->nb[2]*x);
+                                ml.create_tensor_as_view(ctx_split, layer.ffn_up_exps,   tn(LLM_TENSOR_FFN_UP_EXP,   "weight", i, x), { n_embd, n_ff }, layer.ffn_up_exps->nb[2]*x);
+                            }
                         }
 
                         layer.layer_out_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
@@ -5308,7 +5432,7 @@ static bool llm_load_tensors(
         // only the mmap region containing the tensors in the model is mapped to the backend buffer
         // this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, then we could just use metal for all layers
         // this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size
-        if (ml.use_mmap && buft == llama_default_buffer_type_cpu(true)) {
+        if (ml.use_mmap && use_mmap_buffer && buft == llama_default_buffer_type_cpu(true)) {
             for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
                 void * addr = nullptr;
                 size_t first, last;
@@ -5332,7 +5456,7 @@ static bool llm_load_tensors(
             }
         }
 #ifdef GGML_USE_METAL
-        else if (ml.use_mmap && buft == ggml_backend_metal_buffer_type()) {
+        else if (ml.use_mmap && use_mmap_buffer && buft == ggml_backend_metal_buffer_type()) {
             for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
                 const size_t max_size = ggml_get_max_tensor_size(ctx);
                 void * addr = nullptr;
@@ -5415,8 +5539,10 @@ static bool llm_load_tensors(
         }
     }
 
-    for (auto & mapping : ml.mappings) {
-        model.mappings.emplace_back(std::move(mapping));
+    if (use_mmap_buffer) {
+        for (auto & mapping : ml.mappings) {
+            model.mappings.emplace_back(std::move(mapping));
+        }
     }
 
     // loading time will be recalculate after the first eval, so
@@ -6284,19 +6410,19 @@ struct llm_build_context {
                 for (int i = 0; i < n_expert_used; ++i) {
                     ggml_tensor * cur_expert;
 
-                    ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exp, n_expert, selected_experts, i, cur);
+                    ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur);
                     cb(cur_up, "ffn_moe_up", il);
 
-                    ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exp, n_expert, selected_experts, i, cur);
+                    ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur);
                     cb(cur_gate, "ffn_moe_gate", il);
 
                     cur_gate = ggml_silu(ctx0, cur_gate);
                     cb(cur_gate, "ffn_moe_silu", il);
 
-                    cur_expert = ggml_mul(ctx0, cur_up, cur_gate); // [n_tokens, n_embd]
+                    cur_expert = ggml_mul(ctx0, cur_up, cur_gate);
                     cb(cur_expert, "ffn_moe_gate_par", il);
 
-                    cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exp, n_expert, selected_experts, i, cur_expert); // [n_tokens, n_embd]
+                    cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd]
                     cb(cur_expert, "ffn_moe_down", il);
 
                     cur_expert = ggml_mul(ctx0, cur_expert,
@@ -6818,20 +6944,20 @@ struct llm_build_context {
             for (int i = 0; i < n_expert_used; ++i) {
                 ggml_tensor * cur_expert;
 
-                ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exp, n_expert, selected_experts, i, cur);
+                ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur);
                 cb(cur_up, "ffn_moe_up", il);
 
-                ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exp, n_expert, selected_experts, i, cur);
+                ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur);
                 cb(cur_gate, "ffn_moe_gate", il);
 
                 //GeLU
                 cur_gate = ggml_gelu(ctx0, cur_gate);
                 cb(cur_gate, "ffn_moe_gelu", il);
 
-                cur_expert = ggml_mul(ctx0, cur_up, cur_gate); // [n_tokens, n_embd]
+                cur_expert = ggml_mul(ctx0, cur_up, cur_gate);
                 cb(cur_expert, "ffn_moe_gate_par", il);
 
-                cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exp, n_expert, selected_experts, i, cur_expert); // [n_tokens, n_embd]
+                cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd]
                 cb(cur_expert, "ffn_moe_down", il);
 
                 cur_expert = ggml_mul(ctx0, cur_expert,
@@ -12902,7 +13028,6 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
             // sprinkled in the model. Hence, simply dividing i_ffn_down by n_expert does not work
             // for getting the current layer as I initially thought, and we need to resort to parsing the
             // tensor name.
-            n_layer /= n_expert;
             if (sscanf(name, "blk.%d.", &i_layer) != 1) {
                 throw std::runtime_error(format("Failed to determine layer for tensor %s", name));
             }
@@ -13264,7 +13389,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         kv_overrides = v->data();
     }
     llama_model_loader ml(fname_inp, use_mmap, kv_overrides);
-    ml.init_mappings(false); // no prefetching?
+    ml.init_mappings(false); // no prefetching
 
     llama_model model;
     llm_load_arch(ml, model);
@@ -13316,20 +13441,15 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         // TODO: avoid hardcoded tensor names - use the TN_* constants
         if (name.find("attn_v.weight") != std::string::npos || name.find("attn_qkv.weight") != std::string::npos) {
             ++qs.n_attention_wv;
-        } else if (name.find("ffn_down") != std::string::npos) {
-            ++qs.n_ffn_down;
-        } else if (name.find("ffn_gate") != std::string::npos) {
-            ++qs.n_ffn_gate;
-        } else if (name.find("ffn_up") != std::string::npos) {
-            ++qs.n_ffn_up;
         } else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
             qs.has_output = true;
         }
     }
-    if (qs.n_attention_wv != qs.n_ffn_down || (uint32_t) qs.n_attention_wv != model.hparams.n_layer) {
-        LLAMA_LOG_WARN("%s ============ Strange model: n_attention_wv = %d, n_ffn_down = %d, hparams.n_layer = %d\n",
-                __func__, qs.n_attention_wv, qs.n_ffn_down, model.hparams.n_layer);
-    }
+
+    qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
+
+    // sanity checks
+    GGML_ASSERT(qs.n_attention_wv == (int)model.hparams.n_layer && "n_attention_wv != n_layer is unexpected");
 
     size_t total_size_org = 0;
     size_t total_size_new = 0;
@@ -13359,6 +13479,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
     // placeholder for the meta data
     ::zeros(fout, meta_size);
 
+    const auto tn = LLM_TN(model.arch);
+
     for (int i = 0; i < ml.n_tensors; ++i) {
         struct ggml_tensor * tensor = ml.get_tensor_meta(i);
 
@@ -13381,8 +13503,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         // This used to be a regex, but <regex> has an extreme cost to compile times.
         bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
 
-        // quantize only 2D tensors
-        quantize &= (ggml_n_dims(tensor) == 2);
+        // quantize only 2D and 3D tensors (experts)
+        quantize &= (ggml_n_dims(tensor) >= 2);
         quantize &= params->quantize_output_tensor || name != "output.weight";
         quantize &= !params->only_copy;
 
@@ -13437,11 +13559,20 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
                 if (it == imatrix_data->end()) {
                     LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name);
                 } else {
-                    if (it->second.size() == (size_t)tensor->ne[0]) {
+                    if (it->second.size() == (size_t)tensor->ne[0]*tensor->ne[2]) {
                         imatrix = it->second.data();
                     } else {
                         LLAMA_LOG_INFO("\n====== %s: imatrix size %d is different from tensor size %d for %s\n", __func__,
-                                int(it->second.size()), int(tensor->ne[0]), tensor->name);
+                                int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name);
+
+                        // this can happen when quantizing an old mixtral model with split tensors with a new incompatible imatrix
+                        // this is a significant error and it may be good idea to abort the process if this happens,
+                        // since many people will miss the error and not realize that most of the model is being quantized without an imatrix
+                        // tok_embd should be ignored in this case, since it always causes this warning
+                        if (name != tn(LLM_TENSOR_TOKEN_EMBD, "weight")) {
+                            throw std::runtime_error(format("imatrix size %d is different from tensor size %d for %s",
+                                    int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name));
+                        }
                     }
                 }
             }
@@ -13478,15 +13609,24 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
             new_data = work.data();
 
             const int n_per_row = tensor->ne[0];
-            const int nrows = nelements / n_per_row;
+            const int nrows = tensor->ne[1];
 
             static const int min_chunk_size = 32 * 512;
             const int chunk_size = n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row);
 
-            const int nchunk = (nelements + chunk_size - 1)/chunk_size;
+            const int nelements_matrix = tensor->ne[0] * tensor->ne[1];
+            const int nchunk = (nelements_matrix + chunk_size - 1)/chunk_size;
             const int nthread_use = nthread > 1 ? std::max(1, std::min(nthread, nchunk)) : 1;
-            new_size = llama_tensor_quantize_internal(new_type, f32_data, new_data, chunk_size, nrows, n_per_row, imatrix, workers, nthread_use);
 
+            // quantize each expert separately since they have different importance matrices
+            new_size = 0;
+            for (int64_t i03 = 0; i03 < tensor->ne[2]; ++i03) {
+                const float * f32_data_03 = f32_data + i03 * nelements_matrix;
+                void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows;
+                const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr;
+
+                new_size += llama_tensor_quantize_internal(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use);
+            }
             LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
         }
         total_size_org += ggml_nbytes(tensor);
index 5dfea5662eb0b5a9267d7e5a2a3a2a05e1803423..51b3487b2a948c70e21cf1dab4fd689beec29dae 100644 (file)
@@ -979,17 +979,13 @@ struct test_mul_mat_id : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         // C^T = A * B^T: (k, m) * (k, n) => (m, n)
-        std::vector<ggml_tensor *> mats;
-        for (int i = 0; i < n_mats; i++) {
-            ggml_tensor * a = ggml_new_tensor_2d(ctx, type_a, k, m);
-            mats.push_back(a);
-        }
+        ggml_tensor * mats = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats);
         ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n);
         if (v) {
             ids = ggml_view_2d(ctx, ids, n_mats/2, ids->ne[1], ids->nb[1], 0);
         }
         ggml_tensor * b = ggml_new_tensor_2d(ctx, type_b, k, n);
-        ggml_tensor * out = ggml_mul_mat_id(ctx, mats.data(), n_mats, ids, v ? id/2 : id, b);
+        ggml_tensor * out = ggml_mul_mat_id(ctx, mats, ids, v ? id/2 : id, b);
         return out;
     }
 
@@ -1477,91 +1473,6 @@ struct test_leaky_relu : public test_case {
     }
 };
 
-// Mixtral MOE
-struct test_moe : public test_case {
-    const int n_experts;
-    const int n_experts_per_tok;
-    const int n_tokens;
-    const int n_embd;
-    const int n_ff;
-
-    std::string op_desc(ggml_tensor * t) override {
-        return "MOE";
-
-        GGML_UNUSED(t);
-    }
-
-    std::string vars() override {
-        return VARS_TO_STR5(n_experts, n_experts_per_tok, n_tokens, n_embd, n_ff);
-    }
-
-    test_moe(int n_experts = 8, int n_experts_per_tok = 2, int n_tokens = 1, int n_embd = 4096, int n_ff = 14336)
-        : n_experts(n_experts), n_experts_per_tok(n_experts_per_tok), n_tokens(n_tokens), n_embd(n_embd), n_ff(n_ff) {
-    }
-
-    ggml_tensor * build_graph(ggml_context * ctx) override {
-        ggml_tensor * ffn_gate_inp = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_experts);
-
-        std::vector<ggml_tensor *> ffn_up_exp(n_experts);
-        std::vector<ggml_tensor *> ffn_gate_exp(n_experts);
-        std::vector<ggml_tensor *> ffn_down_exp(n_experts);
-
-        for (int i = 0; i < n_experts; ++i) {
-            ffn_up_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
-            ffn_gate_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
-            ffn_down_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd);
-        }
-
-        ggml_tensor * cur = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens);
-
-        ggml_tensor * logits = ggml_mul_mat(ctx, ffn_gate_inp, cur);
-        ggml_tensor * probs = ggml_soft_max_ext(ctx, logits, nullptr, nullptr, 1.0f/sqrtf(n_embd), 0.0f);
-
-        // select experts
-        ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_experts_per_tok);
-
-        ggml_tensor * weights = ggml_get_rows(ctx,
-                ggml_reshape_3d(ctx, probs, 1, n_experts, n_tokens), selected_experts);
-
-        weights = ggml_reshape_2d(ctx, weights, n_experts_per_tok, n_tokens);
-
-        ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights);
-
-        weights = ggml_div(ctx, weights, weights_sum);
-
-        // compute expert outputs
-        ggml_tensor * moe_out = nullptr;
-
-        for (int i = 0; i < n_experts_per_tok; ++i) {
-            ggml_tensor * cur_expert;
-
-            ggml_tensor * cur_up = ggml_mul_mat_id(ctx, ffn_up_exp.data(), n_experts, selected_experts, i, cur);
-
-            ggml_tensor * cur_gate = ggml_mul_mat_id(ctx, ffn_gate_exp.data(), n_experts, selected_experts, i, cur);
-
-            cur_gate = ggml_silu(ctx, cur_gate);
-
-            cur_expert = ggml_mul(ctx, cur_up, cur_gate);
-
-            cur_expert = ggml_mul_mat_id(ctx, ffn_down_exp.data(), n_experts, selected_experts, i, cur_expert);
-
-            cur_expert = ggml_mul(ctx, cur_expert,
-                    ggml_view_2d(ctx, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
-
-            if (i == 0) {
-                moe_out = cur_expert;
-            } else {
-                moe_out = ggml_add(ctx, moe_out, cur_expert);
-            }
-        }
-
-        cur = moe_out;
-
-        return cur;
-    }
-};
-
-
 enum llm_norm_type {
     LLM_NORM,
     LLM_NORM_RMS,
@@ -2169,6 +2080,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) {
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
+        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
     }
 
     test_cases.emplace_back(new test_sum_rows());
@@ -2182,11 +2094,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
 
     // these tests are disabled to save execution time, but they can be handy for debugging
 #if 0
-#if !defined(__SANITIZE_THREAD__)
-    // FIXME: these tests use too much memory with thread sanitizer
-    test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 8*1024));
-    //test_cases.emplace_back(new test_moe(8, 2, 8, 4096, 14336));
-#endif
     test_cases.emplace_back(new test_llama(1));
     test_cases.emplace_back(new test_llama(2));
     test_cases.emplace_back(new test_falcon(1));