GGML_METAL_DECL_KERNEL(get_rows_q6_K);
GGML_METAL_DECL_KERNEL(rms_norm);
GGML_METAL_DECL_KERNEL(norm);
+ GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
GGML_METAL_ADD_KERNEL(rms_norm);
GGML_METAL_ADD_KERNEL(norm);
+ GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
GGML_METAL_DEL_KERNEL(rms_norm);
GGML_METAL_DEL_KERNEL(norm);
+ GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
ne00%32 == 0 &&
ne11 > 1) {
switch (src0->type) {
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
// use custom matrix x vector kernel
switch (src0t) {
+ case GGML_TYPE_F32:
+ {
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
+ nrows = 4;
+ } break;
case GGML_TYPE_F16:
{
nth0 = 32;
}
}
+#define N_F32_F32 4
+
+kernel void kernel_mul_mat_f32_f32(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+
+ const int64_t r0 = tgpig.x;
+ const int64_t rb = tgpig.y*N_F32_F32;
+ const int64_t im = tgpig.z;
+
+ device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+
+ if (ne00 < 128) {
+ for (int row = 0; row < N_F32_F32; ++row) {
+ int r1 = rb + row;
+ if (r1 >= ne11) {
+ break;
+ }
+
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+
+ float sumf = 0;
+ for (int i = tiisg; i < ne00; i += 32) {
+ sumf += (float) x[i] * (float) y[i];
+ }
+
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
+ }
+ } else {
+ device const float4 * x4 = (device const float4 *)x;
+ for (int row = 0; row < N_F32_F32; ++row) {
+ int r1 = rb + row;
+ if (r1 >= ne11) {
+ break;
+ }
+
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+ device const float4 * y4 = (device const float4 *) y;
+
+ float sumf = 0;
+ for (int i = tiisg; i < ne00/4; i += 32) {
+ for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
+ }
+
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
+ }
+ }
+}
+
kernel void kernel_mul_mat_f16_f32_1row(
device const char * src0,
device const char * src1,
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01[[buffer(4)]],
- constant int64_t & ne02[[buffer(5)]],
- constant int64_t & ne10[[buffer(9)]],
- constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0[[buffer(15)]],
- constant int64_t & ne1[[buffer(16)]],
- constant uint & gqa[[buffer(17)]],
+ constant int64_t & ne01 [[buffer(4)]],
+ constant int64_t & ne02 [[buffer(5)]],
+ constant int64_t & ne10 [[buffer(9)]],
+ constant int64_t & ne12 [[buffer(11)]],
+ constant int64_t & ne0 [[buffer(15)]],
+ constant int64_t & ne1 [[buffer(16)]],
+ constant uint & gqa [[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
constant uint & gqa,
threadgroup uchar *, uint3, uint, uint);
+template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;