]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : full broadcast in mul, add, div + ggml_mul_mat_id, ggml_argsort, ggml_top_k...
authorslaren <redacted>
Tue, 5 Dec 2023 12:56:07 +0000 (13:56 +0100)
committerGitHub <redacted>
Tue, 5 Dec 2023 12:56:07 +0000 (13:56 +0100)
* ggml : support broadcasting in dim 0 in add and mul

* add cuda add/mul broadcast impl
add configurable eps to cuda norm

* add metal impl
ggml-ci

* deduplicate code in cuda impl

* try to optimize cuda impl

* ggml : support broadcasting in ggml_div

* test-backend-ops : allow filtering by op and backend

* ggml-cuda : add ggml_div impl

* ggml : add ggml_mul_mat_id, ggml_sort, ggml_top_k (CPU only)

* fix ggml_div threads

* fix ggml_div with accelerate

* ggml_sort -> ggml_argsort

* whatever

* actually fix accelerate div

* disable opencl ci

* ci : disable ctest error check temporarily until we fix backend ops test

* cmake : propagete GGML_USE_xxx compile flags with ggml target

* whisper : utlize new ggml_add broadcast for dim 0

* cmake : adendum to ee666ae9

* ggml_backend_graph_copy : fix leak

* ggml_cuda : add ggml_sum_rows impl

* metal : add ggml_div

* metal : add ggml_sum_rows

* ggml_cuda : add ggml_argsort impl

* move kernel

* metal : add ggml_argsort

* mul_mat_id : fix missing init task

* cuda/metal: fix argsort synchronization

* metal : add ggml_mul_mat_id

* ggml-cuda : add mul_mat_id for f16 + tensor cores

* test-backend-ops : add tests for quants mat mul

* ggml : fix q5_0 and q5_1 hist stats

* test-backend-ops : use smaller matrices to avoid automatic offloading, add mat-vec tests

* metal : fix alibi to match the CPU behavior

* metal : check dimensions in supports_op

* test-backend-ops : reduce error threshold for mat muls

* ggml-cuda : simplify dequantize funs, add supports_op by type for mul_mat_id

* ggml-cuda : support quantized types in mul_mat_id with cublas

* ggml-cuda : add fallback over CPU for mul_mat_id

* test-backend-ops : increase mul mat error threshold

* cleanup
ggml-ci

* test-backend-ops : fix usage

* cleanup

* ci : re-enable tests

* metal : fix compile warnings

---------

Co-authored-by: Georgi Gerganov <redacted>
14 files changed:
.github/workflows/ci.yml
examples/starcoder/starcoder-mmap.cpp
examples/whisper/whisper.cpp
include/ggml/ggml.h
src/CMakeLists.txt
src/ggml-backend.c
src/ggml-cuda.cu
src/ggml-metal.m
src/ggml-metal.metal
src/ggml.c
tests/test-backend-ops.cpp
tests/test-conv1d.cpp
tests/test-conv2d.cpp
tests/test-mul-mat.cpp

index e5543d92139fe2914dd2513c2b4c000342a03922..ba719f6540d146a5e567536dbd90efe9ddacd22e 100644 (file)
@@ -8,6 +8,7 @@ on:
 
 jobs:
   test-ubuntu-opencl:
+    if: false
     runs-on: ubuntu-latest
     env:
       GGML_NLOOP: 3
index 1ab039c310a0eb94b8982bbc9a2615ea440167d6..8d2c72ddb49a2bb6ea2c1b7b0360923ac52a41c6 100644 (file)
@@ -75,7 +75,7 @@ struct llama_buffer {
     void resize(size_t len) {
 #ifdef GGML_USE_METAL
         free(addr);
-        int result = posix_memalign((void **) &addr, getpagesize(), len);
+        int result = posix_memalign((void **) &addr, sysconf(_SC_PAGESIZE), len);
         if (result == 0) {
             memset(addr, 0, len);
         }
index f0a0a5a6d37916d8015c040cfd3f6db8ee9239c7..376712157a93dd194c04364ce02a07592b998285 100644 (file)
@@ -1341,10 +1341,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
             model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
 
             model.e_conv_1_w     = ggml_new_tensor_3d(ctx, vtype,         3, n_mels,     n_audio_state);
-            model.e_conv_1_b     = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2*n_audio_ctx, n_audio_state);
+            model.e_conv_1_b     = ggml_new_tensor_2d(ctx, GGML_TYPE_F32,         1,     n_audio_state);
 
             model.e_conv_2_w     = ggml_new_tensor_3d(ctx, vtype,         3, n_audio_state, n_audio_state);
-            model.e_conv_2_b     = ggml_new_tensor_2d(ctx, GGML_TYPE_F32,    n_audio_ctx,   n_audio_state);
+            model.e_conv_2_b     = ggml_new_tensor_2d(ctx, GGML_TYPE_F32,                1, n_audio_state);
 
             model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
             model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
@@ -1574,29 +1574,25 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
 
             auto tensor = model.tensors[name.data()];
 
-            const bool is_conv_bias = (name == "encoder.conv1.bias" || name == "encoder.conv2.bias");
-
-            if (!is_conv_bias) {
-                if (ggml_nelements(tensor) != nelements) {
-                    WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
-                    WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
-                            __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
-                    return false;
-                }
+            if (ggml_nelements(tensor) != nelements) {
+                WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
+                WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
+                        __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
+                return false;
+            }
 
-                if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
-                    WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
-                            __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
-                    return false;
-                }
+            if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
+                WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
+                        __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
+                return false;
+            }
 
-                const size_t bpe = ggml_type_size(ggml_type(ttype));
+            const size_t bpe = ggml_type_size(ggml_type(ttype));
 
-                if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
-                    WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
-                            __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
-                    return false;
-                }
+            if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
+                WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
+                        __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
+                return false;
             }
 
             ggml_backend_t backend = wctx.backend;
@@ -1607,7 +1603,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
 #ifdef GGML_USE_METAL
                 || ggml_backend_is_metal(backend)
 #endif
-                ) && !is_conv_bias) {
+                )) {
                 // for the CPU and Metal backend, we can read directly into the tensor
                 loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
                 BYTESWAP_TENSOR(tensor);
@@ -1618,7 +1614,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
                 // we repeat the 2 bias tensors along dim 0:
                 // [1, 512] -> [3000, 512] (conv1.bias)
                 // [1, 512] -> [1500, 512] (conv2.bias)
-                if (is_conv_bias) {
+                if (false) {
                     loader->read(loader->context, read_buf.data(), read_buf.size() / tensor->ne[0]);
 
                     float * data_f32 = (float *) read_buf.data();
@@ -1733,21 +1729,11 @@ static struct ggml_cgraph * whisper_build_graph_conv(
         {
             cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
             cur = ggml_add(ctx0, cur, model.e_conv_1_b);
-            //cur = ggml_add(ctx0,
-            //        ggml_repeat(ctx0,
-            //            model.e_conv_1_b,
-            //            cur),
-            //        cur);
 
             cur = ggml_gelu(ctx0, cur);
 
             cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
             cur = ggml_add(ctx0, cur, model.e_conv_2_b);
-            //cur = ggml_add(ctx0,
-            //        ggml_repeat(ctx0,
-            //            model.e_conv_2_b,
-            //            cur),
-            //        cur);
 
             cur = ggml_gelu(ctx0, cur);
         }
index b53abaa16a025c5d210db40e9f8b67e0809cb1e1..353d52e1987ce72f1aa7ce731f957a84366c057d 100644 (file)
     const type prefix##3 = (pointer)->array[3]; \
     GGML_UNUSED(prefix##3);
 
+#define GGML_TENSOR_UNARY_OP_LOCALS \
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
+#define GGML_TENSOR_BINARY_OP_LOCALS \
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
 #ifdef  __cplusplus
 extern "C" {
 #endif
@@ -382,6 +396,7 @@ extern "C" {
         GGML_OP_GROUP_NORM,
 
         GGML_OP_MUL_MAT,
+        GGML_OP_MUL_MAT_ID,
         GGML_OP_OUT_PROD,
 
         GGML_OP_SCALE,
@@ -408,8 +423,8 @@ extern "C" {
         GGML_OP_CONV_TRANSPOSE_2D,
         GGML_OP_POOL_1D,
         GGML_OP_POOL_2D,
-
         GGML_OP_UPSCALE, // nearest interpolate
+        GGML_OP_ARGSORT,
 
         GGML_OP_FLASH_ATTN,
         GGML_OP_FLASH_FF,
@@ -1033,6 +1048,15 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    // indirect matrix multiplication
+    //  ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
+    GGML_API struct ggml_tensor * ggml_mul_mat_id(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * as[],
+            struct ggml_tensor  * ids,
+            int                   id,
+            struct ggml_tensor  * b);
+
     // A: m columns, n rows,
     // B: p columns, n rows,
     // result is m columns, p rows
@@ -1518,6 +1542,23 @@ extern "C" {
             struct ggml_tensor  * a,
             int                   scale_factor);
 
+    // sort rows
+    enum ggml_sort_order {
+        GGML_SORT_ASC,
+        GGML_SORT_DESC,
+    };
+
+    GGML_API struct ggml_tensor * ggml_argsort(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            enum ggml_sort_order  order);
+
+    // top k elements per row
+    GGML_API struct ggml_tensor * ggml_top_k(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   k);
+
     GGML_API struct ggml_tensor * ggml_flash_attn(
             struct ggml_context * ctx,
             struct ggml_tensor  * q,
@@ -1579,7 +1620,6 @@ extern "C" {
             int                   kh);
 
     // used in sam
-
     GGML_API struct ggml_tensor * ggml_add_rel_pos(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
index 15a70041ecdb85e103e5ef9840fc1bf073ef3b5c..94ee0ce1480d080edff1233ad76982c31d5afb8a 100644 (file)
@@ -171,7 +171,7 @@ if (GGML_OPENBLAS)
 
         set(GGML_EXTRA_LIBS  ${GGML_EXTRA_LIBS}  ${OPENBLAS_LIB})
         set(GGML_EXTRA_INCS  ${GGML_EXTRA_INCS}  ${OPENBLAS_INC})
-       set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_OPENBLAS)
+        set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_OPENBLAS)
     else()
         message(WARNING "OpenBLAS not found")
     endif()
@@ -213,7 +213,17 @@ if (GGML_CUBLAS)
 
         set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)
 
-        add_compile_definitions(GGML_USE_CUBLAS)
+        set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_CUBLAS)
+
+        if (GGML_CUDA_FORCE_DMMV)
+            add_compile_definitions(GGML_CUDA_FORCE_DMMV)
+        endif()
+        if (GGML_CUDA_FORCE_MMQ)
+            add_compile_definitions(GGML_CUDA_FORCE_MMQ)
+        endif()
+
+        # required for dynamic parallelism
+        set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)
 
         if (GGML_STATIC)
             set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
@@ -245,7 +255,9 @@ if (GGML_HIPBLAS)
 
     if (${hipblas_FOUND} AND ${hip_FOUND})
         message(STATUS "HIP and hipBLAS found")
-        add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
+
+        set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_CUBLAS)
+
         add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h)
         if (BUILD_SHARED_LIBS)
             set_target_properties(ggml-rocm PROPERTIES POSITION_INDEPENDENT_CODE ON)
@@ -280,7 +292,8 @@ if (GGML_METAL)
 
     set(GGML_METAL_SOURCES ggml-metal.m ggml-metal.h)
 
-    add_compile_definitions(GGML_USE_METAL)
+    set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_METAL)
+
     #add_compile_definitions(GGML_METAL_NDEBUG)
 
     # get full path to the file
index 5f3005b2b7f7e7945a7d3297e21cde3140b87c99..9f31b065db4a5c2f0bf6a95cc983fff7f8ebe468 100644 (file)
@@ -1064,8 +1064,6 @@ ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_bac
     struct ggml_backend_sched * sched = malloc(sizeof(struct ggml_backend_sched));
     memset(sched, 0, sizeof(struct ggml_backend_sched));
 
-    fprintf(stderr, "ggml_backend_sched size: %lu KB\n", sizeof(struct ggml_backend_sched)/1024);
-
     sched->n_backends = n_backends;
     for (int i = 0; i < n_backends; i++) {
         sched->backends[i] = backends[i];
@@ -1271,6 +1269,7 @@ struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, s
 
     free(hash_set.keys);
     free(node_copies);
+    free(node_init);
 
     return (struct ggml_backend_graph_copy) {
         /* .buffer           = */ buffer,
index 9a8e40eb80efa3f010553f52b620facb6b1157a9..dbe92d97eb395ddeab050845c55afc33c5ed0a8a 100644 (file)
@@ -69,6 +69,7 @@
 #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
 #define cudaSetDevice hipSetDevice
 #define cudaStreamCreateWithFlags hipStreamCreateWithFlags
+#define cudaStreamFireAndForget hipStreamFireAndForget
 #define cudaStreamNonBlocking hipStreamNonBlocking
 #define cudaStreamSynchronize hipStreamSynchronize
 #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
@@ -433,8 +434,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 #define WARP_SIZE 32
 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
 
-#define CUDA_ADD_BLOCK_SIZE 256
-#define CUDA_MUL_BLOCK_SIZE 256
+#define CUDA_ADDMUL_BLOCK_SIZE 256
 #define CUDA_GELU_BLOCK_SIZE 256
 #define CUDA_SILU_BLOCK_SIZE 256
 #define CUDA_RELU_BLOCK_SIZE 256
@@ -501,40 +501,43 @@ static size_t g_scratch_offset = 0;
 
 static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
 
-static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
-    const int i = blockDim.x*blockIdx.x + threadIdx.x;
-
-    if (i >= kx) {
-        return;
-    }
-    dst[i] = x[i] + y[i%ky];
+static __device__ __forceinline__ float op_add(const float a, const float b) {
+    return a + b;
 }
 
-static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
-    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+static __device__ __forceinline__ float op_mul(const float a, const float b) {
+    return a * b;
+}
 
-    if (i >= k) {
-        return;
-    }
-    dst[i] = __hadd(x[i], __float2half(y[i]));
+static __device__ __forceinline__ float op_div(const float a, const float b) {
+    return a / b;
 }
 
-static __global__ void add_f16_f32_f32(const half * x, const float * y, float * dst, const int k) {
-    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
+static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
+        int ne0,/* int ne1, int ne2, */int ne3,
+        int ne10, int ne11, int ne12, int ne13,
+        /*int s0, */ int s1,  int s2,  int s3,
+        /*int s10,*/ int s11, int s12, int s13) {
+    const int i0 = blockDim.x*blockIdx.x + threadIdx.x;
+    const int i1 = blockIdx.y;
+    const int i2 = blockIdx.z / ne3;
+    const int i3 = blockIdx.z % ne3;
 
-    if (i >= k) {
+    if (i0 >= ne0) {
         return;
     }
-    dst[i] = __half2float(x[i]) + y[i];
-}
 
-static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
-    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+    const int i10 = i0 % ne10;
+    const int i11 = i1 % ne11;
+    const int i12 = i2 % ne12;
+    const int i13 = i3 % ne13;
 
-    if (i >= kx) {
-        return;
-    }
-    dst[i] = x[i] * y[i%ky];
+    const size_t i_dst  = i3*s3 + i2*s2 + i1*s1 + i0;
+    const size_t i_src0 = i_dst;
+    const size_t i_src1 = i13*s13 + i12*s12 + i11*s11 + i10;
+
+    dst[i_dst] = (dst_t)bin_op((float)src0[i_src0], (float)src1[i_src1]);
 }
 
 static __global__ void gelu_f32(const float * x, float * dst, const int k) {
@@ -577,6 +580,14 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
     dst[i] = x[i] * x[i];
 }
 
+static __device__ __forceinline__ float warp_reduce_sum(float x) {
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        x += __shfl_xor_sync(0xffffffff, x, mask, 32);
+    }
+    return x;
+}
+
 static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
 #pragma unroll
     for (int mask = 16; mask > 0; mask >>= 1) {
@@ -587,12 +598,10 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
 }
 
 template <int block_size>
-static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
+static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) {
     const int row = blockIdx.x*blockDim.y + threadIdx.y;
     const int tid = threadIdx.x;
 
-    const float eps = 1e-5f;
-
     float2 mean_var = make_float2(0.f, 0.f);
 
     for (int col = tid; col < ncols; col += block_size) {
@@ -624,14 +633,6 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
     }
 }
 
-static __device__ __forceinline__ float warp_reduce_sum(float x) {
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        x += __shfl_xor_sync(0xffffffff, x, mask, 32);
-    }
-    return x;
-}
-
 template <int block_size>
 static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
     const int row = blockIdx.x*blockDim.y + threadIdx.y;
@@ -4676,6 +4677,65 @@ static __global__ void alibi_f32(const float * x, float * dst, const int ncols,
     dst[i] = col * m_k + x[i];
 }
 
+static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
+    const int row = blockIdx.y;
+    const int col = threadIdx.x;
+
+    float sum = 0.0f;
+    for (int i = col; i < ncols; i += blockDim.x) {
+        sum += x[row * ncols + i];
+    }
+
+    sum = warp_reduce_sum(sum);
+
+    if (col == 0) {
+        dst[row] = sum;
+    }
+}
+
+template<typename T>
+static inline __device__ void swap(T & a, T & b) {
+    T tmp = a;
+    a = b;
+    b = tmp;
+}
+
+template<ggml_sort_order order>
+static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols) {
+    // bitonic sort
+    int col = threadIdx.x;
+    int row = blockIdx.y;
+
+    if (col >= ncols) return;
+
+    const float * x_row = x + row * ncols;
+    int * dst_row = dst + row * ncols;
+
+    // initialize indices
+    if (col < ncols) {
+        dst_row[col] = col;
+    }
+    __syncthreads();
+
+    for (int k = 2; k <= ncols; k *= 2) {
+        for (int j = k / 2; j > 0; j /= 2) {
+            int ixj = col ^ j;
+            if (ixj > col) {
+                if ((col & k) == 0) {
+                    if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
+                        swap(dst_row[col], dst_row[ixj]);
+                    }
+                } else {
+                    if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
+                        swap(dst_row[col], dst_row[ixj]);
+                    }
+                }
+            }
+            __syncthreads();
+        }
+    }
+}
+
 static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
     const int col = blockDim.y*blockIdx.y + threadIdx.y;
     const int row = blockDim.x*blockIdx.x + threadIdx.x;
