]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : minor code formatting
authorGeorgi Gerganov <redacted>
Mon, 25 Nov 2024 13:08:04 +0000 (15:08 +0200)
committerGeorgi Gerganov <redacted>
Mon, 25 Nov 2024 13:08:04 +0000 (15:08 +0200)
ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-metal/ggml-metal.metal

index d1abb3cef0ec4a375755e3ae1cf7d3e0d98f0846..3a533d7f9c9afb1125076d755024d895d8a05476 100644 (file)
@@ -1951,316 +1951,316 @@ static void ggml_metal_encode_node(
                         }
 #endif
 
-                        // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
-                        // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
-                        if ([device supportsFamily:MTLGPUFamilyApple7] &&
-                                !ggml_is_transposed(src0) &&
-                                !ggml_is_transposed(src1) &&
-                                src1t == GGML_TYPE_F32 &&
-                                ne00 % 32 == 0 && ne00 >= 64 &&
-                                (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
-                            //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
-
-                            // some Metal matrix data types require aligned pointers
-                            // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
-                            switch (src0->type) {
-                                case GGML_TYPE_F32:  GGML_ASSERT(nb01 % 16 == 0); break;
-                                case GGML_TYPE_F16:  GGML_ASSERT(nb01 % 8  == 0); break;
-                                case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8  == 0); break;
-                                default: break;
-                            }
+                // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
+                // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
+                if ([device supportsFamily:MTLGPUFamilyApple7] &&
+                        !ggml_is_transposed(src0) &&
+                        !ggml_is_transposed(src1) &&
+                        src1t == GGML_TYPE_F32 &&
+                        ne00 % 32 == 0 && ne00 >= 64 &&
+                        (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
+                    //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
 
-                            id<MTLComputePipelineState> pipeline = nil;
-
-                            switch (src0->type) {
-                                case GGML_TYPE_F32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32    ].pipeline; break;
-                                case GGML_TYPE_F16:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32    ].pipeline; break;
-                                case GGML_TYPE_BF16:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32   ].pipeline; break;
-                                case GGML_TYPE_Q4_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32   ].pipeline; break;
-                                case GGML_TYPE_Q4_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32   ].pipeline; break;
-                                case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32   ].pipeline; break;
-                                case GGML_TYPE_Q5_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32   ].pipeline; break;
-                                case GGML_TYPE_Q8_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32   ].pipeline; break;
-                                case GGML_TYPE_Q2_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32   ].pipeline; break;
-                                case GGML_TYPE_Q3_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32   ].pipeline; break;
-                                case GGML_TYPE_Q4_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32   ].pipeline; break;
-                                case GGML_TYPE_Q5_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32   ].pipeline; break;
-                                case GGML_TYPE_Q6_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32   ].pipeline; break;
-                                case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
-                                case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
-                                case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
-                                case GGML_TYPE_IQ3_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32  ].pipeline; break;
-                                case GGML_TYPE_IQ2_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32  ].pipeline; break;
-                                case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32  ].pipeline; break;
-                                case GGML_TYPE_IQ1_M:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32  ].pipeline; break;
-                                case GGML_TYPE_IQ4_NL:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
-                                case GGML_TYPE_IQ4_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
-                                default: GGML_ABORT("MUL MAT-MAT not implemented");
-                            }
+                    // some Metal matrix data types require aligned pointers
+                    // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
+                    switch (src0->type) {
+                        case GGML_TYPE_F32:  GGML_ASSERT(nb01 % 16 == 0); break;
+                        case GGML_TYPE_F16:  GGML_ASSERT(nb01 % 8  == 0); break;
+                        case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8  == 0); break;
+                        default: break;
+                    }
 
