]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : improve `MUL_MAT_ID` (#15541)
authorGeorgi Gerganov <redacted>
Tue, 26 Aug 2025 09:46:15 +0000 (12:46 +0300)
committerGitHub <redacted>
Tue, 26 Aug 2025 09:46:15 +0000 (12:46 +0300)
* metal : mul_mm_id remove hdst

* metal : remove mul_mm_id hsrc1

* metal : mul_mm_id simplify + add test

* metal : opt mul_mm_id map0

* metal : optimize mul_mm_id id gathering

* metal : mul/div opt

* metal : optimize mul_mm_id_map0

ggml-ci

ggml/src/ggml-metal/ggml-metal-impl.h
ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-metal/ggml-metal.metal
tests/test-backend-ops.cpp

index fc6526d6d5dc6eec1ea0b31438650fa62d17fc70..82c1ac1da0b13b490027fb0000cc92830ddcb314 100644 (file)
@@ -320,40 +320,31 @@ typedef struct {
 } ggml_metal_kargs_mul_mv_ext;
 
 typedef struct {
+    int32_t  ne02;
     int32_t  ne10;
     int32_t  ne11;  // n_expert_used (bcast)
     uint64_t nb11;
     uint64_t nb12;
-    int32_t  neh11; // n_tokens
-    uint64_t nbh11;
+    int32_t  ne21; // n_tokens
     int32_t  ne20;  // n_expert_used
     uint64_t nb21;
 } ggml_metal_kargs_mul_mm_id_map0;
 
-typedef struct {
-    int32_t  ne20; // n_expert_used
-    int32_t  neh0;
-    int32_t  neh1;
-    uint64_t nbh1;
-    uint64_t nbh2;
-    int32_t  ne0;
-    uint64_t nb1;
-    uint64_t nb2;
-} ggml_metal_kargs_mul_mm_id_map1;
-
 typedef struct {
     int32_t  ne00;
     int32_t  ne02;
     uint64_t nb01;
     uint64_t nb02;
     uint64_t nb03;
-    int32_t  neh12;
-    uint64_t nbh10;
-    uint64_t nbh11;
-    uint64_t nbh12;
-    uint64_t nbh13;
-    int32_t  neh0;
-    int32_t  neh1;
+    int32_t  ne11;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ne20;
+    int32_t  ne21;
+    int32_t  ne0;
+    int32_t  ne1;
     int16_t  r2;
     int16_t  r3;
 } ggml_metal_kargs_mul_mm_id;
index dcd816a430f60c342ae4af510f35a25d94363498..7a05a98278f94699b68ef0d5f1471dea122770ac 100644 (file)
@@ -398,8 +398,12 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
-    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
@@ -1428,8 +1432,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,                mul_mm_iq1_m_f32,                has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,               mul_mm_iq4_nl_f32,               has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,               mul_mm_iq4_xs_f32,               has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,              mul_mm_id_map0_f16,              has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,              mul_mm_id_map1_f32,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1,       mul_mm_id_map0_f16_ne20_1,       has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2,       mul_mm_id_map0_f16_ne20_2,       has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,       mul_mm_id_map0_f16_ne20_4,       has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,       mul_mm_id_map0_f16_ne20_6,       has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,       mul_mm_id_map0_f16_ne20_8,       has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,      mul_mm_id_map0_f16_ne20_16,      has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,               mul_mm_id_f32_f16,               has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,               mul_mm_id_f16_f16,               has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,              mul_mm_id_bf16_f16,              has_simdgroup_mm && use_bfloat);
@@ -3908,38 +3916,6 @@ static int ggml_metal_encode_node(
                         default: break;
                     }
 
