]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : support permuted matrix multiplicaions (llama/10033)
authorGeorgi Gerganov <redacted>
Fri, 25 Oct 2024 19:26:15 +0000 (22:26 +0300)
committerGeorgi Gerganov <redacted>
Fri, 1 Nov 2024 08:19:05 +0000 (10:19 +0200)
* metal : support permuted matrix multiplicaions

ggml-ci

* cont : use nb01 directly for row steps

ggml-ci

* cont : add comments [no ci]

* metal : minor refactor

* metal : minor

ggml/src/ggml-metal.m
ggml/src/ggml-metal.metal

index e9541441c8f547af7eb5ec4d765c357296c3c9c7..80c08f15b2999aecffe61785b620d3f87bd1fa5e 100644 (file)
@@ -1015,19 +1015,21 @@ static void ggml_metal_encode_node(
     id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
     id<MTLBuffer> id_dst  = dst  ? ggml_metal_get_buffer(dst,  &offs_dst)  : nil;
 
-    //GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
-    //if (src0) {
-    //    GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
-    //            ggml_is_contiguous(src0), src0->name);
-    //}
-    //if (src1) {
-    //    GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
-    //            ggml_is_contiguous(src1), src1->name);
-    //}
-    //if (dst) {
-    //    GGML_LOG_INFO("%s: dst  - %4s [%5lld, %5lld, %5lld], 1, %s\n",  __func__, ggml_type_name(dstt),  ne0,  ne1,  ne2,
-    //            dst->name);
-    //}
+#if 0
+    GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
+    if (src0) {
+        GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
+                ggml_is_contiguous(src0), src0->name);
+    }
+    if (src1) {
+        GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
+                ggml_is_contiguous(src1), src1->name);
+    }
+    if (dst) {
+        GGML_LOG_INFO("%s: dst  - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
+                dst->name);
+    }
+#endif
 
     id<MTLDevice> device = ctx_dev->mtl_device;
 
@@ -1810,14 +1812,16 @@ static void ggml_metal_encode_node(
                             [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:4];
                             [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:5];
                             [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:6];
-                            [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:7];
-                            [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:8];
-                            [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:9];
-                            [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:10];
-                            [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:11];
-                            [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:12];
-                            [encoder setBytes:&r2      length:sizeof(r2)   atIndex:13];
-                            [encoder setBytes:&r3      length:sizeof(r3)   atIndex:14];
+                            [encoder setBytes:&nb03    length:sizeof(nb03) atIndex:7];
+                            [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:8];
+                            [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:9];
+                            [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:10];
+                            [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:11];
+                            [encoder setBytes:&nb13    length:sizeof(nb13) atIndex:12];
+                            [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:13];
+                            [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:14];
+                            [encoder setBytes:&r2      length:sizeof(r2)   atIndex:15];
+                            [encoder setBytes:&r3      length:sizeof(r3)   atIndex:16];
                             [encoder setThreadgroupMemoryLength:8192 atIndex:0];
                             [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                         } else {
@@ -1986,20 +1990,22 @@ static void ggml_metal_encode_node(
                             [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
                             [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
                             [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
-                            [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
-                            [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
-                            [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
-                            [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
-                            [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
-                            [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
-                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:15];
-                            [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:16];
-                            [encoder setBytes:&r2   length:sizeof(r2)   atIndex:17];
-                            [encoder setBytes:&r3   length:sizeof(r3)   atIndex:18];
+                            [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+                            [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
+                            [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
+                            [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
+                            [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13];
+                            [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14];
+                            [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15];
+                            [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16];
+                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:17];
+                            [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:18];
+                            [encoder setBytes:&r2   length:sizeof(r2)   atIndex:19];
+                            [encoder setBytes:&r3   length:sizeof(r3)   atIndex:20];
 
                             if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 ||
-                                    src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
-                                    src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
+                                src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
+                                src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
                             else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -2048,6 +2054,9 @@ static void ggml_metal_encode_node(
 
                 GGML_ASSERT(src1t == GGML_TYPE_F32);
 
+                GGML_ASSERT(ne03 == 1);
+                GGML_ASSERT(ne13 == 1);
+
                 // find the break-even point where the matrix-matrix kernel becomes more efficient compared
                 // to the matrix-vector kernel
                 // ne20 = n_used_experts
index 71b58be1fd8a4c440879c68aa52f1b996db43f14..defde6246f12945ae467f37b6e7f7d8ee64fdeff 100644 (file)
@@ -777,10 +777,10 @@ kernel void kernel_ssm_conv_f32(
     const int64_t i3 = tgpig.z;
 
     const int64_t nc  = ne10;
-    const int64_t ncs = ne00;
-    const int64_t nr  = ne01;
-    const int64_t n_t = ne1;
-    const int64_t n_s = ne2;
+  //const int64_t ncs = ne00;
+  //const int64_t nr  = ne01;
+  //const int64_t n_t = ne1;
+  //const int64_t n_s = ne2;
 
     device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
     device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
@@ -834,9 +834,9 @@ kernel void kernel_ssm_scan_f32(
     const int64_t i3 = tgpig.y;
 
     const int64_t nc  = d_state;
-    const int64_t nr  = d_inner;
+  //const int64_t nr  = d_inner;
     const int64_t n_t = n_seq_tokens;
-    const int64_t n_s = n_seqs;
+  //const int64_t n_s = n_seqs;
 
     for (int64_t i2 = 0; i2 < n_t; ++i2) {
         device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
@@ -1064,17 +1064,18 @@ kernel void kernel_group_norm(
 inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
     float d = qb_curr->d;
 
-    float2 acc = 0.f;
+    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
 
-    device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
+    device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2);
 
-    for (int i = 0; i < 8; i+=2) {
-        acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
-                + yl[i + 1] * (qs[i / 2] & 0x0F00);
-        acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
-                + yl[i + 9] * (qs[i / 2] & 0xF000);
+    for (int i = 0; i < 8; i += 2) {
+        acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
+        acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
+        acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
+        acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
     }
-    return d * (sumy * -8.f + acc[0] + acc[1]);
+
+    return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]);
 }
 
 // function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -1085,17 +1086,18 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
     float d = qb_curr->d;
     float m = qb_curr->m;
 
-    float2 acc = 0.f;
+    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
 
-    device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
+    device const uint16_t * qs = ((device const uint16_t *) qb_curr + 2 + il/2);
 
     for (int i = 0; i < 8; i+=2) {
-        acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
-                + yl[i + 1] * (qs[i / 2] & 0x0F00);
-        acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
-                + yl[i + 9] * (qs[i / 2] & 0xF000);
+        acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
+        acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
+        acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
+        acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
     }
-    return d * (acc[0] + acc[1]) + sumy * m;
+
+    return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
 }
 
 // function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -1105,18 +1107,19 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
 inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
     float d = qb_curr->d;
 
-    float2 acc = 0.f;
+    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
 
     device const uint16_t * qs =  ((device const uint16_t *)qb_curr + 3 + il/2);
            const uint32_t   qh = *((device const uint32_t *)qb_curr->qh);
 
     for (int i = 0; i < 8; i+=2) {
-        acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il        ) << 4 ) & 0x00010))
-                + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il        ) << 12) & 0x01000));
-        acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
-                + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
+        acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il        ) << 4 ) & 0x00010));
+        acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il        ) << 12) & 0x01000));
+        acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
+        acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
     }
