GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
+ GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32,
+ GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16,
+ GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32,
+ GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16,
+ GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32,
+ GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16,
+ GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32,
+ GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16,
+ GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32,
+ GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16,
GGML_METAL_KERNEL_TYPE_CONCAT,
GGML_METAL_KERNEL_TYPE_SQR,
GGML_METAL_KERNEL_TYPE_SQRT,
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, cpy_q4_0_f16, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, cpy_q4_1_f16, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, cpy_q5_0_f16, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, cpy_q5_1_f16, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, cpy_q8_0_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
default:
return false;
}
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ switch (op->type) {
+ case GGML_TYPE_F32:
+ case GGML_TYPE_F16:
+ return true;
+ default:
+ return false;
+ }
default:
return false;
};
case GGML_OP_CPY:
case GGML_OP_CONT:
{
- GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
-
- int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
-
id<MTLComputePipelineState> pipeline = nil;
switch (src0t) {
switch (dstt) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
- default: GGML_ASSERT(false && "not implemented");
+ default: GGML_ABORT("not implemented");
+ };
+ } break;
+ case GGML_TYPE_Q4_0:
+ {
+ switch (dstt) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16].pipeline; break;
+ default: GGML_ABORT("not implemented");
+ };
+ } break;
+ case GGML_TYPE_Q4_1:
+ {
+ switch (dstt) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16].pipeline; break;
+ default: GGML_ABORT("not implemented");
+ };
+ } break;
+ case GGML_TYPE_Q5_0:
+ {
+ switch (dstt) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16].pipeline; break;
+ default: GGML_ABORT("not implemented");
+ };
+ } break;
+ case GGML_TYPE_Q5_1:
+ {
+ switch (dstt) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16].pipeline; break;
+ default: GGML_ABORT("not implemented");
+ };
+ } break;
+ case GGML_TYPE_Q8_0:
+ {
+ switch (dstt) {
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline; break;
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16].pipeline; break;
+ default: GGML_ABORT("not implemented");
};
} break;
default: GGML_ABORT("not implemented");
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
+ int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+
} break;
case GGML_OP_SET:
{
}
}
+template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
+kernel void kernel_cpy_q_f32(
+ constant ggml_metal_kargs_cpy & args,
+ device const char * src0,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i03 = tgpig[2];
+ const int i02 = tgpig[1];
+ const int i01 = tgpig[0];
+
+ const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
+
+ const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
+ const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
+ const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
+ const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
+
+ device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
+ device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+
+ for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) {
+ T4x4 temp;
+ dequantize_func(src_data + i00/nl, i00%nl, temp);
+ dst_data[i00] = temp;
+ }
+}
+
+typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
+
+template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
+template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
+template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
+template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
+template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
+
+template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
+template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
+template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
+template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_1, 2, dequantize_q5_1>;
+template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q8_0, 2, dequantize_q8_0>;
+
kernel void kernel_concat(
constant ggml_metal_kargs_concat & args,
device const char * src0,