-                    const int64_t neh10 = ne10; // n_embd
-                    const int64_t neh11 = ne21; // n_tokens
-                    const int64_t neh12 = ne02; // n_expert
-
-                    const uint64_t nbh10 = ggml_type_size(GGML_TYPE_F16);
-                    const uint64_t nbh11 = nbh10*neh10;
-                    const uint64_t nbh12 = nbh11*neh11;
-                    const uint64_t nbh13 = nbh12*neh12;
-
-                    const size_t s_src1 = ggml_type_size(GGML_TYPE_F16)*neh10*neh11*neh12;
-                    id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
-                    if (!h_src1) {
-                        GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
-                        return 0;
-                    }
-
-                    const int64_t neh0 = ne0;
-                    const int64_t neh1 = ne21;
-                    const int64_t neh2 = ne02;
-
-                    const uint64_t nbh0 = ggml_type_size(GGML_TYPE_F32);
-                    const uint64_t nbh1 = nbh0*neh0;
-                    const uint64_t nbh2 = nbh1*neh1;
-                  //const uint64_t nbh3 = nbh2*neh2;
-
-                    const size_t s_dst = ggml_type_size(GGML_TYPE_F32)*neh0*neh1*neh2;
-                    id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
-                    if (!h_dst) {
-                        GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
-                        return 0;
-                    }
-
                     // tokens per expert
                     const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
                     id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
@@ -3949,8 +3925,8 @@ static int ggml_metal_encode_node(
                     }
 
                     // id map
-                    // [n_expert_used, n_tokens]
-                    const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne20*ne21;
+                    // [n_tokens, n_expert]
+                    const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne21*ne02;
                     id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
                     if (!h_ids) {
                         GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
@@ -3958,32 +3934,45 @@ static int ggml_metal_encode_node(
                     }
 
                     {
-                        const int nth = MIN(1024, ne10/4);
-
                         ggml_metal_kargs_mul_mm_id_map0 args = {
+                            ne02,
                             ne10,
-                            ne11,  // n_expert_used (bcast)
+                            ne11, // n_expert_used (bcast)
                             nb11,
                             nb12,
-                            neh11, // n_tokens
-                            nbh11,
-                            ne20,  // n_expert_used
+                            ne21, // n_tokens
+                            ne20, // n_expert_used
                             nb21,
                         };
 
                         id<MTLComputePipelineState> pipeline = nil;
 
-                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline;
+                        pipeline = nil;
+
+                        switch (ne20) {
+                            case 1:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1 ].pipeline; break;
+                            case 2:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2 ].pipeline; break;
+                            case 4:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4 ].pipeline; break;
+                            case 6:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6 ].pipeline; break;
+                            case 8:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8 ].pipeline; break;
+                            case 16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16].pipeline; break;
+                            default: GGML_ABORT("missing specialization for ne20 = %d", (int) ne20);
+                        }
+
+                        GGML_ASSERT(ne02 <= (int) pipeline.maxTotalThreadsPerThreadgroup);
+
+                        const size_t smem = ne02*ne20*sizeof(uint16_t);
+
+                        GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
 
                         [encoder setComputePipelineState:pipeline];
                         [encoder setBytes:&args    length:sizeof(args) atIndex:0];
