]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : enable shader debugging (cmake option) (llama/4705)
authorGeorgi Gerganov <redacted>
Tue, 2 Jan 2024 08:57:44 +0000 (10:57 +0200)
committerGeorgi Gerganov <redacted>
Wed, 3 Jan 2024 12:43:51 +0000 (14:43 +0200)
* ggml : disable fast-math for Metal (cmake build only)

ggml-ci

* metal : fix Metal API debug warnings

* cmake : add -fno-inline for Metal build (llama/4545)

* metal : fix API debug warnings

* metal : fix compile warnings

* metal : use uint64_t for strides

* cmake : rename option to LLAMA_METAL_SHADER_DEBUG

* metal : fix mat-vec Q8_0 kernel for BS > 1

* metal : normalize mat-vec kernel signatures

* cmake : respect LLAMA_QKK_64 option

* metal : fix mat-vec Q4_K kernel for QK_K == 64

ggml-ci

ggml-metal.m
ggml-metal.metal

index 51a72ae335745008aeb1ead8da2221eacb352a3f..cd9d00456f7d4e78609d7d6db50a52a785e156ba 100644 (file)
@@ -257,13 +257,14 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
 #endif
         NSError * error = nil;
-        NSString * libPath = [bundle pathForResource:@"default" ofType:@"metallib"];
+        NSString * libPath = [bundle pathForResource:@"ggml" ofType:@"metallib"];
         if (libPath != nil) {
+            // pre-compiled library found
             NSURL * libURL = [NSURL fileURLWithPath:libPath];
             GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
             ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
         } else {
-            GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
+            GGML_METAL_LOG_INFO("%s: ggml.metallib not found, loading from source\n", __func__);
 
             NSString * sourcePath;
             NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
@@ -291,6 +292,13 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
             options = [MTLCompileOptions new];
             options.preprocessorMacros = @{ @"QK_K" : @(64) };
 #endif
+            // try to disable fast-math
+            // NOTE: this seems to have no effect whatsoever
+            //       instead, in order to disable fast-math, we have to build ggml.metallib from the command line
+            //       using xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air
+            //       and go through the "pre-compiled library found" path above
+            //[options setFastMathEnabled:false];
+
             ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
         }
 
