]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
test-backend-ops : add performance eval mode + improve CUDA repeat and binary broadca...
authorslaren <redacted>
Thu, 7 Dec 2023 08:51:46 +0000 (09:51 +0100)
committerGitHub <redacted>
Thu, 7 Dec 2023 08:51:46 +0000 (09:51 +0100)
* ggml-cuda : implement repeat with bin_bcast

* ggml-cuda : change supports_op for mul_mat to match compute_forward

* test-backend-ops : add performance eval mode

* improve formatting

* add sd test cases

* fix test case

* ggml-cuda : bin_bcast: better block sizes, two elements per thread

* metal : add dim3 broadcast support for mul mat

* cleanup

* typo

* metal : enable mul mat-vec for dim2 > 1

* metal : mul mat-vec support dim3 broadcasts

ggml-ci

* ggml-cuda : fix bin_bcast for ne0=1
ggml-ci

* ggml-cuda : limit block size z dim to 64

* test-backend-ops : add test cases

* test-backend-ops : add warmup run, print test type before trying to compute

* ggml-cuda : bin_bcast: collapse dimensions when possible, add fallback kernel for large tensors
ggml-ci

* test-backend-ops : avoid division by zero

---------

Co-authored-by: Georgi Gerganov <redacted>
src/ggml-cuda.cu
src/ggml-metal.m
src/ggml-metal.metal
tests/test-backend-ops.cpp

index dbe92d97eb395ddeab050845c55afc33c5ed0a8a..8fbd5c7e640829c596d0d41bac85d0f87a63ce21 100644 (file)
@@ -434,7 +434,6 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 #define WARP_SIZE 32
 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
 
-#define CUDA_ADDMUL_BLOCK_SIZE 256
 #define CUDA_GELU_BLOCK_SIZE 256
 #define CUDA_SILU_BLOCK_SIZE 256
 #define CUDA_RELU_BLOCK_SIZE 256
@@ -501,6 +500,10 @@ static size_t g_scratch_offset = 0;
 
 static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
 
+static __device__ __forceinline__ float op_repeat(const float a, const float b) {
+    return b;
+}
+
 static __device__ __forceinline__ float op_add(const float a, const float b) {
     return a + b;
 }
@@ -515,29 +518,69 @@ static __device__ __forceinline__ float op_div(const float a, const float b) {
 
 template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
 static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
-        int ne0,/* int ne1, int ne2, */int ne3,
+        int ne0, int ne1, int ne2, int ne3,
         int ne10, int ne11, int ne12, int ne13,
         /*int s0, */ int s1,  int s2,  int s3,
         /*int s10,*/ int s11, int s12, int s13) {
-    const int i0 = blockDim.x*blockIdx.x + threadIdx.x;
-    const int i1 = blockIdx.y;
-    const int i2 = blockIdx.z / ne3;
-    const int i3 = blockIdx.z % ne3;
+    const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
+    const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
+    const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
+    const int i3 = (blockDim.z*blockIdx.z + threadIdx.z) % ne3;
 
-    if (i0 >= ne0) {
+    if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
+        return;
+    }
+
+    const int i11 = i1 % ne11;
+    const int i12 = i2 % ne12;
+    const int i13 = i3 % ne13;
+
+    const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
+    const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
+    const size_t i_dst  = i_src0;
+
+    const src0_t * src0_row = src0 + i_src0;
+    const src1_t * src1_row = src1 + i_src1;
+    dst_t * dst_row = dst + i_dst;
+
+    for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) {
+        const int i10 = i0 % ne10;
+        dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+    }
+}
+
+template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
+static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
+        int ne0, int ne1, int ne2, int ne3,
+        int ne10, int ne11, int ne12, int ne13,
+        /*int s0, */ int s1,  int s2,  int s3,
+        /*int s10,*/ int s11, int s12, int s13) {
+
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    const int i3 = i/(ne2*ne1*ne0);
+    const int i2 = (i/(ne1*ne0)) % ne2;
+    const int i1 = (i/ne0) % ne1;
+    const int i0 = i % ne0;
+
+    if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
         return;
     }
 
-    const int i10 = i0 % ne10;
     const int i11 = i1 % ne11;
     const int i12 = i2 % ne12;
     const int i13 = i3 % ne13;
 
-    const size_t i_dst  = i3*s3 + i2*s2 + i1*s1 + i0;
-    const size_t i_src0 = i_dst;
-    const size_t i_src1 = i13*s13 + i12*s12 + i11*s11 + i10;
+    const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
+    const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
+    const size_t i_dst  = i_src0;
+
+    const src0_t * src0_row = src0 + i_src0;
+    const src1_t * src1_row = src1 + i_src1;
+    dst_t * dst_row = dst + i_dst;
 
-    dst[i_dst] = (dst_t)bin_op((float)src0[i_src0], (float)src1[i_src1]);
+    const int i10 = i0 % ne10;
+    dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
 }
 
 static __global__ void gelu_f32(const float * x, float * dst, const int k) {
@@ -4849,24 +4892,108 @@ struct bin_bcast_cuda {
 
         GGML_TENSOR_BINARY_OP_LOCALS
 
-        //size_t s0 = nb0 / sizeof(src1_t);
-        size_t s1 = nb1 / sizeof(src1_t);
-        size_t s2 = nb2 / sizeof(src1_t);
-        size_t s3 = nb3 / sizeof(src1_t);
-
-        //size_t s10 = nb10 / sizeof(src1_t);
-        size_t s11 = nb11 / sizeof(src1_t);
-        size_t s12 = nb12 / sizeof(src1_t);
-        size_t s13 = nb13 / sizeof(src1_t);
 
-        const int num_blocks_x = (ne0 + CUDA_ADDMUL_BLOCK_SIZE - 1) / CUDA_ADDMUL_BLOCK_SIZE;
-        dim3 num_blocks(num_blocks_x, ne1, ne2*ne3);
+        int nr0 = ne10/ne0;
+        int nr1 = ne11/ne1;
+        int nr2 = ne12/ne2;
+        int nr3 = ne13/ne3;
+
+        int nr[4] = { nr0, nr1, nr2, nr3 };
+
+        // collapse dimensions until first broadcast dimension
+        int64_t cne0[] = {ne0, ne1, ne2, ne3};
+        int64_t cne1[] = {ne10, ne11, ne12, ne13};
+        size_t cnb0[] = {nb0, nb1, nb2, nb3};
+        size_t cnb1[] = {nb10, nb11, nb12, nb13};
+        auto collapse = [](int64_t cne[]) {
+            cne[0] *= cne[1];
+            cne[1] = cne[2];
+            cne[2] = cne[3];
+            cne[3] = 1;
+        };
+
+        auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
+            cnb[1] *= cne[1];
+            cnb[2] *= cne[2];
+            cnb[3] *= cne[3];
+        };
+
+        for (int i = 0; i < 4; i++) {
+            if (nr[i] != 1) {
+                break;
+            }
+            if (i > 0) {
+                collapse_nb(cnb0, cne0);
+                collapse_nb(cnb1, cne1);
+                collapse(cne0);
+                collapse(cne1);
+            }
+        }
+        {
+            int64_t ne0 = cne0[0];
+            int64_t ne1 = cne0[1];
+            int64_t ne2 = cne0[2];
+            int64_t ne3 = cne0[3];
+
+            int64_t ne10 = cne1[0];
+            int64_t ne11 = cne1[1];
+            int64_t ne12 = cne1[2];
+            int64_t ne13 = cne1[3];
+
+            //size_t nb0 = cnb0[0];
+            size_t nb1 = cnb0[1];
+            size_t nb2 = cnb0[2];
+            size_t nb3 = cnb0[3];
+
+            //size_t nb10 = cnb1[0];
+            size_t nb11 = cnb1[1];
+            size_t nb12 = cnb1[2];
+            size_t nb13 = cnb1[3];
+
+            //size_t s0 = nb0 / sizeof(src1_t);
+            size_t s1 = nb1 / sizeof(src1_t);
+            size_t s2 = nb2 / sizeof(src1_t);
+            size_t s3 = nb3 / sizeof(src1_t);
+
+            //size_t s10 = nb10 / sizeof(src1_t);
+            size_t s11 = nb11 / sizeof(src1_t);
+            size_t s12 = nb12 / sizeof(src1_t);
+            size_t s13 = nb13 / sizeof(src1_t);
+
+
+            const int block_size = 128;
+
+            int64_t hne0 = std::max(ne0/2LL, 1LL);
+
+            dim3 block_dims;
+            block_dims.x = std::min<unsigned int>(hne0, block_size);
+            block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
+            block_dims.z = std::min(std::min<unsigned int>(ne2*ne3, block_size / block_dims.x / block_dims.y), 64U);
+
+            dim3 block_nums(
+                (hne0 + block_dims.x - 1) / block_dims.x,
+                (ne1 + block_dims.y - 1) / block_dims.y,
+                (ne2*ne3 + block_dims.z - 1) / block_dims.z
+            );
 
-        k_bin_bcast<bin_op><<<num_blocks, CUDA_ADDMUL_BLOCK_SIZE, 0, stream>>>(src0_dd, src1_dd, dst_dd,
-            ne0,/* ne1, ne2, */ne3,
-            ne10, ne11, ne12, ne13,
-            /* s0, */s1, s2, s3,
-            /* s10,*/ s11, s12, s13);
+            if (block_nums.z > 65535) {
+                // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
+                int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
+                k_bin_bcast_unravel<bin_op><<<block_num, block_size, 0, stream>>>(
+                    src0_dd, src1_dd, dst_dd,
+                    ne0, ne1, ne2, ne3,
+                    ne10, ne11, ne12, ne13,
+                    /* s0, */ s1, s2, s3,
+                    /* s10, */ s11, s12, s13);
+            } else {
+                k_bin_bcast<bin_op><<<block_nums, block_dims, 0, stream>>>(
+                    src0_dd, src1_dd, dst_dd,
+                    ne0, ne1, ne2, ne3,
+                    ne10, ne11, ne12, ne13,
+                    /* s0, */ s1, s2, s3,
+                    /* s10, */ s11, s12, s13);
+            }
+        }
     }
 };
 
