// fuse only ops that start with these operations
// can be expanded when needed
if (node.op() == GGML_OP_ADD ||
+ node.op() == GGML_OP_NORM ||
node.op() == GGML_OP_RMS_NORM) {
ops[0] = node.op();
// can be expanded when needed
if (gf->nodes[f]->op != GGML_OP_ADD &&
gf->nodes[f]->op != GGML_OP_MUL &&
+ gf->nodes[f]->op != GGML_OP_NORM &&
gf->nodes[f]->op != GGML_OP_RMS_NORM) {
break;
}
return res;
}
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) {
- assert(op->op == GGML_OP_RMS_NORM);
-
- GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
- GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
-
- char base[256];
- char name[256];
-
- switch (n_fuse) {
- case 1: snprintf(base, 256, "kernel_rms_norm_f32"); break;
- case 2: snprintf(base, 256, "kernel_rms_norm_mul_f32"); break;
- case 3: snprintf(base, 256, "kernel_rms_norm_mul_add_f32"); break;
- default: GGML_ABORT("fatal error");
- }
-
- snprintf(name, 256, "%s", base);
-
- ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
- if (res) {
- return res;
- }
-
- res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
- ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
-
- return res;
-}
-
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_L2_NORM);
return res;
}
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
- assert(op->op == GGML_OP_NORM);
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) {
+ assert(op->op == GGML_OP_NORM || op->op == GGML_OP_RMS_NORM);
- GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
- GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
+ GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
char base[256];
char name[256];
- snprintf(base, 256, "kernel_norm_f32");
+ const char * suffix = "";
+ if (op->ne[0] % 4 == 0) {
+ suffix = "_4";
+ }
+
+ switch (op->op) {
+ case GGML_OP_NORM:
+ switch (n_fuse) {
+ case 1: snprintf(base, 256, "kernel_norm_f32%s", suffix); break;
+ case 2: snprintf(base, 256, "kernel_norm_mul_f32%s", suffix); break;
+ case 3: snprintf(base, 256, "kernel_norm_mul_add_f32%s", suffix); break;
+ default: GGML_ABORT("fatal error");
+ } break;
+ case GGML_OP_RMS_NORM:
+ switch (n_fuse) {
+ case 1: snprintf(base, 256, "kernel_rms_norm_f32%s", suffix); break;
+ case 2: snprintf(base, 256, "kernel_rms_norm_mul_f32%s", suffix); break;
+ case 3: snprintf(base, 256, "kernel_rms_norm_mul_add_f32%s", suffix); break;
+ default: GGML_ABORT("fatal error");
+ } break;
+ default: GGML_ABORT("fatal error");
+ }
+
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
case GGML_OP_SOFT_MAX:
case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
- case GGML_OP_RMS_NORM:
case GGML_OP_L2_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
case GGML_OP_ARGMAX:
return has_simdgroup_reduction;
case GGML_OP_NORM:
- return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
+ case GGML_OP_RMS_NORM:
+ return has_simdgroup_reduction && (ggml_is_contiguous_rows(op->src[0]));
case GGML_OP_ROPE:
return true;
case GGML_OP_IM2COL:
uint64_t nb1;
} ggml_metal_kargs_mul_mv_id;
+// NORM
+// RMS_NORM
typedef struct {
int32_t ne00;
- int32_t ne00_4;
- uint64_t nb01;
- float eps;
-} ggml_metal_kargs_norm;
-
-typedef struct {
- int32_t ne00;
- int32_t ne00_4;
+ int32_t ne00_t;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
uint64_t nbf1[3];
uint64_t nbf2[3];
uint64_t nbf3[3];
-} ggml_metal_kargs_rms_norm;
+} ggml_metal_kargs_norm;
typedef struct {
int32_t ne00;
{
n_fuse = ggml_metal_op_set_rows(ctx, idx);
} break;
- case GGML_OP_RMS_NORM:
- {
- n_fuse = ggml_metal_op_rms_norm(ctx, idx);
- } break;
case GGML_OP_L2_NORM:
{
n_fuse = ggml_metal_op_l2_norm(ctx, idx);
n_fuse = ggml_metal_op_group_norm(ctx, idx);
} break;
case GGML_OP_NORM:
+ case GGML_OP_RMS_NORM:
{
n_fuse = ggml_metal_op_norm(ctx, idx);
} break;
return n_fuse;
}
-int ggml_metal_op_rms_norm(ggml_metal_op_t ctx, int idx) {
- ggml_cgraph * gf = ctx->gf;
- ggml_tensor * op = ggml_graph_node(gf, idx);
-
- ggml_metal_library_t lib = ctx->lib;
- ggml_metal_encoder_t enc = ctx->enc;
-
- const int idx_end = ctx->idx_end;
-
- const bool use_fusion = ctx->use_fusion;
-
- const int debug_fusion = ctx->debug_fusion;
-
- ggml_tensor ** ops = ggml_graph_nodes(gf) + idx;
-
- GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
- GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
- GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
-
- float eps;
- memcpy(&eps, op->op_params, sizeof(float));
-
- ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
- ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
-
- ggml_metal_kargs_rms_norm args = {
- /*.ne00 =*/ ne00,
- /*.ne00_4 =*/ ne00/4,
- /*.nb1 =*/ nb1,
- /*.nb2 =*/ nb2,
- /*.nb3 =*/ nb3,
- /*.eps =*/ eps,
- /*.nef1 =*/ { ne01 },
- /*.nef2 =*/ { ne02 },
- /*.nef3 =*/ { ne03 },
- /*.nbf1 =*/ { nb01 },
- /*.nbf2 =*/ { nb02 },
- /*.nbf3 =*/ { nb03 },
- };
-
- ggml_op fops[8];
-
- int n_fuse = 1;
-
- ggml_metal_buffer_id bid_fuse[2] = { bid_src0, bid_src0 };
-
- // d[0] = rms_norm(a)
- // d[1] = mul(d[0], b)
- // d[2] = add(d[1], c)
- if (use_fusion) {
- fops[0] = GGML_OP_RMS_NORM;
- fops[1] = GGML_OP_MUL;
- fops[2] = GGML_OP_ADD;
-
- for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
- if (!ggml_can_fuse(gf, idx + n_fuse, fops + n_fuse, 2)) {
- break;
- }
-
- if (ops[n_fuse] != ops[n_fuse + 1]->src[0]) {
- break;
- }
-
- if (ops[n_fuse + 1]->src[1]->ne[0] != op->ne[0]) {
- break;
- }
-
- if (!ggml_is_contiguous_rows(ops[n_fuse + 1]->src[1])) {
- break;
- }
-
- if (ops[n_fuse + 1]->type != GGML_TYPE_F32) {
- break;
- }
-
- //ctx->fuse_cnt[ops[n_fuse + 1]->op]++;
-
- bid_fuse[n_fuse] = ggml_metal_get_buffer_id(ops[n_fuse + 1]->src[1]);
-
- args.nef1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[1];
- args.nef2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[2];
- args.nef3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[3];
-
- args.nbf1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[1];
- args.nbf2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[2];
- args.nbf3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[3];
- }
-
- ++n_fuse;
-
- if (debug_fusion > 1 && n_fuse > 1) {
- if (n_fuse == 2) {
- GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL\n", __func__);
- }
- if (n_fuse == 3) {
- GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL + ADD\n", __func__);
- }
- }
- }
-
- if (n_fuse > 1) {
- bid_dst = ggml_metal_get_buffer_id(ops[n_fuse - 1]);
-
- for (int i = 1; i < n_fuse; ++i) {
- if (!ggml_metal_op_concurrency_check(ctx, ops[i])) {
- ggml_metal_op_concurrency_reset(ctx);
-
- break;
- }
- }
- }
-
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rms_norm(lib, op, n_fuse);
-
- int nth = 32; // SIMD width
-
- while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
- nth *= 2;
- }
-
- nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
- nth = std::min(nth, ne00/4);
-
- const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
-
- ggml_metal_encoder_set_pipeline(enc, pipeline);
- ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
- ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
- ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 2);
- ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 3);
- ggml_metal_encoder_set_buffer (enc, bid_dst, 4);
-
- ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
-
- ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
-
- return n_fuse;
-}
-
int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
+ const int idx_end = ctx->idx_end;
+
+ const bool use_fusion = ctx->use_fusion;
+
+ const int debug_fusion = ctx->debug_fusion;
+
+ ggml_tensor ** ops = ggml_graph_nodes(gf) + idx;
+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
float eps;
memcpy(&eps, op->op_params, sizeof(float));
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
+
ggml_metal_kargs_norm args = {
/*.ne00 =*/ ne00,
- /*.ne00_4 =*/ ne00/4,
- /*.nb01 =*/ nb01,
+ /*.ne00_t =*/ ne00 % 4 == 0 ? ne00/4 : ne00,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
/*.eps =*/ eps,
+ /*.nef1 =*/ { ne01 },
+ /*.nef2 =*/ { ne02 },
+ /*.nef3 =*/ { ne03 },
+ /*.nbf1 =*/ { nb01 },
+ /*.nbf2 =*/ { nb02 },
+ /*.nbf3 =*/ { nb03 },
};
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_norm(lib, op);
+ ggml_op fops[8];
+
+ int n_fuse = 1;
+
+ ggml_metal_buffer_id bid_fuse[2] = { bid_src0, bid_src0 };
+
+ // d[0] = norm(a)
+ // d[1] = mul(d[0], b)
+ // d[2] = add(d[1], c)
+ if (use_fusion) {
+ fops[0] = op->op;
+ fops[1] = GGML_OP_MUL;
+ fops[2] = GGML_OP_ADD;
+
+ for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
+ if (!ggml_can_fuse(gf, idx + n_fuse, fops + n_fuse, 2)) {
+ break;
+ }
+
+ if (ops[n_fuse] != ops[n_fuse + 1]->src[0]) {
+ break;
+ }
+
+ if (ops[n_fuse + 1]->src[1]->ne[0] != op->ne[0]) {
+ break;
+ }
+
+ if (!ggml_is_contiguous_rows(ops[n_fuse + 1]->src[1])) {
+ break;
+ }
+
+ if (ops[n_fuse + 1]->type != GGML_TYPE_F32) {
+ break;
+ }
+
+ //ctx->fuse_cnt[ops[n_fuse + 1]->op]++;
+
+ bid_fuse[n_fuse] = ggml_metal_get_buffer_id(ops[n_fuse + 1]->src[1]);
+
+ args.nef1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[1];
+ args.nef2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[2];
+ args.nef3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[3];
+
+ args.nbf1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[1];
+ args.nbf2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[2];
+ args.nbf3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[3];
+ }
+
+ ++n_fuse;
+
+ if (debug_fusion > 1 && n_fuse > 1) {
+ if (n_fuse == 2) {
+ GGML_LOG_DEBUG("%s: fuse: %s + MUL\n", __func__, ggml_op_name(op->op));
+ }
+ if (n_fuse == 3) {
+ GGML_LOG_DEBUG("%s: fuse: %s + MUL + ADD\n", __func__, ggml_op_name(op->op));
+ }
+ }
+ }
+
+ if (n_fuse > 1) {
+ bid_dst = ggml_metal_get_buffer_id(ops[n_fuse - 1]);
+
+ for (int i = 1; i < n_fuse; ++i) {
+ if (!ggml_metal_op_concurrency_check(ctx, ops[i])) {
+ ggml_metal_op_concurrency_reset(ctx);
+
+ break;
+ }
+ }
+ }
+
+ ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
int nth = 32; // SIMD width
- while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+
+ while (nth < args.ne00_t && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
}
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
- nth = std::min(nth, ne00/4);
+ nth = std::min(nth, args.ne00_t);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
- const int64_t nrows = ggml_nrows(op->src[0]);
-
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
+ ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 2);
+ ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 3);
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 4);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
- ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
- return 1;
+ return n_fuse;
}
int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx);
-int ggml_metal_op_rms_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
return as_type<float>(bits);
}
+static inline float dot(float x, float y) {
+ return x*y;
+}
+
// NOTE: this is not dequantizing - we are simply fitting the template
template <typename type4x4>
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
dst_i32[tgpig] = arg_val;
}
-kernel void kernel_norm_f32(
+// F == 1 : norm (no fuse)
+// F == 2 : norm + mul
+// F == 3 : norm + mul + add
+template <typename T, short F>
+kernel void kernel_norm_fuse_impl(
constant ggml_metal_kargs_norm & args,
device const char * src0,
+ device const char * src1_0,
+ device const char * src1_1,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
- uint tgpig[[threadgroup_position_in_grid]],
- ushort tpitg[[thread_position_in_threadgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort ntg[[threads_per_threadgroup]]) {
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
- device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
+ const int i01 = tgpig.x;
+ const int i02 = tgpig.y;
+ const int i03 = tgpig.z;
- float4 sumf4(0.0f);
+ device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
+
+ device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
+ device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
+
+ T sumft(0.0f);
float sumf = 0.0f;
- for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
- sumf4 += x[i00];
+ for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
+ sumft += x[i00];
}
- sumf = sumf4[0] + sumf4[1] + sumf4[2] + sumf4[3];
+ sumf = dot(sumft, T(1.0f));
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
const float mean = sumf/args.ne00;
- device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
+ device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
sumf = 0.0f;
- for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+ for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
y[i00] = x[i00] - mean;
sumf += dot(y[i00], y[i00]);
}
const float variance = sumf/args.ne00;
const float scale = 1.0f/sqrt(variance + args.eps);
- for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
- y[i00] = y[i00] * scale;
+ for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
+ if (F == 1) {
+ y[i00] = (y[i00]*scale);
+ }
+ if (F == 2) {
+ y[i00] = (y[i00]*scale)*f0[i00];
+ }
+ if (F == 3) {
+ y[i00] = (y[i00]*scale)*f0[i00] + f1[i00];
+ }
}
}
+typedef decltype(kernel_norm_fuse_impl<float4, 1>) kernel_norm_fuse_t;
+
+template [[host_name("kernel_norm_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 1>;
+template [[host_name("kernel_norm_mul_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 2>;
+template [[host_name("kernel_norm_mul_add_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 3>;
+
+template [[host_name("kernel_norm_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 1>;
+template [[host_name("kernel_norm_mul_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 2>;
+template [[host_name("kernel_norm_mul_add_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 3>;
+
// F == 1 : rms_norm (no fuse)
// F == 2 : rms_norm + mul
// F == 3 : rms_norm + mul + add
-template <short F>
+template <typename T, short F>
kernel void kernel_rms_norm_fuse_impl(
- constant ggml_metal_kargs_rms_norm & args,
+ constant ggml_metal_kargs_norm & args,
device const char * src0,
device const char * src1_0,
device const char * src1_1,
const int i02 = tgpig.y;
const int i03 = tgpig.z;
- device const float4 * x = (device const float4 *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
+ device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
- device const float4 * f0 = (device const float4 *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
- device const float4 * f1 = (device const float4 *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
+ device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
+ device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
float sumf = 0.0f;
// parallel sum
- for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
+ for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
sumf += dot(x[i00], x[i00]);
}
sumf = simd_sum(sumf);
const float mean = sumf/args.ne00;
const float scale = 1.0f/sqrt(mean + args.eps);
- device float4 * y = (device float4 *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
- for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
+ device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
+ for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
if (F == 1) {
y[i00] = (x[i00]*scale);
}
}
}
-typedef decltype(kernel_rms_norm_fuse_impl<1>) kernel_rms_norm_fuse_t;
+typedef decltype(kernel_rms_norm_fuse_impl<float4, 1>) kernel_rms_norm_fuse_t;
+
+template [[host_name("kernel_rms_norm_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 1>;
+template [[host_name("kernel_rms_norm_mul_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 2>;
+template [[host_name("kernel_rms_norm_mul_add_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 3>;
-template [[host_name("kernel_rms_norm_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<1>;
-template [[host_name("kernel_rms_norm_mul_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<2>;
-template [[host_name("kernel_rms_norm_mul_add_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<3>;
+template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 1>;
+template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
+template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
kernel void kernel_l2_norm_f32(
constant ggml_metal_kargs_l2_norm & args,
test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
}
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
- test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
+ test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));