@@ -1230,7 +1238,7 @@ void ggml_metal_graph_compute(
                                 // not sure how to avoid this
                                 // TODO: make a simpler cpy_bytes kernel
 
-                                const int nth = MIN(1024, ne00);
+                                const int nth = MIN((int) ctx->pipeline_cpy_f32_f32.maxTotalThreadsPerThreadgroup, ne00);
 
                                 [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
                                 [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
@@ -1285,7 +1293,7 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
                             [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
 
-                            const int nth = MIN(1024, ne0);
+                            const int nth = MIN((int) ctx->pipeline_add.maxTotalThreadsPerThreadgroup, ne00);
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
@@ -1785,8 +1793,9 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&r3      length:sizeof(r3)   atIndex:17];
                                 [encoder setBytes:&idx     length:sizeof(idx)  atIndex:18];
                                 // TODO: how to make this an array? read Metal docs
-                                for (int j = 0; j < n_as; ++j) {
-                                    struct ggml_tensor * src_cur = dst->src[2 + j];
+                                for (int j = 0; j < 8; ++j) {
+                                    // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
+                                    struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
 
                                     size_t offs_src_cur = 0;
                                     id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
@@ -1909,8 +1918,9 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&r3   length:sizeof(r3)   atIndex:21];
                                 [encoder setBytes:&idx  length:sizeof(idx)  atIndex:22];
                                 // TODO: how to make this an array? read Metal docs
-                                for (int j = 0; j < n_as; ++j) {
-                                    struct ggml_tensor * src_cur = dst->src[2 + j];
+                                for (int j = 0; j < 8; ++j) {
+                                    // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
+                                    struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
 
                                     size_t offs_src_cur = 0;
                                     id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
@@ -2229,7 +2239,7 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
                             [encoder setBytes:&sf   length:sizeof(sf)   atIndex:18];
 
-                            const int nth = MIN(1024, ne0);
+                            const int nth = MIN((int) ctx->pipeline_upscale_f32.maxTotalThreadsPerThreadgroup, ne0);
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
index d5b54e112ea37ebcfe61d1ed38ef659b6f240f83..1d5b8f6f4131c35142814029341d28a9510b38dd 100644 (file)
@@ -59,26 +59,26 @@ kernel void kernel_add(
         constant  int64_t & ne01,
         constant  int64_t & ne02,
         constant  int64_t & ne03,
-        constant  int64_t & nb00,
-        constant  int64_t & nb01,
-        constant  int64_t & nb02,
-        constant  int64_t & nb03,
+        constant uint64_t & nb00,
+        constant uint64_t & nb01,
+        constant uint64_t & nb02,
+        constant uint64_t & nb03,
         constant  int64_t & ne10,
         constant  int64_t & ne11,
         constant  int64_t & ne12,
         constant  int64_t & ne13,
-        constant  int64_t & nb10,
-        constant  int64_t & nb11,
-        constant  int64_t & nb12,
-        constant  int64_t & nb13,
+        constant uint64_t & nb10,
+        constant uint64_t & nb11,
+        constant uint64_t & nb12,
+        constant uint64_t & nb13,
         constant  int64_t & ne0,
         constant  int64_t & ne1,
         constant  int64_t & ne2,
         constant  int64_t & ne3,
-        constant  int64_t & nb0,
-        constant  int64_t & nb1,
-        constant  int64_t & nb2,
-        constant  int64_t & nb3,
+        constant uint64_t & nb0,
+        constant uint64_t & nb1,
+        constant uint64_t & nb2,
+        constant uint64_t & nb3,
         constant  int64_t & offs,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]],
@@ -109,26 +109,26 @@ kernel void kernel_mul(
         constant  int64_t & ne01,
         constant  int64_t & ne02,
         constant  int64_t & ne03,
-        constant  int64_t & nb00,
-        constant  int64_t & nb01,
-        constant  int64_t & nb02,
-        constant  int64_t & nb03,
+        constant uint64_t & nb00,
+        constant uint64_t & nb01,
+        constant uint64_t & nb02,
+        constant uint64_t & nb03,
         constant  int64_t & ne10,
         constant  int64_t & ne11,
         constant  int64_t & ne12,
         constant  int64_t & ne13,
-        constant  int64_t & nb10,
-        constant  int64_t & nb11,
-        constant  int64_t & nb12,
-        constant  int64_t & nb13,
+        constant uint64_t & nb10,
+        constant uint64_t & nb11,
+        constant uint64_t & nb12,
+        constant uint64_t & nb13,
         constant  int64_t & ne0,
         constant  int64_t & ne1,
         constant  int64_t & ne2,
         constant  int64_t & ne3,
-        constant  int64_t & nb0,
-        constant  int64_t & nb1,
-        constant  int64_t & nb2,
-        constant  int64_t & nb3,
+        constant uint64_t & nb0,
+        constant uint64_t & nb1,
+        constant uint64_t & nb2,
+        constant uint64_t & nb3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]],
         uint3   ntg[[threads_per_threadgroup]]) {
@@ -158,26 +158,26 @@ kernel void kernel_div(
         constant  int64_t & ne01,
         constant  int64_t & ne02,
         constant  int64_t & ne03,
-        constant  int64_t & nb00,
-        constant  int64_t & nb01,
-        constant  int64_t & nb02,
-        constant  int64_t & nb03,
+        constant uint64_t & nb00,
+        constant uint64_t & nb01,
+        constant uint64_t & nb02,
+        constant uint64_t & nb03,
         constant  int64_t & ne10,
         constant  int64_t & ne11,
         constant  int64_t & ne12,
         constant  int64_t & ne13,
-        constant  int64_t & nb10,
-        constant  int64_t & nb11,
-        constant  int64_t & nb12,
-        constant  int64_t & nb13,
+        constant uint64_t & nb10,
+        constant uint64_t & nb11,
+        constant uint64_t & nb12,
+        constant uint64_t & nb13,
         constant  int64_t & ne0,
         constant  int64_t & ne1,
         constant  int64_t & ne2,
         constant  int64_t & ne3,
-        constant  int64_t & nb0,
-        constant  int64_t & nb1,
-        constant  int64_t & nb2,
-        constant  int64_t & nb3,
+        constant uint64_t & nb0,
+        constant uint64_t & nb1,
+        constant uint64_t & nb2,
+        constant uint64_t & nb3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]],
         uint3   ntg[[threads_per_threadgroup]]) {
@@ -205,7 +205,7 @@ kernel void kernel_add_row(
         device const float4 * src0,
         device const float4 * src1,
         device       float4 * dst,
-        constant    int64_t & nb [[buffer(28)]],
+        constant   uint64_t & nb [[buffer(28)]],
         uint tpig[[thread_position_in_grid]]) {
     dst[tpig] = src0[tpig] + src1[tpig % nb];
 }
@@ -214,7 +214,7 @@ kernel void kernel_mul_row(
         device const float4 * src0,
         device const float4 * src1,
         device       float4 * dst,
-        constant    int64_t & nb  [[buffer(28)]],
+        constant   uint64_t & nb  [[buffer(28)]],
         uint tpig[[thread_position_in_grid]]) {
     dst[tpig] = src0[tpig] * src1[tpig % nb];
 }
@@ -223,7 +223,7 @@ kernel void kernel_div_row(
         device const float4 * src0,
         device const float4 * src1,
         device       float4 * dst,
-        constant    int64_t & nb  [[buffer(28)]],
+        constant   uint64_t & nb  [[buffer(28)]],
         uint tpig[[thread_position_in_grid]]) {
     dst[tpig] = src0[tpig] / src1[tpig % nb];
 }
@@ -307,26 +307,26 @@ kernel void kernel_sum_rows(
         constant  int64_t & ne01,
         constant  int64_t & ne02,
         constant  int64_t & ne03,
-        constant  int64_t & nb00,
-        constant  int64_t & nb01,
-        constant  int64_t & nb02,
-        constant  int64_t & nb03,
+        constant uint64_t & nb00,
+        constant uint64_t & nb01,
+        constant uint64_t & nb02,
+        constant uint64_t & nb03,
         constant  int64_t & ne10,
         constant  int64_t & ne11,
         constant  int64_t & ne12,
         constant  int64_t & ne13,
-        constant  int64_t & nb10,
-        constant  int64_t & nb11,
-        constant  int64_t & nb12,
-        constant  int64_t & nb13,
+        constant uint64_t & nb10,
+        constant uint64_t & nb11,
+        constant uint64_t & nb12,
+        constant uint64_t & nb13,
         constant  int64_t & ne0,
         constant  int64_t & ne1,
         constant  int64_t & ne2,
         constant  int64_t & ne3,
-        constant  int64_t & nb0,
-        constant  int64_t & nb1,
-        constant  int64_t & nb2,
-        constant  int64_t & nb3,
+        constant uint64_t & nb0,
+        constant uint64_t & nb1,
+        constant uint64_t & nb2,
+        constant uint64_t & nb3,
         uint3 tpig[[thread_position_in_grid]]) {
     int64_t i3 = tpig.z;
     int64_t i2 = tpig.y;
@@ -920,14 +920,21 @@ kernel void kernel_mul_mv_q4_0_f32(
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne01[[buffer(4)]],
-        constant   int64_t & ne02[[buffer(5)]],
-        constant   int64_t & ne10[[buffer(9)]],
-        constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0 [[buffer(15)]],
-        constant   int64_t & ne1 [[buffer(16)]],
-        constant   uint    & r2  [[buffer(17)]],
-        constant   uint    & r3  [[buffer(18)]],
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -939,14 +946,21 @@ kernel void kernel_mul_mv_q4_1_f32(
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne01[[buffer(4)]],
-        constant   int64_t & ne02[[buffer(5)]],
-        constant   int64_t & ne10[[buffer(9)]],
-        constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0 [[buffer(15)]],
-        constant   int64_t & ne1 [[buffer(16)]],
-        constant   uint    & r2  [[buffer(17)]],
-        constant   uint    & r3  [[buffer(18)]],
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -958,14 +972,21 @@ kernel void kernel_mul_mv_q5_0_f32(
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne01[[buffer(4)]],
-        constant   int64_t & ne02[[buffer(5)]],
-        constant   int64_t & ne10[[buffer(9)]],
-        constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0 [[buffer(15)]],
-        constant   int64_t & ne1 [[buffer(16)]],
-        constant   uint    & r2  [[buffer(17)]],
-        constant   uint    & r3  [[buffer(18)]],
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -977,14 +998,21 @@ kernel void kernel_mul_mv_q5_1_f32(
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne01[[buffer(4)]],
-        constant   int64_t & ne02[[buffer(5)]],
-        constant   int64_t & ne10[[buffer(9)]],
-        constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0 [[buffer(15)]],
-        constant   int64_t & ne1 [[buffer(16)]],
-        constant   uint    & r2  [[buffer(17)]],
-        constant   uint    & r3  [[buffer(18)]],
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1071,12 +1099,19 @@ kernel void kernel_mul_mv_q8_0_f32(
         constant   int64_t & ne00,
         constant   int64_t & ne01,
         constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
         constant   int64_t & ne10,
+        constant   int64_t & ne11,
         constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
-        constant   uint    & r2   [[buffer(17)]],
-        constant   uint    & r3   [[buffer(18)]],
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1182,8 +1217,8 @@ kernel void kernel_mul_mv_f32_f32(
         constant  uint64_t & nb12,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
-        constant   uint    & r2   [[buffer(17)]],
-        constant   uint    & r3   [[buffer(18)]],
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]]) {
     kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
@@ -1209,8 +1244,8 @@ kernel void kernel_mul_mv_f16_f16(
         constant  uint64_t & nb12,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
-        constant   uint    & r2   [[buffer(17)]],
-        constant   uint    & r3   [[buffer(18)]],
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]]) {
 
@@ -1346,8 +1381,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
         constant  uint64_t & nb12,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
-        constant   uint    & r2   [[buffer(17)]],
-        constant   uint    & r3   [[buffer(18)]],
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]]) {
     kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
@@ -1452,8 +1487,8 @@ kernel void kernel_mul_mv_f16_f32(
         constant  uint64_t & nb12,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
-        constant   uint    & r2   [[buffer(17)]],
-        constant   uint    & r3   [[buffer(18)]],
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]]) {
     kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
@@ -1478,8 +1513,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
         constant  uint64_t & nb12,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
-        constant   uint    & r2   [[buffer(17)]],
-        constant   uint    & r3   [[buffer(18)]],
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]]) {
 
@@ -1543,7 +1578,8 @@ kernel void kernel_alibi_f32(
     const int64_t i3 = n / (ne2*ne1*ne0);
     const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
     const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
-    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+  //const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
     const int64_t k = i3*ne3 + i2;
 
     float m_k;
@@ -2410,22 +2446,6 @@ typedef struct {
 } block_q6_K;
 // 210 bytes / block
 
-static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
-    uchar4 r;
-    if (j < 4) {
-        r[0] = q[j+0] & 63;
-        r[2] = q[j+1] & 63;
-        r[1] = q[j+4] & 63;
-        r[3] = q[j+5] & 63;
-    } else {
-        r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
-        r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
-        r[1] = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);
-        r[3] = (q[j+5] >>  4) | ((q[j+1] >> 6) << 4);
-    }
-    return r;
-}
-
 //====================================== dot products =========================
 
 void kernel_mul_mv_q2_K_f32_impl(
@@ -2584,14 +2604,21 @@ kernel void kernel_mul_mv_q2_K_f32(
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne01[[buffer(4)]],
-        constant   int64_t & ne02[[buffer(5)]],
-        constant   int64_t & ne10[[buffer(9)]],
-        constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0 [[buffer(15)]],
-        constant   int64_t & ne1 [[buffer(16)]],
-        constant   uint    & r2  [[buffer(17)]],
-        constant   uint    & r3  [[buffer(18)]],
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2841,14 +2868,21 @@ kernel void kernel_mul_mv_q3_K_f32(
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne01[[buffer(4)]],
-        constant   int64_t & ne02[[buffer(5)]],
-        constant   int64_t & ne10[[buffer(9)]],
-        constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0 [[buffer(15)]],
-        constant   int64_t & ne1 [[buffer(16)]],
-        constant   uint    & r2  [[buffer(17)]],
-        constant   uint    & r3  [[buffer(18)]],
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2984,8 +3018,8 @@ void kernel_mul_mv_q4_K_f32_impl(
         constant   uint    & r2,
         constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const int ix = tiisg/4;  // 0...7
     const int it = tiisg%4;  // 0...3
@@ -2994,7 +3028,7 @@ void kernel_mul_mv_q4_K_f32_impl(
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = r0 * N_DST;
     const int ib_row = first_row * nb;
 
     const uint i12 = im%ne12;
@@ -3060,7 +3094,7 @@ void kernel_mul_mv_q4_K_f32_impl(
     for (int row = 0; row < N_DST; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
         }
     }
 }
@@ -3072,14 +3106,21 @@ kernel void kernel_mul_mv_q4_K_f32(
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne01[[buffer(4)]],
-        constant   int64_t & ne02[[buffer(5)]],
-        constant   int64_t & ne10[[buffer(9)]],
-        constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0 [[buffer(15)]],
-        constant   int64_t & ne1 [[buffer(16)]],
-        constant   uint    & r2  [[buffer(17)]],
-        constant   uint    & r3  [[buffer(18)]],
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3271,14 +3312,21 @@ kernel void kernel_mul_mv_q5_K_f32(
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne01[[buffer(4)]],
-        constant   int64_t & ne02[[buffer(5)]],
-        constant   int64_t & ne10[[buffer(9)]],
-        constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0 [[buffer(15)]],
-        constant   int64_t & ne1 [[buffer(16)]],
-        constant   uint    & r2  [[buffer(17)]],
-        constant   uint    & r3  [[buffer(18)]],
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3398,14 +3446,21 @@ kernel void kernel_mul_mv_q6_K_f32(
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne01[[buffer(4)]],
-        constant   int64_t & ne02[[buffer(5)]],
-        constant   int64_t & ne10[[buffer(9)]],
-        constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0 [[buffer(15)]],
-        constant   int64_t & ne1 [[buffer(16)]],
-        constant   uint    & r2  [[buffer(17)]],
-        constant   uint    & r3  [[buffer(18)]],
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3523,7 +3578,7 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
     device const int8_t * qs = ((device const int8_t *)xb->qs);
     const half d = xb->d;
 
-    for (int i=0;i<16;i++) {
+    for (int i = 0; i < 16; i++) {
         reg[i/4][i%4] = (qs[i + 16*il] * d);
     }
 }
@@ -3792,12 +3847,12 @@ void kernel_mul_mm_impl(device const  uchar * src0,
                         device        float * dst,
                         constant    int64_t & ne00,
                         constant    int64_t & ne02,
-                        constant    int64_t & nb01,
-                        constant    int64_t & nb02,
+                        constant   uint64_t & nb01,
+                        constant   uint64_t & nb02,
                         constant    int64_t & ne12,
-                        constant    int64_t & nb10,
-                        constant    int64_t & nb11,
-                        constant    int64_t & nb12,
+                        constant   uint64_t & nb10,
+                        constant   uint64_t & nb11,
+                        constant   uint64_t & nb12,
                         constant    int64_t & ne0,
                         constant    int64_t & ne1,
                         constant       uint & r2,
@@ -3924,12 +3979,12 @@ kernel void kernel_mul_mm(device const  uchar * src0,
                           device        float * dst,
                           constant    int64_t & ne00,
                           constant    int64_t & ne02,
-                          constant    int64_t & nb01,
-                          constant    int64_t & nb02,
+                          constant   uint64_t & nb01,
+                          constant   uint64_t & nb02,
                           constant    int64_t & ne12,
-                          constant    int64_t & nb10,
-                          constant    int64_t & nb11,
-                          constant    int64_t & nb12,
+                          constant   uint64_t & nb10,
+                          constant   uint64_t & nb11,
+                          constant   uint64_t & nb12,
                           constant    int64_t & ne0,
                           constant    int64_t & ne1,
                           constant       uint & r2,
@@ -3965,19 +4020,19 @@ kernel void kernel_mul_mm_id(
         device const   uchar * ids,
         device const   uchar * src1,
         device         uchar * dst,
-        constant     int64_t & nbi1,
+        constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne02,
-        constant     int64_t & nb01,
-        constant     int64_t & nb02,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
         constant     int64_t & ne12,
         constant     int64_t & ne13,
-        constant     int64_t & nb10,
-        constant     int64_t & nb11,
-        constant     int64_t & nb12,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
         constant     int64_t & ne0,
         constant     int64_t & ne1,
-        constant     int64_t & nb1,
+        constant    uint64_t & nb1,
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
@@ -4070,12 +4125,12 @@ typedef void (mat_mm_t)(
         device        float * dst,
         constant    int64_t & ne00,
         constant    int64_t & ne02,
-        constant    int64_t & nb01,
-        constant    int64_t & nb02,
+        constant   uint64_t & nb01,
+        constant   uint64_t & nb02,
         constant    int64_t & ne12,
-        constant    int64_t & nb10,
-        constant    int64_t & nb11,
-        constant    int64_t & nb12,
+        constant   uint64_t & nb10,
+        constant   uint64_t & nb11,
+        constant   uint64_t & nb12,
         constant    int64_t & ne0,
         constant    int64_t & ne1,
         constant       uint & r2,
@@ -4104,19 +4159,19 @@ typedef void (mat_mm_id_t)(
         device const   uchar * ids,
         device const   uchar * src1,
         device         uchar * dst,
-        constant     int64_t & nbi1,
+        constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne02,
-        constant     int64_t & nb01,
-        constant     int64_t & nb02,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
         constant     int64_t & ne12,
         constant     int64_t & ne13,
-        constant     int64_t & nb10,
-        constant     int64_t & nb11,
-        constant     int64_t & nb12,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
         constant     int64_t & ne0,
         constant     int64_t & ne1,
-        constant     int64_t & nb1,
+        constant    uint64_t & nb1,
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
@@ -4153,7 +4208,7 @@ kernel void kernel_mul_mv_id_f32_f32(
         device const    char * ids,
         device const    char * src1,
         device         uchar * dst,
-        constant     int64_t & nbi1,
+        constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
         constant     int64_t & ne02,
@@ -4169,7 +4224,7 @@ kernel void kernel_mul_mv_id_f32_f32(
         constant    uint64_t & nb12,
         constant     int64_t & ne0,
         constant     int64_t & ne1,
-        constant     int64_t & nb1,
+        constant    uint64_t & nb1,
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
@@ -4222,7 +4277,7 @@ kernel void kernel_mul_mv_id_f16_f32(
         device const    char * ids,
         device const    char * src1,
         device         uchar * dst,
-        constant     int64_t & nbi1,
+        constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
         constant     int64_t & ne02,
@@ -4238,7 +4293,7 @@ kernel void kernel_mul_mv_id_f16_f32(
         constant    uint64_t & nb12,
         constant     int64_t & ne0,
         constant     int64_t & ne1,
-        constant     int64_t & nb1,
+        constant    uint64_t & nb1,
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
@@ -4291,7 +4346,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
         device const    char * ids,
         device const    char * src1,
         device         uchar * dst,
-        constant     int64_t & nbi1,
+        constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
         constant     int64_t & ne02,
@@ -4307,7 +4362,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
         constant    uint64_t & nb12,
         constant     int64_t & ne0,
         constant     int64_t & ne1,
-        constant     int64_t & nb1,
+        constant    uint64_t & nb1,
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
@@ -4354,7 +4409,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
         device const    char * ids,
         device const    char * src1,
         device         uchar * dst,
-        constant     int64_t & nbi1,
+        constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
         constant     int64_t & ne02,
@@ -4370,7 +4425,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
         constant    uint64_t & nb12,
         constant     int64_t & ne0,
         constant     int64_t & ne1,
-        constant     int64_t & nb1,
+        constant    uint64_t & nb1,
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
@@ -4417,7 +4472,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
         device const    char * ids,
         device const    char * src1,
         device         uchar * dst,
-        constant     int64_t & nbi1,
+        constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
         constant     int64_t & ne02,
@@ -4433,7 +4488,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
         constant    uint64_t & nb12,
         constant     int64_t & ne0,
         constant     int64_t & ne1,
-        constant     int64_t & nb1,
+        constant    uint64_t & nb1,
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
@@ -4480,7 +4535,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
         device const    char * ids,
         device const    char * src1,
         device         uchar * dst,
-        constant     int64_t & nbi1,
+        constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
         constant     int64_t & ne02,
@@ -4496,7 +4551,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
         constant    uint64_t & nb12,
         constant     int64_t & ne0,
         constant     int64_t & ne1,
-        constant     int64_t & nb1,
+        constant    uint64_t & nb1,
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
@@ -4543,7 +4598,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
         device const    char * ids,
         device const    char * src1,
         device         uchar * dst,
-        constant     int64_t & nbi1,
+        constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
         constant     int64_t & ne02,
@@ -4559,7 +4614,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
         constant    uint64_t & nb12,
         constant     int64_t & ne0,
         constant     int64_t & ne1,
-        constant     int64_t & nb1,
+        constant    uint64_t & nb1,
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
@@ -4606,7 +4661,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
         device const    char * ids,
         device const    char * src1,
         device         uchar * dst,
-        constant     int64_t & nbi1,
+        constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
         constant     int64_t & ne02,
@@ -4622,7 +4677,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
         constant    uint64_t & nb12,
         constant     int64_t & ne0,
         constant     int64_t & ne1,
-        constant     int64_t & nb1,
+        constant    uint64_t & nb1,
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
@@ -4669,7 +4724,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
         device const    char * ids,
         device const    char * src1,
         device         uchar * dst,
-        constant     int64_t & nbi1,
+        constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
         constant     int64_t & ne02,
@@ -4685,7 +4740,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
         constant    uint64_t & nb12,
         constant     int64_t & ne0,
         constant     int64_t & ne1,
-        constant     int64_t & nb1,
+        constant    uint64_t & nb1,
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
@@ -4732,7 +4787,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
         device const    char * ids,
         device const    char * src1,
         device         uchar * dst,
-        constant     int64_t & nbi1,
+        constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
         constant     int64_t & ne02,
@@ -4748,7 +4803,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
         constant    uint64_t & nb12,
         constant     int64_t & ne0,
         constant     int64_t & ne1,
-        constant     int64_t & nb1,
+        constant    uint64_t & nb1,
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
@@ -4795,7 +4850,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
         device const    char * ids,
         device const    char * src1,
         device         uchar * dst,
-        constant     int64_t & nbi1,
+        constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
         constant     int64_t & ne02,
@@ -4811,7 +4866,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
         constant    uint64_t & nb12,
         constant     int64_t & ne0,
         constant     int64_t & ne1,
-        constant     int64_t & nb1,
+        constant    uint64_t & nb1,
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
@@ -4858,7 +4913,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
         device const    char * ids,
         device const    char * src1,
         device         uchar * dst,
-        constant     int64_t & nbi1,
+        constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
         constant     int64_t & ne02,
@@ -4874,7 +4929,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
         constant    uint64_t & nb12,
         constant     int64_t & ne0,
         constant     int64_t & ne1,
-        constant     int64_t & nb1,
+        constant    uint64_t & nb1,
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,