id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
- // TODO: add ggml_metal_kargs struct
+
+ ggml_metal_kargs_sum_rows args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne10 =*/ ne10,
+ /*.ne11 =*/ ne11,
+ /*.ne12 =*/ ne12,
+ /*.ne13 =*/ ne13,
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ };
+
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
- // TODO: add ggml_metal_kargs struct
- // TODO: optimize (see https://github.com/ggml-org/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
+ ggml_metal_kargs_soft_max args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.scale =*/ scale,
+ /*.max_bias =*/ max_bias,
+ /*.m0 =*/ m0,
+ /*.m1 =*/ m1,
+ /*.n_head_log2 =*/ n_head_log2,
+ };
+
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_src1) {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
}
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
- [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
- [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
- [encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
- [encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
- [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
+ [encoder setBytes:&args length:sizeof(args) atIndex:3];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
}
- // TODO: add ggml_metal_kargs struct
+ ggml_metal_kargs_diag_mask_inf args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.n_past =*/ n_past,
+ };
+
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
- [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
if (ne00%8 == 0) {
[encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
- // TODO: add ggml_metal_kargs struct
+ ggml_metal_kargs_ssm_conv args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.ne10 =*/ ne10,
+ /*.ne11 =*/ ne11,
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ };
+
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15];
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17];
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18];
+ [encoder setBytes:&args length:sizeof(args) atIndex:3];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
- // TODO: add ggml_metal_kargs struct
+ ggml_metal_kargs_ssm_scan args = {
+ /*.d_state =*/ d_state,
+ /*.d_inner =*/ d_inner,
+ /*.n_seq_tokens =*/ n_seq_tokens,
+ /*.n_seqs =*/ n_seqs,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.nb20 =*/ nb20,
+ /*.nb21 =*/ nb21,
+ /*.nb22 =*/ nb22,
+ /*.nb30 =*/ nb30,
+ /*.nb31 =*/ nb31,
+ /*.nb40 =*/ nb40,
+ /*.nb41 =*/ nb41,
+ /*.nb42 =*/ nb42,
+ /*.nb50 =*/ nb50,
+ /*.nb51 =*/ nb51,
+ /*.nb52 =*/ nb52,
+ };
+
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
-
- [encoder setBytes:&d_state length:sizeof(d_state) atIndex:7];
- [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8];
- [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9];
- [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10];
-
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
- [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
- [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18];
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19];
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20];
- [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21];
- [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
- [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23];
- [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
- [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
- [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26];
- [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
- [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
+ [encoder setBytes:&args length:sizeof(args) atIndex:7];
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
default: GGML_ABORT("not implemented");
}
- // TODO: add ggml_metal_kargs struct
+ ggml_metal_kargs_get_rows args = {
+ /*.ne00 =*/ ne00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.ne10 =*/ ne10,
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ };
+
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
- [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
- [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
- [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
+ [encoder setBytes:&args length:sizeof(args) atIndex:3];
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
} break;
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
- // TODO: add ggml_metal_kargs struct
+ ggml_metal_kargs_group_norm args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.n_groups =*/ n_groups,
+ /*.eps =*/ eps,
+ };
+
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
- [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
- [encoder setBytes:&eps length:sizeof( float) atIndex:9];
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
const int32_t CHW = IC * KH * KW;
- const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
- const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
+ const uint64_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
+ const uint64_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
default: GGML_ABORT("fatal error");
};
- // TODO: add ggml_metal_kargs struct
+ ggml_metal_kargs_im2col args = {
+ /*.ofs0 =*/ ofs0,
+ /*.ofs1 =*/ ofs1,
+ /*.IW =*/ IW,
+ /*.IH =*/ IH,
+ /*.CHW =*/ CHW,
+ /*.s0 =*/ s0,
+ /*.s1 =*/ s1,
+ /*.p0 =*/ p0,
+ /*.p1 =*/ p1,
+ /*.d0 =*/ d0,
+ /*.d1 =*/ d1,
+ /*.N =*/ N,
+ /*.KH =*/ KH,
+ /*.KW =*/ KW,
+ /*.KHW =*/ KH * KW,
+ };
+
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ofs0 length:sizeof(int32_t) atIndex:2];
- [encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3];
- [encoder setBytes:&IW length:sizeof(int32_t) atIndex:4];
- [encoder setBytes:&IH length:sizeof(int32_t) atIndex:5];
- [encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6];
- [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7];
- [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8];
- [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9];
- [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10];
- [encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11];
- [encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12];
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
if (is_gt_mttpt) {
- [encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
- [encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
- [encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
-
const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
default: GGML_ABORT("fatal error");
};
+ ggml_metal_kargs_conv_transpose_1d args = {
+ /*.IC =*/ IC,
+ /*.IL =*/ IL,
+ /*.K =*/ K,
+ /*.s0 =*/ s0,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ };
+
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&IC length:sizeof( int32_t) atIndex:3];
- [encoder setBytes:&IL length:sizeof( int32_t) atIndex:4];
- [encoder setBytes:&K length:sizeof( int32_t) atIndex:5];
- [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:6];
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:7];
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:8];
+ [encoder setBytes:&args length:sizeof(args) atIndex:3];
[encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
- // TODO: add ggml_metal_kargs struct
+ ggml_metal_kargs_upscale args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ /*.sf0 =*/ sf0,
+ /*.sf1 =*/ sf1,
+ /*.sf2 =*/ sf2,
+ /*.sf3 =*/ sf3
+ };
+
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
- [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
- [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
- [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
- [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
- // TODO: add ggml_metal_kargs struct
+ ggml_metal_kargs_pad args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3
+ };
+
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
- [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
- [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
const int nth = MIN(1024, ne0);
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline;
+ ggml_metal_kargs_pad_reflect_1d args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ /*.p0 =*/ p0,
+ /*.p1 =*/ p1
+ };
+
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
- [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:6];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
- [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
- [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:11];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:12];
- [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:13];
- [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:14];
- [encoder setBytes:&p0 length:sizeof(p0) atIndex:15];
- [encoder setBytes:&p1 length:sizeof(p1) atIndex:16];
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
const int nth = MIN(1024, ne0);
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
- // TODO: add ggml_metal_kargs struct
+ ggml_metal_kargs_arange args = {
+ /*.ne0 =*/ ne0,
+ /*.start =*/ start,
+ /*.step =*/ step
+ };
+
[encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:0];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
- [encoder setBytes:&start length:sizeof(start) atIndex:2];
- [encoder setBytes:&step length:sizeof(step) atIndex:3];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:0];
+ [encoder setBytes:&args length:sizeof(args) atIndex:1];
const int nth = MIN(1024, ne0);
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
- // TODO: add ggml_metal_kargs struct
+ ggml_metal_kargs_timestep_embedding args = {
+ /*.nb1 =*/ nb1,
+ /*.dim =*/ dim,
+ /*.max_period =*/ max_period
+ };
+
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2];
- [encoder setBytes:&dim length:sizeof(dim) atIndex:3];
- [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
const int nth = MIN(1024, half);
default: GGML_ABORT("fatal error");
};
- // TODO: add ggml_metal_kargs struct
+ ggml_metal_kargs_argsort args = {
+ /*.ncols =*/ ne00,
+ /*.ncols_pad =*/ ne00_padded
+ };
+
[encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
- [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
- // TODO: add ggml_metal_kargs struct
+ ggml_metal_kargs_leaky_relu args = {
+ /*.slope =*/ slope
+ };
+
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
+ [encoder setBytes:&args length:sizeof(args) atIndex:2];
const int64_t n = ggml_nelements(dst);
const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
- // TODO: add ggml_metal_kargs struct
+ ggml_metal_kargs_pool_2d args_pool_2d = {
+ /* .k0 = */ k0,
+ /* .k1 = */ k1,
+ /* .s0 = */ s0,
+ /* .s1 = */ s1,
+ /* .p0 = */ p0,
+ /* .p1 = */ p1,
+ /* .IH = */ IH,
+ /* .IW = */ IW,
+ /* .OH = */ OH,
+ /* .OW = */ OW,
+ /* .parallel_elements = */ parallel_elements
+ };
+
[encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2];
- [encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3];
- [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4];
- [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5];
- [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6];
- [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7];
- [encoder setBytes:&IH length:sizeof(int64_t) atIndex:8];
- [encoder setBytes:&IW length:sizeof(int64_t) atIndex:9];
- [encoder setBytes:&OH length:sizeof(int64_t) atIndex:10];
- [encoder setBytes:&OW length:sizeof(int64_t) atIndex:11];
- [encoder setBytes:¶llel_elements length:sizeof(int64_t) atIndex:12];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&args_pool_2d length:sizeof(args_pool_2d) atIndex:2];
[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
} break;
kernel void kernel_sum_rows(
device const float * src0,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
+ constant ggml_metal_kargs_sum_rows & args,
uint3 tpig[[thread_position_in_grid]]) {
int64_t i3 = tpig.z;
int64_t i2 = tpig.y;
int64_t i1 = tpig.x;
- if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
+ if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return;
}
- device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
- device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
+ device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
+ device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
float row_sum = 0;
- for (int64_t i0 = 0; i0 < ne00; i0++) {
+ for (int64_t i0 = 0; i0 < args.ne00; i0++) {
row_sum += src_row[i0];
}
device const char * src0,
device const char * src1,
device char * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant float & scale,
- constant float & max_bias,
- constant float & m0,
- constant float & m1,
- constant uint32_t & n_head_log2,
+ constant ggml_metal_kargs_soft_max & args,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = (tgpig) / (ne02*ne01);
- const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
- const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
+ const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
+ const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
+ const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
- device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
- device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+ device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
+ device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr;
+ device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
float slope = 1.0f;
// ALiBi
- if (max_bias > 0.0f) {
+ if (args.max_bias > 0.0f) {
const int64_t h = i02;
- const float base = h < n_head_log2 ? m0 : m1;
- const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+ const float base = h < args.n_head_log2 ? args.m0 : args.m1;
+ const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
slope = pow(base, exp);
}
// parallel max
float lmax = -INFINITY;
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
- lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
+ for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
+ lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
}
// find the max value in the block
// parallel sum
float lsum = 0.0f;
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
- const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
+ for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
+ const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
lsum += exp_psrc0;
pdst[i00] = exp_psrc0;
}
const float inv_sum = 1.0f/sum;
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
pdst[i00] *= inv_sum;
}
}
device const char * src0,
device const char * src1,
device char * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant float & scale,
- constant float & max_bias,
- constant float & m0,
- constant float & m1,
- constant uint32_t & n_head_log2,
+ constant ggml_metal_kargs_soft_max & args,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = (tgpig) / (ne02*ne01);
- const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
- const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
+ const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
+ const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
+ const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
- device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
- device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
+ device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
+ device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr;
+ device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
float slope = 1.0f;
- if (max_bias > 0.0f) {
+ if (args.max_bias > 0.0f) {
const int64_t h = i02;
- const float base = h < n_head_log2 ? m0 : m1;
- const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+ const float base = h < args.n_head_log2 ? args.m0 : args.m1;
+ const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
slope = pow(base, exp);
}
// parallel max
float4 lmax4 = -INFINITY;
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
- lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
+ for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
+ lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
}
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
// parallel sum
float4 lsum4 = 0.0f;
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
- const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
+ for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
+ const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
const float inv_sum = 1.0f/sum;
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+ for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
pdst4[i00] *= inv_sum;
}
}
kernel void kernel_diag_mask_inf(
device const float * src0,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int & n_past,
+ constant ggml_metal_kargs_diag_mask_inf & args,
uint3 tpig[[thread_position_in_grid]]) {
const int64_t i02 = tpig[2];
const int64_t i01 = tpig[1];
const int64_t i00 = tpig[0];
- if (i00 > n_past + i01) {
- dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
+ if (i00 > args.n_past + i01) {
+ dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = -INFINITY;
} else {
- dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
+ dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = src0[i02*args.ne01*args.ne00 + i01*args.ne00 + i00];
}
}
kernel void kernel_diag_mask_inf_8(
device const float4 * src0,
device float4 * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int & n_past,
+ constant ggml_metal_kargs_diag_mask_inf & args,
uint3 tpig[[thread_position_in_grid]]) {
const int64_t i = 2*tpig[0];
dst[i+0] = src0[i+0];
dst[i+1] = src0[i+1];
int64_t i4 = 4*i;
- const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
- const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
+ const int64_t i02 = i4/(args.ne00*args.ne01); i4 -= i02*args.ne00*args.ne01;
+ const int64_t i01 = i4/(args.ne00); i4 -= i01*args.ne00;
const int64_t i00 = i4;
for (int k = 3; k >= 0; --k) {
- if (i00 + 4 + k <= n_past + i01) {
+ if (i00 + 4 + k <= args.n_past + i01) {
break;
}
dst[i+1][k] = -INFINITY;
- if (i00 + k > n_past + i01) {
+ if (i00 + k > args.n_past + i01) {
dst[i][k] = -INFINITY;
}
}
}
// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
-// TODO: optimize
kernel void kernel_ssm_conv_f32(
device const void * src0,
device const void * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
+ constant ggml_metal_kargs_ssm_conv & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i2 = tgpig.y;
const int64_t i3 = tgpig.z;
- const int64_t nc = ne10;
- //const int64_t ncs = ne00;
- //const int64_t nr = ne01;
- //const int64_t n_t = ne1;
- //const int64_t n_s = ne2;
+ const int64_t nc = args.ne10;
+ //const int64_t ncs = args.ne00;
+ //const int64_t nr = args.ne01;
+ //const int64_t n_t = args.ne1;
+ //const int64_t n_s = args.ne2;
- device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
- device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
- device float * x = (device float *) ((device char *) dst + ir*nb0 + i2*nb1 + i3*nb2);
+ device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
+ device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
+ device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
float sumf = 0.0f;
}
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
-// TODO: optimize
kernel void kernel_ssm_scan_f32(
device const void * src0,
device const void * src1,
device const void * src4,
device const void * src5,
device float * dst,
- constant int64_t & d_state,
- constant int64_t & d_inner,
- constant int64_t & n_seq_tokens,
- constant int64_t & n_seqs,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant uint64_t & nb20,
- constant uint64_t & nb21,
- constant uint64_t & nb22,
- constant uint64_t & nb30,
- constant uint64_t & nb31,
- constant uint64_t & nb40,
- constant uint64_t & nb41,
- constant uint64_t & nb42,
- constant uint64_t & nb50,
- constant uint64_t & nb51,
- constant uint64_t & nb52,
+ constant ggml_metal_kargs_ssm_scan & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t ir = tgpig.x;
const int64_t i3 = tgpig.y;
- const int64_t nc = d_state;
- //const int64_t nr = d_inner;
- const int64_t n_t = n_seq_tokens;
- //const int64_t n_s = n_seqs;
+ const int64_t nc = args.d_state;
+ // const int64_t nr = args.d_inner;
+ const int64_t n_t = args.n_seq_tokens;
+ // const int64_t n_s = args.n_seqs;
for (int64_t i2 = 0; i2 < n_t; ++i2) {
- device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
- device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12);
- device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22);
- device const float * A = (device const float *) ((device const char *) src3 + ir*nb31);
- device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42);
- device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52);
- device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides
- device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13);
+ device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02);
+ device const float * x = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12);
+ device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22);
+ device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
+ device const float * B = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42);
+ device const float * C = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52);
+ device float * y = (device float *) ((device char *) dst + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides
+ device float * s = (device float *) ((device char *) dst + ir*args.nb01 + i3*args.nb02 + args.nb13);
if (i2 > 0) {
s0 = s;
kernel void kernel_group_norm(
device const float * src0,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int32_t & n_groups,
- constant float & eps,
+ constant ggml_metal_kargs_group_norm & args,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
- const int64_t ne = ne00*ne01*ne02;
- const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
+ const int64_t ne = args.ne00*args.ne01*args.ne02;
+ const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.n_groups - 1) / args.n_groups);
int start = tgpig * gs;
int end = start + gs;
}
const float variance = tmp / gs;
- const float scale = 1.0f/sqrt(variance + eps);
+ const float scale = 1.0f/sqrt(variance + args.eps);
for (int j = start; j < end; j += ntg) {
dst[j] *= scale;
}
typedef void (im2col_t)(
device const float * x,
device char * dst,
- constant int32_t & ofs0,
- constant int32_t & ofs1,
- constant int32_t & IW,
- constant int32_t & IH,
- constant int32_t & CHW,
- constant int32_t & s0,
- constant int32_t & s1,
- constant int32_t & p0,
- constant int32_t & p1,
- constant int32_t & d0,
- constant int32_t & d1,
+ constant ggml_metal_kargs_im2col & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
kernel void kernel_im2col(
device const float * x,
device char * dst,
- constant int32_t & ofs0,
- constant int32_t & ofs1,
- constant int32_t & IW,
- constant int32_t & IH,
- constant int32_t & CHW,
- constant int32_t & s0,
- constant int32_t & s1,
- constant int32_t & p0,
- constant int32_t & p1,
- constant int32_t & d0,
- constant int32_t & d1,
+ constant ggml_metal_kargs_im2col & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
const int64_t ioh = tgpig[1];
const int64_t iow = tgpig[2];
- const int64_t iiw = iow*s0 + ikw*d0 - p0;
- const int64_t iih = ioh*s1 + ikh*d1 - p1;
+ const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
+ const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;
- const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*CHW + (iic*(KH*KW) + ikh*KW + ikw);
+ const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
device T * pdst = (device T *) (dst);
- if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+ if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
pdst[offset_dst] = 0.0f;
} else {
- const int64_t offset_src = in*ofs0 + iic*ofs1 + iih*IW + iiw;
+ const int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
pdst[offset_dst] = x[offset_src];
}
}
typedef void (im2col_ext_t)(
device const float * x,
device char * dst,
- constant int32_t & ofs0,
- constant int32_t & ofs1,
- constant int32_t & IW,
- constant int32_t & IH,
- constant int32_t & CHW,
- constant int32_t & s0,
- constant int32_t & s1,
- constant int32_t & p0,
- constant int32_t & p1,
- constant int32_t & d0,
- constant int32_t & d1,
- constant int32_t & N,
- constant int32_t & KH,
- constant int32_t & KW,
+ constant ggml_metal_kargs_im2col & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
kernel void kernel_im2col_ext(
device const float * x,
device char * dst,
- constant int32_t & ofs0,
- constant int32_t & ofs1,
- constant int32_t & IW,
- constant int32_t & IH,
- constant int32_t & CHW,
- constant int32_t & s0,
- constant int32_t & s1,
- constant int32_t & p0,
- constant int32_t & p1,
- constant int32_t & d0,
- constant int32_t & d1,
- constant int32_t & N,
- constant int32_t & KH,
- constant int32_t & KW,
+ constant ggml_metal_kargs_im2col & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
- const int64_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2]
+ const int64_t KHW = (int64_t)args.KHW;
- const int64_t d = tgpig[0] / CHW;
- const int64_t chw = tgpig[0] % CHW;
+ const int64_t d = tgpig[0] / args.CHW;
+ const int64_t chw = tgpig[0] % args.CHW;
const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
const int64_t HW = tgpig[0] % KHW;
const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
- if (tpitg_0 >= N) {
+ if (tpitg_0 >= args.N) {
return;
}
- const int64_t tpitg_1 = HW / KW;
- const int64_t tpitg_2 = HW % KW;
+ const int64_t tpitg_1 = HW / args.KW;
+ const int64_t tpitg_2 = HW % args.KW;
- const int64_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
- const int64_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
+ const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
+ const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
const int64_t offset_dst =
- (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
- (tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);
+ (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
+ (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
device T * pdst = (device T *) (dst);
- if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+ if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
pdst[offset_dst] = 0.0f;
} else {
- const int64_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
- pdst[offset_dst] = x[offset_src + iih * IW + iiw];
+ const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
+ pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
}
}
device const float * src0,
device const float * src1,
device char * dst,
- constant int32_t & IC,
- constant int32_t & IL,
- constant int32_t & K,
- constant int32_t & s0,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
+ constant ggml_metal_kargs_conv_transpose_1d & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
device const T * src0,
device const float * src1,
device char * dst,
- constant int32_t & IC,
- constant int32_t & IL,
- constant int32_t & K,
- constant int32_t & s0,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
+ constant ggml_metal_kargs_conv_transpose_1d & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]) {
float v = 0.0f;
- for (int64_t c = 0; c < IC; c++) {
- const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
- const int32_t input_offset = c * IL;
+ for (int64_t c = 0; c < args.IC; c++) {
+ const int32_t kernel_offset = c * tgpg[1] * args.K + args.K * tgpig[1];
+ const int32_t input_offset = c * args.IL;
- for (int64_t i = 0; i < IL; i++) {
- if (tgpig[0] >= i * s0 && tgpig[0] < i * s0 + K) {
- v += src0[kernel_offset + tgpig[0] - i * s0] * src1[input_offset + i];
+ for (int64_t i = 0; i < args.IL; i++) {
+ if (tgpig[0] >= i * args.s0 && tgpig[0] < i * args.s0 + args.K) {
+ v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i];
}
}
}
- device float * dst_ptr = (device float *) (dst + tgpig[0] * nb0 + tgpig[1] * nb1);
+ device float * dst_ptr = (device float *) (dst + tgpig[0] * args.nb0 + tgpig[1] * args.nb1);
dst_ptr[0] = v;
}
device const float * src0,
device const float * src1,
device char * dst,
- constant int32_t & IC,
- constant int32_t & IL,
- constant int32_t & K,
- constant int32_t & s0,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
+ constant ggml_metal_kargs_conv_transpose_1d & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
device const half * src0,
device const float * src1,
device char * dst,
- constant int32_t & IC,
- constant int32_t & IL,
- constant int32_t & K,
- constant int32_t & s0,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
+ constant ggml_metal_kargs_conv_transpose_1d & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
kernel void kernel_upscale_f32(
device const char * src0,
device char * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- constant float & sf0,
- constant float & sf1,
- constant float & sf2,
- constant float & sf3,
+ constant ggml_metal_kargs_upscale & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
- const int64_t i03 = i3/sf3;
- const int64_t i02 = i2/sf2;
- const int64_t i01 = i1/sf1;
+ const int64_t i03 = i3/args.sf3;
+ const int64_t i02 = i2/args.sf2;
+ const int64_t i01 = i1/args.sf1;
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
- const int64_t i00 = i0/sf0;
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ const int64_t i00 = i0/args.sf0;
- device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
- device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+ device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
+ device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_ptr[0] = src0_ptr[0];
}
kernel void kernel_pad_f32(
device const char * src0,
device char * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
+ constant ggml_metal_kargs_pad & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i02 = i2;
const int64_t i01 = i1;
- device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
- device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
+ device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
+ device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
- if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
- if (i0 < ne00) {
+ if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ if (i0 < args.ne00) {
dst_ptr[i0] = src0_ptr[i0];
} else {
dst_ptr[i0] = 0.0f;
return;
}
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
dst_ptr[i0] = 0.0f;
}
}
kernel void kernel_pad_reflect_1d_f32(
device const char * src0,
device char * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant int64_t & ne0,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- constant int32_t & p0,
- constant int32_t & p1,
+ constant ggml_metal_kargs_pad_reflect_1d & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
const int64_t i02 = i2;
const int64_t i01 = i1;
- device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
- device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
+ device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
+ device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
- if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
- if (i0 < p0) {
- dst_ptr[i0] = src0_ptr[p0 - i0];
- } else if (i0 < ne0 - p1) {
- dst_ptr[i0] = src0_ptr[i0 - p0];
+ if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ if (i0 < args.p0) {
+ dst_ptr[i0] = src0_ptr[args.p0 - i0];
+ } else if (i0 < args.ne0 - args.p1) {
+ dst_ptr[i0] = src0_ptr[i0 - args.p0];
} else {
- dst_ptr[i0] = src0_ptr[(ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1];
+ dst_ptr[i0] = src0_ptr[(args.ne0 - args.p1 - args.p0) - (args.p1 + 1 - (args.ne0 - i0)) - 1];
}
}
}
kernel void kernel_arange_f32(
device char * dst,
- constant int64_t & ne0,
- constant float & start,
- constant float & step,
+ constant ggml_metal_kargs_arange & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
device float * dst_ptr = (device float *) dst;
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
- dst_ptr[i0] = start + step * i0;
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ dst_ptr[i0] = args.start + args.step * i0;
}
}
kernel void kernel_timestep_embedding_f32(
device const char * src0,
device char * dst,
- constant uint64_t & nb1,
- constant int & dim,
- constant int & max_period,
+ constant ggml_metal_kargs_timestep_embedding & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
int i = tgpig.x;
- device float * embed_data = (device float *)(dst + i*nb1);
+ device float * embed_data = (device float *)(dst + i*args.nb1);
- int half_ = dim / 2;
+ int half_ = args.dim / 2;
for (int j = tpitg.x; j < half_; j += ntg.x) {
float timestep = ((device float *)src0)[i];
- float freq = (float)exp(-log((float)max_period) * j / half_);
+ float freq = (float)exp(-log((float)args.max_period) * j / half_);
float arg = timestep * freq;
embed_data[j ] = cos(arg);
embed_data[j + half_] = sin(arg);
}
- if (dim % 2 != 0 && tpitg.x == 0) {
- embed_data[dim] = 0.f;
+ if (args.dim % 2 != 0 && tpitg.x == 0) {
+ embed_data[args.dim] = 0.f;
}
}
typedef void (argsort_t)(
device const float * x,
device int32_t * dst,
- constant int64_t & ncols,
- constant int64_t & ncols_pad,
+ constant ggml_metal_kargs_argsort & args,
threadgroup int32_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]);
kernel void kernel_argsort_f32_i32(
device const float * x,
device int32_t * dst,
- constant int64_t & ncols,
- constant int64_t & ncols_pad,
+ constant ggml_metal_kargs_argsort & args,
threadgroup int32_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]) {
int col = tpitg[0];
int row = tgpig[1];
- if (col >= ncols_pad) return;
+ if (col >= args.ncols_pad) return;
- device const float * x_row = x + row * ncols;
+ device const float * x_row = x + row * args.ncols;
threadgroup int32_t * dst_row = shared_values;
// initialize indices
threadgroup_barrier(mem_flags::mem_threadgroup);
- for (int k = 2; k <= ncols_pad; k *= 2) {
+ for (int k = 2; k <= args.ncols_pad; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
- if (dst_row[col] >= ncols ||
- (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
+ if (dst_row[col] >= args.ncols ||
+ (dst_row[ixj] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
) {
SWAP(dst_row[col], dst_row[ixj]);
}
} else {
- if (dst_row[ixj] >= ncols ||
- (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
+ if (dst_row[ixj] >= args.ncols ||
+ (dst_row[col] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
) {
}
// copy the result to dst without the padding
- if (col < ncols) {
- dst[row * ncols + col] = dst_row[col];
+ if (col < args.ncols) {
+ dst[row * args.ncols + col] = dst_row[col];
}
}
kernel void kernel_leaky_relu_f32(
device const float * src0,
device float * dst,
- constant float & slope,
+ constant ggml_metal_kargs_leaky_relu & args,
uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
+ dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * args.slope;
}
// ref: https://arxiv.org/pdf/2307.08691.pdf
device const void * src0,
device const void * src1,
device float * dst,
- constant int64_t & ne00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
+ constant ggml_metal_kargs_get_rows & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) {
const int64_t i10 = tgpig.x;
const int64_t i11 = tgpig.y;
- const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
const int64_t i02 = i11;
- for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
+ for (int64_t ind = tiitg; ind < args.ne00/16; ind += tptg.x) {
float4x4 temp;
- dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
- *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
+ dequantize_func(((device const block_q *) ((const device char *) src0 + r*args.nb01 + i02*args.nb02)) + ind/nl, ind%nl, temp);
+ *(((device float4x4 *) ((device char *) dst + i11*args.nb2 + i10*args.nb1)) + ind) = temp;
}
}
device const void * src0,
device const void * src1,
device float * dst,
- constant int64_t & ne00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
+ constant ggml_metal_kargs_get_rows & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) {
const int64_t i10 = tgpig.x;
const int64_t i11 = tgpig.y;
- const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
const int64_t i02 = i11;
- for (int ind = tiitg; ind < ne00; ind += tptg.x) {
- (( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
- ((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
+ for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
+ (( device float *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] =
+ ((const device T *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind];
}
}
device const void * src0,
device const void * src1,
device int32_t * dst,
- constant int64_t & ne00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
+ constant ggml_metal_kargs_get_rows & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) {
const int64_t i10 = tgpig.x;
const int64_t i11 = tgpig.y;
- const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
+ const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
const int64_t i02 = i11;
- for (int ind = tiitg; ind < ne00; ind += tptg.x) {
- (( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
- ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
+ for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
+ (( device int32_t *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] =
+ ((const device int32_t *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind];
}
}
kernel void kernel_pool_2d_max_f32(
device const float * src0,
device float * dst,
- constant int32_t & k0,
- constant int32_t & k1,
- constant int32_t & s0,
- constant int32_t & s1,
- constant int32_t & p0,
- constant int32_t & p1,
- constant int64_t & IH,
- constant int64_t & IW,
- constant int64_t & OH,
- constant int64_t & OW,
- constant int64_t & parallel_elements,
+ constant ggml_metal_kargs_pool_2d & args,
uint gid[[thread_position_in_grid]]) {
- if (gid >= parallel_elements) {
+ if (gid >= args.parallel_elements) {
return;
}
const int idx = gid;
- const int I_HW = IH * IW;
- const int O_HW = OH * OW;
+ const int I_HW = args.IH * args.IW;
+ const int O_HW = args.OH * args.OW;
const int nc = idx / O_HW;
- const int cur_oh = idx % O_HW / OW;
- const int cur_ow = idx % O_HW % OW;
+ const int cur_oh = idx % O_HW / args.OW;
+ const int cur_ow = idx % O_HW % args.OW;
device const float * i_ptr = src0 + nc * I_HW;
device float * o_ptr = dst + nc * O_HW;
- const int start_h = cur_oh * s1 - p1;
+ const int start_h = cur_oh * args.s1 - args.p1;
const int bh = MAX(0, start_h);
- const int eh = MIN(IH, start_h + k1);
- const int start_w = cur_ow * s0 - p0;
+ const int eh = MIN(args.IH, start_h + args.k1);
+ const int start_w = cur_ow * args.s0 - args.p0;
const int bw = MAX(0, start_w);
- const int ew = MIN(IW, start_w + k0);
+ const int ew = MIN(args.IW, start_w + args.k0);
float res = -INFINITY;
for (int i = bh; i < eh; i += 1) {
for (int j = bw; j < ew; j += 1) {
- res = MAX(res, i_ptr[i * IW + j]);
+ res = MAX(res, i_ptr[i * args.IW + j]);
}
}
- o_ptr[cur_oh * OW + cur_ow] = res;
+ o_ptr[cur_oh * args.OW + cur_ow] = res;
}
kernel void kernel_pool_2d_avg_f32(
device const float * src0,
device float * dst,
- constant int32_t & k0,
- constant int32_t & k1,
- constant int32_t & s0,
- constant int32_t & s1,
- constant int32_t & p0,
- constant int32_t & p1,
- constant int64_t & IH,
- constant int64_t & IW,
- constant int64_t & OH,
- constant int64_t & OW,
- constant int64_t & parallel_elements,
+ constant ggml_metal_kargs_pool_2d & args,
uint gid[[thread_position_in_grid]]) {
- if (gid >= parallel_elements) {
+ if (gid >= args.parallel_elements) {
return;
}
const int idx = gid;
- const int I_HW = IH * IW;
- const int O_HW = OH * OW;
+ const int I_HW = args.IH * args.IW;
+ const int O_HW = args.OH * args.OW;
const int nc = idx / O_HW;
- const int cur_oh = idx % O_HW / OW;
- const int cur_ow = idx % O_HW % OW;
+ const int cur_oh = idx % O_HW / args.OW;
+ const int cur_ow = idx % O_HW % args.OW;
device const float * i_ptr = src0 + nc * I_HW;
device float * o_ptr = dst + nc * O_HW;
- const int start_h = cur_oh * s1 - p1;
+ const int start_h = cur_oh * args.s1 - args.p1;
const int bh = MAX(0, start_h);
- const int eh = MIN(IH, start_h + k1);
- const int start_w = cur_ow * s0 - p0;
+ const int eh = MIN(args.IH, start_h + args.k1);
+ const int start_w = cur_ow * args.s0 - args.p0;
const int bw = MAX(0, start_w);
- const int ew = MIN(IW, start_w + k0);
+ const int ew = MIN(args.IW, start_w + args.k0);
// const float scale = 1. / ((eh - bh) * (ew - bw));
- const float scale = 1. / (k0 * k1);
+ const float scale = 1. / (args.k0 * args.k1);
float res = 0;
for (int i = bh; i < eh; i += 1) {
for (int j = bw; j < ew; j += 1) {
- float cur = i_ptr[i * IW + j];
+ float cur = i_ptr[i * args.IW + j];
res += cur * scale;
}
}
- o_ptr[cur_oh * OW + cur_ow] = res;
+ o_ptr[cur_oh * args.OW + cur_ow] = res;
}