]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : optimize ggml_mul_mat_id (faster Mixtral PP) (llama/4725)
authorGeorgi Gerganov <redacted>
Tue, 2 Jan 2024 19:07:47 +0000 (21:07 +0200)
committerGeorgi Gerganov <redacted>
Wed, 3 Jan 2024 12:43:51 +0000 (14:43 +0200)
* ggml : disable fast-math for Metal (cmake build only)

ggml-ci

* metal : fix Metal API debug warnings

* cmake : add -fno-inline for Metal build (llama/4545)

* metal : fix API debug warnings

* metal : fix compile warnings

* metal : use uint64_t for strides

* cmake : rename option to LLAMA_METAL_SHADER_DEBUG

* metal : fix mat-vec Q8_0 kernel for BS > 1

* metal : normalize mat-vec kernel signatures

* cmake : respect LLAMA_QKK_64 option

* metal : fix mat-vec Q4_K kernel for QK_K == 64

* metal : optimizing ggml_mul_mat_id (wip)

* metal : minor fix

* metal : opt mul_mm_id

ggml-metal.m
ggml-metal.metal

index cd9d00456f7d4e78609d7d6db50a52a785e156ba..7a369b55e36282e2de083db6b860b9f38d4e059d 100644 (file)
@@ -1657,6 +1657,10 @@ void ggml_metal_graph_compute(
                                         }
                                 };
 
