]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : add graph tensor allocator (#2411)
authorslaren <redacted>
Sun, 30 Jul 2023 13:58:01 +0000 (15:58 +0200)
committerGitHub <redacted>
Sun, 30 Jul 2023 13:58:01 +0000 (15:58 +0200)
* ggml : add graph tensor allocator

* ggml : don't calculate data pointer of unallocated tensors when creating a view with an offset

* ggml : refactor ggml_view_Nd into ggml_view_tensor_offset

CMakeLists.txt
Makefile
ggml-alloc.c [new file with mode: 0644]
ggml-alloc.h [new file with mode: 0644]
ggml.c
ggml.h
llama.cpp

index 6e1abeaa1205a37223131b4d688905aaed6052b7..57678a302d0977203a62f1637081017240b15f01 100644 (file)
@@ -503,6 +503,8 @@ endif()
 add_library(ggml OBJECT
             ggml.c
             ggml.h
+            ggml-alloc.c
+            ggml-alloc.h
             ${GGML_SOURCES_CUDA}
             ${GGML_SOURCES_OPENCL}
             ${GGML_SOURCES_METAL}
index 3d1fff8e55eee60fc2f7f62e38915f0cd9ec5721..616c2d9b88be9d82000f0287fba43fc246a7afe2 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -329,7 +329,12 @@ $(info )
 ggml.o: ggml.c ggml.h ggml-cuda.h
        $(CC)  $(CFLAGS)   -c $< -o $@
 
-llama.o: llama.cpp ggml.h ggml-cuda.h ggml-metal.h llama.h llama-util.h
+ggml-alloc.o: ggml-alloc.c ggml.h ggml-alloc.h
+       $(CC)  $(CFLAGS)   -c $< -o $@
+
+OBJS += ggml-alloc.o
+
+llama.o: llama.cpp ggml.h ggml-alloc.h ggml-cuda.h ggml-metal.h llama.h llama-util.h
        $(CXX) $(CXXFLAGS) -c $< -o $@
 
 common.o: examples/common.cpp examples/common.h
diff --git a/ggml-alloc.c b/ggml-alloc.c
new file mode 100644 (file)
index 0000000..5e1be61
--- /dev/null
@@ -0,0 +1,541 @@
+#include "ggml-alloc.h"
+#include "ggml.h"
+#include <assert.h>
+#include <stdarg.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+
+#define UNUSED(x) (void)(x)
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+
+//#define GGML_ALLOCATOR_DEBUG
+
+//#define AT_PRINTF printf
+#define AT_PRINTF(...) ((void)0)
+
+struct hash_node {
+    struct ggml_tensor * t;
+    int n_children;
+    int n_views;
+};
+
+static size_t hash(void * p) {
+    return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
+}
+
+static struct hash_node * hash_get(struct hash_node hash_table[], struct ggml_tensor * t) {
+    size_t h = hash(t);
+
+    // linear probing
+    size_t i = h;
+    while (hash_table[i].t != NULL) {
+        if (hash_table[i].t == t) {
+            return &hash_table[i];
+        }
+        i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
+        if (i == h) {
+            // hash table is full
+            GGML_ASSERT(false);
+        }
+    }
+
+    hash_table[i].t = t;
+    return &hash_table[i];
+}
+
+// TODO: GGML_PAD ?
+static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) {
+    assert(alignment && !(alignment & (alignment - 1))); // power of 2
+    size_t align = (alignment - (((uintptr_t)buffer + offset) % alignment)) % alignment;
+    return offset + align;
+}
+
+struct free_block {
+    void * addr;
+    size_t size;
+};
+
+#define MAX_FREE_BLOCKS 128
+
+struct ggml_allocr {
+    void * data;
+    size_t size;
+    size_t alignment;
+    int n_free_blocks;
+    struct free_block free_blocks[MAX_FREE_BLOCKS];
+    struct hash_node hash_table[GGML_GRAPH_HASHTABLE_SIZE];
+    size_t max_size;
+    bool measure;
+
+#ifdef GGML_ALLOCATOR_DEBUG
+    struct ggml_tensor * allocated_tensors[1024];
+#endif
+};
+
+#ifdef GGML_ALLOCATOR_DEBUG
+static void add_allocated_tensor(struct ggml_allocator * alloc, struct ggml_tensor * tensor) {
+    for (int i = 0; i < 1024; i++) {
+        if (alloc->allocated_tensors[i] == NULL) {
+            alloc->allocated_tensors[i] = tensor;
+            return;
+        }
+    }
+    GGML_ASSERT(!"out of allocated_tensors");
+}
+static void remove_allocated_tensor(struct ggml_allocator * alloc, struct ggml_tensor * tensor) {
+    for (int i = 0; i < 1024; i++) {
+        if (alloc->allocated_tensors[i] == tensor ||
+            (alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) {
+            alloc->allocated_tensors[i] = NULL;
+            return;
+        }
+    }
+    printf("tried to free tensor %s not found\n", tensor->name);
+    GGML_ASSERT(!"tensor not found");
+}
+#endif
+
+
+static size_t ggml_allocator_get_alloc_size(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
+    return ggml_nbytes(tensor);
+
+    UNUSED(alloc);
+}
+
+void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
+    size_t size = ggml_allocator_get_alloc_size(alloc, tensor);
+    size = aligned_offset(NULL, size, alloc->alignment);
+
+    AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size);
+
+    size_t max_avail = 0;
+
+    // find the best fitting free block
+    int best_fit_block = -1;
+    size_t best_fit_size = SIZE_MAX;
+    for (int i = 0; i < alloc->n_free_blocks; i++) {
+        struct free_block * block = &alloc->free_blocks[i];
+        max_avail = MAX(max_avail, block->size);
+        if (block->size >= size && block->size <= best_fit_size) {
+            best_fit_block = i;
+            best_fit_size = block->size;
+        }
+    }
+
+    AT_PRINTF("block %d\n", best_fit_block);
+
+    if (best_fit_block == -1) {
+        fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n",
+                __func__, size, max_avail);
+        GGML_ASSERT(!"not enough space in the buffer");
+        return;
+    }
+    struct free_block * block = &alloc->free_blocks[best_fit_block];
+    void * addr = block->addr;
+    block->addr = (char*)block->addr + size;
+    block->size -= size;
+    if (block->size == 0) {
+        // remove block if empty
+        alloc->n_free_blocks--;
+        for (int j = best_fit_block; j < alloc->n_free_blocks; j++) {
+            alloc->free_blocks[j] = alloc->free_blocks[j+1];
+        }
+    }
+
+    tensor->data = addr;
+
+#ifdef GGML_ALLOCATOR_DEBUG
+    add_allocated_tensor(alloc, tensor);
+    size_t cur_max = (char*)addr - (char*)alloc->data + size;
+    if (cur_max > alloc->max_size) {
+        printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
+        for (int i = 0; i < 1024; i++) {
+            if (alloc->allocated_tensors[i]) {
+                printf("%s (%.2f MB) ", alloc->allocated_tensors[i]->name, ggml_nbytes(alloc->allocated_tensors[i]) / 1024.0 / 1024.0);
+            }
+        }
+        printf("\n");
+    }
+#endif
+
+    alloc->max_size = MAX(alloc->max_size, (char*)addr - (char*)alloc->data + size);
+}
+
+// this is a very naive implementation, but for our case the number of free blocks should be very small
+static void ggml_allocator_free_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
+    void * ptr = tensor->data;
+
+    if (ptr < alloc->data || (char*)ptr >= (char*)alloc->data + alloc->max_size) {
+        // the tensor was not allocated in this buffer
+        // this can happen because the graph allocator will try to free weights and other tensors from different buffers
+        // the easiest way to deal with this is just to ignore it
+        return;
+    }
+
+    size_t size = ggml_allocator_get_alloc_size(alloc, tensor);
+    size = aligned_offset(NULL, size, alloc->alignment);
+    AT_PRINTF("%s: freeing %s (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, size, alloc->n_free_blocks);
+
+#ifdef GGML_ALLOCATOR_DEBUG
+    remove_allocated_tensor(alloc, tensor);
+#endif
+
+    // see if we can merge with an existing block
+    for (int i = 0; i < alloc->n_free_blocks; i++) {
+        struct free_block * block = &alloc->free_blocks[i];
+        // check if ptr is at the end of the block
+        if ((char*)block->addr + block->size == ptr) {
+            block->size += size;
+            // check if we can merge with the next block
+            if (i < alloc->n_free_blocks - 1 && (char*)block->addr + block->size == alloc->free_blocks[i+1].addr) {
+                block->size += alloc->free_blocks[i+1].size;
+                alloc->n_free_blocks--;
+                for (int j = i+1; j < alloc->n_free_blocks; j++) {
+                    alloc->free_blocks[j] = alloc->free_blocks[j+1];
+                }
+            }
+            return;
+        }
+        // check if ptr is at the beginning of the block
+        if ((char*)ptr + size == block->addr) {
+            block->addr = ptr;
+            block->size += size;
+            // check if we can merge with the previous block
+            if (i > 0 && (char*)alloc->free_blocks[i-1].addr + alloc->free_blocks[i-1].size == block->addr) {
+                alloc->free_blocks[i-1].size += block->size;
+                alloc->n_free_blocks--;
+                for (int j = i; j < alloc->n_free_blocks; j++) {
+                    alloc->free_blocks[j] = alloc->free_blocks[j+1];
+                }
+            }
+            return;
+        }
+    }
+    // otherwise, add a new block
+    GGML_ASSERT(alloc->n_free_blocks < MAX_FREE_BLOCKS && "out of free blocks");
+    // insert the new block in the correct position to keep the array sorted by address (to make merging blocks faster)
+    int insert_pos = 0;
+    while (insert_pos < alloc->n_free_blocks && alloc->free_blocks[insert_pos].addr < ptr) {
+        insert_pos++;
+    }
+    // shift all blocks from insert_pos onward to make room for the new block
+    for (int i = alloc->n_free_blocks; i > insert_pos; i--) {
+        alloc->free_blocks[i] = alloc->free_blocks[i-1];
+    }
+    // insert the new block
+    alloc->free_blocks[insert_pos].addr = ptr;
+    alloc->free_blocks[insert_pos].size = size;
+    alloc->n_free_blocks++;
+}
+
+void ggml_allocr_reset(struct ggml_allocr * alloc) {
+    alloc->n_free_blocks = 1;
+    size_t align_offset = aligned_offset(alloc->data, 0, alloc->alignment);
+    alloc->free_blocks[0].addr = (char *)alloc->data + align_offset;
+    alloc->free_blocks[0].size = alloc->size - align_offset;
+}
+
+struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment) {
+    struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */);
+
+    *alloc = (struct ggml_allocr){
+        /*.data          = */ data,
+        /*.size          = */ size,
+        /*.alignment     = */ alignment,
+        /*.n_free_blocks = */ 0,
+        /*.free_blocks   = */ {{0}},
+        /*.hash_table    = */ {{0}},
+        /*.max_size      = */ 0,
+        /*.measure       = */ false,
+#ifdef GGML_ALLOCATOR_DEBUG
+        /*.allocated_tensors = */ = {0},
+#endif
+    };
+
+    ggml_allocr_reset(alloc);
+
+    return alloc;
+}
+
+// address and size of the buffer when measuring
+// it needs to be large enough to fit all the tensors, but it cannot overlap with other existing buffers
+static void * const MEASURE_BASE_ADDR = (void *) 0x1000;
+static const size_t MEASURE_MAX_SIZE  = 1ULL<<40; // 1 TB
+
+struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) {
+    struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */);
+
+    *alloc = (struct ggml_allocr){
+        /*.data          = */ MEASURE_BASE_ADDR,
+        /*.size          = */ MEASURE_MAX_SIZE,
+        /*.alignment     = */ alignment,
+        /*.n_free_blocks = */ 0,
+        /*.free_blocks   = */ {{0}},
+        /*.hash_table    = */ {{0}},
+        /*.max_size      = */ 0,
+        /*.measure       = */ true,
+#ifdef GGML_ALLOCATOR_DEBUG
+        /*.allocated_tensors = */ = {0},
+#endif
+    };
+
+    ggml_allocr_reset(alloc);
+
+    return alloc;
+}
+
+void ggml_allocr_free(struct ggml_allocr * alloc) {
+    free(alloc);
+}
+
+bool ggml_allocr_is_measure(struct ggml_allocr * alloc) {
+    return alloc->measure;
+}
+
+//////////// compute graph allocator
+
+static bool ggml_is_view(struct ggml_tensor * t) {
+    return t->op == GGML_OP_RESHAPE || t->op == GGML_OP_VIEW || t->op == GGML_OP_TRANSPOSE ||
+           t->op == GGML_OP_PERMUTE || t->op == GGML_OP_CPY;
+}
+
+static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
+    if (a->type != b->type) {
+        return false;
+    }
+    for (int i = 0; i < GGML_MAX_DIMS; i++) {
+        if (a->ne[i] != b->ne[i]) {
+            return false;
+        }
+        if (a->nb[i] != b->nb[i]) {
+            return false;
+        }
+    }
+    return true;
+}
+
+static struct ggml_tensor * get_view_parent(struct ggml_tensor * t) {
+    switch (t->op) {
+        case GGML_OP_PERMUTE:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_TRANSPOSE:
+        case GGML_OP_VIEW:
+            return t->src[0];
+        case GGML_OP_CPY:
+            return t->src[1];
+        default:
+            return NULL;
+    }
+}
+
+static struct ggml_tensor * get_view_source(struct ggml_tensor * t) {
+    struct ggml_tensor * parent = t;
+    do {
+        parent = get_view_parent(parent);
+    } while (ggml_is_view(parent));
+    return parent;
+}
+
+static bool ggml_op_can_inplace(enum ggml_op op) {
+    switch (op) {
+        case GGML_OP_SCALE:
+        case GGML_OP_DIAG_MASK_ZERO:
+        case GGML_OP_DIAG_MASK_INF:
+        case GGML_OP_ADD:
+        case GGML_OP_ADD1:
+        case GGML_OP_ACC:
+        case GGML_OP_SUB:
+        case GGML_OP_MUL:
+        case GGML_OP_DIV:
+        case GGML_OP_SQR:
+        case GGML_OP_SQRT:
+        case GGML_OP_LOG:
+        case GGML_OP_UNARY:
+        case GGML_OP_ROPE:
+        case GGML_OP_RMS_NORM:
+        case GGML_OP_SET:
+        case GGML_OP_SOFT_MAX:
+        case GGML_OP_CONT:
+            return true;
+
+        default:
+            return false;
+    }
+}
+
+static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node) {
+    struct hash_node * ht = alloc->hash_table;
+    if (node->data == NULL) {
+        if (ggml_is_view(node)) {
+            size_t offset;
+            switch(node->op) {
+                case GGML_OP_VIEW:
+                    memcpy(&offset, node->op_params, sizeof(size_t));
+                    node->data = (char *) node->src[0]->data + offset;
+                    break;
+                case GGML_OP_PERMUTE:
+                case GGML_OP_RESHAPE:
+                case GGML_OP_TRANSPOSE:
+                    node->data = node->src[0]->data;
+                    break;
+                case GGML_OP_CPY:
+                    node->data = node->src[1]->data;
+                    break;
+                default:
+                    GGML_ASSERT(!"unknown view op");
+                    break;
+            }
+        } else {
+            // see if we can reuse a parent's buffer (inplace)
+            if (ggml_op_can_inplace(node->op)) {
+                for (int i = 0; i < GGML_MAX_SRC; i++) {
+                    struct ggml_tensor * parent = node->src[i];
+                    if (parent == NULL) {
+                        break;
+                    }
+                    struct hash_node * p_hn = hash_get(ht, parent);
+                    if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && ggml_are_same_layout(node, parent)) {
+                        if (ggml_is_view(parent)) {
+                            struct ggml_tensor * view_src = get_view_source(parent);
+                            struct hash_node * view_src_hn = hash_get(ht, view_src);
+                            if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
+                                // TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite
+                                // the parent's data that it will need later (same layout requirement). the problem is that then
+                                // we cannot free the tensor because the original address of the allocation is lost.
+                                // adding a view_src pointer to the tensor would solve this and simplify the code dealing with views
+                                // for now, we only reuse the parent's data if the offset is zero (view_src->data == parent->data)
+                                AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name);
+                                node->data = parent->data;
+                                return;
+                            }
+                        }
+                        else {
+                            AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
+                            node->data = parent->data;
+                        }
+                        return;
+                    }
+                }
+            }
+            ggml_allocr_alloc(alloc, node);
+        }
+    }
+}
+
+static size_t ggml_allocator_alloc_graph_tensors_n(
+    struct ggml_allocr * alloc,
+    struct ggml_cgraph ** graphs, int n_graphs,
+    struct ggml_tensor *** inputs, struct ggml_tensor *** outputs) {
+
+    // reset hash table
+    struct hash_node * ht = alloc->hash_table;
+    memset(ht, 0, sizeof(struct hash_node) * GGML_GRAPH_HASHTABLE_SIZE);
+
+    // count number of children and views
+    for (int g = 0; g < n_graphs; g++) {
+        struct ggml_cgraph * gf = graphs[g];
+        for (int i = 0; i < gf->n_nodes; i++) {
+            struct ggml_tensor * node = gf->nodes[i];
+
+            if (ggml_is_view(node)) {
+                struct ggml_tensor * view_src = get_view_source(node);
+                hash_get(ht, view_src)->n_views += 1;
+            }
+
+            for (int j = 0; j < GGML_MAX_SRC; j++) {
+                struct ggml_tensor * parent = node->src[j];
+                if (parent == NULL) {
+                    break;
+                }
+                hash_get(ht, parent)->n_children += 1;
+            }
+        }
+    }
+
+    // allocate tensors
+    for (int g = 0; g < n_graphs; g++) {
+        struct ggml_cgraph * gf = graphs[g];
+        AT_PRINTF("####### graph %d/%d\n", g, n_graphs);
+        // graph inputs are allocated first to ensure that they are not overwritten by each other
+        if (inputs != NULL && inputs[g] != NULL) {
+            for (int i = 0; inputs[g][i] != NULL; i++) {
+                struct ggml_tensor * input = inputs[g][i];
+                AT_PRINTF("input: %s\n", input->name);
+                allocate_node(alloc, input);
+            }
+        }
+        for (int i = 0; i < gf->n_nodes; i++) {
+            struct ggml_tensor * node = gf->nodes[i];
+
+            // allocate parents (leafs)
+            for (int j = 0; j < GGML_MAX_SRC; j++) {
+                struct ggml_tensor * parent = node->src[j];
+                if (parent == NULL) {
+                    break;
+                }
+                allocate_node(alloc, parent);
+            }
+
+            // allocate node
+            allocate_node(alloc, node);
+
+            AT_PRINTF("exec: %s (%s) <= ", ggml_op_name(node->op), node->name);
+            for (int j = 0; j < GGML_MAX_SRC; j++) {
+                struct ggml_tensor * parent = node->src[j];
+                if (parent == NULL) {
+                    break;
+                }
+                AT_PRINTF("%s", parent->name);
+                if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
+                    AT_PRINTF(", ");
+                }
+            }
+            AT_PRINTF("\n");
+
+            // update parents
+            for (int j = 0; j < GGML_MAX_SRC; j++) {
+                struct ggml_tensor * parent = node->src[j];
+                if (parent == NULL) {
+                    break;
+                }
+                struct hash_node * p_hn = hash_get(ht, parent);
+                p_hn->n_children -= 1;
+
+                //AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views);
+
+                if (p_hn->n_children == 0 && p_hn->n_views == 0) {
+                    if (ggml_is_view(parent)) {
+                        struct ggml_tensor * view_src = get_view_source(parent);
+                        struct hash_node * view_src_hn = hash_get(ht, view_src);
+                        view_src_hn->n_views -= 1;
+                        AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src->n_children, view_src->n_views);
+                        if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
+                            ggml_allocator_free_tensor(alloc, view_src);
+                        }
+                    }
+                    else {
+                        if (parent->data != node->data) {
+                            ggml_allocator_free_tensor(alloc, parent);
+                        }
+                    }
+                }
+            }
+            AT_PRINTF("\n");
+        }
+        // free graph outputs here that wouldn't be freed otherwise because they have no children
+        if (outputs != NULL && outputs[g] != NULL) {
+            for (int i = 0; outputs[g][i] != NULL; i++) {
+                struct ggml_tensor * output = outputs[g][i];
+                AT_PRINTF("output: %s\n", output->name);
+                ggml_allocator_free_tensor(alloc, output);
+            }
+        }
+    }
+
+    return alloc->max_size;
+}
+
+size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph) {
+    return ggml_allocator_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL);
+}
diff --git a/ggml-alloc.h b/ggml-alloc.h
new file mode 100644 (file)
index 0000000..a5ec8f8
--- /dev/null
@@ -0,0 +1,22 @@
+#pragma once
+
+#include "ggml.h"
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+
+GGML_API struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment);
+GGML_API struct ggml_allocr * ggml_allocr_new_measure(size_t alignment);
+
+GGML_API void   ggml_allocr_free(struct ggml_allocr * alloc);
+GGML_API bool   ggml_allocr_is_measure(struct ggml_allocr * alloc);
+GGML_API void   ggml_allocr_reset(struct ggml_allocr * alloc);
+GGML_API void   ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor);
+GGML_API size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph);
+
+
+#ifdef  __cplusplus
+}
+#endif
diff --git a/ggml.c b/ggml.c
index b77f9926754ed1cd57369ab1de48db3bb2cda234..fa0f98aa09df202e63dd89cb8ccf3538f3de38eb 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -4557,10 +4557,12 @@ static struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml
 
 static struct ggml_tensor * ggml_new_tensor_impl(
         struct ggml_context * ctx,
-        enum   ggml_type type,
-        int    n_dims,
-        const int64_t* ne,
-        void*  data) {
+        enum   ggml_type      type,
+        int                   n_dims,
+        const int64_t       * ne,
+        void                * data) {
+
+    assert(n_dims >= 1 && n_dims <= GGML_MAX_DIMS);
 
     size_t data_size = 0;
 
@@ -4648,22 +4650,22 @@ static void ggml_set_op_params_i32(struct ggml_tensor * tensor, uint32_t i, int3
 
 struct ggml_tensor * ggml_new_tensor(
         struct ggml_context * ctx,
-        enum   ggml_type type,
-        int    n_dims,
-        const int64_t * ne) {
+        enum   ggml_type      type,
+        int                   n_dims,
+        const int64_t       * ne) {
     return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL);
 }
 
 struct ggml_tensor * ggml_new_tensor_1d(
         struct ggml_context * ctx,
-        enum   ggml_type type,
+        enum   ggml_type      type,
         int64_t ne0) {
     return ggml_new_tensor(ctx, type, 1, &ne0);
 }
 
 struct ggml_tensor * ggml_new_tensor_2d(
         struct ggml_context * ctx,
-        enum   ggml_type type,
+        enum   ggml_type      type,
         int64_t ne0,
         int64_t ne1) {
     const int64_t ne[2] = { ne0, ne1 };
@@ -4672,7 +4674,7 @@ struct ggml_tensor * ggml_new_tensor_2d(
 
 struct ggml_tensor * ggml_new_tensor_3d(
         struct ggml_context * ctx,
-        enum   ggml_type type,
+        enum   ggml_type      type,
         int64_t ne0,
         int64_t ne1,
         int64_t ne2) {
@@ -6238,6 +6240,27 @@ struct ggml_tensor * ggml_reshape_4d(
 
 // ggml_view_1d
 
+static struct ggml_tensor * ggml_view_tensor_offset(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_dims,
+        const int64_t       * ne,
+        size_t                offset) {
+    // don't calculate an offset from an unallocated tensor
+    void * data = NULL;
+    if (a->data != NULL) {
+        data = (char *) a->data + offset;
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, n_dims, ne, data);
+
+    ggml_format_name(result, "%s (view)", a->name);
+
+    ggml_set_op_params(result, &offset, sizeof(offset));
+
+    return result;
+}
+
 struct ggml_tensor * ggml_view_1d(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
@@ -6250,10 +6273,7 @@ struct ggml_tensor * ggml_view_1d(
         is_node = true;
     }
 
-    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset);
-    ggml_format_name(result, "%s (view)", a->name);
-
-    ggml_set_op_params(result, &offset, sizeof(offset));
+    struct ggml_tensor * result = ggml_view_tensor_offset(ctx, a, 1, &ne0, offset);
 
     result->op   = GGML_OP_VIEW;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -6280,10 +6300,7 @@ struct ggml_tensor * ggml_view_2d(
 
     const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, 1, 1 };
 
-    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset);
-    ggml_format_name(result, "%s (view)", a->name);
-
-    ggml_set_op_params(result, &offset, sizeof(offset));
+    struct ggml_tensor * result = ggml_view_tensor_offset(ctx, a, 2, ne, offset);
 
     result->nb[1] = nb1;
     result->nb[2] = result->nb[1]*ne1;
@@ -6316,10 +6333,7 @@ struct ggml_tensor * ggml_view_3d(
 
     const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, ne2, 1 };
 
-    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, (char *) a->data + offset);
-    ggml_format_name(result, "%s (view)", a->name);
-
-    ggml_set_op_params(result, &offset, sizeof(offset));
+    struct ggml_tensor * result = ggml_view_tensor_offset(ctx, a, 3, ne, offset);
 
     result->nb[1] = nb1;
     result->nb[2] = nb2;
@@ -6354,10 +6368,7 @@ struct ggml_tensor * ggml_view_4d(
 
     const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, ne2, ne3 };
 
-    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, (char *) a->data + offset);
-    ggml_format_name(result, "%s (view)", a->name);
-
-    ggml_set_op_params(result, &offset, sizeof(offset));
+    struct ggml_tensor * result = ggml_view_tensor_offset(ctx, a, 4, ne, offset);
 
     result->nb[1] = nb1;
     result->nb[2] = nb2;
@@ -6741,6 +6752,18 @@ struct ggml_tensor * ggml_rope_inplace(
     return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, true);
 }
 
+struct ggml_tensor * ggml_rope_custom(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past,
+        int                   n_dims,
+        int                   mode,
+        int                   n_ctx,
+        float                 freq_base,
+        float                 freq_scale) {
+    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, false);
+}
+
 struct ggml_tensor * ggml_rope_custom_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
