]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : concurrently dispatch commands (#2358)
authorShouzheng Liu <redacted>
Tue, 25 Jul 2023 12:00:19 +0000 (08:00 -0400)
committerGitHub <redacted>
Tue, 25 Jul 2023 12:00:19 +0000 (15:00 +0300)
* metal: concurrently dispatch commands

Function `ggml_metal_graph_find_concurrency` will run and write
commands that can be issued concurrently to metal context `concur_list`
array, when `ggml_metal_graph_compute` is called for the first time.

* metal: don't call find_concurrency automatically.

* metal : code style changes

---------

Co-authored-by: Georgi Gerganov <redacted>
ggml-metal.h
ggml-metal.m
llama.cpp

index 928f1705c381cf710468968b3f5a66e19e5a0c47..16f1a0caacfac483cf2c33e0715ca8d1c61ea7ce 100644 (file)
@@ -61,6 +61,13 @@ void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
 // get data from the device into host memory
 void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
 
+// try to find operations that can be run concurrently in the graph
+// you should run it again if the topology of your graph changes
+void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
+
+// if the graph has been optimized for concurrently dispatch
+bool ggml_metal_if_optimized(struct ggml_metal_context * ctx);
+
 // same as ggml_graph_compute but uses Metal
 // creates gf->n_threads command buffers in parallel
 void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
index c1db3d16573be8cff5229eceb4da314b36197a3b..74a6bff40411784f2b13ca4c1a7bf607bfc400c4 100644 (file)
@@ -36,6 +36,9 @@ struct ggml_metal_context {
     int n_buffers;
     struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
 
+    int concur_list[GGML_MAX_NODES];
+    int concur_list_len;
+
     // custom kernels
 #define GGML_METAL_DECL_KERNEL(name) \
     id<MTLFunction>             function_##name; \
@@ -98,6 +101,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
     ctx->device = MTLCreateSystemDefaultDevice();
     ctx->queue  = [ctx->device newCommandQueue];
     ctx->n_buffers = 0;
+    ctx->concur_list_len = 0;
 
     // determine if we can use MPS
     if (MPSSupportsMTLDevice(ctx->device)) {
@@ -217,6 +221,13 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
     ctx->n_cb = n_cb;
 }
 
+bool ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
+    if (ctx->concur_list_len) {
+        return true;
+    }
+    return false;
+}
+
 // finds the Metal buffer that contains the tensor data on the GPU device
 // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
 // Metal buffer based on the host memory pointer
@@ -355,11 +366,98 @@ void ggml_metal_get_tensor(
     memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
 }
 
+void ggml_metal_graph_find_concurrency(
+        struct ggml_metal_context * ctx,
+        struct ggml_cgraph * gf) {
+    int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
+    int nodes_unused[GGML_MAX_NODES];
+
+    for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;}
+    for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;}
+    ctx->concur_list_len = 0;
+
+    int n_left = gf->n_nodes;
+    int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
+    int level_pos = 0;  // at ctx->concur_list, the last layer (level) ends at level_pos
+
+    while (n_left > 0) {
+        // number of nodes at a layer (that can be issued concurrently)
+        int concurrency = 0;
+        for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
+            if (nodes_unused[i]) {
+                // if the requirements for gf->nodes[i] are satisfied
+                int exe_flag=1;
+                // scan all srcs
+                for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
+                    struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
+                    if (src_cur) {
+                        // if is leaf nodes it's satisfied.
+                        if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;}
+
+                        // otherwise this src should be the output from previous nodes.
+                        int is_found = 0;
+                        // scan 2*search_depth back because we inserted barrier.
+                        for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
+                            if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;}
+                        }
+                        if (is_found == 0) {exe_flag = 0; break;}
+                    }
+                }
+                if (exe_flag) {
+                    // check if nodes[i]'s data will be overwritten by a node before nodes[i].
+                    // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
+                    int64_t data_start = (int64_t) gf->nodes[i]->data;
+                    int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
+                    for (int j = n_start; j < i; j++) {
+                        if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
+                                            && gf->nodes[j]->op != GGML_OP_VIEW \
+                                            && gf->nodes[j]->op != GGML_OP_TRANSPOSE \
+                                            && gf->nodes[j]->op != GGML_OP_PERMUTE) {
+                            if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
+                                ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
+                                continue;
+                            } else {
+                                exe_flag = 0;
+                            }
+                        }
+                    }
+                }
+                if (exe_flag) {
+                    ctx->concur_list[level_pos + concurrency] = i;
+                    nodes_unused[i] = 0;
+                    concurrency++;
+                    ctx->concur_list_len++;
+                }
+            }
+        }
+        n_left -= concurrency;
+        // adding a barrier different layer
+        ctx->concur_list[level_pos + concurrency] = -1;
+        ctx->concur_list_len++;
+        // jump all sorted nodes at nodes_bak
+        while (!nodes_unused[n_start]) {n_start++;}
+        level_pos += concurrency + 1;
+    }
+
+    if (ctx->concur_list_len > GGML_MAX_NODES) {
+        fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
+    }
+}
+
 void ggml_metal_graph_compute(
         struct ggml_metal_context * ctx,
                struct ggml_cgraph * gf) {
     metal_printf("%s: evaluating graph\n", __func__);
 
+    // if there is ctx->concur_list, dispatch concurrently
+    // else fallback to serial dispatch
+    MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
+
+    const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES;
+
+    const int n_nodes  = has_concur ? ctx->concur_list_len      : gf->n_nodes;
+    edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
+
     // create multiple command buffers and enqueue them
     // then, we encode the graph into the command buffers in parallel
 
@@ -378,7 +476,7 @@ void ggml_metal_graph_compute(
     dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
 
     for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
-        const int n_nodes_per_cb = (gf->n_nodes + n_cb - 1) / n_cb;
+        const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
 
         dispatch_async(queue, ^{
             size_t offs_src0 = 0;
@@ -389,10 +487,21 @@ void ggml_metal_graph_compute(
 
             id<MTLComputeCommandEncoder> encoder = nil;
 
-            const int node_start =                                      (cb_idx + 0) * n_nodes_per_cb;
-            const int node_end   = (cb_idx == n_cb - 1) ? gf->n_nodes : (cb_idx + 1) * n_nodes_per_cb;
+            const int node_start =                                  (cb_idx + 0) * n_nodes_per_cb;
+            const int node_end   = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
+
+            for (int ind = node_start; ind < node_end; ++ind) {
+                const int i = has_concur ? ctx->concur_list[ind] : ind;
+
+                if (i == -1) {
+                    if (encoder == nil) {
+                        encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
+                        continue;
+                    }
+                    [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
+                    continue;
+                }
 
-            for (int i = node_start; i < node_end; ++i) {
                 metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
 
                 struct ggml_tensor * src0 = gf->nodes[i]->src[0];
@@ -463,7 +572,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_ADD:
                         {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
 
                             if (ggml_nelements(src1) == ne10) {
@@ -484,7 +593,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_MUL:
                         {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
 
                             if (ggml_nelements(src1) == ne10) {
@@ -505,7 +614,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_SCALE:
                         {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
 
                             const float scale = *(const float *) src1->data;
@@ -524,7 +633,7 @@ void ggml_metal_graph_compute(
                             case GGML_UNARY_OP_SILU:
                                 {
                                     if (encoder == nil) {
-                                        encoder = [command_buffer computeCommandEncoder];
+                                        encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                                     }
 
                                     [encoder setComputePipelineState:ctx->pipeline_silu];
@@ -538,7 +647,7 @@ void ggml_metal_graph_compute(
                             case GGML_UNARY_OP_RELU:
                                 {
                                     if (encoder == nil) {
-                                        encoder = [command_buffer computeCommandEncoder];
+                                        encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                                     }
 
                                     [encoder setComputePipelineState:ctx->pipeline_relu];
@@ -552,7 +661,7 @@ void ggml_metal_graph_compute(
                             case GGML_UNARY_OP_GELU:
                                 {
                                     if (encoder == nil) {
-                                        encoder = [command_buffer computeCommandEncoder];
+                                        encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                                     }
 
                                     [encoder setComputePipelineState:ctx->pipeline_gelu];
@@ -572,7 +681,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_SOFT_MAX:
                         {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
 
                             const int nth = 32;
@@ -590,7 +699,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_DIAG_MASK_INF:
                         {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
 
                             const int n_past = ((int32_t *)(dst->op_params))[0];
@@ -653,7 +762,7 @@ void ggml_metal_graph_compute(
                                 }
                             } else {
                                 if (encoder == nil) {
-                                    encoder = [command_buffer computeCommandEncoder];
+                                    encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                                 }
 
                                 int nth0 = 32;
@@ -780,7 +889,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_GET_ROWS:
                         {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
 
                             switch (src0->type) {
@@ -809,7 +918,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_RMS_NORM:
                         {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
 
                             float eps;
@@ -832,7 +941,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_NORM:
                         {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
 
                             const float eps = 1e-5f;
@@ -854,7 +963,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_ALIBI:
                         {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
 
                             GGML_ASSERT((src0t == GGML_TYPE_F32));
@@ -897,7 +1006,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_ROPE:
                         {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
 
                             const int n_past = ((int32_t *) dst->op_params)[0];
@@ -941,7 +1050,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_CONT:
                         {
                             if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
+                                encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
                             }
 
                             const int nth = 32;
index b42b41008654422e213eac660398024b8caee33a..2d737bbcebc2f8ba0f57bebcc975a24af02ca997 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1720,6 +1720,9 @@ static bool llama_eval_internal(
 
 #ifdef GGML_USE_METAL
     if (lctx.ctx_metal && N == 1) {
+        if (!ggml_metal_if_optimized(lctx.ctx_metal)) {
+            ggml_metal_graph_find_concurrency(lctx.ctx_metal,&gf);
+        }
         ggml_metal_set_n_cb     (lctx.ctx_metal, n_threads);
         ggml_metal_graph_compute(lctx.ctx_metal, &gf);
         ggml_metal_get_tensor   (lctx.ctx_metal, cur);