]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add Mixtral support (#4406)
authorslaren <redacted>
Wed, 13 Dec 2023 12:04:25 +0000 (13:04 +0100)
committerGitHub <redacted>
Wed, 13 Dec 2023 12:04:25 +0000 (14:04 +0200)
* convert : support Mixtral as LLAMA arch

* convert : fix n_ff typo

* llama : model loading

* ggml : sync latest ggml_mul_mat_id

* llama : update graph to support MoE

* llama : fix cur -> cur_expert

* llama : first working version

* llama : fix expert weighting in the FFN

* ggml : ggml_get_rows support 2D indexing [n_tokens, n_experts] (cpu only)

* ggml : add n_as argument to ggml_mul_mat_id

* ggml : fix ggml_get_rows to take into account ne02 / ne11

* metal : add more general support for ggml_get_rows + tests

* llama : add basic support for offloading moe with CUDA

* metal : add/mul/div use general kernel when src1 not cont

* metal : reduce the kernel launches for ggml_mul_mat_id

* ggml : get_rows : support non-contiguos tensors with gaps, generalize up to 3D

* ggml : update get_rows f16 and q

* cuda : support non-contiguous src1 in get_rows

* llama : offload missing ffn_moe_silu

* metal : fix ggml_get_rows to work with non-cont src1

* metal : add indirect mat-vec kernels for all quantization types

* llama : do not quantize expert gating tensors

* llama : add n_expert and n_expert_used to hparams + change quants

* test-backend-ops : add moe test

* cuda : fix get_rows when ncols is odd

* convert : determine n_ctx correctly

* metal : fix ggml_mul_mat_id for F32

* test-backend-ops : make experts more evenly probable (test_moe)

* test-backend-ops : cleanup, add moe test for batches

* test-backend-ops : add cpy from f32 -> all types test

* test-backend-ops : fix dequantize block offset

* llama : fix hard-coded number of experts

* test-backend-ops : simplify and disable slow tests to avoid CI timeout

* test-backend-ops : disable MOE test with thread sanitizer

* cuda : fix mul_mat_id with multi gpu

* convert : use 1e6 rope_freq_base for mixtral

* convert : fix style

* convert : support safetensors format

* gguf-py : bump version

* metal : add cpy f16 -> f32 kernel

* metal : fix binary ops for ne10 % 4 != 0

* test-backend-ops : add one more sum_rows test

* ggml : do not use BLAS with ggml_mul_mat_id

* convert-hf : support for mixtral-instruct (#4428)

* convert : typo fix, add additional hyperparameters, use LLaMA arch for Mixtral-instruct

* convert : use sentencepiece tokenizer for Mixtral-instruct

* convert : make flake8 happy

* metal : fix soft_max kernels

ref: https://github.com/ggerganov/ggml/pull/621/commits/1914017863d2f9ab8ecc0281cc2a56d683668b92

* metal : limit kernels to not use more than the allowed threads

---------

Co-authored-by: Georgi Gerganov <redacted>
Co-authored-by: Radek Pilar <redacted>
14 files changed:
Makefile
convert-hf-to-gguf.py
convert.py
ggml-cuda.cu
ggml-metal.m
ggml-metal.metal
ggml.c
ggml.h
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
gguf-py/gguf/tensor_mapping.py
gguf-py/pyproject.toml
llama.cpp
tests/test-backend-ops.cpp

index e775959524c5bc7772ccdcfda213affdfd4dce8e..b7afda2b570e501dad2e734b44637c538ed497b8 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -399,6 +399,11 @@ ifdef LLAMA_CUBLAS
        MK_LDFLAGS   += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
        OBJS         += ggml-cuda.o
        NVCCFLAGS = --forward-unknown-to-host-compiler -use_fast_math
+
+ifdef LLAMA_DEBUG
+       NVCCFLAGS += -lineinfo
+endif
+
 ifdef LLAMA_CUDA_NVCC
        NVCC = $(LLAMA_CUDA_NVCC)
 else
index bced1f5617a0f9c52d89959abcfeffd668f7299c..e46a7813a78e98f42b0611179ccd3df312930896 100755 (executable)
@@ -77,8 +77,18 @@ class Model:
             self.gguf_writer.add_embedding_length(n_embd)
         if (n_ff := self.hparams.get("intermediate_size")) is not None:
             self.gguf_writer.add_feed_forward_length(n_ff)
-        if (n_head := self.hparams.get("num_attention_head")) is not None:
+        if (n_head := self.hparams.get("num_attention_heads")) is not None:
             self.gguf_writer.add_head_count(n_head)
+        if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
+            self.gguf_writer.add_head_count_kv(n_head_kv)
+
+        if (n_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
+            self.gguf_writer.add_layer_norm_rms_eps(n_rms_eps)
+        if (n_experts := self.hparams.get("num_local_experts")) is not None:
+            self.gguf_writer.add_expert_count(n_experts)
+        if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
+            self.gguf_writer.add_expert_used_count(n_experts_used)
+
         self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))
 
     def write_tensors(self):
@@ -170,6 +180,8 @@ class Model:
             return StableLMModel
         if model_architecture == "QWenLMHeadModel":
             return QwenModel
+        if model_architecture == "MixtralForCausalLM":
+            return MixtralModel
         return Model
 
     def _is_model_safetensors(self) -> bool:
@@ -207,6 +219,8 @@ class Model:
             return gguf.MODEL_ARCH.STABLELM
         if arch == "QWenLMHeadModel":
             return gguf.MODEL_ARCH.QWEN
+        if arch == "MixtralForCausalLM":
+            return gguf.MODEL_ARCH.LLAMA
 
         raise NotImplementedError(f'Architecture "{arch}" not supported!')
 
@@ -837,6 +851,11 @@ class StableLMModel(Model):
         self.gguf_writer.add_layer_norm_eps(1e-5)
 
 
+class MixtralModel(Model):
+    def set_vocab(self):
+        self._set_vocab_sentencepiece()
+
+
 class QwenModel(Model):
     @staticmethod
     def token_bytes_to_string(b):
index a6fc6b8ea893309b421d97ed3ecb2bcd79000ac5..e4b69d172f728bda274889e38defded1fef2ac4b 100755 (executable)
@@ -42,6 +42,7 @@ NDArray: TypeAlias = 'np.ndarray[Any, Any]'
 ARCH = gguf.MODEL_ARCH.LLAMA
 
 DEFAULT_CONCURRENCY = 8
+
 #
 # data types
 #
@@ -62,10 +63,10 @@ class UnquantizedDataType(DataType):
     pass
 
 
-DT_F16  = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0'])
-DT_F32  = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0'])
-DT_I32  = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = [])
-DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0'])
+DT_F16  = UnquantizedDataType('F16',  dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0'])
+DT_F32  = UnquantizedDataType('F32',  dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0'])
+DT_I32  = UnquantizedDataType('I32',  dtype = np.dtype(np.int16),   valid_conversions = [])
+DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16),  valid_conversions = ['F32', 'F16', 'Q8_0'])
 
 
 @dataclass(frozen=True)
@@ -151,14 +152,16 @@ GGML_FILE_TYPE_TO_DATA_TYPE: dict[GGMLFileType, DataType] = {
 
 @dataclass
 class Params:
-    n_vocab:    int
-    n_embd:     int
-    n_layer:    int
-    n_ctx:      int
-    n_ff:       int
-    n_head:     int
-    n_head_kv:  int
-    f_norm_eps: float
+    n_vocab:        int
+    n_embd:         int
+    n_layer:        int
+    n_ctx:          int
+    n_ff:           int
+    n_head:         int
+    n_head_kv:      int
+    n_experts:      int | None = None
+    n_experts_used: int | None = None
+    f_norm_eps:     float | None = None
 
     rope_scaling_type: gguf.RopeScalingType | None = None
     f_rope_freq_base: float | None = None
@@ -233,6 +236,13 @@ class Params:
             raise Exception("failed to guess 'n_ctx'. This model is unknown or unsupported.\n"
                             "Suggestion: provide 'config.json' of the model in the same directory containing model files.")
 
+        n_experts      = None
+        n_experts_used = None
+
+        if "num_local_experts" in config:
+            n_experts = config["num_local_experts"]
+            n_experts_used = config["num_experts_per_tok"]
+
         return Params(
             n_vocab           = config["vocab_size"],
             n_embd            = config["hidden_size"],
@@ -241,6 +251,8 @@ class Params:
             n_ff              = config["intermediate_size"],
             n_head            = (n_head := config["num_attention_heads"]),
             n_head_kv         = config.get("num_key_value_heads", n_head),
+            n_experts         = n_experts,
+            n_experts_used    = n_experts_used,
             f_norm_eps        = config["rms_norm_eps"],
             f_rope_freq_base  = config.get("rope_theta"),
             rope_scaling_type = rope_scaling_type,
@@ -255,8 +267,15 @@ class Params:
     def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
         config = json.load(open(config_path))
 
+        n_experts      = None
+        n_experts_used = None
+        f_rope_freq_base = None
+
         # hack to determine LLaMA v1 vs v2 vs CodeLlama
-        if config.get("rope_theta") == 1000000:
+        if config.get("moe"):
+            # Mixtral
+            n_ctx = 32768
+        elif config.get("rope_theta") == 1000000:
             # CodeLlama
             n_ctx = 16384
         elif config["norm_eps"] == 1e-05:
@@ -266,16 +285,27 @@ class Params:
             # LLaMA v1
             n_ctx = 2048
 
+        if "layers.0.feed_forward.w1.weight" in model:
+            n_ff = model["layers.0.feed_forward.w1.weight"].shape[0]
+
+        if config.get("moe"):
+            n_ff = model["layers.0.feed_forward.experts.0.w1.weight"].shape[0]
+            n_experts      = config["moe"]["num_experts"]
+            n_experts_used = config["moe"]["num_experts_per_tok"]
+            f_rope_freq_base = 1e6
+
         return Params(
             n_vocab          = model["tok_embeddings.weight"].shape[0],
             n_embd           = config["dim"],
             n_layer          = config["n_layers"],
             n_ctx            = n_ctx,
-            n_ff             = model["layers.0.feed_forward.w1.weight"].shape[0],
+            n_ff             = n_ff,
             n_head           = (n_head := config["n_heads"]),
             n_head_kv        = config.get("n_kv_heads", n_head),
+            n_experts        = n_experts,
+            n_experts_used   = n_experts_used,
             f_norm_eps       = config["norm_eps"],
-            f_rope_freq_base = config.get("rope_theta"),
+            f_rope_freq_base = config.get("rope_theta", f_rope_freq_base),
         )
 
     @staticmethod
@@ -832,7 +862,17 @@ class OutputFile:
         self.gguf.add_rope_dimension_count(params.n_embd // params.n_head)
         self.gguf.add_head_count          (params.n_head)
         self.gguf.add_head_count_kv       (params.n_head_kv)
-        self.gguf.add_layer_norm_rms_eps  (params.f_norm_eps)
+
+        if params.n_experts:
+            self.gguf.add_expert_count(params.n_experts)
+
+        if params.n_experts_used:
+            self.gguf.add_expert_used_count(params.n_experts_used)
+
+        if params.f_norm_eps:
+            self.gguf.add_layer_norm_rms_eps(params.f_norm_eps)
+        else:
+            raise ValueError('f_norm_eps is None')
 
         if params.f_rope_freq_base is not None:
             self.gguf.add_rope_freq_base(params.f_rope_freq_base)
@@ -956,7 +996,7 @@ class OutputFile:
 
 
 def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType:
-    wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) +".weight"].data_type
+    wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) + ".weight"].data_type
 
     if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32):
         return GGMLFileType.AllF32
index 85f7a293783bed95758dd729fb49fb6c97d27017..9e1acd3f19e5f6d08e2ea1ac8335186ca980fff2 100644 (file)
@@ -1,13 +1,15 @@
 #include <algorithm>
+#include <assert.h>
+#include <atomic>
+#include <cinttypes>
 #include <cstddef>
 #include <cstdint>
-#include <cinttypes>
 #include <float.h>
 #include <limits>
 #include <stdint.h>
 #include <stdio.h>
-#include <atomic>
-#include <assert.h>
+#include <vector>
+
 
 #if defined(GGML_USE_HIPBLAS)
 #include <hip/hip_runtime.h>
@@ -1684,31 +1686,65 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
 }
 
 template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
-static __global__ void k_get_rows(const void * x, const int32_t * y, dst_t * dst, const int ncols) {
-    const int col = (blockIdx.x*blockDim.x + threadIdx.x)*2;
-    const int row = blockDim.y*blockIdx.y + threadIdx.y;
-
-    if (col >= ncols) {
+static __global__ void k_get_rows(
+            const void * src0, const int32_t * src1, dst_t * dst,
+            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+            size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
+
+    const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
+    const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
+    const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
+    const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
+
+    if (i00 >= ne00) {
         return;
     }
 
-    const int r = y[row];
+    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
 
-    // copy x[r*ncols + col] to dst[row*ncols + col]
-    const int xi = r*ncols + col;
-    const int di = row*ncols + col;
+    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+    const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
 
-    const int ib = xi/qk; // block index
-    const int iqs = (xi%qk)/qr; // quant index
-    const int iybs = di - di%qk; // y block start index
+    const int ib = i00/qk; // block index
+    const int iqs = (i00%qk)/qr; // quant index
+    const int iybs = i00 - i00%qk; // dst block start index
     const int y_offset = qr == 1 ? 1 : qk/2;
 
     // dequantize
     dfloat2 v;
-    dequantize_kernel(x, ib, iqs, v);
+    dequantize_kernel(src0_row, ib, iqs, v);
 
-    dst[iybs + iqs + 0]        = v.x;
-    dst[iybs + iqs + y_offset] = v.y;
+    dst_row[iybs + iqs + 0]        = v.x;
+    dst_row[iybs + iqs + y_offset] = v.y;
+}
+
+template<typename src0_t, typename dst_t>
+static __global__ void k_get_rows_float(
+            const src0_t * src0, const int32_t * src1, dst_t * dst,
+            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+            size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
+
+    const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
+    const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
+    const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
+    const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
+
+    if (i00 >= ne00) {
+        return;
+    }
+
+    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+
+    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+    const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
+
+    dst_row[i00] = src0_row[i00];
 }
 
 template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
@@ -5053,11 +5089,69 @@ static  __global__ void im2col_f32_f16(
 }
 
 template<int qk, int qr, dequantize_kernel_t dq>
-static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
+static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+                            const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
+    const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
+    const dim3 block_nums(block_num_x, ne10, ne11*ne12);
+
+    // strides in elements
+    //const size_t s0 = nb0 / ggml_element_size(dst);
+    const size_t s1 = nb1 / ggml_element_size(dst);
+    const size_t s2 = nb2 / ggml_element_size(dst);
+    const size_t s3 = nb3 / ggml_element_size(dst);
+
+    const size_t s10 = nb10 / ggml_element_size(src1);
+    const size_t s11 = nb11 / ggml_element_size(src1);
+    const size_t s12 = nb12 / ggml_element_size(src1);
+    //const size_t s13 = nb13 / ggml_element_size(src1);
+
+    GGML_ASSERT(ne00 % 2 == 0);
+
+    k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
+            src0_dd, src1_dd, dst_dd,
+            ne00, /*ne01, ne02, ne03,*/
+            /*ne10, ne11,*/ ne12, /*ne13,*/
+            /* s0,*/ s1, s2, s3,
+            /* nb00,*/ nb01, nb02, nb03,
+            s10, s11, s12/*, s13*/);
+
+    (void) dst;
+}
+
+template<typename src0_t>
+static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+                                const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
     const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