-                        [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
-                        [encoder setBuffer:id_src2 offset:offs_src2    atIndex:2];
-                        [encoder setBuffer: h_src1 offset:0            atIndex:3];
-                        [encoder setBuffer: h_tpe  offset:0            atIndex:4];
-                        [encoder setBuffer: h_ids  offset:0            atIndex:5];
+                        [encoder setBuffer:id_src2 offset:offs_src2    atIndex:1];
+                        [encoder setBuffer: h_tpe  offset:0            atIndex:2];
+                        [encoder setBuffer: h_ids  offset:0            atIndex:3];
+                        [encoder setThreadgroupMemoryLength:smem atIndex:0];
 
-                        [encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)];
                     }
 
                     {
@@ -4022,13 +4011,15 @@ static int ggml_metal_encode_node(
                             /*.nb01  =*/ nb01,
                             /*.nb02  =*/ nb02,
                             /*.nb03  =*/ nb03,
-                            /*.neh12 =*/ neh12,
-                            /*.nbh10 =*/ nbh10,
-                            /*.nbh11 =*/ nbh11,
-                            /*.nbh12 =*/ nbh12,
-                            /*.nbh13 =*/ nbh13,
-                            /*.neh0  =*/ neh0,
-                            /*.neh1  =*/ neh1,
+                            /*.ne11  =*/ ne11, // n_expert_used (bcast)
+                            /*.nb10  =*/ nb10,
+                            /*.nb11  =*/ nb11,
+                            /*.nb12  =*/ nb12,
+                            /*.nb13  =*/ nb13,
+                            /*.ne20  =*/ ne20, // n_expert_used
+                            /*.ne21  =*/ ne21, // n_tokens
+                            /*.ne0   =*/ ne0,
+                            /*.ne1   =*/ ne1,
                             /*.r2    =*/ r2,
                             /*.r3    =*/ r3,
                         };
@@ -4036,42 +4027,14 @@ static int ggml_metal_encode_node(
                         [encoder setComputePipelineState:pipeline];
                         [encoder setBytes:&args    length:sizeof(args) atIndex:0];
                         [encoder setBuffer:id_src0 offset:offs_src0    atIndex:1];
-                        [encoder setBuffer: h_src1 offset:0            atIndex:2];
+                        [encoder setBuffer:id_src1 offset:offs_src1    atIndex:2];
                         [encoder setBuffer: h_tpe  offset:0            atIndex:3];
-                        [encoder setBuffer: h_dst  offset:0            atIndex:4];
+                        [encoder setBuffer: h_ids  offset:0            atIndex:4];
+                        [encoder setBuffer:id_dst  offset:offs_dst     atIndex:5];
 
                         [encoder setThreadgroupMemoryLength:8192 atIndex:0];
                         [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                     }
-
-                    {
-                        GGML_ASSERT(ne0 % 4 == 0);
-
-                        const int nth = MIN(1024, ne0/4);
-
-                        ggml_metal_kargs_mul_mm_id_map1 args = {
-                            ne20, // n_expert_used
-                            neh0,
-                            neh1,
-                            nbh1,
-                            nbh2,
-                            ne0,
-                            nb1,
-                            nb2,
-                        };
-
-                        id<MTLComputePipelineState> pipeline = nil;
-
-                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline;
-
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBytes:&args   length:sizeof(args) atIndex:0];
-                        [encoder setBuffer: h_dst offset:0            atIndex:1];
-                        [encoder setBuffer: h_ids offset:0            atIndex:2];
-                        [encoder setBuffer:id_dst offset:offs_dst     atIndex:3];
-
-                        [encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-                    }
                 } else {
                     id<MTLComputePipelineState> pipeline = nil;
 
index 3dd55fd325ce24bcfa282f955b160e31654cd039..7037c1aa0d30d741bfdaa86f9eaced5cea37927f 100644 (file)
@@ -974,9 +974,16 @@ kernel void kernel_mul(
     device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
     device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs;
 
-    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-        const int i10 = i0%args.ne10;
-        *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
+    if (args.ne10 == 1) {
+        const float x = *((device float *)(src1_ptr));
+        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+            *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
+        }
+    } else {
+        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+            const int i10 = i0%args.ne10;
+            *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
+        }
     }
 }
 
@@ -1000,9 +1007,16 @@ kernel void kernel_div(
     device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
     device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs;
 
-    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-        const int i10 = i0%args.ne10;
-        *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
+    if (args.ne10 == 1) {
+        const float x = 1.0f / *((device float *)(src1_ptr));
+        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+            *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
+        }
+    } else {
+        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+            const int i10 = i0%args.ne10;
+            *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
+        }
     }
 }
 
@@ -7491,97 +7505,81 @@ kernel void kernel_mul_mm(
     }
 }
 
-template<typename T4>
+template<short ne20> // n_expert_used
 kernel void kernel_mul_mm_id_map0(
         constant ggml_metal_kargs_mul_mm_id_map0 & args,
-        device  const char * src1,
         device  const char * src2,
-        device        char * hsrc1,
         device        char * htpe,
         device        char * hids,
-        uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort3 tpitg[[thread_position_in_threadgroup]],
-        ushort3   ntg[[threads_per_threadgroup]]) {
-    const int ide = tgpig[0]; // expert id
-
-    int n_all = 0;
+        threadgroup   char * shmem [[threadgroup(0)]],
+        ushort tpitg[[thread_position_in_threadgroup]],
+        ushort   ntg[[threads_per_threadgroup]]) {
+    const short ide = tpitg; // expert id
 
-    device int32_t * ids_i32 = (device int32_t *) (hids);
+    uint32_t n_all = 0;
 
-    for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens
-        device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21);
+    device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21;
 