+                                if (ggml_is_quantized(src0t)) {
+                                    GGML_ASSERT(ne00 >= nth0*nth1);
+                                }
+
                                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                 [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
                                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
@@ -1715,6 +1719,9 @@ void ggml_metal_graph_compute(
                             // TODO: make this more general
                             GGML_ASSERT(n_as <= 8);
 
+                            // max size of the src1ids array in the kernel stack
+                            GGML_ASSERT(ne11 <= 512);
+
                             struct ggml_tensor * src2 = gf->nodes[i]->src[2];
 
                             const int64_t  ne20 = src2 ? src2->ne[0] : 0;
@@ -1732,9 +1739,6 @@ void ggml_metal_graph_compute(
                             GGML_ASSERT(!ggml_is_transposed(src2));
                             GGML_ASSERT(!ggml_is_transposed(src1));
 
-                            GGML_ASSERT(ne20 % 32 == 0);
-                            // !!!!!!!!! TODO: this assert is probably required but not sure!
-                            //GGML_ASSERT(ne20 >= 64);
                             GGML_ASSERT(src1t == GGML_TYPE_F32);
 
                             const uint r2 = ne12/ne22;
@@ -1742,22 +1746,22 @@ void ggml_metal_graph_compute(
 
                             // find the break-even point where the matrix-matrix kernel becomes more efficient compared
                             // to the matrix-vector kernel
-                            int ne11_mm_min = 1;
+                            int ne11_mm_min = n_as;
 
                             const int idx = ((int32_t *) dst->op_params)[0];
 
                             // batch size
                             GGML_ASSERT(ne01 == ne11);
 
-                            const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
-
                             // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
                             // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
                             // !!!
                             // TODO: for now, always use mat-vec kernels until we figure out how to improve the
                             //       indirect matrix multiplication
                             // !!!
-                            if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
+                            if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
+                                ne20 % 32 == 0 && ne20 >= 64 &&
+                                ne11 > ne11_mm_min) {
                                 switch (src2->type) {
                                     case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32];  break;
                                     case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32];  break;
@@ -1787,7 +1791,7 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:11];
                                 [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:12];
                                 [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:13];
-                                [encoder setBytes:&_ne1    length:sizeof(_ne1) atIndex:14];
+                                [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:14];
                                 [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:15];
                                 [encoder setBytes:&r2      length:sizeof(r2)   atIndex:16];
                                 [encoder setBytes:&r3      length:sizeof(r3)   atIndex:17];
@@ -1805,8 +1809,7 @@ void ggml_metal_graph_compute(
 
                                 [encoder setThreadgroupMemoryLength:8192 atIndex:0];
 
-                                // TODO: processing one row at a time (ne11 -> 1) is not efficient
-                                [encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                             } else {
                                 int nth0 = 32;
                                 int nth1 = 1;
@@ -1889,11 +1892,17 @@ void ggml_metal_graph_compute(
                                         } break;
                                     default:
                                         {
-                                            GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
+                                            GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
                                             GGML_ASSERT(false && "not implemented");
                                         }
                                 };
 
+                                if (ggml_is_quantized(src2t)) {
+                                    GGML_ASSERT(ne20 >= nth0*nth1);
+                                }
+
+                                const int64_t _ne1 = 1; // kernels needs a reference in constant memory
+
                                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                 [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
                                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
index 1d5b8f6f4131c35142814029341d28a9510b38dd..9aa7b502a9ea094e386be638cc6827d9188600a6 100644 (file)
@@ -846,7 +846,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
 #define N_SIMDGROUP 2  // number of SIMD groups in a thread group
 //Note: This is a template, but strictly speaking it only applies to
 //      quantizations where the block size is 32. It also does not
-//      giard against the number of rows not being divisible by
+//      guard against the number of rows not being divisible by
 //      N_DST, so this is another explicit assumption of the implementation.
 template<typename block_q_type, int nr, int nsg, int nw>
 void mul_vec_q_n_f32_impl(
@@ -3973,6 +3973,131 @@ void kernel_mul_mm_impl(device const  uchar * src0,
     }
 }
 
+// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids
+template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
+void kernel_mul_mm_id_impl(
+        device const  uchar * src0,
+        device const  uchar * src1,
+        thread        short * src1ids,
+        device        float * dst,
+        constant    int64_t & ne00,
+        constant    int64_t & ne02,
+        constant   uint64_t & nb01,
+        constant   uint64_t & nb02,
+        constant    int64_t & ne12,
+        constant   uint64_t & nb10,
+        constant   uint64_t & nb11,
+        constant   uint64_t & nb12,
+        constant    int64_t & ne0,
+                    int64_t   ne1,
+        constant       uint & r2,
+        constant       uint & r3,
+        threadgroup   uchar * shared_memory,
+        uint3                 tgpig[[threadgroup_position_in_grid]],
+        uint                  tiitg[[thread_index_in_threadgroup]],
+        uint                  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    threadgroup half  * sa = (threadgroup half  *)(shared_memory);
+    threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
+
+    const uint r0 = tgpig.y;
+    const uint r1 = tgpig.x;
+    const uint im = tgpig.z;
+
+    if (r1 * BLOCK_SIZE_N >= ne1) return;
+
+    // if this block is of 64x32 shape or smaller
+    short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
+    short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
+
+    // a thread shouldn't load data outside of the matrix
+    short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
+    short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
+
+    simdgroup_half8x8  ma[4];
+    simdgroup_float8x8 mb[2];
+    simdgroup_float8x8 c_res[8];
+    for (int i = 0; i < 8; i++){
+        c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
+    }
+
+    short il = (tiitg % THREAD_PER_ROW);
+
+    const uint i12 = im%ne12;
+    const uint i13 = im/ne12;
+
+    uint   offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
+    ushort offset1 = il/nl;
+
+    device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
+    device const float   * y = (device const float   *)(src1
+        + nb12 * im
+        + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col]
+        + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
+
+    for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
+        // load data and store to threadgroup memory
+        half4x4 temp_a;
+        dequantize_func(x, il, temp_a);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        for (int i = 0; i < 16; i++) {
+            *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
+            +                     (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
+            +                     (tiitg / THREAD_PER_ROW) % 8  + (i & 7) * 8) = temp_a[i/4][i%4];
+        }
+
+        *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
+
+        il = (il + 2 < nl) ? il + 2 : il % 2;
+        x  = (il < 2) ? x + (2+nl-1)/nl : x;
+        y += BLOCK_SIZE_K;
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        // load matrices from threadgroup memory and conduct outer products
+        threadgroup half  * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
+        threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
+
+        for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
+            for (int i = 0; i < 4; i++) {
+                simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
+            }
+            simdgroup_barrier(mem_flags::mem_none);
+            for (int i = 0; i < 2; i++) {
+                simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
+            }
+
+            lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
+            lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
+
+            for (int i = 0; i < 8; i++){
+                simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
+            }
+        }
+    }
+
+    {
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
+                                      + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
+        for (int i = 0; i < 8; i++) {
+            simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0;
+        if (sgitg == 0) {
+            for (int i = 0; i < n_rows; i++) {
+                for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
+                    *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
+                }
+            }
+        }
+    }
+}
+
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
 kernel void kernel_mul_mm(device const  uchar * src0,
                           device const  uchar * src1,
@@ -4019,7 +4144,7 @@ template<typename block_q, short nl, void (*dequantize_func)(device const block_
 kernel void kernel_mul_mm_id(
         device const   uchar * ids,
         device const   uchar * src1,
-        device         uchar * dst,
+        device         float * dst,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne02,
@@ -4048,18 +4173,28 @@ kernel void kernel_mul_mm_id(
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+    device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
 
-    const int64_t bid = tgpig.z/(ne12*ne13);
+    // expert id
+    const int32_t id = tgpig.z/(ne12*ne13);
 
     tgpig.z = tgpig.z%(ne12*ne13);
 
-    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+    // row indices of src1 for expert id
+    int64_t _ne1 = 0;
+    short src1ids[512];
 
-    kernel_mul_mm_impl<block_q, nl, dequantize_func>(
-        src0[id],
-        src1 + bid*nb11,
-        (device float *) (dst + bid*nb1),
+    for (int64_t i1 = 0; i1 < ne1; i1++) {
+        if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
+            src1ids[_ne1++] = i1;
+        }
+    }
+
+    kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
+        src0s[id],
+        src1,
+        src1ids,
+        dst,
         ne00,
         ne02,
         nb01,
@@ -4069,7 +4204,7 @@ kernel void kernel_mul_mm_id(
         nb11,
         nb12,
         ne0,
-        ne1,
+        _ne1,
         r2,
         r3,
         shared_memory,
@@ -4158,7 +4293,7 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
 typedef void (mat_mm_id_t)(
         device const   uchar * ids,
         device const   uchar * src1,
-        device         uchar * dst,
+        device         float * dst,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne02,
@@ -4207,7 +4342,7 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
 kernel void kernel_mul_mv_id_f32_f32(
         device const    char * ids,
         device const    char * src1,
-        device         uchar * dst,
+        device         float * dst,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -4251,7 +4386,7 @@ kernel void kernel_mul_mv_id_f32_f32(
     kernel_mul_mv_f32_f32_impl(
         src0[id],
         src1 + bid*nb11,
-        (device float *) (dst + bid*nb1),
+        dst  + bid*ne0,
         ne00,
         ne01,
         ne02,
@@ -4276,7 +4411,7 @@ kernel void kernel_mul_mv_id_f32_f32(
 kernel void kernel_mul_mv_id_f16_f32(
         device const    char * ids,
         device const    char * src1,
-        device         uchar * dst,
+        device         float * dst,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -4320,7 +4455,7 @@ kernel void kernel_mul_mv_id_f16_f32(
     kernel_mul_mv_f16_f32_impl(
         src0[id],
         src1 + bid*nb11,
-        (device float *) (dst + bid*nb1),
+        dst  + bid*ne0,
         ne00,
         ne01,
         ne02,
@@ -4345,7 +4480,7 @@ kernel void kernel_mul_mv_id_f16_f32(
 kernel void kernel_mul_mv_id_q8_0_f32(
         device const    char * ids,
         device const    char * src1,
-        device         uchar * dst,
+        device         float * dst,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -4389,7 +4524,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
     kernel_mul_mv_q8_0_f32_impl(
         src0[id],
         (device const float *) (src1 + bid*nb11),
-        (device       float *) ( dst + bid*nb1),
+        dst + bid*ne0,
         ne00,
         ne01,
         ne02,
@@ -4408,7 +4543,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
 kernel void kernel_mul_mv_id_q4_0_f32(
         device const    char * ids,
         device const    char * src1,
-        device         uchar * dst,
+        device         float * dst,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -4452,7 +4587,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
     mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
         src0[id],
         (device const float *) (src1 + bid*nb11),
-        (device       float *) ( dst + bid*nb1),
+        dst + bid*ne0,
         ne00,
         ne01,
         ne02,
@@ -4471,7 +4606,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
 kernel void kernel_mul_mv_id_q4_1_f32(
         device const    char * ids,
         device const    char * src1,
-        device         uchar * dst,
+        device         float * dst,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -4515,7 +4650,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
     mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
         src0[id],
         (device const float *) (src1 + bid*nb11),
-        (device       float *) ( dst + bid*nb1),
+        dst + bid*ne0,
         ne00,
         ne01,
         ne02,
@@ -4534,7 +4669,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
 kernel void kernel_mul_mv_id_q5_0_f32(
         device const    char * ids,
         device const    char * src1,
-        device         uchar * dst,
+        device         float * dst,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -4578,7 +4713,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
     mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
         src0[id],
         (device const float *) (src1 + bid*nb11),
-        (device       float *) ( dst + bid*nb1),
+        dst + bid*ne0,
         ne00,
         ne01,
         ne02,
@@ -4597,7 +4732,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
 kernel void kernel_mul_mv_id_q5_1_f32(
         device const    char * ids,
         device const    char * src1,
-        device         uchar * dst,
+        device         float * dst,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -4641,7 +4776,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
     mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
         src0[id],
         (device const float *) (src1 + bid*nb11),
-        (device       float *) ( dst + bid*nb1),
+        dst + bid*ne0,
         ne00,
         ne01,
         ne02,
@@ -4660,7 +4795,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
 kernel void kernel_mul_mv_id_q2_K_f32(
         device const    char * ids,
         device const    char * src1,
-        device         uchar * dst,
+        device         float * dst,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -4704,7 +4839,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
     kernel_mul_mv_q2_K_f32_impl(
         src0[id],
         (device const float *) (src1 + bid*nb11),
-        (device       float *) ( dst + bid*nb1),
+        dst + bid*ne0,
         ne00,
         ne01,
         ne02,
@@ -4723,7 +4858,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
 kernel void kernel_mul_mv_id_q3_K_f32(
         device const    char * ids,
         device const    char * src1,
-        device         uchar * dst,
+        device         float * dst,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -4767,7 +4902,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
     kernel_mul_mv_q3_K_f32_impl(
         src0[id],
         (device const float *) (src1 + bid*nb11),
-        (device       float *) ( dst + bid*nb1),
+        dst + bid*ne0,
         ne00,
         ne01,
         ne02,
@@ -4786,7 +4921,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
 kernel void kernel_mul_mv_id_q4_K_f32(
         device const    char * ids,
         device const    char * src1,
-        device         uchar * dst,
+        device         float * dst,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -4830,7 +4965,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
     kernel_mul_mv_q4_K_f32_impl(
         src0[id],
         (device const float *) (src1 + bid*nb11),
-        (device       float *) ( dst + bid*nb1),
+        dst + bid*ne0,
         ne00,
         ne01,
         ne02,
@@ -4849,7 +4984,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
 kernel void kernel_mul_mv_id_q5_K_f32(
         device const    char * ids,
         device const    char * src1,
-        device         uchar * dst,
+        device         float * dst,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -4893,7 +5028,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
     kernel_mul_mv_q5_K_f32_impl(
         src0[id],
         (device const float *) (src1 + bid*nb11),
-        (device       float *) ( dst + bid*nb1),
+        dst + bid*ne0,
         ne00,
         ne01,
         ne02,
@@ -4912,7 +5047,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
 kernel void kernel_mul_mv_id_q6_K_f32(
         device const    char * ids,
         device const    char * src1,
-        device         uchar * dst,
+        device         float * dst,
         constant    uint64_t & nbi1,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -4956,7 +5091,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
     kernel_mul_mv_q6_K_f32_impl(
         src0[id],
         (device const float *) (src1 + bid*nb11),
-        (device       float *) ( dst + bid*nb1),
+        dst + bid*ne0,
         ne00,
         ne01,
         ne02,