-                            ggml_metal_kargs_mul_mm args = {
-                                /*.ne00 =*/ ne00,
-                                /*.ne02 =*/ ne02,
-                                /*.nb01 =*/ nb01,
-                                /*.nb02 =*/ nb02,
-                                /*.nb03 =*/ nb03,
-                                /*.ne12 =*/ ne12,
-                                /*.nb10 =*/ nb10,
-                                /*.nb11 =*/ nb11,
-                                /*.nb12 =*/ nb12,
-                                /*.nb13 =*/ nb13,
-                                /*.ne0  =*/ ne0,
-                                /*.ne1  =*/ ne1,
-                                /*.r2   =*/ r2,
-                                /*.r3   =*/ r3,
-                            };
+                    id<MTLComputePipelineState> pipeline = nil;
+
+                    switch (src0->type) {
+                        case GGML_TYPE_F32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32    ].pipeline; break;
+                        case GGML_TYPE_F16:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32    ].pipeline; break;
+                        case GGML_TYPE_BF16:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32   ].pipeline; break;
+                        case GGML_TYPE_Q4_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32   ].pipeline; break;
+                        case GGML_TYPE_Q4_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32   ].pipeline; break;
+                        case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32   ].pipeline; break;
+                        case GGML_TYPE_Q5_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32   ].pipeline; break;
+                        case GGML_TYPE_Q8_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32   ].pipeline; break;
+                        case GGML_TYPE_Q2_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32   ].pipeline; break;
+                        case GGML_TYPE_Q3_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32   ].pipeline; break;
+                        case GGML_TYPE_Q4_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32   ].pipeline; break;
+                        case GGML_TYPE_Q5_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32   ].pipeline; break;
+                        case GGML_TYPE_Q6_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32   ].pipeline; break;
+                        case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
+                        case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
+                        case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
+                        case GGML_TYPE_IQ3_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32  ].pipeline; break;
+                        case GGML_TYPE_IQ2_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32  ].pipeline; break;
+                        case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32  ].pipeline; break;
+                        case GGML_TYPE_IQ1_M:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32  ].pipeline; break;
+                        case GGML_TYPE_IQ4_NL:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
+                        case GGML_TYPE_IQ4_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
+                        default: GGML_ABORT("MUL MAT-MAT not implemented");
+                    }
+
+                    ggml_metal_kargs_mul_mm args = {
+                        /*.ne00 =*/ ne00,
+                        /*.ne02 =*/ ne02,
+                        /*.nb01 =*/ nb01,
+                        /*.nb02 =*/ nb02,
+                        /*.nb03 =*/ nb03,
+                        /*.ne12 =*/ ne12,
+                        /*.nb10 =*/ nb10,
+                        /*.nb11 =*/ nb11,
+                        /*.nb12 =*/ nb12,
+                        /*.nb13 =*/ nb13,
+                        /*.ne0  =*/ ne0,
+                        /*.ne1  =*/ ne1,
+                        /*.r2   =*/ r2,
+                        /*.r3   =*/ r3,
+                    };
 
-                            [encoder setComputePipelineState:pipeline];
-                            [encoder setBytes:&args    length:sizeof(args) atIndex:0];
-                            [encoder setBuffer:id_src0 offset:offs_src0    atIndex:1];
-                            [encoder setBuffer:id_src1 offset:offs_src1    atIndex:2];
-                            [encoder setBuffer:id_dst  offset:offs_dst     atIndex:3];
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBytes:&args    length:sizeof(args) atIndex:0];
+                    [encoder setBuffer:id_src0 offset:offs_src0    atIndex:1];
+                    [encoder setBuffer:id_src1 offset:offs_src1    atIndex:2];
+                    [encoder setBuffer:id_dst  offset:offs_dst     atIndex:3];
 
