static int ggml_metal_encode_node(
ggml_backend_t backend,
int idx,
+ int idx_end,
id<MTLComputeCommandEncoder> encoder,
struct ggml_metal_mem_pool * mem_pool) {
struct ggml_backend_metal_context * ctx = backend->context;
size_t offs_fuse;
id<MTLBuffer> id_fuse;
- for (n_fuse = 0; n_fuse <= 6; ++n_fuse) {
+ // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing nodes
+ // across splits. idx_end indicates the last node in the current split
+ for (n_fuse = 0; n_fuse <= 6 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
break;
}
ops[1] = GGML_OP_MUL;
ops[2] = GGML_OP_ADD;
- for (n_fuse = 0; n_fuse <= 1; ++n_fuse) {
+ for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
break;
}
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
}
- const int res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
+ const int res = ggml_metal_encode_node(backend, idx, node_end, encoder, mem_pool);
+ if (idx + res > node_end) {
+ GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
+ "https://github.com/ggml-org/llama.cpp/pull/14849");
+ }
if (should_capture) {
[encoder popDebugGroup];