]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : fix out-of-bounds access + inc concurrency nodes (#2416)
authorGeorgi Gerganov <redacted>
Mon, 7 Aug 2023 07:52:57 +0000 (10:52 +0300)
committerGitHub <redacted>
Mon, 7 Aug 2023 07:52:57 +0000 (10:52 +0300)
* metal : fix out-of-bounds access + style changes

* metal : increase concurrency nodes to 2*GGML_MAX_NODES

ggml-metal.m

index 3f098d39677a0ee49c0952243c82abdc63f5ad0f..b47a98e214b613fb0f022f8606047e6771201ddd 100644 (file)
@@ -7,6 +7,11 @@
 #import <Metal/Metal.h>
 #import <MetalPerformanceShaders/MetalPerformanceShaders.h>
 
+#undef MIN
+#undef MAX
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+
 #ifdef GGML_METAL_NDEBUG
 #define metal_printf(...)
 #else
@@ -15,6 +20,8 @@
 
 #define UNUSED(x) (void)(x)
 
+#define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
+
 struct ggml_metal_buffer {
     const char * name;
 
@@ -36,7 +43,7 @@ struct ggml_metal_context {
     int n_buffers;
     struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
 
-    int concur_list[GGML_MAX_NODES];
+    int concur_list[GGML_MAX_CONCUR];
     int concur_list_len;
 
     // custom kernels
@@ -370,15 +377,15 @@ void ggml_metal_graph_find_concurrency(
         struct ggml_metal_context * ctx,
         struct ggml_cgraph * gf) {
     int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
-    int nodes_unused[GGML_MAX_NODES];
+    int nodes_unused[GGML_MAX_CONCUR];
 
-    for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;}
-    for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;}
+    for (int i = 0; i < GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; }
+    for (int i = 0; i < gf->n_nodes;     i++) { nodes_unused[i]     = 1; }
     ctx->concur_list_len = 0;
 
-    int n_left = gf->n_nodes;
-    int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
-    int level_pos = 0;  // at ctx->concur_list, the last layer (level) ends at level_pos
+    int n_left    = gf->n_nodes;
+    int n_start   = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
+    int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
 
     while (n_left > 0) {
         // number of nodes at a layer (that can be issued concurrently)
@@ -386,28 +393,40 @@ void ggml_metal_graph_find_concurrency(
         for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
             if (nodes_unused[i]) {
                 // if the requirements for gf->nodes[i] are satisfied
-                int exe_flag=1;
+                int exe_flag = 1;
+
                 // scan all srcs
                 for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
                     struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
                     if (src_cur) {
                         // if is leaf nodes it's satisfied.
-                        if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;}
+                        // TODO: ggml_is_leaf()
+                        if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {
+                            continue;
+                        }
 
                         // otherwise this src should be the output from previous nodes.
                         int is_found = 0;
+
                         // scan 2*search_depth back because we inserted barrier.
-                        for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
-                            if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;}
+                        //for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
+                        for (int j = MAX(0, level_pos - 2*search_depth); j < level_pos; j++) {
+                            if (ctx->concur_list[j] >= 0 && gf->nodes[ctx->concur_list[j]] == src_cur) {
+                                is_found = 1;
+                                break;
+                            }
+                        }
+                        if (is_found == 0) {
+                            exe_flag = 0;
+                            break;
                         }
-                        if (is_found == 0) {exe_flag = 0; break;}
                     }
                 }
                 if (exe_flag) {
                     // check if nodes[i]'s data will be overwritten by a node before nodes[i].
                     // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
                     int64_t data_start = (int64_t) gf->nodes[i]->data;
-                    int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
+                    int64_t length     = (int64_t) ggml_nbytes(gf->nodes[i]);
                     for (int j = n_start; j < i; j++) {
                         if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
                                             && gf->nodes[j]->op != GGML_OP_VIEW \
@@ -416,9 +435,9 @@ void ggml_metal_graph_find_concurrency(
                             if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
                                 ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
                                 continue;
-                            } else {
-                                exe_flag = 0;
                             }
+
+                            exe_flag = 0;
                         }
                     }
                 }
@@ -435,11 +454,13 @@ void ggml_metal_graph_find_concurrency(
         ctx->concur_list[level_pos + concurrency] = -1;
         ctx->concur_list_len++;
         // jump all sorted nodes at nodes_bak
-        while (!nodes_unused[n_start]) {n_start++;}
+        while (!nodes_unused[n_start]) {
+            n_start++;
+        }
         level_pos += concurrency + 1;
     }
 
-    if (ctx->concur_list_len > GGML_MAX_NODES) {
+    if (ctx->concur_list_len > GGML_MAX_CONCUR) {
         fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
     }
 }
@@ -453,7 +474,7 @@ void ggml_metal_graph_compute(
     // else fallback to serial dispatch
     MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
 
-    const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES;
+    const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR;
 
     const int n_nodes  = has_concur ? ctx->concur_list_len      : gf->n_nodes;
     edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;