GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4,
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4, mul_mv_f32_f32_c4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4, mul_mv_bf16_f32_c4, use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4, mul_mv_f16_f32_c4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
nsg = 1;
nr0 = 1;
nr1 = 4;
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
+ if (ne00 == 4) {
+ nr0 = 32;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4].pipeline;
+ } else {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
+ }
} break;
case GGML_TYPE_F16:
{
nsg = 1;
nr0 = 1;
if (src1t == GGML_TYPE_F32) {
- if (ne11 * ne12 < 4) {
+ if (ne00 == 4) {
+ nr0 = 32;
+ nr1 = 4;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4].pipeline;
+ } else if (ne11 * ne12 < 4) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
nsg = 1;
nr0 = 1;
if (src1t == GGML_TYPE_F32) {
- if (ne11 * ne12 < 4) {
+ if (ne00 == 4) {
+ nr0 = 32;
+ nr1 = 4;
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4].pipeline;
+ } else if (ne11 * ne12 < 4) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
#endif
+template<typename T04, typename T14, typename args_t>
+void kernel_mul_mv_c4_impl(
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig,
+ ushort tiisg) {
+ const int r0 = tgpig.x*32 + tiisg;
+ const int rb = tgpig.y*N_MV_T_T;
+ const int im = tgpig.z;
+
+ if (r0 >= args.ne01) {
+ return;
+ }
+
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
+
+ const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+
+ device const T04 * x = (device const T04 *) (src0 + offset0);
+
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
+
+ for (int row = 0; row < N_MV_T_T; ++row) {
+ int r1 = rb + row;
+ if (r1 >= args.ne11) {
+ break;
+ }
+
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ device const T14 * y = (device const T14 *) (src1 + offset1);
+
+ dst_f32[(uint64_t)r1*args.ne0 + r0] = dot((float4) x[0], (float4) y[0]);
+ }
+}
+
+template<typename T04, typename T14>
+kernel void kernel_mul_mv_c4(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]]) {
+ kernel_mul_mv_c4_impl<T04, T14, constant ggml_metal_kargs_mul_mv &>(
+ args,
+ src0,
+ src1,
+ dst,
+ tgpig,
+ tiisg);
+}
+
+typedef decltype(kernel_mul_mv_c4<half4, half4>) mul_mv_c4_t;
+
+template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<float4, float4>;
+template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<half4, float4>;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>;
+#endif
+
template<typename T, typename T4>
kernel void kernel_mul_mv_1row(
constant ggml_metal_kargs_mul_mv & args,