const ggml_tensor * src0 = dst->src[0];
- assert(ggml_is_contiguous_1(src0));
- assert(ggml_is_contiguous_1(dst));
+ assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
+
const int ith = params->ith;
const int nth = params->nth;
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
- for (int i1 = ir0; i1 < ir1; i1++) {
+ for (int ir = ir0; ir < ir1; ++ir) {
+ const int i3 = ir/(ne02*ne01);
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
+
ggml_vec_gelu_f32(nc,
- (float *) ((char *) dst->data + i1*( dst->nb[1])),
- (float *) ((char *) src0->data + i1*(src0->nb[1])));
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const ggml_tensor * src0 = dst->src[0];
- assert(ggml_is_contiguous_1(src0));
- assert(ggml_is_contiguous_1(dst));
+ assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
+
const int ith = params->ith;
const int nth = params->nth;
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
- for (int i1 = ir0; i1 < ir1; i1++) {
+ for (int ir = ir0; ir < ir1; ++ir) {
+ const int i3 = ir/(ne02*ne01);
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
+
ggml_vec_gelu_f16(nc,
- (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
+ (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
+ (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const ggml_tensor * src0 = dst->src[0];
- assert(ggml_is_contiguous_1(src0));
- assert(ggml_is_contiguous_1(dst));
+ assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
+
const int ith = params->ith;
const int nth = params->nth;
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
- for (int i1 = ir0; i1 < ir1; i1++) {
+ for (int ir = ir0; ir < ir1; ++ir) {
+ const int i3 = ir/(ne02*ne01);
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
+
ggml_vec_gelu_erf_f32(nc,
- (float *) ((char *) dst->data + i1*( dst->nb[1])),
- (float *) ((char *) src0->data + i1*(src0->nb[1])));
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const ggml_tensor * src0 = dst->src[0];
- assert(ggml_is_contiguous_1(src0));
- assert(ggml_is_contiguous_1(dst));
+ assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
+
const int ith = params->ith;
const int nth = params->nth;
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
- for (int i1 = ir0; i1 < ir1; i1++) {
+ for (int ir = ir0; ir < ir1; ++ir) {
+ const int i3 = ir/(ne02*ne01);
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
+
ggml_vec_gelu_erf_f16(nc,
- (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
+ (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
+ (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const ggml_tensor * src0 = dst->src[0];
- assert(ggml_is_contiguous_1(src0));
- assert(ggml_is_contiguous_1(dst));
+ assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
+
const int ith = params->ith;
const int nth = params->nth;
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
- for (int i1 = ir0; i1 < ir1; i1++) {
+ for (int ir = ir0; ir < ir1; ++ir) {
+ const int i3 = ir/(ne02*ne01);
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
+
ggml_vec_gelu_quick_f32(nc,
- (float *) ((char *) dst->data + i1*( dst->nb[1])),
- (float *) ((char *) src0->data + i1*(src0->nb[1])));
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const ggml_tensor * src0 = dst->src[0];
- assert(ggml_is_contiguous_1(src0));
- assert(ggml_is_contiguous_1(dst));
+ assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
+
const int ith = params->ith;
const int nth = params->nth;
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
- for (int i1 = ir0; i1 < ir1; i1++) {
+ for (int ir = ir0; ir < ir1; ++ir) {
+ const int i3 = ir/(ne02*ne01);
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
+
ggml_vec_gelu_quick_f16(nc,
- (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
+ (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
+ (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const ggml_tensor * src0 = dst->src[0];
- assert(ggml_is_contiguous_1(src0));
- assert(ggml_is_contiguous_1(dst));
+ assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
+
const int ith = params->ith;
const int nth = params->nth;
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
- for (int i1 = ir0; i1 < ir1; i1++) {
+ for (int ir = ir0; ir < ir1; ++ir) {
+ const int i3 = ir/(ne02*ne01);
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
+
ggml_vec_silu_f32(nc,
- (float *) ((char *) dst->data + i1*( dst->nb[1])),
- (float *) ((char *) src0->data + i1*(src0->nb[1])));
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const ggml_tensor * src0 = dst->src[0];
- assert(ggml_is_contiguous_1(src0));
- assert(ggml_is_contiguous_1(dst));
+ assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
+
const int ith = params->ith;
const int nth = params->nth;
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
- for (int i1 = ir0; i1 < ir1; i1++) {
+ for (int ir = ir0; ir < ir1; ++ir) {
+ const int i3 = ir/(ne02*ne01);
+ const int i2 = (ir - i3*ne02*ne01)/ne01;
+ const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
+
ggml_vec_silu_f16(nc,
- (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
+ (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
+ (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
constant float a5_erf = 1.061405429f;
template<typename T>
-T erf_approx(T x) {
+inline T erf_approx(T x) {
T sign_x = sign(x);
x = fabs(x);
T t = 1.0f / (1.0f + p_erf * x);
return sign_x * y;
}
+template<typename T> T elu_approx(T x);
+
+template<> inline float elu_approx<float>(float x) {
+ return (x > 0.f) ? x : (exp(x) - 1);
+}
+
+template<> inline float4 elu_approx<float4>(float4 x) {
+ float4 res;
+
+ res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
+ res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
+ res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
+ res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
+
+ return res;
+}
+
constant short FC_unary_op [[function_constant(FC_UNARY + 0)]];
constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]];
-template <typename T0, typename T>
+template <typename T0, typename T, typename TC>
kernel void kernel_unary_impl(
constant ggml_metal_kargs_unary & args,
device const char * src0,
}
}
- device const T0 & x = src0_ptr[i0];
+ const TC x = (TC) src0_ptr[i0];
if (FC_OP == OP_UNARY_NUM_SCALE) {
- dst_ptr[i0] = args.scale * x + args.bias;
+ dst_ptr[i0] = (T) (args.scale * x + args.bias);
}
if (FC_OP == OP_UNARY_NUM_FILL) {
- dst_ptr[i0] = args.val;
+ dst_ptr[i0] = (T) args.val;
}
if (FC_OP == OP_UNARY_NUM_CLAMP) {
- dst_ptr[i0] = clamp(x, args.min, args.max);
+ dst_ptr[i0] = (T) clamp(x, args.min, args.max);
}
if (FC_OP == OP_UNARY_NUM_SQR) {
- dst_ptr[i0] = x * x;
+ dst_ptr[i0] = (T) (x * x);
}
if (FC_OP == OP_UNARY_NUM_SQRT) {
- dst_ptr[i0] = sqrt(x);
+ dst_ptr[i0] = (T) sqrt(x);
}
if (FC_OP == OP_UNARY_NUM_SIN) {
- dst_ptr[i0] = sin(x);
+ dst_ptr[i0] = (T) sin(x);
}
if (FC_OP == OP_UNARY_NUM_COS) {
- dst_ptr[i0] = cos(x);
+ dst_ptr[i0] = (T) cos(x);
}
if (FC_OP == OP_UNARY_NUM_LOG) {
- dst_ptr[i0] = log(x);
+ dst_ptr[i0] = (T) log(x);
}
if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) {
- dst_ptr[i0] = T(x > 0.0f)*x + T(x <= 0.0f)*(x * args.slope);
+ dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope));
}
if (FC_OP == OP_UNARY_NUM_TANH) {
- dst_ptr[i0] = precise::tanh(x);
+ dst_ptr[i0] = (T) precise::tanh(x);
}
if (FC_OP == OP_UNARY_NUM_RELU) {
- dst_ptr[i0] = fmax(0.0f, x);
+ dst_ptr[i0] = (T) fmax(0, x);
}
if (FC_OP == OP_UNARY_NUM_SIGMOID) {
- dst_ptr[i0] = 1.0f / (1.0f + exp(-x));
+ dst_ptr[i0] = (T) (1 / (1 + exp(-x)));
}
if (FC_OP == OP_UNARY_NUM_GELU) {
- dst_ptr[i0] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+ dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x))));
}
if (FC_OP == OP_UNARY_NUM_GELU_ERF) {
- dst_ptr[i0] = 0.5f*x*(1.0f + erf_approx(SQRT_2_INV*x));
+ dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x)));
}
if (FC_OP == OP_UNARY_NUM_GELU_QUICK) {
- dst_ptr[i0] = x * (1.0f/(1.0f + exp(GELU_QUICK_COEF*x)));
+ dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x))));
}
if (FC_OP == OP_UNARY_NUM_SILU) {
- dst_ptr[i0] = x / (1.0f + exp(-x));
+ dst_ptr[i0] = (T) (x / (1 + exp(-x)));
}
if (FC_OP == OP_UNARY_NUM_ELU) {
- dst_ptr[i0] = T(x > 0.0f)*x + T(x <= 0.0f)*(exp(x) - 1.0f);
+ dst_ptr[i0] = (T) elu_approx(x);
}
if (FC_OP == OP_UNARY_NUM_NEG) {
- dst_ptr[i0] = -x;
+ dst_ptr[i0] = (T) -x;
}
if (FC_OP == OP_UNARY_NUM_ABS) {
- dst_ptr[i0] = fabs(x);
+ dst_ptr[i0] = (T) fabs(x);
}
if (FC_OP == OP_UNARY_NUM_SGN) {
- dst_ptr[i0] = T(x > 0.0f) - T(x < 0.0f);
+ dst_ptr[i0] = T(x > 0) - T(x < 0);
}
if (FC_OP == OP_UNARY_NUM_STEP) {
- dst_ptr[i0] = T(x > 0.0f);
+ dst_ptr[i0] = T(x > 0);
}
if (FC_OP == OP_UNARY_NUM_HARDSWISH) {
- dst_ptr[i0] = x * fmax(0.0f, fmin(1.0f, x/6.0f + 0.5f));
+ dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5)));
}
if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) {
- dst_ptr[i0] = fmax(0.0f, fmin(1.0f, x/6.0f + 0.5f));
+ dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5));
}
if (FC_OP == OP_UNARY_NUM_EXP) {
- dst_ptr[i0] = exp(x);
+ dst_ptr[i0] = (T) exp(x);
}
if (FC_OP == OP_UNARY_NUM_SOFTPLUS) {
- dst_ptr[i0] = select(log(1.0f + exp(x)), x, x > 20.0f);
+ dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20);
}
if (FC_OP == OP_UNARY_NUM_EXPM1) {
// TODO: precise implementation
- dst_ptr[i0] = exp(x) - 1.0f;
+ dst_ptr[i0] = (T) (exp(x) - 1);
}
}
#undef FC_CNT
}
-typedef decltype(kernel_unary_impl<float, float>) kernel_unary_t;
-
-template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl<float, float>;
-template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl<float4, float4>;
+typedef decltype(kernel_unary_impl<float, float, float>) kernel_unary_t;
+template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl<float, float, float>;
+template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl<float4, float4, float4>;
+template [[host_name("kernel_unary_f16_f16")]] kernel kernel_unary_t kernel_unary_impl<half, half, float>;
+template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl<half4, half4, float4>;
// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
constant short FC_bin_op [[function_constant(FC_BIN + 0)]];