]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : sync llama.cpp (#409)
authorGeorgi Gerganov <redacted>
Sun, 23 Jul 2023 14:51:29 +0000 (17:51 +0300)
committerGitHub <redacted>
Sun, 23 Jul 2023 14:51:29 +0000 (17:51 +0300)
* ggml : sync llama.cpp

ggml-ci

* ggml : fix nullptr derefs in backward

* ci : add mnist test, import/export graph

* add op_params to ggml_graph_export/import

ggml-ci

* mnist : export/import op_params for testing purposes

* mnist : fix f32 model generation test + instructions

ggml-ci

* ci : install python deps even for low-perf builds

ggml-ci

---------

Co-authored-by: Diego Devesa <redacted>
ci/run.sh
examples/mnist/README.md
examples/mnist/main-cpu.cpp
examples/mnist/main.cpp
include/ggml/ggml.h
src/ggml-cuda.cu
src/ggml-metal.m
src/ggml-metal.metal
src/ggml.c

index 973a0fe3d6c20501e9cbb0d431aec8d40212ded9..abb43e76735872a511790d08ab8d927f7f84f449 100644 (file)
--- a/ci/run.sh
+++ b/ci/run.sh
@@ -1,4 +1,15 @@
 #/bin/bash
+#
+# sample usage:
+#
+# mkdir tmp
+#
+# # CPU-only build
+# bash ./ci/run.sh ./tmp/results ./tmp/mnt
+#
+# # with CUDA support
+# GG_BUILD_CUDA=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
+#
 
 if [ -z "$2" ]; then
     echo "usage: $0 <output-dir> <mnt-dir>"
@@ -190,23 +201,57 @@ function gg_sum_mpt {
     gg_printf '```\n'
 }
 
+# mnist
+
+function gg_run_mnist {
+    cd ${SRC}
+
+    cd build-ci-release
+
+    set -e
+
+    mkdir -p models/mnist
+    python3 ../examples/mnist/convert-h5-to-ggml.py ../examples/mnist/models/mnist/mnist_model.state_dict
+
+    model_f32="./models/mnist/ggml-model-f32.bin"
+    samples="../examples/mnist/models/mnist/t10k-images.idx3-ubyte"
+
+    # first command runs and exports "mnist.ggml", the second command runs the exported model
+
+    (time ./bin/mnist     ${model_f32} ${samples} ) 2>&1 | tee -a $OUT/${ci}-mnist.log
+    (time ./bin/mnist-cpu ./mnist.ggml ${samples} ) 2>&1 | tee -a $OUT/${ci}-mnist.log
+
+    set +e
+}
+
+function gg_sum_mnist {
+    gg_printf '### %s\n\n' "${ci}"
+
+    gg_printf 'MNIST\n'
+    gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)"
+    gg_printf '```\n'
+    gg_printf '%s\n' "$(cat $OUT/${ci}-mnist.log)"
+    gg_printf '```\n'
+}
+
 ## main
 
 if [ -z $GG_BUILD_LOW_PERF ]; then
     rm -rf ${SRC}/models-mnt
 
-    mnt_models=$(realpath ${MNT}/models)
+    mnt_models=${MNT}/models
     mkdir -p ${mnt_models}
     ln -sfn ${mnt_models} ${SRC}/models-mnt
-
-    python3 -m pip install -r ${SRC}/requirements.txt
 fi
 
+python3 -m pip install -r ${SRC}/requirements.txt
+
 ret=0
 
 test $ret -eq 0 && gg_run ctest_debug
 test $ret -eq 0 && gg_run ctest_release
 test $ret -eq 0 && gg_run gpt_2
+test $ret -eq 0 && gg_run mnist
 
 if [ -z $GG_BUILD_LOW_PERF ]; then
     test $ret -eq 0 && gg_run mpt
index 0f2ed8c2e23b262ab6f49a02aeffd3cc95f9bf88..3bb436cb5c144141d28ca6b7763940cf37f0afef 100644 (file)
@@ -41,8 +41,13 @@ mkdir build && cd build
 cmake ..
 make -j4 mnist
 
+# Generate ggml model
+mkdir -p models/mnist
+python3 ../examples/mnist/convert-h5-to-ggml.py ../examples/mnist/models/mnist/mnist_model.state_dict
+
 # Run the MNIST model