@@ -4780,25 +4840,35 @@ static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const
     k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols);
 }
 
-static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
-    const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
-    add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
-}
+template<float (*bin_op)(const float, const float)>
+struct bin_bcast_cuda {
+    template<typename src0_t, typename src1_t, typename dst_t>
+    void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
+            const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
+            cudaStream_t stream) {
 
-static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
-    add_f16_f32_f16<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
-}
+        GGML_TENSOR_BINARY_OP_LOCALS
 
-static void add_f16_f32_f32_cuda(const half * x, const float * y, float * dst, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
-    add_f16_f32_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
-}
+        //size_t s0 = nb0 / sizeof(src1_t);
+        size_t s1 = nb1 / sizeof(src1_t);
+        size_t s2 = nb2 / sizeof(src1_t);
+        size_t s3 = nb3 / sizeof(src1_t);
 
-static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
-    const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
-    mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
-}
+        //size_t s10 = nb10 / sizeof(src1_t);
+        size_t s11 = nb11 / sizeof(src1_t);
+        size_t s12 = nb12 / sizeof(src1_t);
+        size_t s13 = nb13 / sizeof(src1_t);
+
+        const int num_blocks_x = (ne0 + CUDA_ADDMUL_BLOCK_SIZE - 1) / CUDA_ADDMUL_BLOCK_SIZE;
+        dim3 num_blocks(num_blocks_x, ne1, ne2*ne3);
+
+        k_bin_bcast<bin_op><<<num_blocks, CUDA_ADDMUL_BLOCK_SIZE, 0, stream>>>(src0_dd, src1_dd, dst_dd,
+            ne0,/* ne1, ne2, */ne3,
+            ne10, ne11, ne12, ne13,
+            /* s0, */s1, s2, s3,
+            /* s10,*/ s11, s12, s13);
+    }
+};
 
 static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
@@ -4820,14 +4890,14 @@ static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t
     sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
-static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
     GGML_ASSERT(ncols % WARP_SIZE == 0);
     if (ncols < 1024) {
         const dim3 block_dims(WARP_SIZE, 1, 1);
-        norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
+        norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
     } else {
         const dim3 block_dims(1024, 1, 1);
-        norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
+        norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
     }
 }
 
@@ -4849,38 +4919,14 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con
     quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
 }
 
-template<typename dst_t>
-static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
-    dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
-
-template<typename dst_t>
-static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
-    dequantize_block<QK4_1, QR4_1, dequantize_q4_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
-
-template<typename dst_t>
-static void dequantize_row_q5_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
-    dequantize_block<QK5_0, QR5_0, dequantize_q5_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
-
-template<typename dst_t>
-static void dequantize_row_q5_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
-    dequantize_block<QK5_1, QR5_1, dequantize_q5_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
-
-template<typename dst_t>
-static void dequantize_row_q8_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static __host__ __device__ void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
-    dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
+    dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
 
 template<typename dst_t>
-static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static __host__ __device__ void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
 #if QK_K == 256
     dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -4890,7 +4936,7 @@ static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cu
 }
 
 template<typename dst_t>
-static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static __host__ __device__ void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
 #if QK_K == 256
     dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -4900,13 +4946,13 @@ static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cu
 }
 
 template<typename dst_t>
-static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static __host__ __device__ void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
     dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
 }
 
 template<typename dst_t>
-static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static __host__ __device__ void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
 #if QK_K == 256
     dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -4916,7 +4962,7 @@ static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cu
 }
 
 template<typename dst_t>
-static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static __host__ __device__ void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int nb = k / QK_K;
 #if QK_K == 256
     dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -4925,6 +4971,64 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
 #endif
 }
 
+static to_fp16_cuda_t __host__ __device__ ggml_get_to_fp16_cuda(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+            return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
+        case GGML_TYPE_Q4_1:
+            return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
+        case GGML_TYPE_Q5_0:
+            return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
+        case GGML_TYPE_Q5_1:
+            return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
+        case GGML_TYPE_Q8_0:
+            return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
+        case GGML_TYPE_Q2_K:
+            return dequantize_row_q2_K_cuda;
+        case GGML_TYPE_Q3_K:
+            return dequantize_row_q3_K_cuda;
+        case GGML_TYPE_Q4_K:
+            return dequantize_row_q4_K_cuda;
+        case GGML_TYPE_Q5_K:
+            return dequantize_row_q5_K_cuda;
+        case GGML_TYPE_Q6_K:
+            return dequantize_row_q6_K_cuda;
+        case GGML_TYPE_F32:
+            return dequantize_block_cuda<1, 1, convert_f32>;
+        default:
+            return nullptr;
+    }
+}
+
+static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+            return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
+        case GGML_TYPE_Q4_1:
+            return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
+        case GGML_TYPE_Q5_0:
+            return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
+        case GGML_TYPE_Q5_1:
+            return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
+        case GGML_TYPE_Q8_0:
+            return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
+        case GGML_TYPE_Q2_K:
+            return dequantize_row_q2_K_cuda;
+        case GGML_TYPE_Q3_K:
+            return dequantize_row_q3_K_cuda;
+        case GGML_TYPE_Q4_K:
+            return dequantize_row_q4_K_cuda;
+        case GGML_TYPE_Q5_K:
+            return dequantize_row_q5_K_cuda;
+        case GGML_TYPE_Q6_K:
+            return dequantize_row_q6_K_cuda;
+        case GGML_TYPE_F16:
+            return dequantize_block_cuda<1, 1, convert_f16>;
+        default:
+            return nullptr;
+    }
+}
+
 static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
@@ -5013,6 +5117,15 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
     dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
 }
 
+static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+    const dim3 block_nums(block_num_y, 1, 1);
+    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+    dequantize_mul_mat_vec<1, 1, convert_f16>
+        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
+}
+
 static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % QK4_0 == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
@@ -5103,83 +5216,6 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float *
         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
 }
 
-static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
-    dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
-
-static void convert_fp32_to_fp16_cuda(const void * vx, half * y, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
-    dequantize_block<1, 1, convert_f32><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
-}
-
-static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
-    const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
-    const dim3 block_nums(block_num_y, 1, 1);
-    const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
-    dequantize_mul_mat_vec<1, 1, convert_f16>
-        <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
-}
-
-static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
-    switch (type) {
-        case GGML_TYPE_Q4_0:
-            return dequantize_row_q4_0_cuda;
-        case GGML_TYPE_Q4_1:
-            return dequantize_row_q4_1_cuda;
-        case GGML_TYPE_Q5_0:
-            return dequantize_row_q5_0_cuda;
-        case GGML_TYPE_Q5_1:
-            return dequantize_row_q5_1_cuda;
-        case GGML_TYPE_Q8_0:
-            return dequantize_row_q8_0_cuda;
-        case GGML_TYPE_Q2_K:
-            return dequantize_row_q2_K_cuda;
-        case GGML_TYPE_Q3_K:
-            return dequantize_row_q3_K_cuda;
-        case GGML_TYPE_Q4_K:
-            return dequantize_row_q4_K_cuda;
-        case GGML_TYPE_Q5_K:
-            return dequantize_row_q5_K_cuda;
-        case GGML_TYPE_Q6_K:
-            return dequantize_row_q6_K_cuda;
-        case GGML_TYPE_F32:
-            return convert_fp32_to_fp16_cuda;
-        default:
-            return nullptr;
-    }
-}
-
-static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
-    switch (type) {
-        case GGML_TYPE_Q4_0:
-            return dequantize_row_q4_0_cuda;
-        case GGML_TYPE_Q4_1:
-            return dequantize_row_q4_1_cuda;
-        case GGML_TYPE_Q5_0:
-            return dequantize_row_q5_0_cuda;
-        case GGML_TYPE_Q5_1:
-            return dequantize_row_q5_1_cuda;
-        case GGML_TYPE_Q8_0:
-            return dequantize_row_q8_0_cuda;
-        case GGML_TYPE_Q2_K:
-            return dequantize_row_q2_K_cuda;
-        case GGML_TYPE_Q3_K:
-            return dequantize_row_q3_K_cuda;
-        case GGML_TYPE_Q4_K:
-            return dequantize_row_q4_K_cuda;
-        case GGML_TYPE_Q5_K:
-            return dequantize_row_q5_K_cuda;
-        case GGML_TYPE_Q6_K:
-            return dequantize_row_q6_K_cuda;
-        case GGML_TYPE_F16:
-            return convert_fp16_to_fp32_cuda;
-        default:
-            return nullptr;
-    }
-}
-
 static void ggml_mul_mat_q4_0_q8_1_cuda(
     const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
     const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
@@ -5752,6 +5788,27 @@ static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const
     alibi_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1);
 }
 
