state->opt = new struct ggml_opt_context;
state->opt->ctx = NULL;
state->opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
+ state->opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;
state->opt->loss_after = 0.0f;
return state;
#include "ggml.h"
#include "llama.h"
+#define LLAMA_TRAIN_MAX_NODES 16384
+
typedef std::string mt19937_state;
struct train_state {
struct ggml_tensor * m11xm2 = ggml_mul_mat(ctx, m11, m2);
// printf("Creating compute graph\n");
- struct ggml_cgraph gf = ggml_build_forward(m11xm2);
+ struct ggml_cgraph * gf = ggml_new_graph(ctx);
+ ggml_build_forward_expand(gf, m11xm2);
printf("n_threads=%i\n", benchmark_params.n_threads);
std::vector<uint8_t> work_buffer;
- ggml_graph_compute_helper(work_buffer, &gf, benchmark_params.n_threads);
+ ggml_graph_compute_helper(work_buffer, gf, benchmark_params.n_threads);
- TENSOR_DUMP(gf.nodes[0]);
+ TENSOR_DUMP(gf->nodes[0]);
printf("\n------ Test 2 - Matrix Mult via %s code\n", ggml_type_name(qtype));
struct ggml_tensor * q31 = ggml_mul_mat(ctx, q11, m2);
// printf("Creating compute graph\n");
- struct ggml_cgraph gf31 = ggml_build_forward(q31);
+ struct ggml_cgraph * gf31 = ggml_new_graph(ctx);
+ ggml_build_forward_expand(gf31, q31);
// Set up a second graph computation to make sure we override the CPU cache lines
// printf("Creating new tensor q12 & Running quantize\n");
struct ggml_tensor * q32 = ggml_mul_mat(ctx, q12, m2);
//printf("Creating compute graph\n");
- struct ggml_cgraph gf32 = ggml_build_forward(q32);
+ struct ggml_cgraph * gf32 = ggml_new_graph(ctx);
+ ggml_build_forward_expand(gf32, q32);
printf("n_threads=%i\n", benchmark_params.n_threads);
const int dimx = sizex;
// Let's use the F32 result from above as a reference for the quantized multiplication
- float sum_of_F32_reference = tensor_sum_elements(gf.nodes[0]);
+ float sum_of_F32_reference = tensor_sum_elements(gf->nodes[0]);
printf("Iteration;NThreads; SizeX; SizeY; SizeZ; Required_FLOPS; Elapsed_u_Seconds; gigaFLOPS\n");
printf("=====================================================================================\n");
long long int start = ggml_time_us();
//printf("Running ggml_graph_compute\n");
- ggml_graph_compute_helper(work_buffer, &gf31, benchmark_params.n_threads);
+ ggml_graph_compute_helper(work_buffer, gf31, benchmark_params.n_threads);
long long int stop = ggml_time_us();
long long int usec = stop-start;
// Check that the matrix multiplication result is in the right ballpark
// We cannot use the exact value from the F32 multiplication because the quantizuation will be slightly different
- float sum_of_Q4_result = tensor_sum_elements(gf31.nodes[0]);
+ float sum_of_Q4_result = tensor_sum_elements(gf31->nodes[0]);
float delta = std::abs(sum_of_Q4_result - sum_of_F32_reference);
float allowed_delta = (sum_of_F32_reference) / 1000 / 1000; // Let's accept an epsilon of 10^-6
}
// Running a different graph computation to make sure we override the CPU cache lines
- ggml_graph_compute_helper(work_buffer, &gf32, benchmark_params.n_threads);
+ ggml_graph_compute_helper(work_buffer, gf32, benchmark_params.n_threads);
}
printf("\n");
printf("Average%78.2f\n",gflops_sum/((double)benchmark_params.n_iterations));
}
struct ggml_init_params params_ggml;
- params_ggml.mem_size = ggml_tensor_overhead() * GGML_MAX_NODES;
+ params_ggml.mem_size = ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE;
params_ggml.mem_buffer = NULL;
params_ggml.no_alloc = true;
result->ctx = ggml_init(params_ggml);
float scaling = lora->info.scale * (float)lora->lora_alpha / (float)lora->lora_r;
struct ggml_init_params params;
- params.mem_size = GGML_OBJECT_SIZE + GGML_GRAPH_SIZE + ggml_tensor_overhead()*4 + GGML_MEM_ALIGN*5;
+ params.mem_size = GGML_OBJECT_SIZE + ggml_graph_overhead() + ggml_tensor_overhead()*4 + GGML_MEM_ALIGN*5;
params.mem_buffer = NULL;
params.no_alloc = true;
struct ggml_context * ctx = NULL;
if (enable_checkpointing) {
ggml_build_backward_gradient_checkpointing(ctx, gf, gb, gb_tmp, checkpoints.data(), (int) checkpoints.size());
} else {
- *gb = *gf;
+ ggml_graph_cpy(gf, gb);
ggml_build_backward_expand(ctx, gf, gb, true);
}
opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
opt->params.print_forward_graph = false;
opt->params.print_backward_graph = false;
+ opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;
opt->params.n_threads = params.common.n_threads;
opt->params.past = params.common.opt_past;
opt->params.delta = params.common.opt_delta;
ggml_allocr_free(alloc);
// context for compute tensors without their data
- size_t estimated_compute_size_wo_data = (
- ggml_tensor_overhead()*GGML_MAX_NODES*2
- + (GGML_OBJECT_SIZE+GGML_GRAPH_SIZE)*(
- params.common.use_checkpointing ? 3 : 2
- )
+ const size_t estimated_compute_size_wo_data = (
+ 2*LLAMA_TRAIN_MAX_NODES*ggml_tensor_overhead() +
+ (params.common.use_checkpointing ? 3 : 2)*(GGML_OBJECT_SIZE+ggml_graph_overhead_custom(LLAMA_TRAIN_MAX_NODES, true))
);
struct ggml_init_params ctx_compute_params = {
estimated_compute_size_wo_data, // mem_size
for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) {
ctx_compute = ggml_init(ctx_compute_params);
alloc = ggml_allocr_new_measure(tensor_alignment);
- gf = ggml_new_graph(ctx_compute);
+ gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = (enum ggml_cgraph_eval_order) order;
- gb = ggml_new_graph(ctx_compute);
+ gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gb_tmp = params.common.use_checkpointing
- ? ggml_new_graph(ctx_compute)
+ ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true)
: NULL;
loss = llama_build_lora_finetune_graphs(
&model, &lora, alloc, ctx_compute,
mem_compute_data.resize(max_compute_size);
ctx_compute = ggml_init(ctx_compute_params);
alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment);
- gf = ggml_new_graph(ctx_compute);
+ gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = best_order;
- gb = ggml_new_graph(ctx_compute);
+ gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gb_tmp = params.common.use_checkpointing
- ? ggml_new_graph(ctx_compute)
+ ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true)
: NULL;
loss = llama_build_lora_finetune_graphs(
&model, &lora, alloc, ctx_compute,
// measure mem requirement and allocate
{
static const size_t tensor_alignment = 32;
- new_clip->buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
+ new_clip->buf_compute.resize(ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead());
new_clip->alloc = ggml_allocr_new_measure(tensor_alignment);
clip_image_f32_batch batch;
batch.size = 1;
struct ggml_context * ctx_data = NULL;
struct ggml_context * ctx_eval = NULL;
- struct ggml_cgraph gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
+ struct ggml_cgraph * gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
// this allocates all Metal resources and memory buffers
auto * ctx_metal = ggml_metal_init(1);
// main
{
- struct ggml_tensor * input = ggml_graph_get_tensor(&gf, "embd");
+ struct ggml_tensor * input = ggml_graph_get_tensor(gf, "embd");
*(int32_t *) input->data = 1; // BOS
ggml_metal_set_tensor(ctx_metal, input);
// warmup
- ggml_metal_graph_compute(ctx_metal, &gf);
+ ggml_metal_graph_compute(ctx_metal, gf);
const int n_iter = 16;
// the actual inference happens here
for (int i = 0; i < n_iter; ++i) {
- ggml_metal_graph_compute(ctx_metal, &gf);
+ ggml_metal_graph_compute(ctx_metal, gf);
}
const int64_t t1 = ggml_time_us();
// debug output
{
- struct ggml_tensor * logits = gf.nodes[gf.n_nodes - 1];
+ struct ggml_tensor * logits = gf->nodes[gf->n_nodes - 1];
ggml_metal_get_tensor(ctx_metal, logits);
float * ptr = (float *) ggml_get_data(logits);
if (enable_checkpointing) {
ggml_build_backward_gradient_checkpointing(ctx, gf, gb, gb_tmp, checkpoints.data(), (int) checkpoints.size());
} else {
- *gb = *gf;
+ ggml_graph_cpy(gf, gb);
ggml_build_backward_expand(ctx, gf, gb, true);
}
opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
opt->params.print_forward_graph = false;
opt->params.print_backward_graph = false;
+ opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;
opt->params.n_threads = params.common.n_threads;
opt->params.past = params.common.opt_past;
opt->params.delta = params.common.opt_delta;
ggml_allocr_free(alloc);
// context for compute tensors without their data
- size_t estimated_compute_size_wo_data = (
- ggml_tensor_overhead()*GGML_MAX_NODES*2
- + (GGML_OBJECT_SIZE+GGML_GRAPH_SIZE)*(
- params.common.use_checkpointing ? 3 : 2
- )
+ const size_t estimated_compute_size_wo_data = (
+ 2*LLAMA_TRAIN_MAX_NODES*ggml_tensor_overhead() +
+ (params.common.use_checkpointing ? 3 : 2)*(GGML_OBJECT_SIZE+ggml_graph_overhead_custom(LLAMA_TRAIN_MAX_NODES, true))
);
struct ggml_init_params ctx_compute_params = {
estimated_compute_size_wo_data, // mem_size
for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) {
ctx_compute = ggml_init(ctx_compute_params);
alloc = ggml_allocr_new_measure(tensor_alignment);
- gf = ggml_new_graph(ctx_compute);
+ gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = (enum ggml_cgraph_eval_order) order;
- gb = ggml_new_graph(ctx_compute);
+ gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gb_tmp = params.common.use_checkpointing
- ? ggml_new_graph(ctx_compute)
+ ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true)
: NULL;
loss = llama_build_train_graphs(
&model, alloc, ctx_compute,
mem_compute_data.resize(max_compute_size);
ctx_compute = ggml_init(ctx_compute_params);
alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment);
- gf = ggml_new_graph(ctx_compute);
+ gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = best_order;
- gb = ggml_new_graph(ctx_compute);
+ gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gb_tmp = params.common.use_checkpointing
- ? ggml_new_graph(ctx_compute)
+ ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true)
: NULL;
loss = llama_build_train_graphs(
&model, alloc, ctx_compute,
#include "ggml-alloc.h"
-#include "ggml-backend.h"
+#include "ggml-backend-impl.h"
#include "ggml.h"
+#include "ggml-impl.h"
#include <assert.h>
+#include <limits.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
-
-#define UNUSED(x) (void)(x)
#define MAX(a, b) ((a) > (b) ? (a) : (b))
-#define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
+#define MAX_FREE_BLOCKS 256
//#define GGML_ALLOCATOR_DEBUG
-//#define AT_PRINTF printf
-#define AT_PRINTF(...) ((void)0)
-
-struct hash_node {
- struct ggml_tensor * t;
- int n_children;
- int n_views;
-};
-
-static size_t hash(void * p) {
- return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
-}
-
-static struct hash_node * hash_get(struct hash_node hash_table[], struct ggml_tensor * t) {
- size_t h = hash(t);
-
- // linear probing
- size_t i = h;
- while (hash_table[i].t != NULL) {
- if (hash_table[i].t == t) {
- return &hash_table[i];
- }
- i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
- if (i == h) {
- // hash table is full
- GGML_ASSERT(false);
- }
- }
-
- hash_table[i].t = t;
- return &hash_table[i];
-}
+//#define AT_PRINTF(...) fprintf(stderr, __VA_ARGS__)
+#define AT_PRINTF(...)
// TODO: GGML_PAD ?
static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) {
size_t size;
};
-#define MAX_FREE_BLOCKS 256
-
-struct ggml_allocr {
+struct ggml_tallocr {
struct ggml_backend_buffer * buffer;
bool buffer_owned;
- void * data;
+ void * base;
size_t alignment;
+
int n_free_blocks;
struct free_block free_blocks[MAX_FREE_BLOCKS];
- struct hash_node hash_table[GGML_GRAPH_HASHTABLE_SIZE];
+
size_t max_size;
+
bool measure;
- int parse_seq[GGML_MAX_CONCUR];
- int parse_seq_len;
#ifdef GGML_ALLOCATOR_DEBUG
struct ggml_tensor * allocated_tensors[1024];
};
#ifdef GGML_ALLOCATOR_DEBUG
-static void add_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
+static void add_allocated_tensor(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
for (int i = 0; i < 1024; i++) {
if (alloc->allocated_tensors[i] == NULL) {
alloc->allocated_tensors[i] = tensor;
}
GGML_ASSERT(!"out of allocated_tensors");
}
-static void remove_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
+static void remove_allocated_tensor(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
for (int i = 0; i < 1024; i++) {
if (alloc->allocated_tensors[i] == tensor ||
(alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) {
#endif
// check if a tensor is allocated by this buffer
-static bool ggml_allocr_is_own(struct ggml_allocr * alloc, const struct ggml_tensor * tensor) {
+static bool ggml_tallocr_is_own(ggml_tallocr_t alloc, const struct ggml_tensor * tensor) {
return tensor->buffer == alloc->buffer;
}
return t->view_src != NULL;
}
-void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
+void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
GGML_ASSERT(!ggml_is_view(tensor)); // views generally get data pointer from one of their sources
GGML_ASSERT(tensor->data == NULL); // avoid allocating tensor which already has memory allocated
}
tensor->data = addr;
- AT_PRINTF("%s: allocated data at %p\n", __func__, tensor->data);
tensor->buffer = alloc->buffer;
- ggml_backend_buffer_init_tensor(alloc->buffer, tensor);
+ if (!alloc->measure) {
+ ggml_backend_buffer_init_tensor(alloc->buffer, tensor);
+ }
#ifdef GGML_ALLOCATOR_DEBUG
add_allocated_tensor(alloc, tensor);
}
#endif
- alloc->max_size = MAX(alloc->max_size, (char*)addr - (char*)alloc->data + size);
+ alloc->max_size = MAX(alloc->max_size, (char*)addr - (char*)alloc->base + size);
}
// this is a very naive implementation, but for our case the number of free blocks should be very small
-static void ggml_allocr_free_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
- if (ggml_allocr_is_own(alloc, tensor) == false) {
+static void ggml_tallocr_free_tensor(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
+ if (ggml_tallocr_is_own(alloc, tensor) == false) {
// the tensor was not allocated in this buffer
// this can happen because the graph allocator will try to free weights and other tensors from different buffers
// the easiest way to deal with this is just to ignore it
- AT_PRINTF("ignoring %s (their buffer: %p, our buffer: %p)\n", tensor->name, (void *)tensor->buffer, (void *)alloc->buffer);
+ // AT_PRINTF("ignoring %s (their buffer: %p, our buffer: %p)\n", tensor->name, (void *)tensor->buffer, (void *)alloc->buffer);
return;
}
size = aligned_offset(NULL, size, alloc->alignment);
AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks);
- ggml_backend_buffer_free_tensor(alloc->buffer, tensor);
+ if (!alloc->measure) {
+ ggml_backend_buffer_free_tensor(alloc->buffer, tensor);
+ }
#ifdef GGML_ALLOCATOR_DEBUG
remove_allocated_tensor(alloc, tensor);
alloc->n_free_blocks++;
}
-void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n) {
- for (int i = 0; i < n; i++) {
- alloc->parse_seq[i] = list[i];
- }
- alloc->parse_seq_len = n;
-}
-
-void ggml_allocr_reset(struct ggml_allocr * alloc) {
+void ggml_tallocr_reset(ggml_tallocr_t alloc) {
alloc->n_free_blocks = 1;
- size_t align_offset = aligned_offset(alloc->data, 0, alloc->alignment);
- alloc->free_blocks[0].addr = (char *)alloc->data + align_offset;
- alloc->free_blocks[0].size = ggml_backend_buffer_get_size(alloc->buffer) - align_offset;
+ size_t align_offset = aligned_offset(alloc->base, 0, alloc->alignment);
+ alloc->free_blocks[0].addr = (char *)alloc->base + align_offset;
+
+ if (alloc->measure) {
+ alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows
+ } else {
+ alloc->free_blocks[0].size = ggml_backend_buffer_get_size(alloc->buffer) - align_offset;
+ }
}
-struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment) {
+ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment) {
struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(NULL, data, size);
- struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr));
+ ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr));
- *alloc = (struct ggml_allocr){
+ *alloc = (struct ggml_tallocr) {
/*.buffer = */ buffer,
/*.buffer_owned = */ true,
/*.base = */ ggml_backend_buffer_get_base(buffer),
/*.alignment = */ alignment,
/*.n_free_blocks = */ 0,
/*.free_blocks = */ {{0}},
- /*.hash_table = */ {{0}},
/*.max_size = */ 0,
/*.measure = */ false,
- /*.parse_seq = */ {0},
- /*.parse_seq_len = */ 0,
#ifdef GGML_ALLOCATOR_DEBUG
/*.allocated_tensors = */ {0},
#endif
};
- ggml_allocr_reset(alloc);
+ ggml_tallocr_reset(alloc);
+
+ return alloc;
+}
+
+ggml_tallocr_t ggml_tallocr_new_measure(size_t alignment) {
+ ggml_tallocr_t alloc = ggml_tallocr_new((void *)0x1000, SIZE_MAX/2, alignment);
+ alloc->measure = true;
return alloc;
}
-struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) {
- struct ggml_allocr * alloc = ggml_allocr_new((void *)0x1000, (size_t)-0x1001, alignment);
+ggml_tallocr_t ggml_tallocr_new_measure_from_backend(struct ggml_backend * backend) {
+ // create a backend buffer to get the correct tensor allocation sizes
+ ggml_backend_buffer_t buffer = ggml_backend_alloc_buffer(backend, 1);
+
+ // TODO: move alloc initialization to a common ggml_tallocr_new_impl function
+ ggml_tallocr_t alloc = ggml_tallocr_new_from_buffer(buffer);
+ alloc->buffer_owned = true;
alloc->measure = true;
+ ggml_tallocr_reset(alloc);
+ return alloc;
+}
+ggml_tallocr_t ggml_tallocr_new_from_backend(struct ggml_backend * backend, size_t size) {
+ ggml_backend_buffer_t buffer = ggml_backend_alloc_buffer(backend, size);
+ ggml_tallocr_t alloc = ggml_tallocr_new_from_buffer(buffer);
+ alloc->buffer_owned = true;
return alloc;
}
-struct ggml_allocr * ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer) {
- struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr));
+ggml_tallocr_t ggml_tallocr_new_from_buffer(struct ggml_backend_buffer * buffer) {
+ ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr));
- *alloc = (struct ggml_allocr){
+ *alloc = (struct ggml_tallocr) {
/*.buffer = */ buffer,
/*.buffer_owned = */ false,
/*.base = */ ggml_backend_buffer_get_base(buffer),
/*.alignment = */ ggml_backend_buffer_get_alignment(buffer),
/*.n_free_blocks = */ 0,
/*.free_blocks = */ {{0}},
- /*.hash_table = */ {{0}},
/*.max_size = */ 0,
/*.measure = */ false,
- /*.parse_seq = */ {0},
- /*.parse_seq_len = */ 0,
#ifdef GGML_ALLOCATOR_DEBUG
/*.allocated_tensors = */ {0},
#endif
};
- ggml_allocr_reset(alloc);
+ ggml_tallocr_reset(alloc);
return alloc;
}
-void ggml_allocr_free(struct ggml_allocr * alloc) {
+struct ggml_backend_buffer * ggml_tallocr_get_buffer(ggml_tallocr_t alloc) {
+ return alloc->buffer;
+}
+
+void ggml_tallocr_free(ggml_tallocr_t alloc) {
+ if (alloc == NULL) {
+ return;
+ }
+
if (alloc->buffer_owned) {
ggml_backend_buffer_free(alloc->buffer);
}
free(alloc);
}
-bool ggml_allocr_is_measure(struct ggml_allocr * alloc) {
+bool ggml_tallocr_is_measure(ggml_tallocr_t alloc) {
return alloc->measure;
}
-//////////// compute graph allocator
+size_t ggml_tallocr_max_size(ggml_tallocr_t alloc) {
+ return alloc->max_size;
+}
+
+// graph allocator
+
+struct hash_node {
+ int n_children;
+ int n_views;
+};
+
+struct ggml_gallocr {
+ ggml_tallocr_t talloc;
+ struct ggml_hash_set hash_set;
+ struct hash_node * hash_values;
+ size_t hash_values_size;
+ ggml_tallocr_t * hash_allocs;
+ int * parse_seq;
+ int parse_seq_len;
+};
+
+ggml_gallocr_t ggml_gallocr_new(void) {
+ ggml_gallocr_t galloc = (ggml_gallocr_t)malloc(sizeof(struct ggml_gallocr));
+
+ *galloc = (struct ggml_gallocr) {
+ /*.talloc = */ NULL,
+ /*.hash_set = */ {0},
+ /*.hash_values = */ NULL,
+ /*.hash_values_size = */ 0,
+ /*.hash_allocs = */ NULL,
+ /*.parse_seq = */ NULL,
+ /*.parse_seq_len = */ 0,
+ };
+
+ return galloc;
+}
+
+void ggml_gallocr_free(ggml_gallocr_t galloc) {
+ if (galloc == NULL) {
+ return;
+ }
+
+ if (galloc->hash_set.keys != NULL) {
+ free(galloc->hash_set.keys);
+ }
+ if (galloc->hash_values != NULL) {
+ free(galloc->hash_values);
+ }
+ if (galloc->hash_allocs != NULL) {
+ free(galloc->hash_allocs);
+ }
+ if (galloc->parse_seq != NULL) {
+ free(galloc->parse_seq);
+ }
+ free(galloc);
+}
+
+void ggml_gallocr_set_parse_seq(ggml_gallocr_t galloc, const int * list, int n) {
+ free(galloc->parse_seq);
+ galloc->parse_seq = malloc(sizeof(int) * n);
+
+ for (int i = 0; i < n; i++) {
+ galloc->parse_seq[i] = list[i];
+ }
+ galloc->parse_seq_len = n;
+}
+
+static struct hash_node * hash_get(ggml_gallocr_t galloc, struct ggml_tensor * t) {
+ size_t i = ggml_hash_find_or_insert(galloc->hash_set, t);
+ return &galloc->hash_values[i];
+}
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
if (a->type != b->type) {
}
}
-static void init_view(struct ggml_allocr * alloc, struct ggml_tensor * view, bool update_backend) {
- assert(view->view_src != NULL && view->view_src->data != NULL);
+static ggml_tallocr_t node_tallocr(ggml_gallocr_t galloc, struct ggml_tensor * node) {
+ if (galloc->talloc != NULL) {
+ return galloc->talloc;
+ }
+
+ return galloc->hash_allocs[ggml_hash_find_or_insert(galloc->hash_set, node)];
+}
+
+static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool update_backend) {
+ ggml_tallocr_t alloc = node_tallocr(galloc, view);
+ //printf("init_view: %s from src %s\n", view->name, view->view_src->name);
+ GGML_ASSERT(view->view_src != NULL && view->view_src->data != NULL);
if (update_backend) {
view->backend = view->view_src->backend;
}
-
view->buffer = view->view_src->buffer;
view->data = (char *)view->view_src->data + view->view_offs;
// FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend
// due to the ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras
- assert(ggml_allocr_is_measure(alloc) || !view->buffer || view->buffer->backend == alloc->buffer->backend);
- ggml_backend_buffer_init_tensor(alloc->buffer, view);
+ assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->backend == alloc->buffer->backend);
+
+ if (!alloc->measure) {
+ ggml_backend_buffer_init_tensor(alloc->buffer, view);
+ }
}
-static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node) {
- struct hash_node * ht = alloc->hash_table;
+static void allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node) {
+ ggml_tallocr_t alloc = node_tallocr(galloc, node);
+
if (node->data == NULL) {
if (ggml_is_view(node)) {
- init_view(alloc, node, true);
+ init_view(galloc, node, true);
} else {
// see if we can reuse a parent's buffer (inplace)
if (ggml_op_can_inplace(node->op)) {
}
// if the node's data is external, then we cannot re-use it
- if (ggml_allocr_is_own(alloc, parent) == false) {
+ if (ggml_tallocr_is_own(alloc, parent) == false) {
AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data);
continue;
}
- struct hash_node * p_hn = hash_get(ht, parent);
+ struct hash_node * p_hn = hash_get(galloc, parent);
if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && ggml_are_same_layout(node, parent)) {
if (ggml_is_view(parent)) {
struct ggml_tensor * view_src = parent->view_src;
- struct hash_node * view_src_hn = hash_get(ht, view_src);
+ struct hash_node * view_src_hn = hash_get(galloc, view_src);
if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
// TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite
// the parent's data that it will need later (same layout requirement). the problem is that then
AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name);
node->view_src = view_src;
view_src_hn->n_views += 1;
- init_view(alloc, node, false);
+ init_view(galloc, node, false);
return;
}
} else {
AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
node->view_src = parent;
p_hn->n_views += 1;
- init_view(alloc, node, false);
+ init_view(galloc, node, false);
return;
}
}
}
}
- ggml_allocr_alloc(alloc, node);
+ ggml_tallocr_alloc(alloc, node);
}
}
}
-size_t ggml_allocr_alloc_graph_n(
- struct ggml_allocr * alloc,
- struct ggml_cgraph ** graphs, int n_graphs,
- struct ggml_tensor *** inputs, struct ggml_tensor *** outputs) {
+static void free_node(ggml_gallocr_t galloc, struct ggml_tensor * node) {
+ ggml_tallocr_t alloc = node_tallocr(galloc, node);
- // reset hash table
- struct hash_node * ht = alloc->hash_table;
- memset(ht, 0, sizeof(struct hash_node) * GGML_GRAPH_HASHTABLE_SIZE);
+ ggml_tallocr_free_tensor(alloc, node);
+}
+
+static void ggml_tallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgraph * gf) {
+ const int * parse_seq = galloc->parse_seq;
+ int parse_seq_len = galloc->parse_seq_len;
// count number of children and views
- for (int g = 0; g < n_graphs; g++) {
- struct ggml_cgraph * gf = graphs[g];
- for (int i = 0; i < gf->n_nodes; i++) {
+ for (int i = 0; i < gf->n_nodes; i++) {
+ struct ggml_tensor * node = gf->nodes[i];
+
+ if (ggml_is_view(node)) {
+ struct ggml_tensor * view_src = node->view_src;
+ hash_get(galloc, view_src)->n_views += 1;
+ if (node->buffer == NULL && node->data != NULL) {
+ // view of a pre-allocated tensor, didn't call init_view() yet
+ init_view(galloc, node, true);
+ }
+ }
+
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ struct ggml_tensor * parent = node->src[j];
+ if (parent == NULL) {
+ break;
+ }
+ hash_get(galloc, parent)->n_children += 1;
+ if (ggml_is_view(parent) && parent->buffer == NULL && parent->data != NULL) {
+ init_view(galloc, parent, true);
+ }
+ }
+ }
+
+ // allocate tensors
+ // if we have parse_seq then we allocate nodes following the list, and we only free nodes at barriers
+ int last_barrier_pos = 0;
+ int n_nodes = parse_seq_len ? parse_seq_len : gf->n_nodes;
+
+ for (int ind = 0; ind < n_nodes; ind++) {
+ // allocate a node if there is no parse_seq or this is not a barrier
+ if (parse_seq_len == 0 || parse_seq[ind] != -1) {
+ int i = parse_seq_len ? parse_seq[ind] : ind;
struct ggml_tensor * node = gf->nodes[i];
- if (ggml_is_view(node)) {
- struct ggml_tensor * view_src = node->view_src;
- hash_get(ht, view_src)->n_views += 1;
- if (node->buffer == NULL && node->data != NULL) {
- // view of a pre-allocated tensor, didn't call init_view() yet
- init_view(alloc, node, true);
+ // allocate parents (leafs)
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ struct ggml_tensor * parent = node->src[j];
+ if (parent == NULL) {
+ break;
}
+ allocate_node(galloc, parent);
}
+ // allocate node
+ allocate_node(galloc, node);
+
+ AT_PRINTF("exec: %s (%s) <= ", ggml_op_name(node->op), node->name);
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * parent = node->src[j];
if (parent == NULL) {
break;
}
- hash_get(ht, parent)->n_children += 1;
- if (ggml_is_view(parent) && parent->buffer == NULL && parent->data != NULL) {
- init_view(alloc, parent, true);
+ AT_PRINTF("%s", parent->name);
+ if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
+ AT_PRINTF(", ");
}
}
+ AT_PRINTF("\n");
}
- }
-
- // allocate tensors
- for (int g = 0; g < n_graphs; g++) {
- struct ggml_cgraph * gf = graphs[g];
- AT_PRINTF("####### graph %d/%d\n", g, n_graphs);
- // graph inputs are allocated first to ensure that they are not overwritten by each other
- if (inputs != NULL && inputs[g] != NULL) {
- for (int i = 0; inputs[g][i] != NULL; i++) {
- struct ggml_tensor * input = inputs[g][i];
- AT_PRINTF("input: %s\n", input->name);
- allocate_node(alloc, input);
- }
- }
- // if we have parse_seq then we allocate nodes following the list, and we only free nodes at barriers
- int last_barrier_pos = 0;
- int n_nodes = alloc->parse_seq_len ? alloc->parse_seq_len : gf->n_nodes;
- for (int ind = 0; ind < n_nodes; ind++) {
- // allocate a node if there is no parse_seq or this is not a barrier
- if ((alloc->parse_seq_len==0) || alloc->parse_seq[ind] != -1) {
- int i = alloc->parse_seq_len ? alloc->parse_seq[ind] : ind;
- struct ggml_tensor * node = gf->nodes[i];
+ // update parents
+ // update immediately if there is no parse_seq
+ // update only at barriers if there is parse_seq
+ if ((parse_seq_len == 0) || parse_seq[ind] == -1) {
+ int update_start = parse_seq_len ? last_barrier_pos : ind;
+ int update_end = parse_seq_len ? ind : ind + 1;
+ for (int i = update_start; i < update_end; i++) {
+ int node_i = parse_seq_len ? parse_seq[i] : i;
+ struct ggml_tensor * node = gf->nodes[node_i];
- // allocate parents (leafs)
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * parent = node->src[j];
if (parent == NULL) {
break;
}
- allocate_node(alloc, parent);
- }
+ struct hash_node * p_hn = hash_get(galloc, parent);
+ p_hn->n_children -= 1;
- // allocate node
- allocate_node(alloc, node);
+ //AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views);
- AT_PRINTF("exec: %s (%s) <= ", ggml_op_name(node->op), node->name);
- for (int j = 0; j < GGML_MAX_SRC; j++) {
- struct ggml_tensor * parent = node->src[j];
- if (parent == NULL) {
- break;
- }
- AT_PRINTF("%s", parent->name);
- if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
- AT_PRINTF(", ");
- }
- }
- AT_PRINTF("\n");
- }
-
- // update parents
- // update immediately if there is no parse_seq
- // update only at barriers if there is parse_seq
- if ((alloc->parse_seq_len == 0) || alloc->parse_seq[ind] == -1) {
- int update_start = alloc->parse_seq_len ? last_barrier_pos : ind;
- int update_end = alloc->parse_seq_len ? ind : ind + 1;
- for (int i = update_start; i < update_end; i++) {
- int node_i = alloc->parse_seq_len ? alloc->parse_seq[i] : i;
- struct ggml_tensor * node = gf->nodes[node_i];
-
- for (int j = 0; j < GGML_MAX_SRC; j++) {
- struct ggml_tensor * parent = node->src[j];
- if (parent == NULL) {
- break;
- }
- struct hash_node * p_hn = hash_get(ht, parent);
- p_hn->n_children -= 1;
-
- //AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views);
-
- if (p_hn->n_children == 0 && p_hn->n_views == 0) {
- if (ggml_is_view(parent)) {
- struct ggml_tensor * view_src = parent->view_src;
- struct hash_node * view_src_hn = hash_get(ht, view_src);
- view_src_hn->n_views -= 1;
- AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
- if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
- ggml_allocr_free_tensor(alloc, view_src);
- }
- }
- else {
- if (parent->data != node->data) {
- ggml_allocr_free_tensor(alloc, parent);
- }
+ if (p_hn->n_children == 0 && p_hn->n_views == 0) {
+ if (ggml_is_view(parent)) {
+ struct ggml_tensor * view_src = parent->view_src;
+ struct hash_node * view_src_hn = hash_get(galloc, view_src);
+ view_src_hn->n_views -= 1;
+ AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
+ if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0) {
+ free_node(galloc, view_src);
}
}
+ else {
+ free_node(galloc, parent);
+ }
}
}
- AT_PRINTF("\n");
- if (alloc->parse_seq_len) {
- last_barrier_pos = ind + 1;
- }
}
- }
- // free graph outputs here that wouldn't be freed otherwise because they have no children
- if (outputs != NULL && outputs[g] != NULL) {
- for (int i = 0; outputs[g][i] != NULL; i++) {
- struct ggml_tensor * output = outputs[g][i];
- AT_PRINTF("output: %s\n", output->name);
- ggml_allocr_free_tensor(alloc, output);
+ AT_PRINTF("\n");
+ if (parse_seq_len) {
+ last_barrier_pos = ind + 1;
}
}
}
+}
- return alloc->max_size;
+size_t ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, ggml_tallocr_t talloc, struct ggml_cgraph * graph) {
+ size_t hash_size = graph->visited_hash_table.size;
+
+ // check if the hash table is initialized and large enough
+ if (galloc->hash_set.size < hash_size) {
+ if (galloc->hash_set.keys != NULL) {
+ free(galloc->hash_set.keys);
+ }
+ if (galloc->hash_values != NULL) {
+ free(galloc->hash_values);
+ }
+ galloc->hash_set.keys = malloc(sizeof(struct ggml_tensor *) * hash_size);
+ galloc->hash_set.size = hash_size;
+ galloc->hash_values = malloc(sizeof(struct hash_node) * hash_size);
+ }
+
+ // reset hash table
+ memset(galloc->hash_set.keys, 0, sizeof(struct ggml_tensor *) * hash_size);
+ memset(galloc->hash_values, 0, sizeof(struct hash_node) * hash_size);
+
+ galloc->talloc = talloc;
+ ggml_tallocr_alloc_graph_impl(galloc, graph);
+ galloc->talloc = NULL;
+
+ size_t max_size = ggml_tallocr_max_size(talloc);
+
+ return max_size;
}
-size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph) {
- return ggml_allocr_alloc_graph_n(alloc, &graph, 1, NULL, NULL);
+void ggml_gallocr_alloc_graph_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, struct ggml_hash_set hash_set, ggml_tallocr_t * hash_node_talloc) {
+ const size_t hash_size = hash_set.size;
+
+ GGML_ASSERT(hash_size >= (size_t)(graph->n_nodes + graph->n_leafs));
+
+ galloc->talloc = NULL;
+
+ // alloc hash_values if needed
+ if (galloc->hash_values == NULL || galloc->hash_values_size < hash_size) {
+ free(galloc->hash_values);
+ galloc->hash_values = malloc(sizeof(struct hash_node) * hash_size);
+ galloc->hash_values_size = hash_size;
+ }
+
+ // free hash_set.keys if needed
+ if (galloc->hash_set.keys != NULL) {
+ free(galloc->hash_set.keys);
+ }
+ galloc->hash_set = hash_set;
+
+ // reset hash values
+ memset(galloc->hash_values, 0, sizeof(struct hash_node) * hash_size);
+
+ galloc->hash_allocs = hash_node_talloc;
+
+ ggml_tallocr_alloc_graph_impl(galloc, graph);
+
+ // remove unowned resources
+ galloc->hash_set.keys = NULL;
+ galloc->hash_allocs = NULL;
}
-size_t ggml_allocr_max_size(struct ggml_allocr * alloc) {
- return alloc->max_size;
+// legacy API wrapper
+
+struct ggml_allocr {
+ ggml_tallocr_t talloc;
+ ggml_gallocr_t galloc;
+};
+
+static ggml_allocr_t ggml_allocr_new_impl(ggml_tallocr_t talloc) {
+ ggml_allocr_t alloc = (ggml_allocr_t)malloc(sizeof(struct ggml_allocr));
+ *alloc = (struct ggml_allocr) {
+ /*.talloc = */ talloc,
+ /*.galloc = */ ggml_gallocr_new(),
+ };
+ return alloc;
+}
+
+ggml_allocr_t ggml_allocr_new(void * data, size_t size, size_t alignment) {
+ return ggml_allocr_new_impl(ggml_tallocr_new(data, size, alignment));
+}
+
+ggml_allocr_t ggml_allocr_new_measure(size_t alignment) {
+ return ggml_allocr_new_impl(ggml_tallocr_new_measure(alignment));
+}
+
+ggml_allocr_t ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer) {
+ return ggml_allocr_new_impl(ggml_tallocr_new_from_buffer(buffer));
+}
+
+ggml_allocr_t ggml_allocr_new_from_backend(struct ggml_backend * backend, size_t size) {
+ return ggml_allocr_new_impl(ggml_tallocr_new_from_backend(backend, size));
+}
+
+ggml_allocr_t ggml_allocr_new_measure_from_backend(struct ggml_backend * backend) {
+ return ggml_allocr_new_impl(ggml_tallocr_new_measure_from_backend(backend));
+}
+
+struct ggml_backend_buffer * ggml_allocr_get_buffer(ggml_allocr_t alloc) {
+ return ggml_tallocr_get_buffer(alloc->talloc);
+}
+
+void ggml_allocr_set_parse_seq(ggml_allocr_t alloc, const int * list, int n) {
+ ggml_gallocr_set_parse_seq(alloc->galloc, list, n);
+}
+
+void ggml_allocr_free(ggml_allocr_t alloc) {
+ ggml_gallocr_free(alloc->galloc);
+ ggml_tallocr_free(alloc->talloc);
+ free(alloc);
+}
+
+bool ggml_allocr_is_measure(ggml_allocr_t alloc) {
+ return ggml_tallocr_is_measure(alloc->talloc);
+}
+
+void ggml_allocr_reset(ggml_allocr_t alloc) {
+ ggml_tallocr_reset(alloc->talloc);
+}
+
+void ggml_allocr_alloc(ggml_allocr_t alloc, struct ggml_tensor * tensor) {
+ ggml_tallocr_alloc(alloc->talloc, tensor);
+}
+
+size_t ggml_allocr_max_size(ggml_allocr_t alloc) {
+ return ggml_tallocr_max_size(alloc->talloc);
+}
+
+size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph) {
+ return ggml_gallocr_alloc_graph(alloc->galloc, alloc->talloc, graph);
}
extern "C" {
#endif
+struct ggml_backend;
struct ggml_backend_buffer;
-GGML_API struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment);
-GGML_API struct ggml_allocr * ggml_allocr_new_measure(size_t alignment);
-GGML_API struct ggml_allocr * ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer);
+//
+// Legacy API
+//
+
+typedef struct ggml_allocr * ggml_allocr_t;
+
+// initialize allocator for use with CPU backend only
+GGML_API ggml_allocr_t ggml_allocr_new(void * data, size_t size, size_t alignment);
+GGML_API ggml_allocr_t ggml_allocr_new_measure(size_t alignment);
+
+// initialize allocator for use with ggml-backend
+GGML_API ggml_allocr_t ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer);
+GGML_API ggml_allocr_t ggml_allocr_new_from_backend(struct ggml_backend * backend, size_t size); // allocates an owned buffer
+GGML_API ggml_allocr_t ggml_allocr_new_measure_from_backend(struct ggml_backend * backend);
+
+GGML_API struct ggml_backend_buffer * ggml_allocr_get_buffer(ggml_allocr_t alloc);
// tell the allocator to parse nodes following the order described in the list
// you should call this if your graph are optimized to execute out-of-order
-GGML_API void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n);
-
-GGML_API void ggml_allocr_free (struct ggml_allocr * alloc);
-GGML_API bool ggml_allocr_is_measure (struct ggml_allocr * alloc);
-GGML_API void ggml_allocr_reset (struct ggml_allocr * alloc);
-GGML_API void ggml_allocr_alloc (struct ggml_allocr * alloc, struct ggml_tensor * tensor);
-GGML_API size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph);
-GGML_API size_t ggml_allocr_max_size (struct ggml_allocr * alloc);
-
-GGML_API size_t ggml_allocr_alloc_graph_n(
- struct ggml_allocr * alloc,
- struct ggml_cgraph ** graphs, int n_graphs,
- struct ggml_tensor *** inputs, struct ggml_tensor *** outputs);
+GGML_API void ggml_allocr_set_parse_seq(ggml_allocr_t alloc, const int * list, int n);
+
+GGML_API void ggml_allocr_free (ggml_allocr_t alloc);
+GGML_API bool ggml_allocr_is_measure (ggml_allocr_t alloc);
+GGML_API void ggml_allocr_reset (ggml_allocr_t alloc);
+GGML_API void ggml_allocr_alloc (ggml_allocr_t alloc, struct ggml_tensor * tensor);
+GGML_API size_t ggml_allocr_max_size (ggml_allocr_t alloc);
+
+GGML_API size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph);
+
+//
+// ggml-backend v2 API
+//
+
+// Seperate tensor and graph allocator objects
+// This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators
+// The original API is kept as a wrapper around the new API
+
+// Tensor allocator
+typedef struct ggml_tallocr * ggml_tallocr_t;
+
+GGML_API ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment);
+GGML_API ggml_tallocr_t ggml_tallocr_new_measure(size_t alignment);
+GGML_API ggml_tallocr_t ggml_tallocr_new_from_buffer(struct ggml_backend_buffer * buffer);
+GGML_API ggml_tallocr_t ggml_tallocr_new_from_backend(struct ggml_backend * backend, size_t size); // allocates an owned buffer
+GGML_API ggml_tallocr_t ggml_tallocr_new_measure_from_backend(struct ggml_backend * backend);
+
+GGML_API struct ggml_backend_buffer * ggml_tallocr_get_buffer(ggml_tallocr_t talloc);
+
+GGML_API void ggml_tallocr_free (ggml_tallocr_t talloc);
+GGML_API bool ggml_tallocr_is_measure (ggml_tallocr_t talloc);
+GGML_API void ggml_tallocr_reset (ggml_tallocr_t talloc);
+GGML_API void ggml_tallocr_alloc (ggml_tallocr_t talloc, struct ggml_tensor * tensor);
+GGML_API size_t ggml_tallocr_max_size (ggml_tallocr_t talloc);
+
+
+// Graph allocator
+typedef struct ggml_gallocr * ggml_gallocr_t;
+
+GGML_API ggml_gallocr_t ggml_gallocr_new(void);
+GGML_API void ggml_gallocr_free(ggml_gallocr_t galloc);
+
+GGML_API void ggml_gallocr_set_parse_seq(ggml_gallocr_t galloc, const int * list, int n);
+GGML_API size_t ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, ggml_tallocr_t talloc, struct ggml_cgraph * graph);
+
+// Allocate tensors from the allocators given by the hash table
+GGML_API void ggml_gallocr_alloc_graph_n(
+ ggml_gallocr_t galloc,
+ struct ggml_cgraph * graph,
+ struct ggml_hash_set hash_set,
+ ggml_tallocr_t * hash_node_talloc);
#ifdef __cplusplus
}
--- /dev/null
+#pragma once
+
+// ggml-backend internal header
+
+#include "ggml-backend.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+ //
+ // Backend buffer
+ //
+
+ typedef void * ggml_backend_buffer_context_t;
+
+ struct ggml_backend_buffer_i {
+ void (*free_buffer) (ggml_backend_buffer_t buffer);
+ void * (*get_base) (ggml_backend_buffer_t buffer); // get base pointer
+ size_t (*get_alloc_size)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-allocation callback
+ void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // post-allocation callback
+ void (*free_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-free callback
+ };
+
+ struct ggml_backend_buffer {
+ struct ggml_backend_buffer_i iface;
+
+ ggml_backend_t backend;
+ ggml_backend_buffer_context_t context;
+
+ size_t size;
+ };
+
+ GGML_API ggml_backend_buffer_t ggml_backend_buffer_init(
+ struct ggml_backend * backend,
+ struct ggml_backend_buffer_i iface,
+ ggml_backend_buffer_context_t context,
+ size_t size);
+
+ //
+ // Backend
+ //
+
+ typedef void * ggml_backend_context_t;
+
+ struct ggml_backend_i {
+ const char * (*get_name)(ggml_backend_t backend);
+
+ void (*free)(ggml_backend_t backend);
+
+ // buffer allocation
+ ggml_backend_buffer_t (*alloc_buffer)(ggml_backend_t backend, size_t size);
+
+ // get buffer alignment
+ size_t (*get_alignment)(ggml_backend_t backend);
+
+ // tensor data access
+ // these functions can be asynchronous, helper functions are provided for synchronous access that automatically call synchronize
+ void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+ void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
+ void (*synchronize) (ggml_backend_t backend);
+
+ // (optional) copy tensor between different backends, allow for single-copy tranfers
+ void (*cpy_tensor_from)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
+ void (*cpy_tensor_to) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
+
+ // compute graph with a plan
+ ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
+ void (*graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
+ void (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
+
+ // compute graph without a plan
+ void (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
+
+ // check if the backend supports an operation
+ bool (*supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
+ };
+
+ struct ggml_backend {
+ struct ggml_backend_i iface;
+
+ ggml_backend_context_t context;
+ };
+
+#ifdef __cplusplus
+}
+#endif
-#include "ggml-backend.h"
+#include "ggml-backend-impl.h"
#include "ggml-alloc.h"
+#include "ggml-impl.h"
#include <assert.h>
+#include <limits.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
}
void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
+ if (buffer == NULL) {
+ return;
+ }
+
if (buffer->iface.free_buffer != NULL) {
buffer->iface.free_buffer(buffer);
}
return ggml_backend_get_alignment(buffer->backend);
}
-void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
- return buffer->iface.get_base(buffer);
-}
-
size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) {
return buffer->size;
}
+void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
+ void * base = buffer->iface.get_base(buffer);
+
+ GGML_ASSERT(base != NULL && "backend buffer base cannot be NULL");
+
+ return base;
+}
+
size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+ // get_alloc_size is optional, defaults to ggml_nbytes
if (buffer->iface.get_alloc_size) {
return buffer->iface.get_alloc_size(buffer, tensor);
}
}
void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+ // init_tensor is optional
if (buffer->iface.init_tensor) {
buffer->iface.init_tensor(buffer, tensor);
}
}
void ggml_backend_buffer_free_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+ // free_tensor is optional
if (buffer->iface.free_tensor) {
buffer->iface.free_tensor(buffer, tensor);
}
// backend
ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor) {
- return tensor->buffer->backend;
+ return tensor->buffer ? tensor->buffer->backend : NULL;
}
const char * ggml_backend_name(ggml_backend_t backend) {
+ if (backend == NULL) {
+ return "NULL";
+ }
return backend->iface.get_name(backend);
}
void ggml_backend_free(ggml_backend_t backend) {
+ if (backend == NULL) {
+ return;
+ }
+
backend->iface.free(backend);
}
}
void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
- ggml_get_backend(tensor)->iface.set_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size);
- ggml_get_backend(tensor)->iface.synchronize(ggml_get_backend(tensor));
+ ggml_backend_t backend = ggml_get_backend(tensor);
+
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+ GGML_ASSERT(backend != NULL && "tensor backend not set");
+
+ backend->iface.set_tensor_async(backend, tensor, data, offset, size);
+ backend->iface.synchronize(backend);
}
void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
- ggml_get_backend(tensor)->iface.get_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size);
- ggml_get_backend(tensor)->iface.synchronize(ggml_get_backend(tensor));
+ ggml_backend_t backend = ggml_get_backend(tensor);
+
+ GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+ GGML_ASSERT(backend != NULL && "tensor backend not set");
+
+ backend->iface.get_tensor_async(backend, tensor, data, offset, size);
+ backend->iface.synchronize(backend);
}
void ggml_backend_synchronize(ggml_backend_t backend) {
//printf("dst: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", dst->name, (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], (int)dst->nb[0], (int)dst->nb[1], (int)dst->nb[2], (int)dst->nb[3]);
GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
- // printf("cpy tensor %s from %s to %s (%lu bytes)\n", src->name, ggml_backend_name(src->backend), ggml_backend_name(dst->backend), ggml_nbytes(src));
+ // fprintf(stderr, "cpy tensor %s from %s to %s (%lu bytes)\n", src->name, ggml_backend_name(src->backend), ggml_backend_name(dst->backend), ggml_nbytes(src));
if (src == dst) {
return;
size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned
void * data = malloc(size); // TODO: maybe use GGML_ALIGNED_MALLOC?
+ GGML_ASSERT(data != NULL && "failed to allocate buffer");
+
return ggml_backend_buffer_init(backend, cpu_backend_buffer_i, data, size);
}
}
static void ggml_backend_cpu_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
- // for a backend such as CUDA that can queue async calls, it is ok to do this asynchronously, but it may not be the case for other backends
- ggml_backend_tensor_set_async(dst, src->data, 0, ggml_nbytes(src));
+ ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
UNUSED(backend);
}
ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size) {
return ggml_backend_buffer_init(backend_cpu, cpu_backend_buffer_i_from_ptr, ptr, size);
}
+
+// scheduler
+
+#define GGML_MAX_BACKENDS 4
+#define GGML_MAX_SPLITS 256
+#define GGML_MAX_SPLIT_INPUTS 16
+
+struct ggml_backend_sched_split {
+ ggml_tallocr_t tallocr;
+ int i_start;
+ int i_end;
+ struct ggml_tensor * inputs[GGML_MAX_SPLIT_INPUTS];
+ int n_inputs;
+ struct ggml_cgraph * graph;
+};
+
+struct ggml_backend_sched {
+ int n_backends;
+ ggml_backend_t backends[GGML_MAX_BACKENDS];
+ ggml_tallocr_t tallocs[GGML_MAX_BACKENDS];
+
+ ggml_gallocr_t galloc;
+
+ struct ggml_hash_set hash_set;
+ ggml_tallocr_t * node_talloc; // [hash_set.size]
+ struct ggml_tensor * (* node_copies)[GGML_MAX_BACKENDS]; // [hash_set.size][GGML_MAX_BACKENDS]
+
+ struct ggml_cgraph * graph;
+ struct ggml_backend_sched_split splits[GGML_MAX_SPLITS];
+ int n_splits;
+
+ struct ggml_context * ctx;
+
+ // align context_buffer to GGML_MEM_ALIGN
+ #ifdef _MSC_VER
+ __declspec(align(GGML_MEM_ALIGN))
+ #else
+ __attribute__((aligned(GGML_MEM_ALIGN)))
+ #endif
+ char context_buffer[GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS*sizeof(struct ggml_tensor) + GGML_MAX_SPLITS*sizeof(struct ggml_cgraph)];
+};
+
+#define hash_id(node) ggml_hash_find_or_insert(sched->hash_set, node)
+#define node_allocr(node) sched->node_talloc[hash_id(node)]
+
+static bool ggml_is_view_op(enum ggml_op op) {
+ return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE;
+}
+
+// returns the priority of the backend, lower is better
+static int sched_backend_prio(ggml_backend_sched_t sched, ggml_backend_t backend) {
+ for (int i = 0; i < sched->n_backends; i++) {
+ if (sched->backends[i] == backend) {
+ return i;
+ }
+ }
+ return INT_MAX;
+}
+
+static int sched_allocr_prio(ggml_backend_sched_t sched, ggml_tallocr_t allocr) {
+ for (int i = 0; i < sched->n_backends; i++) {
+ if (sched->tallocs[i] == allocr) {
+ return i;
+ }
+ }
+ return INT_MAX;
+}
+
+// returns the backend that should be used for the node based on the current locations
+char causes[GGML_DEFAULT_GRAPH_SIZE*4 + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS][128]; // debug, remove
+static ggml_backend_t sched_backend_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * node) {
+ // if the dst tensor is already allocated in a buffer, we must assume that it is critical to keep it there
+ // ie. kv cache updates
+ // note that this doesn't allow fallback to CPU. need to add output tensors to the splits to copy the data back to the original backend.
+ // dst
+ ggml_backend_t cur_backend = ggml_get_backend(node);
+ if (cur_backend != NULL) {
+ sprintf(causes[hash_id(node)], "1.dst");
+ return cur_backend;
+ }
+
+ // view_src
+ if (node->view_src != NULL && ggml_get_backend(node->view_src) != NULL) {
+ sprintf(causes[hash_id(node)], "1.vsrc");
+ return ggml_get_backend(node->view_src);
+ }
+
+ // src
+ int cur_prio = INT_MAX;
+ size_t cur_size = 0;
+
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
+ const struct ggml_tensor * src = node->src[i];
+ if (src == NULL) {
+ break;
+ }
+ ggml_backend_t src_backend = ggml_get_backend(src);
+ if (src_backend != NULL) {
+ int src_prio = sched_backend_prio(sched, src_backend);
+ size_t src_size = ggml_nbytes(src);
+ if (src_prio < cur_prio && src_size >= cur_size) {
+ cur_prio = src_prio;
+ cur_size = src_size;
+ cur_backend = src_backend;
+ sprintf(causes[hash_id(node)], "1.src%d", i);
+ }
+ }
+ }
+ return cur_backend;
+}
+
+static char * fmt_size(size_t size) {
+ static char buffer[128];
+ if (size >= 1024*1024) {
+ sprintf(buffer, "%zuM", size/1024/1024);
+ } else {
+ sprintf(buffer, "%zuK", size/1024);
+ }
+ return buffer;
+}
+
+static void sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+ int cur_split = 0;
+ for (int i = 0; i < graph->n_nodes; i++) {
+ if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) {
+ ggml_backend_t split_backend = ggml_tallocr_get_buffer(sched->splits[cur_split].tallocr)->backend;
+ fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend), sched->splits[cur_split].n_inputs);
+ for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) {
+ fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name, fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j])));
+ }
+ fprintf(stderr, "\n");
+ cur_split++;
+ }
+ struct ggml_tensor * node = graph->nodes[i];
+ if (ggml_is_view_op(node->op)) {
+ continue;
+ }
+ ggml_tallocr_t node_allocr = node_allocr(node);
+ ggml_backend_t node_backend = node_allocr ? ggml_tallocr_get_buffer(node_allocr)->backend : NULL;
+ fprintf(stderr, "node #%3d (%10.10s): %20.20s (%4.4s) [%4.4s %8.8s]:", i, ggml_op_name(node->op), node->name, fmt_size(ggml_nbytes(node)), node_allocr ? ggml_backend_name(node_backend) : "NULL", causes[hash_id(node)]);
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ struct ggml_tensor * src = node->src[j];
+ if (src == NULL) {
+ break;
+ }
+ ggml_tallocr_t src_allocr = node_allocr(src);
+ ggml_backend_t src_backend = src_allocr ? ggml_tallocr_get_buffer(src_allocr)->backend : NULL;
+ fprintf(stderr, " %20.20s (%4.4s) [%4.4s %8.8s]", src->name, fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", causes[hash_id(src)]);
+ }
+ fprintf(stderr, "\n");
+ }
+}
+
+// creates a copy of the tensor with the same memory layout
+static struct ggml_tensor * ggml_dup_tensor_layout(struct ggml_context * ctx, const struct ggml_tensor * tensor) {
+ struct ggml_tensor * dup = ggml_dup_tensor(ctx, tensor);
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
+ dup->nb[i] = tensor->nb[i];
+ }
+ return dup;
+}
+
+// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
+// TODO: merge passes
+static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+ // reset state
+ size_t hash_size = sched->hash_set.size;
+ memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size);
+ memset(sched->node_talloc, 0, sizeof(sched->node_talloc[0]) * hash_size);
+ memset(sched->node_copies, 0, sizeof(sched->node_copies[0]) * hash_size);
+ sched->n_splits = 0;
+
+ struct ggml_init_params params = {
+ /*.mem_size = */ sizeof(sched->context_buffer),
+ /*.mem_buffer = */ sched->context_buffer,
+ /*.no_alloc = */ true
+ };
+
+ if (sched->ctx != NULL) {
+ ggml_free(sched->ctx);
+ }
+
+ sched->ctx = ggml_init(params);
+
+ // pass 1: assign backends to ops with allocated inputs
+ for (int i = 0; i < graph->n_leafs; i++) {
+ struct ggml_tensor * leaf = graph->leafs[i];
+ if (node_allocr(leaf) != NULL) {
+ // do not overwrite user assignments
+ continue;
+ }
+ ggml_backend_t leaf_backend = ggml_get_backend(leaf);
+ if (leaf_backend == NULL && leaf->view_src != NULL) {
+ leaf_backend = ggml_get_backend(leaf->view_src);
+ }
+ if (leaf_backend != NULL) {
+ node_allocr(leaf) = ggml_backend_sched_get_tallocr(sched, leaf_backend);
+ }
+ }
+
+ for (int i = 0; i < graph->n_nodes; i++) {
+ struct ggml_tensor * node = graph->nodes[i];
+ if (node_allocr(node) != NULL) {
+ // do not overwrite user assignments
+ continue;
+ }
+ ggml_backend_t node_backend = sched_backend_from_cur(sched, node);
+ if (node_backend != NULL) {
+ node_allocr(node) = ggml_backend_sched_get_tallocr(sched, node_backend);
+ }
+ }
+ //printf("PASS 1 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
+
+ // pass 2: assign backends to ops from current assignments
+ // TODO:
+ // - reuse sched_backend_from_cur
+ for (int i = 0; i < graph->n_nodes; i++) {
+ struct ggml_tensor * node = graph->nodes[i];
+ ggml_tallocr_t node_allocr = node_allocr(node);
+ if (node_allocr == NULL) {
+ int cur_prio = INT_MAX;
+ size_t cur_size = 0;
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ struct ggml_tensor * src = node->src[j];
+ if (src == NULL) {
+ break;
+ }
+ ggml_tallocr_t src_allocr = node_allocr(src);
+ if (src_allocr != NULL) {
+ int src_prio = sched_allocr_prio(sched, src_allocr);
+ size_t src_size = ggml_nbytes(src);
+ if (src_prio < cur_prio && src_size >= cur_size) {
+ cur_prio = src_prio;
+ cur_size = src_size;
+ node_allocr = src_allocr;
+ sprintf(causes[hash_id(node)], "2.src%d", j);
+ }
+ }
+ }
+ if (node_allocr != NULL) {
+ node_allocr(node) = node_allocr;
+ }
+ }
+ }
+ //printf("PASS 2 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
+
+ // pass 3: assign backends to remaining src from dst (should only be leafs)
+ for (int i = 0; i < graph->n_nodes; i++) {
+ struct ggml_tensor * node = graph->nodes[i];
+ ggml_tallocr_t node_allocr = node_allocr(node);
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ struct ggml_tensor * src = node->src[j];
+ if (src == NULL) {
+ break;
+ }
+ ggml_tallocr_t src_allocr = node_allocr(src);
+ if (src_allocr == NULL) {
+ node_allocr(src) = node_allocr;
+ }
+ }
+ }
+ //printf("PASS 3 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
+
+ // pass 4: split graph, find tensors that need to be copied
+ // TODO:
+ // - when switching from a less preferred backend to a more preferred backend, check if it is possible to move the switch to an earlier point for the same cost
+ // find first backend
+ int cur_split = 0;
+ for (int i = 0; i < graph->n_nodes; i++) {
+ struct ggml_tensor * node = graph->nodes[i];
+ if (node->view_src == NULL) {
+ sched->splits[0].tallocr = node_allocr(node);
+ break;
+ }
+ }
+ sched->splits[0].i_start = 0;
+ sched->splits[0].n_inputs = 0;
+ memset(sched->splits[0].inputs, 0, sizeof(sched->splits[0].inputs)); //HACK
+ ggml_tallocr_t cur_allocr = sched->splits[0].tallocr;
+ size_t cur_backend_id = sched_allocr_prio(sched, cur_allocr);
+ for (int i = 0; i < graph->n_nodes; i++) {
+ struct ggml_tensor * node = graph->nodes[i];
+
+ if (ggml_is_view_op(node->op)) {
+ continue;
+ }
+
+ ggml_tallocr_t node_allocr = node_allocr(node);
+
+ if (node_allocr != cur_allocr) {
+ sched->splits[cur_split].i_end = i;
+ cur_split++;
+ GGML_ASSERT(cur_split < GGML_MAX_SPLITS);
+ sched->splits[cur_split].tallocr = node_allocr;
+ sched->splits[cur_split].i_start = i;
+ sched->splits[cur_split].n_inputs = 0;
+ memset(sched->splits[cur_split].inputs, 0, sizeof(sched->splits[cur_split].inputs)); //HACK
+ cur_allocr = node_allocr;
+ cur_backend_id = sched_allocr_prio(sched, cur_allocr);
+ }
+
+ // find inputs that are not on the same backend
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ struct ggml_tensor * src = node->src[j];
+ if (src == NULL) {
+ break;
+ }
+ ggml_tallocr_t src_allocr = node_allocr(src);
+ if (src_allocr != node_allocr) {
+ int n_inputs = sched->splits[cur_split].n_inputs++;
+ GGML_ASSERT(n_inputs < GGML_MAX_SPLIT_INPUTS);
+ sched->splits[cur_split].inputs[n_inputs] = (struct ggml_tensor *)src;
+
+ // create copies
+ size_t id = hash_id(src);
+ if (sched->node_copies[id][cur_backend_id] == NULL) {
+ struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
+ sched->node_copies[id][cur_backend_id] = tensor_copy;
+ node_allocr(tensor_copy) = cur_allocr;
+ ggml_backend_t backend = ggml_tallocr_get_buffer(cur_allocr)->backend;
+ ggml_format_name(tensor_copy, "%s#%s", ggml_backend_name(backend), src->name);
+ }
+ node->src[j] = sched->node_copies[id][cur_backend_id];
+ }
+ }
+ }
+ sched->splits[cur_split].i_end = graph->n_nodes;
+ sched->n_splits = cur_split + 1;
+
+ //fprintf(stderr, "PASS 4 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); fflush(stdout);
+
+#if 1
+ // sanity check: all sources should have the same backend as the node
+ for (int i = 0; i < graph->n_nodes; i++) {
+ struct ggml_tensor * node = graph->nodes[i];
+ ggml_tallocr_t node_allocr = node_allocr(node);
+ if (node_allocr == NULL) {
+ fprintf(stderr, "!!!!!!! %s has no backend\n", node->name);
+ }
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ struct ggml_tensor * src = node->src[j];
+ if (src == NULL) {
+ break;
+ }
+ ggml_tallocr_t src_allocr = node_allocr(src);
+ if (src_allocr != node_allocr /* && src_backend != NULL */) { // ignore nulls for now
+ fprintf(stderr, "!!!! %s has backend %s, src %d (%s) has backend %s\n",
+ node->name, node_allocr ? ggml_backend_name(ggml_tallocr_get_buffer(node_allocr)->backend) : "NULL",
+ j, src->name, src_allocr ? ggml_backend_name(ggml_tallocr_get_buffer(src_allocr)->backend) : "NULL");
+ }
+ }
+ }
+#endif
+
+ // create copies of the graph for each split
+ // FIXME: avoid this copy, pass split inputs to ggml_gallocr_alloc_graph_n in some other way
+ struct ggml_cgraph * graph_copy = ggml_new_graph_custom(sched->ctx, graph->n_nodes + sched->n_splits*GGML_MAX_SPLIT_INPUTS, false);
+ for (int i = 0; i < sched->n_splits; i++) {
+ struct ggml_backend_sched_split * split = &sched->splits[i];
+ split->graph = ggml_graph_view(sched->ctx, graph, split->i_start, split->i_end);
+
+ // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split
+ for (int j = 0; j < split->n_inputs; j++) {
+ struct ggml_tensor * input = split->inputs[j];
+ struct ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][sched_allocr_prio(sched, split->tallocr)];
+ input_cpy->src[0] = input;
+ graph_copy->nodes[graph_copy->n_nodes++] = input_cpy;
+ }
+
+ for (int j = split->i_start; j < split->i_end; j++) {
+ graph_copy->nodes[graph_copy->n_nodes++] = graph->nodes[j];
+ }
+ }
+ sched->graph = graph_copy;
+}
+
+static void sched_alloc_splits(ggml_backend_sched_t sched) {
+ ggml_gallocr_alloc_graph_n(
+ sched->galloc,
+ sched->graph,
+ sched->hash_set,
+ sched->node_talloc);
+}
+
+static void sched_compute_splits(ggml_backend_sched_t sched) {
+ uint64_t copy_us[GGML_MAX_BACKENDS] = {0};
+ uint64_t compute_us[GGML_MAX_BACKENDS] = {0};
+
+ struct ggml_backend_sched_split * splits = sched->splits;
+
+ for (int i = 0; i < sched->n_splits; i++) {
+ struct ggml_backend_sched_split * split = &splits[i];
+ ggml_backend_t split_backend = ggml_tallocr_get_buffer(split->tallocr)->backend;
+ int split_backend_id = sched_backend_prio(sched, split_backend);
+
+ // copy the input tensors to the split backend
+ uint64_t copy_start_us = ggml_time_us();
+ for (int j = 0; j < split->n_inputs; j++) {
+ struct ggml_tensor * input_cpy = sched->node_copies[hash_id(split->inputs[j])][sched_backend_prio(sched, split_backend)];
+ if (split->inputs[j]->buffer == NULL) {
+ if (split->inputs[j]->view_src == NULL) {
+ fprintf(stderr, "input %s has no buffer and no view_src\n", split->inputs[j]->name);
+ exit(1);
+ }
+ struct ggml_tensor * view = split->inputs[j];
+ view->backend = view->view_src->backend;
+ view->buffer = view->view_src->buffer;
+ view->data = (char *)view->view_src->data + view->view_offs;
+ ggml_backend_buffer_init_tensor(ggml_backend_sched_get_buffer(sched, view->buffer->backend), view);
+ }
+ if (input_cpy->buffer == NULL) {
+ fprintf(stderr, "input_cpy %s has no buffer\n", input_cpy->name);
+ exit(1);
+ }
+ GGML_ASSERT(split->inputs[j]->buffer->backend != input_cpy->buffer->backend);
+ GGML_ASSERT(input_cpy->buffer->backend == split_backend);
+ ggml_backend_tensor_copy(split->inputs[j], input_cpy);
+ }
+ // ggml_backend_synchronize(split_backend);
+ int64_t copy_end_us = ggml_time_us();
+ copy_us[split_backend_id] += copy_end_us - copy_start_us;
+
+#if 0
+ char split_filename[GGML_MAX_NAME];
+ snprintf(split_filename, GGML_MAX_NAME, "split_%i_%s.dot", i, ggml_backend_name(split_backend));
+ ggml_graph_dump_dot(split->graph, NULL, split_filename);
+#endif
+
+ uint64_t compute_start_us = ggml_time_us();
+ ggml_backend_graph_compute(split_backend, split->graph);
+ // ggml_backend_synchronize(split_backend);
+ uint64_t compute_end_us = ggml_time_us();
+ compute_us[split_backend_id] += compute_end_us - compute_start_us;
+ }
+
+#if 0
+ // per-backend timings
+ fprintf(stderr, "sched_compute_splits times (%d splits):\n", sched->n_splits);
+ for (int i = 0; i < sched->n_backends; i++) {
+ if (copy_us[i] > 0 || compute_us[i] > 0) {
+ fprintf(stderr, "\t%5.5s: %lu us copy, %lu us compute\n", ggml_backend_name(sched->backends[i]), copy_us[i], compute_us[i]);
+ }
+ }
+#endif
+}
+
+static void sched_reset(ggml_backend_sched_t sched) {
+ for (int i = 0; i < sched->n_backends; i++) {
+ ggml_tallocr_reset(sched->tallocs[i]);
+ }
+}
+
+ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_backends) {
+ GGML_ASSERT(n_backends <= GGML_MAX_BACKENDS);
+
+ struct ggml_backend_sched * sched = malloc(sizeof(struct ggml_backend_sched));
+ memset(sched, 0, sizeof(struct ggml_backend_sched));
+
+ fprintf(stderr, "ggml_backend_sched size: %lu KB\n", sizeof(struct ggml_backend_sched)/1024);
+
+ sched->n_backends = n_backends;
+ for (int i = 0; i < n_backends; i++) {
+ sched->backends[i] = backends[i];
+ }
+
+ sched->galloc = ggml_gallocr_new();
+
+ // init measure allocs for each backend
+ for (int i = 0; i < n_backends; i++) {
+ sched->tallocs[i] = ggml_tallocr_new_measure_from_backend(backends[i]);
+ }
+
+ return sched;
+}
+
+void ggml_backend_sched_free(ggml_backend_sched_t sched) {
+ if (sched == NULL) {
+ return;
+ }
+ for (int i = 0; i < sched->n_backends; i++) {
+ ggml_tallocr_free(sched->tallocs[i]);
+ }
+ ggml_gallocr_free(sched->galloc);
+ free(sched->hash_set.keys);
+ free(sched->node_talloc);
+ free(sched->node_copies);
+ free(sched);
+}
+
+void ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) {
+ // initialize hash tables
+ size_t hash_size = measure_graph->visited_hash_table.size + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS;
+ sched->hash_set.size = hash_size;
+ sched->hash_set.keys = malloc(sizeof(sched->hash_set.keys[0]) * hash_size);
+ sched->node_talloc = malloc(sizeof(sched->node_talloc[0]) * hash_size);
+ sched->node_copies = malloc(sizeof(sched->node_copies[0]) * hash_size);
+
+ sched_split_graph(sched, measure_graph);
+ sched_alloc_splits(sched);
+
+ // allocate buffers and reset allocators
+ for (int i = 0; i < sched->n_backends; i++) {
+ size_t size = ggml_tallocr_max_size(sched->tallocs[i]);
+ ggml_tallocr_free(sched->tallocs[i]);
+ sched->tallocs[i] = ggml_tallocr_new_from_backend(sched->backends[i], size);
+ }
+
+ sched_reset(sched);
+}
+
+void ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+ GGML_ASSERT(sched->hash_set.size >= graph->visited_hash_table.size + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS);
+
+ sched_split_graph(sched, graph);
+ sched_alloc_splits(sched);
+ sched_compute_splits(sched);
+ sched_reset(sched);
+}
+
+ggml_tallocr_t ggml_backend_sched_get_tallocr(ggml_backend_sched_t sched, ggml_backend_t backend) {
+ int backend_index = sched_backend_prio(sched, backend);
+ return sched->tallocs[backend_index];
+}
+
+ggml_backend_buffer_t ggml_backend_sched_get_buffer(ggml_backend_sched_t sched, ggml_backend_t backend) {
+ int backend_index = sched_backend_prio(sched, backend);
+ return ggml_tallocr_get_buffer(sched->tallocs[backend_index]);
+}
+
+void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
+ int backend_index = sched_backend_prio(sched, backend);
+ GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
+ node_allocr(node) = sched->tallocs[backend_index];
+}
#pragma once
#include "ggml.h"
+#include "ggml-alloc.h"
#ifdef __cplusplus
extern "C" {
#endif
- struct ggml_backend;
- struct ggml_backend_buffer;
-
- // type-erased backend-specific types / wrappers
- typedef void * ggml_backend_context_t;
- typedef void * ggml_backend_graph_plan_t;
- typedef void * ggml_backend_buffer_context_t;
-
- // avoid accessing internals of these types
- typedef struct ggml_backend * ggml_backend_t;
- typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
//
- // backend buffer
+ // Backend buffer
//
- struct ggml_backend_buffer_i {
- void (*free_buffer) (ggml_backend_buffer_t buffer);
- void * (*get_base) (ggml_backend_buffer_t buffer); // get base pointer
- size_t (*get_alloc_size)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-allocation callback
- void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // post-allocation callback
- void (*free_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-free callback
- };
-
- // TODO: hide behind API
- struct ggml_backend_buffer {
- struct ggml_backend_buffer_i iface;
-
- ggml_backend_t backend;
- ggml_backend_buffer_context_t context;
-
- size_t size;
- };
+ struct ggml_backend_buffer;
+ typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
// backend buffer functions
- GGML_API ggml_backend_buffer_t ggml_backend_buffer_init(
- struct ggml_backend * backend,
- struct ggml_backend_buffer_i iface,
- ggml_backend_buffer_context_t context,
- size_t size);
-
GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer);
GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer);
GGML_API void ggml_backend_buffer_free_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
//
- // backend
+ // Backend
//
- struct ggml_backend_i {
- const char * (*get_name)(ggml_backend_t backend);
-
- void (*free)(ggml_backend_t backend);
-
- // buffer allocation
- ggml_backend_buffer_t (*alloc_buffer)(ggml_backend_t backend, size_t size);
-
- // get buffer alignment
- size_t (*get_alignment)(ggml_backend_t backend);
-
- // tensor data access
- // these functions can be asynchronous, helper functions are provided for synchronous access that automatically call synchronize
- void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
- void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
- void (*synchronize) (ggml_backend_t backend);
-
- // (optional) copy tensor between different backends, allow for single-copy tranfers
- void (*cpy_tensor_from)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
- void (*cpy_tensor_to) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
-
- // compute graph with a plan
- ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
- void (*graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
- void (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
-
- // compute graph without a plan
- void (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
-
- // check if the backend supports an operation
- bool (*supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
- };
-
- // TODO: hide behind API
- struct ggml_backend {
- struct ggml_backend_i iface;
-
- ggml_backend_context_t context;
- };
+ struct ggml_backend;
+ typedef struct ggml_backend * ggml_backend_t;
+ typedef void * ggml_backend_graph_plan_t;
- // backend helper functions
GGML_API ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor);
GGML_API const char * ggml_backend_name(ggml_backend_t backend);
GGML_API ggml_backend_t ggml_backend_cpu_init(void);
GGML_API bool ggml_backend_is_cpu(ggml_backend_t backend);
-
GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads);
+ // Create a backend buffer from an existing pointer
GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size);
+
+ //
+ // Backend scheduler
+ //
+
+ // The backend scheduler allows for multiple backends to be used together
+ // Handles compute buffer allocation, assignment of tensors to backends, and copying of tensors between backends
+ // The backends are selected based on:
+ // - the backend that supports the operation
+ // - the location of the pre-allocated tensors (e.g. the weights)
+ /*
+ Example usage:
+
+ sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, num_backends);
+ // sched is initialized with measure allocators and cannot be used until allocated with a measure graph
+
+ // initialize buffers from a measure graph
+ measure_graph = build_graph(sched); // use the allocr to allocate inputs as needed
+
+ // in build_graph:
+ build_graph(...) {
+ // allocating tensors in a specific backend (optional, recommended: pre-allocate inputs in a different buffer)
+ alloc_cpu = ggml_backend_sched_get_allocr(sched, backend_cpu);
+ ggml_allocr_alloc(alloc_cpu, tensor);
+
+ // manually assigning nodes to a backend (optional, shouldn't be needed in most cases)
+ struct ggml_tensor * node = ggml_mul_mat(ctx, ...);
+ ggml_backend_sched_set_node_backend(sched, node, backend_gpu);
+ }
+
+ // allocate backend buffers from measure graph
+ ggml_backend_sched_init_measure(sched, measure_graph);
+
+ // the scheduler is now ready to compute graphs
+
+ // compute
+ graph = build_graph(sched);
+ ggml_backend_sched_graph_compute(sched, graph);
+ */
+
+ struct ggml_backend_sched;
+ typedef struct ggml_backend_sched * ggml_backend_sched_t;
+
+ // Initialize a backend scheduler
+ GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_backends);
+
+ GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
+
+ // Initialize backend buffers from a measure graph
+ GGML_API void ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
+
+ GGML_API ggml_tallocr_t ggml_backend_sched_get_tallocr(ggml_backend_sched_t sched, ggml_backend_t backend);
+ GGML_API ggml_backend_buffer_t ggml_backend_sched_get_buffer (ggml_backend_sched_t sched, ggml_backend_t backend);
+
+ GGML_API void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
+
+ // Allocate a graph on the backend scheduler
+ GGML_API void ggml_backend_sched_graph_compute(
+ ggml_backend_sched_t sched,
+ struct ggml_cgraph * graph);
+
#ifdef __cplusplus
}
#endif
#include "ggml-cuda.h"
#include "ggml.h"
+#include "ggml-backend-impl.h"
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define CC_VOLTA 700
static ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
if (g_temp_tensor_extras == nullptr) {
- g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_MAX_NODES];
+ g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_DEFAULT_GRAPH_SIZE];
}
size_t alloc_index = g_temp_tensor_extra_index;
- g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_MAX_NODES;
+ g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_DEFAULT_GRAPH_SIZE;
ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index];
memset(extra, 0, sizeof(*extra));
ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
if (temp_tensor_extras == nullptr) {
- temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_MAX_NODES];
+ temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_DEFAULT_GRAPH_SIZE];
}
size_t alloc_index = temp_tensor_extra_index;
- temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_MAX_NODES;
+ temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_DEFAULT_GRAPH_SIZE;
ggml_tensor_extra_gpu * extra = &temp_tensor_extras[alloc_index];
memset(extra, 0, sizeof(*extra));
ggml_cuda_set_device(g_main_device);
ggml_backend_buffer_context_cuda * ctx = new ggml_backend_buffer_context_cuda;
+
+ size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0
+
+ ggml_cuda_set_device(g_main_device);
CUDA_CHECK(cudaMalloc(&ctx->device, size));
+
return ggml_backend_buffer_init(backend, cuda_backend_buffer_interface, ctx, size);
}
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
+ if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE)
+ continue;
assert(node->backend == GGML_BACKEND_GPU);
for (int j = 0; j < GGML_MAX_SRC; j++) {
if (node->src[j] != nullptr) {
#endif
- // TODO: backend v2 PR
+#define GGML_HASHTABLE_FULL ((size_t)-1)
+#define GGML_HASHTABLE_ALREADY_EXISTS ((size_t)-2)
+
+bool ggml_hash_contains (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
+
+// returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted
+size_t ggml_hash_find (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
+
+// returns GGML_HAHSHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
+size_t ggml_hash_insert ( struct ggml_hash_set hash_set, struct ggml_tensor * key);
+
+// return index, asserts if table is full
+size_t ggml_hash_find_or_insert( struct ggml_hash_set hash_set, struct ggml_tensor * key);
#ifdef __cplusplus
}
#import "ggml-metal.h"
+#import "ggml-backend-impl.h"
#import "ggml.h"
#import <Foundation/Foundation.h>
#define UNUSED(x) (void)(x)
-#define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
+#define GGML_MAX_CONCUR (2*GGML_DEFAULT_GRAPH_SIZE)
struct ggml_metal_buffer {
const char * name;
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
struct ggml_tensor * dst = gf->nodes[i];
+ switch (dst->op) {
+ case GGML_OP_NONE:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_PERMUTE:
+ {
+ // noop -> next node
+ } continue;
+ default:
+ {
+ } break;
+ }
+
const int64_t ne00 = src0 ? src0->ne[0] : 0;
const int64_t ne01 = src0 ? src0->ne[1] : 0;
const int64_t ne02 = src0 ? src0->ne[2] : 0;
//}
switch (dst->op) {
- case GGML_OP_NONE:
- case GGML_OP_RESHAPE:
- case GGML_OP_VIEW:
- case GGML_OP_TRANSPOSE:
- case GGML_OP_PERMUTE:
- {
- // noop
- } break;
case GGML_OP_CONCAT:
{
const int64_t nb = ne00;
#include <hbwmalloc.h>
#endif
+#if defined(__APPLE__)
+#include <TargetConditionals.h>
+#endif
+
+#if (defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)) && \
+ (!defined(TARGET_OS_TV) && !defined(TARGET_OS_WATCH))
+
+#include <sys/wait.h>
+
+void ggml_print_backtrace(void) {
+ /*
+ #include <execinfo.h>
+ #include <dlfcn.h>
+
+ void * trace[100];
+
+ int nptrs = backtrace(trace, sizeof(trace)/sizeof(trace[0]));
+
+ backtrace_symbols_fd(trace, nptrs, STDERR_FILENO);
+ */
+
+ // backtrack_symbols does not show line numbers, use gdb instead
+ char attach[32];
+ snprintf(attach, sizeof(attach), "attach %d", getpid());
+ int pid = fork();
+ if (pid == 0) {
+ execlp("gdb", "gdb", "--batch",
+ "-ex", "set style enabled on",
+ "-ex", attach,
+ "-ex", "bt -frame-info source-and-location",
+ "-ex", "detach",
+ "-ex", "quit",
+ NULL);
+ } else {
+ waitpid(pid, NULL, 0);
+ }
+}
+#else
+void ggml_print_backtrace(void) {
+ // platform not supported
+}
+#endif
+
/*#define GGML_PERF*/
#define GGML_DEBUG 0
#define GGML_GELU_FP16
inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); }
inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
+inline static void ggml_vec_leaky_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.1f*x[i]; }
static const float GELU_COEF_A = 0.044715f;
static const float GELU_QUICK_COEF = -1.702f;
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU);
}
+// ggml_leaky
+
+struct ggml_tensor * ggml_leaky(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_unary(ctx, a, GGML_UNARY_OP_LEAKY);
+}
+
// ggml_gelu
struct ggml_tensor * ggml_gelu(
// ggml_pool_*
-static int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, int p) {
+static int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, float p) {
return (ins + 2 * p - ks) / s + 1;
}
int k1,
int s0,
int s1,
- int p0,
- int p1) {
+ float p0,
+ float p1) {
bool is_node = false;
}
}
+// ggml_compute_forward_leaky
+
+static void ggml_compute_forward_leaky_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ assert(params->ith == 0);
+ assert(ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ assert(dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_leaky_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_leaky(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_leaky_f32(params, src0, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
// ggml_compute_forward_silu_back
static void ggml_compute_forward_silu_back_f32(
ggml_compute_forward_pool_1d_sk_p0(params, op, src0, k0, dst);
}
-// ggml_compute_forward_pool_2d_sk_p0
+// ggml_compute_forward_pool_2d
-static void ggml_compute_forward_pool_2d_sk_p0(
+static void ggml_compute_forward_pool_2d(
const struct ggml_compute_params * params,
- const enum ggml_op_pool op,
const struct ggml_tensor * src,
- const int k0,
- const int k1,
struct ggml_tensor * dst) {
assert(src->type == GGML_TYPE_F32);
assert(params->ith == 0);
return;
}
+ const int32_t * opts = (const int32_t *)dst->op_params;
+ enum ggml_op_pool op = opts[0];
+ const int k0 = opts[1];
+ const int k1 = opts[2];
+ const int s0 = opts[3];
+ const int s1 = opts[4];
+ const int p0 = opts[5];
+ const int p1 = opts[6];
const char * cdata = (const char*)src->data;
const char * const data_end = cdata + ggml_nbytes(src);
float * dplane = (float *)dst->data;
const int ka = k0 * k1;
+ const int offset0 = -p0;
+ const int offset1 = -p1;
while (cdata < data_end) {
for (int oy = 0; oy < py; ++oy) {
case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
}
- const int ix = ox * k0;
- const int iy = oy * k1;
+ const int ix = offset0 + ox * s0;
+ const int iy = offset1 + oy * s1;
for (int ky = 0; ky < k1; ++ky) {
+ if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
const float * const srow = (const float *)(cdata + src->nb[1] * (iy + ky));
for (int kx = 0; kx < k0; ++kx) {
int j = ix + kx;
+ if (j < 0 || j >= src->ne[0]) continue;
switch (op) {
case GGML_OP_POOL_AVG: *out += srow[j]; break;
case GGML_OP_POOL_MAX: if (srow[j] > *out) *out = srow[j]; break;
}
}
-// ggml_compute_forward_pool_2d
-
-static void ggml_compute_forward_pool_2d(
- const struct ggml_compute_params * params,
- const struct ggml_tensor * src0,
- struct ggml_tensor * dst) {
-
- const int32_t * opts = (const int32_t *)dst->op_params;
- enum ggml_op_pool op = opts[0];
- const int k0 = opts[1];
- const int k1 = opts[2];
- const int s0 = opts[3];
- const int s1 = opts[4];
- const int p0 = opts[5];
- const int p1 = opts[6];
- GGML_ASSERT(p0 == 0);
- GGML_ASSERT(p1 == 0); // padding not supported
- GGML_ASSERT(k0 == s0);
- GGML_ASSERT(k1 == s1); // only s = k supported
-
- ggml_compute_forward_pool_2d_sk_p0(params, op, src0, k0, k1, dst);
-}
-
// ggml_compute_forward_upscale
static void ggml_compute_forward_upscale_f32(
{
ggml_compute_forward_silu(params, src0, dst);
} break;
+ case GGML_UNARY_OP_LEAKY:
+ {
+ ggml_compute_forward_leaky(params, src0, dst);
+ } break;
default:
{
GGML_ASSERT(false);
////////////////////////////////////////////////////////////////////////////////
-static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
+static size_t ggml_hash_size(size_t min_sz) {
+ // next primes after powers of two
+ static const size_t primes[] = {
+ 2, 3, 5, 11, 17, 37, 67, 131, 257, 521, 1031,
+ 2053, 4099, 8209, 16411, 32771, 65537, 131101,
+ 262147, 524309, 1048583, 2097169, 4194319, 8388617,
+ 16777259, 33554467, 67108879, 134217757, 268435459,
+ 536870923, 1073741827, 2147483659
+ };
+ static const size_t n_primes = sizeof(primes)/sizeof(primes[0]);
+
+ // find the smallest prime that is larger or equal to min_sz
+ size_t l = 0;
+ size_t r = n_primes;
+ while (l < r) {
+ size_t m = (l + r)/2;
+ if (primes[m] < min_sz) {
+ l = m + 1;
+ } else {
+ r = m;
+ }
+ }
+ size_t sz = l < n_primes ? primes[l] : min_sz | 1;
+ return sz;
+}
-static size_t hash(void * p) {
- return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
+static size_t ggml_hash(const void * p) {
+ return (size_t)p;
}
-static size_t hash_find(void * hash_table[], void * p) {
- size_t h = hash(p);
+size_t ggml_hash_find(const struct ggml_hash_set hash_set, struct ggml_tensor * key) {
+ size_t h = ggml_hash(key) % hash_set.size;
// linear probing
size_t i = h;
- while (hash_table[i] != NULL && hash_table[i] != p) {
- i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
+ while (hash_set.keys[i] != NULL && hash_set.keys[i] != key) {
+ i = (i + 1) % hash_set.size;
if (i == h) {
// visited all hash table entries -> not found
- return GGML_GRAPH_HASHTABLE_SIZE;
+ return GGML_HASHTABLE_FULL;
}
}
return i;
}
-static bool hash_insert(void * hash_table[], void * p) {
- size_t i = hash_find(hash_table, p);
+bool ggml_hash_contains(struct ggml_hash_set hash_set, struct ggml_tensor * key) {
+ size_t i = ggml_hash_find(hash_set, key);
+ return i != GGML_HASHTABLE_FULL && hash_set.keys[i] == key;
+}
- GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
+size_t ggml_hash_insert(struct ggml_hash_set hash_set, struct ggml_tensor * key) {
+ size_t i = ggml_hash_find(hash_set, key);
- if (hash_table[i] == p) {
- return true;
+ GGML_ASSERT(i != GGML_HASHTABLE_FULL);
+
+ if (hash_set.keys[i] == key) {
+ return GGML_HASHTABLE_ALREADY_EXISTS;
}
// insert
- GGML_ASSERT(hash_table[i] == NULL);
- hash_table[i] = p;
- return false;
+ GGML_ASSERT(hash_set.keys[i] == NULL);
+ hash_set.keys[i] = key;
+ return i;
+}
+
+size_t ggml_hash_find_or_insert(struct ggml_hash_set hash_set, struct ggml_tensor * key) {
+ size_t i = ggml_hash_find(hash_set, key);
+
+ GGML_ASSERT(i != GGML_HASHTABLE_FULL);
+
+ hash_set.keys[i] = key;
+ return i;
+}
+
+static struct ggml_hash_set ggml_hash_set_new(size_t size) {
+ size = ggml_hash_size(size);
+ struct ggml_hash_set result;
+ result.size = size;
+ result.keys = malloc(sizeof(struct ggml_tensor *) * size);
+ memset(result.keys, 0, sizeof(struct ggml_tensor *) * size);
+ return result;
}
-static bool hash_contains(void * hash_table[], void * p) {
- size_t i = hash_find(hash_table, p);
- return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p);
+static void ggml_hash_set_free(struct ggml_hash_set hash_set) {
+ free(hash_set.keys);
}
struct hash_map {
- void * keys[GGML_GRAPH_HASHTABLE_SIZE];
- void * vals[GGML_GRAPH_HASHTABLE_SIZE];
+ struct ggml_hash_set set;
+ struct ggml_tensor ** vals;
};
-static struct hash_map * new_hash_map(void) {
+static struct hash_map * ggml_new_hash_map(size_t size) {
struct hash_map * result = malloc(sizeof(struct hash_map));
- for (int i=0; i<GGML_GRAPH_HASHTABLE_SIZE; ++i) {
- result->keys[i] = NULL;
- result->vals[i] = NULL;
- }
+ result->set = ggml_hash_set_new(size);
+ result->vals = malloc(sizeof(struct ggml_tensor *) * result->set.size);
+ memset(result->vals, 0, sizeof(struct ggml_tensor *) * result->set.size);
return result;
}
-static void free_hash_map(struct hash_map * map) {
+static void ggml_hash_map_free(struct hash_map * map) {
+ ggml_hash_set_free(map->set);
+ free(map->vals);
free(map);
}
return node;
}
- if (!hash_contains(graph->visited_hash_table, node)) {
+ if (!ggml_hash_contains(graph->visited_hash_table, node)) {
return node;
}
return node;
}
- size_t i = hash_find(replacements->keys, node);
- GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
- if (replacements->keys[i] == node) {
- return (struct ggml_tensor *) replacements->vals[i];
+ size_t i = ggml_hash_find(replacements->set, node);
+ GGML_ASSERT(i != GGML_HASHTABLE_FULL); // assert that not full
+ if (replacements->set.keys[i] == node) {
+ return replacements->vals[i];
}
struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne);
// insert clone into replacements
- GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite
- replacements->keys[i] = node;
+ GGML_ASSERT(replacements->set.keys[i] == NULL); // assert that we don't overwrite
+ replacements->set.keys[i] = node;
replacements->vals[i] = clone;
clone->op = node->op;
struct ggml_cgraph * gb_tmp,
struct ggml_tensor * * checkpoints,
int n_checkpoints) {
- *gb_tmp = *gf;
+ ggml_graph_cpy(gf, gb_tmp);
ggml_build_backward_expand(ctx, gf, gb_tmp, true);
if (n_checkpoints <= 0) {
- *gb = *gb_tmp;
+ ggml_graph_cpy(gb_tmp, gb);
return;
}
- struct hash_map * replacements = new_hash_map();
+ struct hash_map * replacements = ggml_new_hash_map(gf->n_nodes + gf->n_leafs + n_checkpoints);
// insert checkpoints in replacements
for (int i = 0; i < n_checkpoints; ++i) {
- size_t k = hash_find(replacements->keys, checkpoints[i]);
- GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full
- GGML_ASSERT(replacements->keys[k] == NULL); // assert that we don't overwrite
- replacements->keys[k] = checkpoints[i];
- replacements->vals[k] = checkpoints[i];
+ size_t k = ggml_hash_find(replacements->set, checkpoints[i]);
+ GGML_ASSERT(k != GGML_HASHTABLE_FULL); // assert that not full
+ GGML_ASSERT(replacements->set.keys[k] == NULL); // assert that we don't overwrite
+ replacements->set.keys[k] = checkpoints[i];
+ replacements->vals[k] = checkpoints[i];
}
- *gb = *gf;
+ ggml_graph_cpy(gf, gb);
// rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
// replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
// by recomputing them from checkpoints
ggml_build_forward_expand(gb, node);
}
- free_hash_map(replacements);
+ ggml_hash_map_free(replacements);
}
// functions to change gradients considering the case that input a might be initial gradient with zero value
-static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
- if (hash_contains(zero_table, a)) {
+static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set zero_table) {
+ if (ggml_hash_contains(zero_table, a)) {
return b;
} else {
return ggml_add_impl(ctx, a, b, false);
}
}
-static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, void * zero_table[]) {
- if (hash_contains(zero_table, a)) {
+static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct ggml_hash_set zero_table) {
+ if (ggml_hash_contains(zero_table, a)) {
struct ggml_tensor * a_zero = ggml_scale(ctx, a, ggml_new_f32(ctx, 0));
return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
} else {
}
}
-static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
- if (hash_contains(zero_table, a)) {
+static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set zero_table) {
+ if (ggml_hash_contains(zero_table, a)) {
return ggml_repeat(ctx, b, a);
} else {
return ggml_add1_impl(ctx, a, b, false);
}
}
-static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
- if (hash_contains(zero_table, a)) {
+static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set zero_table) {
+ if (ggml_hash_contains(zero_table, a)) {
return ggml_neg(ctx, b);
} else {
return ggml_sub_impl(ctx, a, b, false);
}
}
-static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, void * zero_table[]) {
+static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) {
struct ggml_tensor * src0 = tensor->src[0];
struct ggml_tensor * src1 = tensor->src[1];
}
// check if already visited
- if (hash_insert(cgraph->visited_hash_table, node)) {
+ if (ggml_hash_insert(cgraph->visited_hash_table, node) == GGML_HASHTABLE_ALREADY_EXISTS) {
return;
}
if (node->op == GGML_OP_NONE && node->grad == NULL) {
// reached a leaf node, not part of the gradient graph (e.g. a constant)
- GGML_ASSERT(cgraph->n_leafs < GGML_MAX_NODES);
+ GGML_ASSERT(cgraph->n_leafs < cgraph->size);
if (strlen(node->name) == 0) {
ggml_format_name(node, "leaf_%d", cgraph->n_leafs);
cgraph->leafs[cgraph->n_leafs] = node;
cgraph->n_leafs++;
} else {
- GGML_ASSERT(cgraph->n_nodes < GGML_MAX_NODES);
+ GGML_ASSERT(cgraph->n_nodes < cgraph->size);
if (strlen(node->name) == 0) {
ggml_format_name(node, "node_%d", cgraph->n_nodes);
}
cgraph->nodes[cgraph->n_nodes] = node;
- cgraph->grads[cgraph->n_nodes] = node->grad;
+ if (cgraph->grads) {
+ cgraph->grads[cgraph->n_nodes] = node->grad;
+ }
cgraph->n_nodes++;
}
}
static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) {
if (!expand) {
- cgraph->n_nodes = 0;
- cgraph->n_leafs = 0;
+ // TODO: this branch isn't accessible anymore, maybe move this to ggml_build_forward_expand
+ ggml_graph_clear(cgraph);
}
const int n0 = cgraph->n_nodes;
ggml_build_forward_impl(cgraph, tensor, true);
}
-struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
- struct ggml_cgraph result = {
- /*.n_nodes =*/ 0,
- /*.n_leafs =*/ 0,
- /*.nodes =*/ { NULL },
- /*.grads =*/ { NULL },
- /*.leafs =*/ { NULL },
- /*.hash_table =*/ { NULL },
- /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
- /*.perf_runs =*/ 0,
- /*.perf_cycles =*/ 0,
- /*.perf_time_us =*/ 0,
- };
-
- ggml_build_forward_impl(&result, tensor, false);
-
- return result;
-}
-
void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep) {
GGML_ASSERT(gf->n_nodes > 0);
}
// remember original gradients which start with zero values
- void ** zero_table = malloc(sizeof(void *) * GGML_GRAPH_HASHTABLE_SIZE);
- memset(zero_table, 0, sizeof(void*) * GGML_GRAPH_HASHTABLE_SIZE);
+ struct ggml_hash_set zero_table = ggml_hash_set_new(gf->size);
for (int i = 0; i < gf->n_nodes; i++) {
if (gf->grads[i]) {
- hash_insert(zero_table, gf->grads[i]);
+ ggml_hash_insert(zero_table, gf->grads[i]);
}
}
}
}
- free(zero_table);
+ ggml_hash_set_free(zero_table);
}
-struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) {
- struct ggml_cgraph result = *gf;
- ggml_build_backward_expand(ctx, gf, &result, keep);
- return result;
+static size_t ggml_graph_nbytes(size_t size, bool grads) {
+ size_t nbytes = sizeof(struct ggml_cgraph);
+ nbytes += size * sizeof(struct ggml_tensor *) * 2; // leafs + nodes
+ if (grads) {
+ nbytes += size * sizeof(struct ggml_tensor *); // grads
+ }
+ nbytes += ggml_hash_size(size * 2) * sizeof(struct ggml_tensor *); // hash set
+ return nbytes;
}
-struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) {
- struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_GRAPH, GGML_GRAPH_SIZE);
+size_t ggml_graph_overhead_custom(size_t size, bool grads) {
+ return GGML_OBJECT_SIZE + GGML_PAD(ggml_graph_nbytes(size, grads), GGML_MEM_ALIGN);
+}
+
+size_t ggml_graph_overhead(void) {
+ return ggml_graph_overhead_custom(GGML_DEFAULT_GRAPH_SIZE, false);
+}
+
+struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads) {
+ const size_t obj_size = ggml_graph_nbytes(size, grads);
+ struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_GRAPH, obj_size);
struct ggml_cgraph * cgraph = (struct ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs);
+ struct ggml_tensor ** data_start = (struct ggml_tensor **) (cgraph + 1);
+
+ size_t hash_size = ggml_hash_size(size * 2);
+ struct ggml_tensor ** nodes_ptr = data_start;
+ struct ggml_tensor ** leafs_ptr = nodes_ptr + size;
+ struct ggml_tensor ** hash_keys_ptr = leafs_ptr + size;
+ struct ggml_tensor ** grads_ptr = grads ? hash_keys_ptr + hash_size : NULL;
+
+ // check that we allocated the correct amount of memory
+ assert(obj_size == (size_t) (
+ (grads ? (char *)(grads_ptr + size) : (char *)(hash_keys_ptr + hash_size)) - (char *)cgraph));
+
+ memset(hash_keys_ptr, 0, hash_size * sizeof(struct ggml_tensor *));
+
*cgraph = (struct ggml_cgraph) {
+ /*.size =*/ size,
/*.n_nodes =*/ 0,
/*.n_leafs =*/ 0,
- /*.nodes =*/ { NULL },
- /*.grads =*/ { NULL },
- /*.leafs =*/ { NULL },
- /*.hash_table =*/ { NULL },
+ /*.nodes =*/ nodes_ptr,
+ /*.grads =*/ grads_ptr,
+ /*.leafs =*/ leafs_ptr,
+ /*.hash_table =*/ { hash_size, hash_keys_ptr },
/*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
/*.perf_runs =*/ 0,
/*.perf_cycles =*/ 0,
return cgraph;
}
-struct ggml_cgraph * ggml_build_forward_ctx(struct ggml_context * ctx, struct ggml_tensor * tensor) {
- struct ggml_cgraph * cgraph = ggml_new_graph(ctx);
- ggml_build_forward_impl(cgraph, tensor, false);
+struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) {
+ return ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, false);
+}
+
+struct ggml_cgraph * ggml_graph_view(struct ggml_context * ctx, struct ggml_cgraph * cgraph0, int i0, int i1) {
+ const size_t obj_size = sizeof(struct ggml_cgraph);
+ struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_GRAPH, obj_size);
+ struct ggml_cgraph * cgraph = (struct ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs);
+
+ *cgraph = (struct ggml_cgraph) {
+ /*.size =*/ 0,
+ /*.n_nodes =*/ i1 - i0,
+ /*.n_leafs =*/ 0,
+ /*.nodes =*/ cgraph0->nodes + i0,
+ /*.grads =*/ cgraph0->grads ? cgraph0->grads + i0 : NULL,
+ /*.leafs =*/ NULL,
+ /*.hash_table =*/ { 0, NULL },
+ /*.order =*/ cgraph0->order,
+ /*.perf_runs =*/ 0,
+ /*.perf_cycles =*/ 0,
+ /*.perf_time_us =*/ 0,
+ };
+
return cgraph;
}
-size_t ggml_graph_overhead(void) {
- return GGML_OBJECT_SIZE + GGML_PAD(GGML_GRAPH_SIZE, GGML_MEM_ALIGN);
+void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
+ GGML_ASSERT(dst->size >= src->n_leafs);
+ GGML_ASSERT(dst->size >= src->n_nodes);
+ GGML_ASSERT(dst->visited_hash_table.size >= src->visited_hash_table.size);
+
+ dst->n_leafs = src->n_leafs;
+ dst->n_nodes = src->n_nodes;
+ dst->order = src->order;
+
+ for (int i = 0; i < src->n_leafs; ++i) {
+ dst->leafs[i] = src->leafs[i];
+ }
+
+ for (int i = 0; i < src->n_nodes; ++i) {
+ dst->nodes[i] = src->nodes[i];
+ }
+
+ if (src->grads) {
+ GGML_ASSERT(dst->grads != NULL);
+ for (int i = 0; i < src->n_nodes; ++i) {
+ dst->grads[i] = src->grads[i];
+ }
+ }
+
+ for (size_t i = 0; i < src->visited_hash_table.size; ++i) {
+ if (src->visited_hash_table.keys[i]) {
+ ggml_hash_insert(dst->visited_hash_table, src->visited_hash_table.keys[i]);
+ }
+ }
+}
+
+struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
+ struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads != NULL);
+ ggml_graph_cpy(cgraph, result);
+ return result;
+}
+
+void ggml_graph_reset(struct ggml_cgraph * cgraph) {
+ GGML_ASSERT(cgraph->grads != NULL);
+
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ struct ggml_tensor * grad = cgraph->grads[i];
+
+ if (grad) {
+ ggml_set_zero(grad);
+ }
+ }
+}
+
+void ggml_graph_clear(struct ggml_cgraph * cgraph) {
+ cgraph->n_leafs = 0;
+ cgraph->n_nodes = 0;
+ memset(cgraph->visited_hash_table.keys, 0, cgraph->visited_hash_table.size * sizeof(struct ggml_tensor *));
}
//
node->perf_time_us += time_us_cur;
}
+static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
+ int n_tasks = 0;
+
+ switch (node->op) {
+ case GGML_OP_CPY:
+ case GGML_OP_DUP:
+ case GGML_OP_ADD:
+ case GGML_OP_ADD1:
+ case GGML_OP_ACC:
+ {
+ n_tasks = n_threads;
+ } break;
+ case GGML_OP_SUB:
+ case GGML_OP_DIV:
+ case GGML_OP_SQR:
+ case GGML_OP_SQRT:
+ case GGML_OP_LOG:
+ case GGML_OP_SUM:
+ case GGML_OP_SUM_ROWS:
+ case GGML_OP_MEAN:
+ case GGML_OP_ARGMAX:
+ case GGML_OP_REPEAT:
+ case GGML_OP_REPEAT_BACK:
+ {
+ n_tasks = 1;
+ } break;
+ case GGML_OP_UNARY:
+ switch (ggml_get_unary_op(node)) {
+ case GGML_UNARY_OP_ABS:
+ case GGML_UNARY_OP_SGN:
+ case GGML_UNARY_OP_NEG:
+ case GGML_UNARY_OP_STEP:
+ case GGML_UNARY_OP_TANH:
+ case GGML_UNARY_OP_ELU:
+ case GGML_UNARY_OP_RELU:
+ case GGML_UNARY_OP_LEAKY:
+ {
+ n_tasks = 1;
+ } break;
+
+ case GGML_UNARY_OP_GELU:
+ case GGML_UNARY_OP_GELU_QUICK:
+ case GGML_UNARY_OP_SILU:
+ {
+ n_tasks = n_threads;
+ } break;
+ }
+ break;
+ case GGML_OP_SILU_BACK:
+ case GGML_OP_MUL:
+ case GGML_OP_NORM:
+ case GGML_OP_RMS_NORM:
+ case GGML_OP_RMS_NORM_BACK:
+ case GGML_OP_GROUP_NORM:
+ case GGML_OP_CONCAT:
+ {
+ n_tasks = n_threads;
+ } break;
+ case GGML_OP_MUL_MAT:
+ {
+ n_tasks = n_threads;
+
+ // TODO: use different scheduling for different matrix sizes
+ //const int nr0 = ggml_nrows(node->src[0]);
+ //const int nr1 = ggml_nrows(node->src[1]);
+
+ //n_tasks = MIN(n_threads, MAX(1, nr0/128));
+ //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks);
+
+#if defined(GGML_USE_CUBLAS)
+ if (ggml_cuda_can_mul_mat(node->src[0], node->src[1], node)) {
+ n_tasks = 1; // TODO: this actually is doing nothing
+ // the threads are still spinning
+ }
+#elif defined(GGML_USE_CLBLAST)
+ if (ggml_cl_can_mul_mat(node->src[0], node->src[1], node)) {
+ n_tasks = 1; // TODO: this actually is doing nothing
+ // the threads are still spinning
+ }
+#endif
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+ if (ggml_compute_forward_mul_mat_use_blas(node->src[0], node->src[1], node)) {
+ n_tasks = 1; // TODO: this actually is doing nothing
+ // the threads are still spinning
+ }
+#endif
+ } break;
+ case GGML_OP_OUT_PROD:
+ {
+ n_tasks = n_threads;
+ } break;
+ case GGML_OP_SCALE:
+ case GGML_OP_SET:
+ case GGML_OP_CONT:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_PERMUTE:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_GET_ROWS:
+ case GGML_OP_GET_ROWS_BACK:
+ case GGML_OP_DIAG:
+ {
+ n_tasks = 1;
+ } break;
+ case GGML_OP_DIAG_MASK_ZERO:
+ case GGML_OP_DIAG_MASK_INF:
+ case GGML_OP_SOFT_MAX:
+ case GGML_OP_SOFT_MAX_BACK:
+ case GGML_OP_ROPE:
+ case GGML_OP_ROPE_BACK:
+ case GGML_OP_ADD_REL_POS:
+ {
+ n_tasks = n_threads;
+ } break;
+ case GGML_OP_ALIBI:
+ {
+ n_tasks = 1; //TODO
+ } break;
+ case GGML_OP_CLAMP:
+ {
+ n_tasks = 1; //TODO
+ } break;
+ case GGML_OP_CONV_1D:
+ {
+ n_tasks = n_threads;
+ } break;
+ case GGML_OP_CONV_1D_STAGE_0:
+ {
+ n_tasks = n_threads;
+ } break;
+ case GGML_OP_CONV_1D_STAGE_1:
+ {
+ n_tasks = n_threads;
+ } break;
+ case GGML_OP_CONV_TRANSPOSE_1D:
+ {
+ n_tasks = n_threads;
+ } break;
+ case GGML_OP_CONV_2D:
+ {
+ n_tasks = n_threads;
+ } break;
+ case GGML_OP_CONV_2D_STAGE_0:
+ {
+ n_tasks = n_threads;
+ } break;
+ case GGML_OP_CONV_2D_STAGE_1:
+ {
+ n_tasks = n_threads;
+ } break;
+ case GGML_OP_CONV_TRANSPOSE_2D:
+ {
+ n_tasks = n_threads;
+ } break;
+ case GGML_OP_POOL_1D:
+ case GGML_OP_POOL_2D:
+ {
+ n_tasks = 1;
+ } break;
+ case GGML_OP_UPSCALE:
+ {
+ n_tasks = n_threads;
+ } break;
+ case GGML_OP_FLASH_ATTN:
+ {
+ n_tasks = n_threads;
+ } break;
+ case GGML_OP_FLASH_FF:
+ {
+ n_tasks = n_threads;
+ } break;
+ case GGML_OP_FLASH_ATTN_BACK:
+ {
+ n_tasks = n_threads;
+ } break;
+ case GGML_OP_WIN_PART:
+ case GGML_OP_WIN_UNPART:
+ case GGML_OP_GET_REL_POS:
+ case GGML_OP_MAP_UNARY:
+ case GGML_OP_MAP_BINARY:
+ case GGML_OP_MAP_CUSTOM1_F32:
+ case GGML_OP_MAP_CUSTOM2_F32:
+ case GGML_OP_MAP_CUSTOM3_F32:
+ {
+ n_tasks = 1;
+ } break;
+ case GGML_OP_MAP_CUSTOM1:
+ {
+ struct ggml_map_custom1_op_params * p = (struct ggml_map_custom1_op_params *) node->op_params;
+ if (p->n_tasks == GGML_N_TASKS_MAX) {
+ n_tasks = n_threads;
+ } else {
+ n_tasks = MIN(p->n_tasks, n_threads);
+ }
+ } break;
+ case GGML_OP_MAP_CUSTOM2:
+ {
+ struct ggml_map_custom2_op_params * p = (struct ggml_map_custom2_op_params *) node->op_params;
+ if (p->n_tasks == GGML_N_TASKS_MAX) {
+ n_tasks = n_threads;
+ } else {
+ n_tasks = MIN(p->n_tasks, n_threads);
+ }
+ } break;
+ case GGML_OP_MAP_CUSTOM3:
+ {
+ struct ggml_map_custom3_op_params * p = (struct ggml_map_custom3_op_params *) node->op_params;
+ if (p->n_tasks == GGML_N_TASKS_MAX) {
+ n_tasks = n_threads;
+ } else {
+ n_tasks = MIN(p->n_tasks, n_threads);
+ }
+ } break;
+ case GGML_OP_CROSS_ENTROPY_LOSS:
+ {
+ n_tasks = n_threads;
+ } break;
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+ {
+ n_tasks = n_threads;
+ } break;
+ case GGML_OP_NONE:
+ {
+ n_tasks = 1;
+ } break;
+ case GGML_OP_COUNT:
+ {
+ GGML_ASSERT(false);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+
+ assert(n_tasks > 0);
+
+ return n_tasks;
+}
+
static thread_ret_t ggml_graph_compute_thread(void * data) {
struct ggml_compute_state * state = (struct ggml_compute_state *) data;
const struct ggml_cgraph * cgraph = state->shared->cgraph;
const struct ggml_cplan * cplan = state->shared->cplan;
- const int * n_tasks_arr = cplan->n_tasks;
const int n_threads = state->shared->n_threads;
set_numa_thread_affinity(state->ith, n_threads);
if (node_n != -1) {
/* FINALIZE */
- struct ggml_tensor * node = state->shared->cgraph->nodes[node_n];
+ struct ggml_tensor * node = cgraph->nodes[node_n];
if (GGML_OP_HAS_FINALIZE[node->op]) {
- params.nth = n_tasks_arr[node_n];
+ params.nth = ggml_get_n_tasks(node, n_threads);
ggml_compute_forward(¶ms, node);
}
ggml_graph_compute_perf_stats_node(node, state->shared);
GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes);
struct ggml_tensor * node = cgraph->nodes[node_n];
- const int n_tasks = n_tasks_arr[node_n];
+ const int n_tasks = ggml_get_n_tasks(node, n_threads);
state->shared->perf_node_start_cycles = ggml_perf_cycles();
state->shared->perf_node_start_time_us = ggml_perf_time_us();
/* COMPUTE */
struct ggml_tensor * node = cgraph->nodes[node_n];
- const int n_tasks = n_tasks_arr[node_n];
+ const int n_tasks = ggml_get_n_tasks(node, n_threads);
struct ggml_compute_params params = {
/*.type =*/ GGML_TASK_COMPUTE,
struct ggml_tensor * node = cgraph->nodes[i];
+ size_t cur = 0;
+
switch (node->op) {
case GGML_OP_CPY:
case GGML_OP_DUP:
{
n_tasks = n_threads;
- size_t cur = 0;
if (ggml_is_quantized(node->type)) {
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
}
-
- work_size = MAX(work_size, cur);
} break;
case GGML_OP_ADD:
case GGML_OP_ADD1:
{
n_tasks = n_threads;
- size_t cur = 0;
-
if (ggml_is_quantized(node->src[0]->type)) {
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
}
-
- work_size = MAX(work_size, cur);
} break;
case GGML_OP_ACC:
{
n_tasks = n_threads;
- size_t cur = 0;
-
if (ggml_is_quantized(node->src[0]->type)) {
cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
}
-
- work_size = MAX(work_size, cur);
- } break;
- case GGML_OP_SUB:
- case GGML_OP_DIV:
- case GGML_OP_SQR:
- case GGML_OP_SQRT:
- case GGML_OP_LOG:
- case GGML_OP_SUM:
- case GGML_OP_SUM_ROWS:
- case GGML_OP_MEAN:
- case GGML_OP_ARGMAX:
- case GGML_OP_REPEAT:
- case GGML_OP_REPEAT_BACK:
- {
- n_tasks = 1;
- } break;
-
- case GGML_OP_UNARY:
- {
- switch (ggml_get_unary_op(node)) {
- case GGML_UNARY_OP_ABS:
- case GGML_UNARY_OP_SGN:
- case GGML_UNARY_OP_NEG:
- case GGML_UNARY_OP_STEP:
- case GGML_UNARY_OP_TANH:
- case GGML_UNARY_OP_ELU:
- case GGML_UNARY_OP_RELU:
- {
- n_tasks = 1;
- } break;
-
- case GGML_UNARY_OP_GELU:
- case GGML_UNARY_OP_GELU_QUICK:
- case GGML_UNARY_OP_SILU:
- {
- n_tasks = n_threads;
- } break;
- }
} break;
- case GGML_OP_SILU_BACK:
- case GGML_OP_MUL:
- case GGML_OP_NORM:
- case GGML_OP_RMS_NORM:
- case GGML_OP_RMS_NORM_BACK:
- case GGML_OP_GROUP_NORM:
- {
- n_tasks = n_threads;
- } break;
- case GGML_OP_CONCAT:
case GGML_OP_MUL_MAT:
{
- n_tasks = n_threads;
-
- // TODO: use different scheduling for different matrix sizes
- //const int nr0 = ggml_nrows(node->src[0]);
- //const int nr1 = ggml_nrows(node->src[1]);
-
- //n_tasks = MIN(n_threads, MAX(1, nr0/128));
- //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks);
-
- size_t cur = 0;
const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type;
-#if defined(GGML_USE_CUBLAS)
- if (ggml_cuda_can_mul_mat(node->src[0], node->src[1], node)) {
- n_tasks = 1; // TODO: this actually is doing nothing
- // the threads are still spinning
- } else
-#elif defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_CLBLAST)
if (ggml_cl_can_mul_mat(node->src[0], node->src[1], node)) {
- n_tasks = 1; // TODO: this actually is doing nothing
- // the threads are still spinning
cur = ggml_cl_mul_mat_get_wsize(node->src[0], node->src[1], node);
} else
#endif
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
if (ggml_compute_forward_mul_mat_use_blas(node->src[0], node->src[1], node)) {
- n_tasks = 1; // TODO: this actually is doing nothing
- // the threads are still spinning
if (node->src[0]->type != GGML_TYPE_F32) {
// here we need memory just for single 2D matrix from src0
cur = ggml_type_size(GGML_TYPE_F32)*(node->src[0]->ne[0]*node->src[0]->ne[1]);
#endif
if (node->src[1]->type != vec_dot_type) {
cur = ggml_type_size(vec_dot_type)*ggml_nelements(node->src[1])/ggml_blck_size(vec_dot_type);
- } else {
- cur = 0;
}
-
- work_size = MAX(work_size, cur);
} break;
case GGML_OP_OUT_PROD:
{
n_tasks = n_threads;
- size_t cur = 0;
-
if (ggml_is_quantized(node->src[0]->type)) {
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
}
-
- work_size = MAX(work_size, cur);
- } break;
- case GGML_OP_SCALE:
- {
- n_tasks = 1;
- } break;
- case GGML_OP_SET:
- case GGML_OP_CONT:
- case GGML_OP_RESHAPE:
- case GGML_OP_VIEW:
- case GGML_OP_PERMUTE:
- case GGML_OP_TRANSPOSE:
- case GGML_OP_GET_ROWS:
- case GGML_OP_GET_ROWS_BACK:
- case GGML_OP_DIAG:
- {
- n_tasks = 1;
- } break;
- case GGML_OP_DIAG_MASK_ZERO:
- case GGML_OP_DIAG_MASK_INF:
- case GGML_OP_SOFT_MAX:
- case GGML_OP_SOFT_MAX_BACK:
- case GGML_OP_ROPE:
- case GGML_OP_ROPE_BACK:
- case GGML_OP_ADD_REL_POS:
- {
- n_tasks = n_threads;
- } break;
- case GGML_OP_ALIBI:
- {
- n_tasks = 1; //TODO
- } break;
- case GGML_OP_CLAMP:
- {
- n_tasks = 1; //TODO
} break;
case GGML_OP_CONV_1D:
{
- n_tasks = n_threads;
-
GGML_ASSERT(node->src[0]->ne[3] == 1);
GGML_ASSERT(node->src[1]->ne[2] == 1);
GGML_ASSERT(node->src[1]->ne[3] == 1);
UNUSED(ne10);
UNUSED(ne11);
- size_t cur = 0;
-
if (node->src[0]->type == GGML_TYPE_F16 &&
node->src[1]->type == GGML_TYPE_F32) {
cur = sizeof(ggml_fp16_t)*(ne0*ne1*ew0);
} else {
GGML_ASSERT(false);
}
-
- work_size = MAX(work_size, cur);
- } break;
- case GGML_OP_CONV_1D_STAGE_0:
- {
- n_tasks = n_threads;
- } break;
- case GGML_OP_CONV_1D_STAGE_1:
- {
- n_tasks = n_threads;
} break;
case GGML_OP_CONV_TRANSPOSE_1D:
{
- n_tasks = n_threads;
-
GGML_ASSERT(node->src[0]->ne[3] == 1);
GGML_ASSERT(node->src[1]->ne[2] == 1);
GGML_ASSERT(node->src[1]->ne[3] == 1);
const int64_t ne10 = node->src[1]->ne[0]; // L
const int64_t ne11 = node->src[1]->ne[1]; // Cin
- size_t cur = 0;
if (node->src[0]->type == GGML_TYPE_F16 &&
node->src[1]->type == GGML_TYPE_F32) {
cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02;
} else {
GGML_ASSERT(false);
}
-
- work_size = MAX(work_size, cur);
} break;
case GGML_OP_CONV_2D:
{
- n_tasks = n_threads;
-
const int64_t ne00 = node->src[0]->ne[0]; // W
const int64_t ne01 = node->src[0]->ne[1]; // H
const int64_t ne02 = node->src[0]->ne[2]; // C
UNUSED(ne03);
UNUSED(ne2);
- size_t cur = 0;
-
if (node->src[0]->type == GGML_TYPE_F16 &&
node->src[1]->type == GGML_TYPE_F32) {
// im2col: [N*OH*OW, IC*KH*KW]
} else {
GGML_ASSERT(false);
}
-
- work_size = MAX(work_size, cur);
- } break;
- case GGML_OP_CONV_2D_STAGE_0:
- {
- n_tasks = n_threads;
- } break;
- case GGML_OP_CONV_2D_STAGE_1:
- {
- n_tasks = n_threads;
} break;
case GGML_OP_CONV_TRANSPOSE_2D:
{
- n_tasks = n_threads;
-
const int64_t ne00 = node->src[0]->ne[0]; // W
const int64_t ne01 = node->src[0]->ne[1]; // H
const int64_t ne02 = node->src[0]->ne[2]; // Channels Out
const int64_t ne11 = node->src[1]->ne[1]; // H
const int64_t ne12 = node->src[1]->ne[2]; // Channels In
- size_t cur = 0;
cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
-
- work_size = MAX(work_size, cur);
- } break;
- case GGML_OP_POOL_1D:
- case GGML_OP_POOL_2D:
- {
- n_tasks = 1;
- } break;
- case GGML_OP_UPSCALE:
- {
- n_tasks = n_threads;
} break;
case GGML_OP_FLASH_ATTN:
{
n_tasks = n_threads;
- size_t cur = 0;
-
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
if (node->src[1]->type == GGML_TYPE_F32) {
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
- }
-
- if (node->src[1]->type == GGML_TYPE_F16) {
+ } else if (node->src[1]->type == GGML_TYPE_F16) {
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
}
-
- work_size = MAX(work_size, cur);
} break;
case GGML_OP_FLASH_FF:
{
n_tasks = n_threads;
- size_t cur = 0;
-
if (node->src[1]->type == GGML_TYPE_F32) {
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
- }
-
- if (node->src[1]->type == GGML_TYPE_F16) {
+ } else if (node->src[1]->type == GGML_TYPE_F16) {
cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
}
-
- work_size = MAX(work_size, cur);
} break;
case GGML_OP_FLASH_ATTN_BACK:
{
n_tasks = n_threads;
- size_t cur = 0;
-
const int64_t D = node->src[0]->ne[0];
const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
if (node->src[1]->type == GGML_TYPE_F32) {
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
- }
-
- if (node->src[1]->type == GGML_TYPE_F16) {
+ } else if (node->src[1]->type == GGML_TYPE_F16) {
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
}
-
- work_size = MAX(work_size, cur);
- } break;
- case GGML_OP_WIN_PART:
- case GGML_OP_WIN_UNPART:
- case GGML_OP_GET_REL_POS:
- case GGML_OP_MAP_UNARY:
- case GGML_OP_MAP_BINARY:
- case GGML_OP_MAP_CUSTOM1_F32:
- case GGML_OP_MAP_CUSTOM2_F32:
- case GGML_OP_MAP_CUSTOM3_F32:
- {
- n_tasks = 1;
- } break;
- case GGML_OP_MAP_CUSTOM1:
- {
- struct ggml_map_custom1_op_params * p = (struct ggml_map_custom1_op_params *) node->op_params;
- if (p->n_tasks == GGML_N_TASKS_MAX) {
- n_tasks = n_threads;
- } else {
- n_tasks = MIN(p->n_tasks, n_threads);
- }
- } break;
- case GGML_OP_MAP_CUSTOM2:
- {
- struct ggml_map_custom2_op_params * p = (struct ggml_map_custom2_op_params *) node->op_params;
- if (p->n_tasks == GGML_N_TASKS_MAX) {
- n_tasks = n_threads;
- } else {
- n_tasks = MIN(p->n_tasks, n_threads);
- }
- } break;
- case GGML_OP_MAP_CUSTOM3:
- {
- struct ggml_map_custom3_op_params * p = (struct ggml_map_custom3_op_params *) node->op_params;
- if (p->n_tasks == GGML_N_TASKS_MAX) {
- n_tasks = n_threads;
- } else {
- n_tasks = MIN(p->n_tasks, n_threads);
- }
} break;
+
case GGML_OP_CROSS_ENTROPY_LOSS:
{
n_tasks = n_threads;
- size_t cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
-
- work_size = MAX(work_size, cur);
- } break;
- case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
- {
- n_tasks = n_threads;
- } break;
- case GGML_OP_NONE:
- {
- n_tasks = 1;
+ cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
} break;
case GGML_OP_COUNT:
{
GGML_ASSERT(false);
} break;
+ default:
+ break;
}
- cplan.n_tasks[i] = n_tasks;
+ work_size = MAX(work_size, cur);
}
if (work_size > 0) {
if (cplan->work_size > 0) {
GGML_ASSERT(cplan->work_data);
}
-
- for (int i = 0; i < cgraph->n_nodes; ++i) {
- if (cgraph->nodes[i]->op != GGML_OP_NONE) {
- GGML_ASSERT(cplan->n_tasks[i] > 0);
- }
- }
}
const int n_threads = cplan->n_threads;
return compute_status;
}
-void ggml_graph_reset(struct ggml_cgraph * cgraph) {
- for (int i = 0; i < cgraph->n_nodes; i++) {
- struct ggml_tensor * grad = cgraph->grads[i];
-
- if (grad) {
- ggml_set_zero(grad);
- }
- }
-}
-
void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads);
const uint32_t magic = GGML_FILE_MAGIC;
const uint32_t version = GGML_FILE_VERSION;
const uint32_t n_leafs = cgraph->n_leafs;
- const uint32_t nodes = cgraph->n_nodes;
+ const uint32_t n_nodes = cgraph->n_nodes;
fwrite(&magic, sizeof(uint32_t), 1, fout);
fwrite(&version, sizeof(uint32_t), 1, fout);
fwrite(&n_leafs, sizeof(uint32_t), 1, fout);
- fwrite(&nodes, sizeof(uint32_t), 1, fout);
+ fwrite(&n_nodes, sizeof(uint32_t), 1, fout);
fwrite(&size_eval, sizeof(uint64_t), 1, fout);
}
if (idx == -1) {
for (int k = 0; k < cgraph->n_nodes; ++k) {
if (args[j] == cgraph->nodes[k]) {
- idx = GGML_MAX_NODES + k;
+ idx = cgraph->n_leafs + k;
break;
}
}
}
}
-struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval) {
+struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval) {
assert(*ctx_data == NULL);
assert(*ctx_eval == NULL);
- struct ggml_cgraph result = { 0 };
+ struct ggml_cgraph * result = NULL;
struct ggml_tensor * data = NULL;
const uint32_t n_leafs = *(const uint32_t *) ptr; ptr += sizeof(n_leafs);
const uint32_t n_nodes = *(const uint32_t *) ptr; ptr += sizeof(n_nodes);
const uint64_t size_eval = *(const uint64_t *) ptr; ptr += sizeof(size_eval);
-
- result.n_leafs = n_leafs;
- result.n_nodes = n_nodes;
+ const int graph_size = MAX(n_leafs, n_nodes);
// create the data context
{
- const size_t overhead = (n_leafs + n_nodes)*ggml_tensor_overhead();
+ const size_t overhead = (n_leafs + n_nodes)*ggml_tensor_overhead() + ggml_graph_overhead_custom(graph_size, false);
struct ggml_init_params params = {
.mem_size = size_eval + overhead,
}
}
+ result = ggml_new_graph_custom(*ctx_eval, graph_size, false);
+
+ result->n_leafs = n_leafs;
+ result->n_nodes = n_nodes;
+
+
// leafs
{
uint32_t type;
tensor->nb[j] = nb[j];
}
- result.leafs[i] = tensor;
+ result->leafs[i] = tensor;
ptr += ggml_nbytes(tensor);
continue;
}
- if (arg_idx < GGML_MAX_NODES) {
- args[j] = result.leafs[arg_idx];
+ if (arg_idx < result->n_leafs) {
+ args[j] = result->leafs[arg_idx];
} else {
- args[j] = result.nodes[arg_idx - GGML_MAX_NODES];
+ args[j] = result->nodes[arg_idx - result->n_leafs];
}
}
tensor->src[j] = args[j];
}
- result.nodes[i] = tensor;
+ result->nodes[i] = tensor;
fprintf(stderr, "%s: loaded node %d: '%16s', %3d dims, %9zu bytes\n", __func__, i, tensor->name, n_dims, ggml_nbytes(tensor));
}
case GGML_OPT_ADAM:
{
result = (struct ggml_opt_params) {
- .type = GGML_OPT_ADAM,
- .n_threads = 1,
- .past = 0,
- .delta = 1e-5f,
+ .type = GGML_OPT_ADAM,
+ .graph_size = GGML_DEFAULT_GRAPH_SIZE,
+ .n_threads = 1, // FIXME: GGML_DEFAULT_N_THREADS ?
+ .past = 0,
+ .delta = 1e-5f,
.max_no_improvement = 100,
case GGML_OPT_LBFGS:
{
result = (struct ggml_opt_params) {
- .type = GGML_OPT_LBFGS,
- .n_threads = 1,
- .past = 0,
- .delta = 1e-5f,
+ .type = GGML_OPT_LBFGS,
+ .graph_size = GGML_DEFAULT_GRAPH_SIZE,
+ .n_threads = 1,
+ .past = 0,
+ .delta = 1e-5f,
.max_no_improvement = 0,
struct ggml_tensor * f) {
// build forward + backward compute graphs
- struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32)+ (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0));
- struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32)+ (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0));
-
- struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
- struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data;
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx, opt->params.graph_size, true);
+ ggml_build_forward_expand(gf, f);
- *gf = ggml_build_forward (f);
- *gb = ggml_build_backward(ctx, gf, true);
+ struct ggml_cgraph * gb = ggml_graph_dup(ctx, gf);
+ ggml_build_backward_expand(ctx, gf, gb, true);
return ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL);
}
// {
// ...
//
-// struct ggml_cgraph gf = ggml_build_forward(f);
+// struct ggml_cgraph * gf = ggml_new_graph(ctx);
+// ggml_build_forward_expand(gf, f);
//
// // set the input variable and parameter values
// ggml_set_f32(x, 2.0f);
#define GGML_QNT_VERSION 2 // bump this on quantization format changes
#define GGML_QNT_VERSION_FACTOR 1000 // do not change this
-#define GGML_MAX_DIMS 4
-#define GGML_MAX_NODES 16384
-#define GGML_MAX_PARAMS 1024
-#define GGML_MAX_CONTEXTS 64
-#define GGML_MAX_SRC 6
-#define GGML_MAX_NAME 64
-#define GGML_MAX_OP_PARAMS 64
-#define GGML_DEFAULT_N_THREADS 4
-
+#define GGML_MAX_DIMS 4
+#define GGML_MAX_PARAMS 1024
+#define GGML_MAX_CONTEXTS 64
+#define GGML_MAX_SRC 6
+#define GGML_MAX_NAME 64
+#define GGML_MAX_OP_PARAMS 64
+#define GGML_DEFAULT_N_THREADS 4
+#define GGML_DEFAULT_GRAPH_SIZE 2048
#if UINTPTR_MAX == 0xFFFFFFFF
#define GGML_MEM_ALIGN 4
#else
do { \
if (!(x)) { \
fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
- abort(); \
+ fflush(stderr); \
+ fflush(stdout); \
+ ggml_print_backtrace(); \
+ exit(1); \
} \
} while (0)
GGML_UNARY_OP_GELU,
GGML_UNARY_OP_GELU_QUICK,
GGML_UNARY_OP_SILU,
+ GGML_UNARY_OP_LEAKY
};
enum ggml_object_type {
int n_threads;
- // the `n_tasks` of nodes, 1:1 mapping to cgraph nodes
- int n_tasks[GGML_MAX_NODES];
-
// abort ggml_graph_compute when true
bool (*abort_callback)(void * data);
void * abort_callback_data;
};
- // next prime after GGML_MAX_NODES
- // #define GGML_GRAPH_HASHTABLE_SIZE 4099
- // next prime after GGML_MAX_NODES * 2 (nodes + leafs)
- // #define GGML_GRAPH_HASHTABLE_SIZE 8273
- // #define GGML_GRAPH_HASHTABLE_SIZE 16411
- #define GGML_GRAPH_HASHTABLE_SIZE 32771
-
enum ggml_cgraph_eval_order {
GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0,
GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT,
GGML_CGRAPH_EVAL_ORDER_COUNT
};
+ struct ggml_hash_set {
+ size_t size;
+ struct ggml_tensor ** keys;
+ };
+
// computation graph
struct ggml_cgraph {
+ int size;
int n_nodes;
int n_leafs;
- struct ggml_tensor * nodes[GGML_MAX_NODES];
- struct ggml_tensor * grads[GGML_MAX_NODES];
- struct ggml_tensor * leafs[GGML_MAX_NODES];
+ struct ggml_tensor ** nodes;
+ struct ggml_tensor ** grads;
+ struct ggml_tensor ** leafs;
- void * visited_hash_table[GGML_GRAPH_HASHTABLE_SIZE];
+ struct ggml_hash_set visited_hash_table;
enum ggml_cgraph_eval_order order;
int64_t perf_time_us;
};
- static const size_t GGML_GRAPH_SIZE = sizeof(struct ggml_cgraph);
-
// scratch buffer
struct ggml_scratch {
size_t offs;
GGML_API int64_t ggml_cycles(void);
GGML_API int64_t ggml_cycles_per_ms(void);
+ GGML_API void ggml_print_backtrace(void);
+
GGML_API void ggml_numa_init(void); // call once for better performance on NUMA systems
GGML_API bool ggml_is_numa(void); // true if init detected that system has >1 NUMA node
// Context tensor enumeration and lookup
GGML_API struct ggml_tensor * ggml_get_first_tensor(struct ggml_context * ctx);
GGML_API struct ggml_tensor * ggml_get_next_tensor (struct ggml_context * ctx, struct ggml_tensor * tensor);
- GGML_API struct ggml_tensor * ggml_get_tensor (struct ggml_context * ctx, const char * name);
+ GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name);
GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
struct ggml_context * ctx,
struct ggml_tensor * a);
+ GGML_API struct ggml_tensor * ggml_leaky(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
GGML_API struct ggml_tensor * ggml_relu_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
int s0, // stride
int p0); // padding
+ // the result will have 2*p0 padding for the first dimension
+ // and 2*p1 padding for the second dimension
GGML_API struct ggml_tensor * ggml_pool_2d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k1,
int s0,
int s1,
- int p0,
- int p1);
+ float p0,
+ float p1);
// nearest interpolate
// used in stable-diffusion
GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep);
- GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
- GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
-
// graph allocation in a context
- GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx);
- GGML_API struct ggml_cgraph * ggml_build_forward_ctx(struct ggml_context * ctx, struct ggml_tensor * tensor);
+ GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
+ GGML_API struct ggml_cgraph * ggml_new_graph_custom (struct ggml_context * ctx, size_t size, bool grads);
+ GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
+ GGML_API struct ggml_cgraph * ggml_graph_view (struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i0, int i1);
+ GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
+ GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // zero grads
+ GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);
+
GGML_API size_t ggml_graph_overhead(void);
+ GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads);
// ggml_graph_plan() has to be called before ggml_graph_compute()
// when plan.work_size > 0, caller must allocate memory for plan.work_data
GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
- GGML_API int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
- GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);
+ GGML_API int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
// same as ggml_graph_compute() but the work data is allocated as a part of the context
// note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);
- GGML_API void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname);
- GGML_API struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval);
+ GGML_API void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname);
+ GGML_API struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval);
// print info and performance information for the graph
GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);
struct ggml_opt_params {
enum ggml_opt_type type;
+ size_t graph_size;
+
int n_threads;
// delta-based convergence test
#define LLAMA_ATTRIBUTE_FORMAT(...)
#endif
+#define LLAMA_MAX_NODES 4096
+
//
// logging
//
}
struct ggml_cgraph * build_llama() {
- struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
GGML_ASSERT(n_embd_head == hparams.n_rot);
}
struct ggml_cgraph * build_baichuan() {
- struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
}
struct ggml_cgraph * build_falcon() {
- struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
}
struct ggml_cgraph * build_starcoder() {
- struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
struct ggml_tensor * cur;
struct ggml_tensor * pos;
}
struct ggml_cgraph * build_persimmon() {
- struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
const int64_t n_rot = n_embd_head / 2;
}
struct ggml_cgraph * build_refact() {
- struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
}
struct ggml_cgraph * build_bloom() {
- struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
}
struct ggml_cgraph * build_mpt() {
- struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
{
static const size_t tensor_alignment = 32;
// the compute buffer is used to store the tensor and graph structs, while the allocator buffer is used for the tensor data
- ctx->buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
+ ctx->buf_compute.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead());
// create measure allocator
ctx->alloc = ggml_allocr_new_measure(tensor_alignment);
if (kv_buf_size) {
const size_t elt_size = ggml_element_size(kv_self.k);
- ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true });
- ggml_cgraph gf{};
+ ggml_context * cpy_ctx = ggml_init({ 6*ggml_tensor_overhead() + ggml_graph_overhead(), NULL, /* no_alloc */ true });
+ ggml_cgraph * gf = ggml_new_graph(cpy_ctx);
ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer);
std::vector<uint8_t> kout3d_data(ggml_nbytes(kout3d), 0);
kv_head, n_embd, n_layer,
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
- ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d));
- ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d));
- ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1);
+ ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, k3d, kout3d));
+ ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, v3d, vout3d));
+ ggml_graph_compute_helper(ctx->work_buffer, gf, /*n_threads*/ 1);
ggml_free(cpy_ctx);
const size_t elt_size = ggml_element_size(kv_self.k);
- ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true });
- ggml_cgraph gf{};
+ ggml_context * cpy_ctx = ggml_init({ 6*ggml_tensor_overhead() + ggml_graph_overhead(), NULL, /* no_alloc */ true });
+ ggml_cgraph * gf = ggml_new_graph(cpy_ctx);
ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer);
kin3d->data = (void *) inp;
kv_head, n_embd, n_layer,
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
- ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d));
- ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d));
- ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1);
+ ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, kin3d, k3d));
+ ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, vin3d, v3d));
+ ggml_graph_compute_helper(ctx->work_buffer, gf, /*n_threads*/ 1);
ggml_free(cpy_ctx);
}
cp -rpv ../ggml/src/ggml.c ./ggml.c
cp -rpv ../ggml/src/ggml-alloc.c ./ggml-alloc.c
+cp -rpv ../ggml/src/ggml-backend-impl.h ./ggml-backend-impl.h
cp -rpv ../ggml/src/ggml-backend.c ./ggml-backend.c
-cp -rpv ../ggml/src/ggml-cuda.h ./ggml-cuda.h
cp -rpv ../ggml/src/ggml-cuda.cu ./ggml-cuda.cu
-cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h
-cp -rpv ../ggml/src/ggml-opencl.cpp ./ggml-opencl.cpp
+cp -rpv ../ggml/src/ggml-cuda.h ./ggml-cuda.h
+cp -rpv ../ggml/src/ggml-impl.h ./ggml-impl.h
cp -rpv ../ggml/src/ggml-metal.h ./ggml-metal.h
cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m
cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal
+cp -rpv ../ggml/src/ggml-mpi.h ./ggml-mpi.h
+cp -rpv ../ggml/src/ggml-mpi.c ./ggml-mpi.c
+cp -rpv ../ggml/src/ggml-opencl.cpp ./ggml-opencl.cpp
+cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h
+cp -rpv ../ggml/src/ggml-quants.c ./ggml-quants.c
+cp -rpv ../ggml/src/ggml-quants.h ./ggml-quants.h
cp -rpv ../ggml/include/ggml/ggml.h ./ggml.h
cp -rpv ../ggml/include/ggml/ggml-alloc.h ./ggml-alloc.h
cp -rpv ../ggml/include/ggml/ggml-backend.h ./ggml-backend.h
printf("GGML_N_THREADS = %d\n", n_threads);
}
- struct ggml_cgraph * gf = ggml_build_forward_ctx(ctx0, f);
- struct ggml_cgraph * gb = ggml_new_graph(ctx0);
- *gb = *gf;
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
+ struct ggml_cgraph * gb = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
+ ggml_build_forward_expand(gf, f);
+ ggml_graph_cpy(gf, gb);
ggml_build_backward_expand(ctx0, gf, gb, false);
ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
struct ggml_tensor * d = ggml_sub(ctx, c, ab);
struct ggml_tensor * e = ggml_sum(ctx, ggml_sqr(ctx, d));
- struct ggml_cgraph ge = ggml_build_forward(e);
- ggml_graph_reset(&ge);
+ struct ggml_cgraph * ge = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, true);
+ ggml_build_forward_expand(ge, e);
+ ggml_graph_reset(ge);
- ggml_graph_compute_with_ctx(ctx, &ge, /*n_threads*/ 1);
+ ggml_graph_compute_with_ctx(ctx, ge, /*n_threads*/ 1);
const float fe = ggml_get_f32_1d(e, 0);
printf("%s: e = %.4f\n", __func__, fe);
ggml_opt(ctx, opt_params, e);
- ggml_graph_reset(&ge);
+ ggml_graph_reset(ge);
- ggml_graph_compute_with_ctx(ctx, &ge, /*n_threads*/ 1);
+ ggml_graph_compute_with_ctx(ctx, ge, /*n_threads*/ 1);
const float fe_opt = ggml_get_f32_1d(e, 0);
printf("%s: original e = %.4f\n", __func__, fe);