]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : add special-case mat-vec mul for ne00 == 4 (#14385) upstream/0.0.5760
authorGeorgi Gerganov <redacted>
Thu, 26 Jun 2025 12:51:19 +0000 (15:51 +0300)
committerGitHub <redacted>
Thu, 26 Jun 2025 12:51:19 +0000 (15:51 +0300)
ggml-ci

ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-metal/ggml-metal.metal
tests/test-backend-ops.cpp

index 248fa378ef0f17b8c63be834476e580e42407632..d8d30cc0b41ca5b1b6fef3a623b0ed4cef8a6c73 100644 (file)
@@ -211,11 +211,14 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
     GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4,
     GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4,
     GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
     GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
     GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
     GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4,
     GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
     GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
     GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
@@ -1175,11 +1178,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,                   rwkv_wkv6_f32,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,                   rwkv_wkv7_f32,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,                  mul_mv_f32_f32,                  has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4,               mul_mv_f32_f32_c4,               true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,                 mul_mv_bf16_f32,                 has_simdgroup_reduction && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4,              mul_mv_bf16_f32_c4,              use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,            mul_mv_bf16_f32_1row,            has_simdgroup_reduction && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,              mul_mv_bf16_f32_l4,              has_simdgroup_reduction && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,                mul_mv_bf16_bf16,                has_simdgroup_reduction && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,                  mul_mv_f16_f32,                  has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4,               mul_mv_f16_f32_c4,               true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,             mul_mv_f16_f32_1row,             has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,               mul_mv_f16_f32_l4,               has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,                  mul_mv_f16_f16,                  has_simdgroup_reduction);
@@ -3111,14 +3117,23 @@ static bool ggml_metal_encode_node(
                                 nsg = 1;
                                 nr0 = 1;
                                 nr1 = 4;
-                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
+                                if (ne00 == 4) {
+                                    nr0 = 32;
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4].pipeline;
+                                } else {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
+                                }
                             } break;
                         case GGML_TYPE_F16:
                             {
                                 nsg = 1;
                                 nr0 = 1;
                                 if (src1t == GGML_TYPE_F32) {
-                                    if (ne11 * ne12 < 4) {
+                                    if (ne00 == 4) {
+                                        nr0 = 32;
+                                        nr1 = 4;
+                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4].pipeline;
+                                    } else if (ne11 * ne12 < 4) {
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
                                     } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
@@ -3137,7 +3152,11 @@ static bool ggml_metal_encode_node(
                                 nsg = 1;
                                 nr0 = 1;
                                 if (src1t == GGML_TYPE_F32) {
-                                    if (ne11 * ne12 < 4) {
+                                    if (ne00 == 4) {
+                                        nr0 = 32;
+                                        nr1 = 4;
+                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4].pipeline;
+                                    } else if (ne11 * ne12 < 4) {
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
                                     } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
index f028276068ef4beb36c62d14672aa2275894a1fe..5f004a856bde63897ce34552d2eb75a9ad7ac698 100644 (file)
@@ -2532,6 +2532,70 @@ template [[host_name("kernel_mul_mv_bf16_f32")]]  kernel mul_mv_t kernel_mul_mv<
 template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
 #endif
 
+template<typename T04, typename T14, typename args_t>
+void kernel_mul_mv_c4_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig,
+        ushort tiisg) {
+    const int r0 = tgpig.x*32 + tiisg;
+    const int rb = tgpig.y*N_MV_T_T;
+    const int im = tgpig.z;
+
+    if (r0 >= args.ne01) {
+        return;
+    }
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+
+    device const T04 * x = (device const T04 *) (src0 + offset0);
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
+
+    for (int row = 0; row < N_MV_T_T; ++row) {
+        int r1 = rb + row;
+        if (r1 >= args.ne11) {
+            break;
+        }
+
+        const uint64_t offset1 = r1*args.nb11 + (i12   )*args.nb12 + (i13   )*args.nb13;
+
+        device const T14 * y = (device const T14 *) (src1 + offset1);
+
+        dst_f32[(uint64_t)r1*args.ne0 + r0] = dot((float4) x[0], (float4) y[0]);
+    }
+}
+
+template<typename T04, typename T14>
+kernel void kernel_mul_mv_c4(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]]) {
+    kernel_mul_mv_c4_impl<T04, T14, constant ggml_metal_kargs_mul_mv &>(
+        args,
+        src0,
+        src1,
+        dst,
+        tgpig,
+        tiisg);
+}
+
+typedef decltype(kernel_mul_mv_c4<half4, half4>) mul_mv_c4_t;
+
+template [[host_name("kernel_mul_mv_f32_f32_c4")]]  kernel mul_mv_c4_t kernel_mul_mv_c4<float4,  float4>;
+template [[host_name("kernel_mul_mv_f16_f32_c4")]]  kernel mul_mv_c4_t kernel_mul_mv_c4<half4,   float4>;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>;
+#endif
+
 template<typename T, typename T4>
 kernel void kernel_mul_mv_1row(
         constant ggml_metal_kargs_mul_mv & args,
index 7be7f2205fa044dac2955fa09fef48cdc8c9eb60..615c2dc008a8dee5ec3bfe415395e5aaf749023c 100644 (file)
@@ -4252,39 +4252,45 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
 #if 1
     for (ggml_type type_a : base_types) {
         for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
-            // test cases without permutation
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {1, 1}, {1, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {1, 1}, {2, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {1, 1}, {1, 2}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 1}, {1, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 1}, {2, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 2}, {1, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 2}, {2, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 2}, {1, 2}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 2}, {2, 2}));
-
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {2, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 2}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {1, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {2, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 2}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 2}));
-
-            // test cases with permutation
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
-
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  8, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  8, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  8, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
-
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
+            std::vector<int> ks = { 256 };
+            if (ggml_blck_size(type_a) == 1) {
+                ks.push_back(4);
+            }
+            for (auto k : ks) {
+                // test cases without permutation
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {1, 1}, {1, 1}));
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {1, 1}, {2, 1}));
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {1, 1}, {1, 2}));
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {3, 1}, {1, 1}));
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {3, 1}, {2, 1}));
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {3, 2}, {1, 1}));
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {3, 2}, {2, 1}));
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {3, 2}, {1, 2}));
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {3, 2}, {2, 2}));
+
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {1, 1}));
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {2, 1}));
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {1, 2}));
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 1}, {1, 1}));
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 1}, {2, 1}));
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 2}, {1, 1}));
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 2}, {2, 1}));
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 2}, {1, 2}));
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 2}, {2, 2}));
+
+                // test cases with permutation
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {2, 3}, {1, 1}, {0, 2, 1, 3}));
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
+
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  8, k, {2, 3}, {1, 1}, {0, 2, 1, 3}));
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  8, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  8, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
+
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {2, 3}, {1, 1}, {0, 2, 1, 3}));
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
+            }
 
             // test cases with large ne00/ne10 to cover stream-k fixup
             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 1024, {3, 2}, {1, 1}));