-        for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used
-            if (src2_i32[i20] != ide) {
-                continue;
-            }
+    for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens
+        if (i21 + tpitg < args.ne21) {
+            device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21);
 
-            device const float4 *  src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11);
-            device       T4     * hsrc1_f32x4 = (device       T4     *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11);
+            threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20;
 
-            for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) {
-                hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]);
+            #pragma unroll(ne20)
+            for (short i20 = 0; i20 < ne20; i20++) {
+                sids[i20] = src2_i32[i20];
             }
-
-            if (tpitg.x == 0) {
-                ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all;
-            }
-
-            ++n_all;
         }
-    }
 
-    if (tpitg.x == 0) {
-        device int32_t * tpe_i32 = (device int32_t *) (htpe);
-        tpe_i32[ide] = n_all;
-    }
-}
-
-typedef decltype(kernel_mul_mm_id_map0<half4>) kernel_mul_mm_id_map0_t;
-
-template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<half4>;
+        threadgroup_barrier(mem_flags::mem_threadgroup);
 
-template<typename T>
-kernel void kernel_mul_mm_id_map1(
-        constant ggml_metal_kargs_mul_mm_id_map1 & args,
-        device  const char * hdst,
-        device  const char * hids,
-        device        char * dst,
-        uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort3 tpitg[[thread_position_in_threadgroup]],
-        ushort3   ntg[[threads_per_threadgroup]]) {
-    const int i20 = tgpig[0]; // used expert
-    const int i21 = tgpig[1]; // token
+        for (short t = 0; t < ntg; t++) {
+            if (i21 + t >= args.ne21) {
+                break;
+            }
 
-    device const int32_t * ids_i32    = (device const int32_t *) (hids);
-    device       float4  * dst_f32x4  = (device       float4  *) (dst + i20*args.nb1 + i21*args.nb2);
+            threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20;
 
-    const int id = ids_i32[i21*args.ne20 + i20];
+            short sel = 0;
+            #pragma unroll(ne20)
+            for (short i20 = 0; i20 < ne20; i20++) {
+                sel += (sids[i20] == ide)*(i20 + 1);
+            }
 
-    const int ide = id / args.neh1;
-    const int idt = id % args.neh1;
+            ids_i32[n_all] = (i21 + t)*ne20 + sel - 1;
 
-    device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2);
+            n_all += sel > 0;
+        }
 
-    for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) {
-        dst_f32x4[i0] = hdst_f32x4[i0];
+        threadgroup_barrier(mem_flags::mem_threadgroup);
     }
+
+    device uint32_t * tpe_u32 = (device uint32_t *) (htpe);
+    tpe_u32[ide] = n_all;
 }
 
-typedef decltype(kernel_mul_mm_id_map1<float>) kernel_mul_mm_id_map1_t;
+typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
 
