}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
- GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
char base[256];
char name[256];
- const char * op_str = "undefined";
+ int op_num = -1;
+
switch (op->op) {
- case GGML_OP_SUM_ROWS:
- op_str = "sum_rows"; break;
- case GGML_OP_MEAN:
- op_str = "mean"; break;
+ case GGML_OP_SUM_ROWS: op_num = OP_SUM_ROWS_NUM_SUM_ROWS; break;
+ case GGML_OP_MEAN: op_num = OP_SUM_ROWS_NUM_MEAN; break;
default: GGML_ABORT("fatal error");
};
- snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type));
+ const char * t0_str = ggml_type_name(op->src[0]->type);
+ const char * t_str = ggml_type_name(op->type);
- snprintf(name, 256, "%s", base);
+ const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
+
+ snprintf(base, 256, "kernel_sum_rows_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
+ snprintf(name, 256, "%s_op=%d", base, op_num);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+ ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+ ggml_metal_cv_set_int16(cv, op_num, FC_SUM_ROWS + 0);
+
+ res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+ ggml_metal_cv_free(cv);
}
res.smem = 32*sizeof(float);
+ if (is_c4) {
+ res.smem *= 4;
+ }
+
+ res.c4 = is_c4;
+
return res;
}
#define FC_COUNT_EQUAL 1100
#define FC_UNARY 1200
#define FC_BIN 1300
+#define FC_SUM_ROWS 1400
// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPSG 8
#define OP_UNARY_NUM_SOFTPLUS 115
#define OP_UNARY_NUM_EXPM1 116
+#define OP_SUM_ROWS_NUM_SUM_ROWS 10
+#define OP_SUM_ROWS_NUM_MEAN 11
// kernel argument structs
//
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
+
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
+
ggml_metal_kargs_sum_rows args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
+ if (pipeline.c4) {
+ args.ne00 = ne00/4;
+ args.ne0 = ne0/4;
+ }
+
int nth = 32; // SIMD width
- while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+ while (nth < args.ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
}
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
- nth = std::min(nth, ne00);
+ nth = std::min(nth, (int) args.ne00);
const size_t smem = pipeline.smem;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
return x*y;
}
+static inline float sum(float x) {
+ return x;
+}
+
+static inline float sum(float4 x) {
+ return x[0] + x[1] + x[2] + x[3];
+}
+
// NOTE: this is not dequantizing - we are simply fitting the template
template <typename type4x4>
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
}
}
-template <bool norm>
-kernel void kernel_sum_rows(
+constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]];
+
+template <typename T0, typename T>
+kernel void kernel_sum_rows_impl(
constant ggml_metal_kargs_sum_rows & args,
- device const float * src0,
- device float * dst,
- threadgroup float * shmem_f32 [[threadgroup(0)]],
+ device const char * src0,
+ device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
- int64_t i3 = tgpig.z;
- int64_t i2 = tgpig.y;
- int64_t i1 = tgpig.x;
+#define FC_OP FC_sum_rows_op
- if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
- return;
- }
+ const int i3 = tgpig.z;
+ const int i2 = tgpig.y;
+ const int i1 = tgpig.x;
+
+ threadgroup T0 * shmem_t = (threadgroup T0 *) shmem;
if (sgitg == 0) {
- shmem_f32[tiisg] = 0.0f;
+ shmem_t[tiisg] = 0.0f;
}
- device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
- device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
+ device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
+ device T * dst_row = (device T *) (dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
- float sumf = 0;
+ T0 sumf = T0(0.0f);
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
sumf += src_row[i0];
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
- shmem_f32[sgitg] = sumf;
+ shmem_t[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
- sumf = shmem_f32[tiisg];
+ sumf = shmem_t[tiisg];
sumf = simd_sum(sumf);
if (tpitg.x == 0) {
- dst_row[0] = norm ? sumf / args.ne00 : sumf;
+ if (FC_OP == OP_SUM_ROWS_NUM_MEAN) {
+ if (is_same<float4, T0>::value) {
+ dst_row[0] = sum(sumf) / (4*args.ne00);
+ } else {
+ dst_row[0] = sum(sumf) / args.ne00;
+ }
+ } else {
+ dst_row[0] = sum(sumf);
+ }
}
+
+#undef FC_OP
}
-typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
+typedef decltype(kernel_sum_rows_impl<float, float>) kernel_sum_rows_t;
-template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
-template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
+template [[host_name("kernel_sum_rows_f32_f32")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float, float>;
+template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float4, float>;
template<typename T>
kernel void kernel_cumsum_blk(
const short K = FC_solve_tri_k;
const short NP = PAD2(N, NW);
- const int32_t ne02 = args.ne02;
- const int32_t ne03 = args.ne03;
-
const int32_t i03 = tgpig.z;
const int32_t i02 = tgpig.y;
const int32_t i01 = tgpig.x*NSG + sgitg;
static_assert(DK4 % NL == 0, "DK4 must be divisible by NL");
static_assert(DV4 % NL == 0, "DV4 must be divisible by NL");
- const short T = PK + NSG*SH; // shared memory size per query in (half)
+ //const short T = PK + NSG*SH; // shared memory size per query in (half)
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
+#ifdef GGML_METAL_HAS_TENSOR
threadgroup float * sc = (threadgroup float *)(shmem);
+#endif
constexpr int NR0 = 64;
constexpr int NR1 = 32;
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
- const short dx = sx;
- const short dy = sy;
+ //const short dx = sx;
+ //const short dy = sy;
const short ly = (tiitg/NL1)%8;
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
+#ifdef GGML_METAL_HAS_TENSOR
threadgroup float * sc = (threadgroup float *)(shmem);
+#endif
constexpr int NR0 = 64;
constexpr int NR1 = 32;
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
- const short dx = sx;
- const short dy = sy;
+ //const short dx = sx;
+ //const short dy = sy;
const short ly = (tiitg/NL1)%8;
}
test_cases.emplace_back(new test_sum());
- test_cases.emplace_back(new test_sum_rows());
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 2, 1, 3})); // row-contiguous but non-contiguous
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 3, 2, 1}));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 1, 3, 2}));
- test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false));
- test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true));
- test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true));
test_cases.emplace_back(new test_mean());
- test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 }));
- test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1, 1, 1 }));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 1, 1, 1 }));
+ test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 }));
+ test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 }));
+ test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32, 1, 1, 1 }));
+ test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32, 256, 1, 1 }));
+ test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32768, 1, 1, 1 }));
+ test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 }));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
- test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous
+ test_cases.emplace_back(new test_sum_rows());
+ test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false));
+ test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true));
+ test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true));
+ test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, true, false));
+ test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, false, true));
+ test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, true, true));
+ test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1, 1, 1 }));
+ test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 256, 1, 1 }));
- test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 }));
- test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 }));
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1}));
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {64, 64, 320, 1}));