GGML_METAL_KERNEL_TYPE_IM2COL_F32,
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
+ GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32,
+ GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32,
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
GGML_METAL_KERNEL_TYPE_PAD_F32,
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, conv_transpose_1d_f32_f32, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
case GGML_OP_REPEAT:
case GGML_OP_SCALE:
case GGML_OP_CLAMP:
+ case GGML_OP_CONV_TRANSPOSE_1D:
return true;
case GGML_OP_SQR:
case GGML_OP_SQRT:
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
}
} break;
+ case GGML_OP_CONV_TRANSPOSE_1D:
+ {
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(src1));
+ GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+
+ const int32_t IC = src1->ne[1];
+ const int32_t IL = src1->ne[0];
+
+ const int32_t K = src0->ne[0];
+
+ const int32_t OL = dst->ne[0];
+ const int32_t OC = dst->ne[1];
+
+ id<MTLComputePipelineState> pipeline;
+
+ switch (src0->type) {
+ case GGML_TYPE_F32: {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32].pipeline;
+ } break;
+ case GGML_TYPE_F16: {
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32].pipeline;
+ } break;
+ default: GGML_ABORT("fatal error");
+ };
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&IC length:sizeof( int32_t) atIndex:3];
+ [encoder setBytes:&IL length:sizeof( int32_t) atIndex:4];
+ [encoder setBytes:&K length:sizeof( int32_t) atIndex:5];
+ [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:6];
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:7];
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:8];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
case GGML_OP_UPSCALE:
{
GGML_ASSERT(src0->type == GGML_TYPE_F32);
template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
+typedef void (conv_transpose_1d_t)(
+ device const float * src0,
+ device const float * src1,
+ device char * dst,
+ constant int32_t & IC,
+ constant int32_t & IL,
+ constant int32_t & K,
+ constant int32_t & s0,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tgpg[[threadgroups_per_grid]]);
+
+template <typename T>
+kernel void kernel_conv_transpose_1d(
+ device const T * src0,
+ device const float * src1,
+ device char * dst,
+ constant int32_t & IC,
+ constant int32_t & IL,
+ constant int32_t & K,
+ constant int32_t & s0,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tgpg[[threadgroups_per_grid]]) {
+
+ float v = 0.0f;
+
+ for (int64_t c = 0; c < IC; c++) {
+ const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
+ const int32_t input_offset = c * IL;
+
+ for (int64_t i = 0; i < IL; i++) {
+ if (tgpig[0] >= i * s0 && tgpig[0] < i * s0 + K) {
+ v += src0[kernel_offset + tgpig[0] - i * s0] * src1[input_offset + i];
+ }
+ }
+ }
+
+ device float * dst_ptr = (device float *) (dst + tgpig[0] * nb0 + tgpig[1] * nb1);
+
+ dst_ptr[0] = v;
+}
+
+template [[host_name("kernel_conv_transpose_1d_f32_f32")]]
+kernel void kernel_conv_transpose_1d<float>(
+ device const float * src0,
+ device const float * src1,
+ device char * dst,
+ constant int32_t & IC,
+ constant int32_t & IL,
+ constant int32_t & K,
+ constant int32_t & s0,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tgpg[[threadgroups_per_grid]]);
+
+template [[host_name("kernel_conv_transpose_1d_f16_f32")]]
+kernel void kernel_conv_transpose_1d<half>(
+ device const half * src0,
+ device const float * src1,
+ device char * dst,
+ constant int32_t & IC,
+ constant int32_t & IL,
+ constant int32_t & K,
+ constant int32_t & s0,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tgpg[[threadgroups_per_grid]]);
+
kernel void kernel_upscale_f32(
device const char * src0,
device char * dst,