nth *= 2;
}
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
nth = MIN(nth, ne00);
ggml_metal_kargs_sum_rows args = {
nth *= 2;
}
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
nth = MIN(nth, ne00/4);
ggml_metal_kargs_rms_norm args = {
nth *= 2;
}
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
nth = MIN(nth, ne00/4);
ggml_metal_kargs_l2_norm args = {
nth *= 2;
}
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
nth = MIN(nth, ne00/4);
ggml_metal_kargs_norm args = {
default: GGML_ABORT("not implemented");
}
+ GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
+
+ // TODO: support
+ //const int32_t nk00 = ne00/ggml_blck_size(dst->type);
+ const int32_t nk00 = ne00;
+
+ int nth = 32; // SIMD width
+
+ while (nth < nk00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
+ nth *= 2;
+ }
+
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
+
+ // when rows are small, we can batch them together in a single threadgroup
+ int nrptg = 1;
+
+ // TODO: relax this constraint in the future
+ if (ggml_blck_size(src0->type) == 1 && ggml_blck_size(dst->type) == 1) {
+ if (nth > nk00) {
+ nrptg = (nth + nk00 - 1)/nk00;
+ nth = nk00;
+
+ if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
+ nrptg--;
+ }
+ }
+ }
+
+ nth = MIN(nth, nk00);
+
ggml_metal_kargs_cpy args = {
- /*.ne00 =*/ ne00,
+ /*.ne00 =*/ nk00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
- int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
-
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
} break;
case GGML_OP_SET:
{
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort3 ntg[[threads_per_threadgroup]]) {
+ ushort3 tptg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
- const int i01 = tgpig[0];
+ const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x;
+
+ if (i01 >= args.ne01) {
+ return;
+ }
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
- for (int64_t i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
+ for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) {
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
dst_data[i00] = (T1) src[0];
}