+static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    const dim3 block_dims(WARP_SIZE, 1, 1);
+    const dim3 block_nums(1, nrows, 1);
+    k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
+}
+
+static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
+    // bitonic sort requires ncols to be power of 2
+    GGML_ASSERT((ncols & (ncols - 1)) == 0);
+
+    const dim3 block_dims(ncols, 1, 1);
+    const dim3 block_nums(1, nrows, 1);
+    if (order == GGML_SORT_ASC) {
+        k_argsort_f32_i32<GGML_SORT_ASC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
+    } else if (order == GGML_SORT_DESC) {
+        k_argsort_f32_i32<GGML_SORT_DESC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
+    } else {
+        GGML_ASSERT(false);
+    }
+}
+
 static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
     const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1);
     const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
@@ -6140,44 +6197,46 @@ static void ggml_cuda_op_get_rows(
     }
 }
 
-inline void ggml_cuda_op_add(
+template<class op>
+inline void ggml_cuda_op_bin_bcast(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
 
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
-
     if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-        add_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream);
+        op()(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
     } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
-        add_f16_f32_f16_cuda((const half *) src0_dd, src1_dd, (half *) dst_dd, ggml_nelements(src0), main_stream);
+        op()(src0, src1, dst, (const half *) src0_dd, src1_dd, (half *) dst_dd, main_stream);
     } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
-        add_f16_f32_f32_cuda((const half *) src0_dd, src1_dd, dst_dd, ggml_nelements(src0), main_stream);
+        op()(src0, src1, dst, (const half *) src0_dd, src1_dd, dst_dd, main_stream);
     } else {
-        fprintf(stderr, "src0->type: %d  dst->type: %d\n", src0->type, dst->type);
+        fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
+            ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
         GGML_ASSERT(false);
     }
 
-    (void) src1;
-    (void) dst;
 }
 
-inline void ggml_cuda_op_mul(
+inline void ggml_cuda_op_add(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
+}
 
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
+inline void ggml_cuda_op_mul(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
-    mul_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream);
+    ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
+}
 
-    (void) dst;
+inline void ggml_cuda_op_div(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
 }
 
 inline void ggml_cuda_op_gelu(
@@ -6246,7 +6305,10 @@ inline void ggml_cuda_op_norm(
     const int64_t ne00 = src0->ne[0];
     const int64_t nrows = ggml_nrows(src0);
 
-    norm_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream);
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+
+    norm_f32_cuda(src0_dd, dst_dd, ne00, nrows, eps, main_stream);
 
     (void) src1;
     (void) dst;
@@ -6785,6 +6847,42 @@ inline void ggml_cuda_op_im2col(
     (void) src0_dd;
 }
 
+inline void ggml_cuda_op_sum_rows(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    const int64_t ncols = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    sum_rows_f32_cuda(src0_dd, dst_dd, ncols, nrows, main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_cuda_op_argsort(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_I32);
+
+    const int64_t ncols = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
+
+    argsort_f32_i32_cuda(src0_dd, (int *)dst_dd, ncols, nrows, order, main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
 inline void ggml_cuda_op_diag_mask_inf(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -7298,6 +7396,10 @@ static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, gg
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul);
 }
 
+static void ggml_cuda_div(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_div);
+}
+
 static void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu);
 }
@@ -7401,7 +7503,7 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
     ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
 }
 
-__global__ void k_compute_batched_ptrs(
+static __global__ void k_compute_batched_ptrs(
         const half * src0_as_f16, const half * src1_as_f16, half * dst_f16,
         const void ** ptrs_src, void ** ptrs_dst,
         int ne12, int ne13,
@@ -7457,9 +7559,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
     CUDA_CHECK(ggml_cuda_set_device(g_main_device));
     cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
 
-    int id;
-    CUDA_CHECK(cudaGetDevice(&id));
-    CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream));
+    CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream));
 
     ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
     void * src0_ddq = src0_extra->data_device[g_main_device];
@@ -7516,7 +7616,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
         // there is no broadcast and src0, src1 are contiguous across dims 2, 3
         // use cublasGemmStridedBatchedEx
         CUBLAS_CHECK(
-        cublasGemmStridedBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
+        cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
                 ne01, ne11, ne10,
                 &alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half),  src0->nb[2]/sizeof(half),  // strideA
                             (const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
@@ -7550,7 +7650,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
         CUDA_CHECK(cudaGetLastError());
 
         CUBLAS_CHECK(
-        cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
+        cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
                 ne01, ne11, ne10,
                 &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
                             (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
@@ -7648,6 +7748,219 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
     }
 }
 
+#if 0
+template<typename ... Srcs>
+static __global__ void k_compute_batched_ptrs_id(
+        const void ** ptrs_src, void ** ptrs_dst,
+        int ne12, int ne13,
+        int ne23,
+        int nb02, int nb03,
+        int nb12, int nb13,
+        int nb2, int nb3,
+        int r2, int r3,
+        ggml_type src0_type, half * src0_as_f16, int64_t src0_ne,
+        const half * src1_f16, half * dst_f16,
+        const int32_t * ids, const int id,
+        Srcs... src0s) {
+
+    int i = ids[id];
+
+    half * src0_f16;
+    const void * srcs_ar[] = { (const half *) src0s... };
+    if (src0_type == GGML_TYPE_F16) {
+        src0_f16 = (half *) srcs_ar[i];
+    } else {
+        src0_f16 = src0_as_f16;
+        if (threadIdx.x == 0 && threadIdx.y == 0) {
+            const to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(src0_type);
+            to_fp16(srcs_ar[i], src0_f16, src0_ne, cudaStreamFireAndForget);
+        }
+    }
+
+    int i13 = blockIdx.x * blockDim.x + threadIdx.x;
+    int i12 = blockIdx.y * blockDim.y + threadIdx.y;
+
+    if (i13 >= ne13 || i12 >= ne12) {
+        return;
+    }
+
+    int i03 = i13 / r3;
+    int i02 = i12 / r2;
+
+    ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_f16 + i02*nb02   + i03*nb03;
+    ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_f16 + i12*nb12/2 + i13*nb13/2;
+    ptrs_dst[0*ne23 + i12 + i13*ne12] = (      char *)  dst_f16 + i12* nb2/2 + i13* nb3/2;
+}
+
+static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) {
+    const struct ggml_tensor * ids = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+    const struct ggml_tensor * src00 = dst->src[2];
+
+    const int id = dst->op_params[0];
+
+    GGML_ASSERT(!ggml_is_transposed(src00));
+    GGML_ASSERT(!ggml_is_transposed(src1));
+
+    GGML_ASSERT(src00->backend != GGML_BACKEND_GPU_SPLIT);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+    const int64_t ne00 = src00->ne[0]; GGML_UNUSED(ne00);
+    const int64_t ne01 = src00->ne[1];
+    const int64_t ne02 = src00->ne[2];
+    const int64_t ne03 = src00->ne[3];
+
+    //const int64_t nb01 = src00->nb[1];
+    const int64_t nb02 = src00->nb[2]; GGML_UNUSED(nb02);
+    const int64_t nb03 = src00->nb[3]; GGML_UNUSED(nb03);
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
+
+    //const int64_t nb11 = src1->nb[1];
+    const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12);
+    const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13);
+
+    const int64_t ne1 = ggml_nelements(src1);
+    const int64_t ne  = ggml_nelements(dst);
+
+    CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+    cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
+
+    CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream));
+
+    //ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
+    //void * src0_ddq = src0_extra->data_device[g_main_device];
+    //half * src0_as_f16 = (half *) src0_ddq;
+
+    ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
+    float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
+
+    ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
+    float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
+
+    // convert src1 to fp16
+    const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
+    GGML_ASSERT(to_fp16_cuda != nullptr);
+
+    size_t src1_as = 0;
+    half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
+    to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
+
+    size_t dst_as = 0;
+    half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
+
+    GGML_ASSERT(ne12 % ne02 == 0);
+    GGML_ASSERT(ne13 % ne03 == 0);
+
+    // broadcast factors
+    const int64_t r2 = ne12/ne02;
+    const int64_t r3 = ne13/ne03;
+
+    const half alpha_f16 = 1.0f;
+    const half beta_f16  = 0.0f;
+
+    // use cublasGemmBatchedEx
+    const int ne23 = ne12*ne13;
+
+    const void ** ptrs_src = nullptr;
+          void ** ptrs_dst = nullptr;
+
+    size_t ptrs_src_s = 0;
+    size_t ptrs_dst_s = 0;
+
+    ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s);
+    ptrs_dst = (      void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s);
+
+    int64_t src0_ne = ggml_nelements(src00);
+    half * src0_as_f16 = nullptr;
+    size_t src0_as = 0;
+    if (src00->type != GGML_TYPE_F16) {
+        src0_as_f16 = (half *) ggml_cuda_pool_malloc(src0_ne * sizeof(half), &src0_as);
+    }
+
+    static_assert(GGML_MAX_SRC == 6, "GGML_MAX_SRC == 6");
+    dim3 block_dims(ne13, ne12);
+    k_compute_batched_ptrs_id<<<1, block_dims, 0, main_stream>>>(
+            ptrs_src, ptrs_dst,
+            ne12, ne13,
+            ne23,
+            ne00*ne01*sizeof(half), ne00*ne01*ne02*sizeof(half),
+            nb12, nb13,
+            dst->nb[2], dst->nb[3],
+            r2, r3,
+            src00->type, src0_as_f16, src0_ne,
+            src1_as_f16, dst_f16,
+            (const int *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device], id,
+            dst->src[2] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[2]->extra)->data_device[g_main_device] : nullptr,
+            dst->src[3] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[3]->extra)->data_device[g_main_device] : nullptr,
+            dst->src[4] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[4]->extra)->data_device[g_main_device] : nullptr,
+            dst->src[5] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[5]->extra)->data_device[g_main_device] : nullptr
+    );
+    CUDA_CHECK(cudaGetLastError());
+
+    CUBLAS_CHECK(
+    cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
+            ne01, ne11, ne10,
+            &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, ne00,
+                        (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, ne10,
+            &beta_f16,  (      void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
+            ne23,
+            CUBLAS_COMPUTE_16F,
+            CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+
+    if (src0_as != 0) {
+        ggml_cuda_pool_free(src0_as_f16, src0_as);
+    }
+    if (ptrs_src_s != 0) {
+        ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
+    }
+    if (ptrs_dst_s != 0) {
+        ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
+    }
+
+    const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
+    to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
+
+    ggml_cuda_pool_free(src1_as_f16, src1_as);
+    ggml_cuda_pool_free(dst_f16, dst_as);
+}
+#endif
+
+static void ggml_cuda_mul_mat_id(const ggml_tensor * _src0, const ggml_tensor * _src1, ggml_tensor * dst) {
+#if 0
+//#ifdef CUDA_USE_TENSOR_CORES
+//    const bool use_tensor_cores = true;
+//#else
+//    const bool use_tensor_cores = false;
+//#endif
+
+    ggml_cuda_mul_mat_id_cublas(dst);
+
+    // TODO: mmq/mmv support
+#else
+    const struct ggml_tensor * ids = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+    const int id = dst->op_params[0];
+
+    int32_t * ids_dev = (int32_t *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
+
+    int32_t a_id;
+    CUDA_CHECK(cudaMemcpyAsync(&a_id, ids_dev + id, sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
+    CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
+
+    GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
+    const struct ggml_tensor * src0 = dst->src[a_id + 2];
+
+    ggml_cuda_mul_mat(src0, src1, dst);
+#endif
+
+    (void) _src0;
+    (void) _src1;
+}
+
 static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
 }
@@ -7735,6 +8048,16 @@ static void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1,
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
 }
 
+static void ggml_cuda_sum_rows(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sum_rows);
+}
+
+static void ggml_cuda_argsort(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_argsort);
+}
+
 static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     (void) src0;
     (void) src1;
@@ -8054,6 +8377,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
         case GGML_OP_MUL:
             func = ggml_cuda_mul;
             break;
+        case GGML_OP_DIV:
+            func = ggml_cuda_div;
+            break;
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(tensor)) {
                 case GGML_UNARY_OP_GELU:
@@ -8080,6 +8406,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
             }
             func = ggml_cuda_mul_mat;
             break;
+        case GGML_OP_MUL_MAT_ID:
+            if (!any_on_device && !ggml_cuda_can_mul_mat(tensor->src[2], tensor->src[1], tensor)) {
+                return false;
+            }
+            func = ggml_cuda_mul_mat_id;
+            break;
         case GGML_OP_SCALE:
             func = ggml_cuda_scale;
             break;