-                            [encoder setThreadgroupMemoryLength:8192 atIndex:0];
-                            [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
-                        } else {
-                            int nth0 = 32;
-                            int nth1 = 1;
-                            int nrows = 1;
-                            //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
+                    [encoder setThreadgroupMemoryLength:8192 atIndex:0];
+                    [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                } else {
+                    int nth0 = 32;
+                    int nth1 = 1;
+                    int nrows = 1;
+                    //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
 
-                            id<MTLComputePipelineState> pipeline = nil;
+                    id<MTLComputePipelineState> pipeline = nil;
 
-                            // use custom matrix x vector kernel
-                            switch (src0t) {
-                                case GGML_TYPE_F32:
-                                    {
-                                        GGML_ASSERT(src1t == GGML_TYPE_F32);
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
+                    // use custom matrix x vector kernel
+                    switch (src0t) {
+                        case GGML_TYPE_F32:
+                            {
+                                GGML_ASSERT(src1t == GGML_TYPE_F32);
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
+                                nrows = 4;
+                            } break;
+                        case GGML_TYPE_F16:
+                            {
+                                nth0 = 32;
+                                nth1 = 1;
+                                if (src1t == GGML_TYPE_F32) {
+                                    if (ne11 * ne12 < 4) {
+                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
+                                    } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
+                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
+                                        nrows = ne11;
+                                    } else {
+                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
                                         nrows = 4;
-                                    } break;
-                                case GGML_TYPE_F16:
-                                    {
-                                        nth0 = 32;
-                                        nth1 = 1;
-                                        if (src1t == GGML_TYPE_F32) {
-                                            if (ne11 * ne12 < 4) {
-                                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
-                                            } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
-                                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
-                                                nrows = ne11;
-                                            } else {
-                                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
-                                                nrows = 4;
-                                            }
-                                        } else {
-                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
-                                            nrows = 4;
-                                        }
-                                    } break;
-                                case GGML_TYPE_BF16:
-                                    {
-                                        nth0 = 32;
-                                        nth1 = 1;
-                                        if (src1t == GGML_TYPE_F32) {
-                                            if (ne11 * ne12 < 4) {
-                                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
-                                            } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
-                                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
-                                                nrows = ne11;
-                                            } else {
-                                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
-                                                nrows = 4;
-                                            }
-                                        } else {
-                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
-                                            nrows = 4;
-                                        }
-                                    } break;
-                                case GGML_TYPE_Q4_0:
-                                    {
-                                        nth0 = 8;
-                                        nth1 = 8;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_Q4_1:
-                                    {
-                                        nth0 = 8;
-                                        nth1 = 8;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_Q5_0:
-                                    {
-                                        nth0 = 8;
-                                        nth1 = 8;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_Q5_1:
-                                    {
-                                        nth0 = 8;
-                                        nth1 = 8;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_Q8_0:
-                                    {
-                                        nth0 = 8;
-                                        nth1 = 8;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_Q2_K:
-                                    {
-                                        nth0 = 2;
-                                        nth1 = 32;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_Q3_K:
-                                    {
-                                        nth0 = 2;
-                                        nth1 = 32;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_Q4_K:
-                                    {
-                                        nth0 = 4; //1;
-                                        nth1 = 8; //32;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_Q5_K:
-                                    {
-                                        nth0 = 2;
-                                        nth1 = 32;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_Q6_K:
-                                    {
-                                        nth0 = 2;
-                                        nth1 = 32;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_IQ2_XXS:
-                                    {
-                                        nth0 = 4;
-                                        nth1 = 16;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_IQ2_XS:
-                                    {
-                                        nth0 = 4;
-                                        nth1 = 16;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_IQ3_XXS:
-                                    {
-                                        nth0 = 4;
-                                        nth1 = 16;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_IQ3_S:
-                                    {
-                                        nth0 = 4;
-                                        nth1 = 16;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_IQ2_S:
-                                    {
-                                        nth0 = 4;
-                                        nth1 = 16;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_IQ1_S:
-                                    {
-                                        nth0 = 4;
-                                        nth1 = 16;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_IQ1_M:
-                                    {
-                                        nth0 = 4;
-                                        nth1 = 16;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_IQ4_NL:
-                                    {
-                                        nth0 = 4;
-                                        nth1 = 16;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
-                                    } break;
-                                case GGML_TYPE_IQ4_XS:
-                                    {
-                                        nth0 = 4;
-                                        nth1 = 16;
-                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
-                                    } break;
-                                default:
-                                    {
-                                        GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t);
-                                        GGML_ABORT("not implemented");
                                     }
-                            };
+                                } else {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
+                                    nrows = 4;
+                                }
+                            } break;
+                        case GGML_TYPE_BF16:
+                            {
+                                nth0 = 32;
+                                nth1 = 1;
+                                if (src1t == GGML_TYPE_F32) {
+                                    if (ne11 * ne12 < 4) {
+                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
+                                    } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
+                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
+                                        nrows = ne11;
+                                    } else {
+                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
+                                        nrows = 4;
+                                    }
+                                } else {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
+                                    nrows = 4;
+                                }
+                            } break;
+                        case GGML_TYPE_Q4_0:
+                            {
+                                nth0 = 8;
+                                nth1 = 8;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q4_1:
+                            {
+                                nth0 = 8;
+                                nth1 = 8;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q5_0:
+                            {
+                                nth0 = 8;
+                                nth1 = 8;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q5_1:
+                            {
+                                nth0 = 8;
+                                nth1 = 8;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q8_0:
+                            {
+                                nth0 = 8;
+                                nth1 = 8;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q2_K:
+                            {
+                                nth0 = 2;
+                                nth1 = 32;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q3_K:
+                            {
+                                nth0 = 2;
+                                nth1 = 32;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q4_K:
+                            {
+                                nth0 = 4; //1;
+                                nth1 = 8; //32;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q5_K:
+                            {
+                                nth0 = 2;
+                                nth1 = 32;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q6_K:
+                            {
+                                nth0 = 2;
+                                nth1 = 32;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ2_XXS:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ2_XS:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ3_XXS:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ3_S:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ2_S:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ1_S:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ1_M:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ4_NL:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ4_XS:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
+                            } break;
+                        default:
+                            {
+                                GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t);
+                                GGML_ABORT("not implemented");
+                            }
+                    };
 
-                            ggml_metal_kargs_mul_mv args = {
-                                /*.ne00 =*/ ne00,
-                                /*.ne01 =*/ ne01,
-                                /*.ne02 =*/ ne02,
-                                /*.nb00 =*/ nb00,
-                                /*.nb01 =*/ nb01,
-                                /*.nb02 =*/ nb02,
-                                /*.nb03 =*/ nb03,
-                                /*.ne10 =*/ ne10,
-                                /*.ne11 =*/ ne11,
-                                /*.ne12 =*/ ne12,
-                                /*.nb10 =*/ nb10,
-                                /*.nb11 =*/ nb11,
-                                /*.nb12 =*/ nb12,
-                                /*.nb13 =*/ nb13,
-                                /*.ne0  =*/ ne0,
-                                /*.ne1  =*/ ne1,
-                                /*.r2   =*/ r2,
-                                /*.r3   =*/ r3,
-                            };
+                    ggml_metal_kargs_mul_mv args = {
+                        /*.ne00 =*/ ne00,
+                        /*.ne01 =*/ ne01,
+                        /*.ne02 =*/ ne02,
+                        /*.nb00 =*/ nb00,
+                        /*.nb01 =*/ nb01,
+                        /*.nb02 =*/ nb02,
+                        /*.nb03 =*/ nb03,
+                        /*.ne10 =*/ ne10,
+                        /*.ne11 =*/ ne11,
+                        /*.ne12 =*/ ne12,
+                        /*.nb10 =*/ nb10,
+                        /*.nb11 =*/ nb11,
+                        /*.nb12 =*/ nb12,
+                        /*.nb13 =*/ nb13,
+                        /*.ne0  =*/ ne0,
+                        /*.ne1  =*/ ne1,
+                        /*.r2   =*/ r2,
+                        /*.r3   =*/ r3,
+                    };
 
-                            [encoder setComputePipelineState:pipeline];
-                            [encoder setBytes:&args length:sizeof(args) atIndex:0];
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
-                            [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3];
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                    [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3];
 
-                            if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 ||
-                                src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
-                                src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                            }
-                            else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
-                                const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
-                                [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                            }
-                            else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
-                                const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
-                                [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                            }
-                            else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
-                                const int mem_size = 32*sizeof(float);
-                                [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                            }
-                            else if (src0t == GGML_TYPE_Q4_K) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                            }
-                            else if (src0t == GGML_TYPE_Q3_K) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                            }
-                            else if (src0t == GGML_TYPE_Q5_K) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                            }
-                            else if (src0t == GGML_TYPE_Q6_K) {
-                                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                            } else {
-                                const int64_t ny = (ne11 + nrows - 1)/nrows;
-                                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                            }
-                        }
+                    if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 ||
+                        src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
+                        src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                    else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
+                        const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
+                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                    else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
+                        const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
+                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                    else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
+                        const int mem_size = 32*sizeof(float);
+                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                    else if (src0t == GGML_TYPE_Q4_K) {
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                    else if (src0t == GGML_TYPE_Q3_K) {
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                    else if (src0t == GGML_TYPE_Q5_K) {
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                    else if (src0t == GGML_TYPE_Q6_K) {
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    } else {
+                        const int64_t ny = (ne11 + nrows - 1)/nrows;
+                        [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                }
             } break;
         case GGML_OP_MUL_MAT_ID:
             {
index 971f5054bce2a6384bcebaf8619d8a7f9eb56ca2..eaca38864bd6548384acc6a29c430dfb36c5aaf2 100644 (file)
@@ -5447,12 +5447,12 @@ kernel void kernel_mul_mm(
     const int im = tgpig.z;
 
     // if this block is of 64x32 shape or smaller
-    short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
-    short n_cols = (args.ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (args.ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
+    const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
+    const short n_cols = (args.ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (args.ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
 
     // a thread shouldn't load data outside of the matrix
-    short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
-    short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
+    const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
+    const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
 
     simdgroup_T8x8     ma[4];
     simdgroup_float8x8 mb[2];
@@ -5467,20 +5467,23 @@ kernel void kernel_mul_mm(
     const int i12 = im%args.ne12;
     const int i13 = im/args.ne12;
 
-    uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
-    short    offset1 = il/nl;
+    const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const short    offset1 = il/nl;
+
+    device const block_q * x = (device const block_q *)(src0
+        + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
 
-    device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*args.nb01 + offset0) + offset1;
     device const float   * y = (device const float   *)(src1
         + args.nb13*i13
         + args.nb12*i12
-        + args.nb11*(r1 * BLOCK_SIZE_N + thread_col)
+        + args.nb11*(r1*BLOCK_SIZE_N + thread_col)
         + args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
 
     for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
         // load data and store to threadgroup memory
         T4x4 temp_a;
         dequantize_func(x, il, temp_a);
+
         threadgroup_barrier(mem_flags::mem_threadgroup);
 
         #pragma unroll(16)
@@ -5490,44 +5493,46 @@ kernel void kernel_mul_mm(
             +                     (tiitg/THREAD_PER_ROW)%8  + (i&7)*8) = temp_a[i/4][i%4];
         }
 
-        *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL)*8*32 + 8*(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y);
+        *(threadgroup float2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y);
 
         il = (il + 2 < nl) ? il + 2 : il % 2;
-        x  = (il < 2) ? x + (2+nl-1)/nl : x;
+        x  = (il < 2) ? x + (2 + nl - 1)/nl : x;
         y += BLOCK_SIZE_K;
 
         threadgroup_barrier(mem_flags::mem_threadgroup);
 
         // load matrices from threadgroup memory and conduct outer products
-        threadgroup T     * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
-        threadgroup float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
+        threadgroup const T     * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
+        threadgroup const float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
 
         #pragma unroll(4)
-        for (short ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
+        for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
             #pragma unroll(4)
             for (short i = 0; i < 4; i++) {
                 simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
             }
+
             simdgroup_barrier(mem_flags::mem_none);
+
             #pragma unroll(2)
             for (short i = 0; i < 2; i++) {
                 simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
             }
 
-            lsma += BLOCK_SIZE_M/SG_MAT_ROW * SG_MAT_SIZE;
-            lsmb += BLOCK_SIZE_N/SG_MAT_ROW * SG_MAT_SIZE;
-
             #pragma unroll(8)
             for (short i = 0; i < 8; i++){
                 simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
             }
+
+            lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE;
+            lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE;
         }
     }
 
     if ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1) {
         device float * C = (device float *) dst +
-            (BLOCK_SIZE_M * r0 + 32 * (sgitg &  1)) + \
-            (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
+            (BLOCK_SIZE_M * r0 + 32*(sgitg &  1)) + \
+            (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
 
         for (short i = 0; i < 8; i++) {
             simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0);
@@ -5536,7 +5541,7 @@ kernel void kernel_mul_mm(
         // block is smaller than 64x32, we should avoid writing data outside of the matrix
         threadgroup_barrier(mem_flags::mem_threadgroup);
         threadgroup float * temp_str = ((threadgroup float *) shmem) \
-                                      + 32 * (sgitg&1) + (16 * (sgitg>>1))*BLOCK_SIZE_M;
+                                     + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
         for (short i = 0; i < 8; i++) {
             simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
         }