-    const int block_num_x = (ncols + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
-    const dim3 block_nums(block_num_x, nrows, 1);
-    k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols);
+    const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
+    const dim3 block_nums(block_num_x, ne10, ne11*ne12);
+
+    // strides in elements
+    //const size_t s0 = nb0 / ggml_element_size(dst);
+    const size_t s1 = nb1 / ggml_element_size(dst);
+    const size_t s2 = nb2 / ggml_element_size(dst);
+    const size_t s3 = nb3 / ggml_element_size(dst);
+
+    const size_t s10 = nb10 / ggml_element_size(src1);
+    const size_t s11 = nb11 / ggml_element_size(src1);
+    const size_t s12 = nb12 / ggml_element_size(src1);
+    //const size_t s13 = nb13 / ggml_element_size(src1);
+
+    k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
+            src0_dd, src1_dd, dst_dd,
+            ne00, /*ne01, ne02, ne03,*/
+            /*ne10, ne11,*/ ne12, /*ne13,*/
+            /* s0,*/ s1, s2, s3,
+            /* nb00,*/ nb01, nb02, nb03,
+            s10, s11, s12/*, s13*/);
+
+    (void) dst;
 }
 
 template<float (*bin_op)(const float, const float)>
@@ -5069,7 +5163,6 @@ struct bin_bcast_cuda {
 
         GGML_TENSOR_BINARY_OP_LOCALS
 
-
         int nr0 = ne10/ne0;
         int nr1 = ne11/ne1;
         int nr2 = ne12/ne2;
@@ -5117,26 +5210,28 @@ struct bin_bcast_cuda {
             int64_t ne12 = cne1[2];
             int64_t ne13 = cne1[3];
 
-            //size_t nb0 = cnb0[0];
+            size_t nb0 = cnb0[0];
             size_t nb1 = cnb0[1];
             size_t nb2 = cnb0[2];
             size_t nb3 = cnb0[3];
 
-            //size_t nb10 = cnb1[0];
+            size_t nb10 = cnb1[0];
             size_t nb11 = cnb1[1];
             size_t nb12 = cnb1[2];
             size_t nb13 = cnb1[3];
 
-            //size_t s0 = nb0 / sizeof(src1_t);
+            size_t s0 = nb0 / sizeof(src1_t);
             size_t s1 = nb1 / sizeof(src1_t);
             size_t s2 = nb2 / sizeof(src1_t);
             size_t s3 = nb3 / sizeof(src1_t);
 
-            //size_t s10 = nb10 / sizeof(src1_t);
+            size_t s10 = nb10 / sizeof(src1_t);
             size_t s11 = nb11 / sizeof(src1_t);
             size_t s12 = nb12 / sizeof(src1_t);
             size_t s13 = nb13 / sizeof(src1_t);
 
+            GGML_ASSERT(s0 == 1);
+            GGML_ASSERT(s10 == 1);
 
             const int block_size = 128;
 
@@ -6447,36 +6542,34 @@ static void ggml_cuda_op_get_rows(
 
     GGML_ASSERT(src1->type == GGML_TYPE_I32);
     GGML_ASSERT(dst->type == GGML_TYPE_F32);
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    GGML_ASSERT(ggml_is_contiguous(src1));
-    GGML_ASSERT(ggml_is_contiguous(dst));
 
-    const int ncols = src0->ne[0];
-    const int nrows = ggml_nelements(src1);
+    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+    GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
+    GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
 
     const int32_t * src1_i32 = (const int32_t *) src1_d;
 
     switch (src0->type) {
         case GGML_TYPE_F16:
-            get_rows_cuda<1, 1, convert_f16>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+            get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream);
             break;
         case GGML_TYPE_F32:
-            get_rows_cuda<1, 1, convert_f32>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+            get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
             break;
         case GGML_TYPE_Q4_0:
-            get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+            get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
             break;
         case GGML_TYPE_Q4_1:
-            get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+            get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
             break;
         case GGML_TYPE_Q5_0:
-            get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+            get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
             break;
         case GGML_TYPE_Q5_1:
-            get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+            get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
             break;
         case GGML_TYPE_Q8_0:
-            get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+            get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
             break;
         default:
             // TODO: k-quants
@@ -8234,36 +8327,69 @@ static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) {
 }
 #endif
 
-static void ggml_cuda_mul_mat_id(const ggml_tensor * _src0, const ggml_tensor * _src1, ggml_tensor * dst) {
+static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
 #if 0
-//#ifdef CUDA_USE_TENSOR_CORES
-//    const bool use_tensor_cores = true;
-//#else
-//    const bool use_tensor_cores = false;
-//#endif
-
     ggml_cuda_mul_mat_id_cublas(dst);
-
     // TODO: mmq/mmv support
-#else
-    const struct ggml_tensor * ids = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-    const int id = dst->op_params[0];
+#endif
 
-    int32_t * ids_dev = (int32_t *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
+    GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
 
-    int32_t a_id;
-    CUDA_CHECK(cudaMemcpyAsync(&a_id, ids_dev + id, sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
-    CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
+    const struct ggml_tensor * ids = src0;
+    const int32_t id = ((int32_t *) dst->op_params)[0];
+    const int32_t n_as = ((int32_t *) dst->op_params)[1];
 
-    GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
-    const struct ggml_tensor * src0 = dst->src[a_id + 2];
+    std::vector<char> ids_host(ggml_nbytes(ids));
+
+    if (ids->backend == GGML_BACKEND_GPU) {
+        const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
+        CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
+        CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
+    } else {
+        memcpy(ids_host.data(), ids->data, ggml_nbytes(ids));
+    }
+
+    const ggml_tensor_extra_gpu * src1_extra = (const ggml_tensor_extra_gpu *) src1->extra;
+    const ggml_tensor_extra_gpu * dst_extra = (const ggml_tensor_extra_gpu *) dst->extra;
+
+    ggml_tensor_extra_gpu src1_row_extra;
+    ggml_tensor_extra_gpu dst_row_extra;
+
+    ggml_tensor src1_row = *src1;
+    ggml_tensor dst_row = *dst;
+
+    src1_row.ne[1] = 1;
+    dst_row.ne[1] = 1;
+
+    src1_row.nb[2] = src1_row.nb[1];
+    dst_row.nb[2] = dst_row.nb[1];
+
+    src1_row.nb[3] = src1_row.nb[1];
+    dst_row.nb[3] = dst_row.nb[1];
+
+    src1_row.extra = &src1_row_extra;
+    dst_row.extra = &dst_row_extra;
 
-    ggml_cuda_mul_mat(src0, src1, dst);
-#endif
 
-    (void) _src0;
-    (void) _src1;
+    for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
+        //int32_t row_id;
+        //CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
+        //CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
+
+        const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
+
+        GGML_ASSERT(row_id >= 0 && row_id < n_as);
+
+        const struct ggml_tensor * src0_row = dst->src[row_id + 2];
+
+        src1_row_extra.data_device[g_main_device] = (char *) src1_extra->data_device[g_main_device] + i01*src1->nb[1];
+        src1_row.data = (char *) src1->data + i01*src1->nb[1];
+
+        dst_row_extra.data_device[g_main_device] = (char *) dst_extra->data_device[g_main_device] + i01*dst->nb[1];
+        dst_row.data = (char *) dst->data + i01*dst->nb[1];
+
+        ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
+    }
 }
 
 static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -9181,6 +9307,45 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
                 }
                 return true;
             } break;
+        case GGML_OP_GET_ROWS:
+            {
+                switch (op->src[0]->type) {
+                    case GGML_TYPE_F16:
+                    case GGML_TYPE_F32:
+                    case GGML_TYPE_Q4_0:
+                    case GGML_TYPE_Q4_1:
+                    case GGML_TYPE_Q5_0:
+                    case GGML_TYPE_Q5_1:
+                    case GGML_TYPE_Q8_0:
+                        return true;
+                    default:
+                        return false;
+                }
+            } break;
+        case GGML_OP_CPY:
+            {
+                ggml_type src0_type = op->src[0]->type;
+                ggml_type src1_type = op->src[1]->type;
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
+                    return true;
+                }
+                return false;
+            } break;
         case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
         case GGML_OP_VIEW:
@@ -9188,7 +9353,6 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
         case GGML_OP_TRANSPOSE:
         case GGML_OP_NORM:
         case GGML_OP_REPEAT:
-        case GGML_OP_GET_ROWS:
         case GGML_OP_DUP:
         case GGML_OP_ADD:
         case GGML_OP_MUL:
@@ -9197,7 +9361,6 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
         case GGML_OP_SCALE:
         case GGML_OP_SQR:
         case GGML_OP_CLAMP:
-        case GGML_OP_CPY:
         case GGML_OP_CONT:
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_SOFT_MAX:
@@ -9264,7 +9427,9 @@ static ggml_backend_t ggml_backend_reg_cuda_init(const char * params, void * use
     UNUSED(params);
 }
 
-extern "C" int ggml_backend_cuda_reg_devices() {
+extern "C" int ggml_backend_cuda_reg_devices();
+
+int ggml_backend_cuda_reg_devices() {
     int device_count = ggml_cuda_get_device_count();
     //int device_count = 1; // DEBUG: some tools require delaying CUDA initialization
     for (int i = 0; i < device_count; i++) {
index f9bd69dc84bbe78128c5b9a01917decc40b627f0..1dcfa6eddbfa570dbbec056c6370ef0140f20681 100644 (file)
@@ -102,6 +102,21 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
+    //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
+    //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
+    //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
@@ -140,6 +155,7 @@ struct ggml_metal_context {
     //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
     //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
     GGML_METAL_DECL_KERNEL(cpy_f16_f16);
+    GGML_METAL_DECL_KERNEL(cpy_f16_f32);
     GGML_METAL_DECL_KERNEL(concat);
     GGML_METAL_DECL_KERNEL(sqr);
     GGML_METAL_DECL_KERNEL(sum_rows);
@@ -177,6 +193,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
             ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data);
         } else {
             char* buffer2 = malloc(len+1);
+            va_end(args);
+            va_start(args, format);
             vsnprintf(buffer2, len+1, format, args);
             buffer2[len] = 0;
             ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data);
@@ -352,6 +370,21 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
         GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
         GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
+        //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
+        //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
+        //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
         if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
             GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
             GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
@@ -392,6 +425,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
         //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
         GGML_METAL_ADD_KERNEL(cpy_f16_f16);
+        GGML_METAL_ADD_KERNEL(cpy_f16_f32);
         GGML_METAL_ADD_KERNEL(concat);
         GGML_METAL_ADD_KERNEL(sqr);
         GGML_METAL_ADD_KERNEL(sum_rows);
@@ -452,6 +486,21 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
     GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
     GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
+    //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
+    //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
+    //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
     if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
         GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
         GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
@@ -492,6 +541,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
     //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
     GGML_METAL_DEL_KERNEL(cpy_f16_f16);
+    GGML_METAL_DEL_KERNEL(cpy_f16_f32);
     GGML_METAL_DEL_KERNEL(concat);
     GGML_METAL_DEL_KERNEL(sqr);
     GGML_METAL_DEL_KERNEL(sum_rows);
@@ -803,8 +853,9 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
         case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
         case GGML_OP_VIEW:
-        case GGML_OP_TRANSPOSE:
         case GGML_OP_PERMUTE:
+        case GGML_OP_TRANSPOSE:
+        case GGML_OP_GET_ROWS:
         case GGML_OP_CONCAT:
         case GGML_OP_ADD:
         case GGML_OP_MUL:
@@ -819,14 +870,38 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
         case GGML_OP_ROPE:
         case GGML_OP_IM2COL:
         case GGML_OP_ARGSORT:
-        case GGML_OP_DUP:
-        case GGML_OP_CPY:
-        case GGML_OP_CONT:
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
             return true;
+        case GGML_OP_CPY:
+        case GGML_OP_DUP:
+        case GGML_OP_CONT:
+            {
+                switch (op->src[0]->type) {
+                    case GGML_TYPE_F32:
+                        switch (op->type) {
+                           case GGML_TYPE_F16:
+                           case GGML_TYPE_F32:
+                           case GGML_TYPE_Q8_0:
+                           case GGML_TYPE_Q4_0:
+                           case GGML_TYPE_Q4_1:
+                                return true;
+                           default:
+                                return false;
+                        }
+                    case GGML_TYPE_F16:
+                        switch (op->type) {
+                           case GGML_TYPE_F16:
+                           case GGML_TYPE_F32:
+                                return true;
+                           default:
+                                return false;
+                        }
+                    default:
+                        return false;
+                };
+            }
         case GGML_OP_DIAG_MASK_INF:
-        case GGML_OP_GET_ROWS:
             {
                 return op->ne[0] % 4 == 0;
             }
@@ -1001,34 +1076,37 @@ void ggml_metal_graph_compute(
                     case GGML_OP_MUL:
                     case GGML_OP_DIV:
                         {
-                            GGML_ASSERT(ggml_is_contiguous(src0));
-                            GGML_ASSERT(ggml_is_contiguous(src1));
-
                             bool bcast_row = false;
 
                             int64_t nb = ne00;
 
-                            if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
+                            id<MTLComputePipelineState> pipeline = nil;
+
+                            if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
+                                GGML_ASSERT(ggml_is_contiguous(src0));
+
                                 // src1 is a row
                                 GGML_ASSERT(ne11 == 1);
 
                                 nb = ne00 / 4;
                                 switch (dst->op) {
-                                    case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
-                                    case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
-                                    case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
+                                    case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
+                                    case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
+                                    case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
                                     default: GGML_ASSERT(false);
                                 }
 
                                 bcast_row = true;
                             } else {
                                 switch (dst->op) {
-                                    case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
-                                    case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
-                                    case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
+                                    case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
+                                    case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
+                                    case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
                                     default: GGML_ASSERT(false);
                                 }
                             }
+
+                            [encoder setComputePipelineState:pipeline];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
@@ -1063,7 +1141,7 @@ void ggml_metal_graph_compute(
 
                                 [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                             } else {
-                                const int nth = MIN(1024, ne0);
+                                const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
 
                                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                             }
@@ -1193,7 +1271,11 @@ void ggml_metal_graph_compute(
                             const float scale = ((float *) dst->op_params)[0];
 
                             [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
-                            [encoder setBuffer:id_src1 offset:offs_src1   atIndex:1];
+                            if (id_src1) {
+                                [encoder setBuffer:id_src1 offset:offs_src1   atIndex:1];
+                            } else {
+                                [encoder setBuffer:id_src0 offset:offs_src0   atIndex:1];
+                            }
                             [encoder setBuffer:id_dst  offset:offs_dst    atIndex:2];
                             [encoder setBytes:&ne00  length:sizeof(ne00)  atIndex:3];
                             [encoder setBytes:&ne01  length:sizeof(ne01)  atIndex:4];
@@ -1444,7 +1526,7 @@ void ggml_metal_graph_compute(
                                 else if (src0t == GGML_TYPE_Q6_K) {
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 } else {
-                                    int64_t ny = (ne11 + nrows - 1)/nrows;
+                                    const int64_t ny = (ne11 + nrows - 1)/nrows;
                                     [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                             }
@@ -1456,7 +1538,7 @@ void ggml_metal_graph_compute(
 
                             GGML_ASSERT(src0t == GGML_TYPE_I32);
 
-                            const int n_as = ne00;
+                            const int n_as = ((int32_t *) dst->op_params)[1];
 
                             // TODO: make this more general
                             GGML_ASSERT(n_as <= 8);
@@ -1488,14 +1570,22 @@ void ggml_metal_graph_compute(
 
                             // find the break-even point where the matrix-matrix kernel becomes more efficient compared
                             // to the matrix-vector kernel
-                            int ne11_mm_min = 0;
+                            int ne11_mm_min = 1;
 
                             const int idx = ((int32_t *) dst->op_params)[0];
 
+                            // batch size
+                            GGML_ASSERT(ne01 == ne11);
+
+                            const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
+
                             // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
                             // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
-                            if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
-                                ne11 > ne11_mm_min) {
+                            // !!!
+                            // TODO: for now, always use mat-vec kernels until we figure out how to improve the
+                            //       indirect matrix multiplication
+                            // !!!
+                            if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
                                 switch (src2->type) {
                                     case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32];  break;
                                     case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32];  break;
@@ -1514,19 +1604,22 @@ void ggml_metal_graph_compute(
                                 [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
                                 [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
                                 [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
-                                [encoder setBytes:&ne20    length:sizeof(ne20) atIndex:3];
-                                [encoder setBytes:&ne22    length:sizeof(ne22) atIndex:4];
-                                [encoder setBytes:&nb21    length:sizeof(nb21) atIndex:5];
-                                [encoder setBytes:&nb22    length:sizeof(nb22) atIndex:6];
-                                [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:7];
-                                [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:8];
-                                [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:9];
-                                [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:10];
-                                [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:11];
-                                [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:12];
-                                [encoder setBytes:&r2      length:sizeof(r2)   atIndex:13];
-                                [encoder setBytes:&r3      length:sizeof(r3)   atIndex:14];
-                                [encoder setBytes:&idx     length:sizeof(idx)  atIndex:15];
+                                [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:3];
+                                [encoder setBytes:&ne20    length:sizeof(ne20) atIndex:4];
+                                [encoder setBytes:&ne22    length:sizeof(ne22) atIndex:5];
+                                [encoder setBytes:&nb21    length:sizeof(nb21) atIndex:6];
+                                [encoder setBytes:&nb22    length:sizeof(nb22) atIndex:7];
+                                [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:8];
+                                [encoder setBytes:&ne13    length:sizeof(ne13) atIndex:9];
+                                [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:10];
+                                [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:11];
+                                [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:12];
+                                [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:13];
+                                [encoder setBytes:&_ne1    length:sizeof(_ne1) atIndex:14];
+                                [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:15];
+                                [encoder setBytes:&r2      length:sizeof(r2)   atIndex:16];
+                                [encoder setBytes:&r3      length:sizeof(r3)   atIndex:17];
+                                [encoder setBytes:&idx     length:sizeof(idx)  atIndex:18];
                                 // TODO: how to make this an array? read Metal docs
                                 for (int j = 0; j < n_as; ++j) {
                                     struct ggml_tensor * src_cur = dst->src[2 + j];
@@ -1534,11 +1627,157 @@ void ggml_metal_graph_compute(
                                     size_t offs_src_cur = 0;
                                     id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
 
-                                    [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
+                                    [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
                                 }
 
                                 [encoder setThreadgroupMemoryLength:8192 atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+
+                                // TODO: processing one row at a time (ne11 -> 1) is not efficient
+                                [encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                            } else {
+                                int nth0 = 32;
+                                int nth1 = 1;
+                                int nrows = 1;
+                                //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
+
+                                // use custom matrix x vector kernel
+                                switch (src2t) {
+                                    case GGML_TYPE_F32:
+                                        {
+                                            GGML_ASSERT(src1t == GGML_TYPE_F32);
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
+                                        } break;
+                                    case GGML_TYPE_F16:
+                                        {
+                                            GGML_ASSERT(src1t == GGML_TYPE_F32);
+                                            nth0 = 32;
+                                            nth1 = 1;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
+                                        } break;
+                                    case GGML_TYPE_Q4_0:
+                                        {
+                                            nth0 = 8;
+                                            nth1 = 8;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
+                                        } break;
+                                    case GGML_TYPE_Q4_1:
+                                        {
+                                            nth0 = 8;
+                                            nth1 = 8;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
+                                        } break;
+                                    case GGML_TYPE_Q5_0:
+                                        {
+                                            nth0 = 8;
+                                            nth1 = 8;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
+                                        } break;
+                                    case GGML_TYPE_Q5_1:
+                                        {
+                                            nth0 = 8;
+                                            nth1 = 8;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
+                                        } break;
+                                    case GGML_TYPE_Q8_0:
+                                        {
+                                            nth0 = 8;
+                                            nth1 = 8;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
+                                        } break;
+                                    case GGML_TYPE_Q2_K:
+                                        {
+                                            nth0 = 2;
+                                            nth1 = 32;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
+                                        } break;
+                                    case GGML_TYPE_Q3_K:
+                                        {
+                                            nth0 = 2;
+                                            nth1 = 32;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
+                                        } break;
+                                    case GGML_TYPE_Q4_K:
+                                        {
+                                            nth0 = 4; //1;
+                                            nth1 = 8; //32;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
+                                        } break;
+                                    case GGML_TYPE_Q5_K:
+                                        {
+                                            nth0 = 2;
+                                            nth1 = 32;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
+                                        } break;
+                                    case GGML_TYPE_Q6_K:
+                                        {
+                                            nth0 = 2;
+                                            nth1 = 32;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
+                                        } break;
+                                    default:
+                                        {
+                                            GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
+                                            GGML_ASSERT(false && "not implemented");
+                                        }
+                                };
+
+                                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
+                                [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
+                                [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
+                                [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
+                                [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
+                                [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
+                                [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
+                                [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
+                                [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
+                                [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
+                                [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
+                                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
+                                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
+                                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
+                                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:17];
+                                [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
+                                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:19];
+                                [encoder setBytes:&r2   length:sizeof(r2)   atIndex:20];
+                                [encoder setBytes:&r3   length:sizeof(r3)   atIndex:21];
+                                [encoder setBytes:&idx  length:sizeof(idx)  atIndex:22];
+                                // TODO: how to make this an array? read Metal docs
+                                for (int j = 0; j < n_as; ++j) {
+                                    struct ggml_tensor * src_cur = dst->src[2 + j];
+
+                                    size_t offs_src_cur = 0;
+                                    id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
+
+                                    [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
+                                }
+
+                                if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
+                                    src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
+                                    src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                }
+                                else if (src2t == GGML_TYPE_Q4_K) {
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                }
+                                else if (src2t == GGML_TYPE_Q3_K) {
+#ifdef GGML_QKK_64
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+#else
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+#endif
+                                }
+                                else if (src2t == GGML_TYPE_Q5_K) {
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                }
+                                else if (src2t == GGML_TYPE_Q6_K) {
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                } else {
+                                    const int64_t ny = (_ne1 + nrows - 1)/nrows;
+                                    [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                }
                             }
                         } break;
                     case GGML_OP_GET_ROWS:
@@ -1559,16 +1798,19 @@ void ggml_metal_graph_compute(
                                 default: GGML_ASSERT(false && "not implemented");
                             }
 
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                            [encoder setBuffer:id_src0     offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_src1     offset:offs_src1 atIndex:1];
+                            [encoder setBuffer:id_dst      offset:offs_dst  atIndex:2];
                             [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
                             [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
-                            [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:5];
-
-                            const int64_t n = ggml_nelements(src1);
-
-                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                            [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
+                            [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
+                            [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
+                            [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
+                            [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:9];
+                            [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:10];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
                         } break;
                     case GGML_OP_RMS_NORM:
                         {
@@ -1813,7 +2055,7 @@ void ggml_metal_graph_compute(
                                     {
                                         switch (dstt) {
                                             case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
-                                            case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
+                                            case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
                                             default: GGML_ASSERT(false && "not implemented");
                                         };
                                     } break;
index 2f8ea22d66226040f9b9b1de4b70466994f6ff94..773fac124b0c486f21e328fc9f0cb2ea99289e47 100644 (file)
@@ -347,9 +347,9 @@ kernel void kernel_soft_max(
     const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
     const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
 
-    device const float * psrc0 =        src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-    device const float * pmask = src1 ? src1                                      + i01*ne00 : nullptr;
-    device       float * pdst  =        dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+    device const float * psrc0 =         src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+    device const float * pmask = src1 != src0 ? src1                               + i01*ne00 : nullptr;
+    device       float * pdst  =         dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
 
     // parallel max
     float lmax = -INFINITY;
@@ -385,7 +385,12 @@ kernel void kernel_soft_max(
         pdst[i00] = exp_psrc0;
     }
 
+    // This barrier fixes a failing test
+    // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+    threadgroup_barrier(mem_flags::mem_none);
+
     float sum = simd_sum(lsum);
+
     if (ntg > N_SIMDWIDTH) {
         if (sgitg == 0) {
             buf[tiisg] = 0.0f;
@@ -428,9 +433,9 @@ kernel void kernel_soft_max_4(
     const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
     const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
 
-    device const float4 * psrc4 =        (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
-    device const float4 * pmask = src1 ? (device const float4 *)(src1 +                                      i01*ne00) : nullptr;
-    device       float4 * pdst4 =        (device       float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    device const float4 * psrc4 =                (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 +                                      i01*ne00) : nullptr;
+    device       float4 * pdst4 =                (device       float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
 
     // parallel max
     float4 lmax4 = -INFINITY;
@@ -468,7 +473,13 @@ kernel void kernel_soft_max_4(
     }
 
     const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
+
+    // This barrier fixes a failing test
+    // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+    threadgroup_barrier(mem_flags::mem_none);
+
     float sum = simd_sum(lsum);
+
     if (ntg > N_SIMDWIDTH) {
         if (sgitg == 0) {
             buf[tiisg] = 0.0f;
@@ -731,7 +742,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
 //      giard against the number of rows not being divisible by
 //      N_DST, so this is another explicit assumption of the implementation.
 template<typename block_q_type, int nr, int nsg, int nw>
-void mul_vec_q_n_f32(
+void mul_vec_q_n_f32_impl(
         device const void  * src0,
         device const float * src1,
         device       float * dst,
@@ -813,7 +824,7 @@ kernel void kernel_mul_mv_q4_0_f32(
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+    mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mv_q4_1_f32(
@@ -832,7 +843,7 @@ kernel void kernel_mul_mv_q4_1_f32(
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
-     mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+     mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mv_q5_0_f32(
@@ -851,7 +862,7 @@ kernel void kernel_mul_mv_q5_0_f32(
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+    mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mv_q5_1_f32(
@@ -870,28 +881,28 @@ kernel void kernel_mul_mv_q5_1_f32(
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+    mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
 }
 
 
 #define NB_Q8_0 8
 
-kernel void kernel_mul_mv_q8_0_f32(
+void kernel_mul_mv_q8_0_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne01[[buffer(4)]],
-        constant   int64_t & ne02[[buffer(5)]],
-        constant   int64_t & ne10[[buffer(9)]],
-        constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0 [[buffer(15)]],
-        constant   int64_t & ne1 [[buffer(16)]],
-        constant   uint    & r2  [[buffer(17)]],
-        constant   uint    & r3  [[buffer(18)]],
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
     const int nr  = N_DST;
     const int nsg = N_SIMDGROUP;
     const int nw  = N_SIMDWIDTH;
@@ -945,9 +956,29 @@ kernel void kernel_mul_mv_q8_0_f32(
     }
 }
 
+[[host_name("kernel_mul_mv_q8_0_f32")]]
+kernel void kernel_mul_mv_q8_0_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2   [[buffer(17)]],
+        constant   uint    & r3   [[buffer(18)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+    kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+}
+
 #define N_F32_F32 4
 
-kernel void kernel_mul_mv_f32_f32(
+void kernel_mul_mv_f32_f32_impl(
         device const  char * src0,
         device const  char * src1,
         device       float * dst,
@@ -965,8 +996,8 @@ kernel void kernel_mul_mv_f32_f32(
         constant  uint64_t & nb12,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
-        constant   uint    & r2   [[buffer(17)]],
-        constant   uint    & r3   [[buffer(18)]],
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]]) {
 
@@ -1025,6 +1056,32 @@ kernel void kernel_mul_mv_f32_f32(
     }
 }
 
+[[host_name("kernel_mul_mv_f32_f32")]]
+kernel void kernel_mul_mv_f32_f32(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2   [[buffer(17)]],
+        constant   uint    & r3   [[buffer(18)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]]) {
+    kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
+}
+
 #define N_F16_F16 4
 
 kernel void kernel_mul_mv_f16_f16(
@@ -1105,7 +1162,7 @@ kernel void kernel_mul_mv_f16_f16(
     }
 }
 
-kernel void kernel_mul_mv_f16_f32_1row(
+void kernel_mul_mv_f16_f32_1row_impl(
         device const  char * src0,
         device const  char * src1,
         device       float * dst,
@@ -1123,8 +1180,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
         constant  uint64_t & nb12,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
-        constant   uint    & r2   [[buffer(17)]],
-        constant   uint    & r3   [[buffer(18)]],
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]]) {
 
@@ -1161,12 +1218,37 @@ kernel void kernel_mul_mv_f16_f32_1row(
             dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
         }
     }
+}
 
+[[host_name("kernel_mul_mv_f16_f32_1row")]]
+kernel void kernel_mul_mv_f16_f32_1row(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2   [[buffer(17)]],
+        constant   uint    & r3   [[buffer(18)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]]) {
+    kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
 }
 
 #define N_F16_F32 4
 
-kernel void kernel_mul_mv_f16_f32(
+void kernel_mul_mv_f16_f32_impl(
         device const  char * src0,
         device const  char * src1,
         device       float * dst,
@@ -1184,8 +1266,8 @@ kernel void kernel_mul_mv_f16_f32(
         constant  uint64_t & nb12,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
-        constant   uint    & r2   [[buffer(17)]],
-        constant   uint    & r3   [[buffer(18)]],
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]]) {
 
@@ -1244,6 +1326,32 @@ kernel void kernel_mul_mv_f16_f32(
     }
 }
 
+[[host_name("kernel_mul_mv_f16_f32")]]
+kernel void kernel_mul_mv_f16_f32(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2   [[buffer(17)]],
+        constant   uint    & r3   [[buffer(18)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]]) {
+    kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
+}
+
 // Assumes row size (ne00) is a multiple of 4
 kernel void kernel_mul_mv_f16_f32_l4(
         device const  char * src0,
@@ -1601,8 +1709,8 @@ template [[host_name("kernel_argsort_f32_i32_asc")]]  kernel argsort_t kernel_ar
 template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
 
 kernel void kernel_cpy_f16_f16(
-        device const half * src0,
-        device       half * dst,
+        device  const half * src0,
+        device        half * dst,
         constant   int64_t & ne00,
         constant   int64_t & ne01,
         constant   int64_t & ne02,
@@ -1641,6 +1749,47 @@ kernel void kernel_cpy_f16_f16(
     }
 }
 
+kernel void kernel_cpy_f16_f32(
+        device  const half * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne03,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant   int64_t & ne3,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        constant  uint64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig[2];
+    const int64_t i02 = tgpig[1];
+    const int64_t i01 = tgpig[0];
+
+    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    const int64_t i3 = n / (ne2*ne1*ne0);
+    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+    device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+        device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+        dst_data[i00] = src[0];
+    }
+}
+
 kernel void kernel_cpy_f32_f16(
         device const float * src0,
         device        half * dst,
@@ -2064,19 +2213,19 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
 
 //====================================== dot products =========================
 
-kernel void kernel_mul_mv_q2_K_f32(
+void kernel_mul_mv_q2_K_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne01[[buffer(4)]],
-        constant   int64_t & ne02[[buffer(5)]],
-        constant   int64_t & ne10[[buffer(9)]],
-        constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0 [[buffer(15)]],
-        constant   int64_t & ne1 [[buffer(16)]],
-        constant   uint    & r2  [[buffer(17)]],
-        constant   uint    & r3  [[buffer(18)]],
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2214,8 +2363,8 @@ kernel void kernel_mul_mv_q2_K_f32(
     }
 }
 
-#if QK_K == 256
-kernel void kernel_mul_mv_q3_K_f32(
+[[host_name("kernel_mul_mv_q2_K_f32")]]
+kernel void kernel_mul_mv_q2_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -2229,8 +2378,29 @@ kernel void kernel_mul_mv_q3_K_f32(
         constant   uint    & r2  [[buffer(17)]],
         constant   uint    & r3  [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
+#if QK_K == 256
+void kernel_mul_mv_q3_K_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const int nb = ne00/QK_K;
 
@@ -2373,19 +2543,19 @@ kernel void kernel_mul_mv_q3_K_f32(
     }
 }
 #else
-kernel void kernel_mul_mv_q3_K_f32(
+void kernel_mul_mv_q3_K_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne01[[buffer(4)]],
-        constant   int64_t & ne02[[buffer(5)]],
-        constant   int64_t & ne10[[buffer(9)]],
-        constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0 [[buffer(15)]],
-        constant   int64_t & ne1 [[buffer(16)]],
-        constant   uint    & r2  [[buffer(17)]],
-        constant   uint    & r3  [[buffer(18)]],
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2450,20 +2620,41 @@ kernel void kernel_mul_mv_q3_K_f32(
 }
 #endif
 
+[[host_name("kernel_mul_mv_q3_K_f32")]]
+kernel void kernel_mul_mv_q3_K_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01[[buffer(4)]],
+        constant   int64_t & ne02[[buffer(5)]],
+        constant   int64_t & ne10[[buffer(9)]],
+        constant   int64_t & ne12[[buffer(11)]],
+        constant   int64_t & ne0 [[buffer(15)]],
+        constant   int64_t & ne1 [[buffer(16)]],
+        constant   uint    & r2  [[buffer(17)]],
+        constant   uint    & r3  [[buffer(18)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
 #if QK_K == 256
-kernel void kernel_mul_mv_q4_K_f32(
+void kernel_mul_mv_q4_K_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne01 [[buffer(4)]],
-        constant   int64_t & ne02 [[buffer(5)]],
-        constant   int64_t & ne10 [[buffer(9)]],
-        constant   int64_t & ne12 [[buffer(11)]],
-        constant   int64_t & ne0  [[buffer(15)]],
-        constant   int64_t & ne1  [[buffer(16)]],
-        constant   uint    & r2   [[buffer(17)]],
-        constant   uint    & r3   [[buffer(18)]],
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2564,21 +2755,21 @@ kernel void kernel_mul_mv_q4_K_f32(
     }
 }
 #else
-kernel void kernel_mul_mv_q4_K_f32(
+void kernel_mul_mv_q4_K_f32_impl(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne01[[buffer(4)]],
-        constant   int64_t & ne02[[buffer(5)]],
-        constant   int64_t & ne10[[buffer(9)]],
-        constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0 [[buffer(15)]],
-        constant   int64_t & ne1 [[buffer(16)]],
-        constant   uint    & r2  [[buffer(17)]],
-        constant   uint    & r3  [[buffer(18)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const int ix = tiisg/4;  // 0...7
@@ -2660,7 +2851,8 @@ kernel void kernel_mul_mv_q4_K_f32(
 }
 #endif
 
-kernel void kernel_mul_mv_q5_K_f32(
+[[host_name("kernel_mul_mv_q4_K_f32")]]
+kernel void kernel_mul_mv_q4_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -2677,6 +2869,26 @@ kernel void kernel_mul_mv_q5_K_f32(
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
+    kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_q5_K_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
     const int nb = ne00/QK_K;
 
     const int64_t r0 = tgpig.x;
@@ -2836,10 +3048,10 @@ kernel void kernel_mul_mv_q5_K_f32(
             dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
         }
     }
-
 }
 
-kernel void kernel_mul_mv_q6_K_f32(
+[[host_name("kernel_mul_mv_q5_K_f32")]]
+kernel void kernel_mul_mv_q5_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -2853,8 +3065,28 @@ kernel void kernel_mul_mv_q6_K_f32(
         constant   uint    & r2  [[buffer(17)]],
         constant   uint    & r3  [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_q6_K_f32_impl(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const uint8_t kmask1 = 0x03;
     const uint8_t kmask2 = 0x0C;
@@ -2945,6 +3177,27 @@ kernel void kernel_mul_mv_q6_K_f32(
     }
 }
 
+[[host_name("kernel_mul_mv_q6_K_f32")]]
+kernel void kernel_mul_mv_q6_K_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01[[buffer(4)]],
+        constant   int64_t & ne02[[buffer(5)]],
+        constant   int64_t & ne10[[buffer(9)]],
+        constant   int64_t & ne12[[buffer(11)]],
+        constant   int64_t & ne0 [[buffer(15)]],
+        constant   int64_t & ne1 [[buffer(16)]],
+        constant   uint    & r2  [[buffer(17)]],
+        constant   uint    & r3  [[buffer(18)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
 //============================= templates and their specializations =============================
 
 // NOTE: this is not dequantizing - we are simply fitting the template
@@ -3219,22 +3472,90 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
 kernel void kernel_get_rows(
         device const  void * src0,
-        device const   int * src1,
+        device const  char * src1,
         device       float * dst,
         constant   int64_t & ne00,
         constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
         constant  uint64_t & nb1,
-        uint                 tgpig[[threadgroup_position_in_grid]],
+        constant  uint64_t & nb2,
+        uint3                tgpig[[threadgroup_position_in_grid]],
         uint                 tiitg[[thread_index_in_threadgroup]],
-        uint                 tptg[[threads_per_threadgroup]]) {
-    const int i = tgpig;
-    const int r = ((device int32_t *) src1)[i];
+        uint3                tptg [[threads_per_threadgroup]]) {
+    //const int64_t i = tgpig;
+    //const int64_t r = ((device int32_t *) src1)[i];
 
-    for (int ind = tiitg; ind < ne00/16; ind += tptg) {
+    const int64_t i10 = tgpig.x;
+    const int64_t i11 = tgpig.y;
+
+    const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+    const int64_t i02 = i11;
+
+    for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
         float4x4 temp;
         dequantize_func(
-            ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
-        *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
+            ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
+        *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
+    }
+}
+
+kernel void kernel_get_rows_f32(
+        device const  void * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        uint3                tgpig[[threadgroup_position_in_grid]],
+        uint                 tiitg[[thread_index_in_threadgroup]],
+        uint3                tptg [[threads_per_threadgroup]]) {
+    const int64_t i10 = tgpig.x;
+    const int64_t i11 = tgpig.y;
+
+    const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+    const int64_t i02 = i11;
+
+    for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+        ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
+            ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
+    }
+}
+
+kernel void kernel_get_rows_f16(
+        device const  void * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        uint3                tgpig[[threadgroup_position_in_grid]],
+        uint                 tiitg[[thread_index_in_threadgroup]],
+        uint3                tptg [[threads_per_threadgroup]]) {
+    const int64_t i10 = tgpig.x;
+    const int64_t i11 = tgpig.y;
+
+    const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+    const int64_t i02 = i11;
+
+    for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+        ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
+            ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
     }
 }
 
@@ -3426,19 +3747,22 @@ kernel void kernel_mul_mm(device const  uchar * src0,
 
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
 kernel void kernel_mul_mm_id(
-        device const int32_t * ids,
+        device const   uchar * ids,
         device const   uchar * src1,
-        device         float * dst,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne02,
         constant     int64_t & nb01,
         constant     int64_t & nb02,
         constant     int64_t & ne12,
+        constant     int64_t & ne13,
         constant     int64_t & nb10,
         constant     int64_t & nb11,
         constant     int64_t & nb12,
         constant     int64_t & ne0,
         constant     int64_t & ne1,
+        constant     int64_t & nb1,
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
@@ -3456,10 +3780,16 @@ kernel void kernel_mul_mm_id(
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
     device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
 
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
     kernel_mul_mm_impl<block_q, nl, dequantize_func>(
-        src0[ids[idx]],
-        src1,
-        dst,
+        src0[id],
+        src1 + bid*nb11,
+        (device float *) (dst + bid*nb1),
         ne00,
         ne02,
         nb01,
@@ -3484,17 +3814,26 @@ kernel void kernel_mul_mm_id(
 #define QK_NL 4
 #endif
 
+//
+// get rows
+//
+
 typedef void (get_rows_t)(
         device const void * src0,
-        device const  int * src1,
+        device const char * src1,
         device      float * dst,
         constant  int64_t & ne00,
         constant uint64_t & nb01,
+        constant uint64_t & nb02,
+        constant  int64_t & ne10,
+        constant uint64_t & nb10,
+        constant uint64_t & nb11,
         constant uint64_t & nb1,
-        uint, uint, uint);
+        constant uint64_t & nb2,
+        uint3, uint, uint3);
 
-template [[host_name("kernel_get_rows_f32")]]  kernel get_rows_t kernel_get_rows<float4x4,   1, dequantize_f32>;
-template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_t kernel_get_rows<half4x4,    1, dequantize_f16>;
+//template [[host_name("kernel_get_rows_f32")]]  kernel get_rows_t kernel_get_rows<float4x4,   1, dequantize_f32>;
+//template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_t kernel_get_rows<half4x4,    1, dequantize_f16>;
 template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
 template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
 template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
@@ -3506,6 +3845,10 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
 template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
 template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
 
+//
+// matrix-matrix multiplication
+//
+
 typedef void (mat_mm_t)(
         device const  uchar * src0,
         device const  uchar * src1,
@@ -3538,20 +3881,27 @@ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
 template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
 template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
 
+//
+// indirect matrix-matrix multiplication
+//
+
 typedef void (mat_mm_id_t)(
-        device const int32_t * ids,
+        device const   uchar * ids,
         device const   uchar * src1,
-        device         float * dst,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne02,
         constant     int64_t & nb01,
         constant     int64_t & nb02,
         constant     int64_t & ne12,
+        constant     int64_t & ne13,
         constant     int64_t & nb10,
         constant     int64_t & nb11,
         constant     int64_t & nb12,
         constant     int64_t & ne0,
         constant     int64_t & ne1,
+        constant     int64_t & nb1,
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
@@ -3578,3 +3928,775 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mu
 template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
 template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
 template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
+
+//
+// matrix-vector multiplication
+//
+
+[[host_name("kernel_mul_mv_id_f32_f32")]]
+kernel void kernel_mul_mv_id_f32_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_f32_f32_impl(
+        src0[id],
+        src1 + bid*nb11,
+        (device float *) (dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        nb00,
+        nb01,
+        nb02,
+        ne10,
+        ne11,
+        ne12,
+        nb10,
+        nb11,
+        nb12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg);
+}
+
+[[host_name("kernel_mul_mv_id_f16_f32")]]
+kernel void kernel_mul_mv_id_f16_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_f16_f32_impl(
+        src0[id],
+        src1 + bid*nb11,
+        (device float *) (dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        nb00,
+        nb01,
+        nb02,
+        ne10,
+        ne11,
+        ne12,
+        nb10,
+        nb11,
+        nb12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg);
+}
+
+[[host_name("kernel_mul_mv_id_q8_0_f32")]]
+kernel void kernel_mul_mv_id_q8_0_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_q8_0_f32_impl(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q4_0_f32")]]
+kernel void kernel_mul_mv_id_q4_0_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q4_1_f32")]]
+kernel void kernel_mul_mv_id_q4_1_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q5_0_f32")]]
+kernel void kernel_mul_mv_id_q5_0_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q5_1_f32")]]
+kernel void kernel_mul_mv_id_q5_1_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q2_K_f32")]]
+kernel void kernel_mul_mv_id_q2_K_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_q2_K_f32_impl(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q3_K_f32")]]
+kernel void kernel_mul_mv_id_q3_K_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_q3_K_f32_impl(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q4_K_f32")]]
+kernel void kernel_mul_mv_id_q4_K_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_q4_K_f32_impl(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q5_K_f32")]]
+kernel void kernel_mul_mv_id_q5_K_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_q5_K_f32_impl(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q6_K_f32")]]
+kernel void kernel_mul_mv_id_q6_K_f32(
+        device const    char * ids,
+        device const    char * src1,
+        device         uchar * dst,
+        constant     int64_t & nbi1,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant     int64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        constant         int & idx,
+        device const    char * src00,
+        device const    char * src01,
+        device const    char * src02,
+        device const    char * src03,
+        device const    char * src04,
+        device const    char * src05,
+        device const    char * src06,
+        device const    char * src07,
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    const int64_t bid = tgpig.z/(ne12*ne13);
+
+    tgpig.z = tgpig.z%(ne12*ne13);
+
+    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+    kernel_mul_mv_q6_K_f32_impl(
+        src0[id],
+        (device const float *) (src1 + bid*nb11),
+        (device       float *) ( dst + bid*nb1),
+        ne00,
+        ne01,
+        ne02,
+        ne10,
+        ne12,
+        ne0,
+        ne1,
+        r2,
+        r3,
+        tgpig,
+        tiisg,
+        sgitg);
+}
diff --git a/ggml.c b/ggml.c
index eb7989dc45cefed7f51089148c0be21911fa2422..66658ff4b24cef452df166f68aeac763106460ba 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -4075,17 +4075,18 @@ struct ggml_tensor * ggml_mul_mat(
 
 struct ggml_tensor * ggml_mul_mat_id(
         struct ggml_context * ctx,
-        struct ggml_tensor  * as[],
+        struct ggml_tensor  * const as[],
+        int                   n_as,
         struct ggml_tensor  * ids,
         int                   id,
         struct ggml_tensor  * b) {
 
-    int64_t n_as = ids->ne[0];
-
     GGML_ASSERT(ids->type == GGML_TYPE_I32);
-    GGML_ASSERT(ggml_is_vector(ids));
+    GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1);
+    GGML_ASSERT(ids->ne[1] == b->ne[1]);
+    GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
     GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2);
-    GGML_ASSERT(id >= 0 && id < n_as);
+    GGML_ASSERT(id >= 0 && id < ids->ne[0]);
 
     bool is_node = false;
 
@@ -4097,13 +4098,14 @@ struct ggml_tensor * ggml_mul_mat_id(
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne);
 
     ggml_set_op_params_i32(result, 0, id);
+    ggml_set_op_params_i32(result, 1, n_as);
 
     result->op   = GGML_OP_MUL_MAT_ID;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = ids;
     result->src[1] = b;
 
-    for (int64_t i = 0; i < n_as; i++) {
+    for (int i = 0; i < n_as; i++) {
         struct ggml_tensor * a = as[i];
         GGML_ASSERT(ggml_are_same_shape(as[0], a));
         GGML_ASSERT(ggml_can_mul_mat(a, b));
@@ -4731,7 +4733,9 @@ struct ggml_tensor * ggml_get_rows(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b) {
-    GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32);
+    GGML_ASSERT(a->ne[2] == b->ne[1]);
+    GGML_ASSERT(b->ne[3] == 1);
+    GGML_ASSERT(b->type == GGML_TYPE_I32);
 
     bool is_node = false;
 
@@ -4741,7 +4745,7 @@ struct ggml_tensor * ggml_get_rows(
 
     // TODO: implement non F32 return
     //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
-    struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0]);
+    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
 
     result->op   = GGML_OP_GET_ROWS;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -9504,8 +9508,11 @@ static bool ggml_compute_forward_mul_mat_use_blas(
     const int64_t ne0 = dst->ne[0];
     const int64_t ne1 = dst->ne[1];
 
+    // NOTE: with GGML_OP_MUL_MAT_ID we don't want to go through the BLAS branch because it will dequantize (to_float)
+    //       all the experts for each batch element and the processing would become incredibly slow
     // TODO: find the optimal values for these
-    if (ggml_is_contiguous(src0) &&
+    if (dst->op != GGML_OP_MUL_MAT_ID &&
+        ggml_is_contiguous(src0) &&
         ggml_is_contiguous(src1) &&
       //src0->type == GGML_TYPE_F32 &&
         src1->type == GGML_TYPE_F32 &&
@@ -9519,11 +9526,16 @@ static bool ggml_compute_forward_mul_mat_use_blas(
 }
 #endif
 
+// off1 = offset in i11 and i1
+// cne1 = ne11 and ne1
+// in a normal matrix multiplication, off1 = 0 and cne1 = ne1
+// during GGML_TASK_INIT, the full src1 is converted regardless of off1 and cne1
 static void ggml_compute_forward_mul_mat(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
+              struct ggml_tensor * dst,
+              int64_t off1, int64_t cne1) {
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
@@ -9591,10 +9603,9 @@ static void ggml_compute_forward_mul_mat(
                 const int64_t i03 = i13/r3;
                 const int64_t i02 = i12/r2;
 
-                const void  * x = (char *)            src0->data + i02*nb02 + i03*nb03;
-                const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
-
-                float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
+                const void  * x = (char *)            src0->data +             i02*nb02 + i03*nb03;
+                const float * y = (float *) ((char *) src1->data + off1*nb11 + i12*nb12 + i13*nb13);
+                      float * d = (float *) ((char *)  dst->data + off1*nb1  + i12*nb2  + i13*nb3);
 
                 if (type != GGML_TYPE_F32) {
                             float * const wdata    = params->wdata;
@@ -9611,10 +9622,10 @@ static void ggml_compute_forward_mul_mat(
                 }
 
                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
-                        ne11, ne01, ne10,
-                        1.0f,    y, ne10,
-                                 x, ne00,
-                        0.0f,    d, ne01);
+                         cne1, ne01, ne10,
+                         1.0f,    y, ne10,
+                                  x, ne00,
+                         0.0f,    d, ne01);
             }
         }
 
@@ -9630,6 +9641,7 @@ static void ggml_compute_forward_mul_mat(
             const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
 
             assert(params->wsize >= ne11*ne12*ne13*row_size);
+            assert(src1->type == GGML_TYPE_F32);
 
             for (int64_t i13 = 0; i13 < ne13; ++i13) {
                 for (int64_t i12 = 0; i12 < ne12; ++i12) {
@@ -9652,7 +9664,7 @@ static void ggml_compute_forward_mul_mat(
     const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
 
     const int64_t nr0 = ne01;           // src0 rows
-    const int64_t nr1 = ne11*ne12*ne13; // src1 rows
+    const int64_t nr1 = cne1*ne12*ne13; // src1 rows
 
     //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
 
@@ -9694,9 +9706,9 @@ static void ggml_compute_forward_mul_mat(
     for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
         for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
             for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
-                const int64_t i13 = (ir1/(ne12*ne11));
-                const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11;
-                const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11);
+                const int64_t i13 = (ir1/(ne12*cne1));
+                const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1;
+                const int64_t i11 = (ir1 - i13*ne12*cne1 - i12*cne1) + off1;
 
                 // broadcast src0 into src1
                 const int64_t i03 = i13/r3;
@@ -9736,20 +9748,28 @@ static void ggml_compute_forward_mul_mat(
 
 static void ggml_compute_forward_mul_mat_id(
         const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
               struct ggml_tensor * dst) {
 
-    const struct ggml_tensor * ids = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    const int id = ggml_get_op_params_i32(dst, 0);
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        // during GGML_TASK_INIT the entire src1 is converted to vec_dot_type
+        ggml_compute_forward_mul_mat(params, dst->src[2], src1, dst, 0, dst->ne[1]);
+        return;
+    }
 
-    const int a_id = ((int32_t *)ids->data)[id];
+    const struct ggml_tensor * ids = src0;
+    const int id   = ggml_get_op_params_i32(dst, 0);
+    const int n_as = ggml_get_op_params_i32(dst, 1);
 
-    GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
+    for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
+        const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
 
-    const struct ggml_tensor * src0 = dst->src[a_id + 2];
+        GGML_ASSERT(row_id >= 0 && row_id < n_as);
 
-    ggml_compute_forward_mul_mat(params, src0, src1, dst);
+        const struct ggml_tensor * src0_row = dst->src[row_id + 2];
+        ggml_compute_forward_mul_mat(params, src0_row, src1, dst, i01, 1);
+    }
 }
 
 // ggml_compute_forward_out_prod
@@ -10325,21 +10345,30 @@ static void ggml_compute_forward_get_rows_q(
         return;
     }
 
-    const int nc = src0->ne[0];
-    const int nr = ggml_nelements(src1);
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int64_t nc = ne00;
+    const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
+
     const enum ggml_type type = src0->type;
     ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
 
-    assert( dst->ne[0] == nc);
-    assert( dst->ne[1] == nr);
-    assert(src0->nb[0] == ggml_type_size(type));
+    assert(ne0  == nc);
+    assert(ne02 == ne11);
+    assert(nb00 == ggml_type_size(type));
+    assert(ggml_nrows(dst) == nr);
 
-    for (int i = 0; i < nr; ++i) {
-        const int r = ((int32_t *) src1->data)[i];
+    // TODO: multi-thread
+    for (int64_t i12 = 0; i12 < ne12; ++i12) {
+        for (int64_t i11 = 0; i11 < ne11; ++i11) {
+            for (int64_t i10 = 0; i10 < ne10; ++i10) {
+                const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
 
-        dequantize_row_q(
-                (const void *) ((char *) src0->data + r*src0->nb[1]),
-                     (float *) ((char *)  dst->data + i*dst->nb[1]), nc);
+                dequantize_row_q(
+                        (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
+                             (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
+            }
+        }
     }
 }
 
@@ -10354,19 +10383,26 @@ static void ggml_compute_forward_get_rows_f16(
         return;
     }
 
-    const int nc = src0->ne[0];
-    const int nr = ggml_nelements(src1);
+    GGML_TENSOR_BINARY_OP_LOCALS
 
-    assert( dst->ne[0] == nc);
-    assert( dst->ne[1] == nr);
-    assert(src0->nb[0] == sizeof(ggml_fp16_t));
+    const int64_t nc = ne00;
+    const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
 
-    for (int i = 0; i < nr; ++i) {
-        const int r = ((int32_t *) src1->data)[i];
+    assert(ne0  == nc);
+    assert(ne02 == ne11);
+    assert(nb00 == sizeof(ggml_fp16_t));
+    assert(ggml_nrows(dst) == nr);
 
-        for (int j = 0; j < nc; ++j) {
-            ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j];
-            ((float *) ((char *)  dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v);
+    // TODO: multi-thread
+    for (int64_t i12 = 0; i12 < ne12; ++i12) {
+        for (int64_t i11 = 0; i11 < ne11; ++i11) {
+            for (int64_t i10 = 0; i10 < ne10; ++i10) {
+                const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+                ggml_fp16_to_fp32_row(
+                        (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
+                             (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
+            }
         }
     }
 }
@@ -10382,19 +10418,27 @@ static void ggml_compute_forward_get_rows_f32(
         return;
     }
 
-    const int nc = src0->ne[0];
-    const int nr = ggml_nelements(src1);
+    GGML_TENSOR_BINARY_OP_LOCALS
 
-    assert( dst->ne[0] == nc);
-    assert( dst->ne[1] == nr);
-    assert(src0->nb[0] == sizeof(float));
+    const int64_t nc = ne00;
+    const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
 
-    for (int i = 0; i < nr; ++i) {
-        const int r = ((int32_t *) src1->data)[i];
+    assert(ne0  == nc);
+    assert(ne02 == ne11);
+    assert(nb00 == sizeof(float));
+    assert(ggml_nrows(dst) == nr);
 
-        ggml_vec_cpy_f32(nc,
-                (float *) ((char *)  dst->data + i*dst->nb[1]),
-                (float *) ((char *) src0->data + r*src0->nb[1]));
+    // TODO: multi-thread
+    for (int64_t i12 = 0; i12 < ne12; ++i12) {
+        for (int64_t i11 = 0; i11 < ne11; ++i11) {
+            for (int64_t i10 = 0; i10 < ne10; ++i10) {
+                const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+                ggml_vec_cpy_f32(nc,
+                        (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3),
+                        (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
+            }
+        }
     }
 }
 
@@ -14037,11 +14081,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_MUL_MAT:
             {
-                ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
+                ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor, 0, tensor->ne[1]);
             } break;
         case GGML_OP_MUL_MAT_ID:
             {
-                ggml_compute_forward_mul_mat_id(params, tensor);
+                ggml_compute_forward_mul_mat_id(params, tensor->src[0], tensor->src[1], tensor);
             } break;
         case GGML_OP_OUT_PROD:
             {
diff --git a/ggml.h b/ggml.h
index 41a075e92d506d7de3b2129c8081b68d1bb39db4..32f256481615ce3703d08997bfff022a74283cff 100644 (file)
--- a/ggml.h
+++ b/ggml.h
 #define GGML_MAX_DIMS           4
 #define GGML_MAX_PARAMS         2048
 #define GGML_MAX_CONTEXTS       64
-#define GGML_MAX_SRC            6
+#define GGML_MAX_SRC            10
 #define GGML_MAX_NAME           64
 #define GGML_MAX_OP_PARAMS      64
 #define GGML_DEFAULT_N_THREADS  4
@@ -1051,7 +1051,8 @@ extern "C" {
     //  ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
     GGML_API struct ggml_tensor * ggml_mul_mat_id(
             struct ggml_context * ctx,
-            struct ggml_tensor  * as[],
+            struct ggml_tensor  * const as[],
+            int                   n_as,
             struct ggml_tensor  * ids,
             int                   id,
             struct ggml_tensor  * b);
@@ -1263,6 +1264,7 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    // supports 3D: a->ne[2] == b->ne[1]
     GGML_API struct ggml_tensor * ggml_get_rows(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
index 685c88f1a33972035e1f39e2319d3f227977f73a..12133882be2c4065f20091be9fe0ea8c3ae73adc 100644 (file)
@@ -38,6 +38,8 @@ class Keys:
         FEED_FORWARD_LENGTH   = "{arch}.feed_forward_length"
         USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
         TENSOR_DATA_LAYOUT    = "{arch}.tensor_data_layout"
+        EXPERT_COUNT          = "{arch}.expert_count"
+        EXPERT_USED_COUNT     = "{arch}.expert_used_count"
 
     class Attention:
         HEAD_COUNT        = "{arch}.attention.head_count"
@@ -111,10 +113,14 @@ class MODEL_TENSOR(IntEnum):
     ATTN_NORM       = auto()
     ATTN_NORM_2     = auto()
     ATTN_ROT_EMBD   = auto()
+    FFN_GATE_INP    = auto()
+    FFN_NORM        = auto()
     FFN_GATE        = auto()
     FFN_DOWN        = auto()
     FFN_UP          = auto()
-    FFN_NORM        = auto()
+    FFN_GATE_EXP    = auto()
+    FFN_DOWN_EXP    = auto()
+    FFN_UP_EXP      = auto()
     ATTN_Q_NORM     = auto()
     ATTN_K_NORM     = auto()
 
@@ -154,10 +160,14 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.ATTN_ROT_EMBD:   "blk.{bid}.attn_rot_embd",
     MODEL_TENSOR.ATTN_Q_NORM:     "blk.{bid}.attn_q_norm",
     MODEL_TENSOR.ATTN_K_NORM:     "blk.{bid}.attn_k_norm",
+    MODEL_TENSOR.FFN_GATE_INP:    "blk.{bid}.ffn_gate_inp",
     MODEL_TENSOR.FFN_NORM:        "blk.{bid}.ffn_norm",
     MODEL_TENSOR.FFN_GATE:        "blk.{bid}.ffn_gate",
     MODEL_TENSOR.FFN_DOWN:        "blk.{bid}.ffn_down",
     MODEL_TENSOR.FFN_UP:          "blk.{bid}.ffn_up",
+    MODEL_TENSOR.FFN_GATE_EXP:    "blk.{bid}.ffn_gate.{xid}",
+    MODEL_TENSOR.FFN_DOWN_EXP:    "blk.{bid}.ffn_down.{xid}",
+    MODEL_TENSOR.FFN_UP_EXP:      "blk.{bid}.ffn_up.{xid}",
 }
 
 MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@@ -172,10 +182,14 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.ATTN_V,
         MODEL_TENSOR.ATTN_OUT,
         MODEL_TENSOR.ATTN_ROT_EMBD,
+        MODEL_TENSOR.FFN_GATE_INP,
         MODEL_TENSOR.FFN_NORM,
         MODEL_TENSOR.FFN_GATE,
         MODEL_TENSOR.FFN_DOWN,
         MODEL_TENSOR.FFN_UP,
+        MODEL_TENSOR.FFN_GATE_EXP,
+        MODEL_TENSOR.FFN_DOWN_EXP,
+        MODEL_TENSOR.FFN_UP_EXP,
     ],
     MODEL_ARCH.GPTNEOX: [
         MODEL_TENSOR.TOKEN_EMBD,
index b8ec977c8f3fa52cdd4aab415fb8579c637f5e6e..73e02160750b252140d247e8b018904a31e21fe8 100644 (file)
@@ -339,6 +339,12 @@ class GGUFWriter:
     def add_clamp_kqv(self, value: float) -> None:
         self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)
 
+    def add_expert_count(self, count: int) -> None:
+        self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count)
+
+    def add_expert_used_count(self, count: int) -> None:
+        self.add_uint32(Keys.LLM.EXPERT_USED_COUNT.format(arch=self.arch), count)
+
     def add_layer_norm_eps(self, value: float) -> None:
         self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)
 
index cc6236014eb723c55c3f5969f2eb2d0411904b79..0115ea1c605b1f796ba870fd00522a0bf7414a91 100644 (file)
@@ -149,6 +149,11 @@ class TensorNameMap:
             "model.layers.{bid}.ln2",                                        # yi
         ),
 
+        MODEL_TENSOR.FFN_GATE_INP: (
+            "layers.{bid}.feed_forward.gate",           # mixtral
+            "model.layers.{bid}.block_sparse_moe.gate", # mixtral
+        ),
+
         # Feed-forward up
         MODEL_TENSOR.FFN_UP: (
             "gpt_neox.layers.{bid}.mlp.dense_h_to_4h",                # gptneox
@@ -164,11 +169,21 @@ class TensorNameMap:
             "transformer.h.{bid}.mlp.w1",                             # qwen
         ),
 
+        MODEL_TENSOR.FFN_UP_EXP: (
+            "layers.{bid}.feed_forward.experts.{xid}.w3",           # mixtral
+            "model.layers.{bid}.block_sparse_moe.experts.{xid}.w3", # mixtral
+        ),
+
         # Feed-forward gate
         MODEL_TENSOR.FFN_GATE: (
-            "model.layers.{bid}.mlp.gate_proj",  # llama-hf refact
-            "layers.{bid}.feed_forward.w1",      # llama-pth
-            "transformer.h.{bid}.mlp.w2",        # qwen
+            "model.layers.{bid}.mlp.gate_proj",           # llama-hf refact
+            "layers.{bid}.feed_forward.w1",               # llama-pth
+            "transformer.h.{bid}.mlp.w2",                 # qwen
+        ),
+
+        MODEL_TENSOR.FFN_GATE_EXP: (
+            "layers.{bid}.feed_forward.experts.{xid}.w1",           # mixtral
+            "model.layers.{bid}.block_sparse_moe.experts.{xid}.w1", # mixtral
         ),
 
         # Feed-forward down
@@ -185,6 +200,11 @@ class TensorNameMap:
             "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h",  # persimmon
         ),
 
+        MODEL_TENSOR.FFN_DOWN_EXP: (
+            "layers.{bid}.feed_forward.experts.{xid}.w2",           # mixtral
+            "model.layers.{bid}.block_sparse_moe.experts.{xid}.w2", # mixtral
+        ),
+
         MODEL_TENSOR.ATTN_Q_NORM: (
             "language_model.encoder.layers.{bid}.self_attention.q_layernorm",
         ),
@@ -213,11 +233,14 @@ class TensorNameMap:
             for tensor, keys in self.block_mappings_cfg.items():
                 if tensor not in MODEL_TENSORS[arch]:
                     continue
-                tensor_name = TENSOR_NAMES[tensor].format(bid = bid)
-                self.mapping[tensor_name] = (tensor, tensor_name)
-                for key in keys:
-                    key = key.format(bid = bid)
-                    self.mapping[key] = (tensor, tensor_name)
+                # TODO: make this configurable
+                n_experts = 8
+                for xid in range(n_experts):
+                    tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid)
+                    self.mapping[tensor_name] = (tensor, tensor_name)
+                    for key in keys:
+                        key = key.format(bid = bid, xid = xid)
+                        self.mapping[key] = (tensor, tensor_name)
 
     def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
         result = self.mapping.get(key)
index e6374bfe898a4b86a55e0a06c2e8768fdb94707d..9789c2c877165edc73697e38c71e10edebfb0dd9 100644 (file)
@@ -1,6 +1,6 @@
 [tool.poetry]
 name = "gguf"
-version = "0.6.0"
+version = "0.7.0"
 description = "Read and write ML models in GGUF for GGML"
 authors = ["GGML <ggml@ggml.ai>"]
 packages = [
index 54fa9e43eb605c10859fa141ca396717417d7105..0e5ab044cdfa4fa8c471d76846726abeaa4bafbc 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -91,7 +91,8 @@
 #define LLAMA_ATTRIBUTE_FORMAT(...)
 #endif
 
-#define LLAMA_MAX_NODES 8192
+#define LLAMA_MAX_NODES   8192
+#define LLAMA_MAX_EXPERTS 8
 
 //
 // logging
@@ -231,6 +232,8 @@ enum llm_kv {
     LLM_KV_FEED_FORWARD_LENGTH,
     LLM_KV_USE_PARALLEL_RESIDUAL,
     LLM_KV_TENSOR_DATA_LAYOUT,
+    LLM_KV_EXPERT_COUNT,
+    LLM_KV_EXPERT_USED_COUNT,
 
     LLM_KV_ATTENTION_HEAD_COUNT,
     LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -281,6 +284,8 @@ static std::map<llm_kv, std::string> LLM_KV_NAMES = {
     { LLM_KV_FEED_FORWARD_LENGTH,           "%s.feed_forward_length"   },
     { LLM_KV_USE_PARALLEL_RESIDUAL,         "%s.use_parallel_residual" },
     { LLM_KV_TENSOR_DATA_LAYOUT,            "%s.tensor_data_layout"    },
+    { LLM_KV_EXPERT_COUNT,                  "%s.expert_count"          },
+    { LLM_KV_EXPERT_USED_COUNT,             "%s.expert_used_count"     },
 
     { LLM_KV_ATTENTION_HEAD_COUNT,          "%s.attention.head_count"             },
     { LLM_KV_ATTENTION_HEAD_COUNT_KV,       "%s.attention.head_count_kv"          },
@@ -338,10 +343,14 @@ enum llm_tensor {
     LLM_TENSOR_ATTN_NORM,
     LLM_TENSOR_ATTN_NORM_2,
     LLM_TENSOR_ATTN_ROT_EMBD,
+    LLM_TENSOR_FFN_GATE_INP,
+    LLM_TENSOR_FFN_NORM,
     LLM_TENSOR_FFN_GATE,
     LLM_TENSOR_FFN_DOWN,
     LLM_TENSOR_FFN_UP,
-    LLM_TENSOR_FFN_NORM,
+    LLM_TENSOR_FFN_DOWN_EXP,
+    LLM_TENSOR_FFN_GATE_EXP,
+    LLM_TENSOR_FFN_UP_EXP,
     LLM_TENSOR_ATTN_Q_NORM,
     LLM_TENSOR_ATTN_K_NORM,
 };
@@ -360,10 +369,14 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
             { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
             { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
             { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
+            { LLM_TENSOR_FFN_GATE_INP,    "blk.%d.ffn_gate_inp" },
             { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
             { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
             { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+            { LLM_TENSOR_FFN_GATE_EXP,    "blk.%d.ffn_gate.%d" },
+            { LLM_TENSOR_FFN_DOWN_EXP,    "blk.%d.ffn_down.%d" },
+            { LLM_TENSOR_FFN_UP_EXP,      "blk.%d.ffn_up.%d" },
         },
     },
     {
@@ -585,6 +598,10 @@ struct LLM_TN {
     std::string operator()(llm_tensor tensor, const std::string & suffix, int bid) const {
         return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid) + "." + suffix;
     }
+
+    std::string operator()(llm_tensor tensor, const std::string & suffix, int bid, int xid) const {
+        return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid, xid) + "." + suffix;
+    }
 };
 
 //
@@ -1164,6 +1181,8 @@ struct llama_hparams {
     uint32_t n_layer;
     uint32_t n_rot;
     uint32_t n_ff;
+    uint32_t n_expert = 0;
+    uint32_t n_expert_used = 0;
 
     float f_norm_eps;
     float f_norm_rms_eps;
@@ -1178,15 +1197,18 @@ struct llama_hparams {
     float f_max_alibi_bias;
 
     bool operator!=(const llama_hparams & other) const {
-        if (this->vocab_only  != other.vocab_only)  return true;
-        if (this->n_vocab     != other.n_vocab)     return true;
-        if (this->n_ctx_train != other.n_ctx_train) return true;
-        if (this->n_embd      != other.n_embd)      return true;
-        if (this->n_head      != other.n_head)      return true;
-        if (this->n_head_kv   != other.n_head_kv)   return true;
-        if (this->n_layer     != other.n_layer)     return true;
-        if (this->n_rot       != other.n_rot)       return true;
-        if (this->n_ff        != other.n_ff)        return true;
+        if (this->vocab_only    != other.vocab_only)    return true;
+        if (this->n_vocab       != other.n_vocab)       return true;
+        if (this->n_ctx_train   != other.n_ctx_train)   return true;
+        if (this->n_embd        != other.n_embd)        return true;
+        if (this->n_head        != other.n_head)        return true;
+        if (this->n_head_kv     != other.n_head_kv)     return true;
+        if (this->n_layer       != other.n_layer)       return true;
+        if (this->n_rot         != other.n_rot)         return true;
+        if (this->n_ff          != other.n_ff)          return true;
+        if (this->n_expert      != other.n_expert)      return true;
+        if (this->n_expert_used != other.n_expert_used) return true;
+
         if (this->rope_finetuned  != other.rope_finetuned)  return true;
         if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
 
@@ -1268,6 +1290,12 @@ struct llama_layer {
     struct ggml_tensor * ffn_down; // w2
     struct ggml_tensor * ffn_up;   // w3
 
+    // ff MoE
+    struct ggml_tensor * ffn_gate_inp;
+    struct ggml_tensor * ffn_gate_exp[LLAMA_MAX_EXPERTS];
+    struct ggml_tensor * ffn_down_exp[LLAMA_MAX_EXPERTS];
+    struct ggml_tensor * ffn_up_exp  [LLAMA_MAX_EXPERTS];
+
     // ff bias
     struct ggml_tensor * ffn_down_b; // b2
     struct ggml_tensor * ffn_up_b;   // b3
@@ -2440,6 +2468,16 @@ static void llm_load_hparams(
     ml.get_key  (LLM_KV_FEED_FORWARD_LENGTH,  hparams.n_ff);
     ml.get_key  (LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head);
     ml.get_key  (LLM_KV_BLOCK_COUNT,          hparams.n_layer);
+    ml.get_key  (LLM_KV_EXPERT_COUNT,         hparams.n_expert,      false);
+    ml.get_key  (LLM_KV_EXPERT_USED_COUNT,    hparams.n_expert_used, false);
+
+    GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS);
+    GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert);
+    if (hparams.n_expert > 0) {
+        GGML_ASSERT(hparams.n_expert_used > 0);
+    } else {
+        GGML_ASSERT(hparams.n_expert_used == 0);
+    }
 
     // n_head_kv is optional, default to n_head
     hparams.n_head_kv = hparams.n_head;
@@ -2871,6 +2909,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     LLAMA_LOG_INFO("%s: f_clamp_kqv      = %.1e\n",   __func__, hparams.f_clamp_kqv);
     LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n",   __func__, hparams.f_max_alibi_bias);
     LLAMA_LOG_INFO("%s: n_ff             = %u\n",     __func__, hparams.n_ff);
+    LLAMA_LOG_INFO("%s: n_expert         = %u\n",     __func__, hparams.n_expert);
+    LLAMA_LOG_INFO("%s: n_expert_used    = %u\n",     __func__, hparams.n_expert_used);
     LLAMA_LOG_INFO("%s: rope scaling     = %s\n",     __func__, rope_scaling_type.c_str());
     LLAMA_LOG_INFO("%s: freq_base_train  = %.1f\n",   __func__, hparams.rope_freq_base_train);
     LLAMA_LOG_INFO("%s: freq_scale_train = %g\n",     __func__, hparams.rope_freq_scale_train);
@@ -3025,9 +3065,26 @@ static void llm_load_tensors(
 
                         layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
 
-                        layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, backend_split);
-                        layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, backend_split);
-                        layer.ffn_up   = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, backend_split);
+                        layer.ffn_gate_inp = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd}, backend, false);
+
+                        if (layer.ffn_gate_inp == nullptr) {
+                            GGML_ASSERT(hparams.n_expert      == 0);
+                            GGML_ASSERT(hparams.n_expert_used == 0);
+
+                            layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, backend_split);
+                            layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, backend_split);
+                            layer.ffn_up   = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, backend_split);
+                        } else {
+                            GGML_ASSERT(hparams.n_expert      > 0);
+                            GGML_ASSERT(hparams.n_expert_used > 0);
+
+                            // MoE branch
+                            for (uint32_t x = 0; x < hparams.n_expert; ++x) {
+                                layer.ffn_gate_exp[x] = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), {n_embd,   n_ff}, backend_split);
+                                layer.ffn_down_exp[x] = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x), {  n_ff, n_embd}, backend_split);
+                                layer.ffn_up_exp[x]   = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP_EXP,   "weight", i, x), {n_embd,   n_ff}, backend_split);
+                            }
+                        }
 
                         if (backend == GGML_BACKEND_GPU) {
                             vram_weights +=
@@ -3037,8 +3094,18 @@ static void llm_load_tensors(
                                 (layer.bk ? ggml_nbytes(layer.bk) : 0) +
                                 (layer.bv ? ggml_nbytes(layer.bv) : 0) +
                                 (layer.bo ? ggml_nbytes(layer.bo) : 0) +
-                                ggml_nbytes(layer.ffn_norm) + ggml_nbytes(layer.ffn_gate) +
-                                ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up);
+                                ggml_nbytes(layer.ffn_norm);
+
+                            if (layer.ffn_gate_inp == nullptr) {
+                                vram_weights +=
+                                    ggml_nbytes(layer.ffn_gate) + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up);
+                            } else {
+                                vram_weights += ggml_nbytes(layer.ffn_gate_inp);
+                                for (uint32_t x = 0; x < hparams.n_expert; ++x) {
+                                    vram_weights +=
+                                        ggml_nbytes(layer.ffn_gate_exp[x]) + ggml_nbytes(layer.ffn_down_exp[x]) + ggml_nbytes(layer.ffn_up_exp[x]);
+                                }
+                            }
                         }
                     }
                 } break;
@@ -4019,6 +4086,8 @@ struct llm_build_context {
     const int64_t n_head_kv;
     const int64_t n_embd_head;
     const int64_t n_embd_gqa;
+    const int64_t n_expert;
+    const int64_t n_expert_used;
 
     const float freq_base;
     const float freq_scale;
@@ -4060,6 +4129,8 @@ struct llm_build_context {
         n_head_kv     (hparams.n_head_kv),
         n_embd_head   (hparams.n_embd_head()),
         n_embd_gqa    (hparams.n_embd_gqa()),
+        n_expert      (hparams.n_expert),
+        n_expert_used (hparams.n_expert_used),
         freq_base     (cparams.rope_freq_base),
         freq_scale    (cparams.rope_freq_scale),
         ext_factor    (cparams.yarn_ext_factor),
@@ -4184,7 +4255,7 @@ struct llm_build_context {
             cb(ffn_inp, "ffn_inp", il);
 
             // feed-forward network
-            {
+            if (model.layers[il].ffn_gate_inp == nullptr) {
                 cur = llm_build_norm(ctx0, ffn_inp, hparams,
                         model.layers[il].ffn_norm, NULL,
                         LLM_NORM_RMS, cb, il);
@@ -4196,6 +4267,69 @@ struct llm_build_context {
                         model.layers[il].ffn_down, NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                 cb(cur, "ffn_out", il);
+            } else {
+                // MoE branch
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "ffn_norm", il);
+
+                ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
+                cb(logits, "ffn_moe_logits", il);
+
+                ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
+                cb(probs, "ffn_moe_probs", il);
+
+                // select experts
+                ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
+                cb(selected_experts->src[0], "ffn_moe_argsort", il);
+
+                ggml_tensor * weights = ggml_get_rows(ctx0,
+                        ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
+                cb(weights, "ffn_moe_weights", il);
+
+                weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
+
+                ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
+                cb(weights_sum, "ffn_moe_weights_sum", il);
+
+                weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
+                cb(weights, "ffn_moe_weights_norm", il);
+
+                // compute expert outputs
+                ggml_tensor * moe_out = nullptr;
+
+                for (int i = 0; i < n_expert_used; ++i) {
+                    ggml_tensor * cur_expert;
+
+                    ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exp, n_expert, selected_experts, i, cur);
+                    cb(cur_up, "ffn_moe_up", il);
+
+                    ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exp, n_expert, selected_experts, i, cur);
+                    cb(cur_gate, "ffn_moe_gate", il);
+
+                    cur_gate = ggml_silu(ctx0, cur_gate);
+                    cb(cur_gate, "ffn_moe_silu", il);
+
+                    cur_expert = ggml_mul(ctx0, cur_up, cur_gate); // [n_tokens, n_embd]
+                    cb(cur_expert, "ffn_moe_gate_par", il);
+
+                    cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exp, n_expert, selected_experts, i, cur_expert); // [n_tokens, n_embd]
+                    cb(cur_expert, "ffn_moe_down", il);
+
+                    cur_expert = ggml_mul(ctx0, cur_expert,
+                            ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
+                    cb(cur_expert, "ffn_moe_weighted", il);
+
+                    if (i == 0) {
+                        moe_out = cur_expert;
+                    } else {
+                        moe_out = ggml_add(ctx0, moe_out, cur_expert);
+                        cb(moe_out, "ffn_moe_out", il);
+                    }
+                }
+
+                cur = moe_out;
             }
 
             cur = ggml_add(ctx0, cur, ffn_inp);
@@ -5450,6 +5584,20 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
     { "ffn_relu",                   OFFLOAD_FUNC     },
     { "ffn_sqr(relu)",              OFFLOAD_FUNC     },
 
+    { "ffn_moe_logits",             OFFLOAD_FUNC     },
+    { "ffn_moe_probs",              OFFLOAD_FUNC     },
+    { "ffn_moe_argsort",            OFFLOAD_FUNC     },
+    { "ffn_moe_weights",            OFFLOAD_FUNC     },
+    { "ffn_moe_weights_sum",        OFFLOAD_FUNC     },
+    { "ffn_moe_weights_norm",       OFFLOAD_FUNC     },
+    { "ffn_moe_weighted",           OFFLOAD_FUNC     },
+    { "ffn_moe_up",                 OFFLOAD_FUNC     },
+    { "ffn_moe_gate",               OFFLOAD_FUNC     },
+    { "ffn_moe_silu",               OFFLOAD_FUNC     },
+    { "ffn_moe_gate_par",           OFFLOAD_FUNC     },
+    { "ffn_moe_down",               OFFLOAD_FUNC     },
+    { "ffn_moe_out",                OFFLOAD_FUNC     },
+
     { "l_out",                      OFFLOAD_FUNC     },
 
     { "result_norm",                OFFLOAD_FUNC_EMB },
@@ -8067,11 +8215,9 @@ static void llama_convert_tensor_internal(
     workers.clear();
 }
 
-static ggml_type get_k_quant_type(
-    quantize_state_internal & qs,
-    ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype
-) {
+static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
     const std::string name = ggml_get_name(tensor);
+
     // TODO: avoid hardcoded tensor names - use the TN_* constants
     const llm_arch arch = qs.model.arch;
     const auto       tn = LLM_TN(arch);
@@ -8105,7 +8251,18 @@ static ggml_type get_k_quant_type(
             // nearly negligible increase in model size by quantizing this tensor with more bits:
             if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K;
         }
+        if (qs.model.hparams.n_expert == 8) {
+            // for the 8-expert model, bumping this to Q8_0 trades just ~128MB
+            // TODO: explore better strategies
+            new_type = GGML_TYPE_Q8_0;
+        }
         ++qs.i_attention_wv;
+    } else if (name.find("attn_k.weight") != std::string::npos) {
+        if (qs.model.hparams.n_expert == 8) {
+            // for the 8-expert model, bumping this to Q8_0 trades just ~128MB
+            // TODO: explore better strategies
+            new_type = GGML_TYPE_Q8_0;
+        }
     } else if (name.find("ffn_down.weight") != std::string::npos) {
         if      (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
         else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
@@ -8318,6 +8475,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         quantize &= params->quantize_output_tensor || name != "output.weight";
         quantize &= !params->only_copy;
 
+        // do not quantize expert gating tensors
+        quantize &= name.find("ffn_gate_inp.weight") == std::string::npos;
+
         enum ggml_type new_type;
         void * new_data;
         size_t new_size;
index e0155ac1c89139367a6a09e77b2f8a1cc3501ab0..44830b4d4da301a23bcf5b0fdeae89540a1a44f2 100644 (file)
@@ -20,8 +20,6 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
     size_t size = ggml_nelements(tensor);
     std::vector<float> data(size);
 
-    std::random_device rd;
-
 #if 0
     std::default_random_engine generator(rd());
     std::uniform_real_distribution<float> distribution(min, max);
@@ -31,6 +29,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
     }
 #endif
     auto init_thread = [&](size_t start, size_t end) {
+        std::random_device rd;
         std::default_random_engine generator(rd());
         std::uniform_real_distribution<float> distribution(min, max);
 
@@ -51,7 +50,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
         t.join();
     }
 
-    if (tensor->type == GGML_TYPE_F32) {
+    if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) {
         ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float));
     } else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16) {
         GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0);
@@ -71,23 +70,28 @@ static std::vector<float> tensor_to_float(const ggml_tensor * t) {
     std::vector<uint8_t> buf(ggml_nbytes(t));
     ggml_backend_tensor_get(t, buf.data(), 0, ggml_nbytes(t));
 
+    ggml_type_traits_t tt = ggml_internal_get_type_traits(t->type);
+    size_t bs = ggml_blck_size(t->type);
+
     // access elements by index to avoid gaps in views
     for (int64_t i3 = 0; i3 < t->ne[3]; i3++) {
         for (int64_t i2 = 0; i2 < t->ne[2]; i2++) {
             for (int64_t i1 = 0; i1 < t->ne[1]; i1++) {
-                for (int64_t i0 = 0; i0 < t->ne[0]; i0++) {
-                    size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0*t->nb[0];
-                    float v;
+                for (int64_t i0 = 0; i0 < t->ne[0]; i0 += bs) {
+                    size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0];
                     if (t->type == GGML_TYPE_F16) {
-                        v = (float) ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i]);
+                        tv.push_back(ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i]));
                     } else if (t->type == GGML_TYPE_F32) {
-                        v = *(float *) &buf[i];
+                        tv.push_back(*(float *) &buf[i]);
                     } else if (t->type == GGML_TYPE_I32) {
-                        v = *(int32_t *) &buf[i];
+                        tv.push_back((float)*(int32_t *) &buf[i]);
+                    } else if (ggml_is_quantized(t->type)) {
+                        std::vector<float> vq(ggml_blck_size(t->type));
+                        tt.to_float(&buf[i], vq.data(), ggml_blck_size(t->type));
+                        tv.insert(tv.end(), vq.begin(), vq.end());
                     } else {
                         GGML_ASSERT(false);
                     }
-                    tv.push_back(v);
                 }
             }
         }
@@ -233,6 +237,10 @@ static bool ggml_is_view_op(enum ggml_op op) {
 struct test_case {
     virtual ~test_case() {}
 
+    virtual std::string op_desc(ggml_tensor * t) {
+        return ggml_op_desc(t);
+    }
+
     virtual std::string vars() {
         return "";
     }
@@ -240,7 +248,7 @@ struct test_case {
     virtual ggml_tensor * build_graph(ggml_context * ctx) = 0;
 
     virtual double max_nmse_err() {
-        return 1e-6;
+        return 1e-7;
     }
 
     virtual void initialize_tensors(ggml_context * ctx) {
@@ -270,13 +278,13 @@ struct test_case {
 
         ggml_tensor * out = build_graph(ctx);
 
-        if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) {
-            //printf("  %s: skipping\n", ggml_op_desc(out));
+        if (op_name != nullptr && op_desc(out) != op_name) {
+            //printf("  %s: skipping\n", op_desc(out).c_str());
             ggml_free(ctx);
             return true;
         }
 
-        printf("  %s(%s): ", ggml_op_desc(out), vars().c_str());
+        printf("  %s(%s): ", op_desc(out).c_str(), vars().c_str());
         fflush(stdout);
 
         // check if backends support op
@@ -317,7 +325,7 @@ struct test_case {
             for (size_t i = 0; i < f1.size(); i++) {
                 // check for nans
                 if (std::isnan(f1[i]) || std::isnan(f2[i])) {
-                    printf("NaN at index %zu ", i);
+                    printf("[%s] NaN at index %zu (%f %f) ", ggml_op_desc(t1), i, f1[i], f2[i]);
                     ud->ok = false;
                     return true;
                 }
@@ -325,12 +333,12 @@ struct test_case {
                 if (isinf_or_max(f1[i]) || isinf_or_max(f2[i])) {
                     if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) {
                         if (std::signbit(f1[i]) != std::signbit(f2[i])) {
-                            printf("inf sign mismatch: %f %f ", f1[i], f2[i]);
+                            printf("[%s] inf sign mismatch: %f %f ", ggml_op_desc(t1), f1[i], f2[i]);
                             ud->ok = false;
                             return true;
                         }
                     } else {
-                        printf("inf mismatch: %f %f ", f1[i], f2[i]);
+                        printf("[%s] inf mismatch: %f %f ", ggml_op_desc(t1), f1[i], f2[i]);
                         ud->ok = false;
                         return true;
                     }
@@ -339,10 +347,16 @@ struct test_case {
 
             double err = nmse(f1.data(), f2.data(), f1.size());
             if (err > ud->max_err) {
-                printf("NMSE = %f ", err);
+                printf("[%s] NMSE = %f ", ggml_op_desc(t1), err);
+                //for (int i = 0; i < f1.size(); i++) {
+                //    printf("(%f, %f) ", f1[i], f2[i]);
+                //}
+                //printf("\n");
                 ud->ok = false;
             }
             return true;
+
+            GGML_UNUSED(index);
         };
 
         ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud);
@@ -372,13 +386,13 @@ struct test_case {
 
         ggml_tensor * out = build_graph(ctx);
 
-        if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) {
-            //printf("  %s: skipping\n", ggml_op_desc(out));
+        if (op_name != nullptr && op_desc(out) != op_name) {
+            //printf("  %s: skipping\n", op_desc(out).c_str());
             ggml_free(ctx);
             return true;
         }
 
-        int len = printf("  %s(%s): ", ggml_op_desc(out), vars().c_str());
+        int len = printf("  %s(%s): ", op_desc(out).c_str(), vars().c_str());
         fflush(stdout);
 
         // check if backends support op
@@ -430,8 +444,9 @@ struct test_case {
             return size;
         };
         for (int i = 0; i < gf->n_nodes; i++) {
-            if (ggml_is_view_op(gf->nodes[i]->op) || gf->nodes[i] == out)
+            if (ggml_is_view_op(gf->nodes[i]->op) || gf->nodes[i] == out) {
                 continue;
+            }
             mem += tensor_op_size(gf->nodes[i]);
         }
 
@@ -486,17 +501,22 @@ struct test_get_rows : public test_case {
     const int n; // cols
     const int m; // rows
     const int r; // rows to get
+    const int b; // batch size
+    const bool v; // view (non-contiguous src1)
 
     std::string vars() override {
-        return VARS_TO_STR4(type, n, m, r);
+        return VARS_TO_STR6(type, n, m, r, b, v);
     }
 
-    test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3)
-        : type(type), n(n), m(m), r(r) {}
+    test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1, bool v = false)
+        : type(type), n(n), m(m), r(r), b(b), v(v) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
-        ggml_tensor * in = ggml_new_tensor_2d(ctx, type, n, m);
-        ggml_tensor * rows = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, r);
+        ggml_tensor * in = ggml_new_tensor_3d(ctx, type, n, m, b);
+        ggml_tensor * rows = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, r, b);
+        if (v) {
+            rows = ggml_view_2d(ctx, rows, r/2, b, rows->nb[1], 0);
+        }
         ggml_tensor * out = ggml_get_rows(ctx, in, rows);
         return out;
     }
@@ -504,12 +524,13 @@ struct test_get_rows : public test_case {
     void initialize_tensors(ggml_context * ctx) override {
         for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
             if (t->type == GGML_TYPE_I32) {
+                if (ggml_is_view_op(t->op)) { continue; }
                 // rows
-                std::vector<int> data(r);
-                for (int i = 0; i < r; i++) {
+                std::vector<int> data(r*b);
+                for (int i = 0; i < r*b; i++) {
                     data[i] = rand() % m;
                 }
-                ggml_backend_tensor_set(t, data.data(), 0, r * sizeof(int));
+                ggml_backend_tensor_set(t, data.data(), 0, r * b * sizeof(int));
             } else {
                 init_tensor_uniform(t);
             }
@@ -770,11 +791,10 @@ struct test_mul_mat_id : public test_case {
     const int64_t m;
     const int64_t n;
     const int64_t k;
-    const std::array<int64_t, 2> bs; // dims 3 and 4
-    const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
+    const bool v; // view (non-contiguous ids)
 
     std::string vars() override {
-        return VARS_TO_STR9(type_a, type_b, n_mats, id, m, n, k, bs, nr);
+        return VARS_TO_STR8(type_a, type_b, n_mats, id, m, n, k, v);
     }
 
     double max_nmse_err() override {
@@ -782,7 +802,7 @@ struct test_mul_mat_id : public test_case {
     }
 
     size_t op_size(ggml_tensor * t) override {
-        size_t a = ggml_nbytes(t->src[2]) * n * nr[0] * nr[1];
+        size_t a = ggml_nbytes(t->src[2]) * n;
         size_t b = ggml_nbytes(t->src[1]) * m;
         size_t c  = ggml_nbytes(t);
         return a + b + c;
@@ -792,35 +812,41 @@ struct test_mul_mat_id : public test_case {
 
     test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
             int n_mats = 2, int id = 0,
-            int64_t m = 32, int64_t n = 32, int64_t k = 32,
-            std::array<int64_t, 2> bs = {10, 10},
-            std::array<int64_t, 2> nr = {2, 2})
+            int64_t m = 32, int64_t n = 32, int64_t k = 32, bool v = false)
         : type_a(type_a), type_b(type_b), n_mats(n_mats), id(id),
-            m(m), n(n), k(k), bs(bs), nr(nr) {}
+            m(m), n(n), k(k), v(v) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         // C^T = A * B^T: (k, m) * (k, n) => (m, n)
         std::vector<ggml_tensor *> mats;
         for (int i = 0; i < n_mats; i++) {
-            ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
+            ggml_tensor * a = ggml_new_tensor_2d(ctx, type_a, k, m);
             mats.push_back(a);
         }
-        ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_mats);
-        ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
-        ggml_tensor * out = ggml_mul_mat_id(ctx, mats.data(), ids, id, b);
+        ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n);
+        if (v) {
+            ids = ggml_view_2d(ctx, ids, n_mats/2, ids->ne[1], ids->nb[1], 0);
+        }
+        ggml_tensor * b = ggml_new_tensor_2d(ctx, type_b, k, n);
+        ggml_tensor * out = ggml_mul_mat_id(ctx, mats.data(), n_mats, ids, v ? id/2 : id, b);
         return out;
     }
 
     void initialize_tensors(ggml_context * ctx) override {
+        std::random_device rd;
+        std::default_random_engine rng(rd());
         for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
             if (t->type == GGML_TYPE_I32) {
+                if (ggml_is_view_op(t->op)) { continue; }
                 // ids
-                std::vector<int> data(n_mats);
-                for (int i = 0; i < n_mats; i++) {
-                    data[i] = i;
+                for (int64_t r = 0; r < ggml_nrows(t); r++) {
+                    std::vector<int32_t> data(t->ne[0]);
+                    for (int i = 0; i < t->ne[0]; i++) {
+                        data[i] = i % n_mats;
+                    }
+                    std::shuffle(data.begin(), data.end(), rng);
+                    ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
                 }
-                std::shuffle(data.begin(), data.end(), std::default_random_engine(std::random_device()()));
-                ggml_backend_tensor_set(t, data.data(), 0, n_mats * sizeof(int));
             } else {
                 init_tensor_uniform(t);
             }
@@ -1109,6 +1135,90 @@ struct test_sum_rows : public test_case {
     }
 };
 
+// Mixtral MOE
+struct test_moe : public test_case {
+    const int n_experts;
+    const int n_experts_per_tok;
+    const int n_tokens;
+    const int n_embd;
+    const int n_ff;
+
+    std::string op_desc(ggml_tensor * t) override {
+        return "MOE";
+
+        GGML_UNUSED(t);
+    }
+
+    std::string vars() override {
+        return VARS_TO_STR5(n_experts, n_experts_per_tok, n_tokens, n_embd, n_ff);
+    }
+
+    test_moe(int n_experts = 8, int n_experts_per_tok = 2, int n_tokens = 1, int n_embd = 4096, int n_ff = 14336)
+        : n_experts(n_experts), n_experts_per_tok(n_experts_per_tok), n_tokens(n_tokens), n_embd(n_embd), n_ff(n_ff) {
+    }
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * ffn_gate_inp = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_experts);
+
+        std::vector<ggml_tensor *> ffn_up_exp(n_experts);
+        std::vector<ggml_tensor *> ffn_gate_exp(n_experts);
+        std::vector<ggml_tensor *> ffn_down_exp(n_experts);
+
+        for (int i = 0; i < n_experts; ++i) {
+            ffn_up_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
+            ffn_gate_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
+            ffn_down_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd);
+        }
+
+        ggml_tensor * cur = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens);
+
+        ggml_tensor * logits = ggml_mul_mat(ctx, ffn_gate_inp, cur);
+        ggml_tensor * probs = ggml_soft_max_ext(ctx, logits, nullptr, 1.0f/sqrtf(n_embd));
+
+        // select experts
+        ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_experts_per_tok);
+
+        ggml_tensor * weights = ggml_get_rows(ctx,
+                ggml_reshape_3d(ctx, probs, 1, n_experts, n_tokens), selected_experts);
+
+        weights = ggml_reshape_2d(ctx, weights, n_experts_per_tok, n_tokens);
+
+        ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights);
+
+        weights = ggml_div(ctx, weights, weights_sum);
+
+        // compute expert outputs
+        ggml_tensor * moe_out = nullptr;
+
+        for (int i = 0; i < n_experts_per_tok; ++i) {
+            ggml_tensor * cur_expert;
+
+            ggml_tensor * cur_up = ggml_mul_mat_id(ctx, ffn_up_exp.data(), n_experts, selected_experts, i, cur);
+
+            ggml_tensor * cur_gate = ggml_mul_mat_id(ctx, ffn_gate_exp.data(), n_experts, selected_experts, i, cur);
+
+            cur_gate = ggml_silu(ctx, cur_gate);
+
+            cur_expert = ggml_mul(ctx, cur_up, cur_gate);
+
+            cur_expert = ggml_mul_mat_id(ctx, ffn_down_exp.data(), n_experts, selected_experts, i, cur_expert);
+
+            cur_expert = ggml_mul(ctx, cur_expert,
+                    ggml_view_2d(ctx, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
+
+            if (i == 0) {
+                moe_out = cur_expert;
+            } else {
+                moe_out = ggml_add(ctx, moe_out, cur_expert);
+            }
+        }
+
+        cur = moe_out;
+
+        return cur;
+    }
+};
+
 enum test_mode {
     MODE_TEST,
     MODE_PERF,
@@ -1117,14 +1227,28 @@ enum test_mode {
 static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
     std::vector<std::unique_ptr<test_case>> test_cases;
 
+    const ggml_type all_types[] = {
+        GGML_TYPE_F32, GGML_TYPE_F16,
+        GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
+        GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
+        GGML_TYPE_Q8_0,
+        GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
+        GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
+        GGML_TYPE_Q6_K
+    };
+
     // unary ops
     for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
         test_cases.emplace_back(new test_unary((ggml_unary_op) op));
     }
 
-    for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
-        test_cases.emplace_back(new test_get_rows(type, 10, 5, 3));
-        test_cases.emplace_back(new test_get_rows(type, 16, 5, 3));
+    test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false));
+    for (ggml_type type : all_types) {
+        for (int b : {1, 7}) {
+            for (bool v : {false, true}) {
+                test_cases.emplace_back(new test_get_rows(type, 256, 5, 4, b, v));
+            }
+        }
     }
 
     test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1}));
@@ -1134,7 +1258,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 2}));
 
     test_cases.emplace_back(new test_dup());
-    test_cases.emplace_back(new test_cpy());
+
+    for (ggml_type type : all_types) {
+       test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, type, {256, 10, 10, 1}));
+    }
+
     test_cases.emplace_back(new test_cont());
 
     auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
@@ -1144,6 +1272,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     };
 
     add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 8, 1}, {1, 1, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1, 1}, {32, 1, 1, 1});
     add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 320, 320}, {1, 1, 1, 1});
     add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 1, 1}, {1, 1, 1, 1});
     add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 1}, {1, 1, 1, 1});
@@ -1170,8 +1299,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 640, 1}, {32, 32, 1, 1});
     add_test_bin_bcast(GGML_TYPE_F32, {5120, 1, 1, 1}, {1, 256, 1, 1});
     add_test_bin_bcast(GGML_TYPE_F32, {640, 1, 1, 1}, {1, 1, 1, 1});
-    add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {1, 1, 1, 1});
-    add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {2, 1, 1, 1});
+    //add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {1, 1, 1, 1});
+    //add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {2, 1, 1, 1});
 
     test_cases.emplace_back(new test_scale());
 
@@ -1180,16 +1309,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
     }
 
-    const ggml_type all_types[] = {
-        GGML_TYPE_F32, GGML_TYPE_F16,
-        GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
-        GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
-        GGML_TYPE_Q8_0,
-        GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
-        GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
-        GGML_TYPE_Q6_K
-    };
-
     for (ggml_type type_a : all_types) {
         for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
             // FIXME: CPU crashes on f16xf16
@@ -1213,9 +1332,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
 
     for (ggml_type type_a : all_types) {
         for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
-            for (int n_mats : {1, 2, 4}) {
+            for (int n_mats : {2, 4, 8}) {
                 for (int id = 0; id < n_mats; id++) {
-                    test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, {1, 1}, {1, 1}));
+                    for (bool v : {false, true}) {
+                        test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, v));
+                    }
                 }
             }
         }
@@ -1247,10 +1368,18 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     test_cases.emplace_back(new test_concat());
 
     for (ggml_sort_order order : {GGML_SORT_ASC, GGML_SORT_DESC}) {
+        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
     }
 
-    test_cases.emplace_back(new test_sum_rows());
+    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, {10, 10, 10, 10}));
+    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, {2, 1, 1, 1}));
+
+#if !defined(__SANITIZE_THREAD__)
+    // FIXME: these tests use too much memory with thread sanitizer
+    test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 14336));
+    //test_cases.emplace_back(new test_moe(8, 2, 8, 4096, 14336));
+#endif
 
     // run tests
     if (mode == MODE_TEST) {
@@ -1267,14 +1396,17 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         ggml_backend_free(backend_cpu);
 
         return n_ok == test_cases.size();
-    } else if (mode == MODE_PERF) {
+    }
+
+    if (mode == MODE_PERF) {
         for (auto & test : test_cases) {
             test->eval_perf(backend, op_name);
         }
         return true;
-    } else {
-        GGML_ASSERT(false);
     }
+
+    GGML_ASSERT(false);
+    return false;
 }
 
 static void usage(char ** argv) {
@@ -1347,11 +1479,12 @@ int main(int argc, char ** argv) {
     }
 
     printf("%zu/%zu backends passed\n", n_ok, ggml_backend_reg_get_count());
+
     if (n_ok != ggml_backend_reg_get_count()) {
         printf("\033[1;31mFAIL\033[0m\n");
         return 1;
-    } else {
-        printf("\033[1;32mOK\033[0m\n");
-        return 0;
     }
+
+    printf("\033[1;32mOK\033[0m\n");
+    return 0;
 }