@@ -8119,6 +8451,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
         case GGML_OP_IM2COL:
             func = ggml_cuda_im2col;
             break;
+        case GGML_OP_SUM_ROWS:
+            func = ggml_cuda_sum_rows;
+            break;
+        case GGML_OP_ARGSORT:
+            func = ggml_cuda_argsort;
+            break;
         default:
             return false;
     }
@@ -8343,6 +8681,7 @@ static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggm
 
     // FIXME: this is a hack to avoid having to implement a new buffer type
     ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
+    buffer->buft = buft;
     buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer;
 
     return buffer;
@@ -8515,6 +8854,7 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
                     return false;
             }
             break;
+        case GGML_OP_MUL_MAT_ID:
         case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
         case GGML_OP_VIEW:
@@ -8526,6 +8866,7 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
         case GGML_OP_DUP:
         case GGML_OP_ADD:
         case GGML_OP_MUL:
+        case GGML_OP_DIV:
         case GGML_OP_RMS_NORM:
         case GGML_OP_MUL_MAT:
         case GGML_OP_SCALE:
@@ -8538,6 +8879,8 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
         case GGML_OP_ROPE:
         case GGML_OP_ALIBI:
         case GGML_OP_IM2COL:
+        case GGML_OP_SUM_ROWS:
+        case GGML_OP_ARGSORT:
             return true;
         default:
             return false;
@@ -8595,6 +8938,7 @@ static ggml_backend_t ggml_backend_reg_cuda_init(const char * params, void * use
 
 static int ggml_backend_cuda_reg_devices() {
     int device_count = ggml_cuda_get_device_count();
+    //int device_count = 1; // DEBUG: some tools require delaying CUDA initialization
     for (int i = 0; i < device_count; i++) {
         char name[128];
         snprintf(name, sizeof(name), "%s%d", GGML_CUDA_NAME, i);
index 37b291a9e31756e22de65ae1df1a837d9824b00e..f2267356cc0a0466137ddb42c8a246cc619cb6ca 100644 (file)
@@ -62,6 +62,8 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
     GGML_METAL_DECL_KERNEL(mul);
     GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
+    GGML_METAL_DECL_KERNEL(div);
+    GGML_METAL_DECL_KERNEL(div_row);
     GGML_METAL_DECL_KERNEL(scale);
     GGML_METAL_DECL_KERNEL(scale_4);
     GGML_METAL_DECL_KERNEL(silu);
@@ -112,15 +114,30 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
     GGML_METAL_DECL_KERNEL(rope_f32);
     GGML_METAL_DECL_KERNEL(rope_f16);
     GGML_METAL_DECL_KERNEL(alibi_f32);
     GGML_METAL_DECL_KERNEL(im2col_f16);
+    GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
+    GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
     GGML_METAL_DECL_KERNEL(cpy_f32_f16);
     GGML_METAL_DECL_KERNEL(cpy_f32_f32);
     GGML_METAL_DECL_KERNEL(cpy_f16_f16);
     GGML_METAL_DECL_KERNEL(concat);
     GGML_METAL_DECL_KERNEL(sqr);
+    GGML_METAL_DECL_KERNEL(sum_rows);
 
 #undef GGML_METAL_DECL_KERNEL
 };
@@ -289,6 +306,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(add_row);
         GGML_METAL_ADD_KERNEL(mul);
         GGML_METAL_ADD_KERNEL(mul_row);
+        GGML_METAL_ADD_KERNEL(div);
+        GGML_METAL_ADD_KERNEL(div_row);
         GGML_METAL_ADD_KERNEL(scale);
         GGML_METAL_ADD_KERNEL(scale_4);
         GGML_METAL_ADD_KERNEL(silu);
@@ -340,16 +359,31 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
             GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
             GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
             GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q4_1_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q5_0_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q5_1_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q8_0_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q2_K_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q3_K_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
         }
         GGML_METAL_ADD_KERNEL(rope_f32);
         GGML_METAL_ADD_KERNEL(rope_f16);
         GGML_METAL_ADD_KERNEL(alibi_f32);
         GGML_METAL_ADD_KERNEL(im2col_f16);
+        GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
+        GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
         GGML_METAL_ADD_KERNEL(cpy_f32_f16);
         GGML_METAL_ADD_KERNEL(cpy_f32_f32);
         GGML_METAL_ADD_KERNEL(cpy_f16_f16);
         GGML_METAL_ADD_KERNEL(concat);
         GGML_METAL_ADD_KERNEL(sqr);
+        GGML_METAL_ADD_KERNEL(sum_rows);
 
 #undef GGML_METAL_ADD_KERNEL
     }
@@ -367,6 +401,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(add_row);
     GGML_METAL_DEL_KERNEL(mul);
     GGML_METAL_DEL_KERNEL(mul_row);
+    GGML_METAL_DEL_KERNEL(div);
+    GGML_METAL_DEL_KERNEL(div_row);
     GGML_METAL_DEL_KERNEL(scale);
     GGML_METAL_DEL_KERNEL(scale_4);
     GGML_METAL_DEL_KERNEL(silu);
@@ -418,16 +454,31 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
         GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
         GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
         GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
     }
     GGML_METAL_DEL_KERNEL(rope_f32);
     GGML_METAL_DEL_KERNEL(rope_f16);
     GGML_METAL_DEL_KERNEL(alibi_f32);
     GGML_METAL_DEL_KERNEL(im2col_f16);
+    GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
+    GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
     GGML_METAL_DEL_KERNEL(cpy_f32_f16);
     GGML_METAL_DEL_KERNEL(cpy_f32_f32);
     GGML_METAL_DEL_KERNEL(cpy_f16_f16);
     GGML_METAL_DEL_KERNEL(concat);
     GGML_METAL_DEL_KERNEL(sqr);
+    GGML_METAL_DEL_KERNEL(sum_rows);
 
 #undef GGML_METAL_DEL_KERNEL
 
