]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : allocate graphs in a context (#2392)
authorslaren <redacted>
Wed, 26 Jul 2023 13:56:53 +0000 (15:56 +0200)
committerGitHub <redacted>
Wed, 26 Jul 2023 13:56:53 +0000 (15:56 +0200)
* ggml : graph allocation in contexts

* allocate work buffer as a ggml_object in ggml_graph_compute_with_ctx

* llama.cpp : allocate graph in the context

* add GGML_PAD

---------

Co-authored-by: Georgi Gerganov <redacted>
ggml.c
ggml.h
llama.cpp

diff --git a/ggml.c b/ggml.c
index 35c56151b8f7c7bbbe9ca5f4b7e4372a78cc173d..33459f263657e5473313a644fefca9ee6e7988f1 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -4071,8 +4071,8 @@ bool ggml_is_numa(void) {
 ////////////////////////////////////////////////////////////////////////////////
 
 void ggml_print_object(const struct ggml_object * obj) {
-    GGML_PRINT(" - ggml_object: offset = %zu, size = %zu, next = %p\n",
-            obj->offs, obj->size, (const void *) obj->next);
+    GGML_PRINT(" - ggml_object: type = %d, offset = %zu, size = %zu, next = %p\n",
+            obj->type, obj->offs, obj->size, (const void *) obj->next);
 }
 
 void ggml_print_objects(const struct ggml_context * ctx) {
@@ -4212,7 +4212,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
 }
 
 size_t ggml_tensor_overhead(void) {
-    return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE + 16;
+    return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE;
 }
 
 bool ggml_is_transposed(const struct ggml_tensor * tensor) {
@@ -4383,7 +4383,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
         return NULL;
     }
 
-    const size_t mem_size = (params.mem_size + GGML_MEM_ALIGN - 1) & ~(GGML_MEM_ALIGN - 1);
+    const size_t mem_size = params.mem_buffer ? params.mem_size : GGML_PAD(params.mem_size, GGML_MEM_ALIGN);
 
     *ctx = (struct ggml_context) {
         /*.mem_size           =*/ mem_size,
@@ -4472,12 +4472,14 @@ size_t ggml_get_max_tensor_size(const struct ggml_context * ctx) {
     struct ggml_object * obj = ctx->objects_begin;
 
     while (obj != NULL) {
-        struct ggml_tensor * tensor = (struct ggml_tensor *) ((char *) ctx->mem_buffer + obj->offs);
+        if (obj->type == GGML_OBJECT_TENSOR) {
+            struct ggml_tensor * tensor = (struct ggml_tensor *) ((char *) ctx->mem_buffer + obj->offs);
 
-        const size_t size = ggml_nbytes(tensor);
+            const size_t size = ggml_nbytes(tensor);
 
-        if (max_size < size) {
-            max_size = size;
+            if (max_size < size) {
+                max_size = size;
+            }
         }
 
         obj = obj->next;
@@ -4509,12 +4511,7 @@ static void ggml_scratch_load(struct ggml_context * ctx) {
 
 ////////////////////////////////////////////////////////////////////////////////
 
-static struct ggml_tensor * ggml_new_tensor_impl(
-        struct ggml_context * ctx,
-        enum   ggml_type type,
-        int    n_dims,
-        const int64_t* ne,
-        void*  data) {
+static struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml_object_type type, size_t size) {
     // always insert objects at the end of the context's memory pool
     struct ggml_object * obj_cur = ctx->objects_end;
 
@@ -4522,77 +4519,79 @@ static struct ggml_tensor * ggml_new_tensor_impl(
     const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size;
     const size_t cur_end  = cur_offs + cur_size;
 
-    size_t size_needed = 0;
-
-    if (data == NULL && !ctx->no_alloc) {
-        size_needed += GGML_TYPE_SIZE[type]*(ne[0]/GGML_BLCK_SIZE[type]);
-        for (int i = 1; i < n_dims; i++) {
-            size_needed *= ne[i];
-        }
-        // align to GGML_MEM_ALIGN
-        size_needed = ((size_needed + GGML_MEM_ALIGN - 1)/GGML_MEM_ALIGN)*GGML_MEM_ALIGN;
-    }
+    // align to GGML_MEM_ALIGN
+    size_t size_needed = GGML_PAD(size, GGML_MEM_ALIGN);
 
     char * const mem_buffer = ctx->mem_buffer;
     struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end);
 
-    if (ctx->scratch.data == NULL || data != NULL) {
-        size_needed += GGML_TENSOR_SIZE;
+    if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
+        GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
+                __func__, cur_end + size_needed, ctx->mem_size);
+        assert(false);
+        return NULL;
+    }
 
-        if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
-            GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
-                    __func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size);
-            assert(false);
-            return NULL;
-        }
+    *obj_new = (struct ggml_object) {
+        .offs = cur_end + GGML_OBJECT_SIZE,
+        .size = size_needed,
+        .next = NULL,
+        .type = type,
+    };
 
-        *obj_new = (struct ggml_object) {
-            .offs = cur_end + GGML_OBJECT_SIZE,
-            .size = size_needed,
-            .next = NULL,
-        };
+    ggml_assert_aligned(mem_buffer + obj_new->offs);
+
+    if (obj_cur != NULL) {
+        obj_cur->next = obj_new;
     } else {
-        if (ctx->scratch.offs + size_needed > ctx->scratch.size) {
-            GGML_PRINT("%s: not enough space in the scratch memory pool (needed %zu, available %zu)\n",
-                    __func__, ctx->scratch.offs + size_needed, ctx->scratch.size);
-            assert(false);
-            return NULL;
+        // this is the first object in this context
+        ctx->objects_begin = obj_new;
+    }
+
+    ctx->objects_end = obj_new;
+
+    //printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size);
+
+    return obj_new;
+}
+
+static struct ggml_tensor * ggml_new_tensor_impl(
+        struct ggml_context * ctx,
+        enum   ggml_type type,
+        int    n_dims,
+        const int64_t* ne,
+        void*  data) {
+
+    size_t data_size = 0;
+
+    if (data == NULL && !ctx->no_alloc) {
+        data_size += GGML_TYPE_SIZE[type]*(ne[0]/GGML_BLCK_SIZE[type]);
+        for (int i = 1; i < n_dims; i++) {
+            data_size *= ne[i];
         }
+    }
 
-        if (cur_end + GGML_TENSOR_SIZE + GGML_OBJECT_SIZE > ctx->mem_size) {
-            GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
-                    __func__, cur_end + GGML_TENSOR_SIZE + GGML_OBJECT_SIZE, ctx->mem_size);
+    if (ctx->scratch.data != NULL && data == NULL) {
+        // allocate tensor data in the scratch buffer
+        if (ctx->scratch.offs + data_size > ctx->scratch.size) {
+            GGML_PRINT("%s: not enough space in the scratch memory pool (needed %zu, available %zu)\n",
+                    __func__, ctx->scratch.offs + data_size, ctx->scratch.size);
             assert(false);
             return NULL;
         }
 
         data = (char * const) ctx->scratch.data + ctx->scratch.offs;
 
-        *obj_new = (struct ggml_object) {
-            .offs = cur_end + GGML_OBJECT_SIZE,
-            .size = GGML_TENSOR_SIZE,
-            .next = NULL,
-        };
-
-        //printf("scratch offs = %zu, size_needed = %zu\n", ctx->scratch.offs, size_needed);
+        ctx->scratch.offs += data_size;
 
-        ctx->scratch.offs += size_needed;
+        data_size = 0;
     }
 
-    if (obj_cur != NULL) {
-        obj_cur->next = obj_new;
-    } else {
-        // this is the first object in this context
-        ctx->objects_begin = obj_new;
-    }
-
-    ctx->objects_end = obj_new;
-
-    //printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size);
+    struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TENSOR, GGML_TENSOR_SIZE + data_size);
 
-    struct ggml_tensor * const result = (struct ggml_tensor *)(mem_buffer + obj_new->offs);
+    // TODO: for recoverable errors, we would need to free the data allocated from the scratch buffer here
 
-    ggml_assert_aligned(result);
+    struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs);
 
     *result = (struct ggml_tensor) {
         /*.type         =*/ type,
@@ -5026,9 +5025,11 @@ struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * nam
     char * const mem_buffer = ctx->mem_buffer;
 
     while (obj != NULL) {
-        struct ggml_tensor * cur = (struct ggml_tensor *)(mem_buffer + obj->offs);
-        if (strcmp(cur->name, name) == 0) {
-            return cur;
+        if (obj->type == GGML_OBJECT_TENSOR) {
+            struct ggml_tensor * cur = (struct ggml_tensor *)(mem_buffer + obj->offs);
+            if (strcmp(cur->name, name) == 0) {
+                return cur;
+            }
         }
 
         obj = obj->next;
@@ -15829,6 +15830,35 @@ struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cg
     return result;
 }
 
+struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) {
+    struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_GRAPH, GGML_GRAPH_SIZE);
+    struct ggml_cgraph * cgraph = (struct ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs);
+
+    *cgraph = (struct ggml_cgraph) {
+        /*.n_nodes      =*/ 0,
+        /*.n_leafs      =*/ 0,
+        /*.nodes        =*/ { NULL },
+        /*.grads        =*/ { NULL },
+        /*.leafs        =*/ { NULL },
+        /*.hash_table   =*/ { NULL },
+        /*.perf_runs    =*/ 0,
+        /*.perf_cycles  =*/ 0,
+        /*.perf_time_us =*/ 0,
+    };
+
+    return cgraph;
+}
+
+struct ggml_cgraph * ggml_build_forward_ctx(struct ggml_context * ctx, struct ggml_tensor * tensor) {
+    struct ggml_cgraph * cgraph = ggml_new_graph(ctx);
+    ggml_build_forward_impl(cgraph, tensor, false);
+    return cgraph;
+}
+
+size_t ggml_graph_overhead(void) {
+    return GGML_OBJECT_SIZE + GGML_PAD(GGML_GRAPH_SIZE, GGML_MEM_ALIGN);
+}
+
 //
 // thread data
 //
@@ -16544,10 +16574,9 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
 void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
     struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads);
 
-    struct ggml_tensor * buf = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cplan.work_size);
-    GGML_ASSERT(buf);
+    struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size);
 
-    cplan.work_data = buf->data;
+    cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs;
 
     ggml_graph_compute(cgraph, &cplan);
 }
