uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
- int im = tgpig[0] / args.ne01;
- int i01 = tgpig[0] % args.ne01;
- int i02 = tgpig[1];
- int i03 = tgpig[2];
- const int start = im * (2*args.len);
+ const int im = tgpig[0] / args.ne01;
+ const int i01 = tgpig[0] % args.ne01;
+ const int i02 = tgpig[1];
+ const int i03 = tgpig[2];
+
+ const int start = im * (2 * args.len);
const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start)));
const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len)));
+ args.nb02*i02
+ args.nb03*i03);
- for (int k = tpitg.x; k < (int) total; k += ntg.x) {
- // find partition (i,j) such that i+j = k
- int low = k > len1 ? k - len1 : 0;
- int high = MIN(k, len0);
+ if (total == 0) {
+ return;
+ }
- while (low < high) {
- const int mid = (low + high) >> 1;
+ const int chunk = (total + ntg.x - 1) / ntg.x;
- const int32_t idx0 = tmp0[mid];
- const int32_t idx1 = tmp1[k - mid - 1];
+ const int k0 = tpitg.x * chunk;
+ const int k1 = min(k0 + chunk, total);
- const float val0 = src0_row[idx0];
- const float val1 = src0_row[idx1];
+ if (k0 >= total) {
+ return;
+ }
- if (order == GGML_SORT_ORDER_ASC) {
- if (val0 <= val1) {
- low = mid + 1;
- } else {
- high = mid;
- }
- } else {
- if (val0 >= val1) {
- low = mid + 1;
- } else {
- high = mid;
- }
- }
+ int low = k0 > len1 ? k0 - len1 : 0;
+ int high = MIN(k0, len0);
+
+ // binary-search partition (i, j) such that i + j = k
+ while (low < high) {
+ const int mid = (low + high) >> 1;
+
+ const int32_t idx0 = tmp0[mid];
+ const int32_t idx1 = tmp1[k0 - mid - 1];
+
+ const float val0 = src0_row[idx0];
+ const float val1 = src0_row[idx1];
+
+ bool take_left;
+ if (order == GGML_SORT_ORDER_ASC) {
+ take_left = (val0 <= val1);
+ } else {
+ take_left = (val0 >= val1);
}
- const int i = low;
- const int j = k - i;
+ if (take_left) {
+ low = mid + 1;
+ } else {
+ high = mid;
+ }
+ }
+
+ int i = low;
+ int j = k0 - i;
+
+ // keep the merge fronts into registers
+ int32_t idx0 = 0;
+ float val0 = 0.0f;
+ if (i < len0) {
+ idx0 = tmp0[i];
+ val0 = src0_row[idx0];
+ }
+
+ int32_t idx1 = 0;
+ float val1 = 0.0f;
+ if (j < len1) {
+ idx1 = tmp1[j];
+ val1 = src0_row[idx1];
+ }
+ for (int k = k0; k < k1; ++k) {
int32_t out_idx;
if (i >= len0) {
- out_idx = tmp1[j];
+ while (k < k1) {
+ dst[k++] = tmp1[j++];
+ }
+ break;
} else if (j >= len1) {
- out_idx = tmp0[i];
+ while (k < k1) {
+ dst[k++] = tmp0[i++];
+ }
+ break;
} else {
- const int32_t idx0 = tmp0[i];
- const int32_t idx1 = tmp1[j];
+ bool take_left;
- const float val0 = src0_row[idx0];
- const float val1 = src0_row[idx1];
+ if (order == GGML_SORT_ORDER_ASC) {
+ take_left = (val0 <= val1);
+ } else {
+ take_left = (val0 >= val1);
+ }
- out_idx = (order == GGML_SORT_ORDER_ASC)
- ? (val0 <= val1 ? idx0 : idx1)
- : (val0 >= val1 ? idx0 : idx1);
+ if (take_left) {
+ out_idx = idx0;
+ ++i;
+ if (i < len0) {
+ idx0 = tmp0[i];
+ val0 = src0_row[idx0];
+ }
+ } else {
+ out_idx = idx1;
+ ++j;
+ if (j < len1) {
+ idx1 = tmp1[j];
+ val1 = src0_row[idx1];
+ }
+ }
}
dst[k] = out_idx;