@@ -884,6 +935,8 @@ void ggml_metal_graph_compute(
                             [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
                     case GGML_OP_ADD:
+                    case GGML_OP_MUL:
+                    case GGML_OP_DIV:
                         {
                             GGML_ASSERT(ggml_is_contiguous(src0));
                             GGML_ASSERT(ggml_is_contiguous(src1));
@@ -897,11 +950,21 @@ void ggml_metal_graph_compute(
                                 GGML_ASSERT(ne11 == 1);
 
                                 nb = ne00 / 4;
-                                [encoder setComputePipelineState:ctx->pipeline_add_row];
+                                switch (dst->op) {
+                                    case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
+                                    case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
+                                    case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
+                                    default: GGML_ASSERT(false);
+                                }
 
                                 bcast_row = true;
                             } else {
-                                [encoder setComputePipelineState:ctx->pipeline_add];
+                                switch (dst->op) {
+                                    case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
+                                    case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
+                                    case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
+                                    default: GGML_ASSERT(false);
+                                }
                             }
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -942,31 +1005,6 @@ void ggml_metal_graph_compute(
                                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                             }
                         } break;
-                    case GGML_OP_MUL:
-                        {
-                            GGML_ASSERT(ggml_is_contiguous(src0));
-                            GGML_ASSERT(ggml_is_contiguous(src1));
-
-                            // utilize float4
-                            GGML_ASSERT(ne00 % 4 == 0);
-                            const int64_t nb = ne00/4;
-
-                            if (ggml_nelements(src1) == ne10) {
-                                // src1 is a row
-                                GGML_ASSERT(ne11 == 1);
-                                [encoder setComputePipelineState:ctx->pipeline_mul_row];
-                            } else {
-                                [encoder setComputePipelineState:ctx->pipeline_mul];
-                            }
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                            [encoder setBytes:&nb     length:sizeof(nb) atIndex:3];
-
-                            const int64_t n = ggml_nelements(dst)/4;
-
-                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                        } break;
                     case GGML_OP_SCALE:
                         {
                             GGML_ASSERT(ggml_is_contiguous(src0));
@@ -1039,6 +1077,40 @@ void ggml_metal_graph_compute(
                             const int64_t n = ggml_nelements(dst);
                             [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                         } break;
+                    case GGML_OP_SUM_ROWS:
+                        {
+                            GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+
+                            [encoder setComputePipelineState:ctx->pipeline_sum_rows];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                            [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+                            [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+                            [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+                            [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+                            [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+                            [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
+                            [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
+                            [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
+                            [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
+                            [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
+                            [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
+                            [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
+                            [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
+                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:18];
+                            [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:19];
+                            [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:20];
+                            [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:21];
+                            [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:22];
+                            [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:23];
+                            [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:24];
+                            [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:25];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
                     case GGML_OP_SOFT_MAX:
                         {
                             int nth = 32; // SIMD width
@@ -1331,6 +1403,96 @@ void ggml_metal_graph_compute(
                                 }
                             }
                         } break;
+                    case GGML_OP_MUL_MAT_ID:
+                        {
+                            //GGML_ASSERT(ne00 == ne10);
+                            //GGML_ASSERT(ne03 == ne13);
+
+                            GGML_ASSERT(src0t == GGML_TYPE_I32);
+
+                            const int n_as = ne00;
+
+                            // TODO: make this more general
+                            GGML_ASSERT(n_as <= 8);
+
+                            struct ggml_tensor * src2 = gf->nodes[i]->src[2];
+
+                            const int64_t  ne20 = src2 ? src2->ne[0] : 0;
+                            const int64_t  ne21 = src2 ? src2->ne[1] : 0;
+                            const int64_t  ne22 = src2 ? src2->ne[2] : 0;
+                            const int64_t  ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
+
+                            const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
+                            const uint64_t nb21 = src2 ? src2->nb[1] : 0;
+                            const uint64_t nb22 = src2 ? src2->nb[2] : 0;
+                            const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
+
+                            const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
+
+                            GGML_ASSERT(!ggml_is_transposed(src2));
+                            GGML_ASSERT(!ggml_is_transposed(src1));
+
+                            GGML_ASSERT(ne20 % 32 == 0);
+                            // !!!!!!!!! TODO: this assert is probably required but not sure!
+                            //GGML_ASSERT(ne20 >= 64);
+                            GGML_ASSERT(src1t == GGML_TYPE_F32);
+
+                            const uint gqa = ne12/ne22;
+
+                            // find the break-even point where the matrix-matrix kernel becomes more efficient compared
+                            // to the matrix-vector kernel
+                            int ne11_mm_min = 0;
+
+                            const int idx = ((int32_t *) dst->op_params)[0];
+
+                            // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
+                            // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
+                            if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
+                                ne11 > ne11_mm_min) {
+                                switch (src2->type) {
+                                    case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32];  break;
+                                    case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32];  break;
+                                    case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break;
+                                    case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break;
+                                    case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break;
+                                    case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break;
+                                    case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break;
+                                    case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break;
+                                    case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break;
+                                    case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
+                                    case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
+                                    case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
+                                    default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
+                                }
+                                [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
+                                [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
+                                [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
+                                [encoder setBytes:&ne20    length:sizeof(ne20) atIndex:3];
+                                [encoder setBytes:&ne22    length:sizeof(ne22) atIndex:4];
+                                [encoder setBytes:&nb21    length:sizeof(nb21) atIndex:5];
+                                [encoder setBytes:&nb22    length:sizeof(nb22) atIndex:6];
+                                [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:7];
+                                [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:8];
+                                [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:9];
+                                [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:10];
+                                [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:11];
+                                [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:12];
+                                [encoder setBytes:&gqa     length:sizeof(gqa)  atIndex:13];
+                                [encoder setBytes:&idx     length:sizeof(idx)  atIndex:14];
+                                // TODO: how to make this an array? read Metal docs
+                                for (int j = 0; j < n_as; ++j) {
+                                    struct ggml_tensor * src_cur = dst->src[2 + j];
+
+                                    size_t offs_src_cur = 0;
+                                    id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
+
+                                    [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:15 + j];
+                                }
+
+                                [encoder setThreadgroupMemoryLength:8192 atIndex:0];
+                                [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                            }
+                        } break;
                     case GGML_OP_GET_ROWS:
                         {
                             switch (src0->type) {
@@ -1549,6 +1711,27 @@ void ggml_metal_graph_compute(
 
                             [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
                         } break;
+                    case GGML_OP_ARGSORT:
+                        {
+                            GGML_ASSERT(src0->type == GGML_TYPE_F32);
+                            GGML_ASSERT( dst->type == GGML_TYPE_I32);
+
+                            const int nrows = ggml_nrows(src0);
+
+                            enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
+
+                            switch (order) {
+                                case GGML_SORT_ASC:  [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc];  break;
+                                case GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break;
+                                default: GGML_ASSERT(false);
+                            };
+
+                            [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
+                            [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
+                        } break;
                     case GGML_OP_DUP:
                     case GGML_OP_CPY:
                     case GGML_OP_CONT:
@@ -1809,21 +1992,48 @@ static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct
         case GGML_OP_CONCAT:
         case GGML_OP_ADD:
         case GGML_OP_MUL:
+        case GGML_OP_DIV:
         case GGML_OP_SCALE:
         case GGML_OP_SQR:
+        case GGML_OP_SUM_ROWS:
         case GGML_OP_SOFT_MAX:
-        case GGML_OP_DIAG_MASK_INF:
-        case GGML_OP_MUL_MAT:
-        case GGML_OP_GET_ROWS:
         case GGML_OP_RMS_NORM:
         case GGML_OP_NORM:
         case GGML_OP_ALIBI:
         case GGML_OP_ROPE:
         case GGML_OP_IM2COL:
+        case GGML_OP_ARGSORT:
         case GGML_OP_DUP:
         case GGML_OP_CPY:
         case GGML_OP_CONT:
             return true;
+        case GGML_OP_DIAG_MASK_INF:
+        case GGML_OP_GET_ROWS:
+            {
+                // TODO: also check during graph_compute
+                return op->ne[0] % 4 == 0;
+            } break;
+        case GGML_OP_MUL_MAT:
+        case GGML_OP_MUL_MAT_ID:
+            {
+                // TODO: also check during graph_compute
+                struct ggml_tensor * a;
+                struct ggml_tensor * b; UNUSED(b);
+                if (op->op == GGML_OP_MUL_MAT) {
+                    a = op->src[0];
+                    b = op->src[1];
+                } else {
+                    a = op->src[2];
+                    b = op->src[1];
+                }
+                if (a->ne[3] != 1) {
+                    return false;
+                }
+                if (ggml_is_quantized(a->type) && a->ne[2] != 1) {
+                    return false;
+                }
+                return true;
+            } break;
         default:
             return false;
     }
index 5d1357cd72d4592782802a60222e81d2cacb8d8f..4499b1bfc432f0fb0406a5c550a60de78a07a238 100644 (file)
@@ -3,6 +3,7 @@
 using namespace metal;
 
 #define MAX(x, y) ((x) > (y) ? (x) : (y))
+#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
 
 #define QK4_0 32
 #define QR4_0 2
@@ -39,8 +40,13 @@ typedef struct {
     int8_t  qs[QK8_0]; // quants
 } block_q8_0;
 
-// general-purpose kernel for addition of two tensors
-// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
+enum ggml_sort_order {
+    GGML_SORT_ASC,
+    GGML_SORT_DESC,
+};
+
+// general-purpose kernel for addition, multiplication and division of two tensors
+// pros: works for non-contiguous tensors, supports broadcast across all dims
 // cons: not very efficient
 kernel void kernel_add(
         device const char * src0,
@@ -81,16 +87,111 @@ kernel void kernel_add(
     const int64_t i12 = i02 % ne12;
     const int64_t i11 = i01 % ne11;
 
-    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
-    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
-    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1  + tpitg.x*nb0;
+    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        const int i10 = i0 % ne10;
+        *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
+    }
+}
+
+kernel void kernel_mul(
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        constant  int64_t & ne00,
+        constant  int64_t & ne01,
+        constant  int64_t & ne02,
+        constant  int64_t & ne03,
+        constant  int64_t & nb00,
+        constant  int64_t & nb01,
+        constant  int64_t & nb02,
+        constant  int64_t & nb03,
+        constant  int64_t & ne10,
+        constant  int64_t & ne11,
+        constant  int64_t & ne12,
+        constant  int64_t & ne13,
+        constant  int64_t & nb10,
+        constant  int64_t & nb11,
+        constant  int64_t & nb12,
+        constant  int64_t & nb13,
+        constant  int64_t & ne0,
+        constant  int64_t & ne1,
+        constant  int64_t & ne2,
+        constant  int64_t & ne3,
+        constant  int64_t & nb0,
+        constant  int64_t & nb1,
+        constant  int64_t & nb2,
+        constant  int64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig.z;
+    const int64_t i02 = tgpig.y;
+    const int64_t i01 = tgpig.x;
+
+    const int64_t i13 = i03 % ne13;
+    const int64_t i12 = i02 % ne12;
+    const int64_t i11 = i01 % ne11;
+
+    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
 
     for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0];
+        const int i10 = i0 % ne10;
+        *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
+    }
+}
 
-        src0_ptr += ntg.x*nb00;
-        src1_ptr += ntg.x*nb10;
-        dst_ptr  += ntg.x*nb0;
+kernel void kernel_div(
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        constant  int64_t & ne00,
+        constant  int64_t & ne01,
+        constant  int64_t & ne02,
+        constant  int64_t & ne03,
+        constant  int64_t & nb00,
+        constant  int64_t & nb01,
+        constant  int64_t & nb02,
+        constant  int64_t & nb03,
+        constant  int64_t & ne10,
+        constant  int64_t & ne11,
+        constant  int64_t & ne12,
+        constant  int64_t & ne13,
+        constant  int64_t & nb10,
+        constant  int64_t & nb11,
+        constant  int64_t & nb12,
+        constant  int64_t & nb13,
+        constant  int64_t & ne0,
+        constant  int64_t & ne1,
+        constant  int64_t & ne2,
+        constant  int64_t & ne3,
+        constant  int64_t & nb0,
+        constant  int64_t & nb1,
+        constant  int64_t & nb2,
+        constant  int64_t & nb3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig.z;
+    const int64_t i02 = tgpig.y;
+    const int64_t i01 = tgpig.x;
+
+    const int64_t i13 = i03 % ne13;
+    const int64_t i12 = i02 % ne12;
+    const int64_t i11 = i01 % ne11;
+
+    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        const int i10 = i0 % ne10;
+        *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
     }
 }
 
@@ -105,23 +206,22 @@ kernel void kernel_add_row(
     dst[tpig] = src0[tpig] + src1[tpig % nb];
 }
 
-kernel void kernel_mul(
+kernel void kernel_mul_row(
         device const float4 * src0,
         device const float4 * src1,
         device       float4 * dst,
+        constant    int64_t & nb  [[buffer(27)]],
         uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * src1[tpig];
+    dst[tpig] = src0[tpig] * src1[tpig % nb];
 }
 
-// assumption: src1 is a row
-// broadcast src1 into src0
-kernel void kernel_mul_row(
+kernel void kernel_div_row(
         device const float4 * src0,
         device const float4 * src1,
         device       float4 * dst,
-        constant    int64_t & nb,
+        constant    int64_t & nb  [[buffer(27)]],
         uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * src1[tpig % nb];
+    dst[tpig] = src0[tpig] / src1[tpig % nb];
 }
 
 kernel void kernel_scale(
@@ -162,6 +262,54 @@ kernel void kernel_sqr(
     dst[tpig] = src0[tpig] * src0[tpig];
 }
 
+kernel void kernel_sum_rows(
+        device const float * src0,
+        device       float * dst,
+        constant  int64_t & ne00,
+        constant  int64_t & ne01,
+        constant  int64_t & ne02,
+        constant  int64_t & ne03,
+        constant  int64_t & nb00,
+        constant  int64_t & nb01,
+        constant  int64_t & nb02,
+        constant  int64_t & nb03,
+        constant  int64_t & ne10,
+        constant  int64_t & ne11,
+        constant  int64_t & ne12,
+        constant  int64_t & ne13,
+        constant  int64_t & nb10,
+        constant  int64_t & nb11,
+        constant  int64_t & nb12,
+        constant  int64_t & nb13,
+        constant  int64_t & ne0,
+        constant  int64_t & ne1,
+        constant  int64_t & ne2,
+        constant  int64_t & ne3,
+        constant  int64_t & nb0,
+        constant  int64_t & nb1,
+        constant  int64_t & nb2,
+        constant  int64_t & nb3,
+        uint3 tpig[[thread_position_in_grid]]) {
+    int64_t i3 = tpig.z;
+    int64_t i2 = tpig.y;
+    int64_t i1 = tpig.x;
+
+    if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
+        return;
+    }
+
+    device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
+    device       float * dst_row = (device       float *) ((device       char *) dst  + i1*nb1  + i2*nb2  + i3*nb3);
+
+    float row_sum = 0;
+
+    for (int64_t i0 = 0; i0 < ne00; i0++) {
+        row_sum += src_row[i0];
+    }
+
+    dst_row[0] = row_sum;
+}
+
 constant float GELU_COEF_A    = 0.044715f;
 constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
 
@@ -1120,17 +1268,21 @@ kernel void kernel_alibi_f32(
     const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
     const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
     const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+    const int64_t k = i3*ne3 + i2;
 
-    device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
     float m_k;
-    if (i2 < n_heads_log2_floor) {
-        m_k = pow(m0, i2 + 1);
+    if (k < n_heads_log2_floor) {
+        m_k = pow(m0, k + 1);
     } else {
-        m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
+        m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
     }
+    
+    device       char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
+    device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
     for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
-        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
-        dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
+        const  float   src_v = *(device float *)(src_row + i00*nb00);
+        device float * dst_v =  (device float *)(dst_row + i00*nb0);
+        *dst_v = i00 * m_k + src_v;
     }
 }
 
@@ -1335,6 +1487,58 @@ kernel void kernel_im2col_f16(
     }
 }
 
+// bitonic sort implementation following the CUDA kernels as reference
+typedef void (argsort_t)(
+        device const float * x,
+        device     int32_t * dst,
+        constant   int64_t & ncols,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]]);
+
+template<ggml_sort_order order>
+kernel void kernel_argsort_f32_i32(
+        device const float   * x,
+        device       int32_t * dst,
+        constant     int64_t & ncols,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]]) {
+    // bitonic sort
+    int col = tpitg[0];
+    int row = tgpig[1];
+
+    if (col >= ncols) return;
+
+    device const float   * x_row   = x   + row * ncols;
+    device       int32_t * dst_row = dst + row * ncols;
+
+    // initialize indices
+    if (col < ncols) {
+        dst_row[col] = col;
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    for (int k = 2; k <= ncols; k *= 2) {
+        for (int j = k / 2; j > 0; j /= 2) {
+            int ixj = col ^ j;
+            if (ixj > col) {
+                if ((col & k) == 0) {
+                    if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
+                        SWAP(dst_row[col], dst_row[ixj]);
+                    }
+                } else {
+                    if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
+                        SWAP(dst_row[col], dst_row[ixj]);
+                    }
+                }
+            }
+            threadgroup_barrier(mem_flags::mem_threadgroup);
+        }
+    }
+}
+
+template [[host_name("kernel_argsort_f32_i32_asc")]]  kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
+template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
+
 kernel void kernel_cpy_f16_f16(
         device const half * src0,
         device       half * dst,
@@ -2749,7 +2953,7 @@ kernel void kernel_get_rows(
 
 // each block_q contains 16*nl weights
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
-kernel void kernel_mul_mm(device const  uchar * src0,
+void kernel_mul_mm_impl(device const  uchar * src0,
                           device const  uchar * src1,
                           device        float * dst,
                           constant    int64_t & ne00,
@@ -2876,14 +3080,112 @@ kernel void kernel_mul_mm(device const  uchar * src0,
     }
 }
 
+template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
+kernel void kernel_mul_mm(device const  uchar * src0,
+                          device const  uchar * src1,
+                          device        float * dst,
+                          constant    int64_t & ne00,
+                          constant    int64_t & ne02,
+                          constant    int64_t & nb01,
+                          constant    int64_t & nb02,
+                          constant    int64_t & ne12,
+                          constant    int64_t & nb10,
+                          constant    int64_t & nb11,
+                          constant    int64_t & nb12,
+                          constant    int64_t & ne0,
+                          constant    int64_t & ne1,
+                          constant       uint & gqa,
+                          threadgroup   uchar * shared_memory [[threadgroup(0)]],
+                          uint3                 tgpig[[threadgroup_position_in_grid]],
+                          uint                  tiitg[[thread_index_in_threadgroup]],
+                          uint                  sgitg[[simdgroup_index_in_threadgroup]]) {
+    kernel_mul_mm_impl<block_q, nl, dequantize_func>(
+        src0,
+        src1,
+        dst,
+        ne00,
+        ne02,
+        nb01,
+        nb02,
+        ne12,
+        nb10,
+        nb11,
+        nb12,
+        ne0,
+        ne1,
+        gqa,
+        shared_memory,
+        tgpig,
+        tiitg,
+        sgitg);
+}
+
+template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
+kernel void kernel_mul_mm_id(
+        device const int32_t * ids,
+        device const   uchar * src1,
+        device         float * dst,
+        constant     int64_t & ne00,
+        constant     int64_t & ne02,
+        constant     int64_t & nb01,
+        constant     int64_t & nb02,
+        constant     int64_t & ne12,
+        constant     int64_t & nb10,
+        constant     int64_t & nb11,
+        constant     int64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant        uint & gqa,
+        constant         int & idx,
+        device const   uchar * src00,
+        device const   uchar * src01,
+        device const   uchar * src02,
+        device const   uchar * src03,
+        device const   uchar * src04,
+        device const   uchar * src05,
+        device const   uchar * src06,
+        device const   uchar * src07,
+        threadgroup    uchar * shared_memory [[threadgroup(0)]],
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+    kernel_mul_mm_impl<block_q, nl, dequantize_func>(
+        src0[ids[idx]],
+        src1,
+        dst,
+        ne00,
+        ne02,
+        nb01,
+        nb02,
+        ne12,
+        nb10,
+        nb11,
+        nb12,
+        ne0,
+        ne1,
+        gqa,
+        shared_memory,
+        tgpig,
+        tiitg,
+        sgitg);
+}
+
 #if QK_K == 256
 #define QK_NL 16
 #else
 #define QK_NL 4
 #endif
 
-typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
-                          constant uint64_t &, constant uint64_t &, uint, uint, uint);
+typedef void (get_rows_t)(
+        device const void * src0,
+        device const  int * src1,
+        device      float * dst,
+        constant  int64_t & ne00,
+        constant uint64_t & nb01,
+        constant uint64_t & nb1,
+        uint, uint, uint);
 
 template [[host_name("kernel_get_rows_f32")]]  kernel get_rows_t kernel_get_rows<float4x4,   1, dequantize_f32>;
 template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_t kernel_get_rows<half4x4,    1, dequantize_f16>;
@@ -2913,7 +3215,8 @@ typedef void (mat_mm_t)(
         constant    int64_t & ne0,
         constant    int64_t & ne1,
         constant       uint & gqa,
-        threadgroup uchar *, uint3, uint, uint);
+        threadgroup   uchar *,
+        uint3, uint, uint);
 
 template [[host_name("kernel_mul_mm_f32_f32")]]  kernel mat_mm_t kernel_mul_mm<float4x4,   1,     dequantize_f32>;
 template [[host_name("kernel_mul_mm_f16_f32")]]  kernel mat_mm_t kernel_mul_mm<half4x4,    1,     dequantize_f16>;
@@ -2927,3 +3230,43 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
 template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
 template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
 template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
+
+typedef void (mat_mm_id_t)(
+        device const int32_t * ids,
+        device const   uchar * src1,
+        device         float * dst,
+        constant     int64_t & ne00,
+        constant     int64_t & ne02,
+        constant     int64_t & nb01,
+        constant     int64_t & nb02,
+        constant     int64_t & ne12,
+        constant     int64_t & nb10,
+        constant     int64_t & nb11,
+        constant     int64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant        uint & gqa,
+        constant         int & idx,
+        device const   uchar * src00,
+        device const   uchar * src01,
+        device const   uchar * src02,
+        device const   uchar * src03,
+        device const   uchar * src04,
+        device const   uchar * src05,
+        device const   uchar * src06,
+        device const   uchar * src07,
+        threadgroup    uchar *,
+        uint3, uint, uint);
+
+template [[host_name("kernel_mul_mm_id_f32_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id<float4x4,   1,     dequantize_f32>;
+template [[host_name("kernel_mul_mm_id_f16_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id<half4x4,    1,     dequantize_f16>;
+template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2,     dequantize_q4_0>;
+template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2,     dequantize_q4_1>;
+template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2,     dequantize_q5_0>;
+template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2,     dequantize_q5_1>;
+template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2,     dequantize_q8_0>;
+template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
+template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
+template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
+template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
+template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
index 1b192b765f80a764c3cbaa767edef603f30d2ecf..26c86c42f58ffaf8736689a3940b9c37eeb79f07 100644 (file)
@@ -233,24 +233,6 @@ inline static void * ggml_aligned_malloc(size_t size) {
 #define UNUSED GGML_UNUSED
 #define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
 
-//
-// tensor access macros
-//
-
-#define GGML_TENSOR_UNARY_OP_LOCALS \
-    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
-    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
-    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
-    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
-
-#define GGML_TENSOR_BINARY_OP_LOCALS \
-    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
-    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
-    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
-    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb) \
-    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
-    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
-
 #if defined(GGML_USE_ACCELERATE)
 #include <Accelerate/Accelerate.h>
 #if defined(GGML_USE_CLBLAST) // allow usage of CLBlast alongside Accelerate functions
@@ -1613,6 +1595,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "GROUP_NORM",
 
     "MUL_MAT",
+    "MUL_MAT_ID",
     "OUT_PROD",
 
     "SCALE",
@@ -1640,6 +1623,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "POOL_1D",
     "POOL_2D",
     "UPSCALE",
+    "ARGSORT",
 
     "FLASH_ATTN",
     "FLASH_FF",
@@ -1666,7 +1650,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CROSS_ENTROPY_LOSS_BACK",
 };
 
-static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
+static_assert(GGML_OP_COUNT == 70, "GGML_OP_COUNT != 70");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1695,6 +1679,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "group_norm(x)",
 
     "X*Y",
+    "X[i]*Y",
     "X*Y",
 
     "x*v",
@@ -1722,6 +1707,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "pool_1d(x)",
     "pool_2d(x)",
     "upscale(x)",
+    "argsort(x)",
 
     "flash_attn(x)",
     "flash_ff(x)",
@@ -1748,7 +1734,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "cross_entropy_loss_back(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
+static_assert(GGML_OP_COUNT == 70, "GGML_OP_COUNT != 70");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -1789,6 +1775,7 @@ static void ggml_setup_op_has_task_pass(void) {
 
         p[GGML_OP_ACC                    ] = true;
         p[GGML_OP_MUL_MAT                ] = true;
+        p[GGML_OP_MUL_MAT_ID             ] = true;
         p[GGML_OP_OUT_PROD               ] = true;
         p[GGML_OP_SET                    ] = true;
         p[GGML_OP_GET_ROWS_BACK          ] = true;
@@ -3186,9 +3173,7 @@ static struct ggml_tensor * ggml_add_impl(
         struct ggml_tensor * a,
         struct ggml_tensor * b,
         bool inplace) {
-    // TODO: support less-strict constraint
-    //       GGML_ASSERT(ggml_can_repeat(b, a));
-    GGML_ASSERT(ggml_can_repeat_rows(b, a));
+    GGML_ASSERT(ggml_can_repeat(b, a));
 
     bool is_node = false;
 
@@ -3403,9 +3388,7 @@ static struct ggml_tensor * ggml_mul_impl(
         struct ggml_tensor * a,
         struct ggml_tensor * b,
         bool inplace) {
-    // TODO: support less-strict constraint
-    //       GGML_ASSERT(ggml_can_repeat(b, a));
-    GGML_ASSERT(ggml_can_repeat_rows(b, a));
+    GGML_ASSERT(ggml_can_repeat(b, a));
 
     bool is_node = false;
 
@@ -3450,7 +3433,7 @@ static struct ggml_tensor * ggml_div_impl(
         struct ggml_tensor * a,
         struct ggml_tensor * b,
         bool inplace) {
-    GGML_ASSERT(ggml_are_same_shape(a, b));
+    GGML_ASSERT(ggml_can_repeat(b, a));
 
     bool is_node = false;
 
@@ -4088,6 +4071,49 @@ struct ggml_tensor * ggml_mul_mat(
     return result;
 }
 
+// ggml_mul_mat_id
+
+struct ggml_tensor * ggml_mul_mat_id(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * as[],
+        struct ggml_tensor  * ids,
+        int                   id,
+        struct ggml_tensor  * b) {
+
+    int64_t n_as = ids->ne[0];
+
+    GGML_ASSERT(ids->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_is_vector(ids));
+    GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2);
+    GGML_ASSERT(id >= 0 && id < n_as);
+
+    bool is_node = false;
+
+    if (as[0]->grad || b->grad) {
+        is_node = true;
+    }
+
+    const int64_t ne[4] = { as[0]->ne[1], b->ne[1], b->ne[2], b->ne[3] };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne);
+
+    ggml_set_op_params_i32(result, 0, id);
+
+    result->op   = GGML_OP_MUL_MAT_ID;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = ids;
+    result->src[1] = b;
+
+    for (int64_t i = 0; i < n_as; i++) {
+        struct ggml_tensor * a = as[i];
+        GGML_ASSERT(ggml_are_same_shape(as[0], a));
+        GGML_ASSERT(ggml_can_mul_mat(a, b));
+        GGML_ASSERT(!ggml_is_transposed(a));
+        result->src[i + 2] = a;
+    }
+
+    return result;
+}
+
 // ggml_out_prod
 
 struct ggml_tensor * ggml_out_prod(
@@ -5478,6 +5504,43 @@ struct ggml_tensor * ggml_upscale(
     return ggml_upscale_impl(ctx, a, scale_factor);
 }
 
+// ggml_argsort
+
+struct ggml_tensor * ggml_argsort(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        enum ggml_sort_order  order) {
+    bool is_node = false;
+
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, a->n_dims, a->ne);
+
+    ggml_set_op_params_i32(result, 0, (int32_t) order);
+
+    result->op   = GGML_OP_ARGSORT;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+
+    return result;
+}
+
+// ggml_top_k
+
+struct ggml_tensor * ggml_top_k(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   k) {
+    GGML_ASSERT(a->ne[0] >= k);
+
+    struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_DESC);
+
+    result = ggml_view_4d(ctx, result,
+                k, result->ne[1], result->ne[2], result->ne[3],
+                   result->nb[1], result->nb[2], result->nb[3],
+                0);
+
+    return result;
+}
+
 // ggml_flash_attn
 
 struct ggml_tensor * ggml_flash_attn(
@@ -6837,7 +6900,7 @@ static void ggml_compute_forward_add_f32(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -6870,16 +6933,19 @@ static void ggml_compute_forward_add_f32(
             const int64_t i13 = i03 % ne13;
             const int64_t i12 = i02 % ne12;
             const int64_t i11 = i01 % ne11;
+            const int64_t nr0 = ne00 / ne10;
 
             float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
             float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
             float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
 
+            for (int64_t r = 0; r < nr0; ++r) {
 #ifdef GGML_USE_ACCELERATE
-            vDSP_vadd(src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00);
+                vDSP_vadd(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
 #else
-            ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
+                ggml_vec_add_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
 #endif
+            }
         }
     } else {
         // src1 is not contiguous
@@ -6896,8 +6962,9 @@ static void ggml_compute_forward_add_f32(
             float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
             float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
 
-            for (int i0 = 0; i0 < ne0; i0++) {
-                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
+            for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                const int64_t i10 = i0 % ne10;
+                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
 
                 dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
             }
@@ -7617,7 +7684,7 @@ static void ggml_compute_forward_mul_f32(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -7640,7 +7707,6 @@ static void ggml_compute_forward_mul_f32(
 
     GGML_ASSERT( nb0 == sizeof(float));
     GGML_ASSERT(nb00 == sizeof(float));
-    GGML_ASSERT(ne00 == ne10);
 
     if (nb10 == sizeof(float)) {
         for (int64_t ir = ith; ir < nr; ir += nth) {
@@ -7652,20 +7718,21 @@ static void ggml_compute_forward_mul_f32(
             const int64_t i13 = i03 % ne13;
             const int64_t i12 = i02 % ne12;
             const int64_t i11 = i01 % ne11;
+            const int64_t nr0 = ne00 / ne10;
 
             float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
             float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
             float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
 
+            for (int64_t r = 0 ; r < nr0; ++r) {
 #ifdef GGML_USE_ACCELERATE
-            UNUSED(ggml_vec_mul_f32);
+                UNUSED(ggml_vec_mul_f32);
 
-            vDSP_vmul( src0_ptr, 1, src1_ptr, 1, dst_ptr,  1, ne00);
+                vDSP_vmul(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
 #else
-            ggml_vec_mul_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
+                ggml_vec_mul_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
 #endif
-                // }
-            // }
+            }
         }
     } else {
         // src1 is not contiguous
@@ -7683,8 +7750,9 @@ static void ggml_compute_forward_mul_f32(
             float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
             float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
 
-            for (int64_t i0 = 0; i0 < ne00; i0++) {
-                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
+            for (int64_t i0 = 0; i0 < ne00; ++i0) {
+                const int64_t i10 = i0 % ne10;
+                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
 
                 dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr);
             }
@@ -7718,14 +7786,16 @@ static void ggml_compute_forward_div_f32(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    assert(params->ith == 0);
-    assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
-    const int nr  = ggml_nrows(src0);
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t nr = ggml_nrows(src0);
 
     GGML_TENSOR_BINARY_OP_LOCALS
 
@@ -7733,41 +7803,50 @@ static void ggml_compute_forward_div_f32(
     GGML_ASSERT(nb00 == sizeof(float));
 
     if (nb10 == sizeof(float)) {
-        for (int ir = 0; ir < nr; ++ir) {
-            // src0, src1 and dst are same shape => same indices
-            const int i3 = ir/(ne2*ne1);
-            const int i2 = (ir - i3*ne2*ne1)/ne1;
-            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+        for (int64_t ir = ith; ir < nr; ir += nth) {
+            // src0 and dst are same shape => same indices
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
 
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+            const int64_t nr0 = ne00 / ne10;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+            float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
+
+            for (int64_t r = 0; r < nr0; ++r) {
 #ifdef GGML_USE_ACCELERATE
-            UNUSED(ggml_vec_div_f32);
+                UNUSED(ggml_vec_div_f32);
 
-            vDSP_vdiv(
-                    (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
-                    (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
-                    (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ), 1,
-                    ne0);
+                vDSP_vdiv(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
 #else
-            ggml_vec_div_f32(ne0,
-                    (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
-                    (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
-                    (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
+                ggml_vec_div_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
 #endif
-                // }
-            // }
+            }
         }
     } else {
         // src1 is not contiguous
-        for (int ir = 0; ir < nr; ++ir) {
-            // src0, src1 and dst are same shape => same indices
-            const int i3 = ir/(ne2*ne1);
-            const int i2 = (ir - i3*ne2*ne1)/ne1;
-            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+        for (int64_t ir = ith; ir < nr; ir += nth) {
+            // src0 and dst are same shape => same indices
+            // src1 is broadcastable across src0 and dst in i1, i2, i3
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
 
-            float * dst_ptr  = (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
-            float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-            for (int i0 = 0; i0 < ne0; i0++) {
-                float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+
+            for (int64_t i0 = 0; i0 < ne00; ++i0) {
+                const int64_t i10 = i0 % ne10;
+                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
 
                 dst_ptr[i0] = src0_ptr[i0] / (*src1_ptr);
             }
@@ -8213,7 +8292,7 @@ static void ggml_compute_forward_repeat_f16(
         return;
     }
 
-    GGML_TENSOR_UNARY_OP_LOCALS;
+    GGML_TENSOR_UNARY_OP_LOCALS
 
     // guaranteed to be an integer due to the check in ggml_can_repeat
     const int nr0 = (int)(ne0/ne00);
@@ -9526,6 +9605,8 @@ static void ggml_compute_forward_mul_mat(
             char * wdata = params->wdata;
             const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
 
+            assert(params->wsize >= ne11*ne12*ne13*row_size);
+
             for (int64_t i13 = 0; i13 < ne13; ++i13) {
                 for (int64_t i12 = 0; i12 < ne12; ++i12) {
                     for (int64_t i11 = 0; i11 < ne11; ++i11) {
@@ -9627,6 +9708,26 @@ static void ggml_compute_forward_mul_mat(
     }
 }
 
+// ggml_compute_forward_mul_mat_id
+
+static void ggml_compute_forward_mul_mat_id(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * ids = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    const int id = ggml_get_op_params_i32(dst, 0);
+
+    const int a_id = ((int32_t *)ids->data)[id];
+
+    GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
+
+    const struct ggml_tensor * src0 = dst->src[a_id + 2];
+
+    ggml_compute_forward_mul_mat(params, src0, src1, dst);
+}
+
 // ggml_compute_forward_out_prod
 
 static void ggml_compute_forward_out_prod_f32(
@@ -11960,6 +12061,67 @@ static void ggml_compute_forward_upscale(
     }
 }
 
+// ggml_compute_forward_argsort
+
+static void ggml_compute_forward_argsort_f32(
+    const struct ggml_compute_params * params,
+    const struct ggml_tensor * src0,
+    struct ggml_tensor * dst) {
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT(nb0 == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t nr = ggml_nrows(src0);
+
+    enum ggml_sort_order order = (enum ggml_sort_order) ggml_get_op_params_i32(dst, 0);
+
+    for (int64_t i = ith; i < nr; i += nth) {
+        int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
+        const float * src_data = (float *)((char *) src0->data + i*nb01);
+
+        for (int64_t j = 0; j < ne0; j++) {
+            dst_data[j] = j;
+        }
+
+        // C doesn't have a functional sort, so we do a bubble sort instead
+        for (int64_t j = 0; j < ne0; j++) {
+            for (int64_t k = j + 1; k < ne0; k++) {
+                if ((order == GGML_SORT_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
+                    (order == GGML_SORT_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
+                    int32_t tmp = dst_data[j];
+                    dst_data[j] = dst_data[k];
+                    dst_data[k] = tmp;
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_argsort(
+    const struct ggml_compute_params * params,
+    const struct ggml_tensor * src0,
+    struct ggml_tensor * dst) {
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_argsort_f32(params, src0, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_flash_attn
 
 static void ggml_compute_forward_flash_attn_f32(
@@ -13783,6 +13945,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
             } break;
+        case GGML_OP_MUL_MAT_ID:
+            {
+                ggml_compute_forward_mul_mat_id(params, tensor);
+            } break;
         case GGML_OP_OUT_PROD:
             {
                 ggml_compute_forward_out_prod(params, tensor->src[0], tensor->src[1], tensor);
@@ -13887,6 +14053,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_upscale(params, tensor->src[0], tensor);
             } break;
+        case GGML_OP_ARGSORT:
+            {
+                ggml_compute_forward_argsort(params, tensor->src[0], tensor);
+            } break;
         case GGML_OP_FLASH_ATTN:
             {
                 const int32_t t = ggml_get_op_params_i32(tensor, 0);
@@ -14537,6 +14707,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 zero_table);
                 }
             } break;
+        case GGML_OP_MUL_MAT_ID:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_OUT_PROD:
             {
                 GGML_ASSERT(false); // TODO: not implemented
@@ -14875,6 +15049,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
+        case GGML_OP_ARGSORT:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_FLASH_ATTN:
             {
                 struct ggml_tensor * flash_grad = NULL;
@@ -15471,7 +15649,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
                 n_tasks = n_threads;
             } break;
         case GGML_OP_SUB:
-        case GGML_OP_DIV:
         case GGML_OP_SQR:
         case GGML_OP_SQRT:
         case GGML_OP_LOG:
@@ -15510,6 +15687,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             break;
         case GGML_OP_SILU_BACK:
         case GGML_OP_MUL:
+        case GGML_OP_DIV:
         case GGML_OP_NORM:
         case GGML_OP_RMS_NORM:
         case GGML_OP_RMS_NORM_BACK:
@@ -15547,6 +15725,11 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
                 }
 #endif
             } break;
+        case GGML_OP_MUL_MAT_ID:
+            {
+                // FIXME: blas
+                n_tasks = n_threads;
+            } break;
         case GGML_OP_OUT_PROD:
             {
                 n_tasks = n_threads;
@@ -15603,6 +15786,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             {
                 n_tasks = n_threads;
             } break;
+        case GGML_OP_ARGSORT:
+            {
+                n_tasks = n_threads;
+            } break;
         case GGML_OP_FLASH_ATTN:
             {
                 n_tasks = n_threads;
@@ -15866,6 +16053,23 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                         cur = ggml_type_size(vec_dot_type)*ggml_nelements(node->src[1])/ggml_blck_size(vec_dot_type);
                     }
                 } break;
+            case GGML_OP_MUL_MAT_ID:
+                {
+                    const struct ggml_tensor * a = node->src[2];
+                    const struct ggml_tensor * b = node->src[1];
+                    const enum ggml_type vec_dot_type = type_traits[a->type].vec_dot_type;
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+                    if (ggml_compute_forward_mul_mat_use_blas(a, b, node)) {
+                        if (a->type != GGML_TYPE_F32) {
+                            // here we need memory just for single 2D matrix from src0
+                            cur = ggml_type_size(GGML_TYPE_F32)*(a->ne[0]*a->ne[1]);
+                        }
+                    } else
+#endif
+                    if (b->type != vec_dot_type) {
+                        cur = ggml_type_size(vec_dot_type)*ggml_nelements(b)/ggml_blck_size(vec_dot_type);
+                    }
+                } break;
             case GGML_OP_OUT_PROD:
                 {
                     n_tasks = n_threads;
@@ -17749,8 +17953,8 @@ size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t *
             memcpy(&qh, &y[i].qh, sizeof(qh));
 
             for (int j = 0; j < QK5_0; j += 2) {
-                const uint8_t vh0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
-                const uint8_t vh1 = ((qh & (1u << (j + 16))) >> (j + 12));
+                const uint8_t vh0 = ((qh & (1u << (j/2 + 0 ))) >> (j/2 + 0 )) << 4;
+                const uint8_t vh1 = ((qh & (1u << (j/2 + 16))) >> (j/2 + 12));
 
                 // cast to 16 bins
                 const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2;
@@ -17779,8 +17983,8 @@ size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t *
             memcpy(&qh, &y[i].qh, sizeof(qh));
 
             for (int j = 0; j < QK5_1; j += 2) {
-                const uint8_t vh0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
-                const uint8_t vh1 = ((qh & (1u << (j + 16))) >> (j + 12));
+                const uint8_t vh0 = ((qh & (1u << (j/2 + 0 ))) >> (j/2 + 0 )) << 4;
+                const uint8_t vh1 = ((qh & (1u << (j/2 + 16))) >> (j/2 + 12));
 
                 // cast to 16 bins
                 const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2;
index b42b0fa9febbd3d031f93962f1897dbd66cf68c7..d30523a285835c6c4e61335cc4e33767e27bde7c 100644 (file)
@@ -2,9 +2,10 @@
 #include <ggml-alloc.h>
 #include <ggml-backend.h>
 #include <ggml-backend-impl.h>
+#include <algorithm>
 #include <array>
-#include <cstring>
 #include <cfloat>
+#include <cstring>
 #include <functional>
 #include <memory>
 #include <random>
@@ -28,10 +29,12 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
 
     if (tensor->type == GGML_TYPE_F32) {
         ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float));
-    } else if (tensor->type == GGML_TYPE_F16) {
-        std::vector<ggml_fp16_t> data16(size);
-        ggml_fp32_to_fp16_row(data.data(), data16.data(), size);
-        ggml_backend_tensor_set(tensor, data16.data(), 0, size * sizeof(ggml_fp16_t));
+    } else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16) {
+        GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0);
+        std::vector<uint8_t> dataq(ggml_type_size(tensor->type)*size/ggml_blck_size(tensor->type));
+        int64_t hist[16];
+        ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), 0, size, hist);
+        ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size());
     } else {
         GGML_ASSERT(false);
     }
@@ -55,6 +58,8 @@ static std::vector<float> tensor_to_float(const ggml_tensor * t) {
                         v = (float) ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i]);
                     } else if (t->type == GGML_TYPE_F32) {
                         v = *(float *) &buf[i];
+                    } else if (t->type == GGML_TYPE_I32) {
+                        v = *(int32_t *) &buf[i];
                     } else {
                         GGML_ASSERT(false);
                     }
@@ -206,13 +211,17 @@ struct test_case {
 
     virtual ggml_tensor * build_graph(ggml_context * ctx) = 0;
 
+    virtual double max_nmse_err() {
+        return 1e-6;
+    }
+
     virtual void initialize_tensors(ggml_context * ctx) {
         for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
             init_tensor_uniform(t);
         }
     }
 
-    bool eval(ggml_backend_t backend1, ggml_backend_t backend2) {
+    bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name) {
         ggml_init_params params = {
             /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead(),
             /* .mem_base = */ NULL,
@@ -222,6 +231,12 @@ struct test_case {
 
         ggml_tensor * out = build_graph(ctx);
 
+        if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) {
+            //printf("  %s: skipping\n", ggml_op_desc(out));
+            ggml_free(ctx);
+            return true;
+        }
+
         // check if backends support op
         for (ggml_backend_t backend : {backend1, backend2}) {
             if (!ggml_backend_supports_op(backend, out)) {
@@ -242,18 +257,26 @@ struct test_case {
         initialize_tensors(ctx);
 
         // compare
-        bool ok = true;
+        struct callback_userdata {
+            bool   ok;
+            double max_err;
+        };
+
+        callback_userdata ud {
+            true,
+            max_nmse_err(),
+        };
 
         auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool {
             std::vector<float> f1 = tensor_to_float(t1);
             std::vector<float> f2 = tensor_to_float(t2);
-            bool * ok = (bool *) user_data;
+            callback_userdata * ud = (callback_userdata *) user_data;
 
             for (size_t i = 0; i < f1.size(); i++) {
                 // check for nans
                 if (std::isnan(f1[i]) || std::isnan(f2[i])) {
-                    printf("    Error: %s: NaN\n", ggml_op_desc(t1));
-                    *ok = false;
+                    printf("    Error: %s: NaN at index %zu\n", ggml_op_desc(t1), i);
+                    ud->ok = false;
                     return true;
                 }
                 // check for infs: both must be inf of the same sign, or both must be finite
@@ -261,29 +284,29 @@ struct test_case {
                     if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) {
                         if (std::signbit(f1[i]) != std::signbit(f2[i])) {
                             printf("    Error: %s: inf sign mismatch: %f %f\n", ggml_op_desc(t1), f1[i], f2[i]);
-                            *ok = false;
+                            ud->ok = false;
                             return true;
                         }
                     } else {
                         printf("    Error: %s: inf mismatch: %f %f\n", ggml_op_desc(t1), f1[i], f2[i]);
-                        *ok = false;
+                        ud->ok = false;
                         return true;
                     }
                 }
             }
 
             double err = nmse(f1.data(), f2.data(), f1.size());
-            if (err > 1e-6) {
+            if (err > ud->max_err) {
                 printf("    Error: %s: NMSE = %f\n", ggml_op_desc(t1), err);
-                *ok = false;
+                ud->ok = false;
             }
             return true;
        };
 
-        ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ok);
+        ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud);
 
         printf("  %s(%s): ", ggml_op_desc(out), vars().c_str());
-        if (ok) {
+        if (ud.ok) {
             printf("\033[1;32mOK\033[0m\n");
         } else {
             printf("\033[1;31mFAIL\033[0m\n");
@@ -293,7 +316,7 @@ struct test_case {
 
         ggml_free(ctx);
 
-        return ok;
+        return ud.ok;
     }
 };
 
@@ -444,30 +467,11 @@ struct test_cont : public test_case {
 };
 
 // GGML_OP_ADD
-struct test_add : public test_case {
-    const ggml_type type;
-    const std::array<int64_t, 4> ne;
-    const std::array<int,4> nr;
-
-    std::string vars() override {
-        return VARS_TO_STR3(type, ne, nr);
-    }
-
-    test_add(ggml_type type = GGML_TYPE_F32,
-            std::array<int64_t, 4> ne = {10, 10, 1, 1},
-            std::array<int, 4> nr = {1, 2, 1, 1})
-        : type(type), ne(ne), nr(nr) {}
-
-    ggml_tensor * build_graph(ggml_context * ctx) override {
-        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);
-        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_tensor * out = ggml_add(ctx, a, b);
-        return out;
-    }
-};
-
 // GGML_OP_MUL
-struct test_mul : public test_case {
+// GGML_OP_DIV
+struct test_bin_bcast : public test_case {
+    using op_t = std::function<ggml_tensor * (ggml_context *, ggml_tensor *, ggml_tensor *)>;
+    op_t op;
     const ggml_type type;
     const std::array<int64_t, 4> ne;
     const std::array<int,4> nr;
@@ -476,15 +480,15 @@ struct test_mul : public test_case {
         return VARS_TO_STR3(type, ne, nr);
     }
 
-    test_mul(ggml_type type = GGML_TYPE_F32,
+    test_bin_bcast(op_t op, ggml_type type = GGML_TYPE_F32,
             std::array<int64_t, 4> ne = {10, 10, 1, 1},
             std::array<int, 4> nr = {1, 2, 1, 1})
-        : type(type), ne(ne), nr(nr) {}
+        : op(op), type(type), ne(ne), nr(nr) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);
         ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_tensor * out = ggml_mul(ctx, a, b);
+        ggml_tensor * out = op(ctx, a, b);
         return out;
     }
 };
@@ -568,6 +572,10 @@ struct test_mul_mat : public test_case {
         return VARS_TO_STR7(type_a, type_b, m, n, k, bs, nr);
     }
 
+    double max_nmse_err() override {
+        return 5e-4;
+    }
+
     test_mul_mat(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
             int64_t m = 32, int64_t n = 32, int64_t k = 32,
             std::array<int64_t, 2> bs = {10, 10},
@@ -794,7 +802,128 @@ struct test_concat : public test_case {
     }
 };
 
-static bool test_backend(ggml_backend_t backend) {
+// GGML_OP_ARGSORT
+struct test_argsort : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    ggml_sort_order order;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, order);
+    }
+
+    test_argsort(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {16, 10, 10, 10},
+            ggml_sort_order order = GGML_SORT_ASC)
+        : type(type), ne(ne), order(order) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_argsort(ctx, a, order);
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (t->type == GGML_TYPE_I32) {
+                std::vector<int> data(ggml_nelements(t));
+                for (int i = 0; i < ggml_nelements(t); i++) {
+                    data[i] = rand();
+                }
+                std::shuffle(data.begin(), data.end(), std::default_random_engine(std::random_device()()));
+                ggml_backend_tensor_set(t, data.data(), 0, ne[0]*ne[1]*ne[2]*ne[3] * sizeof(int));
+            } else {
+                init_tensor_uniform(t);
+            }
+        }
+    }
+};
+
+
+// GGML_OP_MUL_MAT_ID
+struct test_mul_mat_id : public test_case {
+    const ggml_type type_a;
+    const ggml_type type_b;
+    const int n_mats;
+    const int id;
+    const int64_t m;
+    const int64_t n;
+    const int64_t k;
+    const std::array<int64_t, 2> bs; // dims 3 and 4
+    const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
+
+    std::string vars() override {
+        return VARS_TO_STR9(type_a, type_b, n_mats, id, m, n, k, bs, nr);
+    }
+
+    double max_nmse_err() override {
+        return 5e-4;
+    }
+
+    test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
+            int n_mats = 2, int id = 0,
+            int64_t m = 32, int64_t n = 32, int64_t k = 32,
+            std::array<int64_t, 2> bs = {10, 10},
+            std::array<int64_t, 2> nr = {2, 2})
+        : type_a(type_a), type_b(type_b), n_mats(n_mats), id(id),
+            m(m), n(n), k(k), bs(bs), nr(nr) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        // C^T = A * B^T: (k, m) * (k, n) => (m, n)
+        std::vector<ggml_tensor *> mats;
+        for (int i = 0; i < n_mats; i++) {
+            ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0]*nr[0], bs[1]*nr[1]);
+            mats.push_back(a);
+        }
+        ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_mats);
+        ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
+        ggml_tensor * out = ggml_mul_mat_id(ctx, mats.data(), ids, id, b);
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (t->type == GGML_TYPE_I32) {
+                // ids
+                std::vector<int> data(n_mats);
+                for (int i = 0; i < n_mats; i++) {
+                    data[i] = i;
+                }
+                std::shuffle(data.begin(), data.end(), std::default_random_engine(std::random_device()()));
+                ggml_backend_tensor_set(t, data.data(), 0, n_mats * sizeof(int));
+            } else {
+                init_tensor_uniform(t);
+            }
+        }
+    }
+};
+
+// GGML_OP_SUM_ROWS
+struct test_sum_rows : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_sum_rows(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 10})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_sum_rows(ctx, a);
+        return out;
+    }
+};
+
+enum test_mode {
+    MODE_TEST,
+    MODE_PERF,
+};
+
+static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
     ggml_backend_t backend_cpu = ggml_backend_cpu_init();
 
     std::vector<std::unique_ptr<test_case>> test_cases;
@@ -814,27 +943,22 @@ static bool test_backend(ggml_backend_t backend) {
     test_cases.emplace_back(new test_cpy());
     test_cases.emplace_back(new test_cont());
 
-    test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 1, 1}, {1, 1, 1, 1}));
-    test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 10, 1}, {1, 1, 1, 1}));
-    test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 1, 1}));
-    //test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 10, 10}, {2, 1, 1, 1})); // broadcasting dim 0 is not supported
-    test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 2, 1, 1}));
-    test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 2, 1}));
-    test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 1, 2}));
-    test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 2, 2}));
-    test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 2, 2, 2}));
-    //test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 10, 10}, {2, 2, 2, 2}));
-
-    test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 1, 1}, {1, 1, 1, 1}));
-    test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 10, 1}, {1, 1, 1, 1}));
-    test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 1, 1}));
-    //test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 10, 10}, {2, 1, 1, 1})); // broadcasting dim 0 is not supported
-    test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 2, 1, 1}));
-    test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 2, 1}));
-    test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 1, 2}));
-    test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 2, 2}));
-    test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 2, 2, 2}));
-    //test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 10, 10}, {2, 2, 2, 2}));
+    auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
+        for (auto op : {ggml_add, ggml_mul, ggml_div}) {
+            test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr));
+        }
+    };
+
+    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 1, 1}, {1, 1, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 1}, {1, 1, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {2, 1, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 2, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 2, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 1, 2});
+    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 2, 2});
+    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 2, 2, 2});
+    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {2, 2, 2, 2});
 
     test_cases.emplace_back(new test_scale());
 
