--- /dev/null
+#include "ggml-metal-common.h"
+
+#include "ggml-impl.h"
+
+#include <vector>
+
+struct ggml_mem_range {
+ uint64_t pb; // buffer id
+
+ uint64_t p0; // begin
+ uint64_t p1; // end
+
+ ggml_mem_range_type pt;
+};
+
+struct ggml_mem_ranges {
+ std::vector<ggml_mem_range> ranges;
+
+ int debug = 0;
+};
+
+struct ggml_mem_ranges * ggml_mem_ranges_init(int debug) {
+ auto * res = new ggml_mem_ranges;
+
+ res->ranges.reserve(256);
+ res->debug = debug;
+
+ return res;
+}
+
+void ggml_mem_ranges_free(ggml_mem_ranges * mrs) {
+ delete mrs;
+}
+
+void ggml_mem_ranges_reset(ggml_mem_ranges * mrs) {
+ mrs->ranges.clear();
+}
+
+static bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, ggml_mem_range mrp) {
+ mrs->ranges.push_back(mrp);
+
+ return true;
+}
+
+static ggml_mem_range ggml_mem_range_from_tensor(const ggml_tensor * tensor, ggml_mem_range_type pt) {
+ // always use the base tensor
+ tensor = tensor->view_src ? tensor->view_src : tensor;
+
+ GGML_ASSERT(!tensor->view_src);
+
+ ggml_mem_range mrp;
+
+ if (tensor->buffer) {
+ // when the tensor is allocated, use the actual memory address range of the buffer
+ mrp = {
+ /*.pb =*/ (uint64_t) tensor->buffer,
+ /*.p0 =*/ (uint64_t) tensor->data,
+ /*.p1 =*/ (uint64_t) tensor->data + ggml_nbytes(tensor),
+ /*.pt =*/ pt,
+ };
+ } else {
+ // otherwise, the tensor ptr is used as an unique id of the memory ranges
+ // that the tensor will be using when it is allocated
+ mrp = {
+ /*.pb =*/ (uint64_t) tensor,
+ /*.p0 =*/ 0, //
+ /*.p1 =*/ 1024, // [0, 1024) is a dummy range, not used
+ /*.pt =*/ pt,
+ };
+ };
+
+ return mrp;
+}
+
+static ggml_mem_range ggml_mem_range_from_tensor_src(const ggml_tensor * tensor) {
+ return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_SRC);
+}
+
+static ggml_mem_range ggml_mem_range_from_tensor_dst(const ggml_tensor * tensor) {
+ return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_DST);
+}
+
+static bool ggml_mem_ranges_add_src(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
+ GGML_ASSERT(tensor);
+
+ ggml_mem_range mrp = ggml_mem_range_from_tensor_src(tensor);
+
+ if (mrs->debug > 2) {
+ GGML_LOG_DEBUG("%s: add src range buf=%lld, [%lld, %lld)\n", __func__, mrp.pb, mrp.p0, mrp.p1);
+ }
+
+ return ggml_mem_ranges_add(mrs, mrp);
+}
+
+static bool ggml_mem_ranges_add_dst(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
+ GGML_ASSERT(tensor);
+
+ ggml_mem_range mrp = ggml_mem_range_from_tensor_dst(tensor);
+
+ if (mrs->debug > 2) {
+ GGML_LOG_DEBUG("%s: add dst range buf=%lld, [%lld, %lld)\n", __func__, mrp.pb, mrp.p0, mrp.p1);
+ }
+
+ return ggml_mem_ranges_add(mrs, mrp);
+}
+
+bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
+ if (tensor->src[i]) {
+ ggml_mem_ranges_add_src(mrs, tensor->src[i]);
+ }
+ }
+
+ return ggml_mem_ranges_add_dst(mrs, tensor);
+}
+
+static bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, ggml_mem_range mrp) {
+ for (size_t i = 0; i < mrs->ranges.size(); i++) {
+ const auto & cmp = mrs->ranges[i];
+
+ if (mrp.pb != cmp.pb) {
+ continue;
+ }
+
+ if (mrp.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) {
+ continue;
+ }
+
+ if (mrp.p0 < cmp.p1 && mrp.p1 >= cmp.p0) {
+ if (mrs->debug > 2) {
+ GGML_LOG_DEBUG("%s: the %s range buf=%lld, [%lld, %lld) overlaps with a previous %s range buf=%lld, [%lld, %lld)\n",
+ __func__,
+ mrp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
+ mrp.pb, mrp.p0, mrp.p1,
+ cmp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
+ cmp.pb, cmp.p0, cmp.p1);
+ }
+
+ return false;
+ }
+ }
+
+ return true;
+}
+
+static bool ggml_mem_ranges_check_src(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
+ GGML_ASSERT(tensor);
+
+ ggml_mem_range mrp = ggml_mem_range_from_tensor_src(tensor);
+
+ const bool res = ggml_mem_ranges_check(mrs, mrp);
+
+ return res;
+}
+
+static bool ggml_mem_ranges_check_dst(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
+ GGML_ASSERT(tensor);
+
+ ggml_mem_range mrp = ggml_mem_range_from_tensor_dst(tensor);
+
+ const bool res = ggml_mem_ranges_check(mrs, mrp);
+
+ return res;
+}
+
+bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
+ if (tensor->src[i]) {
+ if (!ggml_mem_ranges_check_src(mrs, tensor->src[i])) {
+ return false;
+ }
+ }
+ }
+
+ return ggml_mem_ranges_check_dst(mrs, tensor);
+}
+
+// TODO: move to ggml.h?
+static bool is_empty(ggml_op op) {
+ switch (op) {
+ case GGML_OP_NONE:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_VIEW:
+ case GGML_OP_PERMUTE:
+ return true;
+ default:
+ return false;
+ }
+}
+
+struct node_info {
+ ggml_tensor * node;
+
+ std::vector<ggml_tensor *> fused;
+
+ ggml_op op() const {
+ return node->op;
+ }
+
+ const ggml_tensor * dst() const {
+ return fused.empty() ? node : fused.back();
+ }
+
+ bool is_empty() const {
+ return ::is_empty(node->op);
+ }
+
+ void add_fused(ggml_tensor * t) {
+ fused.push_back(t);
+ }
+};
+
+static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node_info> & nodes) {
+ // helper to add node src and dst ranges
+ const auto & h_add = [](ggml_mem_ranges * mrs, const node_info & node) {
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
+ if (node.node->src[i]) {
+ if (!ggml_mem_ranges_add_src(mrs, node.node->src[i])) {
+ return false;
+ }
+ }
+ }
+
+ for (const auto * fused : node.fused) {
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
+ if (fused->src[i]) {
+ if (!ggml_mem_ranges_add_src(mrs, fused->src[i])) {
+ return false;
+ }
+ }
+ }
+ }
+
+ return ggml_mem_ranges_add_dst(mrs, node.dst());
+ };
+
+ // helper to check if a node can run concurrently with the existing set of nodes
+ const auto & h_check = [](const ggml_mem_ranges * mrs, const node_info & node) {
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
+ if (node.node->src[i]) {
+ if (!ggml_mem_ranges_check_src(mrs, node.node->src[i])) {
+ return false;
+ }
+ }
+ }
+
+ for (const auto * fused : node.fused) {
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
+ if (fused->src[i]) {
+ if (!ggml_mem_ranges_check_src(mrs, fused->src[i])) {
+ return false;
+ }
+ }
+ }
+ }
+
+ return ggml_mem_ranges_check_dst(mrs, node.dst());
+ };
+
+ // perform reorders only across these types of ops
+ // can be expanded when needed
+ // IMPORTANT: do not add ops such as GGML_OP_CPY or GGML_OP_SET_ROWS
+ // the dependencies from such ops are not always represented in the graph
+ const auto & h_safe = [](ggml_op op) {
+ switch (op) {
+ case GGML_OP_MUL_MAT:
+ case GGML_OP_MUL_MAT_ID:
+ case GGML_OP_ROPE:
+ case GGML_OP_NORM:
+ case GGML_OP_RMS_NORM:
+ case GGML_OP_GROUP_NORM:
+ case GGML_OP_SUM_ROWS:
+ case GGML_OP_MUL:
+ case GGML_OP_ADD:
+ case GGML_OP_DIV:
+ case GGML_OP_GLU:
+ case GGML_OP_SCALE:
+ case GGML_OP_GET_ROWS:
+ return true;
+ default:
+ return is_empty(op);
+ }
+ };
+
+ const int n = nodes.size();
+
+ std::vector<int> res;
+ res.reserve(n);
+
+ std::vector<bool> used(n, false);
+
+ ggml_mem_ranges * mrs0 = ggml_mem_ranges_init(0);
+ ggml_mem_ranges * mrs1 = ggml_mem_ranges_init(0);
+
+ for (int i0 = 0; i0 < n; i0++) {
+ if (used[i0]) {
+ continue;
+ }
+
+ const auto & node0 = nodes[i0];
+
+ // the node is not concurrent with the existing concurrent set, so we have to "put a barrier" (i.e reset mrs0)
+ // but before we do that, look forward for some other nodes that can be added to the concurrent set mrs0
+ //
+ // note: we can always add empty nodes to the concurrent set as they don't read nor write anything
+ if (!node0.is_empty() && !h_check(mrs0, node0)) {
+ // this will hold the set of memory ranges from the nodes that haven't been processed yet
+ // if a node is not concurrent with this set, we cannot reorder it
+ ggml_mem_ranges_reset(mrs1);
+
+ // initialize it with the current node
+ h_add(mrs1, node0);
+
+ // that many nodes forward to search for a concurrent node
+ constexpr int N_FORWARD = 8;
+
+ for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
+ if (used[i1]) {
+ continue;
+ }
+
+ const auto & node1 = nodes[i1];
+
+ // disallow reordering of certain ops
+ if (!h_safe(node1.op())) {
+ break;
+ }
+
+ const bool is_empty = node1.is_empty();
+
+ // to add a concurrent node, it has to be:
+ // + empty or concurrent with all nodes in the existing concurrent set (mrs0)
+ // + concurrent with all nodes prior to it that haven't been processed yet (mrs1)
+ if ((is_empty || h_check(mrs0, node1)) && h_check(mrs1, node1)) {
+ // add the node to the existing concurrent set (i.e. reorder it for early execution)
+ h_add(mrs0, node1);
+ res.push_back(i1);
+
+ // mark as used, so we skip re-processing it later
+ used[i1] = true;
+ } else {
+ // expand the set of nodes that haven't been processed yet
+ h_add(mrs1, node1);
+ }
+ }
+
+ // finalize the concurrent set and begin a new one
+ ggml_mem_ranges_reset(mrs0);
+ }
+
+ // expand the concurrent set with the current node
+ {
+ h_add(mrs0, node0);
+ res.push_back(i0);
+ }
+ }
+
+ ggml_mem_ranges_free(mrs0);
+ ggml_mem_ranges_free(mrs1);
+
+ return res;
+}
+
+void ggml_metal_graph_optimize(ggml_cgraph * gf) {
+ constexpr int MAX_FUSE = 16;
+
+ const int n = gf->n_nodes;
+
+ enum ggml_op ops[MAX_FUSE];
+
+ std::vector<node_info> nodes;
+ nodes.reserve(gf->n_nodes);
+
+ // fuse nodes:
+ // we don't want to make reorders that break fusing, so we first pack all fusable tensors
+ // and perform the reorder over the fused nodes. after the reorder is done, we unfuse
+ for (int i = 0; i < n; i++) {
+ node_info node = {
+ /*.node =*/ gf->nodes[i],
+ /*.fused =*/ {},
+ };
+
+ // fuse only ops that start with these operations
+ // can be expanded when needed
+ if (node.op() == GGML_OP_ADD ||
+ node.op() == GGML_OP_RMS_NORM) {
+ ops[0] = node.op();
+
+ int f = i + 1;
+ while (f < n && f < i + MAX_FUSE) {
+ // conservatively allow fusing only these ops
+ // can be expanded when needed
+ if (gf->nodes[f]->op != GGML_OP_ADD &&
+ gf->nodes[f]->op != GGML_OP_MUL &&
+ gf->nodes[f]->op != GGML_OP_RMS_NORM) {
+ break;
+ }
+ ops[f - i] = gf->nodes[f]->op;
+ f++;
+ }
+
+ f -= i;
+ for (; f > 1; f--) {
+ if (ggml_can_fuse(gf, i, ops, f)) {
+ break;
+ }
+ }
+
+ // add the fused tensors into the node info so we can unfuse them later
+ for (int k = 1; k < f; k++) {
+ ++i;
+
+ // the .dst() becomes the last fused tensor
+ node.add_fused(gf->nodes[i]);
+ }
+ }
+
+ nodes.push_back(std::move(node));
+ }
+
+ // reorder to improve concurrency
+#if 1
+ const auto order = ggml_metal_graph_optimize_reorder(nodes);
+#else
+ std::vector<int> order(nodes.size());
+ for (size_t i = 0; i < nodes.size(); i++) {
+ order[i] = i;
+ }
+#endif
+
+ // unfuse
+ {
+ int j = 0;
+ for (const auto i : order) {
+ const auto & node = nodes[i];
+
+ gf->nodes[j++] = node.node;
+
+ for (auto * fused : node.fused) {
+ gf->nodes[j++] = fused;
+ }
+ }
+ }
+}
#import "ggml-impl.h"
#import "ggml-backend-impl.h"
#import "ggml-metal-impl.h"
+#import "ggml-metal-common.h"
#import <Foundation/Foundation.h>
bool has_bfloat;
bool use_bfloat;
bool use_fusion;
+ bool use_concurrency;
bool use_shared_buffers;
+ bool use_graph_optimize;
+ int debug_graph;
int debug_fusion;
// how many times a given op was fused
/*.has_bfloat =*/ false,
/*.use_bfloat =*/ false,
/*.use_fusion =*/ true,
+ /*.use_concurrency =*/ true,
/*.use_shared_buffers =*/ true,
+ /*.use_graph_optimize =*/ true,
+ /*.debug_graph =*/ 0,
/*.debug_fusion =*/ 0,
/*.fuse_cnt =*/ { 0 },
/*.max_size =*/ 0,
#else
ctx->use_bfloat = false;
#endif
- ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
+
+ ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
+ ctx->use_concurrency = getenv("GGML_METAL_CONCURRENCY_DISABLE") == nil;
+
+ {
+ const char * val = getenv("GGML_METAL_GRAPH_DEBUG");
+ ctx->debug_graph = val ? atoi(val) : 0;
+ }
{
const char * val = getenv("GGML_METAL_FUSION_DEBUG");
ctx->use_shared_buffers = false;
}
+ ctx->use_graph_optimize = true;
+
+ if (getenv("GGML_METAL_GRAPH_OPTIMIZE_DISABLE") != NULL) {
+ ctx->use_graph_optimize = false;
+ }
+
memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
ctx->max_size = ctx->mtl_device.maxBufferLength;
@end
//
-// ggml_metal_mem_pool
+// ggml_metal_mem_pool [TAG_MEM_POOL_REMOVE]
//
struct ggml_metal_mem_pool {
// each command buffer has a memory pool from which it can allocate temporary buffers during the compute
struct ggml_metal_mem_pool * mem_pool;
+
+ // used to enable concurrent execution of ops in the command buffers
+ struct ggml_mem_ranges * mem_ranges;
};
struct ggml_backend_metal_context {
GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false");
GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false");
GGML_LOG_INFO("%s: use fusion = %s\n", __func__, ctx_dev->use_fusion ? "true" : "false");
+ GGML_LOG_INFO("%s: use concurrency = %s\n", __func__, ctx_dev->use_concurrency ? "true" : "false");
GGML_LOG_INFO("%s: use shared buffers = %s\n", __func__, ctx_dev->use_shared_buffers ? "true" : "false");
+ GGML_LOG_INFO("%s: use graph optimize = %s\n", __func__, ctx_dev->use_graph_optimize ? "true" : "false");
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
ctx->capture_next_compute = false;
ctx->cmd_bufs[i].mem_pool = ggml_metal_mem_pool_init();
ctx->cmd_bufs[i].mem_pool->device = device;
+
+ if (ctx_dev->use_concurrency) {
+ ctx->cmd_bufs[i].mem_ranges = ggml_mem_ranges_init(ctx_dev->debug_graph);
+ }
}
ctx->cmd_bufs_ext = [[NSMutableArray alloc] init];
}
ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
+
+ if (ctx->cmd_bufs[i].mem_ranges) {
+ ggml_mem_ranges_free(ctx->cmd_bufs[i].mem_ranges);
+ }
}
[ctx->cmd_bufs_ext removeAllObjects];
}
}
-static int ggml_metal_encode_node(
- ggml_backend_t backend,
- int idx,
- int idx_end,
- id<MTLComputeCommandEncoder> encoder,
- struct ggml_metal_mem_pool * mem_pool) {
+struct ggml_metal_encode_context {
+ ggml_backend_t backend;
+
+ id<MTLComputeCommandEncoder> encoder;
+
+ struct ggml_metal_mem_pool * mem_pool;
+
+ struct ggml_mem_ranges * mem_ranges;
+};
+
+static bool ggml_metal_encode_concurrency_reset(struct ggml_metal_encode_context * ctx) {
+ if (!ctx->mem_ranges) {
+ return true;
+ }
+
+ [ctx->encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
+
+ ggml_mem_ranges_reset(ctx->mem_ranges);
+
+ return true;
+}
+
+static bool ggml_metal_encode_concurrency_check(struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
+ if (!ctx->mem_ranges) {
+ return false;
+ }
+
+ return ggml_mem_ranges_check(ctx->mem_ranges, node);
+}
+
+static bool ggml_metal_encode_concurrency_add(struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
+ if (!ctx->mem_ranges) {
+ return true;
+ }
+
+ return ggml_mem_ranges_add(ctx->mem_ranges, node);
+}
+
+static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, int idx, int idx_end) {
+ ggml_backend_t backend = ctx_enc->backend;
+
+ id<MTLComputeCommandEncoder> encoder = ctx_enc->encoder;
+
+ struct ggml_metal_mem_pool * mem_pool = ctx_enc->mem_pool;
+
struct ggml_backend_metal_context * ctx = backend->context;
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
const uint64_t nb2 = dst ? dst->nb[2] : 0;
const uint64_t nb3 = dst ? dst->nb[3] : 0;
+ size_t offs_src[GGML_MAX_SRC];
+
+ id<MTLBuffer> id_src[GGML_MAX_SRC];
+
+ enum ggml_type srct[GGML_MAX_SRC];
+
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
+ offs_src[i] = 0;
+ id_src[i] = node->src[i] ? ggml_metal_get_buffer(node->src[i], &offs_src[i]) : nil;
+ srct[i] = node->src[i] ? node->src[i]->type : GGML_TYPE_COUNT;
+ }
+
+ // TODO: tmp shorthands - remove
+ size_t offs_src0 = offs_src[0];
+ size_t offs_src1 = offs_src[1];
+ size_t offs_src2 = offs_src[2];
+
+ id<MTLBuffer> id_src0 = id_src[0];
+ id<MTLBuffer> id_src1 = id_src[1];
+ id<MTLBuffer> id_src2 = id_src[2];
+
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT;
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
- size_t offs_src0 = 0;
- size_t offs_src1 = 0;
- size_t offs_src2 = 0;
- size_t offs_dst = 0;
+ size_t offs_dst = 0;
- id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
- id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
- id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
- id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
+ id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
int n_fuse = 1;
-#if 0
- GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
- if (src0) {
- GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
- ggml_is_contiguous(src0), src0->name);
- }
- if (src1) {
- GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
- ggml_is_contiguous(src1), src1->name);
- }
- if (dst) {
- GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
- dst->name);
+ // check if the current node can run concurrently with other nodes before it
+ // the condition is that:
+ // - the current node cannot write to any previous src or dst ranges
+ // - the current node cannot read from any previous dst ranges
+ //
+ // if the condition is not satisfied, we put a memory barrier and clear all ranges
+ // otherwise, we add the new ranges to the encoding context and process the node concurrently
+ //
+ {
+ const bool is_concurrent = ggml_metal_encode_concurrency_check(ctx_enc, node);
+
+ if (!is_concurrent) {
+ ggml_metal_encode_concurrency_reset(ctx_enc);
+ }
+
+ if (ctx_dev->debug_graph > 0) {
+ GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(dst->op), is_concurrent ? "(concurrent)" : "");
+ }
+ if (ctx_dev->debug_graph > 1) {
+ if (src0) {
+ GGML_LOG_DEBUG("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
+ ggml_is_contiguous(src0), src0->name);
+ }
+ if (src1) {
+ GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
+ ggml_is_contiguous(src1), src1->name);
+ }
+ if (dst) {
+ GGML_LOG_DEBUG("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
+ dst->name);
+ }
+ }
}
-#endif
id<MTLDevice> device = ctx_dev->mtl_device;
if (n_fuse > 1) {
id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
+
+ for (int i = 1; i < n_fuse; ++i) {
+ if (!ggml_metal_encode_concurrency_check(ctx_enc, nodes[i])) {
+ ggml_metal_encode_concurrency_reset(ctx_enc);
+
+ break;
+ }
+ }
}
[encoder setComputePipelineState:pipeline];
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+
+ ggml_metal_encode_concurrency_reset(ctx_enc);
}
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
default: break;
}
+ // TODO: using mem pool allocations with enabled concurrency is not safe because the mem pool
+ // reuses buffers. this can result in 2 concurrent MUL_MAT_ID ops using the same mem pool buffer.
+ // so we add this extra barrier to prevent the race.
+ // the correct solution is to remove mem pools and then remove this barrier [TAG_MEM_POOL_REMOVE]
+ ggml_metal_encode_concurrency_reset(ctx_enc);
+
// tokens per expert
const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)];
}
+ // this barrier is always needed because the next kernel has to wait for the id maps to be computed
+ ggml_metal_encode_concurrency_reset(ctx_enc);
+
{
id<MTLComputePipelineState> pipeline = nil;
if (n_fuse > 1) {
id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
+
+ for (int i = 1; i < n_fuse; ++i) {
+ if (!ggml_metal_encode_concurrency_check(ctx_enc, nodes[i])) {
+ ggml_metal_encode_concurrency_reset(ctx_enc);
+
+ break;
+ }
+ }
}
id<MTLComputePipelineState> pipeline;
} break;
case GGML_OP_ROPE:
{
-
// make sure we have one or more position id(ne10) per token(ne02)
GGML_ASSERT(ne10 % ne02 == 0);
GGML_ASSERT(ne10 >= ne02);
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));
+ // using mem pool allocations with enabled concurrency is not safe [TAG_MEM_POOL_REMOVE]
+ // still, we assume that concurrent FA won't happen before we do the refactor
+ //ggml_metal_encode_concurrency_reset(ctx_enc);
+
const int32_t nrows = ne1*ne2*ne3;
// temp buffer for writing the results from each workgroup
[encoder setThreadgroupMemoryLength:smem atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+ ggml_metal_encode_concurrency_reset(ctx_enc);
+
// reduce the results from the workgroups
{
ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = {
[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
} break;
- case GGML_OP_ARGMAX:
+ case GGML_OP_ARGMAX:
{
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous_1(src0));
}
}
+ if (ctx_dev->debug_graph > 0) {
+ if (n_fuse > 1) {
+ GGML_LOG_DEBUG("%s: fuse %d ops\n", __func__, n_fuse);
+ }
+ }
+
+ // update the mem ranges in the encoding context
+ for (int i = 0; i < n_fuse; ++i) {
+ if (!ggml_metal_encode_concurrency_add(ctx_enc, nodes[i])) {
+ ggml_metal_encode_concurrency_reset(ctx_enc);
+ }
+ }
+
return n_fuse;
}
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
// number of nodes encoded by the main thread (empirically determined)
- const int n_main = 128;
+ const int n_main = 64;
// number of threads in addition to the main thread
const int n_cb = ctx->n_cb;
// cannot use commandBufferWithUnretainedReferences because the buffers from the memory pool can get destroyed
// TODO: when the memory pools are removed, we can again use commandBufferWithUnretainedReferences
// https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2334215009
+ // [TAG_MEM_POOL_REMOVE]
//id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBuffer];
[cmd_buf retain];
return ggml_metal_graph_compute(backend, cgraph);
}
+static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+ struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
+
+ //const int64_t t_start = ggml_time_us();
+
+ if (ctx_dev->use_graph_optimize) {
+ ggml_metal_graph_optimize(cgraph);
+ }
+
+ //printf("%s: graph optimize took %.3f ms\n", __func__, (ggml_time_us() - t_start) / 1000.0);
+}
+
static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
GGML_ASSERT(ggml_backend_is_metal(backend));
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
- id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
- struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
+ id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
+ struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
+ struct ggml_mem_ranges * mem_ranges = ctx->cmd_bufs[cb_idx].mem_ranges;
ggml_metal_mem_pool_reset(mem_pool);
- id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
+ if (mem_ranges) {
+ ggml_mem_ranges_reset(mem_ranges);
+ }
+
+ id<MTLComputeCommandEncoder> encoder;
+
+ struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
+
+ if (ctx_dev->use_concurrency) {
+ encoder = [cmd_buf computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
+ } else {
+ encoder = [cmd_buf computeCommandEncoder];
+ }
int node_start = 0;
int node_end = n_nodes_0;
const bool should_capture = ctx->capture_next_compute;
+ struct ggml_metal_encode_context ctx_enc = {
+ /*.backend =*/ backend,
+ /*.encoder =*/ encoder,
+ /*.mem_pool =*/ mem_pool,
+ /*.mem_ranges =*/ mem_ranges,
+ };
+
for (int idx = node_start; idx < node_end;) {
if (should_capture) {
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
}
- const int res = ggml_metal_encode_node(backend, idx, node_end, encoder, mem_pool);
+ const int res = ggml_metal_encode_node(&ctx_enc, idx, node_end);
if (idx + res > node_end) {
GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
"https://github.com/ggml-org/llama.cpp/pull/14849");
// https://developer.apple.com/documentation/metal/mtlcommandbuffer#Synchronizing-Passes-with-Events
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
- /* .optimize_graph = */ NULL,
+ /* .optimize_graph = */ ggml_backend_metal_graph_optimize,
};
static ggml_guid_t ggml_backend_metal_guid(void) {