diff --git a/ggml.h b/ggml.h
index c309f1361c6f6c5494369b19fb611754f46b7dda..9919cce7c263f5f32eeb8052be88593602458876 100644 (file)
--- a/ggml.h
+++ b/ggml.h
 
 #define GGML_UNUSED(x) (void)(x)
 
+#define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
 
 #define GGML_ASSERT(x) \
     do { \
@@ -396,6 +397,12 @@ extern "C" {
         GGML_UNARY_OP_SILU,
     };
 
+    enum ggml_object_type {
+        GGML_OBJECT_TENSOR,
+        GGML_OBJECT_GRAPH,
+        GGML_OBJECT_WORK_BUFFER
+    };
+
     // ggml object
     struct ggml_object {
         size_t offs;
@@ -403,7 +410,9 @@ extern "C" {
 
         struct ggml_object * next;
 
-        char padding[8];
+        enum ggml_object_type type;
+
+        char padding[4];
     };
 
     static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
@@ -424,7 +433,7 @@ extern "C" {
         enum ggml_op op;
 
         // op params - allocated as int32_t for alignment
-        int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(uint32_t)];
+        int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
 
         bool is_param;
 
@@ -485,6 +494,8 @@ extern "C" {
         int64_t perf_time_us;
     };
 
+    static const size_t GGML_GRAPH_SIZE = sizeof(struct ggml_cgraph);
+
     // scratch buffer
     struct ggml_scratch {
         size_t offs;
@@ -1391,11 +1402,17 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * tensor);
 
+
     GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
 
     GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
     GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
 
+    // graph allocation in a context
+    GGML_API struct ggml_cgraph * ggml_new_graph        (struct ggml_context * ctx);
+    GGML_API struct ggml_cgraph * ggml_build_forward_ctx(struct ggml_context * ctx, struct ggml_tensor * tensor);
+    GGML_API size_t ggml_graph_overhead(void);
+
     // ggml_graph_plan() has to be called before ggml_graph_compute()
     // when plan.work_size > 0, caller must allocate memory for plan.work_data
     GGML_API struct ggml_cplan ggml_graph_plan   (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
index 30d4b0a6e6c7c8c79e3b1ee3beef3106bd040946..024af99a563ac2f4a75259a0516d7ac46a98c8b0 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1424,7 +1424,7 @@ static bool llama_eval_internal(
 
     struct ggml_context * ctx0 = ggml_init(params);
 
-    ggml_cgraph gf = {};
+    ggml_cgraph * gf = ggml_new_graph(ctx0);
 
     // for big prompts, if BLAS is enabled, it is better to use only one thread
     // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
@@ -1541,8 +1541,8 @@ static bool llama_eval_internal(
                 ggml_set_name(v, "v");
 
                 // important: storing RoPE-ed version of K in the KV cache!
-                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
-                ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
+                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
             }
 
             struct ggml_tensor * Q =
@@ -1712,21 +1712,21 @@ static bool llama_eval_internal(
     //cur = ggml_soft_max_inplace(ctx0, cur);
 
     // run the computation
-    ggml_build_forward_expand(&gf, cur);
+    ggml_build_forward_expand(gf, cur);
 
     // fprintf(stderr, "graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf.n_nodes, gf.n_leafs);
 
 #if GGML_USE_MPI
-    ggml_mpi_graph_compute_pre(lctx.ctx_mpi, &gf, n_layer);
+    ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer);
 #endif
 
 #ifdef GGML_USE_METAL
     if (lctx.ctx_metal && N == 1) {
         if (!ggml_metal_if_optimized(lctx.ctx_metal)) {
-            ggml_metal_graph_find_concurrency(lctx.ctx_metal,&gf);
+            ggml_metal_graph_find_concurrency(lctx.ctx_metal, gf);
         }
         ggml_metal_set_n_cb     (lctx.ctx_metal, n_threads);
-        ggml_metal_graph_compute(lctx.ctx_metal, &gf);
+        ggml_metal_graph_compute(lctx.ctx_metal, gf);
         ggml_metal_get_tensor   (lctx.ctx_metal, cur);
     } else {
         // IMPORTANT:
@@ -1745,34 +1745,34 @@ static bool llama_eval_internal(
             ggml_metal_get_tensor(lctx.ctx_metal, kv_self.v);
         }
 
-        ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads);
+        ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads);
     }
 #else
-    ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads);
+    ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads);
 #endif
 
 #if GGML_USE_MPI
-    ggml_mpi_graph_compute_post(lctx.ctx_mpi, &gf, n_layer);
+    ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer);
 #endif
 
     // update kv token count
     lctx.kv_self.n = n_past + N;
 
-    struct ggml_tensor * res = gf.nodes[gf.n_nodes - 1];
+    struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
 
     if (cgraph_fname) {
-        ggml_graph_export(&gf, cgraph_fname);
+        ggml_graph_export(gf, cgraph_fname);
     }
 
 #ifdef GGML_PERF
     // print timing information per ggml operation (for debugging purposes)
     // requires GGML_PERF to be defined
-    ggml_graph_print(&gf);
+    ggml_graph_print(gf);
 #endif
 
     // plot the computation graph in dot format (for debugging purposes)
     //if (n_past%100 == 0) {
-    //    ggml_graph_dump_dot(&gf, NULL, "llama.dot");
+    //    ggml_graph_dump_dot(gf, NULL, "llama.dot");
     //}
 
     // extract logits
@@ -3177,7 +3177,7 @@ struct llama_context * llama_new_context_with_model(
             ctx->embedding.resize(hparams.n_embd);
         }
 
-        ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type));
+        ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type) + ggml_graph_overhead());
 
         ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type));
         ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type));