@@ -6096,63 +6223,6 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
     }
 }
 
-static void ggml_cuda_op_repeat(
-    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-    const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) {
-    // guaranteed to be an integer due to the check in ggml_can_repeat
-    const int64_t ne0 = dst->ne[0];
-    const int64_t ne1 = dst->ne[1];
-    const int64_t ne2 = dst->ne[2];
-    const int64_t ne3 = dst->ne[3];
-
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3];
-
-    const size_t nb0 = dst->nb[0];
-    const size_t nb1 = dst->nb[1];
-    const size_t nb2 = dst->nb[2];
-    const size_t nb3 = dst->nb[3];
-
-    const size_t nb00 = src0->nb[0];
-    const size_t nb01 = src0->nb[1];
-    const size_t nb02 = src0->nb[2];
-    const size_t nb03 = src0->nb[3];
-
-    const int nr0 = (int)(ne0/ne00);
-    const int nr1 = (int)(ne1/ne01);
-    const int nr2 = (int)(ne2/ne02);
-    const int nr3 = (int)(ne3/ne03);
-
-    // TODO: support for transposed / permuted tensors
-    GGML_ASSERT(nb0  == sizeof(float));
-    GGML_ASSERT(nb00 == sizeof(float));
-
-    // TODO: very inefficient, implement in a kernel, or fewer cudaMemcpyAsync calls for contiguous tensors
-    for                         (int i3 = 0; i3 < nr3;  i3++) {
-        for                     (int k3 = 0; k3 < ne03; k3++) {
-            for                 (int i2 = 0; i2 < nr2;  i2++) {
-                for             (int k2 = 0; k2 < ne02; k2++) {
-                    for         (int i1 = 0; i1 < nr1;  i1++) {
-                        for     (int k1 = 0; k1 < ne01; k1++) {
-                            for (int i0 = 0; i0 < nr0;  i0++) {
-                                CUDA_CHECK(cudaMemcpyAsync(
-                                              (char *)  dst_d + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1  + (i0*ne00)*nb0,
-                                        (const char *) src0_d + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01,
-                                        ne00*nb0, cudaMemcpyDeviceToDevice, stream));
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    }
-
-    (void) src1;
-    (void) src1_d;
-}
-
 static void ggml_cuda_op_get_rows(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) {
@@ -6215,7 +6285,16 @@ inline void ggml_cuda_op_bin_bcast(
             ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
         GGML_ASSERT(false);
     }
+}
 
+static void ggml_cuda_op_repeat(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & main_stream) {
+
+    ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(dst, src0, dst, nullptr, src0_d, dst_d, main_stream);
+
+    (void) src1;
+    (void) src1_d;
 }
 
 inline void ggml_cuda_op_add(
@@ -8393,7 +8472,8 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
                     break;
                 default:
                     return false;
-            } break;
+            }
+            break;
         case GGML_OP_NORM:
             func = ggml_cuda_norm;
             break;
@@ -8842,10 +8922,10 @@ static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph
     UNUSED(backend);
 }
 
-static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * tensor) {
-    switch (tensor->op) {
+static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
+    switch (op->op) {
         case GGML_OP_UNARY:
-            switch (ggml_get_unary_op(tensor)) {
+            switch (ggml_get_unary_op(op)) {
                 case GGML_UNARY_OP_GELU:
                 case GGML_UNARY_OP_SILU:
                 case GGML_UNARY_OP_RELU:
@@ -8854,7 +8934,23 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
                     return false;
             }
             break;
+        case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
+            {
+                struct ggml_tensor * a;
+                struct ggml_tensor * b;
+                if (op->op == GGML_OP_MUL_MAT) {
+                    a = op->src[0];
+                    b = op->src[1];
+                } else {
+                    a = op->src[2];
+                    b = op->src[1];
+                }
+                if (a->ne[3] != b->ne[3]) {
+                    return false;
+                }
+                return true;
+            } break;
         case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
         case GGML_OP_VIEW:
@@ -8868,7 +8964,6 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
         case GGML_OP_MUL:
         case GGML_OP_DIV:
         case GGML_OP_RMS_NORM:
-        case GGML_OP_MUL_MAT:
         case GGML_OP_SCALE:
         case GGML_OP_SQR:
         case GGML_OP_CLAMP:
@@ -8913,6 +9008,9 @@ ggml_backend_t ggml_backend_cuda_init(int device) {
         return nullptr;
     }
 
+    // not strictly necessary, but it may reduce the overhead of the first graph_compute
+    ggml_cuda_set_main_device(device);
+
     ggml_backend_context_cuda * ctx = new ggml_backend_context_cuda {
         /* .device = */ device
     };
index cff9d5bc68ed5ca6aceaf69e63be99982da25a6b..0c2f86ff57186b779d8c2369d13487ede8816e27 100644 (file)
@@ -805,32 +805,14 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
         case GGML_OP_DUP:
         case GGML_OP_CPY:
         case GGML_OP_CONT:
+        case GGML_OP_MUL_MAT:
+        case GGML_OP_MUL_MAT_ID:
             return true;
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_GET_ROWS:
             {
                 return op->ne[0] % 4 == 0;
             } break;
-        case GGML_OP_MUL_MAT:
-        case GGML_OP_MUL_MAT_ID:
-            {
-                struct ggml_tensor * a;
-                struct ggml_tensor * b; UNUSED(b);
-                if (op->op == GGML_OP_MUL_MAT) {
-                    a = op->src[0];
-                    b = op->src[1];
-                } else {
-                    a = op->src[2];
-                    b = op->src[1];
-                }
-                if (a->ne[3] != 1) {
-                    return false;
-                }
-                if (ggml_is_quantized(a->type) && a->ne[2] != 1) {
-                    return false;
-                }
-                return true;
-            } break;
         default:
             return false;
     }
@@ -1222,9 +1204,13 @@ void ggml_metal_graph_compute(
                     case GGML_OP_MUL_MAT:
                         {
                             GGML_ASSERT(ne00 == ne10);
-                            GGML_ASSERT(ne03 == ne13);
 
-                            const uint gqa = ne12/ne02;
+                            // TODO: assert that dim2 and dim3 are contiguous
+                            GGML_ASSERT(ne12 % ne02 == 0);
+                            GGML_ASSERT(ne13 % ne03 == 0);
+
+                            const uint r2 = ne12/ne02;
+                            const uint r3 = ne13/ne03;
 
                             // find the break-even point where the matrix-matrix kernel becomes more efficient compared
                             // to the matrix-vector kernel
@@ -1289,9 +1275,10 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:10];
                                 [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:11];
                                 [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:12];
-                                [encoder setBytes:&gqa     length:sizeof(gqa)  atIndex:13];
+                                [encoder setBytes:&r2      length:sizeof(r2)   atIndex:13];
+                                [encoder setBytes:&r3      length:sizeof(r3)   atIndex:14];
                                 [encoder setThreadgroupMemoryLength:8192 atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                             } else {
                                 int nth0 = 32;
                                 int nth1 = 1;
@@ -1327,90 +1314,60 @@ void ggml_metal_graph_compute(
                                         } break;
                                     case GGML_TYPE_Q4_0:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 8;
                                             nth1 = 8;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
                                         } break;
                                     case GGML_TYPE_Q4_1:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 8;
                                             nth1 = 8;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
                                         } break;
                                     case GGML_TYPE_Q5_0:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 8;
                                             nth1 = 8;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
                                         } break;
                                     case GGML_TYPE_Q5_1:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 8;
                                             nth1 = 8;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
                                         } break;
                                     case GGML_TYPE_Q8_0:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 8;
                                             nth1 = 8;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
                                         } break;
                                     case GGML_TYPE_Q2_K:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 2;
                                             nth1 = 32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
                                         } break;
                                     case GGML_TYPE_Q3_K:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 2;
                                             nth1 = 32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
                                         } break;
                                     case GGML_TYPE_Q4_K:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 4; //1;
                                             nth1 = 8; //32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
                                         } break;
                                     case GGML_TYPE_Q5_K:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 2;
                                             nth1 = 32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
                                         } break;
                                     case GGML_TYPE_Q6_K:
                                         {
-                                            GGML_ASSERT(ne02 == 1);
-                                            GGML_ASSERT(ne12 == 1);
-
                                             nth0 = 2;
                                             nth1 = 32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
@@ -1439,31 +1396,32 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
                                 [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:15];
                                 [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:16];
-                                [encoder setBytes:&gqa  length:sizeof(gqa)  atIndex:17];
+                                [encoder setBytes:&r2   length:sizeof(r2)   atIndex:17];
+                                [encoder setBytes:&r3   length:sizeof(r3)   atIndex:18];
 
                                 if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
                                     src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
                                     src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                                 else if (src0t == GGML_TYPE_Q4_K) {
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                                 else if (src0t == GGML_TYPE_Q3_K) {
 #ifdef GGML_QKK_64
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
 #else
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
 #endif
                                 }
                                 else if (src0t == GGML_TYPE_Q5_K) {
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                                 else if (src0t == GGML_TYPE_Q6_K) {
-                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 } else {
                                     int64_t ny = (ne11 + nrows - 1)/nrows;
-                                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                             }
                         } break;
@@ -1501,7 +1459,8 @@ void ggml_metal_graph_compute(
                             //GGML_ASSERT(ne20 >= 64);
                             GGML_ASSERT(src1t == GGML_TYPE_F32);
 
-                            const uint gqa = ne12/ne22;
+                            const uint r2 = ne12/ne22;
+                            const uint r3 = ne13/ne23;
 
                             // find the break-even point where the matrix-matrix kernel becomes more efficient compared
                             // to the matrix-vector kernel
@@ -1541,8 +1500,9 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:10];
                                 [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:11];
                                 [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:12];
-                                [encoder setBytes:&gqa     length:sizeof(gqa)  atIndex:13];
-                                [encoder setBytes:&idx     length:sizeof(idx)  atIndex:14];
+                                [encoder setBytes:&r2      length:sizeof(r2)   atIndex:13];
+                                [encoder setBytes:&r3      length:sizeof(r3)   atIndex:14];
+                                [encoder setBytes:&idx     length:sizeof(idx)  atIndex:15];
                                 // TODO: how to make this an array? read Metal docs
                                 for (int j = 0; j < n_as; ++j) {
                                     struct ggml_tensor * src_cur = dst->src[2 + j];
@@ -1550,11 +1510,11 @@ void ggml_metal_graph_compute(
                                     size_t offs_src_cur = 0;
                                     id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
 
-                                    [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:15 + j];
+                                    [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
                                 }
 
                                 [encoder setThreadgroupMemoryLength:8192 atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                             }
                         } break;
                     case GGML_OP_GET_ROWS:
index 4499b1bfc432f0fb0406a5c550a60de78a07a238..2def5f8202bae069560f0cb6b2fa975d312455fe 100644 (file)
@@ -730,9 +730,20 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
 //      giard against the number of rows not being divisible by
 //      N_DST, so this is another explicit assumption of the implementation.
 template<typename block_q_type, int nr, int nsg, int nw>
-void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
-                    int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
-                    uint3 tgpig, uint tiisg, uint sgitg) {
+void mul_vec_q_n_f32(
+        device const void  * src0,
+        device const float * src1,
+        device       float * dst,
+                   int64_t   ne00,
+                   int64_t   ne01,
+                   int64_t   ne02,
+                   int64_t   ne10,
+                   int64_t   ne12,
+                   int64_t   ne0,
+                   int64_t   ne1,
+                   uint      r2,
+                   uint      r3,
+                   uint3 tgpig, uint tiisg, uint sgitg) {
     const int nb = ne00/QK4_0;
 
     const int r0 = tgpig.x;
@@ -741,7 +752,10 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
 
     const int first_row = (r0 * nsg + sgitg) * nr;
 
-    const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
 
     device const block_q_type * x = (device const block_q_type *) src0 + offset0;
     device const float        * y = (device const float        *) src1 + r1*ne10 + im*ne00*ne1;
@@ -791,13 +805,14 @@ kernel void kernel_mul_mv_q4_0_f32(
         constant   int64_t & ne02[[buffer(5)]],
         constant   int64_t & ne10[[buffer(9)]],
         constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0[[buffer(15)]],
-        constant   int64_t & ne1[[buffer(16)]],
-        constant   uint    & gqa[[buffer(17)]],
+        constant   int64_t & ne0 [[buffer(15)]],
+        constant   int64_t & ne1 [[buffer(16)]],
+        constant   uint    & r2  [[buffer(17)]],
+        constant   uint    & r3  [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
+    mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mv_q4_1_f32(
@@ -809,13 +824,14 @@ kernel void kernel_mul_mv_q4_1_f32(
         constant   int64_t & ne02[[buffer(5)]],
         constant   int64_t & ne10[[buffer(9)]],
         constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0[[buffer(15)]],
-        constant   int64_t & ne1[[buffer(16)]],
-        constant   uint    & gqa[[buffer(17)]],
+        constant   int64_t & ne0 [[buffer(15)]],
+        constant   int64_t & ne1 [[buffer(16)]],
+        constant   uint    & r2  [[buffer(17)]],
+        constant   uint    & r3  [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
-     mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
+     mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mv_q5_0_f32(
@@ -827,13 +843,14 @@ kernel void kernel_mul_mv_q5_0_f32(
         constant   int64_t & ne02[[buffer(5)]],
         constant   int64_t & ne10[[buffer(9)]],
         constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0[[buffer(15)]],
-        constant   int64_t & ne1[[buffer(16)]],
-        constant   uint    & gqa[[buffer(17)]],
+        constant   int64_t & ne0 [[buffer(15)]],
+        constant   int64_t & ne1 [[buffer(16)]],
+        constant   uint    & r2  [[buffer(17)]],
+        constant   uint    & r3  [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
+    mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mv_q5_1_f32(
@@ -845,13 +862,14 @@ kernel void kernel_mul_mv_q5_1_f32(
         constant   int64_t & ne02[[buffer(5)]],
         constant   int64_t & ne10[[buffer(9)]],
         constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0[[buffer(15)]],
-        constant   int64_t & ne1[[buffer(16)]],
-        constant   uint    & gqa[[buffer(17)]],
+        constant   int64_t & ne0 [[buffer(15)]],
+        constant   int64_t & ne1 [[buffer(16)]],
+        constant   uint    & r2  [[buffer(17)]],
+        constant   uint    & r3  [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
+    mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
 }
 
 
@@ -866,9 +884,10 @@ kernel void kernel_mul_mv_q8_0_f32(
         constant   int64_t & ne02[[buffer(5)]],
         constant   int64_t & ne10[[buffer(9)]],
         constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0[[buffer(15)]],
-        constant   int64_t & ne1[[buffer(16)]],
-        constant   uint    & gqa[[buffer(17)]],
+        constant   int64_t & ne0 [[buffer(15)]],
+        constant   int64_t & ne1 [[buffer(16)]],
+        constant   uint    & r2  [[buffer(17)]],
+        constant   uint    & r3  [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -880,8 +899,14 @@ kernel void kernel_mul_mv_q8_0_f32(
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
+
     const int first_row = (r0 * nsg + sgitg) * nr;
-    const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
     device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
     device const float      * y = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
 
@@ -939,6 +964,8 @@ kernel void kernel_mul_mv_f32_f32(
         constant  uint64_t & nb12,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
+        constant   uint    & r2   [[buffer(17)]],
+        constant   uint    & r3   [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]]) {
 
@@ -946,7 +973,12 @@ kernel void kernel_mul_mv_f32_f32(
     const int64_t rb = tgpig.y*N_F32_F32;
     const int64_t im = tgpig.z;
 
-    device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+    device const float * x = (device const float *) (src0 + offset0);
 
     if (ne00 < 128) {
         for (int row = 0; row < N_F32_F32; ++row) {
@@ -1012,6 +1044,8 @@ kernel void kernel_mul_mv_f16_f16(
         constant  uint64_t & nb12,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
+        constant   uint    & r2   [[buffer(17)]],
+        constant   uint    & r3   [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]]) {
 
@@ -1019,7 +1053,12 @@ kernel void kernel_mul_mv_f16_f16(
     const int64_t rb = tgpig.y*N_F16_F16;
     const int64_t im = tgpig.z;
 
-    device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+    device const half * x = (device const half *) (src0 + offset0);
 
     if (ne00 < 128) {
         for (int row = 0; row < N_F16_F16; ++row) {
@@ -1083,6 +1122,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
         constant  uint64_t & nb12,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
+        constant   uint    & r2   [[buffer(17)]],
+        constant   uint    & r3   [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]]) {
 
@@ -1090,7 +1131,12 @@ kernel void kernel_mul_mv_f16_f32_1row(
     const int64_t r1 = tgpig.y;
     const int64_t im = tgpig.z;
 
-    device const half  * x = (device const half  *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+    device const half  * x = (device const half  *) (src0 + offset0);
     device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
 
     float sumf = 0;
@@ -1137,6 +1183,8 @@ kernel void kernel_mul_mv_f16_f32(
         constant  uint64_t & nb12,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
+        constant   uint    & r2   [[buffer(17)]],
+        constant   uint    & r3   [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]]) {
 
@@ -1144,7 +1192,12 @@ kernel void kernel_mul_mv_f16_f32(
     const int64_t rb = tgpig.y*N_F16_F32;
     const int64_t im = tgpig.z;
 
-    device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+    device const half * x = (device const half *) (src0 + offset0);
 
     if (ne00 < 128) {
         for (int row = 0; row < N_F16_F32; ++row) {
@@ -1209,6 +1262,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
         constant  uint64_t & nb12,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
+        constant   uint    & r2   [[buffer(17)]],
+        constant   uint    & r3   [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]]) {
 
@@ -1216,7 +1271,12 @@ kernel void kernel_mul_mv_f16_f32_l4(
     const int64_t r0 = tgpig.x;
     const int64_t im = tgpig.z;
 
-    device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+
+    device const half4 * x4 = (device const half4 *) (src0 + offset0);
 
     for (int r1 = 0; r1 < nrows; ++r1) {
         device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
@@ -1276,7 +1336,7 @@ kernel void kernel_alibi_f32(
     } else {
         m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
     }
-    
+
     device       char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
     device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
     for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
@@ -1821,23 +1881,30 @@ kernel void kernel_mul_mv_q2_K_f32(
         constant   int64_t & ne02[[buffer(5)]],
         constant   int64_t & ne10[[buffer(9)]],
         constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0[[buffer(15)]],
-        constant   int64_t & ne1[[buffer(16)]],
-        constant   uint    & gqa[[buffer(17)]],
+        constant   int64_t & ne0 [[buffer(15)]],
+        constant   int64_t & ne1 [[buffer(16)]],
+        constant   uint    & r2  [[buffer(17)]],
+        constant   uint    & r3  [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const int nb = ne00/QK_K;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
-    const int r2 = tgpig.z;
+    const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
     const int ib_row = first_row * nb;
-    const uint offset0 = r2/gqa*(nb*ne0);
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
     device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
-    device const float      * y = (device const float      *) src1 + r1*ne10 + r2*ne00*ne1;
+    device const float      * y = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+
     float yl[32];
     float sumf[N_DST]={0.f}, all_sum;
 
@@ -1846,11 +1913,11 @@ kernel void kernel_mul_mv_q2_K_f32(
 #if QK_K == 256
     const int ix = tiisg/8;  // 0...3
     const int it = tiisg%8;  // 0...7
-    const int im = it/4;     // 0 or 1
+    const int iq = it/4;     // 0 or 1
     const int ir = it%4;     // 0...3
     const int is = (8*ir)/16;// 0 or 1
 
-    device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
+    device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
 
     for (int ib = ix; ib < nb; ib += 4) {
 
@@ -1862,8 +1929,8 @@ kernel void kernel_mul_mv_q2_K_f32(
             yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
         }
 
-        device const uint8_t  * sc = (device const uint8_t  *)x[ib].scales + 8*im + is;
-        device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
+        device const uint8_t  * sc = (device const uint8_t  *)x[ib].scales + 8*iq + is;
+        device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
         device const half     * dh = &x[ib].d;
 
         for (int row = 0; row < N_DST; row++) {
@@ -1950,7 +2017,7 @@ kernel void kernel_mul_mv_q2_K_f32(
     for (int row = 0; row < N_DST; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
         }
     }
 }
@@ -1965,9 +2032,10 @@ kernel void kernel_mul_mv_q3_K_f32(
         constant   int64_t & ne02[[buffer(5)]],
         constant   int64_t & ne10[[buffer(9)]],
         constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0[[buffer(15)]],
-        constant   int64_t & ne1[[buffer(16)]],
-        constant   uint    & gqa[[buffer(17)]],
+        constant   int64_t & ne0 [[buffer(15)]],
+        constant   int64_t & ne1 [[buffer(16)]],
+        constant   uint    & r2  [[buffer(17)]],
+        constant   uint    & r3  [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1976,12 +2044,17 @@ kernel void kernel_mul_mv_q3_K_f32(
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
-    const int64_t r2 = tgpig.z;
+    const int64_t im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
-    const uint offset0 = r2/gqa*(nb*ne0);
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
     device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
-    device const float     * yy = (device const float      *) src1 + r1*ne10 + r2*ne00*ne1;
+    device const float     * yy = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
 
     float yl[32];
 
@@ -2103,7 +2176,7 @@ kernel void kernel_mul_mv_q3_K_f32(
     }
     if (tiisg == 0) {
         for (int row = 0; row < 2; ++row) {
-            dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row];
         }
     }
 }
@@ -2117,26 +2190,33 @@ kernel void kernel_mul_mv_q3_K_f32(
         constant   int64_t & ne02[[buffer(5)]],
         constant   int64_t & ne10[[buffer(9)]],
         constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0[[buffer(15)]],
-        constant   int64_t & ne1[[buffer(16)]],
-        constant   uint    & gqa[[buffer(17)]],
+        constant   int64_t & ne0 [[buffer(15)]],
+        constant   int64_t & ne1 [[buffer(16)]],
+        constant   uint    & r2  [[buffer(17)]],
+        constant   uint    & r3  [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const int nb = ne00/QK_K;
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
-    const int64_t r2 = tgpig.z;
+    const int64_t im = tgpig.z;
 
     const int row = 2 * r0 + sgitg;
-    const uint offset0 = r2/gqa*(nb*ne0);
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
     device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
-    device const float     * yy = (device const float      *) src1 + r1*ne10 + r2*ne00*ne1;
+    device const float     * yy = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+
     const int ix = tiisg/4;
     const int il = 4 * (tiisg%4);// 0, 4, 8, 12
-    const int im = il/8;         // 0, 0, 1, 1
+    const int iq = il/8;         // 0, 0, 1, 1
     const int in = il%8;         // 0, 4, 0, 4
 
     float2 sum = {0.f, 0.f};
@@ -2156,7 +2236,7 @@ kernel void kernel_mul_mv_q3_K_f32(
         const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
 
         for (int l = 0; l < 4; l += 2) {
-            const uint16_t hm = h[l/2] >> im;
+            const uint16_t hm = h[l/2] >> iq;
             sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 :  4))
                     + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
                     + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
@@ -2172,7 +2252,7 @@ kernel void kernel_mul_mv_q3_K_f32(
 
     const float tot = simd_sum(sumf);
     if (tiisg == 0) {
-        dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
+        dst[r1*ne0 + im*ne0*ne1 + row] = tot;
     }
 
 }
@@ -2190,10 +2270,11 @@ kernel void kernel_mul_mv_q4_K_f32(
         constant   int64_t & ne12 [[buffer(11)]],
         constant   int64_t & ne0  [[buffer(15)]],
         constant   int64_t & ne1  [[buffer(16)]],
-        constant   uint    & gqa  [[buffer(17)]],
+        constant   uint    & r2   [[buffer(17)]],
+        constant   uint    & r3   [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const uint16_t kmask1 = 0x3f3f;
     const uint16_t kmask2 = 0x0f0f;
@@ -2201,26 +2282,32 @@ kernel void kernel_mul_mv_q4_K_f32(
 
     const int ix = tiisg/8;  // 0...3
     const int it = tiisg%8;  // 0...7
-    const int im = it/4;     // 0 or 1
+    const int iq = it/4;     // 0 or 1
     const int ir = it%4;     // 0...3
 
     const int nb = ne00/QK_K;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
-    const int r2 = tgpig.z;
+    const int im = tgpig.z;
     //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
     const int first_row = r0 * N_DST;
     const int ib_row = first_row * nb;
-    const uint offset0 = r2/gqa*(nb*ne0);
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
     device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
-    device const float      * y = (device const float      *) src1 + r1*ne10 + r2*ne00*ne1;
+    device const float      * y = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+
     float yl[16];
     float yh[16];
     float sumf[N_DST]={0.f}, all_sum;
 
     const int step = sizeof(block_q4_K) * nb / 2;
 
-    device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
+    device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
 
     uint16_t sc16[4];
     thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
@@ -2235,8 +2322,8 @@ kernel void kernel_mul_mv_q4_K_f32(
             yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
         }
 
-        device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
-        device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
+        device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
+        device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
         device const half     * dh = &x[ib].d;
 
         for (int row = 0; row < N_DST; row++) {
@@ -2280,7 +2367,7 @@ kernel void kernel_mul_mv_q4_K_f32(
     for (int row = 0; row < N_DST; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
         }
     }
 }
@@ -2294,9 +2381,10 @@ kernel void kernel_mul_mv_q4_K_f32(
         constant   int64_t & ne02[[buffer(5)]],
         constant   int64_t & ne10[[buffer(9)]],
         constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0[[buffer(15)]],
-        constant   int64_t & ne1[[buffer(16)]],
-        constant   uint    & gqa[[buffer(17)]],
+        constant   int64_t & ne0 [[buffer(15)]],
+        constant   int64_t & ne1 [[buffer(16)]],
+        constant   uint    & r2  [[buffer(17)]],
+        constant   uint    & r3  [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2307,12 +2395,18 @@ kernel void kernel_mul_mv_q4_K_f32(
     const int nb = ne00/QK_K;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
-    const int r2 = tgpig.z;
+    const int im = tgpig.z;
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
     const int ib_row = first_row * nb;
-    const uint offset0 = r2/gqa*(nb*ne0);
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
     device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
-    device const float      * y = (device const float      *) src1 + r1*ne10 + r2*ne00*ne1;
+    device const float      * y = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+
     float yl[8];
     float yh[8];
     float sumf[N_DST]={0.f}, all_sum;
@@ -2368,7 +2462,7 @@ kernel void kernel_mul_mv_q4_K_f32(
     for (int row = 0; row < N_DST; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
+            dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
         }
     }
 }
@@ -2383,9 +2477,10 @@ kernel void kernel_mul_mv_q5_K_f32(
         constant   int64_t & ne02[[buffer(5)]],
         constant   int64_t & ne10[[buffer(9)]],
         constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0[[buffer(15)]],
-        constant   int64_t & ne1[[buffer(16)]],
-        constant   uint    & gqa[[buffer(17)]],
+        constant   int64_t & ne0 [[buffer(15)]],
+        constant   int64_t & ne1 [[buffer(16)]],
+        constant   uint    & r2  [[buffer(17)]],
+        constant   uint    & r3  [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2394,12 +2489,17 @@ kernel void kernel_mul_mv_q5_K_f32(
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
-    const int r2 = tgpig.z;
+    const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
-    const uint offset0 = r2/gqa*(nb*ne0);
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
     device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
-    device const float     * yy = (device const float      *) src1 + r1*ne10 + r2*ne00*ne1;
+    device const float     * yy = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
 
     float sumf[2]={0.f};
 
@@ -2415,15 +2515,15 @@ kernel void kernel_mul_mv_q5_K_f32(
 
     const int tid = tiisg/4;
     const int ix  = tiisg%4;
-    const int im  = tid/4;
+    const int iq  = tid/4;
     const int ir  = tid%4;
     const int n   = 8;
 
     const int l0 = n*ir;
-    const int q_offset = 32*im + l0;
-    const int y_offset = 64*im + l0;
+    const int q_offset = 32*iq + l0;
+    const int y_offset = 64*iq + l0;
 
-    const uint8_t hm1 = 1u << (2*im);
+    const uint8_t hm1 = 1u << (2*iq);
     const uint8_t hm2 = hm1 << 1;
     const uint8_t hm3 = hm1 << 4;
     const uint8_t hm4 = hm2 << 4;
@@ -2438,7 +2538,7 @@ kernel void kernel_mul_mv_q5_K_f32(
         device const uint8_t * q1 = x[i].qs + q_offset;
         device const uint8_t * qh = x[i].qh + l0;
         device const half * dh = &x[i].d;
-        device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
+        device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
 
         device const float * y2 = y1 + 128;
         float4 sumy = {0.f, 0.f, 0.f, 0.f};
@@ -2494,7 +2594,7 @@ kernel void kernel_mul_mv_q5_K_f32(
 
     const int il = 4 * (tiisg/8);  // 0, 4, 8, 12
     const int ix = tiisg%8;
-    const int im = il/8;         // 0, 0, 1, 1
+    const int iq = il/8;         // 0, 0, 1, 1
     const int in = il%8;         // 0, 4, 0, 4
 
     device const float * y = yy + ix*QK_K + il;
@@ -2519,7 +2619,7 @@ kernel void kernel_mul_mv_q5_K_f32(
 
             float2 acc = {0.f, 0.f};
             for (int l = 0; l < 4; ++l) {
-                const uint8_t hl = h[l] >> im;
+                const uint8_t hl = h[l] >> iq;
                 acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
                         + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
                 acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
@@ -2541,7 +2641,7 @@ kernel void kernel_mul_mv_q5_K_f32(
     for (int row = 0; row < 2; ++row) {
         const float tot = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
         }
     }
 
@@ -2556,9 +2656,10 @@ kernel void kernel_mul_mv_q6_K_f32(
         constant   int64_t & ne02[[buffer(5)]],
         constant   int64_t & ne10[[buffer(9)]],
         constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0[[buffer(15)]],
-        constant   int64_t & ne1[[buffer(16)]],
-        constant   uint    & gqa[[buffer(17)]],
+        constant   int64_t & ne0 [[buffer(15)]],
+        constant   int64_t & ne1 [[buffer(16)]],
+        constant   uint    & r2  [[buffer(17)]],
+        constant   uint    & r3  [[buffer(18)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2572,12 +2673,17 @@ kernel void kernel_mul_mv_q6_K_f32(
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
-    const int r2 = tgpig.z;
+    const int     im = tgpig.z;
 
     const int row = 2 * r0 + sgitg;
-    const uint offset0 = r2/gqa*(nb*ne0);
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
     device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
-    device const float     * yy = (device const float      *) src1 + r1*ne10 + r2*ne00*ne1;
+    device const float     * yy = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
 
     float sumf = 0;
 
@@ -2643,7 +2749,7 @@ kernel void kernel_mul_mv_q6_K_f32(
 
     const float tot = simd_sum(sumf);
     if (tiisg == 0) {
-        dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
+        dst[r1*ne0 + im*ne0*ne1 + row] = tot;
     }
 }
 
@@ -2954,23 +3060,24 @@ kernel void kernel_get_rows(
 // each block_q contains 16*nl weights
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
 void kernel_mul_mm_impl(device const  uchar * src0,
-                          device const  uchar * src1,
-                          device        float * dst,
-                          constant    int64_t & ne00,
-                          constant    int64_t & ne02,
-                          constant    int64_t & nb01,
-                          constant    int64_t & nb02,
-                          constant    int64_t & ne12,
-                          constant    int64_t & nb10,
-                          constant    int64_t & nb11,
-                          constant    int64_t & nb12,
-                          constant    int64_t & ne0,
-                          constant    int64_t & ne1,
-                          constant       uint & gqa,
-                          threadgroup   uchar * shared_memory [[threadgroup(0)]],
-                          uint3                 tgpig[[threadgroup_position_in_grid]],
-                          uint                  tiitg[[thread_index_in_threadgroup]],
-                          uint                  sgitg[[simdgroup_index_in_threadgroup]]) {
+                        device const  uchar * src1,
+                        device        float * dst,
+                        constant    int64_t & ne00,
+                        constant    int64_t & ne02,
+                        constant    int64_t & nb01,
+                        constant    int64_t & nb02,
+                        constant    int64_t & ne12,
+                        constant    int64_t & nb10,
+                        constant    int64_t & nb11,
+                        constant    int64_t & nb12,
+                        constant    int64_t & ne0,
+                        constant    int64_t & ne1,
+                        constant       uint & r2,
+                        constant       uint & r3,
+                        threadgroup   uchar * shared_memory [[threadgroup(0)]],
+                        uint3                 tgpig[[threadgroup_position_in_grid]],
+                        uint                  tiitg[[thread_index_in_threadgroup]],
+                        uint                  sgitg[[simdgroup_index_in_threadgroup]]) {
 
     threadgroup half  * sa = (threadgroup half  *)(shared_memory);
     threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
@@ -2996,7 +3103,10 @@ void kernel_mul_mm_impl(device const  uchar * src0,
 
     short il = (tiitg % THREAD_PER_ROW);
 
-    uint   offset0 = im/gqa*nb02;
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    uint   offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
     ushort offset1 = il/nl;
 
     device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
@@ -3094,7 +3204,8 @@ kernel void kernel_mul_mm(device const  uchar * src0,
                           constant    int64_t & nb12,
                           constant    int64_t & ne0,
                           constant    int64_t & ne1,
-                          constant       uint & gqa,
+                          constant       uint & r2,
+                          constant       uint & r3,
                           threadgroup   uchar * shared_memory [[threadgroup(0)]],
                           uint3                 tgpig[[threadgroup_position_in_grid]],
                           uint                  tiitg[[thread_index_in_threadgroup]],
@@ -3113,7 +3224,8 @@ kernel void kernel_mul_mm(device const  uchar * src0,
         nb12,
         ne0,
         ne1,
-        gqa,
+        r2,
+        r3,
         shared_memory,
         tgpig,
         tiitg,
@@ -3135,7 +3247,8 @@ kernel void kernel_mul_mm_id(
         constant     int64_t & nb12,
         constant     int64_t & ne0,
         constant     int64_t & ne1,
-        constant        uint & gqa,
+        constant        uint & r2,
+        constant        uint & r3,
         constant         int & idx,
         device const   uchar * src00,
         device const   uchar * src01,
@@ -3165,7 +3278,8 @@ kernel void kernel_mul_mm_id(
         nb12,
         ne0,
         ne1,
-        gqa,
+        r2,
+        r3,
         shared_memory,
         tgpig,
         tiitg,
@@ -3214,7 +3328,8 @@ typedef void (mat_mm_t)(
         constant    int64_t & nb12,
         constant    int64_t & ne0,
         constant    int64_t & ne1,
-        constant       uint & gqa,
+        constant       uint & r2,
+        constant       uint & r3,
         threadgroup   uchar *,
         uint3, uint, uint);
 
@@ -3245,7 +3360,8 @@ typedef void (mat_mm_id_t)(
         constant     int64_t & nb12,
         constant     int64_t & ne0,
         constant     int64_t & ne1,
-        constant        uint & gqa,
+        constant        uint & r2,
+        constant        uint & r3,
         constant         int & idx,
         device const   uchar * src00,
         device const   uchar * src01,
index 22ed469e9785f130711c5d571e5454b8d20a0ddf..a5110c663e2a6ce8c05f0450e007eea3aae57c39 100644 (file)
@@ -12,6 +12,7 @@
 #include <stdio.h>
 #include <stdlib.h>
 #include <string>
+#include <thread>
 #include <vector>
 
 
@@ -20,12 +21,35 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
     std::vector<float> data(size);
 
     std::random_device rd;
+
+#if 0
     std::default_random_engine generator(rd());
     std::uniform_real_distribution<float> distribution(min, max);
 
     for (size_t i = 0; i < size; i++) {
         data[i] = distribution(generator);
     }
+#endif
+    auto init_thread = [&](size_t start, size_t end) {
+        std::default_random_engine generator(rd());
+        std::uniform_real_distribution<float> distribution(min, max);
+
+        for (size_t i = start; i < end; i++) {
+            data[i] = distribution(generator);
+        }
+    };
+
+    size_t n_threads = std::thread::hardware_concurrency();
+    std::vector<std::thread> threads;
+    threads.reserve(n_threads);
+    for (size_t i = 0; i < n_threads; i++) {
+        size_t start =     i*size/n_threads;
+        size_t end   = (i+1)*size/n_threads;
+        threads.emplace_back(init_thread, start, end);
+    }
+    for (auto & t : threads) {
+        t.join();
+    }
 
     if (tensor->type == GGML_TYPE_F32) {
         ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float));
@@ -202,6 +226,10 @@ static bool isinf_or_max(float f) {
     return std::isinf(f) || f == FLT_MAX || f == -FLT_MAX;
 }
 
+static bool ggml_is_view_op(enum ggml_op op) {
+    return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE;
+}
+
 struct test_case {
     virtual ~test_case() {}
 
@@ -221,12 +249,23 @@ struct test_case {
         }
     }
 
+    virtual size_t op_size(ggml_tensor * t) {
+        size_t size = ggml_nbytes(t);
+        // add source tensors
+        for (int i = 0; i < GGML_MAX_SRC; i++) {
+            if (t->src[i] != NULL) {
+                size += ggml_nbytes(t->src[i]);
+            }
+        }
+        return size;
+    }
+
     bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name) {
         ggml_init_params params = {
             /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead(),
             /* .mem_base = */ NULL,
             /* .no_alloc = */ true,
-       };
+        };
         ggml_context * ctx = ggml_init(params);
 
         ggml_tensor * out = build_graph(ctx);
@@ -237,10 +276,13 @@ struct test_case {
             return true;
         }
 
+        printf("  %s(%s): ", ggml_op_desc(out), vars().c_str());
+        fflush(stdout);
+
         // check if backends support op
         for (ggml_backend_t backend : {backend1, backend2}) {
             if (!ggml_backend_supports_op(backend, out)) {
-                printf("  %s: not supported\n", ggml_op_desc(out));
+                printf("not supported\n");
                 ggml_free(ctx);
                 return true;
             }
@@ -275,7 +317,7 @@ struct test_case {
             for (size_t i = 0; i < f1.size(); i++) {
                 // check for nans
                 if (std::isnan(f1[i]) || std::isnan(f2[i])) {
-                    printf("    Error: %s: NaN at index %zu\n", ggml_op_desc(t1), i);
+                    printf("NaN at index %zu ", i);
                     ud->ok = false;
                     return true;
                 }
@@ -283,12 +325,12 @@ struct test_case {
                 if (isinf_or_max(f1[i]) || isinf_or_max(f2[i])) {
                     if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) {
                         if (std::signbit(f1[i]) != std::signbit(f2[i])) {
-                            printf("    Error: %s: inf sign mismatch: %f %f\n", ggml_op_desc(t1), f1[i], f2[i]);
+                            printf("inf sign mismatch: %f %f ", f1[i], f2[i]);
                             ud->ok = false;
                             return true;
                         }
                     } else {
-                        printf("    Error: %s: inf mismatch: %f %f\n", ggml_op_desc(t1), f1[i], f2[i]);
+                        printf("inf mismatch: %f %f ", f1[i], f2[i]);
                         ud->ok = false;
                         return true;
                     }
@@ -297,15 +339,14 @@ struct test_case {
 
             double err = nmse(f1.data(), f2.data(), f1.size());
             if (err > ud->max_err) {
-                printf("    Error: %s: NMSE = %f\n", ggml_op_desc(t1), err);
+                printf("NMSE = %f ", err);
                 ud->ok = false;
             }
             return true;
-       };
+        };
 
         ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud);
 
-        printf("  %s(%s): ", ggml_op_desc(out), vars().c_str());
         if (ud.ok) {
             printf("\033[1;32mOK\033[0m\n");
         } else {
@@ -318,6 +359,103 @@ struct test_case {
 
         return ud.ok;
     }
+
+    bool eval_perf(ggml_backend_t backend, const char * op_name) {
+        static const size_t graph_nodes = 8192;
+
+        ggml_init_params params = {
+            /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead_custom(graph_nodes, false),
+            /* .mem_base = */ NULL,
+            /* .no_alloc = */ true,
+        };
+        ggml_context * ctx = ggml_init(params);
+
+        ggml_tensor * out = build_graph(ctx);
+
+        if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) {
+            //printf("  %s: skipping\n", ggml_op_desc(out));
+            ggml_free(ctx);
+            return true;
+        }
+
+        int len = printf("  %s(%s): ", ggml_op_desc(out), vars().c_str());
+        fflush(stdout);
+
+        // check if backends support op
+        if (!ggml_backend_supports_op(backend, out)) {
+            printf("not supported\n");
+            ggml_free(ctx);
+            return true;
+        }
+
+        // align while also leaving some margin for variations in parameters
+        int align = 20;
+        int last = (len + align - 1) / align * align;
+        if (last - len < 5) {
+            last += align;
+        }
+        last = std::max(last, 60);
+        printf("%*s", last - len, "");
+
+        // allocate
+        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend);
+
+        // randomize tensors
+        initialize_tensors(ctx);
+
+        // build graph
+        ggml_cgraph * gf = ggml_new_graph_custom(ctx, graph_nodes, false);
+        ggml_build_forward_expand(gf, out);
+
+        // warmup run
+        ggml_backend_graph_compute(backend, gf);
+
+        // duplicate the op
+        size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU
+        int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1;
+        for (int i = 1; i < n_runs; i++) {
+            gf->nodes[gf->n_nodes++] = out;
+        }
+
+        // calculate memory
+        size_t mem = n_runs * op_size(out);
+        auto tensor_op_size = [](ggml_tensor * t) {
+            size_t size = ggml_nbytes(t);
+            // add source tensors
+            for (int i = 0; i < GGML_MAX_SRC; i++) {
+                if (t->src[i] != NULL) {
+                    size += ggml_nbytes(t->src[i]);
+                }
+            }
+            return size;
+        };
+        for (int i = 0; i < gf->n_nodes; i++) {
+            if (ggml_is_view_op(gf->nodes[i]->op) || gf->nodes[i] == out)
+                continue;
+            mem += tensor_op_size(gf->nodes[i]);
+        }
+
+        // run
+        ggml_backend_synchronize(backend);
+
+        int64_t start_time = ggml_time_us();
+        ggml_backend_graph_compute(backend, gf);
+        ggml_backend_synchronize(backend);
+        int64_t end_time = ggml_time_us();
+        double time_us = end_time - start_time;
+
+        printf("    %5d runs - %8.2f us/run - %8zu kB/run - \033[1;34m%7.2f GB/s\033[0m\n",
+            n_runs,
+            time_us / n_runs,
+            op_size(out) / 1024,
+            mem / (time_us/1e6) / 1024.0 / 1024.0 / 1024.0);
+
+        ggml_backend_buffer_free(buf);
+
+        ggml_free(ctx);
+
+        return true;
+    }
 };
 
 // GGML_OP_UNARY
@@ -389,6 +527,10 @@ struct test_repeat : public test_case {
         return VARS_TO_STR3(type, ne, nr);
     }
 
+    size_t op_size(ggml_tensor * t) override {
+        return ggml_nbytes(t) * 2;
+    }
+
     test_repeat(ggml_type type = GGML_TYPE_F32,
             std::array<int64_t, 4> ne = {10, 10, 10, 10},
             std::array<int, 4> nr = {2, 2, 2, 2})
@@ -432,6 +574,10 @@ struct test_cpy : public test_case {
         return VARS_TO_STR3(type_src, type_dst, ne);
     }
 
+    size_t op_size(ggml_tensor * t) override {
+        return ggml_nbytes(t) + ggml_nbytes(t->src[0]);
+    }
+
     test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
             std::array<int64_t, 4> ne = {10, 10, 10, 1})
         : type_src(type_src), type_dst(type_dst), ne(ne) {}
@@ -470,16 +616,20 @@ struct test_cont : public test_case {
 // GGML_OP_MUL
 // GGML_OP_DIV
 struct test_bin_bcast : public test_case {
-    using op_t = std::function<ggml_tensor * (ggml_context *, ggml_tensor *, ggml_tensor *)>;
+    using op_t = ggml_tensor * (*) (ggml_context *, ggml_tensor *, ggml_tensor *);
     op_t op;
     const ggml_type type;
     const std::array<int64_t, 4> ne;
-    const std::array<int,4> nr;
+    const std::array<int, 4> nr;
 
     std::string vars() override {
         return VARS_TO_STR3(type, ne, nr);
     }
 
+    size_t op_size(ggml_tensor * t) override {
+        return ggml_nbytes(t) * 3;
+    }
+
     test_bin_bcast(op_t op, ggml_type type = GGML_TYPE_F32,
             std::array<int64_t, 4> ne = {10, 10, 1, 1},
             std::array<int, 4> nr = {1, 2, 1, 1})
@@ -491,6 +641,17 @@ struct test_bin_bcast : public test_case {
         ggml_tensor * out = op(ctx, a, b);
         return out;
     }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (op == ggml_div) {
+                // avoid division by zero
+                init_tensor_uniform(t, 1.0f, 2.0f);
+            } else {
+                init_tensor_uniform(t);
+            }
+        }
+    }
 };
 
 // GGML_OP_SCALE
@@ -576,6 +737,15 @@ struct test_mul_mat : public test_case {
         return 5e-4;
     }
 
+    size_t op_size(ggml_tensor * t) override {
+        size_t a = ggml_nbytes(t->src[0]) * n * nr[0] * nr[1];
+        size_t b = ggml_nbytes(t->src[1]) * m;
+        size_t c  = ggml_nbytes(t);
+        return a + b + c;
+
+        GGML_UNUSED(t);
+    }
+
     test_mul_mat(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
             int64_t m = 32, int64_t n = 32, int64_t k = 32,
             std::array<int64_t, 2> bs = {10, 10},
@@ -584,13 +754,80 @@ struct test_mul_mat : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         // C^T = A * B^T: (k, m) * (k, n) => (m, n)
-        ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0]*nr[0], bs[1]*nr[1]);
+        ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0]      , bs[1]);
         ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
         ggml_tensor * out = ggml_mul_mat(ctx, a, b);
         return out;
     }
 };
 
+// GGML_OP_MUL_MAT_ID
+struct test_mul_mat_id : public test_case {
+    const ggml_type type_a;
+    const ggml_type type_b;
+    const int n_mats;
+    const int id;
+    const int64_t m;
+    const int64_t n;
+    const int64_t k;
+    const std::array<int64_t, 2> bs; // dims 3 and 4
+    const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
+
+    std::string vars() override {
+        return VARS_TO_STR9(type_a, type_b, n_mats, id, m, n, k, bs, nr);
+    }
+
+    double max_nmse_err() override {
+        return 5e-4;
+    }
+
+    size_t op_size(ggml_tensor * t) override {
+        size_t a = ggml_nbytes(t->src[2]) * n * nr[0] * nr[1];
+        size_t b = ggml_nbytes(t->src[1]) * m;
+        size_t c  = ggml_nbytes(t);
+        return a + b + c;
+
+        GGML_UNUSED(t);
+    }
+
+    test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
+            int n_mats = 2, int id = 0,
+            int64_t m = 32, int64_t n = 32, int64_t k = 32,
+            std::array<int64_t, 2> bs = {10, 10},
+            std::array<int64_t, 2> nr = {2, 2})
+        : type_a(type_a), type_b(type_b), n_mats(n_mats), id(id),
+            m(m), n(n), k(k), bs(bs), nr(nr) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        // C^T = A * B^T: (k, m) * (k, n) => (m, n)
+        std::vector<ggml_tensor *> mats;
+        for (int i = 0; i < n_mats; i++) {
+            ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
+            mats.push_back(a);
+        }
+        ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_mats);
+        ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
+        ggml_tensor * out = ggml_mul_mat_id(ctx, mats.data(), ids, id, b);
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (t->type == GGML_TYPE_I32) {
+                // ids
+                std::vector<int> data(n_mats);
+                for (int i = 0; i < n_mats; i++) {
+                    data[i] = i;
+                }
+                std::shuffle(data.begin(), data.end(), std::default_random_engine(std::random_device()()));
+                ggml_backend_tensor_set(t, data.data(), 0, n_mats * sizeof(int));
+            } else {
+                init_tensor_uniform(t);
+            }
+        }
+    }
+};
+
 // GGML_OP_SQR
 struct test_sqr : public test_case {
     const ggml_type type;
@@ -852,64 +1089,6 @@ struct test_argsort : public test_case {
     }
 };
 
-// GGML_OP_MUL_MAT_ID
-struct test_mul_mat_id : public test_case {
-    const ggml_type type_a;
-    const ggml_type type_b;
-    const int n_mats;
-    const int id;
-    const int64_t m;
-    const int64_t n;
-    const int64_t k;
-    const std::array<int64_t, 2> bs; // dims 3 and 4
-    const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
-
-    std::string vars() override {
-        return VARS_TO_STR9(type_a, type_b, n_mats, id, m, n, k, bs, nr);
-    }
-
-    double max_nmse_err() override {
-        return 5e-4;
-    }
-
-    test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
-            int n_mats = 2, int id = 0,
-            int64_t m = 32, int64_t n = 32, int64_t k = 32,
-            std::array<int64_t, 2> bs = {10, 10},
-            std::array<int64_t, 2> nr = {2, 2})
-        : type_a(type_a), type_b(type_b), n_mats(n_mats), id(id),
-            m(m), n(n), k(k), bs(bs), nr(nr) {}
-
-    ggml_tensor * build_graph(ggml_context * ctx) override {
-        // C^T = A * B^T: (k, m) * (k, n) => (m, n)
-        std::vector<ggml_tensor *> mats;
-        for (int i = 0; i < n_mats; i++) {
-            ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0]*nr[0], bs[1]*nr[1]);
-            mats.push_back(a);
-        }
-        ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_mats);
-        ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
-        ggml_tensor * out = ggml_mul_mat_id(ctx, mats.data(), ids, id, b);
-        return out;
-    }
-
-    void initialize_tensors(ggml_context * ctx) override {
-        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
-            if (t->type == GGML_TYPE_I32) {
-                // ids
-                std::vector<int> data(n_mats);
-                for (int i = 0; i < n_mats; i++) {
-                    data[i] = i;
-                }
-                std::shuffle(data.begin(), data.end(), std::default_random_engine(std::random_device()()));
-                ggml_backend_tensor_set(t, data.data(), 0, n_mats * sizeof(int));
-            } else {
-                init_tensor_uniform(t);
-            }
-        }
-    }
-};
-
 // GGML_OP_SUM_ROWS
 struct test_sum_rows : public test_case {
     const ggml_type type;
@@ -936,8 +1115,6 @@ enum test_mode {
 };
 
 static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
-    ggml_backend_t backend_cpu = ggml_backend_cpu_init();
-
     std::vector<std::unique_ptr<test_case>> test_cases;
 
     // unary ops
@@ -950,7 +1127,12 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         test_cases.emplace_back(new test_get_rows(type, 16, 5, 3));
     }
 
-    test_cases.emplace_back(new test_repeat());
+    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1}));
+    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {2, 1, 1, 1}));
+    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 2, 1, 1}));
+    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 2, 1}));
+    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 2}));
+
     test_cases.emplace_back(new test_dup());
     test_cases.emplace_back(new test_cpy());
     test_cases.emplace_back(new test_cont());
@@ -961,6 +1143,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         }
     };
 
+    add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 8, 1}, {1, 1, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 320, 320}, {1, 1, 1, 1});
     add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 1, 1}, {1, 1, 1, 1});
     add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 1}, {1, 1, 1, 1});
     add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 1, 1});
@@ -972,6 +1156,23 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 2, 2, 2});
     add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {2, 2, 2, 2});
 
+    // stable diffusion
+    add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 1, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 16, 16, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {1280, 16, 16, 1}, {1, 1, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 256, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1280, 1}, {16, 16, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {16, 16, 1280, 1}, {1, 1, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1920, 1}, {16, 16, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 2560, 1}, {16, 16, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1280, 1}, {32, 32, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1920, 1}, {32, 32, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 640, 1}, {32, 32, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {5120, 1, 1, 1}, {1, 256, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {640, 1, 1, 1}, {1, 1, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {1, 1, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {2, 1, 1, 1});
+
     test_cases.emplace_back(new test_scale());
 
     for (float eps : {1e-6f, 1e-5f, 1e-3f, 1e-1f}) {
@@ -979,7 +1180,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
     }
 
-    ggml_type all_types[] = {
+    const ggml_type all_types[] = {
         GGML_TYPE_F32, GGML_TYPE_F16,
         GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
         GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
@@ -1010,6 +1211,16 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         }
     }
 
+    for (ggml_type type_a : all_types) {
+        for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
+            for (int n_mats : {1, 2, 4}) {
+                for (int id = 0; id < n_mats; id++) {
+                    test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, {1, 1}, {1, 1}));
+                }
+            }
+        }
+    }
+
     test_cases.emplace_back(new test_sqr());
     test_cases.emplace_back(new test_clamp());
 
@@ -1028,7 +1239,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         test_cases.emplace_back(new test_rope(type, { 64,  71, 10, 1},  64, 2, 512)); // neox (falcon 7B)
         test_cases.emplace_back(new test_rope(type, { 64,   8, 10, 1},  64, 2, 512)); // neox (falcon 40B)
         test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1},  64, 2, 512)); // neox (falcon 40B)
-        //test_cases.emplace_back(new test_rope(type, {80, 32, 10, 1}, 20, 2, 512)); // neox rope (stablelm) (TODO: enable after llama.cpp sync)
+        //test_cases.emplace_back(new test_rope(type, {80, 32, 10, 1}, 20, 2, 512)); // neox (stablelm) (TODO: enable after llama.cpp sync)
     }
 
     test_cases.emplace_back(new test_alibi());
@@ -1039,39 +1250,37 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
     }
 
-    for (ggml_type type_a : all_types) {
-        for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
-            for (int n_mats : {1, 2, 4}) {
-                for (int id = 0; id < n_mats; id++) {
-                    test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, {1, 1}, {1, 1}));
-                }
-            }
-        }
-    }
-
     test_cases.emplace_back(new test_sum_rows());
 
     // run tests
-    size_t n_ok = 0;
-    for (auto & test : test_cases) {
-        if (test->eval(backend, backend_cpu, op_name)) {
-            n_ok++;
-        }
-    }
+    if (mode == MODE_TEST) {
+        ggml_backend_t backend_cpu = ggml_backend_cpu_init();
 
-    printf("  %zu/%zu tests passed\n", n_ok, test_cases.size());
+        size_t n_ok = 0;
+        for (auto & test : test_cases) {
+            if (test->eval(backend, backend_cpu, op_name)) {
+                n_ok++;
+            }
+        }
+        printf("  %zu/%zu tests passed\n", n_ok, test_cases.size());
 
-    ggml_backend_free(backend_cpu);
+        ggml_backend_free(backend_cpu);
 
-    return n_ok == test_cases.size();
+        return n_ok == test_cases.size();
+    } else if (mode == MODE_PERF) {
+        for (auto & test : test_cases) {
+            test->eval_perf(backend, op_name);
+        }
+        return true;
+    } else {
+        GGML_ASSERT(false);
+    }
 }
 
 static void usage(char ** argv) {
-    // command line: test-backend-ops [mode] [-o op] [-b backend]
-    // modes are correctness (compare with CPU) or performance
     printf("Usage: %s [mode] [-o op] [-b backend]\n", argv[0]);
-    printf("  valid modes are: test (compare with CPU backend for correctness) or perf (performance evaluation) [not implemented]\n");
-    printf("  op names are as given ggml_op_desc()\n");
+    printf("  valid modes are: test (compare with CPU backend for correctness) or perf (performance evaluation)\n");
+    printf("  op names are as given by ggml_op_desc()\n");
 }
 
 int main(int argc, char ** argv) {