ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_L2_NORM);
- GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
- GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
-
char base[256];
char name[256];
- snprintf(base, 256, "kernel_l2_norm_f32");
+ const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
+
+ const char * t0_str = ggml_type_name(op->src[0]->type);
+ const char * t_str = ggml_type_name(op->type);
+
+ snprintf(base, 256, "kernel_l2_norm_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
}
+ res.c4 = is_c4;
res.smem = 32*sizeof(float);
return res;
case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX:
case GGML_OP_GROUP_NORM:
- return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_L2_NORM:
- return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
+ return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_COUNT_EQUAL:
return has_simdgroup_reduction &&
op->src[0]->type == GGML_TYPE_I32 &&
typedef struct {
int32_t ne00;
- int32_t ne00_4;
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
+ uint64_t nb00;
uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne0;
+ int32_t ne1;
+ int32_t ne2;
+ int32_t ne3;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
float eps;
} ggml_metal_kargs_l2_norm;
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
+
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
+
float eps;
memcpy(&eps, op->op_params, sizeof(float));
- int nth = 32; // SIMD width
-
ggml_metal_kargs_l2_norm args = {
- /*.ne00 =*/ ne00,
- /*.ne00_4 =*/ ne00/4,
- /*.nb01 =*/ nb01,
- /*.eps =*/ eps,
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ /*.eps =*/ eps,
};
auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
- while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+ if (pipeline.c4) {
+ args.ne00 = ne00/4;
+ args.ne0 = ne0/4;
+ }
+
+ int nth = 32; // SIMD width
+
+ while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
}
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
- nth = std::min(nth, ne00/4);
const size_t smem = pipeline.smem;
- const int64_t nrows = ggml_nrows(op->src[0]);
-
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
- ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
return 1;
}
template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
-kernel void kernel_l2_norm_f32(
+template <typename T0, typename T>
+kernel void kernel_l2_norm_impl(
constant ggml_metal_kargs_l2_norm & args,
device const char * src0,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
- uint tgpig[[threadgroup_position_in_grid]],
- ushort tpitg[[thread_position_in_threadgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort ntg[[threads_per_threadgroup]]) {
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i03 = tgpig.z;
+ const int i02 = tgpig.y;
+ const int i01 = tgpig.x;
+
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
- device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
+ device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
+ device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
float sumf = 0.0f;
// parallel sum
- for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+ for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
sumf += dot(x[i00], x[i00]);
}
sumf = simd_sum(sumf);
const float scale = 1.0f/sqrt(max(sumf, args.eps));
- device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
- for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+ for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
y[i00] = x[i00] * scale;
}
}
+typedef decltype(kernel_l2_norm_impl<float, float>) kernel_l2_norm_t;
+
+template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float, float>;
+template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float4, float4>;
+
kernel void kernel_group_norm_f32(
constant ggml_metal_kargs_group_norm & args,
device const float * src0,