ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
+ int nth = 32; // SIMD width
+
+ while (nth < (int) n && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+ nth *= 2;
+ }
+
+ nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+ nth = std::min(nth, (int) n);
+
+ const int nsg = (nth + 31) / 32;
+
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
- ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
+ ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
return 1;
}
constant ggml_metal_kargs_sum & args,
device const float * src0,
device float * dst,
- ushort tiitg[[thread_index_in_threadgroup]]) {
+ threadgroup float * shmem_f32 [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
- if (tiitg != 0) {
+ if (args.np == 0) {
return;
}
- float acc = 0.0f;
- for (ulong i = 0; i < args.np; ++i) {
- acc += src0[i];
+ const uint nsg = (ntg.x + 31) / 32;
+
+ float sumf = 0;
+
+ for (int64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
+ sumf += src0[i0];
}
- dst[0] = acc;
+ sumf = simd_sum(sumf);
+
+ if (tiisg == 0) {
+ shmem_f32[sgitg] = sumf;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ float total = 0;
+
+ if (sgitg == 0) {
+ float v = 0;
+
+ if (tpitg.x < nsg) {
+ v = shmem_f32[tpitg.x];
+ }
+
+ total = simd_sum(v);
+
+ if (tpitg.x == 0) {
+ dst[0] = total;
+ }
+ }
}
template <bool norm>
struct test_sum : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
+ const std::array<int64_t, 4> permute;
+ bool _use_permute;
std::string vars() override {
- return VARS_TO_STR2(type, ne);
+ std::string v = VARS_TO_STR2(type, ne);
+ if (_use_permute) v += "," + VAR_TO_STR(permute);
+ return v;
}
test_sum(ggml_type type = GGML_TYPE_F32,
- std::array<int64_t, 4> ne = {10, 5, 4, 3})
- : type(type), ne(ne) {}
+ std::array<int64_t, 4> ne = {10, 5, 4, 3},
+ std::array<int64_t, 4> permute = {0, 0, 0, 0})
+ : type(type), ne(ne), permute(permute),
+ _use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(a);
ggml_set_name(a, "a");
+ if (_use_permute) {
+ a = ggml_permute(ctx, a, permute[0], permute[1], permute[2], permute[3]);
+ ggml_set_name(a, "a_permuted");
+ }
+
ggml_tensor * out = ggml_sum(ctx, a);
ggml_set_name(out, "out");
test_cases.emplace_back(new test_sum());
test_cases.emplace_back(new test_sum_rows());
+ test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 2, 1, 3})); // row-contiguous but non-contiguous
+ test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 3, 2, 1}));
+ test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }));
+ test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 256, 1, 1 }));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 }));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 }));