@@ -843,16 +967,34 @@ static bool test_backend(ggml_backend_t backend) {
         test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
     }
 
-    for (ggml_type t0 : {GGML_TYPE_F32, GGML_TYPE_F16}) {
-        for (ggml_type t1 : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
+    ggml_type all_types[] = {
+        GGML_TYPE_F32, GGML_TYPE_F16,
+        GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
+        GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
+        GGML_TYPE_Q8_0,
+        GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
+        GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
+        GGML_TYPE_Q6_K
+    };
+
+    for (ggml_type type_a : all_types) {
+        for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
             // FIXME: CPU crashes on f16xf16
-            test_cases.emplace_back(new test_mul_mat(t0, t1, 32, 32, 32, { 1,  1}, {1, 1}));
-            test_cases.emplace_back(new test_mul_mat(t0, t1, 32, 32, 32, {10,  1}, {1, 1}));
-            test_cases.emplace_back(new test_mul_mat(t0, t1, 32, 32, 32, {10,  1}, {2, 1}));
-            test_cases.emplace_back(new test_mul_mat(t0, t1, 32, 32, 32, {10, 10}, {1, 1}));
-            test_cases.emplace_back(new test_mul_mat(t0, t1, 32, 32, 32, {10, 10}, {2, 1}));
-            test_cases.emplace_back(new test_mul_mat(t0, t1, 32, 32, 32, {10, 10}, {1, 2}));
-            test_cases.emplace_back(new test_mul_mat(t0, t1, 32, 32, 32, {10, 10}, {2, 2}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1,  1}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10,  1}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10,  1}, {2, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2}));
+
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, { 1,  1}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10,  1}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10,  1}, {2, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 2}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
         }
     }
 