-./bin/mnist ../examples/mnist/models/mnist/ggml-model-f32.bin ../examples/mnist/models/mnist/t10k-images.idx3-ubyte
+
+./bin/mnist ./models/mnist/ggml-model-f32.bin ../examples/mnist/models/mnist/t10k-images.idx3-ubyte
 ```
 
 For more information, checkout the corresponding programs in the [examples](examples) folder.
index 2000c9aac6274b44fe9bcde4902f6f431902291e..3e8bfe674f05e54da6b2e70d9b80c58dfaa3af0e 100644 (file)
@@ -42,6 +42,9 @@ int mnist_eval(
 
     struct ggml_cgraph gfi = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
 
+    // param export/import test
+    GGML_ASSERT(ggml_graph_get_tensor(&gfi, "fc1_bias")->op_params[0] == 0xdeadbeef);
+
     // allocate work context
     // needed during ggml_graph_compute() to allocate a work tensor
     static size_t buf_size = 128ull*1024*1024; // TODO
index 5ff4ac20fc93b9826c49ce36912ad60e23fcd077..33986e4ef573205d7e6f912f632461a0fa16c620 100644 (file)
@@ -119,6 +119,9 @@ bool mnist_model_load(const std::string & fname, mnist_model & model) {
             model.fc1_bias = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model.hparams.n_hidden);
             fin.read(reinterpret_cast<char *>(model.fc1_bias->data), ggml_nbytes(model.fc1_bias));
             ggml_set_name(model.fc1_bias, "fc1_bias");
+
+            // just for testing purposes, set some parameters to non-zero
+            model.fc1_bias->op_params[0] = 0xdeadbeef;
         }
     }
 
index 24856a255c9a6d57683b024c979f159e7f78f64e..871c85a89aae796f841fce9bd9bee0284929f2c6 100644 (file)
 #define GGML_MAX_CONTEXTS      64
 #define GGML_MAX_SRC           6
 #define GGML_MAX_NAME          48
+#define GGML_MAX_OP_PARAMS     32
 #define GGML_DEFAULT_N_THREADS 4
 
 
@@ -418,6 +419,9 @@ extern "C" {
         // compute data
         enum ggml_op op;
 
+        // op params - allocated as int32_t for alignment
+        int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(uint32_t)];
+
         bool is_param;
 
         struct ggml_tensor * grad;
@@ -1128,9 +1132,9 @@ extern "C" {
             int                   n_past,
             int                   n_dims,
             int                   mode,
+            int                   n_ctx,
             float                 freq_base,
-            float                 freq_scale,
-            int                   n_ctx);
+            float                 freq_scale);
 
     // rotary position embedding backward, i.e compute dx from dy
     // a - dy
@@ -1139,7 +1143,8 @@ extern "C" {
             struct ggml_tensor  * a,
             int                   n_past,
             int                   n_dims,
-            int                   mode);
+            int                   mode,
+            int                   n_ctx);
 
     // alibi position embedding
     // in-place, returns view(a)
index d3054a7fac71d344918303480240bcd24d4d2b9b..6fb55d838dfb3e4a7ec505f17a3adf334e085ffe 100644 (file)
@@ -220,7 +220,7 @@ typedef struct {
 static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
 
 #define WARP_SIZE 32
-#define MATRIX_ROW_PADDING 256 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
+#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
 
 #define CUDA_ADD_BLOCK_SIZE 256
 #define CUDA_MUL_BLOCK_SIZE 256
@@ -935,12 +935,18 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
     uint16_t aux[4];
     const uint8_t * sc = (const uint8_t *)aux;
 
+#if K_QUANTS_PER_ITERATION == 2
+    uint32_t q32[4];
+    const uint8_t * q4 = (const uint8_t *)q32;
+#else
+    uint16_t q16[4];
+    const uint8_t * q4 = (const uint8_t *)q16;
+#endif
+
     float tmp = 0; // partial sum for thread in warp
 
     for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
 
-        const uint8_t * q1 = x[i].qs + q_offset;
-        const uint8_t * q2 = q1 + 64;
         const float   * y1 = yy + i*QK_K + y_offset;
         const float   * y2 = y1 + 128;
 
@@ -953,14 +959,41 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
         aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
         aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
 
+#if K_QUANTS_PER_ITERATION == 2
+        const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset);
+        const uint32_t * q2 = q1 + 16;
+
+        q32[0] = q1[0] & 0x0f0f0f0f;
+        q32[1] = q1[0] & 0xf0f0f0f0;
+        q32[2] = q2[0] & 0x0f0f0f0f;
+        q32[3] = q2[0] & 0xf0f0f0f0;
+
         float4 s = {0.f, 0.f, 0.f, 0.f};
         float smin = 0;
-        for (int l = 0; l < n; ++l) {
-            s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4);
-            s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4);
+        for (int l = 0; l < 4; ++l) {
+            s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+ 4];
+            s.z += y2[l] * q4[l+8]; s.w += y2[l+32] * q4[l+12];
             smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
         }
-        tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
+        tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
+#else
+        const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset);
+        const uint16_t * q2 = q1 + 32;
+
+        q16[0] = q1[0] & 0x0f0f;
+        q16[1] = q1[0] & 0xf0f0;
+        q16[2] = q2[0] & 0x0f0f;
+        q16[3] = q2[0] & 0xf0f0;
+
+        float4 s = {0.f, 0.f, 0.f, 0.f};
+        float smin = 0;
+        for (int l = 0; l < 2; ++l) {
+            s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2];
+            s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6];
+            smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
+        }
+        tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
+#endif
 
     }
 #else
@@ -1521,7 +1554,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
     const block_q4_K * bq4_K = (const block_q4_K *) vbq;
 
-    const int bq8_offset = QR4_K * (iqs / QI8_1);
+    const int bq8_offset = QR4_K * (iqs / QI8_1); // 0, 2, 4, 6
 
     float sumf_d = 0.0f;
     float sumf_m = 0.0f;
@@ -1531,11 +1564,20 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
 
     const int v = *((int *) &bq4_K->qs[sizeof(int) * iqs]);
 
-    for (int i = 0; i < QR4_K; ++i) {
-        const int isc = bq8_offset + i;
+    const uint16_t * scales = (const uint16_t *)bq4_K->scales;
+    uint16_t aux[2];
+    const int j = bq8_offset/2;
+    if (j < 2) {
+        aux[0] = scales[j+0] & 0x3f3f;
+        aux[1] = scales[j+2] & 0x3f3f;
+    } else {
+        aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
+        aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
+    }
+    const uint8_t * sc = (const uint8_t *)aux;
+    const uint8_t * m  = sc + 2;
 
-        uint8_t sc, m;
-        get_scale_min_k4(isc, bq4_K->scales, sc, m);
+    for (int i = 0; i < QR4_K; ++i) {
 
         const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
         const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
@@ -1543,8 +1585,8 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
 
         const int vi = (v >> (4*i)) & 0x0F0F0F0F;
 
-        sumf_d += d8i * (__dp4a(vi,         ui, 0) * sc); // SIMD dot product
-        sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m);  // multiply constant part of q4_K with sum of q8_1 values
+        sumf_d += d8i * (__dp4a(vi,         ui, 0) * sc[i]); // SIMD dot product
+        sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m[i]);  // multiply constant part of q4_K with sum of q8_1 values
     }
 
     return d*sumf_d - dmin*sumf_m;
@@ -1745,11 +1787,15 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
     }
 }
 
-static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
+static __global__ void mul_mat_p021_f16_f32(
+    const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) {
+
     const half * x = (const half *) vx;
 
     const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
     const int channel = blockDim.z*blockIdx.z + threadIdx.z;
+    const int channel_x = channel / (nchannels_y / nchannels_x);
 
     const int nrows_y = ncols_x;
     const int nrows_dst = nrows_x;
@@ -1765,7 +1811,7 @@ static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const
         }
 
         // x is transposed and permuted
-        const int ix = row_x*nchannels_x*ncols_x + channel*ncols_x + col_x;
+        const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x;
         const float xi = __half2float(x[ix]);
 
         const int row_y = col_x;
@@ -1793,12 +1839,13 @@ static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const
 
 static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
     const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
-    const int row_stride_x, const int channel_stride_x) {
+    const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) {
 
     const half * x = (const half *) vx;
 
     const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
     const int channel = blockDim.z*blockIdx.z + threadIdx.z;
+    const int channel_x = channel / channel_x_divisor;
 
     const int nrows_y = ncols_x;
     const int nrows_dst = nrows_x;
@@ -1815,7 +1862,7 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
             break;
         }
 
-        const int ix = channel*channel_stride_x + row_x*row_stride_x + col_x;
+        const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
         const float xi = __half2float(x[ix]);
 
         const int row_y = col_x;
@@ -2324,20 +2371,23 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
     }
 }
 
-static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) {
-    const dim3 block_nums(1, nrows_x, nchannels_x);
+static void ggml_mul_mat_p021_f16_f32_cuda(
+    const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
+    const int nchannels_x, const int nchannels_y, cudaStream_t stream) {
+
+    const dim3 block_nums(1, nrows_x, nchannels_y);
     const dim3 block_dims(WARP_SIZE, 1, 1);
-    mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x);
+    mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y);
 }
 
 static void ggml_mul_mat_vec_nc_f16_f32_cuda(
     const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x,
-    const int nchannels_x, const int channel_stride_x, cudaStream_t stream) {
+    const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) {
 
-    const dim3 block_nums(1, nrows_x, nchannels_x);
+    const dim3 block_nums(1, nrows_x, nchannels_y);
     const dim3 block_dims(WARP_SIZE, 1, 1);
     mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
-        (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x);
+        (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x);
 }
 
 static void ggml_cpy_f32_f32_cuda(
@@ -2423,20 +2473,53 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
     scoped_spin_lock lock(g_cuda_pool_lock);
     int id;
     CUDA_CHECK(cudaGetDevice(&id));
-
+#ifdef DEBUG_CUDA_MALLOC
+    int nnz = 0;
+    size_t max_size = 0, tot_size = 0;
+#endif
+    size_t best_diff = 1ull << 36;
+    int ibest = -1;
     for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
         cuda_buffer& b = g_cuda_buffer_pool[id][i];
-        if (b.size >= size && b.ptr != nullptr) {
-            void * ptr = b.ptr;
-            *actual_size = b.size;
-            b.ptr = nullptr;
-            b.size = 0;
-            return ptr;
+        if (b.ptr != nullptr) {
+#ifdef DEBUG_CUDA_MALLOC
+            ++nnz;
+            tot_size += b.size;
+            if (b.size > max_size) max_size = b.size;
+#endif
+            if (b.size >= size) {
+                size_t diff = b.size - size;
+                if (diff < best_diff) {
+                    best_diff = diff;
+                    ibest = i;
+                    if (!best_diff) {
+                        void * ptr = b.ptr;
+                        *actual_size = b.size;
+                        b.ptr = nullptr;
+                        b.size = 0;
+                        return ptr;
+                    }
+                }
+            }
         }
     }
+    if (ibest >= 0) {
+        cuda_buffer& b = g_cuda_buffer_pool[id][ibest];
+        void * ptr = b.ptr;
+        *actual_size = b.size;
+        b.ptr = nullptr;
+        b.size = 0;
+        return ptr;
+    }
+#ifdef DEBUG_CUDA_MALLOC
+    fprintf(stderr, "%s: %d buffers, max_size = %u MB, tot_size = %u MB, requested %u MB\n", __func__, nnz,
+            (uint32_t)(max_size/1024/1024), (uint32_t)(tot_size/1024/1024), (uint32_t)(size/1024/1024));
+#endif
     void * ptr;
-    CUDA_CHECK(cudaMalloc((void **) &ptr, size));
-    *actual_size = size;
+    size_t look_ahead_size = (size_t) (1.05 * size);
+    look_ahead_size = 256 * ((look_ahead_size + 255)/256);
+    CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size));
+    *actual_size = look_ahead_size;
     return ptr;
 }
 
@@ -2464,7 +2547,9 @@ static size_t g_scratch_offset = 0;
 
 static int g_device_count = -1;
 static int g_main_device = 0;
+#ifndef GGML_CUDA_FORCE_DMMV
 static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
+#endif
 static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
 
 static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
@@ -2487,7 +2572,9 @@ void ggml_init_cublas() {
             g_tensor_split[id] = total_vram;
             total_vram += prop.totalGlobalMem;
 
+#ifndef GGML_CUDA_FORCE_DMMV
             g_compute_capabilities[id] = 100*prop.major + 10*prop.minor;
+#endif
         }
         for (int id = 0; id < g_device_count; ++id) {
             g_tensor_split[id] /= total_vram;
@@ -2512,6 +2599,9 @@ void ggml_init_cublas() {
 }
 
 void ggml_cuda_set_tensor_split(const float * tensor_split) {
+    if (tensor_split == nullptr) {
+        return;
+    }
     bool all_zero = true;
     for (int i = 0; i < g_device_count; ++i) {
         if (tensor_split[i] != 0.0f) {
@@ -2652,6 +2742,7 @@ inline void ggml_cuda_op_mul(
     (void) dst;
     (void) src0_ddq_i;
     (void) i02;
+    (void) i1;
 }
 
 inline void ggml_cuda_op_gelu(
@@ -2779,8 +2870,8 @@ inline void ggml_cuda_op_mul_mat_vec(
 #endif
 
     if (use_mul_mat_vec_q) {
-        int64_t padded_row_size = ne00 + MATRIX_ROW_PADDING - 1;
-        padded_row_size -= padded_row_size % MATRIX_ROW_PADDING;
+        const int64_t padded_row_size = ne00 % MATRIX_ROW_PADDING == 0 ?
+            ne00 : ne00 - ne00 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
         size_t as;
         void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*sizeof(block_q8_1)/QK8_1, &as);
         quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, padded_row_size, cudaStream_main);
@@ -2947,13 +3038,18 @@ inline void ggml_cuda_op_rope(
     const int64_t ne00 = src0->ne[0];
     const int64_t i01_diff = i01_high - i01_low;
 
-    const int n_past = ((int32_t *) src1->data)[0];
-    const int n_dims = ((int32_t *) src1->data)[1];
-    const int mode   = ((int32_t *) src1->data)[2];
-    const int n_ctx  = ((int32_t *) src1->data)[3];
+    const int n_past = ((int32_t *) dst->op_params)[0];
+    const int n_dims = ((int32_t *) dst->op_params)[1];
+    const int mode   = ((int32_t *) dst->op_params)[2];
+    const int n_ctx  = ((int32_t *) dst->op_params)[3];
+    // RoPE alteration for extended context
 
-    const float theta_scale = powf(10000.0, -2.0f/n_dims);
-    const float p = ((mode & 1) == 0 ? n_past + i02 : i02);
+    float freq_base, freq_scale;
+    memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
+    memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
+
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+    const float p = (((mode & 1) == 0 ? n_past + i02 : i02)) * freq_scale;
 
     bool is_glm = mode & 4;
 
@@ -2966,6 +3062,7 @@ inline void ggml_cuda_op_rope(
         rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
     }
 
+    (void) src1;
     (void) dst;
     (void) src0_ddq_i;
     (void) src1_ddf_i;
@@ -2984,11 +3081,12 @@ inline void ggml_cuda_op_diag_mask_inf(
     const int64_t ne01 = src0->ne[1];
     const int64_t i01_diff = i01_high - i01_low;
 
-    const int n_past = ((int32_t *) src1->data)[0];
+    const int n_past = ((int32_t *) dst->op_params)[0];
 
     // compute
     diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main);
 
+    (void) src1;
     (void) dst;
     (void) src0_ddq_i;
     (void) src1_ddf_i;
@@ -3056,6 +3154,9 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
     const int64_t ne11 = use_src1 ? src1->ne[1] : 1;
     const int64_t ne12 = use_src1 ? src1->ne[2] : 1;
     const int64_t ne13 = use_src1 ? src1->ne[3] : 1;
+    const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
+
+    GGML_ASSERT(ne03 == ne13);
 
     const int64_t ne0 = dst->ne[0];
     const int64_t ne1 = dst->ne[1];
@@ -3067,12 +3168,19 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
     GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT);
 
     // strides for iteration over dims 3 and 2
-    const int64_t num_iters = flatten_rows ? 1 : ne02 * ne03;
-    const int64_t stride_mod = flatten_rows ? ne02 * ne03 : 1;
+    const int64_t num_iters_0 = ne02 >= ne12 ? ne02*ne03 : ne12*ne13;
+    const int64_t num_iters = flatten_rows ? 1 : num_iters_0;
+    const int64_t stride_mod = flatten_rows ? num_iters_0 : 1;
     const int64_t src0_stride = ne00 * ne01 * stride_mod;
     const int64_t src1_stride = ne10 * ne11 * stride_mod;
     const int64_t dst_stride = ne0 * ne1 * stride_mod;
 
+    const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01;
+    const int64_t i03_max = flatten_rows ? 1 : ne03;
+    const int64_t i02_max = flatten_rows ? 1 : (ne02 >= ne12 ? ne02 : ne12);
+    const int64_t i02_divisor = ne02 >= ne12 ? 1 : ne12 / ne02;
+    GGML_ASSERT(!(flatten_rows && ne02 < ne12));
+
     const size_t src0_ts = ggml_type_size(src0->type);
     const size_t src0_bs = ggml_blck_size(src0->type);
 
@@ -3089,6 +3197,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
         dst->op == GGML_OP_SCALE || dst->op == GGML_OP_DIAG_MASK_INF || dst->op == GGML_OP_ROPE);
 
     const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
+    GGML_ASSERT(!(split && ne02 < ne12));
 
     const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
 
@@ -3125,7 +3234,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
             row_high = id == g_device_count - 1 ? nrows0 : nrows0*g_tensor_split[id + 1];
         } else {
             row_low = 0;
-            row_high = nrows0;
+            row_high = nrows0*i02_divisor;
         }
         if (row_low == row_high) {
             continue;
@@ -3173,16 +3282,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
             dst_ddf[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_asf[id]);
         }
 
-        const int64_t i03_max = flatten_rows ? 1 : ne03;
-        const int64_t i02_max = flatten_rows ? 1 : ne02;
-        const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01;
-
         for (int64_t i03 = 0; i03 < i03_max; i03++) {
             const int64_t i13 = i03 % ne13;
             for (int64_t i02 = 0; i02 < i02_max; i02++) {
                 const int64_t i12 = i02 % ne12;
 
-                const int64_t i0 = i03*ne02 + i02;
+                const int64_t i0 = i03*i02_max + i02;
 
                 // i0 values that contain the lower/upper rows for a split tensor when using multiple GPUs
                 const int64_t i0_offset_low = row_low/rows_per_iter;
@@ -3216,10 +3321,10 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
                 const int64_t i11 = i13*ne12 + i12;
 
                 // for split tensors the data begins at i0 == i0_offset_low
-                char  * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
-                float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride;
+                char  * src0_ddq_i = src0_ddq[id] + (i0/i02_divisor - i0_offset_low)*src0_stride*src0_ts/src0_bs;
+                float * src0_ddf_i = src0_ddf[id] + (i0/i02_divisor - i0_offset_low)*src0_stride;
                 float * src1_ddf_i = src1_ddf[id] + i11*src1_stride;
-                float * dst_ddf_i  =  dst_ddf[id] + (i0 - i0_offset_low)*dst_stride;
+                float * dst_ddf_i  =  dst_ddf[id] + (i0             - i0_offset_low)*dst_stride;
 
                 // for split tensors the data pointer needs to be rounded down
                 // to the bin edge for i03, i02 bins beyond the first
@@ -3258,11 +3363,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
                     }
                 }
 
-                if (!src0_on_device || !src0_is_contiguous) {
+                if ((!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) {
                     if (src0_is_f32) {
-                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
+                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main));
                     } else {
-                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
+                        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main));
                     }
                 }
 
@@ -3416,6 +3521,8 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
     const int64_t ne01 = src0->ne[1];
     const int64_t ne02 = src0->ne[2];
 
+    const int64_t ne12 = src1->ne[2];
+
     CUDA_CHECK(cudaSetDevice(g_main_device));
     cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
 
@@ -3428,7 +3535,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
     struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
     float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
 
-    ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main);
+    ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, cudaStream_main);
 }
 
 void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
@@ -3442,6 +3549,8 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
     const int64_t ne01 = src0->ne[1];
     const int64_t ne02 = src0->ne[2];
 
+    const int64_t ne12 = src1->ne[2];
+
     const int64_t nb01 = src0->nb[1];
     const int64_t nb02 = src0->nb[2];
 
@@ -3460,7 +3569,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
     const int row_stride_x = nb01 / sizeof(half);
     const int channel_stride_x = nb02 / sizeof(half);
 
-    ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main);
+    ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, cudaStream_main);
 }
 
 void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -3601,7 +3710,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
         size_t size = ggml_nbytes_split(tensor, nrows_split);
         const size_t original_size = size;
 
-        // pad last row to a multiple of 256 elements to avoid out-of-bounds memory accesses
+        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
         if (ne0 % MATRIX_ROW_PADDING != 0) {
             size += (MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING)
                 * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
@@ -3617,7 +3726,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
         }
 
 
-        CUDA_CHECK(cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice));
+        CUDA_CHECK(cudaMemcpy(buf, buf_host, original_size, cudaMemcpyHostToDevice));
 
         extra->data_device[id] = buf;
 
@@ -3697,7 +3806,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
         char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
         size_t offset = 0;
         if (tensor->op == GGML_OP_VIEW) {
-            memcpy(&offset, tensor->src[2]->data, sizeof(size_t));
+            memcpy(&offset, tensor->op_params, sizeof(size_t));
         }
         extra = ggml_cuda_alloc_temp_tensor_extra();
         extra->data_device[g_main_device] = src0_ddc + offset;
index ee205bcdf773ca1666016ac63aae16a7aafdfb96..bf3f68fe45726ca93a10b9293200d11c717e7789 100644 (file)
@@ -42,6 +42,7 @@ struct ggml_metal_context {
     id<MTLComputePipelineState> pipeline_##name
 
     GGML_METAL_DECL_KERNEL(add);
+    GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
     GGML_METAL_DECL_KERNEL(mul);
     GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
     GGML_METAL_DECL_KERNEL(scale);
@@ -157,6 +158,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
 
         GGML_METAL_ADD_KERNEL(add);
+        GGML_METAL_ADD_KERNEL(add_row);
         GGML_METAL_ADD_KERNEL(mul);
         GGML_METAL_ADD_KERNEL(mul_row);
         GGML_METAL_ADD_KERNEL(scale);
@@ -464,10 +466,16 @@ void ggml_metal_graph_compute(
                                 encoder = [command_buffer computeCommandEncoder];
                             }
 
-                            [encoder setComputePipelineState:ctx->pipeline_add];
+                            if (ggml_nelements(src1) == ne10) {
+                                // src1 is a row
+                                [encoder setComputePipelineState:ctx->pipeline_add_row];
+                            } else {
+                                [encoder setComputePipelineState:ctx->pipeline_add];
+                            }
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
 
                             const int64_t n = ggml_nelements(dst);
 
@@ -577,7 +585,7 @@ void ggml_metal_graph_compute(
                                 encoder = [command_buffer computeCommandEncoder];
                             }
 
-                            const int n_past = ((int32_t *)(src1->data))[0];
+                            const int n_past = ((int32_t *)(dst->op_params))[0];
 
                             [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -676,8 +684,8 @@ void ggml_metal_graph_compute(
                                             GGML_ASSERT(ne02 == 1);
                                             GGML_ASSERT(ne12 == 1);
 
-                                            nth0 = 4;
-                                            nth1 = 16;
+                                            nth0 = 2;
+                                            nth1 = 32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
                                         } break;
                                     case GGML_TYPE_Q3_K:
@@ -685,8 +693,8 @@ void ggml_metal_graph_compute(
                                             GGML_ASSERT(ne02 == 1);
                                             GGML_ASSERT(ne12 == 1);
 
-                                            nth0 = 4;
-                                            nth1 = 16;
+                                            nth0 = 2;
+                                            nth1 = 32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
                                         } break;
                                     case GGML_TYPE_Q4_K:
@@ -694,8 +702,8 @@ void ggml_metal_graph_compute(
                                             GGML_ASSERT(ne02 == 1);
                                             GGML_ASSERT(ne12 == 1);
 
-                                            nth0 = 4;
-                                            nth1 = 16;
+                                            nth0 = 2;
+                                            nth1 = 32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
                                         } break;
                                     case GGML_TYPE_Q5_K:
@@ -703,8 +711,8 @@ void ggml_metal_graph_compute(
                                             GGML_ASSERT(ne02 == 1);
                                             GGML_ASSERT(ne12 == 1);
 
-                                            nth0 = 4;
-                                            nth1 = 16;
+                                            nth0 = 2;
+                                            nth1 = 32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
                                         } break;
                                     case GGML_TYPE_Q6_K:
@@ -712,8 +720,8 @@ void ggml_metal_graph_compute(
                                             GGML_ASSERT(ne02 == 1);
                                             GGML_ASSERT(ne12 == 1);
 
-                                            nth0 = 4;
-                                            nth1 = 16;
+                                            nth0 = 2;
+                                            nth1 = 32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
                                         } break;
                                     default:
@@ -739,16 +747,22 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:13];
                                 [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:14];
 
-                                if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
+                                if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
+                                    src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
-                                else if (src0t == GGML_TYPE_Q2_K ||
-                                         src0t == GGML_TYPE_Q3_K ||
-                                         src0t == GGML_TYPE_Q4_K ||
-                                         src0t == GGML_TYPE_Q5_K ||
-                                         src0t == GGML_TYPE_Q6_K) {
-                                    [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
-                                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                else if (src0t == GGML_TYPE_Q3_K) {
+#ifdef GGML_QKK_64
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+#else
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+#endif
+                                }
+                                else if (src0t == GGML_TYPE_Q5_K) {
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                }
+                                else if (src0t == GGML_TYPE_Q6_K) {
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 } else {
                                     [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
                                     [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -792,7 +806,7 @@ void ggml_metal_graph_compute(
 
                             const float eps = 1e-6f;
 
-                            const int nth = 256;
+                            const int nth = 512;
 
                             [encoder setComputePipelineState:ctx->pipeline_rms_norm];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -800,7 +814,7 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
                             [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
                             [encoder setBytes:&eps  length:sizeof(   float) atIndex:4];
-                            [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
+                            [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
 
                             const int64_t nrows = ggml_nrows(src0);
 
@@ -836,9 +850,10 @@ void ggml_metal_graph_compute(
 
                             GGML_ASSERT((src0t == GGML_TYPE_F32));
 
-                            const int   n_past   = ((int32_t *) src1->data)[0]; UNUSED(n_past);
-                            const int   n_head   = ((int32_t *) src1->data)[1];
-                            const float max_bias = ((float *)   src1->data)[2];
+                            const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
+                            const int n_head = ((int32_t *) dst->op_params)[1];
+                            float max_bias;
+                            memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
 
                             if (__builtin_popcount(n_head) != 1) {
                                 GGML_ASSERT(false && "only power-of-two n_head implemented");
@@ -876,15 +891,14 @@ void ggml_metal_graph_compute(
                                 encoder = [command_buffer computeCommandEncoder];
                             }
 
-                            const int n_dims = ((int32_t *) src1->data)[1];
-                            const int mode   = ((int32_t *) src1->data)[2];
-
-                            const int n_past = ((int32_t *)(src1->data))[0];
+                            const int n_past = ((int32_t *) dst->op_params)[0];
+                            const int n_dims = ((int32_t *) dst->op_params)[1];
+                            const int mode   = ((int32_t *) dst->op_params)[2];
 
                             float freq_base;
                             float freq_scale;
-                            memcpy(&freq_base,  (int32_t *) src1->data + 4, sizeof(float));
-                            memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
+                            memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
+                            memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
 
                             [encoder setComputePipelineState:ctx->pipeline_rope];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -913,7 +927,9 @@ void ggml_metal_graph_compute(
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                         } break;
+                    case GGML_OP_DUP:
                     case GGML_OP_CPY:
+                    case GGML_OP_CONT:
                         {
                             if (encoder == nil) {
                                 encoder = [command_buffer computeCommandEncoder];
index 9f9a4fbd74446e2c203fe2f8c65653ee6f5360f7..987376d560879b97af68f85b066b6e5e4fb5ad6d 100644 (file)
@@ -67,6 +67,17 @@ kernel void kernel_add(
     dst[tpig] = src0[tpig] + src1[tpig];
 }
 
+// assumption: src1 is a row
+// broadcast src1 into src0
+kernel void kernel_add_row(
+        device const float * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] + src1[tpig % ne00];
+}
+
 kernel void kernel_mul(
         device const float * src0,
         device const float * src1,
@@ -331,26 +342,33 @@ kernel void kernel_rms_norm(
         threadgroup float  * sum [[threadgroup(0)]],
         uint tgpig[[threadgroup_position_in_grid]],
         uint tpitg[[thread_position_in_threadgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]],
+        uint tiisg[[thread_index_in_simdgroup]],
         uint   ntg[[threads_per_threadgroup]]) {
-    device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
+    device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
+    device const float * x_scalar = (device const float *) x;
+    float4 sumf=0;
+    float all_sum=0;
 
     // parallel sum
-    sum[tpitg] = 0.0f;
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        sum[tpitg] += x[i00] * x[i00];
+    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+        sumf += x[i00] * x[i00];
+    }
+    all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
+    all_sum = simd_sum(all_sum);
+    if (tiisg == 0) {
+        sum[sgitg] = all_sum;
     }
 
-    // reduce
     threadgroup_barrier(mem_flags::mem_threadgroup);
-    for (uint i = ntg/2; i > 0; i /= 2) {
-        if (tpitg < i) {
-            sum[tpitg] += sum[tpitg + i];
-        }
-        threadgroup_barrier(mem_flags::mem_threadgroup);
+    // broadcast, simd group number is ntg / 32
+    for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
+       if (tpitg < i) {
+           sum[tpitg] += sum[tpitg + i];
+       }
     }
-
-    // broadcast
     if (tpitg == 0) {
+        for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
         sum[0] /= ne00;
     }
 
@@ -359,104 +377,102 @@ kernel void kernel_rms_norm(
     const float mean  = sum[0];
     const float scale = 1.0f/sqrt(mean + eps);
 
-    device float * y = dst + tgpig*ne00;
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+    device float4 * y = (device float4 *) (dst + tgpig*ne00);
+    device float * y_scalar = (device float *) y;
+    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
         y[i00] = x[i00] * scale;
     }
+    if (tpitg == 0) {
+        for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
+    }
+}
+
+// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
+float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) {
+    float d = qb_curr->d;
+    float4 acc = 0.f;
+    device uint16_t * qs = ((device uint16_t *)qb_curr + 1);
+    for (int i = 0; i < 16; i+=2) {
+        acc[0] += yl[i]      * (qs[i / 2] & 0x000F);
+        acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
+        acc[2] += yl[i +  1] * (qs[i / 2] & 0x0F00);
+        acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
+    }
+    return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f);
+}
+
+// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
+float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) {
+    float d = qb_curr->d;
+    float m = qb_curr->m;
+    float4 acc = 0.f;
+    device uint16_t * qs = ((device uint16_t *)qb_curr + 2);
+    for (int i = 0; i < 16; i+=2) {
+        acc[0] += yl[i]      * (qs[i / 2] & 0x000F);
+        acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
+        acc[2] += yl[i +  1] * (qs[i / 2] & 0x0F00);
+        acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
+    }
+    return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m;
 }
 
 // putting them in the kernel cause a significant performance penalty
 #define N_DST 4 // each SIMD group works on 4 rows
 #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
 #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
-kernel void kernel_mul_mat_q4_0_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne10,
-        constant   int64_t & ne0,
-        constant   int64_t & ne01[[buffer(4)]],
-        uint2 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+template<typename block_q_type>
+void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
+                    int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
+                    uint2 tgpig, uint tiisg, uint sgitg) {
     const int nb = ne00/QK4_0;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
-    device const block_q4_0 * x = (device const block_q4_0 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
+    device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
     device const float      * y = (device const float      *) src1 + r1*ne10;
-    block_q4_0 qb_curr, qb_next;
     float4 y_curr[8];       // src1 vector cache
     float sumf[N_DST]={0.f}, all_sum;
     thread float * yl=(thread float *)y_curr;
 
-    // bootstrap
-    qb_curr = x[tiisg];
     // each thread in a SIMD group deals with 1 block.
     for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
-
         float sumy = 0;
         for (int i = 0; i < QK4_0 / 4; i++) {
-            y_curr[i] = *((device float4  *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
+            y_curr[i] = *((device float4  *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i);
             sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
         }
-        sumy *= (-8.f);
 
         for (int row = 0; row < N_DST; row++) {
-            // prefetch next x block
-            qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
-
-            // calculate
-            float d = qb_curr.d;
-            float acc = sumy;
-            for (int i = 0; i < 16; i++) {
-                acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
-            }
-            sumf[row] += d * acc;
-            qb_curr = qb_next;
+            sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl);
         }
     }
 
-    if (nb % N_SIMDWIDTH == 0) {
-        for (int row = 0; row < N_DST; ++row) {
-            all_sum = simd_sum(sumf[row]);
-            if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
-                dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
-            }
-        }
-    } else {
-
+    // from now loads two rows every time and 16 blocks per row
+    int ir = tiisg / (N_SIMDWIDTH / 2);
+    int ib = tiisg % (N_SIMDWIDTH / 2);
+    for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) {
+        int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start
         float sumy = 0;
         for (int i = 0; i < QK4_0 / 4; i++) {
-            y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
+            y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i);
             sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
         }
-        sumy *= (-8.f);
 
-        for (int row = 0; row < N_DST; row++) {
-            // prefetch next x block
-            qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
-
-            // calculate
-            float d = qb_curr.d;
-            float acc = sumy;
-            for (int i = 0; i < 16; i++) {
-                acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
+        for (int row = 0; row < N_DST; row+=2) {
+            if (nb_start + ib < nb) {
+                sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl);
             }
-            if (tiisg < nb % N_SIMDWIDTH) {
-                sumf[row] += d * acc;
-            }
-            qb_curr = qb_next;
+        }
+    }
 
-            all_sum = simd_sum(sumf[row]);
-            if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
-                dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
-            }
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
+            dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
         }
     }
 }
 
-kernel void kernel_mul_mat_q4_1_f32(
+kernel void kernel_mul_mat_q4_0_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -467,80 +483,21 @@ kernel void kernel_mul_mat_q4_1_f32(
         uint2 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
-    const int nb = ne00/QK4_0;
-    const int r0 = tgpig.x;
-    const int r1 = tgpig.y;
-    device const block_q4_1 * x = (device const block_q4_1 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
-    device const float      * y = (device const float      *) src1 + r1*ne10;
-    block_q4_1 qb_curr, qb_next;
-    float4 y_curr[8];       // src1 vector cache
-    float sumf[N_DST]={0.f}, all_sum;
-    thread float * yl=(thread float *)y_curr;
-
-    // bootstrap
-    qb_curr = x[tiisg];
-    // each thread in a SIMD group deals with 1 block.
-    for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
-
-        float sumy = 0;
-        for (int i = 0; i < QK4_0 / 4; i++) {
-            y_curr[i] = *((device float4  *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
-            sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
-        }
-
-        for (int row = 0; row < N_DST; row++) {
-            // prefetch next x block
-            qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
-
-            // calculate
-            const float d = qb_curr.d;
-            const float m = qb_curr.m;
-            float acc = 0.f;
-            for (int i = 0; i < 16; i++) {
-                acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
-            }
-            sumf[row] += d * acc + m * sumy;
-            qb_curr = qb_next;
-        }
-    }
-
-    if (nb % N_SIMDWIDTH == 0) {
-        for (int row = 0; row < N_DST; ++row) {
-            all_sum = simd_sum(sumf[row]);
-            if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
-                dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
-            }
-        }
-    } else {
-
-        float sumy = 0;
-        for (int i = 0; i < QK4_0 / 4; i++) {
-            y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
-            sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
-        }
-
-        for (int row = 0; row < N_DST; row++) {
-            // prefetch next x block
-            qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
-
-            // calculate
-            const float d = qb_curr.d;
-            const float m = qb_curr.m;
-            float acc = 0.f;
-            for (int i = 0; i < 16; i++) {
-                acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
-            }
-            if (tiisg < nb % N_SIMDWIDTH) {
-                sumf[row] += d * acc + m * sumy;
-            }
-            qb_curr = qb_next;
+    mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
+}
 
-            all_sum = simd_sum(sumf[row]);
-            if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
-                dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
-            }
-        }
-    }
+kernel void kernel_mul_mat_q4_1_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne10,
+        constant   int64_t & ne0,
+        constant   int64_t & ne01[[buffer(4)]],
+        uint2 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+     mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mat_f16_f32(
@@ -1263,111 +1220,137 @@ kernel void kernel_mul_mat_q2_K_f32(
         constant   int64_t & ne00,
         constant   int64_t & ne10,
         constant   int64_t & ne0,
-        threadgroup float  * sum [[threadgroup(0)]],
+        constant   int64_t & ne01[[buffer(4)]],
         uint2 tgpig[[threadgroup_position_in_grid]],
-        uint2 tpitg[[thread_position_in_threadgroup]],
-        uint2  tptg[[threads_per_threadgroup]]) {
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const int nb = ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
 
-    const int64_t r0 = tgpig.x;
-    const int64_t r1 = tgpig.y;
-
-    device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
-    device const float     * yy = (device const float      *) src1 + r1*ne10;
-
-    const int nth = tptg.x*tptg.y;
-    const int ith = tptg.y*tpitg.x + tpitg.y;
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int ib_row = first_row * nb;
+    device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row;
+    device const float      * y = (device const float      *) src1 + r1*ne10;
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
 
-    float sumf = 0;
+    const int step = sizeof(block_q2_K) * nb;
 
 #if QK_K == 256
-    const int tid = tpitg.y;    // 0...16
-    const int il  = tid/4;      // 0...3
-    const int ir  = tid%4;      // 0...3
-    const int ip  = il/2;       // 0 or 1
-    const int shift1 = 4*(il%2);// 0 or 4
-    const int shift2 = shift1+2;// 2 or 6
-    const int n   = 8;
-    const int is  = 4*il + (n*ir)/16;
-
-    const int y_offset = 64*il + n*ir;
-    const int q_offset = 32*ip + n*ir;
-
-    for (int i = tpitg.x; i < nb; i += tptg.x) {
-
-        device const uint8_t * q = x[i].qs + q_offset;
-        device const uint8_t * scales = x[i].scales + is;
+    const int ix = tiisg/8;  // 0...3
+    const int it = tiisg%8;  // 0...7
+    const int im = it/4;     // 0 or 1
+    const int ir = it%4;     // 0...3
+    const int is = (8*ir)/16;// 0 or 1
+
+    device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
+
+    for (int ib = ix; ib < nb; ib += 4) {
+
+        float4 sumy = {0.f, 0.f, 0.f, 0.f};
+        for (int i = 0; i < 8; ++i) {
+            yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
+            yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
+            yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
+            yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
+        }
 
-        uint8_t d1 = scales[0] & 0xF;
-        uint8_t d2 = scales[2] & 0xF;
-        uint8_t m1 = scales[0] >>  4;
-        uint8_t m2 = scales[2] >>  4;
+        device const uint8_t  * sc = (device const uint8_t  *)x[ib].scales + 8*im + is;
+        device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
+        device const half     * dh = &x[ib].d;
 
-        device const float   * y = yy + i*QK_K + y_offset;
+        for (int row = 0; row < N_DST; row++) {
 
-        float2 s = {0.f, 0.f};
-        float smin = 0;
-        for (int l = 0; l < n; ++l) {
-            s[0] += y[l+ 0] * ((q[l] >> shift1) & 3);
-            s[1] += y[l+32] * ((q[l] >> shift2) & 3);
-            smin += y[l+ 0] * m1 + y[l+32] * m2;
+            float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+            float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+            for (int i = 0; i < 8; i += 2) {
+                acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
+                acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
+                acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
+                acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
+                acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
+                acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
+                acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
+                acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
+            }
+            float dall = dh[0];
+            float dmin = dh[1] * 1.f/16.f;
+            sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
+                                 (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
+                                 (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
+                                 (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
+                         dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
+
+            qs += step/2;
+            sc += step;
+            dh += step/2;
         }
 
-        const float dall = (float)x[i].d;
-        const float dmin = (float)x[i].dmin;
-
-        sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
-
+        y4 += 4 * QK_K;
     }
 #else
-    const int il = 4 * tpitg.x;
+    const int ix = tiisg/2;  // 0...15
+    const int it = tiisg%2;  // 0...1
 
-    uint32_t aux[2];
-    thread const uint8_t * d = (thread const uint8_t *)aux;
-    thread const uint8_t * m = (thread const uint8_t *)aux + 4;
+    device const float * y4 = y + ix * QK_K + 8 * it;
 
-    for (int i = tpitg.y; i < nb; i += tptg.y) {
+    for (int ib = ix; ib < nb; ib += 16) {
 
-        device const uint8_t * q = x[i].qs + il;
-        device const float   * y = yy + i*QK_K + il;
+        float4 sumy = {0.f, 0.f, 0.f, 0.f};
+        for (int i = 0; i < 8; ++i) {
+            yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
+            yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
+            yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
+            yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
+        }
 
-        const float dall = (float)x[i].d;
-        const float dmin = (float)x[i].dmin;
+        device const uint8_t  * sc = (device const uint8_t  *)x[ib].scales;
+        device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
+        device const half     * dh = &x[ib].d;
 
-        device const uint32_t * a = (device const uint32_t *)x[i].scales;
-        aux[0] = a[0] & 0x0f0f0f0f;
-        aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
+        for (int row = 0; row < N_DST; row++) {
 
-        for (int l = 0; l < 4; ++l) {
-            sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
-                  + y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
-                  + y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
-                  + y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
+            float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+            float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+            for (int i = 0; i < 8; i += 2) {
+                acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
+                acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
+                acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
+                acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
+                acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
+                acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
+                acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
+                acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
+            }
+
+            float dall = dh[0];
+            float dmin = dh[1];
+            sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
+                                 (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
+                                 (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
+                                 (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
+                         dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
+
+            qs += step/2;
+            sc += step;
+            dh += step/2;
         }
+
+        y4 += 16 * QK_K;
     }
 #endif
 
-    sum[ith] = sumf;
-
-    //
-    // Accumulate the sum from all threads in the threadgroup
-    //
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%4 == 0) {
-        for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%16 == 0) {
-        for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith == 0) {
-        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
-        dst[r1*ne0 + r0] = sum[0];
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + first_row + row] = all_sum;
+        }
     }
 }
 
+#if QK_K == 256
 kernel void kernel_mul_mat_q3_K_f32(
         device const  void * src0,
         device const float * src1,
@@ -1376,40 +1359,41 @@ kernel void kernel_mul_mat_q3_K_f32(
         constant   int64_t & ne10,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
-        threadgroup float  * sum [[threadgroup(0)]],
         uint2 tgpig[[threadgroup_position_in_grid]],
-        uint2 tpitg[[thread_position_in_threadgroup]],
-        uint2  tptg[[threads_per_threadgroup]]) {
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const int nb = ne00/QK_K;
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
 
-    device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
-    device const float     * yy = (device const float      *) src1 + r1*ne10;
-
-    const int nth = tptg.x*tptg.y;
-    const int ith = tptg.y*tpitg.x + tpitg.y;
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
 
-#if QK_K == 256
+    device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb;
+    device const float     * yy = (device const float      *) src1 + r1*ne10;
 
-    const uint8_t m3 = 3;
-    const int8_t  m4 = 4;
+    float yl[16];
 
     const uint16_t kmask1 = 0x0303;
     const uint16_t kmask2 = 0x0f0f;
 
-    const int tid = tpitg.y;        // expecting 16
+    const int tid = tiisg/2;
+    const int ix  = tiisg%2;
     const int ip  = tid/8;          // 0 or 1
     const int il  = tid/2 - 4*ip;   // 0...3
     const int ir  = tid%2;
     const int n   = 8;
     const int l0  = n*ir;
 
-    const uint8_t m = 1 << (4*ip + il);
+    const uint16_t m1 = 1 << (4*ip + il);
+    const uint16_t m2 = m1 << 8;
 
     const int shift = 2*il;
+    const uint16_t qm1 = 0x0003 << shift;
+    const uint16_t qm2 = 0x0300 << shift;
+    const int32_t v1 = 4 << shift;
+    const int32_t v2 = 1024 << shift;
 
     const uint16_t s_shift1 = 4*ip;
     const uint16_t s_shift2 = s_shift1 + 2*(il/2);
@@ -1418,226 +1402,315 @@ kernel void kernel_mul_mat_q3_K_f32(
     const int q_offset = 32*ip + l0;
     const int y_offset = 128*ip + 32*il + l0;
 
-    //float sumf = 0;
-    float sumf1 = 0, sumf2 = 0;
-    for (int i = tpitg.x; i < nb; i += tptg.x) {
-
-        const float d_all = (float)(x[i].d);
+    const int step = sizeof(block_q3_K) * nb / 2;
 
-        device const uint8_t * q = x[i].qs + q_offset;
-        device const uint8_t * h = x[i].hmask + l0;
-        device const float   * y = yy + i * QK_K + y_offset;
+    device const float * y1 = yy + ix*QK_K + y_offset;
 
-        device const uint16_t * a = (device const uint16_t *)x[i].scales;
-        const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
+    float sumf1[2] = {0.f}, sumf2[2] = {0.f};
+    for (int i = ix; i < nb; i += 2) {
 
-        float s = 0;
-        for (int l = 0; l < n; ++l) {
-            s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4));
+        for (int l = 0; l < 8; ++l) {
+            yl[l+0] = y1[l+ 0];
+            yl[l+8] = y1[l+16];
         }
-        float d = d_all * s;
-        sumf1 += d * scales[0];
-        sumf2 += d;
-        //sumf += d_all * s * (scales[0] - 32);
 
-        s = 0;
-        for (int l = 0; l < n; ++l) {
-            s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4));
-        }
-        d = d_all * s;
-        sumf1 += d * scales[1];
-        sumf2 += d;
-        //sumf += d_all * s * (scales[1] - 32);
+        device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
+        device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
+        device const uint16_t * a = (device const uint16_t *)(x[i].scales);
+        device const half * dh = &x[i].d;
 
-    }
+        for (int row = 0; row < 2; ++row) {
 
-    //sum[ith] = sumf;
-    sum[ith] = sumf1 - 32.f*sumf2;
-#else
-    const int il = 4 * tpitg.x;  // 0, 4, 8, 12
-    const int im = il/8;         // 0, 0, 1, 1
-    const int in = il%8;         // 0, 4, 0, 4
+            const float d_all = (float)dh[0];
+            const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
 
-    float sumf = 0;
-
-    for (int i = tpitg.y; i < nb; i += tptg.y) {
-
-        const float d_all = (float)(x[i].d);
-
-        device const uint8_t * q = x[i].qs + il;
-        device const uint8_t * h = x[i].hmask + in;
-        device const float   * y = yy + i * QK_K + il;
+            float s1 = 0, s2 = 0;
+            for (int l = 0; l < n; l += 2) {
+                const uint16_t qs = q[l/2];
+                s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
+                s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
+            }
+            float d = d_all * (s1 + 1.f/256.f * s2);
+            sumf1[row] += d * scales[0];
+            sumf2[row] += d;
+
+            s1 = s2 = 0;
+            for (int l = 0; l < n; l += 2) {
+                const uint16_t qs = q[l/2+8];
+                s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
+                s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
+            }
+            d = d_all * (s1 + 1.f/256.f * s2);
+            sumf1[row] += d * scales[1];
+            sumf2[row] += d;
 
-        const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
-        const float d2 = d_all * ((x[i].scales[0] >>  4) - 8);
-        const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
-        const float d4 = d_all * ((x[i].scales[1] >>  4) - 8);
+            q  += step;
+            h  += step;
+            a  += step;
+            dh += step;
 
-        for (int l = 0; l < 4; ++l) {
-            const uint8_t hm = h[l] >> im;
-            sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
-                  + y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
-                  + y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
-                  + y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
         }
 
-    }
-
-    sum[ith] = sumf;
-
-#endif
+        y1 += 2 * QK_K;
 
-    //
-    // Accumulate the sum from all threads in the threadgroup
-    //
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%4 == 0) {
-        for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%16 == 0) {
-        for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith == 0) {
-        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
-        dst[r1*ne0 + r0] = sum[0];
     }
 
+    for (int row = 0; row < 2; ++row) {
+        const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
+        const float tot = simd_sum(sumf);
+        if (tiisg == 0) {
+            dst[r1*ne0 + first_row + row] = tot;
+        }
+    }
 }
-
-kernel void kernel_mul_mat_q4_K_f32(
+#else
+kernel void kernel_mul_mat_q3_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
         constant   int64_t & ne10,
         constant   int64_t & ne0,
-        threadgroup float  * sum [[threadgroup(0)]],
+        constant   int64_t & ne1,
         uint2 tgpig[[threadgroup_position_in_grid]],
-        uint2 tpitg[[thread_position_in_threadgroup]],
-        uint2  tptg[[threads_per_threadgroup]]) {
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const int nb = ne00/QK_K;
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
 
-    const int nth = tptg.x*tptg.y;
-    const int ith = tptg.y*tpitg.x + tpitg.y;
+    const int row = 2 * r0 + sgitg;
 
-    device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb;
+    device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb;
     device const float     * yy = (device const float      *) src1 + r1*ne10;
+    const int ix = tiisg/4;
+    const int il = 4 * (tiisg%4);// 0, 4, 8, 12
+    const int im = il/8;         // 0, 0, 1, 1
+    const int in = il%8;         // 0, 4, 0, 4
 
-    float sumf = 0;
+    float2 sum = {0.f, 0.f};
+
+    for (int i = ix; i < nb; i += 8) {
+
+        const float d_all = (float)(x[i].d);
+
+        device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
+        device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
+        device const uint16_t * s = (device const uint16_t *)(x[i].scales);
+        device const float    * y = yy + i * QK_K + il;
+
+        const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
+        const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
+        const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
+        const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
+
+        for (int l = 0; l < 4; l += 2) {
+            const uint16_t hm = h[l/2] >> im;
+            sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 :  4))
+                    + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
+                    + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
+                    + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
+            sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
+                    + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
+                    + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
+                    + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
+        }
+
+    }
+    const float sumf = sum[0] + sum[1] * 1.f/256.f;
+
+    const float tot = simd_sum(sumf);
+    if (tiisg == 0) {
+        dst[r1*ne0 + row] = tot;
+    }
+
+}
+#endif
 
 #if QK_K == 256
+kernel void kernel_mul_mat_q4_K_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne10,
+        constant   int64_t & ne0,
+        constant   int64_t & ne01[[buffer(4)]],
+        uint2 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const uint16_t kmask1 = 0x3f3f;
     const uint16_t kmask2 = 0x0f0f;
     const uint16_t kmask3 = 0xc0c0;
 
-    const int tid = tpitg.y;   // 0...16
-    const int il  = tid/4;     // 0...3
-    const int ir  = tid - 4*il;// 0...3
-    const int n   = 4;
+    const int ix = tiisg/8;  // 0...3
+    const int it = tiisg%8;  // 0...7
+    const int im = it/4;     // 0 or 1
+    const int ir = it%4;     // 0...3
 
-    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
-    const int in = il%2;
+    const int nb = ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int ib_row = first_row * nb;
+    device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
+    device const float      * y = (device const float      *) src1 + r1*ne10;
+    float yl[16];
+    float yh[16];
+    float sumf[N_DST]={0.f}, all_sum;
 
-    const int l0 = n*(2*ir + in);
-    const int q_offset = 32*im + l0;
-    const int y_offset = 64*im + l0;
+    const int step = sizeof(block_q4_K) * nb / 2;
 
-    uchar2 sc1, sc2, sc3, sc4;
+    device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
 
-    for (int i = tpitg.x; i < nb; i += tptg.x) {
+    uint16_t sc16[4];
+    thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
 
-        device const uint8_t * q1 = (x + i)->qs + q_offset;
-        device const uint8_t * q2 = q1 + 64;
-        device const float   * y1 = yy + i*QK_K + y_offset;
-        device const float   * y2 = y1 + 128;
+    for (int ib = ix; ib < nb; ib += 4) {
 
-        const float dall = (float)((x + i)->d);
-        const float dmin = (float)((x + i)->dmin);
+        float4 sumy = {0.f, 0.f, 0.f, 0.f};
+        for (int i = 0; i < 8; ++i) {
+            yl[i+0] = y4[i+  0]; sumy[0] += yl[i+0];
+            yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
+            yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
+            yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
+        }
 
-        device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
-        sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
-        sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
-        sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
-        sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
+        device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
+        device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
+        device const half     * dh = &x[ib].d;
 
-        float4 s = {0.f, 0.f, 0.f, 0.f};
-        float smin = 0;
-        for (int l = 0; l < n; ++l) {
+        for (int row = 0; row < N_DST; row++) {
 
-            s[0] += y1[l] * (q1[l] & 0xF); s[1] += y1[l+32] * (q1[l] >> 4);
-            s[2] += y2[l] * (q2[l] & 0xF); s[3] += y2[l+32] * (q2[l] >> 4);
-            smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
+            sc16[0] = sc[0] & kmask1;
+            sc16[1] = sc[2] & kmask1;
+            sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
+            sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
+
+            device const uint16_t * q2 = q1 + 32;
+
+            float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+            float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+            for (int i = 0; i < 8; i += 2) {
+                acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
+                acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
+                acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
+                acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
+                acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
+                acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
+                acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
+                acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
+            }
 
+            float dall = dh[0];
+            float dmin = dh[1];
+            sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
+                                 (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
+                                 (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
+                                 (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
+                         dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
+
+            q1 += step;
+            sc += step;
+            dh += step;
         }
-        sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
 
+        y4 += 4 * QK_K;
+    }
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + first_row + row] = all_sum;
+        }
     }
+}
 #else
-    uint16_t aux16[2];
-    thread const uint8_t * scales = (thread const uint8_t *)aux16;
+kernel void kernel_mul_mat_q4_K_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne10,
+        constant   int64_t & ne0,
+        constant   int64_t & ne01[[buffer(4)]],
+        uint2 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    const int il  = 4*tpitg.x;
+    const int ix = tiisg/4;  // 0...7
+    const int it = tiisg%4;  // 0...3
 
-    for (int i = tpitg.y; i < nb; i += tptg.y) {
+    const int nb = ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int ib_row = first_row * nb;
+    device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
+    device const float      * y = (device const float      *) src1 + r1*ne10;
+    float yl[8];
+    float yh[8];
+    float sumf[N_DST]={0.f}, all_sum;
 
-        device const uint8_t * q = x[i].qs + il;
-        device const float   * y = yy + i * QK_K + il;
+    const int step = sizeof(block_q4_K) * nb / 2;
 
-        const float d = (float)x[i].d[0];
-        const float m = (float)x[i].d[1];
+    device const float * y4 = y + ix * QK_K + 8 * it;
 
-        device const uint16_t * a = (device const uint16_t *)x[i].scales;
-        aux16[0] = a[0] & 0x0f0f;
-        aux16[1] = (a[0] >> 4) & 0x0f0f;
+    uint16_t sc16[4];
 
-        for (int l = 0; l < 4; ++l) {
-            sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16])
-                  + d * scales[1] * (y[l+32] * (q[l] >>  4) + y[l+48] * (q[l+16] >>  4)) - m * scales[3] * (y[l+32] + y[l+48]);
+    for (int ib = ix; ib < nb; ib += 8) {
+
+        float2 sumy = {0.f, 0.f};
+        for (int i = 0; i < 8; ++i) {
+            yl[i] = y4[i+ 0]; sumy[0] += yl[i];
+            yh[i] = y4[i+32]; sumy[1] += yh[i];
         }
-    }
-#endif
 
-    sum[ith] = sumf;
+        device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
+        device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
+        device const half     * dh = x[ib].d;
 
-    //
-    // Accumulate the sum from all threads in the threadgroup
-    // This version is slightly faster than the commented out one below,
-    // which I copy-pasted from ggerganov's q4_0 dot product for metal.
-    //
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%4 == 0) {
-        for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%16 == 0) {
-        for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith == 0) {
-        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
-        dst[r1*ne0 + r0] = sum[0];
+        for (int row = 0; row < N_DST; row++) {
+
+            sc16[0] = sc[0] & 0x000f;
+            sc16[1] = sc[0] & 0x0f00;
+            sc16[2] = sc[0] & 0x00f0;
+            sc16[3] = sc[0] & 0xf000;
+
+            float2 acc1 = {0.f, 0.f};
+            float2 acc2 = {0.f, 0.f};
+            for (int i = 0; i < 8; i += 2) {
+                acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
+                acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
+                acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
+                acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
+            }
+
+            float dall = dh[0];
+            float dmin = dh[1];
+            sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
+                                 (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
+                         dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
+
+            qs += step;
+            sc += step;
+            dh += step;
+        }
+
+        y4 += 8 * QK_K;
     }
 
-    //// accumulate the sum from all threads in the threadgroup
-    //threadgroup_barrier(mem_flags::mem_threadgroup);
-    //for (uint i = nth/2; i > 0; i /= 2) {
-    //    if (ith < i) {
-    //        sum[ith] += sum[ith + i];
-    //    }
-    //    threadgroup_barrier(mem_flags::mem_threadgroup);
-    //}
-
-    //if (ith == 0) {
-    //    dst[r1*ne0 + r0] = sum[0];
-    //}
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + first_row + row] = all_sum;
+        }
+    }
 }
+#endif
 
 kernel void kernel_mul_mat_q5_K_f32(
         device const  void * src0,
@@ -1646,39 +1719,39 @@ kernel void kernel_mul_mat_q5_K_f32(
         constant   int64_t & ne00,
         constant   int64_t & ne10,
         constant   int64_t & ne0,
-        threadgroup float  * sum [[threadgroup(0)]],
         uint2 tgpig[[threadgroup_position_in_grid]],
-        uint2 tpitg[[thread_position_in_threadgroup]],
-        uint2  tptg[[threads_per_threadgroup]]) {
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const int nb = ne00/QK_K;
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
 
-    device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb;
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
+
+    device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb;
     device const float     * yy = (device const float      *) src1 + r1*ne10;
 
-    const int nth = tptg.x*tptg.y;
-    const int ith = tptg.y*tpitg.x + tpitg.y;
+    float sumf[2]={0.f};
 
-    float sumf = 0;
+    const int step = sizeof(block_q5_K) * nb;
 
 #if QK_K == 256
+#
+    float yl[16], yh[16];
 
     const uint16_t kmask1 = 0x3f3f;
     const uint16_t kmask2 = 0x0f0f;
     const uint16_t kmask3 = 0xc0c0;
 
-    const int tid = tpitg.y;   // 0...16
-    const int il  = tid/4;     // 0...3
-    const int ir  = tid - 4*il;// 0...3
-    const int n   = 4;
-
-    const int im = il/2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
-    const int in = il%2;
+    const int tid = tiisg/4;
+    const int ix  = tiisg%4;
+    const int im  = tid/4;
+    const int ir  = tid%4;
+    const int n   = 8;
 
-    const int l0 = n*(2*ir + in);
+    const int l0 = n*ir;
     const int q_offset = 32*im + l0;
     const int y_offset = 64*im + l0;
 
@@ -1687,78 +1760,113 @@ kernel void kernel_mul_mat_q5_K_f32(
     const uint8_t hm3 = hm1 << 4;
     const uint8_t hm4 = hm2 << 4;
 
-    uchar2 sc1, sc2, sc3, sc4;
+    uint16_t sc16[4];
+    thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
 
-    for (int i = tpitg.x; i < nb; i += tptg.x) {
+    device const float * y1 = yy + ix*QK_K + y_offset;
 
-        device const uint8_t * q1 = (x + i)->qs + q_offset;
-        device const uint8_t * q2 = q1 + 64;
-        device const uint8_t * qh = (x + i)->qh + l0;
-        device const float   * y1 = yy + i*QK_K + y_offset;
-        device const float   * y2 = y1 + 128;
+    for (int i = ix; i < nb; i += 4) {
 
-        const float dall = (float)((x + i)->d);
-        const float dmin = (float)((x + i)->dmin);
+        device const uint8_t * q1 = x[i].qs + q_offset;
+        device const uint8_t * qh = x[i].qh + l0;
+        device const half * dh = &x[i].d;
+        device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
 
-        device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
-        sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
-        sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
-        sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
-        sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
+        device const float * y2 = y1 + 128;
+        float4 sumy = {0.f, 0.f, 0.f, 0.f};
+        for (int l = 0; l < 8; ++l) {
+            yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
+            yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
+            yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
+            yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
+        }
 
-        float4 s = {0.f, 0.f, 0.f, 0.f};
-        float smin = 0;
-        for (int l = 0; l < n; ++l) {
+        for (int row = 0; row < 2; ++row) {
+
+            device const uint8_t * q2 = q1 + 64;
 
-            s[0] += y1[l+ 0] * ((q1[l] & 0xF) + (qh[l] & hm1 ? 16 : 0));
-            s[1] += y1[l+32] * ((q1[l] >>  4) + (qh[l] & hm2 ? 16 : 0));
-            s[2] += y2[l+ 0] * ((q2[l] & 0xF) + (qh[l] & hm3 ? 16 : 0));
-            s[3] += y2[l+32] * ((q2[l] >>  4) + (qh[l] & hm4 ? 16 : 0));
-            smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
+            sc16[0] = a[0] & kmask1;
+            sc16[1] = a[2] & kmask1;
+            sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
+            sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
+
+            float4 acc = {0.f, 0.f, 0.f, 0.f};
+            for (int l = 0; l < n; ++l) {
+                uint8_t h = qh[l];
+                acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
+                acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
+                acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
+                acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
+            }
+            const float dall = dh[0];
+            const float dmin = dh[1];
+            sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
+                         dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
+
+            q1 += step;
+            qh += step;
+            dh += step/2;
+            a  += step/2;
 
         }
-        sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
+
+        y1 += 4 * QK_K;
 
     }
 #else
-    const int il  = 4 * tpitg.x;  // 0, 4, 8, 12
-    const int im  = il/8;         // 0, 0, 1, 1
-    const int in  = il%8;         // 0, 4, 0, 4
+    float yl[8], yh[8];
+
+    const int il = 4 * (tiisg/8);  // 0, 4, 8, 12
+    const int ix = tiisg%8;
+    const int im = il/8;         // 0, 0, 1, 1
+    const int in = il%8;         // 0, 4, 0, 4
 
-    for (int i = tpitg.y; i < nb; i += tptg.y) {
+    device const float * y = yy + ix*QK_K + il;
 
-        const float d = (float)x[i].d;
+    for (int i = ix; i < nb; i += 8) {
+
+        for (int l = 0; l < 4; ++l) {
+            yl[l+0] = y[l+ 0];
+            yl[l+4] = y[l+16];
+            yh[l+0] = y[l+32];
+            yh[l+4] = y[l+48];
+        }
+
+        device const half * dh = &x[i].d;
         device const uint8_t * q = x[i].qs + il;
         device const uint8_t * h = x[i].qh + in;
         device const int8_t  * s = x[i].scales;
-        device const float   * y = yy + i*QK_K + il;
 
-        for (int l = 0; l < 4; ++l) {
-            const uint8_t hl = h[l] >> im;
-            sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16))
-                  + y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16))
-                  + y[l+32] * d * s[2] * ((q[l+ 0] >>  4) - (hl & 0x10 ? 0 : 16))
-                  + y[l+48] * d * s[3] * ((q[l+16] >>  4) - (hl & 0x40 ? 0 : 16));
+        for (int row = 0; row < 2; ++row) {
+
+            const float d = dh[0];
+
+            float2 acc = {0.f, 0.f};
+            for (int l = 0; l < 4; ++l) {
+                const uint8_t hl = h[l] >> im;
+                acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
+                        + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
+                acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
+                        + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
+            }
+            sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
+
+            q += step;
+            h += step;
+            s += step;
+            dh += step/2;
+
         }
+
+        y += 8 * QK_K;
     }
 #endif
-    sum[ith] = sumf;
 
-    //
-    // Accumulate the sum from all threads in the threadgroup
-    //
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%4 == 0) {
-        sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%16 == 0) {
-        sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith == 0) {
-        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
-        dst[r1*ne0 + r0] = sum[0];
+    for (int row = 0; row < 2; ++row) {
+        const float tot = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst[r1*ne0 + first_row + row] = tot;
+        }
     }
 
 }
@@ -1770,10 +1878,9 @@ kernel void kernel_mul_mat_q6_K_f32(
         constant   int64_t & ne00,
         constant   int64_t & ne10,
         constant   int64_t & ne0,
-        threadgroup float  * sum [[threadgroup(0)]],
         uint2 tgpig[[threadgroup_position_in_grid]],
-        uint2 tpitg[[thread_position_in_threadgroup]],
-        uint2  tptg[[threads_per_threadgroup]]) {
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const uint8_t kmask1 = 0x03;
     const uint8_t kmask2 = 0x0C;
@@ -1785,19 +1892,18 @@ kernel void kernel_mul_mat_q6_K_f32(
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
 
-    device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
-    device const float     * yy = (device const float      *) src1 + r1*ne10;
+    const int row = 2 * r0 + sgitg;
 
-    const int nth = tptg.x*tptg.y;
-    const int ith = tptg.y*tpitg.x + tpitg.y;
+    device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; //r0*nb;
+    device const float     * yy = (device const float      *) src1 + r1*ne10;
 
     float sumf = 0;
 
 #if QK_K == 256
-    // Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
-    const int iqs  = 16 * tpitg.y;
-    const int ip   = iqs / 128;         // 0 or 1
-    const int il   = (iqs - 128*ip)/16; // 0...7
+    const int tid  = tiisg/2;
+    const int ix   = tiisg%2;
+    const int ip   = tid/8;         // 0 or 1
+    const int il   = tid%8;
     const int n    = 4;
     const int l0   = n*il;
     const int is   = 8*ip + l0/16;
@@ -1806,9 +1912,10 @@ kernel void kernel_mul_mat_q6_K_f32(
     const int q_offset_l = 64*ip + l0;
     const int q_offset_h = 32*ip + l0;
 
-    for (int i = tpitg.x; i < nb; i += tptg.x) {
+    for (int i = ix; i < nb; i += 2) {
 
-        device const uint8_t * ql = x[i].ql + q_offset_l;
+        device const uint8_t * q1 = x[i].ql + q_offset_l;
+        device const uint8_t * q2 = q1 + 32;
         device const uint8_t * qh = x[i].qh + q_offset_h;
         device const int8_t  * sc = x[i].scales + is;
 
@@ -1818,19 +1925,21 @@ kernel void kernel_mul_mat_q6_K_f32(
 
         float4 sums = {0.f, 0.f, 0.f, 0.f};
         for (int l = 0; l < n; ++l) {
-            sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
-            sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
-            sums[2] += y[l+64] * ((int8_t)((ql[l+ 0]  >> 4) | ((qh[l] & kmask3) << 0)) - 32);
-            sums[3] += y[l+96] * ((int8_t)((ql[l+32]  >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
+            sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
+            sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
+            sums[2] += y[l+64] * ((int8_t)((q1[l]  >> 4) | ((qh[l] & kmask3) << 0)) - 32);
+            sums[3] += y[l+96] * ((int8_t)((q2[l]  >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
         }
 
         sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
 
     }
+
 #else
-    const int il  = 4*tpitg.x;    // 0, 4, 8, 12
+    const int ix  = tiisg/4;
+    const int il  = 4*(tiisg%4);
 
-    for (int i = tpitg.y; i < nb; i += tptg.y) {
+    for (int i = ix; i < nb; i += 8) {
         device const float * y = yy + i * QK_K + il;
         device const uint8_t * ql = x[i].ql + il;
         device const uint8_t * qh = x[i].qh + il;
@@ -1850,23 +1959,8 @@ kernel void kernel_mul_mat_q6_K_f32(
 
 #endif
 
-    sum[ith] = sumf;
-
-    //
-    // Accumulate the sum from all threads in the threadgroup
-    //
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%4 == 0) {
-        for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
+    const float tot = simd_sum(sumf);
+    if (tiisg == 0) {
+        dst[r1*ne0 + row] = tot;
     }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith%16 == 0) {
-        for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (ith == 0) {
-        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
-        dst[r1*ne0 + r0] = sum[0];
-    }
-
 }
index c56a3d0e0c0a2c695f76f326b9ca8ca4b9ba5ee2..14658ae4fcbf2701452a37191464a65ac9c75dff 100644 (file)
@@ -4590,6 +4590,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
         /*.ne           =*/ { 1, 1, 1, 1 },
         /*.nb           =*/ { 0, 0, 0, 0 },
         /*.op           =*/ GGML_OP_NONE,
+        /*.op_params    =*/ {0},
         /*.is_param     =*/ false,
         /*.grad         =*/ NULL,
         /*.src          =*/ { NULL },
@@ -4969,6 +4970,11 @@ struct ggml_tensor * ggml_format_name(struct ggml_tensor * tensor, const char *
     return tensor;
 }
 
+static void ggml_set_op_params(struct ggml_tensor * tensor, const void * params, size_t params_size) {
+    assert(params_size <= GGML_MAX_OP_PARAMS);
+    memcpy(tensor->op_params, params, params_size);
+}
+
 struct ggml_tensor * ggml_view_tensor(
         struct ggml_context * ctx,
         const struct ggml_tensor * src) {
@@ -5019,7 +5025,6 @@ struct ggml_tensor * ggml_dup_impl(
     result->op   = GGML_OP_DUP;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5143,23 +5148,13 @@ struct ggml_tensor * ggml_acc_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 5);
-
-    ((int32_t *) c->data)[0] = nb1;
-    ((int32_t *) c->data)[1] = nb2;
-    ((int32_t *) c->data)[2] = nb3;
-    ((int32_t *) c->data)[3] = offset;
-    ((int32_t *) c->data)[4] = inplace ? 1 : 0;
-
-    ggml_scratch_load(ctx);
+    int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };
+    ggml_set_op_params(result, params, sizeof(params));
 
     result->op   = GGML_OP_ACC;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
     result->src[1] = b;
-    result->src[2] = c;
 
     return result;
 }
@@ -5332,7 +5327,6 @@ struct ggml_tensor * ggml_sqr_impl(
     result->op   = GGML_OP_SQR;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5366,7 +5360,6 @@ struct ggml_tensor * ggml_sqrt_impl(
     result->op   = GGML_OP_SQRT;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5401,7 +5394,6 @@ struct ggml_tensor * ggml_log_impl(
     result->op   = GGML_OP_LOG;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5434,7 +5426,6 @@ struct ggml_tensor * ggml_sum(
     result->op   = GGML_OP_SUM;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5461,7 +5452,6 @@ struct ggml_tensor * ggml_sum_rows(
     result->op   = GGML_OP_SUM_ROWS;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5484,7 +5474,6 @@ struct ggml_tensor * ggml_mean(
     result->op   = GGML_OP_MEAN;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5508,7 +5497,6 @@ struct ggml_tensor * ggml_argmax(
     result->op   = GGML_OP_ARGMAX;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5586,7 +5574,6 @@ struct ggml_tensor * ggml_abs_impl(
     result->op   = GGML_OP_ABS;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5621,7 +5608,6 @@ struct ggml_tensor * ggml_sgn_impl(
     result->op   = GGML_OP_SGN;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5655,7 +5641,6 @@ struct ggml_tensor * ggml_neg_impl(
     result->op   = GGML_OP_NEG;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5689,7 +5674,6 @@ struct ggml_tensor * ggml_step_impl(
     result->op   = GGML_OP_STEP;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5723,7 +5707,6 @@ struct ggml_tensor * ggml_tanh_impl(
     result->op   = GGML_OP_TANH;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5757,7 +5740,6 @@ struct ggml_tensor * ggml_elu_impl(
     result->op   = GGML_OP_ELU;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5791,7 +5773,6 @@ struct ggml_tensor * ggml_relu_impl(
     result->op   = GGML_OP_RELU;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5825,7 +5806,6 @@ struct ggml_tensor * ggml_gelu_impl(
     result->op   = GGML_OP_GELU;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5859,7 +5839,6 @@ struct ggml_tensor * ggml_gelu_quick_impl(
     result->op   = GGML_OP_GELU_QUICK;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5893,7 +5872,6 @@ struct ggml_tensor * ggml_silu_impl(
     result->op   = GGML_OP_SILU;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -5948,10 +5926,11 @@ struct ggml_tensor * ggml_norm_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
+    // TODO: maybe store epsilon here?
+
     result->op   = GGML_OP_NORM;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL; // TODO: maybe store epsilon here?
 
     return result;
 }
@@ -5980,10 +5959,11 @@ struct ggml_tensor * ggml_rms_norm_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
+    // TODO: maybe store epsilon here?
+
     result->op   = GGML_OP_RMS_NORM;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL; // TODO: maybe store epsilon here?
 
     return result;
 }
@@ -6136,23 +6116,13 @@ struct ggml_tensor * ggml_set_impl(
     // make a view of the destination
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 5);
-
-    (( int32_t * ) c->data)[0] = nb1;
-    (( int32_t * ) c->data)[1] = nb2;
-    (( int32_t * ) c->data)[2] = nb3;
-    (( int32_t * ) c->data)[3] = offset;
-    (( int32_t * ) c->data)[4] = inplace ? 1 : 0;
-
-    ggml_scratch_load(ctx);
+    int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };
+    ggml_set_op_params(result, params, sizeof(params));
 
     result->op   = GGML_OP_SET;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
     result->src[1] = b;
-    result->src[2] = c;
 
     return result;
 }
@@ -6277,7 +6247,6 @@ struct ggml_tensor * ggml_cont_impl(
     result->op   = GGML_OP_CONT;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -6321,7 +6290,6 @@ struct ggml_tensor * ggml_reshape(
     result->op   = GGML_OP_RESHAPE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -6346,7 +6314,6 @@ struct ggml_tensor * ggml_reshape_1d(
     result->op   = GGML_OP_RESHAPE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -6372,7 +6339,6 @@ struct ggml_tensor * ggml_reshape_2d(
     result->op   = GGML_OP_RESHAPE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -6399,7 +6365,6 @@ struct ggml_tensor * ggml_reshape_3d(
     result->op   = GGML_OP_RESHAPE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -6428,7 +6393,6 @@ struct ggml_tensor * ggml_reshape_4d(
     result->op   = GGML_OP_RESHAPE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -6450,19 +6414,11 @@ struct ggml_tensor * ggml_view_1d(
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset);
     ggml_format_name(result, "%s (view)", a->name);
 
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
-    ggml_set_name(offs, "offset");
-    memcpy(offs->data, &offset, 2*sizeof(int32_t));
-
-    ggml_scratch_load(ctx);
+    ggml_set_op_params(result, &offset, sizeof(offset));
 
     result->op   = GGML_OP_VIEW;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
-    result->src[2] = offs;
 
     return result;
 }
@@ -6488,13 +6444,7 @@ struct ggml_tensor * ggml_view_2d(
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset);
     ggml_format_name(result, "%s (view)", a->name);
 
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
-    ggml_set_name(offs, "offset");
-    memcpy(offs->data, &offset, 2*sizeof(int32_t));
-
-    ggml_scratch_load(ctx);
+    ggml_set_op_params(result, &offset, sizeof(offset));
 
     result->nb[1] = nb1;
     result->nb[2] = result->nb[1]*ne1;
@@ -6503,8 +6453,6 @@ struct ggml_tensor * ggml_view_2d(
     result->op   = GGML_OP_VIEW;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
-    result->src[2] = offs;
 
     return result;
 }
@@ -6532,13 +6480,7 @@ struct ggml_tensor * ggml_view_3d(
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, (char *) a->data + offset);
     ggml_format_name(result, "%s (view)", a->name);
 
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
-    ggml_set_name(offs, "offset");
-    memcpy(offs->data, &offset, 2*sizeof(int32_t));
-
-    ggml_scratch_load(ctx);
+    ggml_set_op_params(result, &offset, sizeof(offset));
 
     result->nb[1] = nb1;
     result->nb[2] = nb2;
@@ -6547,8 +6489,6 @@ struct ggml_tensor * ggml_view_3d(
     result->op   = GGML_OP_VIEW;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
-    result->src[2] = offs;
 
     return result;
 }
@@ -6578,13 +6518,7 @@ struct ggml_tensor * ggml_view_4d(
     struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, (char *) a->data + offset);
     ggml_format_name(result, "%s (view)", a->name);
 
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
-    ggml_set_name(offs, "offset");
-    memcpy(offs->data, &offset, 2*sizeof(int32_t));
-
-    ggml_scratch_load(ctx);
+    ggml_set_op_params(result, &offset, sizeof(offset));
 
     result->nb[1] = nb1;
     result->nb[2] = nb2;
@@ -6593,8 +6527,6 @@ struct ggml_tensor * ggml_view_4d(
     result->op   = GGML_OP_VIEW;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
-    result->src[2] = offs;
 
     return result;
 }
@@ -6655,22 +6587,9 @@ struct ggml_tensor * ggml_permute(
     result->op   = GGML_OP_PERMUTE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
-
-    if (is_node) {
-        ggml_scratch_save(ctx);
-
-        struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
 
-        ((int32_t *) b->data)[0] = axis0;
-        ((int32_t *) b->data)[1] = axis1;
-        ((int32_t *) b->data)[2] = axis2;
-        ((int32_t *) b->data)[3] = axis3;
-
-        ggml_scratch_load(ctx);
-
-        result->src[2] = b;
-    }
+    int32_t params[] = { axis0, axis1, axis2, axis3 };
+    ggml_set_op_params(result, &params, sizeof(params));
 
     return result;
 }
@@ -6698,7 +6617,6 @@ struct ggml_tensor * ggml_transpose(
     result->op   = GGML_OP_TRANSPOSE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -6776,7 +6694,6 @@ struct ggml_tensor * ggml_diag(
     result->op   = GGML_OP_DIAG;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -6797,19 +6714,12 @@ struct ggml_tensor * ggml_diag_mask_inf_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
-
-    ((int32_t *) b->data)[0] = n_past;
-    ((int32_t *) b->data)[1] = inplace ? 1 : 0;
-
-    ggml_scratch_load(ctx);
+    int32_t params[] = { n_past, inplace ? 1 : 0 };
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op   = GGML_OP_DIAG_MASK_INF;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = b;
 
     return result;
 }
@@ -6844,20 +6754,12 @@ struct ggml_tensor * ggml_diag_mask_zero_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
-    ggml_set_name(b, "n_past, inplace");
-
-    ((int32_t *) b->data)[0] = n_past;
-    ((int32_t *) b->data)[1] = inplace ? 1 : 0;
-
-    ggml_scratch_load(ctx);
+    int32_t params[] = { n_past, inplace ? 1 : 0 };
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op   = GGML_OP_DIAG_MASK_ZERO;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = b;
 
     return result;
 }
@@ -6893,7 +6795,6 @@ struct ggml_tensor * ggml_soft_max_impl(
     result->op   = GGML_OP_SOFT_MAX;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
 
     return result;
 }
@@ -6956,9 +6857,9 @@ struct ggml_tensor * ggml_rope_impl(
         int                   n_past,
         int                   n_dims,
         int                   mode,
+        int                   n_ctx,
         float                 freq_base,
         float                 freq_scale,
-        int                   n_ctx,
         bool                  inplace) {
     GGML_ASSERT(n_past >= 0);
     bool is_node = false;
@@ -6969,23 +6870,14 @@ struct ggml_tensor * ggml_rope_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 6);
-
-    ((int32_t *) b->data)[0] = n_past;
-    ((int32_t *) b->data)[1] = n_dims;
-    ((int32_t *) b->data)[2] = mode;
-    ((int32_t *) b->data)[3] = n_ctx;
-    memcpy((int32_t *) b->data + 4, &freq_base,  sizeof(float));
-    memcpy((int32_t *) b->data + 5, &freq_scale, sizeof(float));
-
-    ggml_scratch_load(ctx);
+    int32_t params[6] = { n_past, n_dims, mode, n_ctx };
+    memcpy(params + 4, &freq_base, sizeof(float));
+    memcpy(params + 5, &freq_scale, sizeof(float));
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op   = GGML_OP_ROPE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = b;
 
     return result;
 }
@@ -6997,7 +6889,7 @@ struct ggml_tensor * ggml_rope(
         int                   n_dims,
         int                   mode,
         int                   n_ctx) {
-    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, false);
+    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, false);
 }
 
 struct ggml_tensor * ggml_rope_inplace(
@@ -7007,7 +6899,7 @@ struct ggml_tensor * ggml_rope_inplace(
         int                   n_dims,
         int                   mode,
         int                   n_ctx) {
-    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, true);
+    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, true);
 }
 
 struct ggml_tensor * ggml_rope_custom_inplace(
@@ -7016,10 +6908,10 @@ struct ggml_tensor * ggml_rope_custom_inplace(
         int                   n_past,
         int                   n_dims,
         int                   mode,
+        int                   n_ctx,
         float                 freq_base,
-        float                 freq_scale,
-        int                   n_ctx) {
-    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, freq_base, freq_scale, n_ctx, true);
+        float                 freq_scale) {
+    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, true);
 }
 
 // ggml_rope_back
@@ -7029,7 +6921,8 @@ struct ggml_tensor * ggml_rope_back(
         struct ggml_tensor  * a,
         int                   n_past,
         int                   n_dims,
-        int                   mode) {
+        int                   mode,
+        int                   n_ctx) {
     GGML_ASSERT(n_past >= 0);
     GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
 
@@ -7041,21 +6934,12 @@ struct ggml_tensor * ggml_rope_back(
 
     struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
 
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
-    ggml_set_name(b, "n_past, n_dims, mode");
-
-    ((int32_t *) b->data)[0] = n_past;
-    ((int32_t *) b->data)[1] = n_dims;
-    ((int32_t *) b->data)[2] = mode;
-
-    ggml_scratch_load(ctx);
+    int32_t params[] = { n_past, n_dims, mode, n_ctx };
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op   = GGML_OP_ROPE_BACK;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = b;
 
     return result;
 }
@@ -7080,21 +6964,13 @@ struct ggml_tensor * ggml_alibi(
     //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
     struct ggml_tensor * result = ggml_view_tensor(ctx, a);
 
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
-
-    ((int32_t *) b->data)[0] = n_past;
-    ((int32_t *) b->data)[1] = n_head;
-    GGML_ASSERT(sizeof(float) == sizeof(int32_t));
-    (((float *) b->data)[2]) = bias_max;
-
-    ggml_scratch_load(ctx);
+    int32_t op_params[3] = { n_past, n_head };
+    memcpy(op_params + 2, &bias_max, sizeof(float));
+    ggml_set_op_params(result, &op_params, sizeof(op_params));
 
     result->op   = GGML_OP_ALIBI;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = b;
 
     return result;
 }
@@ -7116,19 +6992,12 @@ struct ggml_tensor * ggml_clamp(
     // TODO: when implement backward, fix this:
     struct ggml_tensor * result = ggml_view_tensor(ctx, a);
 
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 2);
-
-    ((float *) b->data)[0] = min;
-    ((float *) b->data)[1] = max;
-
-    ggml_scratch_load(ctx);
+    float params[] = { min, max };
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op   = GGML_OP_CLAMP;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = b;
 
     return result;
 }
@@ -7161,18 +7030,13 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
     };
     struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
 
-    ggml_scratch_save(ctx);
-    struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
-    ((int32_t*)c->data)[0] = s0;
-    ((int32_t*)c->data)[1] = p0;
-    ((int32_t*)c->data)[2] = d0;
-    ggml_scratch_load(ctx);
+    int32_t params[] = { s0, p0, d0 };
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op = GGML_OP_CONV_1D;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
     result->src[1] = b;
-    result->src[2] = c;
 
     return result;
 }
@@ -7205,21 +7069,13 @@ struct ggml_tensor* ggml_conv_2d(
     };
     struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
-    ggml_scratch_save(ctx);
-    struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 6);
-    ((int32_t*)c->data)[0] = s0;
-    ((int32_t*)c->data)[1] = s1;
-    ((int32_t*)c->data)[2] = p0;
-    ((int32_t*)c->data)[3] = p1;
-    ((int32_t*)c->data)[4] = d0;
-    ((int32_t*)c->data)[5] = d1;
-    ggml_scratch_load(ctx);
+    int32_t params[] = { s0, s1, p0, p1, d0, d1 };
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op = GGML_OP_CONV_2D;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
     result->src[1] = b;
-    result->src[2] = c;
 
     return result;
 
@@ -7243,7 +7099,7 @@ static int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, int p) {
     return (ins + 2 * p - ks) / s + 1;
 }
 
-// ggml_pool_2d
+// ggml_pool_1d
 
 struct ggml_tensor* ggml_pool_1d(
         struct ggml_context * ctx,
@@ -7266,18 +7122,12 @@ struct ggml_tensor* ggml_pool_1d(
     };
     struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
 
-    ggml_scratch_save(ctx);
-    struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
-    ((int32_t*)c->data)[0] = op;
-    ((int32_t*)c->data)[1] = k0;
-    ((int32_t*)c->data)[2] = s0;
-    ((int32_t*)c->data)[3] = p0;
-    ggml_scratch_load(ctx);
+    int32_t params[] = { op, k0, s0, p0 };
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op = GGML_OP_POOL_1D;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = c;
 
     return result;
 }
@@ -7309,21 +7159,12 @@ struct ggml_tensor* ggml_pool_2d(
     };
     struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
 
-    ggml_scratch_save(ctx);
-    struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 7);
-    ((int32_t*)c->data)[0] = op;
-    ((int32_t*)c->data)[1] = k0;
-    ((int32_t*)c->data)[2] = k1;
-    ((int32_t*)c->data)[3] = s0;
-    ((int32_t*)c->data)[4] = s1;
-    ((int32_t*)c->data)[5] = p0;
-    ((int32_t*)c->data)[6] = p1;
-    ggml_scratch_load(ctx);
+    int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op = GGML_OP_POOL_2D;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = c;
 
     return result;
 }
@@ -7482,21 +7323,12 @@ struct ggml_tensor * ggml_win_part(
 
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
-
-    ((int32_t *) b->data)[0] = npx;
-    ((int32_t *) b->data)[1] = npy;
-    ((int32_t *) b->data)[2] = w;
-
-    ggml_scratch_load(ctx);
+    int32_t params[] = { npx, npy, w };
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op   = GGML_OP_WIN_PART;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
-    result->src[2] = b;
 
     return result;
 }
@@ -7521,19 +7353,12 @@ struct ggml_tensor * ggml_win_unpart(
     const int64_t ne[4] = { a->ne[0], w0, h0, 1, };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
 
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
-
-    ((int32_t *) b->data)[0] = w;
-
-    ggml_scratch_load(ctx);
+    int32_t params[] = { w };
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op   = GGML_OP_WIN_UNPART;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = NULL;
-    result->src[2] = b;
 
     return result;
 }
@@ -7551,19 +7376,13 @@ struct ggml_tensor * ggml_map_unary_impl_f32(
         is_node = true;
     }
 
-    struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    ggml_scratch_save(ctx);
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
-    *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
-
-    ggml_scratch_load(ctx);
+    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
 
     result->op = GGML_OP_MAP_UNARY;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[2] = addr_tensor;
 
     return result;
 }
@@ -7598,20 +7417,14 @@ struct ggml_tensor * ggml_map_binary_impl_f32(
         is_node = true;
     }
 
-    struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
-    *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    ggml_scratch_load(ctx);
+    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
 
     result->op = GGML_OP_MAP_BINARY;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
     result->src[1] = b;
-    result->src[2] = addr_tensor;
 
     return result;
 }
@@ -7645,19 +7458,13 @@ struct ggml_tensor * ggml_map_custom1_impl_f32(
         is_node = true;
     }
 
-    struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
-    *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    ggml_scratch_load(ctx);
+    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
 
     result->op = GGML_OP_MAP_CUSTOM1;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[2] = addr_tensor;
 
     return result;
 }
@@ -7690,20 +7497,14 @@ struct ggml_tensor * ggml_map_custom2_impl_f32(
         is_node = true;
     }
 
-    struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
-    *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    ggml_scratch_load(ctx);
+    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
 
     result->op = GGML_OP_MAP_CUSTOM2;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
     result->src[1] = b;
-    result->src[2] = addr_tensor;
 
     return result;
 }
@@ -7739,21 +7540,15 @@ struct ggml_tensor * ggml_map_custom3_impl_f32(
         is_node = true;
     }
 
-    struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    ggml_scratch_save(ctx);
-
-    struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
-    *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    ggml_scratch_load(ctx);
+    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
 
     result->op = GGML_OP_MAP_CUSTOM3;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
     result->src[1] = b;
-    result->src[2] = addr_tensor;
-    result->src[3] = c;
+    result->src[2] = c;
 
     return result;
 }
@@ -8981,21 +8776,17 @@ static void ggml_compute_forward_acc_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-        const struct ggml_tensor * opt0,
         struct ggml_tensor * dst) {
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
     GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
 
-    GGML_ASSERT(opt0->type == GGML_TYPE_I32);
-    GGML_ASSERT(ggml_nelements(opt0) == 5);
-
     // view src0 and dst with these strides and data offset inbytes during acc
     // nb0 is implicitely element_size because src0 and dst are contiguous
-    size_t nb1     = ((int32_t *) opt0->data)[0];
-    size_t nb2     = ((int32_t *) opt0->data)[1];
-    size_t nb3     = ((int32_t *) opt0->data)[2];
-    size_t offset  = ((int32_t *) opt0->data)[3];
-    bool   inplace = (bool) ((int32_t *) opt0->data)[4];
+    size_t nb1     = ((int32_t *) dst->op_params)[0];
+    size_t nb2     = ((int32_t *) dst->op_params)[1];
+    size_t nb3     = ((int32_t *) dst->op_params)[2];
+    size_t offset  = ((int32_t *) dst->op_params)[3];
+    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
 
     if (!inplace && (params->type == GGML_TASK_INIT)) {
         // memcpy needs to be synchronized across threads to avoid race conditions.
@@ -9064,13 +8855,12 @@ static void ggml_compute_forward_acc(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-        const struct ggml_tensor * opt0,
         struct ggml_tensor * dst) {
 
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_acc_f32(params, src0, src1, opt0, dst);
+                ggml_compute_forward_acc_f32(params, src0, src1, dst);
             } break;
         case GGML_TYPE_F16:
         case GGML_TYPE_Q4_0:
@@ -11090,21 +10880,17 @@ static void ggml_compute_forward_set_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-        const struct ggml_tensor * opt0,
         struct ggml_tensor * dst) {
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
     GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
 
-    GGML_ASSERT(opt0->type == GGML_TYPE_I32);
-    GGML_ASSERT(ggml_nelements(opt0) == 5);
-
     // view src0 and dst with these strides and data offset inbytes during set
     // nb0 is implicitely element_size because src0 and dst are contiguous
-    size_t nb1     = ((int32_t *) opt0->data)[0];
-    size_t nb2     = ((int32_t *) opt0->data)[1];
-    size_t nb3     = ((int32_t *) opt0->data)[2];
-    size_t offset  = ((int32_t *) opt0->data)[3];
-    bool   inplace = (bool) ((int32_t *) opt0->data)[4];
+    size_t nb1     = ((int32_t *) dst->op_params)[0];
+    size_t nb2     = ((int32_t *) dst->op_params)[1];
+    size_t nb3     = ((int32_t *) dst->op_params)[2];
+    size_t offset  = ((int32_t *) dst->op_params)[3];
+    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
 
     if (!inplace && (params->type == GGML_TASK_INIT)) {
         // memcpy needs to be synchronized across threads to avoid race conditions.
@@ -11164,13 +10950,12 @@ static void ggml_compute_forward_set(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-        const struct ggml_tensor * opt0,
         struct ggml_tensor * dst) {
 
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_set_f32(params, src0, src1, opt0, dst);
+                ggml_compute_forward_set_f32(params, src0, src1, dst);
             } break;
         case GGML_TYPE_F16:
         case GGML_TYPE_Q4_0:
@@ -11566,17 +11351,14 @@ static void ggml_compute_forward_diag(
 static void ggml_compute_forward_diag_mask_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst,
         const float value) {
-    GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    GGML_ASSERT(ggml_nelements(src1) == 2);
 
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int  n_past  =       ((int32_t *) src1->data)[0];
-    const bool inplace = (bool)((int32_t *) src1->data)[1];
+    const int  n_past  =       ((int32_t *) dst->op_params)[0];
+    const bool inplace = (bool)((int32_t *) dst->op_params)[1];
 
     GGML_ASSERT(n_past >= 0);
 
@@ -11619,12 +11401,11 @@ static void ggml_compute_forward_diag_mask_f32(
 static void ggml_compute_forward_diag_mask_inf(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_diag_mask_f32(params, src0, src1, dst, -INFINITY);
+                ggml_compute_forward_diag_mask_f32(params, src0, dst, -INFINITY);
             } break;
         default:
             {
@@ -11636,12 +11417,11 @@ static void ggml_compute_forward_diag_mask_inf(
 static void ggml_compute_forward_diag_mask_zero(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_diag_mask_f32(params, src0, src1, dst, 0);
+                ggml_compute_forward_diag_mask_f32(params, src0, dst, 0);
             } break;
         default:
             {
@@ -11839,20 +11619,17 @@ static void ggml_compute_forward_soft_max_back(
 static void ggml_compute_forward_alibi_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     assert(params->ith == 0);
 
-    GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    GGML_ASSERT(ggml_nelements(src1) == 3);
-
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
-    const int   n_past   = ((int32_t *) src1->data)[0];
-    const int   n_head   = ((int32_t *) src1->data)[1];
-    const float max_bias = ((float *)   src1->data)[2];
+    const int n_past = ((int32_t *) dst->op_params)[0];
+    const int n_head = ((int32_t *) dst->op_params)[1];
+    float max_bias;
+    memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
 
     assert(n_past >= 0);
 
@@ -11905,20 +11682,17 @@ static void ggml_compute_forward_alibi_f32(
 static void ggml_compute_forward_alibi_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     assert(params->ith == 0);
 
-    GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    GGML_ASSERT(ggml_nelements(src1) == 3);
-
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
-    const int   n_past   = ((int32_t *) src1->data)[0];
-    const int   n_head   = ((int32_t *) src1->data)[1];
-    const float max_bias = ((float *)   src1->data)[2];
+    const int n_past = ((int32_t *) dst->op_params)[0];
+    const int n_head = ((int32_t *) dst->op_params)[1];
+    float max_bias;
+    memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
 
     assert(n_past >= 0);
 
@@ -11971,16 +11745,15 @@ static void ggml_compute_forward_alibi_f16(
 static void ggml_compute_forward_alibi(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F16:
             {
-                ggml_compute_forward_alibi_f16(params, src0, src1, dst);
+                ggml_compute_forward_alibi_f16(params, src0, dst);
             } break;
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_alibi_f32(params, src0, src1, dst);
+                ggml_compute_forward_alibi_f32(params, src0, dst);
             } break;
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
@@ -12010,19 +11783,17 @@ static void ggml_compute_forward_alibi(
 static void ggml_compute_forward_clamp_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     assert(params->ith == 0);
 
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT(ggml_nelements(src1) == 2);
-
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
-    const float min = ((float *) src1->data)[0];
-    const float max = ((float *) src1->data)[1];
+    float min;
+    float max;
+    memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -12052,12 +11823,11 @@ static void ggml_compute_forward_clamp_f32(
 static void ggml_compute_forward_clamp(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_clamp_f32(params, src0, src1, dst);
+                ggml_compute_forward_clamp_f32(params, src0, dst);
             } break;
         case GGML_TYPE_F16:
         case GGML_TYPE_Q4_0:
@@ -12087,10 +11857,7 @@ static void ggml_compute_forward_clamp(
 static void ggml_compute_forward_rope_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    GGML_ASSERT(ggml_nelements(src1) == 6);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -12099,12 +11866,12 @@ static void ggml_compute_forward_rope_f32(
     float freq_base;
     float freq_scale;
 
-    const int n_past = ((int32_t *) src1->data)[0];
-    const int n_dims = ((int32_t *) src1->data)[1];
-    const int mode   = ((int32_t *) src1->data)[2];
-    const int n_ctx  = ((int32_t *) src1->data)[3];
-    memcpy(&freq_base,  (int32_t *) src1->data + 4, sizeof(float));
-    memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
+    const int n_past = ((int32_t *) dst->op_params)[0];
+    const int n_dims = ((int32_t *) dst->op_params)[1];
+    const int mode   = ((int32_t *) dst->op_params)[2];
+    const int n_ctx  = ((int32_t *) dst->op_params)[3];
+    memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
+    memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
 
     assert(n_past >= 0);
 
@@ -12219,10 +11986,7 @@ static void ggml_compute_forward_rope_f32(
 static void ggml_compute_forward_rope_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    GGML_ASSERT(ggml_nelements(src1) == 6);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -12231,12 +11995,12 @@ static void ggml_compute_forward_rope_f16(
     float freq_base;
     float freq_scale;
 
-    const int n_past = ((int32_t *) src1->data)[0];
-    const int n_dims = ((int32_t *) src1->data)[1];
-    const int mode   = ((int32_t *) src1->data)[2];
-    const int n_ctx  = ((int32_t *) src1->data)[3];
-    memcpy(&freq_base,  (int32_t *) src1->data + 4, sizeof(float));
-    memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
+    const int n_past = ((int32_t *) dst->op_params)[0];
+    const int n_dims = ((int32_t *) dst->op_params)[1];
+    const int mode   = ((int32_t *) dst->op_params)[2];
+    const int n_ctx  = ((int32_t *) dst->op_params)[3];
+    memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
+    memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
 
     assert(n_past >= 0);
 
@@ -12351,16 +12115,15 @@ static void ggml_compute_forward_rope_f16(
 static void ggml_compute_forward_rope(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F16:
             {
-                ggml_compute_forward_rope_f16(params, src0, src1, dst);
+                ggml_compute_forward_rope_f16(params, src0, dst);
             } break;
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_rope_f32(params, src0, src1, dst);
+                ggml_compute_forward_rope_f32(params, src0, dst);
             } break;
         default:
             {
@@ -12374,10 +12137,7 @@ static void ggml_compute_forward_rope(
 static void ggml_compute_forward_rope_back_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    assert(src1->type == GGML_TYPE_I32);
-    assert(ggml_nelements(src1) == 3);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -12387,9 +12147,9 @@ static void ggml_compute_forward_rope_back_f32(
     // dx = rope_back(dy, src1)
     // src0 is dy, src1 contains options
 
-    const int n_past = ((int32_t *) src1->data)[0];
-    const int n_dims = ((int32_t *) src1->data)[1];
-    const int mode   = ((int32_t *) src1->data)[2];
+    const int n_past = ((int32_t *) dst->op_params)[0];
+    const int n_dims = ((int32_t *) dst->op_params)[1];
+    const int mode   = ((int32_t *) dst->op_params)[2];
 
     assert(n_past >= 0);
 
@@ -12473,10 +12233,7 @@ static void ggml_compute_forward_rope_back_f32(
 static void ggml_compute_forward_rope_back_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    assert(src1->type == GGML_TYPE_I32);
-    assert(ggml_nelements(src1) == 3);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -12486,9 +12243,9 @@ static void ggml_compute_forward_rope_back_f16(
     // dx = rope_back(dy, src1)
     // src0 is dy, src1 contains options
 
-    const int n_past = ((int32_t *) src1->data)[0];
-    const int n_dims = ((int32_t *) src1->data)[1];
-    const int mode   = ((int32_t *) src1->data)[2];
+    const int n_past = ((int32_t *) dst->op_params)[0];
+    const int n_dims = ((int32_t *) dst->op_params)[1];
+    const int mode   = ((int32_t *) dst->op_params)[2];
 
     assert(n_past >= 0);
 
@@ -12572,16 +12329,15 @@ static void ggml_compute_forward_rope_back_f16(
 static void ggml_compute_forward_rope_back(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F16:
             {
-                ggml_compute_forward_rope_back_f16(params, src0, src1, dst);
+                ggml_compute_forward_rope_back_f16(params, src0, dst);
             } break;
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_rope_back_f32(params, src0, src1, dst);
+                ggml_compute_forward_rope_back_f32(params, src0, dst);
             } break;
         default:
             {
@@ -12778,7 +12534,7 @@ static void ggml_compute_forward_conv_1d_s1_ph(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-        struct ggml_tensor * dst) {
+              struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F16:
             {
@@ -12981,7 +12737,7 @@ static void ggml_compute_forward_conv_1d_s2_ph(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-        struct ggml_tensor * dst) {
+              struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F16:
             {
@@ -13001,14 +12757,13 @@ static void ggml_compute_forward_conv_1d_s2_ph(
 // ggml_compute_forward_conv_1d
 
 static void ggml_compute_forward_conv_1d(
-    const struct ggml_compute_params * params,
-    const struct ggml_tensor * src0,
-    const struct ggml_tensor * src1,
-    const struct ggml_tensor * opt0,
-    struct ggml_tensor * dst) {
-    const int32_t s0 = ((const int32_t*)(opt0->data))[0];
-    const int32_t p0 = ((const int32_t*)(opt0->data))[1];
-    const int32_t d0 = ((const int32_t*)(opt0->data))[2];
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+    const int32_t p0 = ((const int32_t*)(dst->op_params))[1];
+    const int32_t d0 = ((const int32_t*)(dst->op_params))[2];
     GGML_ASSERT(d0 == 1); // dilation not supported
     GGML_ASSERT(p0 == src0->ne[0]/2); // only half padding supported
     if (s0 == 1) {
@@ -13026,7 +12781,6 @@ static void ggml_compute_forward_conv_2d_f16_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-        const struct ggml_tensor * opt0,
               struct ggml_tensor * dst) {
     GGML_ASSERT(src0->type == GGML_TYPE_F16);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -13046,12 +12800,12 @@ static void ggml_compute_forward_conv_2d_f16_f32(
     // size of the convolution row - the kernel size unrolled across all channels
     const int ew0 = nk0*nk1*ne02;
 
-    const int32_t s0 = ((const int32_t*)(opt0->data))[0];
-    const int32_t s1 = ((const int32_t*)(opt0->data))[1];
-    const int32_t p0 = ((const int32_t*)(opt0->data))[2];
-    const int32_t p1 = ((const int32_t*)(opt0->data))[3];
-    const int32_t d0 = ((const int32_t*)(opt0->data))[4];
-    const int32_t d1 = ((const int32_t*)(opt0->data))[5];
+    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+    const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
+    const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
+    const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
+    const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
+    const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
 
     GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
     GGML_ASSERT(nb10 == sizeof(float));
@@ -13123,17 +12877,15 @@ static void ggml_compute_forward_conv_2d(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-        const struct ggml_tensor * opt0,
-        struct ggml_tensor * dst
-        ) {
+              struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F16:
             {
-                ggml_compute_forward_conv_2d_f16_f32(params, src0, src1, opt0, dst);
+                ggml_compute_forward_conv_2d_f16_f32(params, src0, src1, dst);
             } break;
         case GGML_TYPE_F32:
             {
-                //ggml_compute_forward_conv_2d_f32(params, src0, src1, opt0, dst);
+                //ggml_compute_forward_conv_2d_f32(params, src0, src1, dst);
                 GGML_ASSERT(false);
             } break;
         default:
@@ -13198,12 +12950,11 @@ static void ggml_compute_forward_pool_1d_sk_p0(
 // ggml_compute_forward_pool_1d
 
 static void ggml_compute_forward_pool_1d(
-    const struct ggml_compute_params* params,
-    const struct ggml_tensor* src0,
-    const struct ggml_tensor* opt0,
-    struct ggml_tensor* dst) {
-    GGML_ASSERT(opt0->ne[0] == 4);
-    const int* opts = (const int*)opt0->data;
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+              struct ggml_tensor * dst) {
+
+    const int32_t* opts = (const int32_t*)dst->op_params;
     enum ggml_op_pool op = opts[0];
     const int k0 = opts[1];
     const int s0 = opts[2];
@@ -13217,12 +12968,12 @@ static void ggml_compute_forward_pool_1d(
 // ggml_compute_forward_pool_2d_sk_p0
 
 static void ggml_compute_forward_pool_2d_sk_p0(
-    const struct ggml_compute_params * params,
-    const enum ggml_op_pool op,
-    const struct ggml_tensor * src,
-    const int k0,
-    const int k1,
-    struct ggml_tensor * dst) {
+        const struct ggml_compute_params * params,
+        const enum   ggml_op_pool op,
+        const struct ggml_tensor * src,
+        const int k0,
+        const int k1,
+        struct ggml_tensor * dst) {
     assert(src->type == GGML_TYPE_F32);
     assert(params->ith == 0);
 
@@ -13282,12 +13033,11 @@ static void ggml_compute_forward_pool_2d_sk_p0(
 // ggml_compute_forward_pool_2d
 
 static void ggml_compute_forward_pool_2d(
-    const struct ggml_compute_params * params,
-    const struct ggml_tensor * src0,
-    const struct ggml_tensor * opt0,
-    struct ggml_tensor * dst) {
-    GGML_ASSERT(opt0->ne[0] == 7);
-    const int* opts = (const int*)opt0->data;
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+              struct ggml_tensor * dst) {
+
+    const int32_t * opts = (const int32_t *)dst->op_params;
     enum ggml_op_pool op = opts[0];
     const int k0 = opts[1];
     const int k1 = opts[2];
@@ -13312,7 +13062,7 @@ static void ggml_compute_forward_flash_attn_f32(
         const struct ggml_tensor * k,
         const struct ggml_tensor * v,
         const bool masked,
-             struct ggml_tensor * dst) {
+        struct ggml_tensor * dst) {
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
@@ -13490,7 +13240,7 @@ static void ggml_compute_forward_flash_attn_f16(
         const struct ggml_tensor * k,
         const struct ggml_tensor * v,
         const bool masked,
-             struct ggml_tensor * dst) {
+        struct ggml_tensor * dst) {
     int64_t t0 = ggml_perf_time_us();
     UNUSED(t0);
 
@@ -14255,7 +14005,6 @@ static void ggml_compute_forward_flash_attn_back(
 static void ggml_compute_forward_win_part_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * opt0,
         struct ggml_tensor * dst) {
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -14264,9 +14013,9 @@ static void ggml_compute_forward_win_part_f32(
     GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
     GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne);
 
-    const int32_t nep0 = ((const int32_t *)(opt0->data))[0];
-    const int32_t nep1 = ((const int32_t *)(opt0->data))[1];
-    const int32_t w    = ((const int32_t *)(opt0->data))[2];
+    const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
+    const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
+    const int32_t w    = ((const int32_t *)(dst->op_params))[2];
 
     assert(ne00 == ne0);
     assert(ne3  == nep0*nep1);
@@ -14300,12 +14049,11 @@ static void ggml_compute_forward_win_part_f32(
 static void ggml_compute_forward_win_part(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * opt0,
         struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_win_part_f32(params, src0, opt0, dst);
+                ggml_compute_forward_win_part_f32(params, src0, dst);
             } break;
         default:
             {
@@ -14319,7 +14067,6 @@ static void ggml_compute_forward_win_part(
 static void ggml_compute_forward_win_unpart_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * opt0,
         struct ggml_tensor * dst) {
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -14328,7 +14075,7 @@ static void ggml_compute_forward_win_unpart_f32(
     GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
     GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne);
 
-    const int32_t w = ((const int32_t *)(opt0->data))[0];
+    const int32_t w = ((const int32_t *)(dst->op_params))[0];
 
     // padding
     const int px = (w - ne1%w)%w;
@@ -14362,12 +14109,11 @@ static void ggml_compute_forward_win_unpart_f32(
 static void ggml_compute_forward_win_unpart(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * opt0,
         struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_win_unpart_f32(params, src0, opt0, dst);
+                ggml_compute_forward_win_unpart_f32(params, src0, dst);
             } break;
         default:
             {
@@ -14886,7 +14632,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_ACC:
             {
-                ggml_compute_forward_acc(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
+                ggml_compute_forward_acc(params, tensor->src[0], tensor->src[1], tensor);
             } break;
         case GGML_OP_SUB:
             {
@@ -15006,7 +14752,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_SET:
             {
-                ggml_compute_forward_set(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
+                ggml_compute_forward_set(params, tensor->src[0], tensor->src[1], tensor);
             } break;
         case GGML_OP_CPY:
             {
@@ -15046,11 +14792,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_DIAG_MASK_INF:
             {
-                ggml_compute_forward_diag_mask_inf(params, tensor->src[0], tensor->src[1], tensor);
+                ggml_compute_forward_diag_mask_inf(params, tensor->src[0], tensor);
             } break;
         case GGML_OP_DIAG_MASK_ZERO:
             {
-                ggml_compute_forward_diag_mask_zero(params, tensor->src[0], tensor->src[1], tensor);
+                ggml_compute_forward_diag_mask_zero(params, tensor->src[0], tensor);
             } break;
         case GGML_OP_SOFT_MAX:
             {
@@ -15062,35 +14808,35 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_ROPE:
             {
-                ggml_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor);
+                ggml_compute_forward_rope(params, tensor->src[0], tensor);
             } break;
         case GGML_OP_ROPE_BACK:
             {
-                ggml_compute_forward_rope_back(params, tensor->src[0], tensor->src[1], tensor);
+                ggml_compute_forward_rope_back(params, tensor->src[0], tensor);
             } break;
         case GGML_OP_ALIBI:
             {
-                ggml_compute_forward_alibi(params, tensor->src[0], tensor->src[1], tensor);
+                ggml_compute_forward_alibi(params, tensor->src[0], tensor);
             } break;
         case GGML_OP_CLAMP:
             {
-                ggml_compute_forward_clamp(params, tensor->src[0], tensor->src[1], tensor);
+                ggml_compute_forward_clamp(params, tensor->src[0], tensor);
             } break;
         case GGML_OP_CONV_1D:
             {
-                ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
+                ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor);
             } break;
         case GGML_OP_CONV_2D:
             {
-                ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
+                ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor);
             } break;
         case GGML_OP_POOL_1D:
             {
-                ggml_compute_forward_pool_1d(params, tensor->src[0], tensor->src[1], tensor);
+                ggml_compute_forward_pool_1d(params, tensor->src[0], tensor);
             } break;
         case GGML_OP_POOL_2D:
             {
-                ggml_compute_forward_pool_2d(params, tensor->src[0], tensor->src[1], tensor);
+                ggml_compute_forward_pool_2d(params, tensor->src[0], tensor);
             } break;
         case GGML_OP_FLASH_ATTN:
             {
@@ -15112,40 +14858,45 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_WIN_PART:
             {
-                ggml_compute_forward_win_part(params, tensor->src[0], tensor->src[2], tensor);
+                ggml_compute_forward_win_part(params, tensor->src[0], tensor);
             } break;
         case GGML_OP_WIN_UNPART:
             {
-                ggml_compute_forward_win_unpart(params, tensor->src[0], tensor->src[2], tensor);
+                ggml_compute_forward_win_unpart(params, tensor->src[0], tensor);
             } break;
         case GGML_OP_MAP_UNARY:
             {
-                const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->src[2]->data);
+                ggml_unary_op_f32_t fun;
+                memcpy(&fun, tensor->op_params, sizeof(fun));
                 ggml_compute_forward_map_unary(params, tensor->src[0], tensor, fun);
             }
             break;
         case GGML_OP_MAP_BINARY:
             {
-                const ggml_binary_op_f32_t fun = *((ggml_binary_op_f32_t *)tensor->src[2]->data);
+                ggml_binary_op_f32_t fun;
+                memcpy(&fun, tensor->op_params, sizeof(fun));
                 ggml_compute_forward_map_binary(params, tensor->src[0], tensor->src[1], tensor, fun);
             }
             break;
         case GGML_OP_MAP_CUSTOM1:
             {
-                const ggml_custom1_op_f32_t fun = *((ggml_custom1_op_f32_t *)tensor->src[2]->data);
+                ggml_custom1_op_f32_t fun;
+                memcpy(&fun, tensor->op_params, sizeof(fun));
                 ggml_compute_forward_map_custom1(params, tensor->src[0], tensor, fun);
             }
             break;
         case GGML_OP_MAP_CUSTOM2:
             {
-                const ggml_custom2_op_f32_t fun = *((ggml_custom2_op_f32_t *)tensor->src[2]->data);
+                ggml_custom2_op_f32_t fun;
+                memcpy(&fun, tensor->op_params, sizeof(fun));
                 ggml_compute_forward_map_custom2(params, tensor->src[0], tensor->src[1], tensor, fun);
             }
             break;
         case GGML_OP_MAP_CUSTOM3:
             {
-                const ggml_custom3_op_f32_t fun = *((ggml_custom3_op_f32_t *)tensor->src[2]->data);
-                ggml_compute_forward_map_custom3(params, tensor->src[0], tensor->src[1], tensor->src[3], tensor, fun);
+                ggml_custom3_op_f32_t fun;
+                memcpy(&fun, tensor->op_params, sizeof(fun));
+                ggml_compute_forward_map_custom3(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor, fun);
             }
             break;
         case GGML_OP_CROSS_ENTROPY_LOSS:
@@ -15209,12 +14960,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
                 }
                 if (src1->grad) {
-                    GGML_ASSERT(ggml_nelements(tensor->src[2]) == 5);
-                    GGML_ASSERT(tensor->src[2]->type == GGML_TYPE_I32);
-                    const size_t nb1     = (( int32_t * ) tensor->src[2]->data)[0];
-                    const size_t nb2     = (( int32_t * ) tensor->src[2]->data)[1];
-                    const size_t nb3     = (( int32_t * ) tensor->src[2]->data)[2];
-                    const size_t offset  = (( int32_t * ) tensor->src[2]->data)[3];
+                    const size_t nb1     = ((int32_t *) tensor->op_params)[0];
+                    const size_t nb2     = ((int32_t *) tensor->op_params)[1];
+                    const size_t nb3     = ((int32_t *) tensor->op_params)[2];
+                    const size_t offset  = ((int32_t *) tensor->op_params)[3];
 
                     struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx,
                         tensor->grad,
@@ -15522,12 +15271,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             } break;
         case GGML_OP_SET:
             {
-                GGML_ASSERT(ggml_nelements(tensor->src[2]) == 5);
-                GGML_ASSERT(tensor->src[2]->type == GGML_TYPE_I32);
-                const size_t nb1     = (( int32_t * ) tensor->src[2]->data)[0];
-                const size_t nb2     = (( int32_t * ) tensor->src[2]->data)[1];
-                const size_t nb3     = (( int32_t * ) tensor->src[2]->data)[2];
-                const size_t offset  = (( int32_t * ) tensor->src[2]->data)[3];
+                const size_t nb1     = ((int32_t *) tensor->op_params)[0];
+                const size_t nb2     = ((int32_t *) tensor->op_params)[1];
+                const size_t nb3     = ((int32_t *) tensor->op_params)[2];
+                const size_t offset  = ((int32_t *) tensor->op_params)[3];
 
                 struct ggml_tensor * tensor_grad_view = NULL;
 
@@ -15604,8 +15351,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 if (src0->grad) {
                     size_t offset;
 
-                    GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->src[2]));
-                    memcpy(&offset, tensor->src[2]->data, sizeof(offset));
+                    memcpy(&offset, tensor->op_params, sizeof(offset));
 
                     size_t nb1     = tensor->nb[1];
                     size_t nb2     = tensor->nb[2];
@@ -15632,7 +15378,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 // necessary for llama
                 if (src0->grad) {
-                    int32_t * axes = (int32_t *) tensor->src[2]->data;
+                    int32_t * axes = (int32_t *) tensor->op_params;
                     int axis0 = axes[0] & 0x3;
                     int axis1 = axes[1] & 0x3;
                     int axis2 = axes[2] & 0x3;
@@ -15688,33 +15434,23 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 // necessary for llama
                 if (src0->grad) {
-                    assert(src1->type == GGML_TYPE_I32);
-                    assert(ggml_nelements(src1) == 2);
-                    const int n_past = ((int32_t *) src1->data)[0];
+                    const int n_past = ((int32_t *) tensor->op_params)[0];
                     src0->grad =
                         ggml_add_impl(ctx, src0->grad,
                             ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
                         inplace);
                 }
-                if (src1->grad) {
-                    // noop
-                }
             } break;
         case GGML_OP_DIAG_MASK_ZERO:
             {
                 // necessary for llama
                 if (src0->grad) {
-                    assert(src1->type == GGML_TYPE_I32);
-                    assert(ggml_nelements(src1) == 2);
-                    const int n_past = ((int32_t *) src1->data)[0];
+                    const int n_past = ((int32_t *) tensor->op_params)[0];
                     src0->grad =
                         ggml_add_impl(ctx, src0->grad,
                             ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
                         inplace);
                 }
-                if (src1->grad) {
-                    // noop
-                }
             } break;
         case GGML_OP_SOFT_MAX:
             {
@@ -15735,33 +15471,28 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 // necessary for llama
                 if (src0->grad) {
-                    assert(src1->type == GGML_TYPE_I32);
-                    assert(ggml_nelements(src1) == 6);
-                    const int n_past = ((int32_t *) src1->data)[0];
-                    const int n_dims = ((int32_t *) src1->data)[1];
-                    const int mode   = ((int32_t *) src1->data)[2];
+                    const int n_past = ((int32_t *) tensor->op_params)[0];
+                    const int n_dims = ((int32_t *) tensor->op_params)[1];
+                    const int mode   = ((int32_t *) tensor->op_params)[2];
+                    const int n_ctx  = ((int32_t *) tensor->op_params)[3];
                     src0->grad = ggml_add_impl(ctx,
                             src0->grad,
                             ggml_rope_back(ctx,
                                 tensor->grad,
                                 n_past,
                                 n_dims,
-                                mode),
+                                mode,
+                                n_ctx),
                             inplace);
                 }
-                if (src1->grad) {
-                    // noop
-                }
             } break;
         case GGML_OP_ROPE_BACK:
             {
                 if (src0->grad) {
-                    assert(src1->type == GGML_TYPE_I32);
-                    assert(ggml_nelements(src1) == 3);
-                    const int n_past = ((int32_t *) src1->data)[0];
-                    const int n_dims = ((int32_t *) src1->data)[1];
-                    const int mode   = ((int32_t *) src1->data)[2];
-                    const int n_ctx  = ((int32_t *) src1->data)[3];
+                    const int n_past = ((int32_t *) tensor->op_params)[0];
+                    const int n_dims = ((int32_t *) tensor->op_params)[1];
+                    const int mode   = ((int32_t *) tensor->op_params)[2];
+                    const int n_ctx  = ((int32_t *) tensor->op_params)[3];
                     src0->grad = ggml_add_impl(ctx,
                             src0->grad,
                             ggml_rope(ctx,
@@ -15772,9 +15503,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 n_ctx),
                             inplace);
                 }
-                if (src1->grad) {
-                    // noop
-                }
             } break;
         case GGML_OP_ALIBI:
             {
@@ -16538,10 +16266,10 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
             case GGML_OP_GET_ROWS:
             case GGML_OP_GET_ROWS_BACK:
             case GGML_OP_DIAG:
-            case GGML_OP_DIAG_MASK_ZERO:
                 {
                     n_tasks = 1;
                 } break;
+            case GGML_OP_DIAG_MASK_ZERO:
             case GGML_OP_DIAG_MASK_INF:
             case GGML_OP_SOFT_MAX:
             case GGML_OP_SOFT_MAX_BACK:
@@ -16988,7 +16716,8 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
                     fwrite(&nb, sizeof(uint64_t), 1, fout);
                 }
 
-                fwrite(tensor->name, sizeof(char), GGML_MAX_NAME, fout);
+                fwrite(tensor->name,      sizeof(char), GGML_MAX_NAME,      fout);
+                fwrite(tensor->op_params, sizeof(char), GGML_MAX_OP_PARAMS, fout);
 
                 // dump the data
                 // TODO: pad this to 32 byte boundary
@@ -17021,7 +16750,8 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
                     fwrite(&nb, sizeof(uint64_t), 1, fout);
                 }
 
-                fwrite(tensor->name, sizeof(char), GGML_MAX_NAME, fout);
+                fwrite(tensor->name,      sizeof(char), GGML_MAX_NAME,      fout);
+                fwrite(tensor->op_params, sizeof(char), GGML_MAX_OP_PARAMS, fout);
 
                 // output the op arguments
                 {
@@ -17202,7 +16932,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
 
                 tensor->op = (enum ggml_op) op;
 
-                memcpy(tensor->name, ptr, GGML_MAX_NAME); ptr += GGML_MAX_NAME;
+                memcpy(tensor->name,      ptr, GGML_MAX_NAME);      ptr += GGML_MAX_NAME;
+                memcpy(tensor->op_params, ptr, GGML_MAX_OP_PARAMS); ptr += GGML_MAX_OP_PARAMS;
 
                 tensor->data = (void *) ptr;
 
@@ -17247,7 +16978,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
                     nb[j] = nb_cur;
                 }
 
-                const char * ptr_name = ptr; ptr += GGML_MAX_NAME;
+                const char * ptr_name      = ptr; ptr += GGML_MAX_NAME;
+                const char * ptr_op_params = ptr; ptr += GGML_MAX_OP_PARAMS;
 
                 const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += GGML_MAX_SRC*sizeof(int32_t);
 
@@ -17284,8 +17016,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
                         {
                             tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0);
 
-                            uint64_t offs;
-                            memcpy(&offs, args[2]->data, sizeof(offs));
+                            size_t offs;
+                            memcpy(&offs, ptr_op_params, sizeof(offs));
 
                             tensor->data = ((char *) tensor->data) + offs;
                         } break;
@@ -17305,7 +17037,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
                         } break;
                 }
 
-                memcpy(tensor->name, ptr_name, GGML_MAX_NAME);
+                memcpy(tensor->name,      ptr_name,      GGML_MAX_NAME);
+                memcpy(tensor->op_params, ptr_op_params, GGML_MAX_OP_PARAMS);
 
                 for (int j = 0; j < GGML_MAX_DIMS; ++j) {
                     tensor->nb[j] = nb[j];