-template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1<float>;
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
 
 template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
 kernel void kernel_mul_mm_id(
         constant ggml_metal_kargs_mul_mm_id & args,
         device const char * src0,
         device const char * src1,
-        device const char * tpe,
+        device const char * htpe,
+        device const char * hids,
         device       char * dst,
         threadgroup  char * shmem [[threadgroup(0)]],
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiitg[[thread_index_in_threadgroup]],
+        ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
     threadgroup T    * sa = (threadgroup T    *)(shmem);
@@ -7589,19 +7587,20 @@ kernel void kernel_mul_mm_id(
 
     const int r0 = tgpig.y;
     const int r1 = tgpig.x;
-    const int im = tgpig.z;
+    const int im = tgpig.z; // expert
 
-    device const int32_t * tpe_i32 = (device const int32_t *) (tpe);
+    device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
+    device const int32_t  * ids_i32 = (device const int32_t  *) (hids);
 
-    const int neh1 = tpe_i32[im];
+    const int32_t neh1 = tpe_u32[im];
 
     if (r1*BLOCK_SIZE_N >= neh1) {
         return;
     }
 
     // if this block is of 64x32 shape or smaller
-    const short n_rows = (args.neh0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.neh0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
-    const short n_cols = (     neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (     neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
+    const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
+    const short n_cols = (    neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (    neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
 
     // a thread shouldn't load data outside of the matrix
     const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
@@ -7617,20 +7616,23 @@ kernel void kernel_mul_mm_id(
 
     short il = (tiitg % THREAD_PER_ROW);
 
-    const int i12 = im%args.neh12;
-    const int i13 = im/args.neh12;
+    const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + thread_col];
 
-    const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const short i11 = (id % args.ne20) % args.ne11;
+    const short i12 = (id / args.ne20);
+    const short i13 = 0;
+
+    const uint64_t offset0 = im*args.nb02 + i13*args.nb03;
     const short    offset1 = il/nl;
 
     device const block_q * x = (device const block_q *)(src0
         + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
 
-    device const half   * y = (device const half   *)(src1
-        + args.nbh13*i13
-        + args.nbh12*i12
-        + args.nbh11*(r1*BLOCK_SIZE_N + thread_col)
-        + args.nbh10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
+    device const float   * y = (device const float   *)(src1
+        + args.nb13*i13
+        + args.nb12*i12
+        + args.nb11*i11
+        + args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
 
     for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
         // load data and store to threadgroup memory
@@ -7646,7 +7648,7 @@ kernel void kernel_mul_mm_id(
             +                     (tiitg/THREAD_PER_ROW)%8  + (i&7)*8) = temp_a[i/4][i%4];
         }
 
-        *(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device half2x4 *) y);
+        *(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (half2x4)(*((device float2x4 *) y));
 
         il = (il + 2 < nl) ? il + 2 : il % 2;
         x  = (il < 2) ? x + (2 + nl - 1)/nl : x;
@@ -7682,43 +7684,38 @@ kernel void kernel_mul_mm_id(
         }
     }
 
-    if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) {
-        device float * C = (device float *) dst +
-            (BLOCK_SIZE_M * r0 + 32*(sgitg &  1)) + \
-            (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0;
+    threadgroup_barrier(mem_flags::mem_threadgroup);
 
-        for (short i = 0; i < 8; i++) {
-            simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0);
-        }
-    } else {
-        // block is smaller than 64x32, we should avoid writing data outside of the matrix
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-        threadgroup float * temp_str = ((threadgroup float *) shmem) \
-                                     + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
-        for (short i = 0; i < 8; i++) {
-            simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
-        }
+    threadgroup float * temp_str = ((threadgroup float *) shmem) \
+                                 + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
 
-        threadgroup_barrier(mem_flags::mem_threadgroup);
+    #pragma unroll(8)
+    for (short i = 0; i < 8; i++) {
+        simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
+    }
 
-        if (sgitg == 0) {
-            for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
-                device float  * D  = (device float  *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0;
-                device float4 * D4 = (device float4 *) D;
+    threadgroup_barrier(mem_flags::mem_threadgroup);
 
-                threadgroup float  * C  = temp_str + (j*BLOCK_SIZE_M);
-                threadgroup float4 * C4 = (threadgroup float4 *) C;
+    for (short j = sgitg; j < n_cols; j += 4) {
+        const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j];
 
-                int i = 0;
-                for (; i < n_rows/4; i++) {
-                    *(D4 + i) = *(C4 + i);
-                }
+        const short ide = id % args.ne20;
+        const short idt = id / args.ne20;
 
-                i *= 4;
-                for (; i < n_rows; i++) {
-                    *(D + i) = *(C + i);
-                }
-            }
+        device float  * D  = (device float  *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1*args.ne0;
+        device float4 * D4 = (device float4 *) D;
+
+        threadgroup float  * C  = (threadgroup float  *) shmem + (j*BLOCK_SIZE_M);
+        threadgroup float4 * C4 = (threadgroup float4 *) C;
+
+        int i = tiisg;
+        for (; i < n_rows/4; i += 32) {
+            *(D4 + i) = *(C4 + i);
+        }
+
+        i = (4*(n_rows/4)) + tiisg;
+        for (; i < n_rows; i += 32) {
+            *(D + i) = *(C + i);
         }
     }
 }
index 765521ffebb4e4744a8a73e61880d2ed69aaeb76..4b4299d49df610df80b1b07624b15f9d77b50c79 100644 (file)
@@ -6018,6 +6018,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     for (bool b : {false, true}) {
         test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 32, 1024, 16));
         test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 2, 2, b, 32, 8192, 64));
+        test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 50, 200, 64));
     }
 
     test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));