@@ -881,9 +1023,26 @@ static bool test_backend(ggml_backend_t backend) {
     test_cases.emplace_back(new test_im2col());
     test_cases.emplace_back(new test_concat());
 
+    for (ggml_sort_order order : {GGML_SORT_ASC, GGML_SORT_DESC}) {
+        test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
+    }
+
+    for (ggml_type type_a : all_types) {
+        for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
+            for (int n_mats : {1, 2, 4}) {
+                for (int id = 0; id < n_mats; id++) {
+                    test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, {1, 1}, {1, 1}));
+                }
+            }
+        }
+    }
+
+    test_cases.emplace_back(new test_sum_rows());
+
+    // run tests
     size_t n_ok = 0;
     for (auto & test : test_cases) {
-        if (test->eval(backend, backend_cpu)) {
+        if (test->eval(backend, backend_cpu, op_name)) {
             n_ok++;
         }
     }
@@ -895,7 +1054,44 @@ static bool test_backend(ggml_backend_t backend) {
     return n_ok == test_cases.size();
 }
 
-int main() {
+static void usage(char ** argv) {
+    // command line: test-backend-ops [mode] [-o op] [-b backend]
+    // modes are correctness (compare with CPU) or performance
+    printf("Usage: %s [mode] [-o op] [-b backend]\n", argv[0]);
+    printf("  valid modes are: test (compare with CPU backend for correctness) or perf (performance evaluation) [not implemented]\n");
+    printf("  op names are as given ggml_op_desc()\n");
+}
+
+int main(int argc, char ** argv) {
+    test_mode mode = MODE_TEST;
+    const char * op_name = NULL;
+    const char * backend = NULL;
+
+    for (int i = 1; i < argc; i++) {
+        if (strcmp(argv[i], "test") == 0) {
+            mode = MODE_TEST;
+        } else if (strcmp(argv[i], "perf") == 0) {
+            mode = MODE_PERF;
+        } else if (strcmp(argv[i], "-o") == 0) {
+            if (i + 1 < argc) {
+                op_name = argv[++i];
+            } else {
+                usage(argv);
+                return 1;
+            }
+        } else if (strcmp(argv[i], "-b") == 0) {
+            if (i + 1 < argc) {
+                backend = argv[++i];
+            } else {
+                usage(argv);
+                return 1;
+            }
+        } else {
+            usage(argv);
+            return 1;
+        }
+    }
+
     // enumerate backends
     printf("Testing %zu backends\n\n", ggml_backend_reg_get_count());
 
@@ -904,11 +1100,17 @@ int main() {
     for (size_t i = 0; i < ggml_backend_reg_get_count(); i++) {
         printf("Backend %zu/%zu (%s)\n", i + 1, ggml_backend_reg_get_count(), ggml_backend_reg_get_name(i));
 
+        if (backend != NULL && strcmp(backend, ggml_backend_reg_get_name(i)) != 0) {
+            printf("  Skipping\n");
+            n_ok++;
+            continue;
+        }
+
         ggml_backend_t backend = ggml_backend_reg_init_backend(i, NULL);
         GGML_ASSERT(backend != NULL);
         printf("  Backend name: %s\n", ggml_backend_name(backend));
 
-        bool ok = test_backend(backend);
+        bool ok = test_backend(backend, mode, op_name);
 
         printf("  Backend %s: ", ggml_backend_name(backend));
         if (ok) {
index 0e5c75f7e7598447d9c9c0c45901c4487dd98320..5b74156853ed58566f9c850a1f2ab4b7077b45e6 100644 (file)
 #include <string>
 #include <vector>
 
+static void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
+    (void) level;
+    (void) user_data;
+    fputs(text, stderr);
+    fflush(stderr);
+}
+
 struct test_model {
     struct ggml_tensor * a;
     struct ggml_tensor * b;
index 64eee784809bfa9f8007b691126e073592f096ed..f50a53afa244f3c55828fd0a714cee054f1d35ac 100644 (file)
 #include <string>
 #include <vector>
 
+static void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
+    (void) level;
+    (void) user_data;
+    fputs(text, stderr);
+    fflush(stderr);
+}
+
 struct test_model {
     struct ggml_tensor * a;
     struct ggml_tensor * b;
index 1811492c2231f2c64d0e9279e12bd7556ec7783d..2bee73393f127b61c7da8902489732082e09d42b 100644 (file)
 #include <string>
 #include <vector>
 
+static void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
+    (void) level;
+    (void) user_data;
+    fputs(text, stderr);
+    fflush(stderr);
+}
+
 struct test_model {
     struct ggml_tensor * a;
     struct ggml_tensor * b;