return t->view_src != NULL;
}
-static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
- if (a->type != b->type) {
- return false;
- }
- for (int i = 0; i < GGML_MAX_DIMS; i++) {
- if (a->ne[i] != b->ne[i]) {
- return false;
- }
- if (a->nb[i] != b->nb[i]) {
- return false;
- }
- }
- return true;
-}
-
// ops that return true for this function must not use restrict pointers for their backend implementations
static bool ggml_op_can_inplace(enum ggml_op op) {
switch (op) {
// backend copy
-static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
- if (a->type != b->type) {
- return false;
- }
- for (int i = 0; i < GGML_MAX_DIMS; i++) {
- if (a->ne[i] != b->ne[i]) {
- return false;
- }
- if (a->nb[i] != b->nb[i]) {
- return false;
- }
- }
- return true;
-}
-
void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) {
GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
return (n + m - 1) & ~(m - 1);
}
+// TODO: move to ggml.h?
+static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
+ if (a->type != b->type) {
+ return false;
+ }
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
+ if (a->ne[i] != b->ne[i]) {
+ return false;
+ }
+ if (a->nb[i] != b->nb[i]) {
+ return false;
+ }
+ }
+ return true;
+}
+
//
// logging
//
uint64_t nb2;
uint64_t nb3;
uint64_t offs;
+ uint64_t o1[8];
} ggml_metal_kargs_bin;
typedef struct {
float max_bias;
float m0;
float m1;
- uint16_t n_head_log2;
+ int32_t n_head_log2;
float logit_softcap;
} ggml_metal_kargs_flash_attn_ext;
typedef struct {
int32_t ne00;
int32_t ne00_4;
- uint64_t nb01;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
float eps;
+ int32_t nef1[3];
+ int32_t nef2[3];
+ int32_t nef3[3];
+ uint64_t nbf1[3];
+ uint64_t nbf2[3];
+ uint64_t nbf3[3];
} ggml_metal_kargs_rms_norm;
typedef struct {
float max_bias;
float m0;
float m1;
- uint32_t n_head_log2;
+ int32_t n_head_log2;
} ggml_metal_kargs_soft_max;
typedef struct {
bool has_residency_sets;
bool has_bfloat;
bool use_bfloat;
+ bool use_fusion;
+
+ int debug_fusion;
+
+ // how many times a given op was fused
+ uint64_t fuse_cnt[GGML_OP_COUNT];
size_t max_size;
/*.has_residency_sets =*/ false,
/*.has_bfloat =*/ false,
/*.use_bfloat =*/ false,
+ /*.use_fusion =*/ true,
+ /*.debug_fusion =*/ 0,
+ /*.fuse_cnt =*/ { 0 },
/*.max_size =*/ 0,
/*.name =*/ "",
};
if (ctx->mtl_device == nil) {
ctx->mtl_device = MTLCreateSystemDefaultDevice();
- }
- if (ctx->mtl_device) {
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
- ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == NULL;
+ ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
#endif
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
#else
ctx->use_bfloat = false;
#endif
+ ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
+
+ {
+ const char * val = getenv("GGML_METAL_FUSION_DEBUG");
+ ctx->debug_fusion = val ? atoi(val) : 0;
+ }
+
+ memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
ctx->max_size = ctx->mtl_device.maxBufferLength;
ctx->mtl_device_ref_count--;
if (ctx->mtl_device_ref_count == 0) {
+ if (ctx->debug_fusion > 0) {
+ fprintf(stderr, "%s: fusion stats:\n", __func__);
+ for (int i = 0; i < GGML_OP_COUNT; i++) {
+ if (ctx->fuse_cnt[i] == 0) {
+ continue;
+ }
+
+ // note: cannot use ggml_log here
+ fprintf(stderr, "%s: - %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]);
+ }
+ }
+
if (ctx->mtl_lock) {
[ctx->mtl_lock release];
ctx->mtl_lock = nil;
enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_ADD,
- GGML_METAL_KERNEL_TYPE_ADD_ROW,
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_3,
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_5,
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4,
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2,
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3,
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4,
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5,
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6,
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7,
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8,
GGML_METAL_KERNEL_TYPE_SUB,
- GGML_METAL_KERNEL_TYPE_SUB_ROW,
+ GGML_METAL_KERNEL_TYPE_SUB_ROW_C4,
GGML_METAL_KERNEL_TYPE_MUL,
- GGML_METAL_KERNEL_TYPE_MUL_ROW,
+ GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
GGML_METAL_KERNEL_TYPE_DIV,
- GGML_METAL_KERNEL_TYPE_DIV_ROW,
+ GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
GGML_METAL_KERNEL_TYPE_RMS_NORM,
+ GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,
+ GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,
GGML_METAL_KERNEL_TYPE_L2_NORM,
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
GGML_METAL_KERNEL_TYPE_NORM,
// simd_sum and simd_max requires MTLGPUFamilyApple7
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, add_fuse_3, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, add_fuse_5, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, add_row_c4, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, add_row_c4_fuse_2, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, add_row_c4_fuse_3, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, add_row_c4_fuse_4, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, add_row_c4_fuse_5, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, add_row_c4_fuse_6, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, add_row_c4_fuse_7, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, add_row_c4_fuse_8, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, sub_row_c4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
}
}
-static bool ggml_metal_encode_node(
+static int ggml_metal_encode_node(
ggml_backend_t backend,
int idx,
id<MTLComputeCommandEncoder> encoder,
struct ggml_cgraph * gf = ctx->gf;
- struct ggml_tensor * node = ggml_graph_node(gf, idx);
+ enum ggml_op ops[8];
+
+ struct ggml_tensor ** nodes = ggml_graph_nodes(gf) + idx;
+ struct ggml_tensor * node = nodes[0];
//GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
struct ggml_tensor * dst = node;
if (ggml_is_empty(dst)) {
- return true;
+ return 1;
}
switch (dst->op) {
case GGML_OP_PERMUTE:
{
// noop -> next node
- } return true;
+ } return 1;
default:
{
} break;
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
+ int n_fuse = 1;
+
#if 0
GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
if (src0) {
GGML_ASSERT(src0t == GGML_TYPE_F32);
GGML_ASSERT(src1t == GGML_TYPE_F32);
+ GGML_ASSERT(ggml_is_contiguous_rows(src0));
+ GGML_ASSERT(ggml_is_contiguous_rows(src1));
+
const size_t offs = 0;
bool bcast_row = false;
id<MTLComputePipelineState> pipeline = nil;
- if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
- GGML_ASSERT(ggml_is_contiguous(src0));
-
- // src1 is a row
- GGML_ASSERT(ne11 == 1);
-
- switch (dst->op) {
- case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
- case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
- case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
- case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
- default: GGML_ABORT("fatal error");
- }
-
- bcast_row = true;
- } else {
- switch (dst->op) {
- case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
- case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
- case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
- case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
- default: GGML_ABORT("fatal error");
- }
- }
-
ggml_metal_kargs_bin args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.offs =*/ offs,
+ /*.o1 =*/ { offs_src1 },
};
+ // c[0] = add(a, b[0])
+ // c[1] = add(c[0], b[1])
+ // c[2] = add(c[1], b[2])
+ // ...
+ if (ctx_dev->use_fusion) {
+ ops[0] = GGML_OP_ADD;
+ ops[1] = GGML_OP_ADD;
+ ops[2] = GGML_OP_ADD;
+ ops[3] = GGML_OP_ADD;
+ ops[4] = GGML_OP_ADD;
+ ops[5] = GGML_OP_ADD;
+ ops[6] = GGML_OP_ADD;
+ ops[7] = GGML_OP_ADD;
+
+ size_t offs_fuse;
+ id<MTLBuffer> id_fuse;
+
+ for (n_fuse = 0; n_fuse <= 6; ++n_fuse) {
+ if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
+ break;
+ }
+
+ if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
+ break;
+ }
+
+ // b[0] === b[1] === ...
+ if (!ggml_are_same_layout(nodes[n_fuse]->src[1], nodes[n_fuse + 1]->src[1])) {
+ break;
+ }
+
+ // only fuse nodes if src1 is in the same Metal buffer
+ id_fuse = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse);
+ if (id_fuse != id_src1) {
+ break;
+ }
+
+ ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
+
+ args.o1[n_fuse + 1] = offs_fuse;
+ }
+
+ ++n_fuse;
+
+ if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
+ GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse);
+ }
+ }
+
+ if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ // src1 is a row
+ GGML_ASSERT(ne11 == 1);
+
+ switch (dst->op) {
+ case GGML_OP_ADD:
+ {
+ switch (n_fuse) {
+ case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ].pipeline; break;
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2].pipeline; break;
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3].pipeline; break;
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4].pipeline; break;
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5].pipeline; break;
+ case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6].pipeline; break;
+ case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7].pipeline; break;
+ case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8].pipeline; break;
+ default: GGML_ABORT("fatal error");
+ }
+ } break;
+ case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW_C4].pipeline; break;
+ case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW_C4].pipeline; break;
+ case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW_C4].pipeline; break;
+ default: GGML_ABORT("fatal error");
+ }
+
+ bcast_row = true;
+ } else {
+ switch (dst->op) {
+ case GGML_OP_ADD:
+ {
+ switch (n_fuse) {
+ case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD ].pipeline; break;
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline; break;
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_3].pipeline; break;
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline; break;
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_5].pipeline; break;
+ case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline; break;
+ case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_7].pipeline; break;
+ case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline; break;
+ default: GGML_ABORT("fatal error");
+ }
+ } break;
+ case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
+ case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
+ case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
+ default: GGML_ABORT("fatal error");
+ }
+ }
+
+ if (n_fuse > 1) {
+ id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
+ }
+
[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+ [encoder setBuffer:id_src1 offset:0 atIndex:2];
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
if (bcast_row) {
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} else {
- const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
+ int nth = 32;
+
+ while (16*nth < ne0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
+ nth *= 2;
+ }
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
}
/*.nb2 =*/ pnb2,
/*.nb3 =*/ pnb3,
/*.offs =*/ offs,
+ /*.o1 =*/ { offs_src1},
};
[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+ [encoder setBuffer:id_src1 offset:0 atIndex:2];
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
if (!h_src0) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
- return false;
+ return 0;
}
offs_src0 = 0;
id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
if (!h_src1) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
- return false;
+ return 0;
}
const int64_t neh0 = ne0;
id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
if (!h_dst) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
- return false;
+ return 0;
}
// tokens per expert
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
if (!h_tpe) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
- return false;
+ return 0;
}
// id map
id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
if (!h_ids) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
- return false;
+ return 0;
}
{
case GGML_OP_RMS_NORM:
{
GGML_ASSERT(ne00 % 4 == 0);
- GGML_ASSERT(ggml_is_contiguous_1(src0));
+ GGML_ASSERT(ggml_is_contiguous_rows(src0));
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
+ ggml_metal_kargs_rms_norm args = {
+ /*.ne00 =*/ ne00,
+ /*.ne00_4 =*/ ne00/4,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ /*.eps =*/ eps,
+ /*.nef1 =*/ { ne01 },
+ /*.nef2 =*/ { ne02 },
+ /*.nef3 =*/ { ne03 },
+ /*.nbf1 =*/ { nb01 },
+ /*.nbf2 =*/ { nb02 },
+ /*.nbf3 =*/ { nb03 },
+ };
+
+ size_t offs_fuse[2] = { 0, 0 };
+ id<MTLBuffer> id_fuse[2] = { id_src0, id_src0 };
+
+ // d[0] = rms_norm(a)
+ // d[1] = mul(d[0], b)
+ // d[2] = add(d[1], c)
+ if (ctx_dev->use_fusion) {
+ ops[0] = GGML_OP_RMS_NORM;
+ ops[1] = GGML_OP_MUL;
+ ops[2] = GGML_OP_ADD;
+
+ for (n_fuse = 0; n_fuse <= 1; ++n_fuse) {
+ if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
+ break;
+ }
+
+ if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
+ break;
+ }
+
+ if (nodes[n_fuse + 1]->src[1]->ne[0] != node->ne[0]) {
+ break;
+ }
+
+ if (!ggml_is_contiguous_rows(nodes[n_fuse + 1]->src[1])) {
+ break;
+ }
+
+ if (nodes[n_fuse + 1]->type != GGML_TYPE_F32) {
+ break;
+ }
+
+ ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
+
+ id_fuse[n_fuse] = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse[n_fuse]);
+
+ args.nef1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[1];
+ args.nef2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[2];
+ args.nef3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[3];
+
+ args.nbf1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[1];
+ args.nbf2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[2];
+ args.nbf3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[3];
+ }
+
+ ++n_fuse;
+
+ if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
+ if (n_fuse == 2) {
+ GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL\n", __func__);
+ }
+ if (n_fuse == 3) {
+ GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL + ADD\n", __func__);
+ }
+ }
+ }
+
+ if (n_fuse > 1) {
+ id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
+ }
+
+ id<MTLComputePipelineState> pipeline;
+
+ switch (n_fuse) {
+ case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM ].pipeline; break;
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL ].pipeline; break;
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline; break;
+ default: GGML_ABORT("unsupported n_fuse = %d\n", n_fuse);
+ }
int nth = 32; // SIMD width
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
nth = MIN(nth, ne00/4);
- ggml_metal_kargs_rms_norm args = {
- /*.ne00 =*/ ne00,
- /*.ne00_4 =*/ ne00/4,
- /*.nb01 =*/ nb01,
- /*.eps =*/ eps,
- };
-
[encoder setComputePipelineState:pipeline];
- [encoder setBytes:&args length:sizeof(args) atIndex:0];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+ [encoder setBuffer:id_fuse[0] offset:offs_fuse[0] atIndex:2];
+ [encoder setBuffer:id_fuse[1] offset:offs_fuse[1] atIndex:3];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
- const int64_t nrows = ggml_nrows(src0);
-
- [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_L2_NORM:
{
}
}
- return true;
+ return n_fuse;
}
static enum ggml_status ggml_metal_graph_compute(
struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
ggml_metal_mem_pool_reset(mem_pool);
- for (int idx = node_start; idx < node_end; ++idx) {
+ for (int idx = node_start; idx < node_end;) {
if (should_capture) {
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
}
- const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
+ const int res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
if (should_capture) {
[encoder popDebugGroup];
}
- if (!res) {
+ if (res == 0) {
break;
}
+
+ idx += res;
}
[encoder endEncoding];
// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
// pros: works for non-contiguous tensors, supports broadcast across all dims
// cons: not very efficient
-kernel void kernel_add(
+template <int F>
+kernel void kernel_add_fuse_impl(
constant ggml_metal_kargs_bin & args,
device const char * src0,
device const char * src1,
const int i12 = i02%args.ne12;
const int i11 = i01%args.ne11;
- device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
- device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
- device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
+ device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
+ device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
+
+ device const float * src1_ptr[F];
+ for (short j = 0; j < F; ++j) {
+ src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
+ }
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = i0%args.ne10;
- *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) + *((device float *)(src1_ptr + i10*args.nb10));
+
+ float res = src0_ptr[i0];
+
+#pragma unroll
+ for (short j = 0; j < F; ++j) {
+ res += src1_ptr[j][i10];
+ }
+
+ dst_ptr[i0] = res;
}
}
+typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t;
+
+template [[host_name("kernel_add")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
+template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>;
+template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>;
+template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>;
+template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>;
+template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>;
+template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>;
+template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>;
+
kernel void kernel_sub(
constant ggml_metal_kargs_bin & args,
device const char * src0,
const int i11 = i01%args.ne11;
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
- device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
+ device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i12 = i02%args.ne12;
const int i11 = i01%args.ne11;
- device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
- device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
- device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1;
+ device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
+ device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
+ device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = i0%args.ne10;
const int i12 = i02%args.ne12;
const int i11 = i01%args.ne11;
- device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
- device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
- device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1;
+ device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
+ device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
+ device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = i0%args.ne10;
// assumption: src1 is a row
// broadcast src1 into src0
-kernel void kernel_add_row(
+template <short F>
+kernel void kernel_add_row_c4_fuse_impl(
constant ggml_metal_kargs_bin & args,
- device const float4 * src0,
- device const float4 * src1,
- device float4 * dst,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
uint tpig[[thread_position_in_grid]]) {
+
const uint nb = args.ne00/4;
- dst[tpig] = src0[tpig] + src1[tpig % nb];
+ const uint i = tpig % nb;
+
+ device const float4 * src0_row = (device const float4 *) (src0);
+ device float4 * dst_row = (device float4 *) (dst);
+
+ device const float4 * src1_row[F];
+ for (short j = 0; j < F; ++j) {
+ src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
+ }
+
+ float4 res = src0_row[tpig];
+
+#pragma unroll(F)
+ for (short j = 0; j < F; ++j) {
+ res += src1_row[j][i];
+ }
+
+ dst_row[tpig] = res;
}
-kernel void kernel_sub_row(
+typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t;
+
+template [[host_name("kernel_add_row_c4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>;
+template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>;
+template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>;
+template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>;
+template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>;
+template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>;
+template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>;
+template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>;
+
+template <short F>
+kernel void kernel_sub_row_c4_fuse_impl(
constant ggml_metal_kargs_bin & args,
- device const float4 * src0,
- device const float4 * src1,
- device float4 * dst,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
uint tpig[[thread_position_in_grid]]) {
+
const uint nb = args.ne00/4;
- dst[tpig] = src0[tpig] - src1[tpig % nb];
+ const uint i = tpig % nb;
+
+ device const float4 * src0_row = (device const float4 *) (src0);
+ device float4 * dst_row = (device float4 *) (dst);
+
+ device const float4 * src1_row[F];
+ for (short j = 0; j < F; ++j) {
+ src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
+ }
+
+ float4 res = src0_row[tpig];
+
+#pragma unroll(F)
+ for (short j = 0; j < F; ++j) {
+ res -= src1_row[j][i];
+ }
+
+ dst_row[tpig] = res;
}
-kernel void kernel_mul_row(
+typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t;
+
+template [[host_name("kernel_sub_row_c4")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>;
+
+template <short F>
+kernel void kernel_mul_row_c4_fuse_impl(
constant ggml_metal_kargs_bin & args,
- device const float4 * src0,
- device const float4 * src1,
- device float4 * dst,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
uint tpig[[thread_position_in_grid]]) {
+
const uint nb = args.ne00/4;
- dst[tpig] = src0[tpig] * src1[tpig % nb];
+ const uint i = tpig % nb;
+
+ device const float4 * src0_row = (device const float4 *) (src0);
+ device float4 * dst_row = (device float4 *) (dst);
+
+ device const float4 * src1_row[F];
+ for (short j = 0; j < F; ++j) {
+ src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
+ }
+
+ float4 res = src0_row[tpig];
+
+#pragma unroll(F)
+ for (short j = 0; j < F; ++j) {
+ res *= src1_row[j][i];
+ }
+
+ dst_row[tpig] = res;
}
-kernel void kernel_div_row(
+typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t;
+
+template [[host_name("kernel_mul_row_c4")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>;
+
+template <short F>
+kernel void kernel_div_row_c4_fuse_impl(
constant ggml_metal_kargs_bin & args,
- device const float4 * src0,
- device const float4 * src1,
- device float4 * dst,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
uint tpig[[thread_position_in_grid]]) {
+
const uint nb = args.ne00/4;
- dst[tpig] = src0[tpig] / src1[tpig % nb];
+ const uint i = tpig % nb;
+
+ device const float4 * src0_row = (device const float4 *) (src0);
+ device float4 * dst_row = (device float4 *) (dst);
+
+ device const float4 * src1_row[F];
+ for (short j = 0; j < F; ++j) {
+ src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
+ }
+
+ float4 res = src0_row[tpig];
+
+#pragma unroll(F)
+ for (short j = 0; j < F; ++j) {
+ res /= src1_row[j][i];
+ }
+
+ dst_row[tpig] = res;
}
+typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t;
+
+template [[host_name("kernel_div_row_c4")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
+
kernel void kernel_scale(
device const float * src0,
device float * dst,
}
}
-kernel void kernel_rms_norm(
+// F == 1 : rms_norm (no fuse)
+// F == 2 : rms_norm + mul
+// F == 3 : rms_norm + mul + add
+template <short F>
+kernel void kernel_rms_norm_fuse_impl(
constant ggml_metal_kargs_rms_norm & args,
device const char * src0,
+ device const char * src1_0,
+ device const char * src1_1,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
- uint tgpig[[threadgroup_position_in_grid]],
- ushort tpitg[[thread_position_in_threadgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort ntg[[threads_per_threadgroup]]) {
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
- device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
+ const int i01 = tgpig.x;
+ const int i02 = tgpig.y;
+ const int i03 = tgpig.z;
+
+ device const float4 * x = (device const float4 *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
+
+ device const float4 * f0 = (device const float4 *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
+ device const float4 * f1 = (device const float4 *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
float sumf = 0.0f;
// parallel sum
- for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+ for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
sumf += dot(x[i00], x[i00]);
}
sumf = simd_sum(sumf);
const float mean = sumf/args.ne00;
const float scale = 1.0f/sqrt(mean + args.eps);
- device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
- for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
- y[i00] = x[i00] * scale;
+ device float4 * y = (device float4 *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
+ for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
+ if (F == 1) {
+ y[i00] = (x[i00]*scale);
+ }
+ if (F == 2) {
+ y[i00] = (x[i00]*scale)*f0[i00];
+ }
+ if (F == 3) {
+ y[i00] = (x[i00]*scale)*f0[i00] + f1[i00];
+ }
}
}
+typedef decltype(kernel_rms_norm_fuse_impl<1>) kernel_rms_norm_fuse_t;
+
+template [[host_name("kernel_rms_norm")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<1>;
+template [[host_name("kernel_rms_norm_mul")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<2>;
+template [[host_name("kernel_rms_norm_mul_add")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<3>;
+
kernel void kernel_l2_norm(
constant ggml_metal_kargs_l2_norm & args,
device const char * src0,
const ggml_type type;
const std::array<int64_t, 4> ne;
const std::array<int, 4> nr;
+ int nf; // number of fused ops, nf == 1 -> single op (no fusion)
+
+ bool run_whole_graph() override { return true; }
std::string vars() override {
- return VARS_TO_STR3(type, ne, nr);
+ return VARS_TO_STR4(type, ne, nr, nf);
}
size_t op_size(ggml_tensor * t) override {
test_bin_bcast(op_t op, ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 10, 1, 1},
- std::array<int, 4> nr = {1, 2, 1, 1})
- : op(op), type(type), ne(ne), nr(nr) {}
+ std::array<int, 4> nr = {1, 2, 1, 1},
+ int nf = 1)
+ : op(op), type(type), ne(ne), nr(nr), nf(nf) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
+ GGML_ASSERT(nf <= 8);
+
ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);
ggml_set_name(a, "a");
- ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
- ggml_set_name(b, "b");
+ ggml_tensor * b[8];
+ for (int i = 0; i < nf; ++i) {
+ b[i] = ggml_new_tensor(ctx, type, 4, ne.data());
+ ggml_set_name(b[i], (std::string("b") + std::to_string(i)).c_str());
+ }
// The backward pass supports broadcasting only for GGML_ADD:
- const bool grad_supported = op == ggml_add || ggml_are_same_shape(a, b);
+ const bool grad_supported = op == ggml_add && ggml_are_same_shape(a, b[0]) && nf == 1;
if (grad_supported) {
ggml_set_param(a);
- ggml_set_param(b);
+ ggml_set_param(b[0]);
+ }
+
+ ggml_tensor * out = a;
+
+ for (int i = 0; i < nf; ++i) {
+ out = op(ctx, out, b[i]);
}
- ggml_tensor * out = op(ctx, a, b);
ggml_set_name(out, "out");
return out;
}
};
-// GGML_OP_RMS_NORM + GGML_OP_MUL
-struct test_rms_norm_mul : public test_case {
+// GGML_OP_RMS_NORM + GGML_OP_MUL + GGML_OP_ADD
+struct test_rms_norm_mul_add : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
const float eps;
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
- return "RMS_NORM_MUL";
+ return "RMS_NORM_MUL_ADD";
}
bool run_whole_graph() override { return true; }
return VARS_TO_STR3(type, ne, eps);
}
- test_rms_norm_mul(ggml_type type = GGML_TYPE_F32,
+ test_rms_norm_mul_add(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {64, 5, 4, 3},
float eps = 1e-6f)
: type(type), ne(ne), eps(eps) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
+ ggml_tensor * c = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_set_param(b);
ggml_set_name(b, "b");
+ ggml_set_param(c);
+ ggml_set_name(c, "c");
- // Use a and b early, so we don't end up with an OP_NONE between rms_norm and mul
- a = ggml_add(ctx, a, b);
- ggml_tensor * out = ggml_mul(ctx, ggml_rms_norm(ctx, a, eps), b);
+ // Use a, b and c early, so we don't end up with an OP_NONE between rms_norm and mul
+ a = ggml_add(ctx, ggml_add(ctx, a, b), c);
+ ggml_tensor * out = ggml_add(ctx, ggml_mul(ctx, ggml_rms_norm(ctx, a, eps), b), c);
ggml_set_name(out, "out");
return out;
//add_test_bin_bcast(type, {3, 3, 2560, 1280}, {2, 1, 1, 1});
}
+ // fusion
+ test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1}, 2));
+ test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 2, 1, 1}, 3));
+ test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 1}, 4));
+ test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 2}, 5));
+ test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 2}, 6));
+ test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {10, 5, 4, 3}, {1, 2, 2, 2}, 7));
+ test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {16, 5, 4, 3}, {2, 2, 2, 2}, 8));
+
test_cases.emplace_back(new test_add1());
test_cases.emplace_back(new test_scale());
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f));
test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
}
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
- test_cases.emplace_back(new test_rms_norm_mul(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
+ test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
}
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));