diff --git a/ggml.h b/ggml.h
index 9919cce7c263f5f32eeb8052be88593602458876..aba92480c833c8394cfae8c4affde523d23401e6 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -1170,7 +1170,18 @@ extern "C" {
             int                   mode,
             int                   n_ctx);
 
-    // custom RoPE, in-place, returns view(a)
+    // custom RoPE
+    GGML_API struct ggml_tensor * ggml_rope_custom(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past,
+            int                   n_dims,
+            int                   mode,
+            int                   n_ctx,
+            float                 freq_base,
+            float                 freq_scale);
+
+    // in-place, returns view(a)
     GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
index a35c690ea9f68ccbe97c0b78652300fe6653c335..6f381f30f8fe1a33df0a8b4c90008143d1a991cd 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
 
+#if !defined(GGML_USE_CUBLAS) && !defined(GGML_USE_METAL)
+#include "ggml-alloc.h"
+#define LLAMA_USE_ALLOCATOR
+#else
 #define LLAMA_USE_SCRATCH
 #define LLAMA_MAX_SCRATCH_BUFFERS 16
+#endif
+
 
 // available llama models
 enum e_model {
@@ -327,13 +333,22 @@ struct llama_model {
 
 struct llama_context {
     llama_context(const llama_model & model) : model(model), t_load_us(model.t_load_us), t_start_us(model.t_start_us) {}
-#ifdef GGML_USE_METAL
     ~llama_context() {
+        if (model_owner) {
+            delete &model;
+        }
+#ifdef GGML_USE_METAL
         if (ctx_metal) {
             ggml_metal_free(ctx_metal);
         }
-    }
 #endif
+#ifdef LLAMA_USE_ALLOCATOR
+        if (alloc) {
+            ggml_allocr_free(alloc);
+        }
+#endif
+    }
+
     std::mt19937 rng;
 
     bool has_evaluated_once = false;
@@ -371,7 +386,17 @@ struct llama_context {
     // memory buffers used to evaluate the model
     // TODO: move in llama_state
     llama_ctx_buffer buf_compute;
+
+#ifdef LLAMA_USE_ALLOCATOR
+    llama_ctx_buffer buf_alloc;
+    ggml_allocr * alloc = NULL;
+#endif
+
+#ifdef LLAMA_USE_SCRATCH
     llama_ctx_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS];
+    int    buf_last = 0;
+    size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 };
+#endif
 
 #ifdef GGML_USE_METAL
     ggml_metal_context * ctx_metal = NULL;
@@ -381,9 +406,6 @@ struct llama_context {
     ggml_mpi_context * ctx_mpi = NULL;
 #endif
 
-    int    buf_last = 0;
-    size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 };
-
     void use_buf(struct ggml_context * ctx, int i) {
 #if defined(LLAMA_USE_SCRATCH)
         size_t last_size = 0;
@@ -1230,12 +1252,16 @@ static void llama_model_load_internal(
         const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1;
 
         // this is the total memory required to run the inference
-        const size_t mem_required =
+        size_t mem_required =
             ctx_size +
-            mmapped_size - vram_weights + // weights in VRAM not in memory
+            mmapped_size - vram_weights; // weights in VRAM not in memory
+
+#ifndef LLAMA_USE_ALLOCATOR
+        mem_required +=
             MEM_REQ_SCRATCH0(hparams.n_ctx).at(model.type) +
             MEM_REQ_SCRATCH1().at(model.type) +
             MEM_REQ_EVAL().at(model.type);
+#endif
 
         // this is the memory required by one llama_state
         const size_t mem_required_state =
@@ -1360,32 +1386,15 @@ static bool llama_model_load(
     }
 }
 
-// evaluate the transformer
-//
-//   - lctx:      llama context
-//   - tokens:    new batch of tokens to process
-//   - embd       embeddings input
-//   - n_tokens   number of tokens
-//   - n_past:    the context size so far
-//   - n_threads: number of threads to use
-//
-static bool llama_eval_internal(
+static struct ggml_cgraph * llama_build_graph(
          llama_context & lctx,
      const llama_token * tokens,
            const float * embd,
                    int   n_tokens,
-                   int   n_past,
-                   int   n_threads,
-            const char * cgraph_fname) {
+                   int   n_past) {
 
     LLAMA_ASSERT((!tokens && embd) || (tokens && !embd));
 
-#ifdef GGML_USE_MPI
-    ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads);
-#endif
-
-    const int64_t t_start_us = ggml_time_us();
-
     const int N = n_tokens;
 
     const auto & model   = lctx.model;
@@ -1401,10 +1410,8 @@ static bool llama_eval_internal(
     const int64_t n_head      = hparams.n_head;
     const int64_t n_head_kv   = hparams.n_head_kv;
     const int64_t n_embd_head = hparams.n_embd_head();
-    const int64_t n_vocab     = hparams.n_vocab;
     const int64_t n_embd_gqa  = hparams.n_embd_gqa();
 
-
     LLAMA_ASSERT(n_embd_head == hparams.n_rot);
 
     const float freq_base  = hparams.rope_freq_base;
@@ -1416,26 +1423,35 @@ static bool llama_eval_internal(
     auto & mem_per_token = lctx.mem_per_token;
     auto & buf_compute   = lctx.buf_compute;
 
+
     struct ggml_init_params params = {
         /*.mem_size   =*/ buf_compute.size,
         /*.mem_buffer =*/ buf_compute.addr,
         /*.no_alloc   =*/ false,
     };
 
+#ifdef LLAMA_USE_ALLOCATOR
+    params.no_alloc = true;
+#endif
+
     struct ggml_context * ctx0 = ggml_init(params);
 
     ggml_cgraph * gf = ggml_new_graph(ctx0);
 
-    // for big prompts, if BLAS is enabled, it is better to use only one thread
-    // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
-    n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads;
-
     struct ggml_tensor * cur;
     struct ggml_tensor * inpL;
 
     if (tokens) {
         struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
+
+#ifdef LLAMA_USE_ALLOCATOR
+        ggml_allocr_alloc(lctx.alloc, inp_tokens);
+        if (!ggml_allocr_is_measure(lctx.alloc)) {
+            memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
+        }
+#else
         memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
+#endif
         ggml_set_name(inp_tokens, "inp_tokens");
 
         inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
@@ -1445,7 +1461,15 @@ static bool llama_eval_internal(
 #endif
 
         inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N);
+
+#ifdef LLAMA_USE_ALLOCATOR
+        ggml_allocr_alloc(lctx.alloc, inpL);
+        if (!ggml_allocr_is_measure(lctx.alloc)) {
+            memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL));
+        }
+#else
         memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL));
+#endif
     }
 
     const int i_gpu_start = n_layer - n_gpu_layers;
@@ -1472,6 +1496,17 @@ static bool llama_eval_internal(
     }
 #endif // GGML_USE_CUBLAS
 
+    struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+#ifdef LLAMA_USE_ALLOCATOR
+    ggml_allocr_alloc(lctx.alloc, KQ_scale);
+    if (!ggml_allocr_is_measure(lctx.alloc)) {
+        ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
+    }
+#else
+    ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
+#endif
+    ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
+
     for (int il = 0; il < n_layer; ++il) {
         ggml_format_name(inpL, "layer_inp_%d", il);
 
@@ -1567,9 +1602,6 @@ static bool llama_eval_internal(
             ggml_set_name(KQ, "KQ");
 
             // KQ_scaled = KQ / sqrt(n_embd_head)
-            struct ggml_tensor * KQ_scale = ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head));
-            ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
-
             // KQ_scaled shape [n_past + N, N, n_head, 1]
             struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
             offload_func_kq(KQ_scaled);
@@ -1685,9 +1717,6 @@ static bool llama_eval_internal(
 
     lctx.use_buf(ctx0, 0);
 
-    // used at the end to optionally extract the embeddings
-    struct ggml_tensor * embeddings = NULL;
-
     // norm
     {
         cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
@@ -1698,8 +1727,6 @@ static bool llama_eval_internal(
         cur = ggml_mul(ctx0, cur, model.norm);
         // offload_func_nr(cur); // TODO CPU + GPU mirrored backend
         ggml_set_name(cur, "result_norm");
-
-        embeddings = cur;
     }
 
     // lm_head
@@ -1711,12 +1738,82 @@ static bool llama_eval_internal(
     // logits -> probs
     //cur = ggml_soft_max_inplace(ctx0, cur);
 
-    // run the computation
     ggml_build_forward_expand(gf, cur);
 
-    // fprintf(stderr, "graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf.n_nodes, gf.n_leafs);
+    if (mem_per_token == 0) {
+        mem_per_token = ggml_used_mem(ctx0)/N;
+    }
+
+#if 0
+    printf("\n%s: used_mem: eval ctx %.3f MB, scratch %.3f MB %.3f MB, work buf %.3f MB, n_past = %d, N = %d\n", __func__,
+            ggml_used_mem(ctx0)/1024.0/1024.0,
+            lctx.get_buf_max_mem(0)/1024.0/1024.0,
+            lctx.get_buf_max_mem(1)/1024.0/1024.0,
+            lctx.work_buffer.size()/1024.0/1024.0,
+            n_past, N);
+#endif
+
+    ggml_free(ctx0);
+
+    return gf;
+}
+
+// evaluate the transformer
+//
+//   - lctx:      llama context
+//   - tokens:    new batch of tokens to process
+//   - embd       embeddings input
+//   - n_tokens   number of tokens
+//   - n_past:    the context size so far
+//   - n_threads: number of threads to use
+//
+static bool llama_eval_internal(
+         llama_context & lctx,
+     const llama_token * tokens,
+           const float * embd,
+                   int   n_tokens,
+                   int   n_past,
+                   int   n_threads,
+            const char * cgraph_fname) {
+
+    LLAMA_ASSERT((!tokens && embd) || (tokens && !embd));
+
+    const int64_t t_start_us = ggml_time_us();
+
+#ifdef GGML_USE_MPI
+    ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads);
+#endif
+
+    const int N = n_tokens;
+
+    const auto & model   = lctx.model;
+    const auto & hparams = model.hparams;
+
+    const auto & kv_self = lctx.kv_self;
+
+    LLAMA_ASSERT(!!kv_self.ctx);
+
+    const int64_t n_embd      = hparams.n_embd;
+    const int64_t n_vocab     = hparams.n_vocab;
+
+#ifdef LLAMA_USE_ALLOCATOR
+    ggml_allocr_reset(lctx.alloc);
+#endif
+
+    ggml_cgraph * gf = llama_build_graph(lctx, tokens, embd, n_tokens, n_past);
+
+#ifdef LLAMA_USE_ALLOCATOR
+    ggml_allocr_alloc_graph(lctx.alloc, gf);
+#endif
+
+    // fprintf(stderr, "graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
+
+    // for big prompts, if BLAS is enabled, it is better to use only one thread
+    // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
+    n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads;
 
 #if GGML_USE_MPI
+    const int64_t n_layer = hparams.n_layer;
     ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer);
 #endif
 
@@ -1760,6 +1857,10 @@ static bool llama_eval_internal(
     lctx.kv_self.n = n_past + N;
 
     struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
+    struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
+
+    LLAMA_ASSERT(strcmp(res->name, "result_output") == 0);
+    LLAMA_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
 
     if (cgraph_fname) {
         ggml_graph_export(gf, cgraph_fname);
@@ -1798,21 +1899,6 @@ static bool llama_eval_internal(
         memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
     }
 
-    if (mem_per_token == 0) {
-        mem_per_token = ggml_used_mem(ctx0)/N;
-    }
-
-#if 0
-    printf("\n%s: used_mem: eval ctx %.3f MB, scratch %.3f MB %.3f MB, work buf %.3f MB, n_past = %d, N = %d\n", __func__,
-            ggml_used_mem(ctx0)/1024.0/1024.0,
-            lctx.get_buf_max_mem(0)/1024.0/1024.0,
-            lctx.get_buf_max_mem(1)/1024.0/1024.0,
-            lctx.work_buffer.size()/1024.0/1024.0,
-            n_past, N);
-#endif
-
-    ggml_free(ctx0);
-
     // measure the performance only for the single-token evals
     if (N == 1) {
         lctx.t_eval_us += ggml_time_us() - t_start_us;
@@ -3180,10 +3266,47 @@ struct llama_context * llama_new_context_with_model(
             ctx->embedding.resize(hparams.n_embd);
         }
 
+#ifdef LLAMA_USE_ALLOCATOR
+        {
+            static const size_t tensor_alignment = 32;
+            // the compute buffer is used to store the tensor and graph structs, while the allocator buffer is used for the tensor data
+            ctx->buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
+
+            // create measure allocator
+            ctx->alloc = ggml_allocr_new_measure(tensor_alignment);
+
+            // build worst-case graph
+            int n_tokens = std::min((int)hparams.n_ctx, params.n_batch);
+            int n_past = hparams.n_ctx - n_tokens;
+            llama_token token = llama_token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
+            ggml_cgraph * gf = llama_build_graph(*ctx, &token, NULL, n_tokens, n_past);
+
+            // measure memory requirements for the graph
+            size_t alloc_size = ggml_allocr_alloc_graph(ctx->alloc, gf) + tensor_alignment;
+
+            fprintf(stderr, "%s: compute buffer total size = %7.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0);
+
+            // debug - for comparison with scratch buffer
+            //size_t prev_req =
+            //    MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type) +
+            //    MEM_REQ_SCRATCH1().at(ctx->model.type) +
+            //    MEM_REQ_EVAL().at(ctx->model.type);
+            //fprintf(stderr, "%s: (debug) equivalent with scratch buffer = %7.2f MB\n", __func__, prev_req / 1024.0 / 1024.0);
+
+            // recreate allocator with exact memory requirements
+            ggml_allocr_free(ctx->alloc);
+
+            ctx->buf_alloc.resize(alloc_size);
+            ctx->alloc = ggml_allocr_new(ctx->buf_alloc.addr, ctx->buf_alloc.size, tensor_alignment);
+        }
+#else
         ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type) + ggml_graph_overhead());
+#endif
 
+#ifdef LLAMA_USE_SCRATCH
         ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type));
         ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type));
+#endif
     }
 
 #ifdef GGML_USE_METAL
@@ -3253,9 +3376,6 @@ struct llama_context * llama_init_from_file(
 }
 
 void llama_free(struct llama_context * ctx) {
-    if (ctx->model_owner) {
-        delete &ctx->model;
-    }
     delete ctx;
 }