ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
+ ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
+
// bitonic sort requires the number of elements to be power of 2
- int64_t ne00_padded = 1;
- while (ne00_padded < ne00) {
- ne00_padded *= 2;
+ int nth = 1;
+ while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+ nth *= 2;
}
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
-
- const int64_t nrows = ggml_nrows(op->src[0]);
+ const int nptg = (ne00 + nth - 1)/nth;
// Metal kernels require the buffer size to be multiple of 16 bytes
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
- const size_t smem = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
+ const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
+
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
+
+ ggml_metal_buffer_id bid_tmp = bid_dst;
+ bid_tmp.offs += ggml_nbytes(op);
+
+ if ((int) ceil(std::log(nptg) / std::log(2)) % 2 == 1) {
+ std::swap(bid_dst, bid_tmp);
+ }
ggml_metal_kargs_argsort args = {
- /*.ncols =*/ ne00,
- /*.ncols_pad =*/ ne00_padded
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
};
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
- ggml_metal_encoder_dispatch_threadgroups(enc, 1, nrows, 1, ne00_padded, 1, 1);
+ ggml_metal_encoder_dispatch_threadgroups(enc, nptg*ne01, ne02, ne03, nth, 1, 1);
+
+ ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
+
+ int len = nth;
+
+ while (len < ne00) {
+ ggml_metal_op_concurrency_reset(ctx);
+
+ ggml_metal_kargs_argsort_merge args_merge = {
+ .ne00 = ne00,
+ .ne01 = ne01,
+ .ne02 = ne02,
+ .ne03 = ne03,
+ .nb00 = nb00,
+ .nb01 = nb01,
+ .nb02 = nb02,
+ .nb03 = nb03,
+ .len = len,
+ };
+
+ // merges per row
+ const int nm = (ne00 + 2*len - 1) / (2*len);
+
+ const int nth = std::min(512, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge));
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
+ ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
+ ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
+
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, 0, 0);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
+
+ std::swap(bid_dst, bid_tmp);
+
+ len <<= 1;
+ }
return 1;
}
// bitonic sort implementation following the CUDA kernels as reference
typedef void (argsort_t)(
constant ggml_metal_kargs_argsort & args,
- device const float * x,
+ device const char * src0,
device int32_t * dst,
- threadgroup int32_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]]);
+ threadgroup int32_t * smem_i32 [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]);
template<ggml_sort_order order>
kernel void kernel_argsort_f32_i32(
constant ggml_metal_kargs_argsort & args,
- device const float * x,
+ device const char * src0,
device int32_t * dst,
- threadgroup int32_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]]) {
+ threadgroup int32_t * smem_i32 [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
// bitonic sort
- int col = tpitg[0];
- int row = tgpig[1];
+ const int col = tpitg[0];
- if (col >= args.ncols_pad) return;
+ const int i00 = (tgpig[0]/args.ne01)*ntg.x;
+ const int i01 = tgpig[0]%args.ne01;
+ const int i02 = tgpig[1];
+ const int i03 = tgpig[2];
- device const float * x_row = x + row * args.ncols;
- threadgroup int32_t * dst_row = shared_values;
+ device const float * x_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
// initialize indices
- dst_row[col] = col;
+ smem_i32[col] = i00 + col;
threadgroup_barrier(mem_flags::mem_threadgroup);
- for (int k = 2; k <= args.ncols_pad; k *= 2) {
+ for (int k = 2; k <= ntg.x; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
- if (dst_row[col] >= args.ncols ||
- (dst_row[ixj] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
- x_row[dst_row[col]] > x_row[dst_row[ixj]] :
- x_row[dst_row[col]] < x_row[dst_row[ixj]]))
+ if (smem_i32[col] >= args.ne00 ||
+ (smem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
+ x_row[smem_i32[col]] > x_row[smem_i32[ixj]] :
+ x_row[smem_i32[col]] < x_row[smem_i32[ixj]]))
) {
- SWAP(dst_row[col], dst_row[ixj]);
+ SWAP(smem_i32[col], smem_i32[ixj]);
}
} else {
- if (dst_row[ixj] >= args.ncols ||
- (dst_row[col] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
- x_row[dst_row[col]] < x_row[dst_row[ixj]] :
- x_row[dst_row[col]] > x_row[dst_row[ixj]]))
+ if (smem_i32[ixj] >= args.ne00 ||
+ (smem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
+ x_row[smem_i32[col]] < x_row[smem_i32[ixj]] :
+ x_row[smem_i32[col]] > x_row[smem_i32[ixj]]))
) {
- SWAP(dst_row[col], dst_row[ixj]);
+ SWAP(smem_i32[col], smem_i32[ixj]);
}
}
}
+
threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
// copy the result to dst without the padding
- if (col < args.ncols) {
- dst[row * args.ncols + col] = dst_row[col];
+ if (i00 + col < args.ne00) {
+ dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03;
+
+ dst[col] = smem_i32[col];
}
}
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
+typedef void (argsort_merge_t)(
+ constant ggml_metal_kargs_argsort_merge & args,
+ device const char * src0,
+ device const int32_t * tmp,
+ device int32_t * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]);
+
+template<ggml_sort_order order>
+kernel void kernel_argsort_merge_f32_i32(
+ constant ggml_metal_kargs_argsort_merge & args,
+ device const char * src0,
+ device const int32_t * tmp,
+ device int32_t * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ int im = tgpig[0] / args.ne01;
+ int i01 = tgpig[0] % args.ne01;
+ int i02 = tgpig[1];
+ int i03 = tgpig[2];
+
+ const int start = im * (2*args.len);
+
+ const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start)));
+ const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len)));
+
+ const int total = len0 + len1;
+
+ device const int32_t * tmp0 = tmp + start
+ + i01*args.ne00
+ + i02*args.ne00*args.ne01
+ + i03*args.ne00*args.ne01*args.ne02;
+
+ device const int32_t * tmp1 = tmp0 + args.len;
+
+ dst += start
+ + i01*args.ne00
+ + i02*args.ne00*args.ne01
+ + i03*args.ne00*args.ne01*args.ne02;
+
+ device const float * src0_row = (device const float *)(src0
+ + args.nb01*i01
+ + args.nb02*i02
+ + args.nb03*i03);
+
+ for (int k = tpitg.x; k < (int) total; k += ntg.x) {
+ // find partition (i,j) such that i+j = k
+ int low = k > len1 ? k - len1 : 0;
+ int high = MIN(k, len0);
+
+ while (low < high) {
+ const int mid = (low + high) >> 1;
+
+ const int32_t idx0 = tmp0[mid];
+ const int32_t idx1 = tmp1[k - mid - 1];
+
+ const float val0 = src0_row[idx0];
+ const float val1 = src0_row[idx1];
+
+ if (order == GGML_SORT_ORDER_ASC) {
+ if (val0 <= val1) {
+ low = mid + 1;
+ } else {
+ high = mid;
+ }
+ } else {
+ if (val0 >= val1) {
+ low = mid + 1;
+ } else {
+ high = mid;
+ }
+ }
+ }
+
+ const int i = low;
+ const int j = k - i;
+
+ int32_t out_idx;
+
+ if (i >= len0) {
+ out_idx = tmp1[j];
+ } else if (j >= len1) {
+ out_idx = tmp0[i];
+ } else {
+ const int32_t idx0 = tmp0[i];
+ const int32_t idx1 = tmp1[j];
+
+ const float val0 = src0_row[idx0];
+ const float val1 = src0_row[idx1];
+
+ out_idx = (order == GGML_SORT_ORDER_ASC)
+ ? (val0 <= val1 ? idx0 : idx1)
+ : (val0 >= val1 ? idx0 : idx1);
+ }
+
+ dst[k] = out_idx;
+ }
+}
+
+template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
+template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
+
kernel void kernel_leaky_relu_f32(
constant ggml_metal_kargs_leaky_relu & args,
device const float * src0,