-    return d * (sumy * -16.f + acc[0] + acc[1]);
+
+    return d * (sumy * -16.f + acc[0] + acc[1] + acc[2] + acc[3]);
 }
 
 // function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -1127,18 +1130,19 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
     float d = qb_curr->d;
     float m = qb_curr->m;
 
-    float2 acc = 0.f;
+    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
 
     device const uint16_t * qs =  ((device const uint16_t *)qb_curr + 4 + il/2);
            const uint32_t   qh = *((device const uint32_t *)qb_curr->qh);
 
     for (int i = 0; i < 8; i+=2) {
-        acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il        ) << 4 ) & 0x00010))
-                + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il        ) << 12) & 0x01000));
-        acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
-                + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
+        acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il        ) << 4 ) & 0x00010));
+        acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il        ) << 12) & 0x01000));
+        acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
+        acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
     }
-    return d * (acc[0] + acc[1]) + sumy * m;
+
+    return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
 }
 
 // putting them in the kernel cause a significant performance penalty
@@ -1156,14 +1160,22 @@ void mul_vec_q_n_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
                    uint      r3,
         threadgroup int8_t * shared_values,
-                   uint3 tgpig, uint tiisg, uint sgitg) {
+                     uint3   tgpig,
+                     uint    tiisg,
+                     uint    sgitg) {
     const int nb = ne00/QK4_0;
 
     const int r0 = tgpig.x;
@@ -1175,10 +1187,19 @@ void mul_vec_q_n_f32_impl(
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+  //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
 
-    device const block_q_type * x = (device const block_q_type *) src0 + offset0;
-    device const float        * y = (device const float        *) src1 + r1*ne10 + im*ne00*ne1;
+  //device const block_q_type * x = (device const block_q_type *) ((device char *) src0 + offset0);
+    device const float        * y = (device const float        *) ((device char *) src1 + offset1);
+
+    // pointers to src0 rows
+    device const block_q_type * ax[nr];
+    for (int row = 0; row < nr; ++row) {
+        const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+
+        ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
+    }
 
     float yl[16]; // src1 vector cache
     float sumf[nr] = {0.f};
@@ -1190,19 +1211,22 @@ void mul_vec_q_n_f32_impl(
 
     // each thread in a SIMD group deals with half a block.
     for (int ib = ix; ib < nb; ib += nw/2) {
-        float sumy = 0;
+        float sumy[2] = { 0.f, 0.f };
+
+#pragma unroll
         for (int i = 0; i < 8; i += 2) {
-            sumy += yb[i] + yb[i+1];
-            yl[i+0] = yb[i+ 0];
-            yl[i+1] = yb[i+ 1]/256.f;
+            sumy[0]  += yb[i +  0] + yb[i +  1];
+            yl[i + 0] = yb[i +  0];
+            yl[i + 1] = yb[i +  1]/256.f;
 
-            sumy += yb[i+16] + yb[i+17];
-            yl[i+8] = yb[i+16]/16.f;
-            yl[i+9] = yb[i+17]/4096.f;
+            sumy[1]  += yb[i + 16] + yb[i + 17];
+            yl[i + 8] = yb[i + 16]/16.f;
+            yl[i + 9] = yb[i + 17]/4096.f;
         }
 
+#pragma unroll
         for (int row = 0; row < nr; row++) {
-            sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il);
+            sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
         }
 
         yb += QK4_0 * 16;
@@ -1226,12 +1250,14 @@ kernel void kernel_mul_mv_q4_0_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -1239,7 +1265,7 @@ kernel void kernel_mul_mv_q4_0_f32(
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+    mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mv_q4_1_f32(
@@ -1252,12 +1278,14 @@ kernel void kernel_mul_mv_q4_1_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -1265,7 +1293,7 @@ kernel void kernel_mul_mv_q4_1_f32(
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
-     mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+     mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mv_q5_0_f32(
@@ -1278,12 +1306,14 @@ kernel void kernel_mul_mv_q5_0_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -1291,7 +1321,7 @@ kernel void kernel_mul_mv_q5_0_f32(
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+    mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mv_q5_1_f32(
@@ -1304,12 +1334,14 @@ kernel void kernel_mul_mv_q5_1_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -1317,7 +1349,7 @@ kernel void kernel_mul_mv_q5_1_f32(
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+    mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
 }
 
 
@@ -1330,8 +1362,14 @@ void kernel_mul_mv_q8_0_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -1354,10 +1392,19 @@ void kernel_mul_mv_q8_0_f32_impl(
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+  //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
 
-    device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
-    device const float      * y = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+  //device const block_q8_0 * x = (device const block_q8_0 *) ((device char *) src0 + offset0);
+    device const float      * y = (device const float      *) ((device char *) src1 + offset1);
+
+    // pointers to src0 rows
+    device const block_q8_0 * ax[nr];
+    for (int row = 0; row < nr; ++row) {
+        const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+
+        ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
+    }
 
     float yl[NB_Q8_0];
     float sumf[nr]={0.f};
@@ -1374,12 +1421,12 @@ void kernel_mul_mv_q8_0_f32_impl(
         }
 
         for (int row = 0; row < nr; row++) {
-            device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
+            device const int8_t * qs = ax[row][ib].qs + NB_Q8_0*il;
             float sumq = 0.f;
             for (int iq = 0; iq < NB_Q8_0; ++iq) {
                 sumq += qs[iq] * yl[iq];
             }
-            sumf[row] += sumq*x[ib+row*nb].d;
+            sumf[row] += sumq*ax[row][ib].d;
         }
 
         yb += NB_Q8_0 * nw;
@@ -1404,12 +1451,14 @@ kernel void kernel_mul_mv_q8_0_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -1417,7 +1466,7 @@ kernel void kernel_mul_mv_q8_0_f32(
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+    kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
 }
 
 #define N_MV_T_T 4
@@ -1433,12 +1482,14 @@ void kernel_mul_mv_impl(
                   uint64_t   nb00,
                   uint64_t   nb01,
                   uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne11,
                    int64_t   ne12,
                   uint64_t   nb10,
                   uint64_t   nb11,
                   uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -1452,7 +1503,7 @@ void kernel_mul_mv_impl(
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
 
     device const T0 * x = (device const T0 *) (src0 + offset0);
 
@@ -1463,7 +1514,9 @@ void kernel_mul_mv_impl(
                 break;
             }
 
-            device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
+            const uint offset1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+            device const T1 * y = (device const T1 *) (src1 + offset1);
 
             float sumf = 0;
             for (int i = tiisg; i < ne00; i += 32) {
@@ -1483,7 +1536,9 @@ void kernel_mul_mv_impl(
                 break;
             }
 
-            device const T1  * y  = (device const T1  *) (src1 + r1*nb11 + im*nb12);
+            const uint offset1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+            device const T1  * y  = (device const T1  *) (src1 + offset1);
             device const T14 * y4 = (device const T14 *) y;
 
             float sumf = 0;
@@ -1511,12 +1566,14 @@ kernel void kernel_mul_mv(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -1533,12 +1590,14 @@ kernel void kernel_mul_mv(
         nb00,
         nb01,
         nb02,
+        nb03,
         ne10,
         ne11,
         ne12,
         nb10,
         nb11,
         nb12,
+        nb13,
         ne0,
         ne1,
         r2,
@@ -1564,12 +1623,14 @@ kernel void kernel_mul_mv_1row(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -1584,10 +1645,11 @@ kernel void kernel_mul_mv_1row(
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
 
     device const T     * x = (device const T     *) (src0 + offset0);
-    device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+    device const float * y = (device const float *) (src1 + offset1);
 
     float sumf = 0;
     if (ne00 < 128) {
@@ -1631,12 +1693,14 @@ kernel void kernel_mul_mv_l4(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -1651,12 +1715,14 @@ kernel void kernel_mul_mv_l4(
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
 
     device const T4 * x4 = (device const T4 *) (src0 + offset0);
 
     for (int r1 = 0; r1 < nrows; ++r1) {
-        device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
+        const uint offset1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+        device const float4 * y4 = (device const float4 *) (src1 + offset1);
 
         float sumf = 0;
         for (int i = tiisg; i < ne00/4; i += 32) {
@@ -3416,8 +3482,14 @@ void kernel_mul_mv_q2_K_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -3433,21 +3505,19 @@ void kernel_mul_mv_q2_K_f32_impl(
     const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-    const int ib_row = first_row * nb;
 
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
 
-    device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
-    device const float      * y = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+    device const block_q2_K * x = (device const block_q2_K *) ((device char *) src0 + offset0);
+    device const float      * y = (device const float      *) ((device char *) src1 + offset1);
 
     float yl[32];
     float sumf[N_DST]={0.f}, all_sum;
 
-    const int step = sizeof(block_q2_K) * nb;
-
     const int ix = tiisg/8;  // 0...3
     const int it = tiisg%8;  // 0...7
     const int iq = it/4;     // 0 or 1
@@ -3492,9 +3562,9 @@ void kernel_mul_mv_q2_K_f32_impl(
                                  (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
                          dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
 
-            qs += step/2;
-            sc += step;
-            dh += step/2;
+            qs += nb01/2;
+            sc += nb01;
+            dh += nb01/2;
         }
 
         y4 += 4 * QK_K;
@@ -3519,12 +3589,14 @@ kernel void kernel_mul_mv_q2_K_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -3533,7 +3605,7 @@ kernel void kernel_mul_mv_q2_K_f32(
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
 }
 
 void kernel_mul_mv_q3_K_f32_impl(
@@ -3543,8 +3615,14 @@ void kernel_mul_mv_q3_K_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -3565,10 +3643,11 @@ void kernel_mul_mv_q3_K_f32_impl(
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
 
-    device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
-    device const float     * yy = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+    device const block_q3_K * x = (device const block_q3_K *) ((device char *) src0 + offset0);
+    device const float     * yy = (device const float      *) ((device char *) src1 + offset1);
 
     float yl[32];
 
@@ -3608,8 +3687,6 @@ void kernel_mul_mv_q3_K_f32_impl(
     const int q_offset = 32*ip + l0;
     const int y_offset = 128*ip + 32*il + l0;
 
-    const int step = sizeof(block_q3_K) * nb / 2;
-
     device const float * y1 = yy + ix*QK_K + y_offset;
 
     uint32_t scales32, aux32;
@@ -3619,7 +3696,6 @@ void kernel_mul_mv_q3_K_f32_impl(
     float sumf1[2] = {0.f};
     float sumf2[2] = {0.f};
     for (int i = ix; i < nb; i += 4) {
-
         for (int l = 0; l < 8; ++l) {
             yl[l+ 0] = y1[l+ 0];
             yl[l+ 8] = y1[l+16];
@@ -3633,7 +3709,6 @@ void kernel_mul_mv_q3_K_f32_impl(
         device const half * dh = &x[i].d;
 
         for (int row = 0; row < 2; ++row) {
-
             const float d_all = (float)dh[0];
 
             scales16[0] = a[4];
@@ -3673,15 +3748,13 @@ void kernel_mul_mv_q3_K_f32_impl(
             sumf1[row] += d1 * (scales[1] - 32);
             sumf2[row] += d2 * (scales[3] - 32);
 
-            q  += step;
-            h  += step;
-            a  += step;
-            dh += step;
-
+            q  += nb01/2;
+            h  += nb01/2;
+            a  += nb01/2;
+            dh += nb01/2;
         }
 
         y1 += 4 * QK_K;
-
     }
 
     for (int row = 0; row < 2; ++row) {
@@ -3706,12 +3779,14 @@ kernel void kernel_mul_mv_q3_K_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -3720,7 +3795,7 @@ kernel void kernel_mul_mv_q3_K_f32(
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
 }
 
 void kernel_mul_mv_q4_K_f32_impl(
@@ -3730,8 +3805,14 @@ void kernel_mul_mv_q4_K_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -3756,29 +3837,26 @@ void kernel_mul_mv_q4_K_f32_impl(
     const int im = tgpig.z;
     //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
     const int first_row = r0 * N_DST;
-    const int ib_row = first_row * nb;
 
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
 
-    device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
-    device const float      * y = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+    device const block_q4_K * x = (device const block_q4_K *) ((device char *) src0 + offset0);
+    device const float      * y = (device const float      *) ((device char *) src1 + offset1);
 
     float yl[16];
     float yh[16];
     float sumf[N_DST]={0.f}, all_sum;
 
-    const int step = sizeof(block_q4_K) * nb / 2;
-
     device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
 
     uint16_t sc16[4];
     thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
 
     for (int ib = ix; ib < nb; ib += 4) {
-
         float4 sumy = {0.f, 0.f, 0.f, 0.f};
         for (int i = 0; i < 8; ++i) {
             yl[i+0] = y4[i+  0]; sumy[0] += yl[i+0];
@@ -3792,7 +3870,6 @@ void kernel_mul_mv_q4_K_f32_impl(
         device const half     * dh = &x[ib].d;
 
         for (int row = 0; row < N_DST; row++) {
-
             sc16[0] = sc[0] & kmask1;
             sc16[1] = sc[2] & kmask1;
             sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
@@ -3821,9 +3898,9 @@ void kernel_mul_mv_q4_K_f32_impl(
                                  (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
                          dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
 
-            q1 += step;
-            sc += step;
-            dh += step;
+            q1 += nb01/2;
+            sc += nb01/2;
+            dh += nb01/2;
         }
 
         y4 += 4 * QK_K;
@@ -3848,12 +3925,14 @@ kernel void kernel_mul_mv_q4_K_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -3862,7 +3941,7 @@ kernel void kernel_mul_mv_q4_K_f32(
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
 }
 
 void kernel_mul_mv_q5_K_f32_impl(
@@ -3872,8 +3951,14 @@ void kernel_mul_mv_q5_K_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -3894,15 +3979,14 @@ void kernel_mul_mv_q5_K_f32_impl(
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
 
-    device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
-    device const float     * yy = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+    device const block_q5_K * x = (device const block_q5_K *) ((device char *) src0 + offset0);
+    device const float     * yy = (device const float      *) ((device char *) src1 + offset1);
 
     float sumf[2]={0.f};
 
-    const int step = sizeof(block_q5_K) * nb;
-
     float yl[16], yh[16];
 
     const uint16_t kmask1 = 0x3f3f;
@@ -3930,7 +4014,6 @@ void kernel_mul_mv_q5_K_f32_impl(
     device const float * y1 = yy + ix*QK_K + y_offset;
 
     for (int i = ix; i < nb; i += 4) {
-
         device const uint8_t * q1 = x[i].qs + q_offset;
         device const uint8_t * qh = x[i].qh + l0;
         device const half * dh = &x[i].d;
@@ -3946,7 +4029,6 @@ void kernel_mul_mv_q5_K_f32_impl(
         }
 
         for (int row = 0; row < 2; ++row) {
-
             device const uint8_t * q2 = q1 + 64;
 
             sc16[0] = a[0] & kmask1;
@@ -3975,15 +4057,13 @@ void kernel_mul_mv_q5_K_f32_impl(
                                  sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
                          dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
 
-            q1 += step;
-            qh += step;
-            dh += step/2;
-            a  += step/2;
-
+            q1 += nb01;
+            qh += nb01;
+            dh += nb01/2;
+            a  += nb01/2;
         }
 
         y1 += 4 * QK_K;
-
     }
 
     for (int row = 0; row < 2; ++row) {
@@ -4005,12 +4085,14 @@ kernel void kernel_mul_mv_q5_K_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -4019,7 +4101,7 @@ kernel void kernel_mul_mv_q5_K_f32(
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
 }
 
 void kernel_mul_mv_q6_K_f32_impl(
@@ -4029,8 +4111,14 @@ void kernel_mul_mv_q6_K_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -4056,10 +4144,11 @@ void kernel_mul_mv_q6_K_f32_impl(
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    const uint offset0 = row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =  r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
 
-    device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
-    device const float     * yy = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+    device const block_q6_K * x = (device const block_q6_K *) ((device char *) src0 + offset0);
+    device const float     * yy = (device const float      *) ((device char *) src1 + offset1);
 
     float sumf = 0;
 
@@ -4115,12 +4204,14 @@ kernel void kernel_mul_mv_q6_K_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -4129,7 +4220,7 @@ kernel void kernel_mul_mv_q6_K_f32(
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
 }
 
 // ======================= "True" 2-bit
@@ -4141,8 +4232,14 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -4158,15 +4255,15 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
     const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-    const int ib_row = first_row * nb;
 
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
 
-    device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0;
-    device const float         * y = (device const float         *) src1 + r1*ne10 + im*ne00*ne1;
+    device const block_iq2_xxs * x = (device const block_iq2_xxs *) ((device char *) src0 + offset0);
+    device const float         * y = (device const float         *) ((device char *) src1 + offset1);
 
     float yl[32];
     float sumf[N_DST]={0.f}, all_sum;
@@ -4219,8 +4316,8 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
             }
             sumf[row] += d * sum;
 
-            dh += nb*sizeof(block_iq2_xxs)/2;
-            q2 += nb*sizeof(block_iq2_xxs)/2;
+            dh += nb01/2;
+            q2 += nb01/2;
         }
 
         y4 += 32 * 32;
@@ -4245,12 +4342,14 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -4260,7 +4359,7 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
 }
 
 void kernel_mul_mv_iq2_xs_f32_impl(
@@ -4270,8 +4369,14 @@ void kernel_mul_mv_iq2_xs_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -4287,15 +4392,15 @@ void kernel_mul_mv_iq2_xs_f32_impl(
     const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-    const int ib_row = first_row * nb;
 
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
 
-    device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0;
-    device const float        * y = (device const float        *) src1 + r1*ne10 + im*ne00*ne1;
+    device const block_iq2_xs * x = (device const block_iq2_xs *) ((device char *) src0 + offset0);
+    device const float        * y = (device const float        *) ((device char *) src1 + offset1);
 
     float yl[32];
     float sumf[N_DST]={0.f}, all_sum;
@@ -4357,9 +4462,9 @@ void kernel_mul_mv_iq2_xs_f32_impl(
             }
             sumf[row] += d1 * sum1 + d2 * sum2;
 
-            dh += nb*sizeof(block_iq2_xs)/2;
-            q2 += nb*sizeof(block_iq2_xs)/2;
-            sc += nb*sizeof(block_iq2_xs);
+            dh += nb01/2;
+            q2 += nb01/2;
+            sc += nb01;
         }
 
         y4 += 32 * 32;
@@ -4384,12 +4489,14 @@ kernel void kernel_mul_mv_iq2_xs_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -4399,7 +4506,7 @@ kernel void kernel_mul_mv_iq2_xs_f32(
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
 }
 
 void kernel_mul_mv_iq3_xxs_f32_impl(
@@ -4409,8 +4516,14 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -4426,15 +4539,15 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
     const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-    const int ib_row = first_row * nb;
 
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
 
-    device const block_iq3_xxs * x = (device const block_iq3_xxs *) src0 + ib_row + offset0;
-    device const float         * y = (device const float         *) src1 + r1*ne10 + im*ne00*ne1;
+    device const block_iq3_xxs * x = (device const block_iq3_xxs *) ((device char *) src0 + offset0);
+    device const float         * y = (device const float         *) ((device char *) src1 + offset1);
 
     float yl[32];
     float sumf[N_DST]={0.f}, all_sum;
@@ -4489,9 +4602,9 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
             }
             sumf[row] += d * (sum[0] + sum[1]);
 
-            dh  += nb*sizeof(block_iq3_xxs)/2;
-            q3  += nb*sizeof(block_iq3_xxs);
-            gas += nb*sizeof(block_iq3_xxs)/2;
+            dh  += nb01/2;
+            q3  += nb01;
+            gas += nb01/2;
         }
 
         y4 += 32 * 32;
@@ -4516,12 +4629,14 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -4531,7 +4646,7 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
 }
 
 void kernel_mul_mv_iq3_s_f32_impl(
@@ -4541,8 +4656,14 @@ void kernel_mul_mv_iq3_s_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -4558,15 +4679,15 @@ void kernel_mul_mv_iq3_s_f32_impl(
     const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-    const int ib_row = first_row * nb;
 
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
 
-    device const block_iq3_s * x = (device const block_iq3_s *) src0 + ib_row + offset0;
-    device const float       * y = (device const float       *) src1 + r1*ne10 + im*ne00*ne1;
+    device const block_iq3_s * x = (device const block_iq3_s *) ((device char *) src0 + offset0);
+    device const float       * y = (device const float       *) ((device char *) src1 + offset1);
 
     float yl[32];
     float sumf[N_DST]={0.f}, all_sum;
@@ -4619,11 +4740,11 @@ void kernel_mul_mv_iq3_s_f32_impl(
             }
             sumf[row] += d * (sum[0] + sum[1]);
 
-            dh  += nb*sizeof(block_iq3_s)/2;
-            qs  += nb*sizeof(block_iq3_s);
-            qh  += nb*sizeof(block_iq3_s);
-            sc  += nb*sizeof(block_iq3_s);
-            signs += nb*sizeof(block_iq3_s);
+            dh    += nb01/2;
+            qs    += nb01;
+            qh    += nb01;
+            sc    += nb01;
+            signs += nb01;
         }
 
         y4 += 32 * 32;
@@ -4648,12 +4769,14 @@ kernel void kernel_mul_mv_iq3_s_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -4663,7 +4786,7 @@ kernel void kernel_mul_mv_iq3_s_f32(
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
 }
 
 void kernel_mul_mv_iq2_s_f32_impl(
@@ -4673,8 +4796,14 @@ void kernel_mul_mv_iq2_s_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -4690,15 +4819,15 @@ void kernel_mul_mv_iq2_s_f32_impl(
     const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-    const int ib_row = first_row * nb;
 
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
 
-    device const block_iq2_s * x = (device const block_iq2_s *) src0 + ib_row + offset0;
-    device const float       * y = (device const float       *) src1 + r1*ne10 + im*ne00*ne1;
+    device const block_iq2_s * x = (device const block_iq2_s *) ((device char *) src0 + offset0);
+    device const float       * y = (device const float       *) ((device char *) src1 + offset1);
 
     float yl[32];
     float sumf[N_DST]={0.f}, all_sum;
@@ -4752,11 +4881,11 @@ void kernel_mul_mv_iq2_s_f32_impl(
             }
             sumf[row] += d1 * sum[0] + d2 * sum[1];
 
-            dh  += nb*sizeof(block_iq2_s)/2;
-            qs  += nb*sizeof(block_iq2_s);
-            qh  += nb*sizeof(block_iq2_s);
-            sc  += nb*sizeof(block_iq2_s);
-            signs += nb*sizeof(block_iq2_s);
+            dh    += nb01/2;
+            qs    += nb01;
+            qh    += nb01;
+            sc    += nb01;
+            signs += nb01;
         }
 
         y4 += 32 * 32;
@@ -4781,12 +4910,14 @@ kernel void kernel_mul_mv_iq2_s_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -4796,7 +4927,7 @@ kernel void kernel_mul_mv_iq2_s_f32(
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
 }
 
 void kernel_mul_mv_iq1_s_f32_impl(
@@ -4806,8 +4937,14 @@ void kernel_mul_mv_iq1_s_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -4823,14 +4960,15 @@ void kernel_mul_mv_iq1_s_f32_impl(
     const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-    const int ib_row = first_row * nb;
 
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
-    device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
-    device const float       * y = (device const float       *) src1 + r1*ne10 + im*ne00*ne1;
+    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+    device const block_iq1_s * x = (device const block_iq1_s *) ((device char *) src0 + offset0);
+    device const float       * y = (device const float       *) ((device char *) src1 + offset1);
 
     float yl[32];
     float sumf[N_DST]={0.f}, all_sum;
@@ -4873,9 +5011,9 @@ void kernel_mul_mv_iq1_s_f32_impl(
             }
             sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1);
 
-            dh += nb*sizeof(block_iq1_s)/2;
-            qs += nb*sizeof(block_iq1_s);
-            qh += nb*sizeof(block_iq1_s)/2;
+            dh += nb01/2;
+            qs += nb01;
+            qh += nb01/2;
         }
 
         y4 += 32 * 32;
@@ -4896,8 +5034,14 @@ void kernel_mul_mv_iq1_m_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -4913,14 +5057,15 @@ void kernel_mul_mv_iq1_m_f32_impl(
     const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-    const int ib_row = first_row * nb;
 
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
-    device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0;
-    device const float       * y = (device const float       *) src1 + r1*ne10 + im*ne00*ne1;
+    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+    device const block_iq1_m * x = (device const block_iq1_m *) ((device char *) src0 + offset0);
+    device const float       * y = (device const float       *) ((device char *) src1 + offset1);
 
     float yl[32];
     float sumf[N_DST]={0.f}, all_sum;
@@ -4972,9 +5117,9 @@ void kernel_mul_mv_iq1_m_f32_impl(
             sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
                                              (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
 
-            sc += nb*sizeof(block_iq1_m)/2;
-            qs += nb*sizeof(block_iq1_m);
-            qh += nb*sizeof(block_iq1_m);
+            sc += nb01/2;
+            qs += nb01;
+            qh += nb01;
         }
 
         y4 += 32 * 32;
@@ -4995,8 +5140,14 @@ void kernel_mul_mv_iq4_nl_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -5012,14 +5163,15 @@ void kernel_mul_mv_iq4_nl_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
     const int first_row = (r0 * 2 + sgitg) * 2;
-    const int ib_row = first_row * nb;
 
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
-    device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0;
-    device const float        * y = (device const float        *) src1 + r1*ne10 + im*ne00*ne1;
+    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+    device const block_iq4_nl * x = (device const block_iq4_nl *) ((device char *) src0 + offset0);
+    device const float        * y = (device const float        *) ((device char *) src1 + offset1);
 
     const int ix = tiisg/2;  // 0...15
     const int it = tiisg%2;  // 0 or 1
@@ -5089,8 +5241,14 @@ void kernel_mul_mv_iq4_xs_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -5106,14 +5264,15 @@ void kernel_mul_mv_iq4_xs_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
     const int first_row = (r0 * 2 + sgitg) * 2;
-    const int ib_row = first_row * nb;
 
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
-    device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0;
-    device const float        * y = (device const float        *) src1 + r1*ne10 + im*ne00*ne1;
+    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+    device const block_iq4_xs * x = (device const block_iq4_xs *) ((device char *) src0 + offset0);
+    device const float        * y = (device const float        *) ((device char *) src1 + offset1);
 
     const int ix = tiisg/16;  // 0 or 1
     const int it = tiisg%16;  // 0...15
@@ -5188,12 +5347,14 @@ kernel void kernel_mul_mv_iq1_s_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -5202,7 +5363,7 @@ kernel void kernel_mul_mv_iq1_s_f32(
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
 }
 
 [[host_name("kernel_mul_mv_iq1_m_f32")]]
@@ -5216,12 +5377,14 @@ kernel void kernel_mul_mv_iq1_m_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -5230,7 +5393,7 @@ kernel void kernel_mul_mv_iq1_m_f32(
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
 }
 
 [[host_name("kernel_mul_mv_iq4_nl_f32")]]
@@ -5244,12 +5407,14 @@ kernel void kernel_mul_mv_iq4_nl_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -5259,7 +5424,7 @@ kernel void kernel_mul_mv_iq4_nl_f32(
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
 }
 
 [[host_name("kernel_mul_mv_iq4_xs_f32")]]
@@ -5273,12 +5438,14 @@ kernel void kernel_mul_mv_iq4_xs_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -5288,7 +5455,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
 }
 
 //============================= templates and their specializations =============================
@@ -5833,10 +6000,12 @@ kernel void kernel_mul_mm(device const  uchar * src0,
                           constant    int64_t & ne02,
                           constant   uint64_t & nb01,
                           constant   uint64_t & nb02,
+                          constant   uint64_t & nb03,
                           constant    int64_t & ne12,
                           constant   uint64_t & nb10,
                           constant   uint64_t & nb11,
                           constant   uint64_t & nb12,
+                          constant   uint64_t & nb13,
                           constant    int64_t & ne0,
                           constant    int64_t & ne1,
                           constant       uint & r2,
@@ -5873,12 +6042,13 @@ kernel void kernel_mul_mm(device const  uchar * src0,
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    uint   offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
+    uint   offset0 = (i12/r2)*nb02 + (i13/r3)*nb03;
     ushort offset1 = il/nl;
 
     device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
     device const float   * y = (device const float   *)(src1
-        + nb12 * im
+        + nb13 * i13
+        + nb12 * i12
         + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
         + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
 
@@ -6257,12 +6427,14 @@ typedef void (kernel_mul_mv_impl_t)(
                   uint64_t   nb00,
                   uint64_t   nb01,
                   uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne11,
                    int64_t   ne12,
                   uint64_t   nb10,
                   uint64_t   nb11,
                   uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -6277,8 +6449,14 @@ typedef void (kernel_mul_mv2_impl_t)(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -6299,6 +6477,7 @@ void mmv_fn(
                     uint64_t   nb00,
                     uint64_t   nb01,
                     uint64_t   nb02,
+                    uint64_t   nb03,
                      int64_t   ne10,
                      int64_t   ne11,
                      int64_t   ne12,
@@ -6306,6 +6485,7 @@ void mmv_fn(
                     uint64_t   nb10,
                     uint64_t   nb11,
                     uint64_t   nb12,
+                    uint64_t   nb13,
                      int64_t   ne0,
                      int64_t   ne1,
                     uint64_t   nb1,
@@ -6316,7 +6496,7 @@ void mmv_fn(
         uint                   tiitg,
         uint                   tiisg,
         uint                   sgitg) {
-    impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg);
+    impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,nb03,ne10,ne11,ne12,nb10,nb11,nb12,nb13,ne0,ne1,r2,r3,tgpig,tiisg);
 }
 
 template<kernel_mul_mv2_impl_t impl_fn>
@@ -6330,6 +6510,7 @@ void mmv_fn(
                     uint64_t   nb00,
                     uint64_t   nb01,
                     uint64_t   nb02,
+                    uint64_t   nb03,
                      int64_t   ne10,
                      int64_t   ne11,
                      int64_t   ne12,
@@ -6337,6 +6518,7 @@ void mmv_fn(
                     uint64_t   nb10,
                     uint64_t   nb11,
                     uint64_t   nb12,
+                    uint64_t   nb13,
                      int64_t   ne0,
                      int64_t   ne1,
                     uint64_t   nb1,
@@ -6347,7 +6529,7 @@ void mmv_fn(
         uint                   tiitg,
         uint                   tiisg,
         uint                   sgitg) {
-    impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
+    impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
 }
 
 typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4>>) mul_mv_impl_fn_t;
@@ -6396,8 +6578,8 @@ kernel void kernel_mul_mv_id(
     const int64_t i2 = i12;
 
     device const char * src0_cur = src0s + i02*nb02;
-    device const char * src1_cur = src1 + i11*nb11 + i12*nb12;
-    device      float * dst_cur  = dst + i1*ne0 + i2*ne1*ne0;
+    device const char * src1_cur = src1  + i11*nb11 + i12*nb12;
+    device      float *  dst_cur = dst   + i1*ne0   + i2*ne1*ne0;
 
     impl_fn(
         /* src0 */ src0_cur,
@@ -6405,19 +6587,21 @@ kernel void kernel_mul_mv_id(
         /* dst  */ dst_cur,
         /* ne00 */ ne00,
         /* ne01 */ ne01,
-        /* ne02 */ 1,//ne02,
+        /* ne02 */ 1, // ne02,
         /* nb00 */ nb00,
         /* nb01 */ nb01,
         /* nb02 */ nb02,
+        /* nb03 */ nb02, // ne02 == 1
         /* ne10 */ ne10,
-        /* ne11 */ 1,//ne11,
-        /* ne12 */ 1,//ne12,
-        /* ne13 */ 1,//ne13,
+        /* ne11 */ 1, // ne11,
+        /* ne12 */ 1, // ne12,
+        /* ne13 */ 1, // ne13,
         /* nb10 */ nb10,
         /* nb11 */ nb11,
         /* nb12 */ nb12,
+        /* ne13 */ nb12, // ne12 == 1
         /* ne0  */ ne0,
-        /* ne1  */ 1,//ne1,
+        /* ne1  */ 1, // ne1,
         /* nb1  */ nb1,
         /* r2   */ 1,
         /* r3   */ 1,