]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : fix mul-mm condition + fix mul-mv permuted kernels (llama/16494)
authorGeorgi Gerganov <redacted>
Sat, 11 Oct 2025 13:54:10 +0000 (16:54 +0300)
committerGeorgi Gerganov <redacted>
Sun, 12 Oct 2025 08:16:23 +0000 (11:16 +0300)
ggml/src/ggml-metal/ggml-metal-ops.cpp
ggml/src/ggml-metal/ggml-metal.metal

index 1137e210773af54dd1969c5a1249773745b18484..5f9370449bb2deb48e7864c6f583152f4d77c0c4 100644 (file)
@@ -1546,9 +1546,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
         !ggml_is_transposed(op->src[1]) &&
         // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
         // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
-        props_dev->has_simdgroup_mm && ne00 >= 64 &&
-        (ne11 > ne11_mm_min || (ggml_is_quantized(op->src[0]->type) && ne12 > 1))) {
-        //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
+        props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) {
+        //GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
 
         // some Metal matrix data types require aligned pointers
         // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
index 45d91def88bf211ab0381117ff2b672480391233..ddc285042d284f081dbb357d719297f3419641b1 100644 (file)
@@ -7487,7 +7487,7 @@ kernel void kernel_mul_mv_iq1_m_f32(
     kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, typename args_t>
+template<int NR0, typename args_t>
 void kernel_mul_mv_iq4_nl_f32_impl(
         args_t args,
         device const char * src0,
@@ -7500,13 +7500,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
     const short NSG = FC_mul_mv_nsg;
 
     threadgroup float * shmem_f32 = (threadgroup float *) shmem;
-    const int nb = args.ne00/QK4_NL;
 
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * NSG + sgitg) * nr0;
+    const int first_row = (r0 * NSG + sgitg) * NR0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -7517,6 +7516,9 @@ void kernel_mul_mv_iq4_nl_f32_impl(
     device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
     device const float        * y = (device const float        *) (src1 + offset1);
 
+    const int nb   = args.ne00/QK4_NL;
+    const int ns01 = args.nb01/args.nb00;
+
     const short ix = tiisg/2;  // 0...15
     const short it = tiisg%2;  // 0 or 1
 
@@ -7524,24 +7526,25 @@ void kernel_mul_mv_iq4_nl_f32_impl(
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     float4 yl[4];
-    float sumf[nr0]={0.f};
+    float sumf[NR0]={0.f};
 
-    device const float * yb = y + ix * QK4_NL + it * 8;
+    device const float * yb = y + ix*QK4_NL + it*8;
 
     uint32_t aux32[2];
     thread const uint8_t * q8 = (thread const uint8_t *)aux32;
 
     float4 qf1, qf2;
 
-    for (int ib = ix; ib < nb; ib += 16) {
+    // [TAG_MUL_MV_WEIRD]
+    for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
         device const float4 * y4 = (device const float4 *)yb;
         yl[0] = y4[0];
         yl[1] = y4[4];
         yl[2] = y4[1];
         yl[3] = y4[5];
 
-        for (short row = 0; row < nr0; row++) {
-            device const block_iq4_nl & xb = x[row*nb + ib];
+        for (short row = 0; row < NR0; row++) {
+            device const block_iq4_nl & xb = x[row*ns01 + ib];
             device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
 
             float4 acc1 = {0.f}, acc2 = {0.f};
@@ -7572,7 +7575,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+    for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
         float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = sum_all;
@@ -7594,7 +7597,7 @@ kernel void kernel_mul_mv_iq4_nl_f32(
     kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, typename args_t>
+template<int NR0, typename args_t>
 void kernel_mul_mv_iq4_xs_f32_impl(
         args_t args,
         device const char * src0,
@@ -7607,12 +7610,11 @@ void kernel_mul_mv_iq4_xs_f32_impl(
     const short NSG = FC_mul_mv_nsg;
 
     threadgroup float * shmem_f32 = (threadgroup float *) shmem;
-    const int nb = args.ne00/QK_K;
 
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
-    const int first_row = (r0 * NSG + sgitg) * nr0;
+    const int first_row = (r0 * NSG + sgitg) * NR0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -7623,6 +7625,9 @@ void kernel_mul_mv_iq4_xs_f32_impl(
     device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
     device const float        * y = (device const float        *) (src1 + offset1);
 
+    const int nb   = args.ne00/QK_K;
+    const int ns01 = args.nb01/args.nb00;
+
     const short ix = tiisg/16;  // 0 or 1
     const short it = tiisg%16;  // 0...15
     const short ib = it/2;
@@ -7632,7 +7637,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     float4 yl[4];
-    float sumf[nr0]={0.f};
+    float sumf[NR0]={0.f};
 
     device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
 
@@ -7641,15 +7646,16 @@ void kernel_mul_mv_iq4_xs_f32_impl(
 
     float4 qf1, qf2;
 
-    for (int ibl = ix; ibl < nb; ibl += 2) {
+    // [TAG_MUL_MV_WEIRD]
+    for (int ibl = ix; ibl < nb && ibl < ns01; ibl += 2) {
         device const float4 * y4 = (device const float4 *)yb;
         yl[0] = y4[0];
         yl[1] = y4[4];
         yl[2] = y4[1];
         yl[3] = y4[5];
 
-        for (short row = 0; row < nr0; ++row) {
-            device const block_iq4_xs & xb = x[row*nb + ibl];
+        for (short row = 0; row < NR0; ++row) {
+            device const block_iq4_xs & xb = x[row*ns01 + ibl];
             device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
 
             float4 acc1 = {0.f}, acc2 = {0.f};
@@ -7679,7 +7685,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+    for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
         float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = sum_all;
@@ -7701,7 +7707,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
     kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template<int nr0, typename args_t>
+template<int NR0, typename args_t>
 void kernel_mul_mv_mxfp4_f32_impl(
         args_t args,
         device const char * src0,
@@ -7714,13 +7720,12 @@ void kernel_mul_mv_mxfp4_f32_impl(
     const short NSG = FC_mul_mv_nsg;
 
     threadgroup float * shmem_f32 = (threadgroup float *) shmem;
-    const int nb = args.ne00/QK_MXFP4;
 
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * NSG + sgitg) * nr0;
+    const int first_row = (r0 * NSG + sgitg) * NR0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -7731,6 +7736,9 @@ void kernel_mul_mv_mxfp4_f32_impl(
     device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
     device const float       * y = (device const float       *) (src1 + offset1);
 
+    const int nb   = args.ne00/QK_MXFP4;
+    const int ns01 = args.nb01/args.nb00; // this can be larger than nb for permuted src0 tensors
+
     const short ix = tiisg/2;  // 0...15
     const short it = tiisg%2;  // 0 or 1
 
@@ -7738,20 +7746,22 @@ void kernel_mul_mv_mxfp4_f32_impl(
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     float4 yl[4];
-    float sumf[nr0]={0.f};
+    float sumf[NR0]={0.f};
 
-    device const float * yb = y + ix * QK_MXFP4 + it * 8;
+    device const float * yb = y + ix*QK_MXFP4 + it*8;
+
+    // note: just the check `ib < nb` is enough, but adding the redundant `&& ib < ns01` check makes the kernel a bit faster
+    //       no idea why that is - needs some deeper investigation [TAG_MUL_MV_WEIRD]
+    for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
+        device const float4 * y4 = (device const float4 *) yb;
 
-    for (int ib = ix; ib < nb; ib += 16) {
-        device const float4 * y4 = (device const float4 *)yb;
         yl[0] = y4[0];
         yl[1] = y4[4];
         yl[2] = y4[1];
         yl[3] = y4[5];
 
-#pragma unroll(nr0)
-        for (short row = 0; row < nr0; row++) {
-            device const block_mxfp4 & xb = x[row*nb + ib];
+        FOR_UNROLL (short row = 0; row < NR0; row++) {
+            device const block_mxfp4 & xb = x[row*ns01 + ib];
             device const uint8_t     * q2 = (device const uint8_t *)(xb.qs + 8*it);
 
             float4 acc1 = yl[0]*float4(shmem_f32[q2[0] &  0x0F], shmem_f32[q2[1] &  0x0F], shmem_f32[q2[2] &  0x0F], shmem_f32[q2[3] &  0x0F]);
@@ -7769,7 +7779,7 @@ void kernel_mul_mv_mxfp4_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+    for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
         float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = sum_all;