params.mem_buffer = NULL;
params.no_alloc = true;
struct ggml_context * ctx = NULL;
- struct ggml_allocr * alloc = NULL;
- struct ggml_cgraph * gf = NULL;
+ struct ggml_gallocr * alloc = NULL;
+ struct ggml_cgraph * gf = NULL;
ctx = ggml_init(params);
- alloc = ggml_allocr_new_measure(tensor_alignment);
+ alloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
gf = build_graph_lora(ctx, tensor, lora_a, lora_b, scaling);
- size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf);
- ggml_allocr_free(alloc);
- ggml_free(ctx);
-
- static std::vector<uint8_t> data_compute;
- data_compute.resize(alloc_size + tensor_alignment);
- ctx = ggml_init(params);
- alloc = ggml_allocr_new(data_compute.data(), data_compute.size(), tensor_alignment);
- gf = build_graph_lora(ctx, tensor, lora_a, lora_b, scaling);
- ggml_allocr_alloc_graph(alloc, gf);
- ggml_allocr_free(alloc);
+ ggml_gallocr_alloc_graph(alloc, gf);
struct ggml_cplan cplan = ggml_graph_plan(gf, n_threads);
static std::vector<uint8_t> data_work;
ggml_graph_compute(gf, &cplan);
+ ggml_gallocr_free(alloc);
ggml_free(ctx);
return true;
}
#include "ggml.h"
#include "ggml-alloc.h"
+#include "ggml-backend.h"
#include "llama.h"
#include "common.h"
#include "train.h"
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
-static const size_t tensor_alignment = 32;
-
struct my_llama_hparams {
uint32_t n_vocab = 32000;
uint32_t n_ctx = 512;
struct my_llama_lora {
struct ggml_context * ctx = NULL;
- std::vector<uint8_t> data;
+ ggml_backend_buffer_t data;
my_llama_lora_hparams hparams;
}
}
-static void alloc_lora(struct ggml_allocr * alloc, struct my_llama_lora * lora) {
- ggml_allocr_alloc(alloc, lora->tok_embeddings_a);
- ggml_allocr_alloc(alloc, lora->tok_embeddings_b);
- ggml_allocr_alloc(alloc, lora->norm_a);
- ggml_allocr_alloc(alloc, lora->norm_b);
- ggml_allocr_alloc(alloc, lora->output_a);
- ggml_allocr_alloc(alloc, lora->output_b);
- for (uint32_t i = 0; i < lora->layers.size(); ++i) {
- auto & layer = lora->layers[i];
- ggml_allocr_alloc(alloc, layer.attention_norm_a);
- ggml_allocr_alloc(alloc, layer.attention_norm_b);
- ggml_allocr_alloc(alloc, layer.wq_a);
- ggml_allocr_alloc(alloc, layer.wq_b);
- ggml_allocr_alloc(alloc, layer.wk_a);
- ggml_allocr_alloc(alloc, layer.wk_b);
- ggml_allocr_alloc(alloc, layer.wv_a);
- ggml_allocr_alloc(alloc, layer.wv_b);
- ggml_allocr_alloc(alloc, layer.wo_a);
- ggml_allocr_alloc(alloc, layer.wo_b);
- ggml_allocr_alloc(alloc, layer.ffn_norm_a);
- ggml_allocr_alloc(alloc, layer.ffn_norm_b);
- ggml_allocr_alloc(alloc, layer.w1_a);
- ggml_allocr_alloc(alloc, layer.w1_b);
- ggml_allocr_alloc(alloc, layer.w2_a);
- ggml_allocr_alloc(alloc, layer.w2_b);
- ggml_allocr_alloc(alloc, layer.w3_a);
- ggml_allocr_alloc(alloc, layer.w3_b);
- }
- ggml_allocr_alloc(alloc, lora->tok_embeddings_a->grad);
- ggml_allocr_alloc(alloc, lora->tok_embeddings_b->grad);
- ggml_allocr_alloc(alloc, lora->norm_a->grad);
- ggml_allocr_alloc(alloc, lora->norm_b->grad);
- ggml_allocr_alloc(alloc, lora->output_a->grad);
- ggml_allocr_alloc(alloc, lora->output_b->grad);
- for (uint32_t i = 0; i < lora->layers.size(); ++i) {
- auto & layer = lora->layers[i];
- ggml_allocr_alloc(alloc, layer.attention_norm_a->grad);
- ggml_allocr_alloc(alloc, layer.attention_norm_b->grad);
- ggml_allocr_alloc(alloc, layer.wq_a->grad);
- ggml_allocr_alloc(alloc, layer.wq_b->grad);
- ggml_allocr_alloc(alloc, layer.wk_a->grad);
- ggml_allocr_alloc(alloc, layer.wk_b->grad);
- ggml_allocr_alloc(alloc, layer.wv_a->grad);
- ggml_allocr_alloc(alloc, layer.wv_b->grad);
- ggml_allocr_alloc(alloc, layer.wo_a->grad);
- ggml_allocr_alloc(alloc, layer.wo_b->grad);
- ggml_allocr_alloc(alloc, layer.ffn_norm_a->grad);
- ggml_allocr_alloc(alloc, layer.ffn_norm_b->grad);
- ggml_allocr_alloc(alloc, layer.w1_a->grad);
- ggml_allocr_alloc(alloc, layer.w1_b->grad);
- ggml_allocr_alloc(alloc, layer.w2_a->grad);
- ggml_allocr_alloc(alloc, layer.w2_b->grad);
- ggml_allocr_alloc(alloc, layer.w3_a->grad);
- ggml_allocr_alloc(alloc, layer.w3_b->grad);
- }
-}
-
static void init_lora(const struct my_llama_model * model, struct my_llama_lora * lora) {
const auto & lparams = lora->hparams;
set_param_lora(lora);
- // measure data size
- size_t size = 0;
- for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
- size += GGML_PAD(ggml_nbytes(t), tensor_alignment);
- }
-
- // allocate data
- struct ggml_allocr * alloc = NULL;
- lora->data.resize(size + tensor_alignment);
- alloc = ggml_allocr_new(lora->data.data(), lora->data.size(), tensor_alignment);
- alloc_lora(alloc, lora);
- ggml_allocr_free(alloc);
+ // allocate data for lora tensors
+ lora->data = ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_cpu_buffer_type());
}
static void randomize_lora(struct my_llama_lora * lora, int seed, float mean, float std, float min, float max) {
static struct ggml_tensor * llama_build_lora_finetune_graphs(
struct my_llama_model * model,
struct my_llama_lora * lora,
- struct ggml_allocr * alloc,
+ ggml_gallocr_t alloc,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct ggml_cgraph * gb,
const int n_tokens,
const int n_batch,
const bool enable_flash_attn,
- const bool enable_checkpointing) {
+ const bool enable_checkpointing,
+ const bool measure_only) {
ggml_set_scratch(ctx, { 0, 0, nullptr, });
const int n_past = 0;
// KQ_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);
- ggml_allocr_alloc(alloc, KQ_pos);
- if (!ggml_allocr_is_measure(alloc)) {
- int * data = (int *) KQ_pos->data;
- for (int i = 0; i < N; ++i) {
- data[i] = n_past + i;
- }
- }
+ ggml_set_input(KQ_pos);
// rope has so much parameters that we make a custom function for it
auto rope = [ctx, KQ_pos, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
// input gradient
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, 1.0f));
GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);
- ggml_allocr_alloc(alloc, t36->grad);
+ ggml_set_input(t36->grad);
// KQ_pos
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, 1.0f));
// note: they will be freed in reverse order
for (unsigned int i = 0; i < checkpoints.size(); ++i) {
if (checkpoints[i]->data == NULL && checkpoints[i]->view_src == NULL) {
- ggml_allocr_alloc(alloc, checkpoints[i]);
+ ggml_set_input(checkpoints[i]);
}
}
- ggml_allocr_alloc_graph(alloc, gb);
+ if (measure_only) {
+ ggml_gallocr_reserve(alloc, gb);
+ } else {
+ ggml_gallocr_alloc_graph(alloc, gb);
+
+ // set KQ_pos
+ {
+ int * data = (int *) KQ_pos->data;
+ for (int i = 0; i < N; ++i) {
+ data[i] = n_past + i;
+ }
+ }
+ }
// remove the additional nodes and leafs
for (int i = n_leafs_before; i < gb->n_leafs; ++i) {
printf("%s: seen train_samples %llu\n", __func__, (long long unsigned) train->train_samples);
printf("%s: seen train_tokens %llu\n", __func__, (long long unsigned) train->train_tokens);
printf("%s: completed train_epochs %llu\n", __func__, (long long unsigned) train->train_epochs);
- printf("%s: lora_size = %zu bytes (%.1f MB)\n", __func__, (ggml_used_mem(lora.ctx) + lora.data.size()), (float) (ggml_used_mem(lora.ctx) + lora.data.size()) / (1024.0f*1024.0f));
+ printf("%s: lora_size = %zu bytes (%.1f MB)\n", __func__, (ggml_used_mem(lora.ctx) + ggml_backend_buffer_get_size(lora.data)), (float) (ggml_used_mem(lora.ctx) + ggml_backend_buffer_get_size(lora.data)) / (1024.0f*1024.0f));
if (params.only_write_lora) {
save_train_files_data save_data;
int n_vocab = model.hparams.n_vocab;
int n_batch = params.common.n_batch;
-
- std::vector<uint8_t> mem_input_data;
- std::vector<uint8_t> mem_compute_data;
-
// context for input tensors without their data
struct ggml_init_params ctx_input_params = {
ggml_tensor_overhead() * 2, // mem_size
struct ggml_tensor * tokens_input = ggml_new_tensor_2d(ctx_input, GGML_TYPE_I32, n_tokens, n_batch);
struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx_input, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
+ // allocate input tensors
// measure required memory for input tensors
- size_t max_input_size = GGML_PAD(ggml_nbytes(tokens_input), tensor_alignment) +
- GGML_PAD(ggml_nbytes(target_probs), tensor_alignment) +
- tensor_alignment;
+ ggml_backend_buffer_t input_data = ggml_backend_alloc_ctx_tensors_from_buft(ctx_input, ggml_backend_cpu_buffer_type());
+ size_t max_input_size = ggml_backend_buffer_get_size(input_data);
printf("%s: input_size = %zu bytes (%.1f MB)\n", __func__, max_input_size, (float) max_input_size / (1024.0f*1024.0f));
- // allocate input tensors
- mem_input_data.resize(max_input_size);
- ggml_allocr_t alloc_inps = ggml_allocr_new(mem_input_data.data(), mem_input_data.size(), tensor_alignment);
- ggml_allocr_alloc(alloc_inps, tokens_input);
- ggml_allocr_alloc(alloc_inps, target_probs);
-
// context for compute tensors without their data
const size_t estimated_compute_size_wo_data = (
2*LLAMA_TRAIN_MAX_NODES*ggml_tensor_overhead() +
// find best evaluation order
for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) {
ctx_compute = ggml_init(ctx_compute_params);
- ggml_allocr_t alloc = ggml_allocr_new_measure(tensor_alignment);
+ ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = (enum ggml_cgraph_eval_order) order;
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
&logits, tokens_input, target_probs,
n_tokens, n_batch,
params.common.use_flash,
- params.common.use_checkpointing
+ params.common.use_checkpointing,
+ true
);
- size_t max_compute_size = ggml_allocr_max_size(alloc) + tensor_alignment;
+ size_t max_compute_size = ggml_gallocr_get_buffer_size(alloc, 0); // FIXME: this will still allocate the buffer
if (max_compute_size < best_compute_size) {
best_compute_size = max_compute_size;
best_order = gf->order;
}
- ggml_allocr_free(alloc);
+ ggml_gallocr_free(alloc);
ggml_free(ctx_compute);
}
size_t max_compute_size = best_compute_size;
"invalid");
// allocate compute tensors
- mem_compute_data.resize(max_compute_size);
ctx_compute = ggml_init(ctx_compute_params);
- ggml_allocr_t alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment);
+ ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = best_order;
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
&logits, tokens_input, target_probs,
n_tokens, n_batch,
params.common.use_flash,
- params.common.use_checkpointing
+ params.common.use_checkpointing,
+ false
);
- ggml_allocr_free(alloc);
- ggml_allocr_free(alloc_inps);
-
// tokenize data
std::vector<llama_token> train_tokens;
ggml_free(ctx_work);
ggml_free(ctx_compute);
ggml_free(ctx_input);
+ ggml_gallocr_free(alloc);
+
int64_t t1 = ggml_time_ms();
printf("%s: total training time: ", __func__);
ggml_backend_buffer_t params_buffer = NULL;
ggml_backend_buffer_t compute_buffer = NULL;
ggml_backend_t backend = NULL;
- ggml_allocr * compute_alloc = NULL;
+ ggml_gallocr_t compute_alloc = NULL;
};
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs) {
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size, image_size, 3, batch_size);
- ggml_allocr_alloc(ctx->compute_alloc, inp_raw);
-
- if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
- float * data = (float *)malloc(ggml_nbytes(inp_raw));
-
- for (size_t i = 0; i < imgs->size; i++) {
- const int nx = imgs->data[i].nx;
- const int ny = imgs->data[i].ny;
- GGML_ASSERT(nx == image_size && ny == image_size);
-
- const int n = nx * ny;
-
- for (int b = 0; b < batch_size; b++) {
- for (int k = 0; k < 3; k++) {
- for (int y = 0; y < ny; y++) {
- for (int x = 0; x < nx; x++) {
- data[(b * 3 * n) + k * n + y * nx + x] = imgs->data[b].buf[3 * (y * nx + x) + k];
- }
- }
- }
- }
- }
- ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
- free(data);
- }
+ ggml_set_name(inp_raw, "inp_raw");
+ ggml_set_input(inp_raw);
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
// concat class_embeddings and patch_embeddings
struct ggml_tensor * embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
- ggml_allocr_alloc(ctx->compute_alloc, embeddings);
- if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
- void* zero_mem = malloc(ggml_nbytes(embeddings));
- memset(zero_mem, 0, ggml_nbytes(embeddings));
- ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings));
- free(zero_mem);
- }
+ ggml_set_name(embeddings, "embeddings");
+ ggml_set_input(embeddings);
embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
- ggml_allocr_alloc(ctx->compute_alloc, positions);
- if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
- int* positions_data = (int*)malloc(ggml_nbytes(positions));
- for (int i = 0; i < num_positions; i++) {
- positions_data[i] = i;
- }
- ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
- free(positions_data);
- }
+ ggml_set_name(positions, "positions");
+ ggml_set_input(positions);
embeddings =
ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions));
embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
- ggml_allocr_alloc(ctx->compute_alloc, patches);
- if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
- int* patches_data = (int*)malloc(ggml_nbytes(patches));
- for (int i = 0; i < num_patches; i++) {
- patches_data[i] = i + 1;
- }
- ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
- free(patches_data);
- }
+ ggml_set_name(patches, "patches");
+ ggml_set_input(patches);
// shape [1, 576, 1024]
// ne is whcn, ne = [1024, 576, 1, 1]
}
// data
- size_t buffer_size = 0;
+ size_t model_size = 0;
{
for (int i = 0; i < n_tensors; ++i) {
const char * name = gguf_get_tensor_name(ctx, i);
enum ggml_type type = gguf_get_tensor_type(ctx, i);
struct ggml_tensor * cur = ggml_get_tensor(meta, name);
size_t tensor_size = ggml_nbytes(cur);
- buffer_size += tensor_size;
+ model_size += tensor_size;
if (verbosity >= 3) {
printf("%s: tensor[%d]: n_dims = %d, name = %s, tensor_size=%zu, offset=%zu, shape:[%" PRIu64 ", %" PRIu64 ", %" PRIu64 ", %" PRIu64 "], type = %s\n",
__func__, i, ggml_n_dims(cur), cur->name, tensor_size, offset, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], ggml_type_name(type));
}
}
- buffer_size += n_tensors * 128 /* CLIP PADDING */;
-
clip_ctx * new_clip = new clip_ctx;
// update projector type
printf("%s: text_encoder: %d\n", __func__, new_clip->has_text_encoder);
printf("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder);
printf("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector);
- printf("%s: model size: %.2f MB\n", __func__, buffer_size / 1024.0 / 1024.0);
+ printf("%s: model size: %.2f MB\n", __func__, model_size / 1024.0 / 1024.0);
printf("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0);
}
}
- printf("%s: params backend buffer size = % 6.2f MB (%i tensors)\n", __func__, buffer_size / (1024.0 * 1024.0), n_tensors);
+ printf("%s: params backend buffer size = % 6.2f MB (%i tensors)\n", __func__, model_size / (1024.0 * 1024.0), n_tensors);
// load tensors
{
}
// alloc memory and offload data
- new_clip->params_buffer = ggml_backend_alloc_buffer(new_clip->backend, buffer_size);
- ggml_allocr* alloc = ggml_allocr_new_from_buffer(new_clip->params_buffer);
+ new_clip->params_buffer = ggml_backend_alloc_ctx_tensors(new_clip->ctx_data, new_clip->backend);
for (int i = 0; i < n_tensors; ++i) {
const char * name = gguf_get_tensor_name(ctx, i);
struct ggml_tensor * cur = ggml_get_tensor(new_clip->ctx_data, name);
- ggml_allocr_alloc(alloc, cur);
const size_t offset = gguf_get_data_offset(ctx) + gguf_get_tensor_offset(ctx, i);
fin.seekg(offset, std::ios::beg);
if (!fin) {
ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
}
}
- ggml_allocr_free(alloc);
fin.close();
}
// measure mem requirement and allocate
{
new_clip->buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead());
- new_clip->compute_alloc = ggml_allocr_new_measure_from_backend(new_clip->backend);
+ new_clip->compute_alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(new_clip->backend));
clip_image_f32_batch batch;
batch.size = 1;
ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch);
- size_t compute_memory_buffer_size = ggml_allocr_alloc_graph(new_clip->compute_alloc, gf);
- ggml_allocr_free(new_clip->compute_alloc);
- new_clip->compute_buffer = ggml_backend_alloc_buffer(new_clip->backend, compute_memory_buffer_size);
- new_clip->compute_alloc = ggml_allocr_new_from_buffer(new_clip->compute_buffer);
-
+ ggml_gallocr_reserve(new_clip->compute_alloc, gf);
+ size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(new_clip->compute_alloc, 0);
printf("%s: compute allocated memory: %.2f MB\n", __func__, compute_memory_buffer_size /1024.0/1024.0);
}
GGML_ASSERT(batch_size == 1); // TODO: support multiple images
}
- // reset alloc buffer to clean the memory from previous invocations
- ggml_allocr_reset(ctx->compute_alloc);
-
// build the inference graph
ggml_cgraph * gf = clip_image_build_graph(ctx, imgs);
- ggml_allocr_alloc_graph(ctx->compute_alloc, gf);
+ ggml_gallocr_alloc_graph(ctx->compute_alloc, gf);
+
+ // set inputs
+ const auto & model = ctx->vision_model;
+ const auto & hparams = model.hparams;
+ const int image_size = hparams.image_size;
+ const int patch_size = hparams.patch_size;
+ const int num_patches = ((image_size / patch_size) * (image_size / patch_size));
+ const int num_positions = num_patches + 1;
+
+ {
+ struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
+ float * data = (float *)malloc(ggml_nbytes(inp_raw));
+
+ for (size_t i = 0; i < imgs->size; i++) {
+ const int nx = imgs->data[i].nx;
+ const int ny = imgs->data[i].ny;
+ GGML_ASSERT(nx == image_size && ny == image_size);
+
+ const int n = nx * ny;
+
+ for (int b = 0; b < batch_size; b++) {
+ for (int k = 0; k < 3; k++) {
+ for (int y = 0; y < ny; y++) {
+ for (int x = 0; x < nx; x++) {
+ data[(b * 3 * n) + k * n + y * nx + x] = imgs->data[b].buf[3 * (y * nx + x) + k];
+ }
+ }
+ }
+ }
+ }
+ ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
+ free(data);
+ }
+
+ {
+ struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings");
+
+ void* zero_mem = malloc(ggml_nbytes(embeddings));
+ memset(zero_mem, 0, ggml_nbytes(embeddings));
+ ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings));
+ free(zero_mem);
+ }
+
+ {
+ struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
+
+ int* positions_data = (int*)malloc(ggml_nbytes(positions));
+ for (int i = 0; i < num_positions; i++) {
+ positions_data[i] = i;
+ }
+ ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
+ free(positions_data);
+ }
+
+ {
+ struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
+ int* patches_data = (int*)malloc(ggml_nbytes(patches));
+ for (int i = 0; i < num_patches; i++) {
+ patches_data[i] = i + 1;
+ }
+ ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
+ free(patches_data);
+ }
if (ggml_backend_is_cpu(ctx->backend)) {
ggml_backend_cpu_set_n_threads(ctx->backend, n_threads);
#include "ggml.h"
#include "ggml-alloc.h"
+#include "ggml-backend.h"
#include "common.h"
#include "train.h"
#include "llama.h"
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
-static const size_t tensor_alignment = 32;
-
struct my_llama_hparams {
uint32_t n_vocab = 32000;
uint32_t n_ctx = 512;
struct my_llama_model {
struct ggml_context * ctx = NULL;
- std::vector<uint8_t> data;
+ ggml_backend_buffer_t data = NULL;
my_llama_hparams hparams;
}
}
-static void alloc_model(struct ggml_allocr * alloc, struct my_llama_model * model) {
- ggml_allocr_alloc(alloc, model->tok_embeddings);
- ggml_allocr_alloc(alloc, model->norm);
- ggml_allocr_alloc(alloc, model->output);
- for (uint32_t i = 0; i < model->layers.size(); ++i) {
- auto & layer = model->layers[i];
- ggml_allocr_alloc(alloc, layer.attention_norm);
- ggml_allocr_alloc(alloc, layer.wq);
- ggml_allocr_alloc(alloc, layer.wk);
- ggml_allocr_alloc(alloc, layer.wv);
- ggml_allocr_alloc(alloc, layer.wo);
- ggml_allocr_alloc(alloc, layer.ffn_norm);
- ggml_allocr_alloc(alloc, layer.w1);
- ggml_allocr_alloc(alloc, layer.w2);
- ggml_allocr_alloc(alloc, layer.w3);
- }
- ggml_allocr_alloc(alloc, model->tok_embeddings->grad);
- ggml_allocr_alloc(alloc, model->norm->grad);
- ggml_allocr_alloc(alloc, model->output->grad);
- for (uint32_t i = 0; i < model->layers.size(); ++i) {
- auto & layer = model->layers[i];
- ggml_allocr_alloc(alloc, layer.attention_norm->grad);
- ggml_allocr_alloc(alloc, layer.wq->grad);
- ggml_allocr_alloc(alloc, layer.wk->grad);
- ggml_allocr_alloc(alloc, layer.wv->grad);
- ggml_allocr_alloc(alloc, layer.wo->grad);
- ggml_allocr_alloc(alloc, layer.ffn_norm->grad);
- ggml_allocr_alloc(alloc, layer.w1->grad);
- ggml_allocr_alloc(alloc, layer.w2->grad);
- ggml_allocr_alloc(alloc, layer.w3->grad);
- }
-}
-
static void init_model(struct my_llama_model * model) {
const auto & hparams = model->hparams;
set_param_model(model);
- // measure data size
- size_t size = 0;
- for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
- size += GGML_PAD(ggml_nbytes(t), tensor_alignment);
- }
-
// allocate data
- struct ggml_allocr * alloc = NULL;
- model->data.resize(size + tensor_alignment);
- alloc = ggml_allocr_new(model->data.data(), model->data.size(), tensor_alignment);
- alloc_model(alloc, model);
+ model->data = ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_cpu_buffer_type());
}
static void randomize_model(struct my_llama_model * model, int seed, float mean, float std, float min, float max) {
static struct ggml_tensor * llama_build_train_graphs(
struct my_llama_model * model,
- struct ggml_allocr * alloc,
+ ggml_gallocr_t alloc,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct ggml_cgraph * gb,
const int n_tokens,
const int n_batch,
const bool enable_flash_attn,
- const bool enable_checkpointing) {
+ const bool enable_checkpointing,
+ const bool measure_only) {
ggml_set_scratch(ctx, { 0, 0, nullptr, });
const int n_past = 0;
// KQ_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);
- ggml_allocr_alloc(alloc, KQ_pos);
- if (!ggml_allocr_is_measure(alloc)) {
- int * data = (int *) KQ_pos->data;
- for (int i = 0; i < N; ++i) {
- data[i] = n_past + i;
- }
- }
+ ggml_set_input(KQ_pos);
// rope has so much parameters that we make a custom function for it
auto rope = [ctx, KQ_pos, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
// KQ_pos
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, 1.0f));
GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);
-
- ggml_allocr_alloc(alloc, t36->grad);
+ ggml_set_input(t36->grad);
// allocating checkpoints in one block to reduce memory fragmentation
// note: they will be freed in reverse order
for (int i = 0; i < (int) checkpoints.size(); ++i) {
if (checkpoints[i]->data == NULL && checkpoints[i]->view_src == NULL) {
- ggml_allocr_alloc(alloc, checkpoints[i]);
+ ggml_set_input(checkpoints[i]);
}
}
//int n_leafs_after = gb->n_leafs;
//int n_nodes_after = gb->n_nodes;
+ if (measure_only) {
+ // FIXME: will still allocate
+ ggml_gallocr_reserve(alloc, gb);
+ } else {
+ ggml_gallocr_alloc_graph(alloc, gb);
- ggml_allocr_alloc_graph(alloc, gb);
+ if (!measure_only) {
+ int * data = (int *) KQ_pos->data;
+ for (int i = 0; i < N; ++i) {
+ data[i] = n_past + i;
+ }
+ }
+ }
// remove the additional nodes and leafs
for (int i = n_leafs_before; i < gb->n_leafs; ++i) {
printf("%s: seen train_samples %llu\n", __func__, (long long unsigned) train->train_samples);
printf("%s: seen train_tokens %llu\n", __func__, (long long unsigned) train->train_tokens);
printf("%s: completed train_epochs %llu\n", __func__, (long long unsigned) train->train_epochs);
- printf("%s: model_size = %zu bytes (%.1f MB)\n", __func__, (ggml_used_mem(model.ctx) + model.data.size()), (float) (ggml_used_mem(model.ctx) + model.data.size()) / (1024.0f*1024.0f));
+ printf("%s: model_size = %zu bytes (%.1f MB)\n", __func__, (ggml_used_mem(model.ctx) + ggml_backend_buffer_get_size(model.data)), (float) (ggml_used_mem(model.ctx) + ggml_backend_buffer_get_size(model.data)) / (1024.0f*1024.0f));
if (params.only_write_model) {
save_train_files_data save_data;
int n_vocab = model.hparams.n_vocab;
int n_batch = params.common.n_batch;
- std::vector<uint8_t> mem_input_data;
- std::vector<uint8_t> mem_compute_data;
-
- ggml_allocr * alloc = NULL;
-
// context for input tensors without their data
struct ggml_init_params ctx_input_params = {
ggml_tensor_overhead() * 2, // mem_size
struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx_input, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
// measure required memory for input tensors
- size_t max_input_size = GGML_PAD(ggml_nbytes(tokens_input), tensor_alignment) +
- GGML_PAD(ggml_nbytes(target_probs), tensor_alignment) +
- tensor_alignment;
- printf("%s: input_size = %zu bytes (%.1f MB)\n", __func__, max_input_size, (float) max_input_size / (1024.0f*1024.0f));
-
// allocate input tensors
- mem_input_data.resize(max_input_size);
- alloc = ggml_allocr_new(mem_input_data.data(), mem_input_data.size(), tensor_alignment);
- ggml_allocr_alloc(alloc, tokens_input);
- ggml_allocr_alloc(alloc, target_probs);
+ ggml_backend_buffer_t input_data = ggml_backend_alloc_ctx_tensors_from_buft(ctx_input, ggml_backend_cpu_buffer_type());
+ size_t max_input_size = ggml_backend_buffer_get_size(input_data);
+ printf("%s: input_size = %zu bytes (%.1f MB)\n", __func__, max_input_size, (float) max_input_size / (1024.0f*1024.0f));
// context for compute tensors without their data
const size_t estimated_compute_size_wo_data = (
// find best evaluation order
for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) {
ctx_compute = ggml_init(ctx_compute_params);
- alloc = ggml_allocr_new_measure(tensor_alignment);
+ ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = (enum ggml_cgraph_eval_order) order;
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
&logits, tokens_input, target_probs,
n_tokens, n_batch,
params.common.use_flash,
- params.common.use_checkpointing
+ params.common.use_checkpointing,
+ true
);
- size_t max_compute_size = ggml_allocr_max_size(alloc) + tensor_alignment;
+ size_t max_compute_size = ggml_gallocr_get_buffer_size(alloc, 0); // FIXME: this will still allocate the buffer
if (max_compute_size < best_compute_size) {
best_compute_size = max_compute_size;
best_order = gf->order;
"invalid");
// allocate compute tensors
- mem_compute_data.resize(max_compute_size);
ctx_compute = ggml_init(ctx_compute_params);
- alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment);
+ ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = best_order;
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
&logits, tokens_input, target_probs,
n_tokens, n_batch,
params.common.use_flash,
- params.common.use_checkpointing
+ params.common.use_checkpointing,
+ false
);
std::vector<llama_token> train_tokens;
//#define AT_PRINTF(...) fprintf(stderr, __VA_ARGS__)
#define AT_PRINTF(...)
+
+static bool ggml_is_view(const struct ggml_tensor * t) {
+ return t->view_src != NULL;
+}
+
+static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
+ if (a->type != b->type) {
+ return false;
+ }
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
+ if (a->ne[i] != b->ne[i]) {
+ return false;
+ }
+ if (a->nb[i] != b->nb[i]) {
+ return false;
+ }
+ }
+ return true;
+}
+
+static bool ggml_op_can_inplace(enum ggml_op op) {
+ switch (op) {
+ case GGML_OP_SCALE:
+ case GGML_OP_DIAG_MASK_ZERO:
+ case GGML_OP_DIAG_MASK_INF:
+ case GGML_OP_ADD:
+ case GGML_OP_ADD1:
+ case GGML_OP_SUB:
+ case GGML_OP_MUL:
+ case GGML_OP_DIV:
+ case GGML_OP_SQR:
+ case GGML_OP_SQRT:
+ case GGML_OP_LOG:
+ case GGML_OP_UNARY:
+ case GGML_OP_ROPE:
+ case GGML_OP_RMS_NORM:
+ case GGML_OP_SOFT_MAX:
+ return true;
+
+ default:
+ return false;
+ }
+}
+
// TODO: GGML_PAD ?
static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) {
assert(alignment && !(alignment & (alignment - 1))); // power of 2
return offset + align;
}
+// tallocr
+struct ggml_tallocr {
+ ggml_backend_buffer_t buffer;
+ void * base;
+ size_t alignment;
+ size_t offset;
+};
+
+ggml_tallocr_t ggml_tallocr_new(ggml_backend_buffer_t buffer) {
+ ggml_tallocr_t talloc = malloc(sizeof(struct ggml_tallocr));
+ if (talloc == NULL) {
+ return NULL;
+ }
+
+ void * base = ggml_backend_buffer_get_base(buffer);
+ size_t align = ggml_backend_buffer_get_alignment(buffer);
+
+ assert(align && !(align & (align - 1))); // power of 2
+
+ *talloc = (struct ggml_tallocr) {
+ /*.buffer = */ buffer,
+ /*.base = */ base,
+ /*.alignment = */ align,
+ /*.offset = */ aligned_offset(base, 0, align),
+ };
+ return talloc;
+}
+
+void ggml_tallocr_free(ggml_tallocr_t talloc) {
+ free(talloc);
+}
+
+void ggml_tallocr_alloc(ggml_tallocr_t talloc, struct ggml_tensor * tensor) {
+ size_t size = ggml_backend_buffer_get_alloc_size(talloc->buffer, tensor);
+ size = GGML_PAD(size, talloc->alignment);
+
+ if (talloc->offset + size > ggml_backend_buffer_get_size(talloc->buffer)) {
+ fprintf(stderr, "%s: not enough space in the buffer to allocate %s (needed %zu, available %zu)\n",
+ __func__, tensor->name, size, ggml_backend_buffer_get_size(talloc->buffer) - talloc->offset);
+ GGML_ASSERT(!"not enough space in the buffer");
+ return;
+ }
+
+ void * addr = (char *)ggml_backend_buffer_get_base(talloc->buffer) + talloc->offset;
+ talloc->offset += size;
+
+ assert(((uintptr_t)addr % talloc->alignment) == 0);
+
+ ggml_backend_tensor_alloc(talloc->buffer, tensor, addr);
+}
+
+// dynamic tensor allocator
+
struct free_block {
- void * addr;
+ size_t offset;
size_t size;
};
-struct ggml_tallocr {
- struct ggml_backend_buffer * buffer;
- bool buffer_owned;
- void * base;
+struct ggml_dyn_tallocr {
size_t alignment;
-
int n_free_blocks;
struct free_block free_blocks[MAX_FREE_BLOCKS];
-
size_t max_size;
- bool measure;
-
#ifdef GGML_ALLOCATOR_DEBUG
- struct ggml_tensor * allocated_tensors[1024];
+ struct {
+ const struct ggml_tensor * tensor;
+ size_t offset;
+ } allocated_tensors[1024];
#endif
};
#ifdef GGML_ALLOCATOR_DEBUG
-static void add_allocated_tensor(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
+static void add_allocated_tensor(struct ggml_dyn_tallocr * alloc, size_t offset, const struct ggml_tensor * tensor) {
for (int i = 0; i < 1024; i++) {
- if (alloc->allocated_tensors[i] == NULL) {
- alloc->allocated_tensors[i] = tensor;
+ if (alloc->allocated_tensors[i].tensor == NULL) {
+ alloc->allocated_tensors[i].tensor = tensor;
+ alloc->allocated_tensors[i].offset = offset;
return;
}
}
GGML_ASSERT(!"out of allocated_tensors");
}
-static void remove_allocated_tensor(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
+static void remove_allocated_tensor(struct ggml_dyn_tallocr * alloc, size_t offset, const struct ggml_tensor * tensor) {
for (int i = 0; i < 1024; i++) {
- if (alloc->allocated_tensors[i] == tensor ||
- (alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) {
- alloc->allocated_tensors[i] = NULL;
+ if (alloc->allocated_tensors[i].offset == offset) {
+ alloc->allocated_tensors[i].tensor = NULL;
return;
}
}
- printf("tried to free tensor %s not found\n", tensor->name);
+ fprintf(stderr, "tried to free tensor %s not found\n", tensor->name);
GGML_ASSERT(!"tensor not found");
}
#endif
-// check if a tensor is allocated by this buffer
-static bool ggml_tallocr_is_own(ggml_tallocr_t alloc, const struct ggml_tensor * tensor) {
- return tensor->buffer == alloc->buffer && (!tensor->view_src || tensor->view_src->buffer == alloc->buffer);
-}
-
-static bool ggml_is_view(struct ggml_tensor * t) {
- return t->view_src != NULL;
-}
-
-void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
- GGML_ASSERT(!ggml_is_view(tensor)); // views generally get data pointer from one of their sources
- GGML_ASSERT(tensor->data == NULL); // avoid allocating tensor which already has memory allocated
-
- size_t size = ggml_backend_buffer_get_alloc_size(alloc->buffer, tensor);
+static size_t ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * alloc, size_t size, const struct ggml_tensor * tensor) {
size = aligned_offset(NULL, size, alloc->alignment);
AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size);
if (block->size >= size) {
best_fit_block = alloc->n_free_blocks - 1;
} else {
- fprintf(stderr, "%s: not enough space in the buffer to allocate %s (needed %zu, largest block available %zu)\n",
- __func__, tensor->name, size, max_avail);
+ // this should never happen
+ fprintf(stderr, "%s: not enough space in the buffer to allocate %zu bytes, largest block available %zu bytes\n",
+ __func__, size, max_avail);
GGML_ASSERT(!"not enough space in the buffer");
- return;
+ GGML_UNREACHABLE();
}
}
struct free_block * block = &alloc->free_blocks[best_fit_block];
- void * addr = block->addr;
- block->addr = (char*)block->addr + size;
+ size_t offset = block->offset;
+ block->offset = offset + size;
block->size -= size;
if (block->size == 0) {
// remove block if empty
}
}
- AT_PRINTF("block %d, addr %p\n", best_fit_block, addr);
-
- tensor->data = addr;
- tensor->buffer = alloc->buffer;
- if (!alloc->measure) {
- ggml_backend_buffer_init_tensor(alloc->buffer, tensor);
- }
+ AT_PRINTF("block %d, offset %zu\n", best_fit_block, offset);
#ifdef GGML_ALLOCATOR_DEBUG
- add_allocated_tensor(alloc, tensor);
- size_t cur_max = (char*)addr - (char*)alloc->base + size;
+ add_allocated_tensor(alloc, offset, tensor);
+ size_t cur_max = offset + size;
if (cur_max > alloc->max_size) {
- printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
+ // sort allocated_tensors by offset
+ for (int i = 0; i < 1024; i++) {
+ for (int j = i + 1; j < 1024; j++) {
+ if (alloc->allocated_tensors[i].offset > alloc->allocated_tensors[j].offset) {
+ const struct ggml_tensor * tmp_tensor = alloc->allocated_tensors[i].tensor;
+ size_t tmp_offset = alloc->allocated_tensors[i].offset;
+ alloc->allocated_tensors[i].tensor = alloc->allocated_tensors[j].tensor;
+ alloc->allocated_tensors[i].offset = alloc->allocated_tensors[j].offset;
+ alloc->allocated_tensors[j].tensor = tmp_tensor;
+ alloc->allocated_tensors[j].offset = tmp_offset;
+ }
+ }
+ }
+ fprintf(stderr, "max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
for (int i = 0; i < 1024; i++) {
- if (alloc->allocated_tensors[i]) {
- printf("%s (%.2f MB) ", alloc->allocated_tensors[i]->name, ggml_nbytes(alloc->allocated_tensors[i]) / 1024.0 / 1024.0);
+ if (alloc->allocated_tensors[i].tensor) {
+ fprintf(stderr, "%s [%zx-%zx] (%.2f MB) ", alloc->allocated_tensors[i].tensor->name,
+ alloc->allocated_tensors[i].offset,
+ alloc->allocated_tensors[i].offset + ggml_nbytes(alloc->allocated_tensors[i].tensor),
+ ggml_nbytes(alloc->allocated_tensors[i].tensor) / 1024.0 / 1024.0);
}
}
- printf("\n");
+ fprintf(stderr, "\n");
}
#endif
- alloc->max_size = MAX(alloc->max_size, (char*)addr - (char*)alloc->base + size);
-}
+ alloc->max_size = MAX(alloc->max_size, offset + size);
-// this is a very naive implementation, but for our case the number of free blocks should be very small
-static void ggml_tallocr_free_tensor(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
- if (ggml_tallocr_is_own(alloc, tensor) == false) {
- // the tensor was not allocated in this buffer
- // this can happen because the graph allocator will try to free weights and other tensors from different buffers
- // the easiest way to deal with this is just to ignore it
- // AT_PRINTF("ignoring %s (their buffer: %p, our buffer: %p)\n", tensor->name, (void *)tensor->buffer, (void *)alloc->buffer);
- return;
- }
+ return offset;
- void * ptr = tensor->data;
+ GGML_UNUSED(tensor);
+}
- size_t size = ggml_backend_buffer_get_alloc_size(alloc->buffer, tensor);
+// this is a very naive implementation, but for our case the number of free blocks should be very small
+static void ggml_dyn_tallocr_free_tensor(struct ggml_dyn_tallocr * alloc, size_t offset, size_t size, const struct ggml_tensor * tensor) {
size = aligned_offset(NULL, size, alloc->alignment);
- AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks);
+
+ AT_PRINTF("%s: freeing %s at %zu (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, offset, size, alloc->n_free_blocks);
#ifdef GGML_ALLOCATOR_DEBUG
- remove_allocated_tensor(alloc, tensor);
+ remove_allocated_tensor(alloc, offset, tensor);
#endif
// see if we can merge with an existing block
for (int i = 0; i < alloc->n_free_blocks; i++) {
struct free_block * block = &alloc->free_blocks[i];
// check if ptr is at the end of the block
- if ((char*)block->addr + block->size == ptr) {
+ if (block->offset + block->size == offset) {
block->size += size;
// check if we can merge with the next block
- if (i < alloc->n_free_blocks - 1 && (char*)block->addr + block->size == alloc->free_blocks[i+1].addr) {
+ if (i < alloc->n_free_blocks - 1 && block->offset + block->size == alloc->free_blocks[i+1].offset) {
block->size += alloc->free_blocks[i+1].size;
alloc->n_free_blocks--;
for (int j = i+1; j < alloc->n_free_blocks; j++) {
return;
}
// check if ptr is at the beginning of the block
- if ((char*)ptr + size == block->addr) {
- block->addr = ptr;
+ if (offset + size == block->offset) {
+ block->offset = offset;
block->size += size;
// check if we can merge with the previous block
- if (i > 0 && (char*)alloc->free_blocks[i-1].addr + alloc->free_blocks[i-1].size == block->addr) {
+ if (i > 0 && alloc->free_blocks[i-1].offset + alloc->free_blocks[i-1].size == block->offset) {
alloc->free_blocks[i-1].size += block->size;
alloc->n_free_blocks--;
for (int j = i; j < alloc->n_free_blocks; j++) {
GGML_ASSERT(alloc->n_free_blocks < MAX_FREE_BLOCKS && "out of free blocks");
// insert the new block in the correct position to keep the array sorted by address (to make merging blocks faster)
int insert_pos = 0;
- while (insert_pos < alloc->n_free_blocks && alloc->free_blocks[insert_pos].addr < ptr) {
+ while (insert_pos < alloc->n_free_blocks && alloc->free_blocks[insert_pos].offset < offset) {
insert_pos++;
}
// shift all blocks from insert_pos onward to make room for the new block
alloc->free_blocks[i] = alloc->free_blocks[i-1];
}
// insert the new block
- alloc->free_blocks[insert_pos].addr = ptr;
+ alloc->free_blocks[insert_pos].offset = offset;
alloc->free_blocks[insert_pos].size = size;
alloc->n_free_blocks++;
+
+ GGML_UNUSED(tensor);
}
-void ggml_tallocr_reset(ggml_tallocr_t alloc) {
+static void ggml_dyn_tallocr_reset(struct ggml_dyn_tallocr * alloc) {
alloc->n_free_blocks = 1;
- size_t align_offset = aligned_offset(alloc->base, 0, alloc->alignment);
- alloc->free_blocks[0].addr = (char *)alloc->base + align_offset;
-
- if (alloc->measure) {
- alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows
- } else {
- alloc->free_blocks[0].size = ggml_backend_buffer_get_size(alloc->buffer) - align_offset;
- ggml_backend_buffer_reset(alloc->buffer);
- }
+ alloc->free_blocks[0].offset = 0;
+ alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows
+ alloc->max_size = 0;
}
-ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment) {
- struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(data, size);
-
- ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr));
+static struct ggml_dyn_tallocr * ggml_dyn_tallocr_new(size_t alignment) {
+ struct ggml_dyn_tallocr * alloc = (struct ggml_dyn_tallocr *)malloc(sizeof(struct ggml_dyn_tallocr));
- *alloc = (struct ggml_tallocr) {
- /*.buffer = */ buffer,
- /*.buffer_owned = */ true,
- /*.base = */ ggml_backend_buffer_get_base(buffer),
+ *alloc = (struct ggml_dyn_tallocr) {
/*.alignment = */ alignment,
/*.n_free_blocks = */ 0,
/*.free_blocks = */ {{0}},
/*.max_size = */ 0,
- /*.measure = */ false,
#ifdef GGML_ALLOCATOR_DEBUG
- /*.allocated_tensors = */ {0},
+ /*.allocated_tensors = */ {{0}},
#endif
};
- ggml_tallocr_reset(alloc);
-
- return alloc;
-}
-
-ggml_tallocr_t ggml_tallocr_new_measure(size_t alignment) {
- ggml_tallocr_t alloc = ggml_tallocr_new((void *)0x1000, SIZE_MAX/2, alignment);
- alloc->measure = true;
+ ggml_dyn_tallocr_reset(alloc);
return alloc;
}
-ggml_tallocr_t ggml_tallocr_new_measure_from_buft(struct ggml_backend_buffer_type * buft) {
- // create a backend buffer to get the correct tensor allocation sizes
- ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, 1);
-
- // TODO: move alloc initialization to a common ggml_tallocr_new_impl function
- ggml_tallocr_t alloc = ggml_tallocr_new_from_buffer(buffer);
- alloc->buffer_owned = true;
- alloc->measure = true;
- ggml_tallocr_reset(alloc);
- return alloc;
-}
-
-ggml_tallocr_t ggml_tallocr_new_measure_from_backend(struct ggml_backend * backend) {
- return ggml_tallocr_new_measure_from_buft(ggml_backend_get_default_buffer_type(backend));
-}
-
-ggml_tallocr_t ggml_tallocr_new_from_buft(struct ggml_backend_buffer_type * buft, size_t size) {
- // create a backend buffer to get the correct tensor allocation sizes
- ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size);
- ggml_tallocr_t alloc = ggml_tallocr_new_from_buffer(buffer);
- alloc->buffer_owned = true;
- return alloc;
-}
-
-ggml_tallocr_t ggml_tallocr_new_from_backend(struct ggml_backend * backend, size_t size) {
- return ggml_tallocr_new_from_buft(ggml_backend_get_default_buffer_type(backend), size);
-}
-
-ggml_tallocr_t ggml_tallocr_new_from_buffer(struct ggml_backend_buffer * buffer) {
- ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr));
-
- *alloc = (struct ggml_tallocr) {
- /*.buffer = */ buffer,
- /*.buffer_owned = */ false,
- /*.base = */ ggml_backend_buffer_get_base(buffer),
- /*.alignment = */ ggml_backend_buffer_get_alignment(buffer),
- /*.n_free_blocks = */ 0,
- /*.free_blocks = */ {{0}},
- /*.max_size = */ 0,
- /*.measure = */ false,
-#ifdef GGML_ALLOCATOR_DEBUG
- /*.allocated_tensors = */ {0},
-#endif
- };
-
- ggml_tallocr_reset(alloc);
-
- return alloc;
-}
-
-struct ggml_backend_buffer * ggml_tallocr_get_buffer(ggml_tallocr_t alloc) {
- return alloc->buffer;
-}
-
-void ggml_tallocr_free(ggml_tallocr_t alloc) {
- if (alloc == NULL) {
- return;
- }
-
- if (alloc->buffer_owned) {
- ggml_backend_buffer_free(alloc->buffer);
- }
+static void ggml_dyn_tallocr_free(struct ggml_dyn_tallocr * alloc) {
free(alloc);
}
-bool ggml_tallocr_is_measure(ggml_tallocr_t alloc) {
- return alloc->measure;
+static size_t ggml_dyn_tallocr_max_size(struct ggml_dyn_tallocr * alloc) {
+ return alloc->max_size;
}
-size_t ggml_tallocr_max_size(ggml_tallocr_t alloc) {
- // FIXME: changes in the tensor sizes compared to the measure graph may cause allocations to fail
- // to avoid this, we add a 10% margin to the buffer size
- return alloc->max_size + alloc->max_size/10;
-}
+
+/////////////////////////////////////
// graph allocator
struct hash_node {
int n_children;
int n_views;
+ int buffer_id;
+ size_t offset; // offset within the buffer
+ bool allocated;
+};
+
+//
+struct tensor_alloc {
+ size_t offset;
+ size_t size_max; // 0 = pre-allocated, unused, or view
+};
+
+struct node_alloc {
+ int buffer_id;
+ struct tensor_alloc dst;
+ struct tensor_alloc src[GGML_MAX_SRC];
};
struct ggml_gallocr {
- ggml_tallocr_t talloc;
+ ggml_backend_buffer_type_t * bufts; // [n_buffers]
+ ggml_backend_buffer_t * buffers; // [n_buffers]
+ struct ggml_dyn_tallocr ** buf_tallocs; // [n_buffers]
+ int n_buffers;
+
struct ggml_hash_set hash_set;
- struct hash_node * hash_values;
- size_t hash_values_size;
- ggml_tallocr_t * hash_allocs;
- int * parse_seq;
- int parse_seq_len;
+ struct hash_node * hash_values; // [hash_set.size]
+
+ struct node_alloc * node_allocs; // [n_nodes]
+ int n_nodes;
};
-ggml_gallocr_t ggml_gallocr_new(void) {
- ggml_gallocr_t galloc = (ggml_gallocr_t)malloc(sizeof(struct ggml_gallocr));
-
- *galloc = (struct ggml_gallocr) {
- /*.talloc = */ NULL,
- /*.hash_set = */ {0},
- /*.hash_values = */ NULL,
- /*.hash_values_size = */ 0,
- /*.hash_allocs = */ NULL,
- /*.parse_seq = */ NULL,
- /*.parse_seq_len = */ 0,
- };
+ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs) {
+ ggml_gallocr_t galloc = (ggml_gallocr_t)calloc(sizeof(struct ggml_gallocr), 1);
+ GGML_ASSERT(galloc != NULL);
+
+ galloc->bufts = calloc(sizeof(ggml_backend_buffer_type_t) * n_bufs, 1);
+ GGML_ASSERT(galloc->bufts != NULL);
+
+ galloc->buffers = calloc(sizeof(ggml_backend_buffer_t) * n_bufs, 1);
+ GGML_ASSERT(galloc->buffers != NULL);
+
+ galloc->buf_tallocs = calloc(sizeof(struct ggml_dyn_tallocr *) * n_bufs, 1);
+ GGML_ASSERT(galloc->buf_tallocs != NULL);
+
+ for (int i = 0; i < n_bufs; i++) {
+ galloc->bufts[i] = bufts[i];
+ galloc->buffers[i] = NULL;
+ size_t alignment = ggml_backend_buft_get_alignment(bufts[i]);
+ galloc->buf_tallocs[i] = ggml_dyn_tallocr_new(alignment);
+ }
+ galloc->n_buffers = n_bufs;
return galloc;
}
+ggml_gallocr_t ggml_gallocr_new(ggml_backend_buffer_type_t buft) {
+ return ggml_gallocr_new_n(&buft, 1);
+}
+
void ggml_gallocr_free(ggml_gallocr_t galloc) {
if (galloc == NULL) {
return;
}
- if (galloc->hash_set.keys != NULL) {
- free(galloc->hash_set.keys);
- }
- if (galloc->hash_values != NULL) {
- free(galloc->hash_values);
- }
- if (galloc->hash_allocs != NULL) {
- free(galloc->hash_allocs);
- }
- if (galloc->parse_seq != NULL) {
- free(galloc->parse_seq);
+ for (int i = 0; i < galloc->n_buffers; i++) {
+ if (galloc->buffers != NULL) {
+ ggml_backend_buffer_free(galloc->buffers[i]);
+ }
+ if (galloc->buf_tallocs != NULL) {
+ ggml_dyn_tallocr_free(galloc->buf_tallocs[i]);
+ }
}
+
+ free(galloc->hash_set.keys);
+ free(galloc->hash_values);
+ free(galloc->bufts);
+ free(galloc->buffers);
+ free(galloc->buf_tallocs);
+ free(galloc->node_allocs);
free(galloc);
}
-void ggml_gallocr_set_parse_seq(ggml_gallocr_t galloc, const int * list, int n) {
- free(galloc->parse_seq);
- galloc->parse_seq = malloc(sizeof(int) * n);
+typedef struct ggml_gallocr * ggml_gallocr_t;
- for (int i = 0; i < n; i++) {
- galloc->parse_seq[i] = list[i];
- }
- galloc->parse_seq_len = n;
-}
-
-static struct hash_node * hash_get(ggml_gallocr_t galloc, struct ggml_tensor * t) {
+static struct hash_node * ggml_gallocr_hash_get(ggml_gallocr_t galloc, struct ggml_tensor * t) {
size_t i = ggml_hash_find_or_insert(galloc->hash_set, t);
return &galloc->hash_values[i];
}
-static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
- if (a->type != b->type) {
- return false;
- }
- for (int i = 0; i < GGML_MAX_DIMS; i++) {
- if (a->ne[i] != b->ne[i]) {
- return false;
- }
- if (a->nb[i] != b->nb[i]) {
- return false;
- }
- }
- return true;
+static bool ggml_gallocr_is_own(ggml_gallocr_t galloc, struct ggml_tensor * t) {
+ return ggml_gallocr_hash_get(galloc, t)->allocated;
}
-static bool ggml_op_can_inplace(enum ggml_op op) {
- switch (op) {
- case GGML_OP_SCALE:
- case GGML_OP_DIAG_MASK_ZERO:
- case GGML_OP_DIAG_MASK_INF:
- case GGML_OP_ADD:
- case GGML_OP_ADD1:
- case GGML_OP_SUB:
- case GGML_OP_MUL:
- case GGML_OP_DIV:
- case GGML_OP_SQR:
- case GGML_OP_SQRT:
- case GGML_OP_LOG:
- case GGML_OP_UNARY:
- case GGML_OP_ROPE:
- case GGML_OP_RMS_NORM:
- case GGML_OP_SOFT_MAX:
- return true;
-
- default:
- return false;
- }
+static void ggml_gallocr_set_node_offset(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id, size_t offset) {
+ struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
+ hn->buffer_id = buffer_id;
+ hn->offset = offset;
+ hn->allocated = true;
}
-static ggml_tallocr_t node_tallocr(ggml_gallocr_t galloc, struct ggml_tensor * node) {
- if (galloc->talloc != NULL) {
- return galloc->talloc;
- }
-
- return galloc->hash_allocs[ggml_hash_find_or_insert(galloc->hash_set, node)];
+static bool ggml_gallocr_is_allocated(ggml_gallocr_t galloc, struct ggml_tensor * t) {
+ return t->data != NULL || ggml_gallocr_hash_get(galloc, t)->allocated;
}
-static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool update_backend) {
- ggml_tallocr_t alloc = node_tallocr(galloc, view);
-
- GGML_ASSERT(view->view_src != NULL && view->view_src->data != NULL);
- if (update_backend) {
- view->backend = view->view_src->backend;
- }
- // views are initialized in the alloc buffer rather than the view_src buffer
- view->buffer = alloc->buffer;
- view->data = (char *)view->view_src->data + view->view_offs;
+static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id) {
+ struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
- assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->buft == alloc->buffer->buft);
+ if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_is_view(node)) {
+ hn->allocated = true;
+ assert(hn->offset == 0);
- if (!alloc->measure) {
- ggml_backend_buffer_init_tensor(alloc->buffer, view);
- }
-}
+ // try to reuse a parent's buffer (inplace)
+ if (ggml_op_can_inplace(node->op)) {
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
+ struct ggml_tensor * parent = node->src[i];
+ if (parent == NULL) {
+ break;
+ }
-static void allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node) {
- ggml_tallocr_t alloc = node_tallocr(galloc, node);
+ // if the node's data is external, then we cannot re-use it
+ if (!ggml_gallocr_is_own(galloc, parent)) {
+ AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data);
+ continue;
+ }
- if (node->data == NULL) {
- if (ggml_is_view(node)) {
- init_view(galloc, node, true);
- } else {
- // see if we can reuse a parent's buffer (inplace)
- if (ggml_op_can_inplace(node->op)) {
- for (int i = 0; i < GGML_MAX_SRC; i++) {
- struct ggml_tensor * parent = node->src[i];
- if (parent == NULL) {
- break;
- }
+ // outputs cannot be reused
+ if (parent->flags & GGML_TENSOR_FLAG_OUTPUT || (parent->view_src != NULL && parent->view_src->flags & GGML_TENSOR_FLAG_OUTPUT)) {
+ AT_PRINTF("not reusing parent %s for %s as it is an output\n", parent->name, node->name);
+ continue;
+ }
- // if the node's data is external, then we cannot re-use it
- if (ggml_tallocr_is_own(alloc, parent) == false) {
- AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data);
- continue;
- }
+ if (!ggml_are_same_layout(node, parent)) {
+ AT_PRINTF("not reusing parent %s for %s as layouts are different\n", parent->name, node->name);
+ continue;
+ }
- struct hash_node * p_hn = hash_get(galloc, parent);
- if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && ggml_are_same_layout(node, parent)) {
- if (ggml_is_view(parent)) {
- struct ggml_tensor * view_src = parent->view_src;
- struct hash_node * view_src_hn = hash_get(galloc, view_src);
- if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
- // TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite
- // the parent's data that it will need later (same layout requirement). the problem is that then
- // we cannot free the tensor because the original address of the allocation is lost.
- // adding a view_src pointer to the tensor would solve this and simplify the code dealing with views
- // for now, we only reuse the parent's data if the offset is zero (view_src->data == parent->data)
- AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name);
- node->view_src = view_src;
- view_src_hn->n_views += 1;
- init_view(galloc, node, false);
- return;
- }
- } else {
- AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
- node->view_src = parent;
- p_hn->n_views += 1;
- init_view(galloc, node, false);
+ struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
+ if (p_hn->n_children == 1 && p_hn->n_views == 0) {
+ if (ggml_is_view(parent)) {
+ struct ggml_tensor * view_src = parent->view_src;
+ struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src);
+ if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
+ AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name);
+ assert(view_src_hn->offset == p_hn->offset);
+ hn->buffer_id = p_hn->buffer_id;
+ hn->offset = p_hn->offset;
+ p_hn->allocated = false; // avoid freeing the parent
+ view_src_hn->allocated = false;
return;
}
+ } else {
+ AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
+ hn->buffer_id = p_hn->buffer_id;
+ hn->offset = p_hn->offset;
+ p_hn->allocated = false; // avoid freeing the parent
+ return;
}
}
}
- ggml_tallocr_alloc(alloc, node);
}
+ // allocate tensor from the buffer
+ struct ggml_dyn_tallocr * alloc = galloc->buf_tallocs[buffer_id];
+ ggml_backend_buffer_type_t buft = galloc->bufts[buffer_id];
+ size_t size = ggml_backend_buft_get_alloc_size(buft, node);
+ size_t offset = ggml_dyn_tallocr_alloc(alloc, size, node);
+ hn->buffer_id = buffer_id;
+ hn->offset = offset;
+ return;
}
}
-static void free_node(ggml_gallocr_t galloc, struct ggml_tensor * node) {
- ggml_tallocr_t alloc = node_tallocr(galloc, node);
+static void ggml_gallocr_free_node(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id) {
+ // graph outputs are never freed
+ if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
+ AT_PRINTF("not freeing output %s\n", node->name);
+ return;
+ }
- ggml_tallocr_free_tensor(alloc, node);
+ struct ggml_dyn_tallocr * alloc = galloc->buf_tallocs[buffer_id];
+ ggml_backend_buffer_type_t buft = galloc->bufts[buffer_id];
+ struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
+ size_t offset = hn->offset;
+ size_t size = ggml_backend_buft_get_alloc_size(buft, node);
+ ggml_dyn_tallocr_free_tensor(alloc, offset, size, node);
+ hn->allocated = false;
}
-static void ggml_tallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgraph * gf) {
- const int * parse_seq = galloc->parse_seq;
- int parse_seq_len = galloc->parse_seq_len;
+static int get_node_buffer_id(const int * node_buffer_ids, int i) {
+ return node_buffer_ids ? node_buffer_ids[i] : 0;
+}
+
+static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids) {
+ // clear hash tables
+ memset(galloc->hash_set.keys, 0, galloc->hash_set.size * sizeof(struct ggml_tensor *));
+ memset(galloc->hash_values, 0, galloc->hash_set.size * sizeof(struct hash_node));
+
+ // allocate all graph inputs first to avoid overwriting them
+ for (int i = 0; i < graph->n_nodes; i++) {
+ if (graph->nodes[i]->flags & GGML_TENSOR_FLAG_INPUT) {
+ ggml_gallocr_allocate_node(galloc, graph->nodes[i], get_node_buffer_id(node_buffer_ids, i));
+ }
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ if (graph->nodes[i]->src[j] == NULL) {
+ break;
+ }
+ if (graph->nodes[i]->src[j]->flags & GGML_TENSOR_FLAG_INPUT) {
+ ggml_gallocr_allocate_node(galloc, graph->nodes[i]->src[j], get_node_buffer_id(node_buffer_ids, i));
+ }
+ }
+ }
// count number of children and views
- for (int i = 0; i < gf->n_nodes; i++) {
- struct ggml_tensor * node = gf->nodes[i];
+ for (int i = 0; i < graph->n_nodes; i++) {
+ struct ggml_tensor * node = graph->nodes[i];
if (ggml_is_view(node)) {
struct ggml_tensor * view_src = node->view_src;
- hash_get(galloc, view_src)->n_views += 1;
- if (node->buffer == NULL && node->data != NULL) {
- // view of a pre-allocated tensor, didn't call init_view() yet
- init_view(galloc, node, true);
- }
+ ggml_gallocr_hash_get(galloc, view_src)->n_views += 1;
}
for (int j = 0; j < GGML_MAX_SRC; j++) {
if (parent == NULL) {
break;
}
- hash_get(galloc, parent)->n_children += 1;
- if (ggml_is_view(parent) && parent->buffer == NULL && parent->data != NULL) {
- init_view(galloc, parent, true);
- }
+ ggml_gallocr_hash_get(galloc, parent)->n_children += 1;
}
}
// allocate tensors
- // if we have parse_seq then we allocate nodes following the list, and we only free nodes at barriers
- int last_barrier_pos = 0;
- int n_nodes = parse_seq_len ? parse_seq_len : gf->n_nodes;
-
- for (int ind = 0; ind < n_nodes; ind++) {
- // allocate a node if there is no parse_seq or this is not a barrier
- if (parse_seq_len == 0 || parse_seq[ind] != -1) {
- int i = parse_seq_len ? parse_seq[ind] : ind;
- struct ggml_tensor * node = gf->nodes[i];
-
- // allocate parents (leafs)
- for (int j = 0; j < GGML_MAX_SRC; j++) {
- struct ggml_tensor * parent = node->src[j];
- if (parent == NULL) {
- break;
- }
- allocate_node(galloc, parent);
+ for (int i = 0; i < graph->n_nodes; i++) {
+ struct ggml_tensor * node = graph->nodes[i];
+ int buffer_id = get_node_buffer_id(node_buffer_ids, i);
+
+ // allocate parents (only leafs need to be allocated at this point)
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ struct ggml_tensor * parent = node->src[j];
+ if (parent == NULL) {
+ break;
}
+ ggml_gallocr_allocate_node(galloc, parent, buffer_id);
+ }
- // allocate node
- allocate_node(galloc, node);
+ // allocate node
+ ggml_gallocr_allocate_node(galloc, node, buffer_id);
- AT_PRINTF("exec: %s (%s) <= ", ggml_op_name(node->op), node->name);
- for (int j = 0; j < GGML_MAX_SRC; j++) {
- struct ggml_tensor * parent = node->src[j];
- if (parent == NULL) {
- break;
- }
- AT_PRINTF("%s", parent->name);
- if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
- AT_PRINTF(", ");
- }
+ AT_PRINTF("exec: %s (%s) <= ", ggml_op_desc(node), node->name);
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ struct ggml_tensor * parent = node->src[j];
+ if (parent == NULL) {
+ break;
+ }
+ AT_PRINTF("%s", parent->name);
+ if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
+ AT_PRINTF(", ");
}
- AT_PRINTF("\n");
}
+ AT_PRINTF("\n");
// update parents
- // update immediately if there is no parse_seq
- // update only at barriers if there is parse_seq
- if ((parse_seq_len == 0) || parse_seq[ind] == -1) {
- int update_start = parse_seq_len ? last_barrier_pos : ind;
- int update_end = parse_seq_len ? ind : ind + 1;
- for (int i = update_start; i < update_end; i++) {
- int node_i = parse_seq_len ? parse_seq[i] : i;
- struct ggml_tensor * node = gf->nodes[node_i];
-
- for (int j = 0; j < GGML_MAX_SRC; j++) {
- struct ggml_tensor * parent = node->src[j];
- if (parent == NULL) {
- break;
- }
- struct hash_node * p_hn = hash_get(galloc, parent);
- p_hn->n_children -= 1;
-
- //AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views);
-
- if (p_hn->n_children == 0 && p_hn->n_views == 0) {
- if (ggml_is_view(parent)) {
- struct ggml_tensor * view_src = parent->view_src;
- struct hash_node * view_src_hn = hash_get(galloc, view_src);
- view_src_hn->n_views -= 1;
- AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
- if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0) {
- free_node(galloc, view_src);
- }
- }
- else {
- free_node(galloc, parent);
- }
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ struct ggml_tensor * parent = node->src[j];
+ if (parent == NULL) {
+ break;
+ }
+ struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
+ p_hn->n_children -= 1;
+
+ AT_PRINTF("parent %s: %d children, %d views, allocated: %d\n",
+ parent->name, p_hn->n_children, p_hn->n_views, p_hn->allocated);
+
+ if (p_hn->n_children == 0 && p_hn->n_views == 0) {
+ if (ggml_is_view(parent)) {
+ struct ggml_tensor * view_src = parent->view_src;
+ struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src);
+ view_src_hn->n_views -= 1;
+ AT_PRINTF("view_src %s: %d children, %d views\n",
+ view_src->name, view_src_hn->n_children, view_src_hn->n_views);
+ if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src_hn->allocated) {
+ ggml_gallocr_free_node(galloc, view_src, buffer_id);
}
}
+ else if (p_hn->allocated) {
+ ggml_gallocr_free_node(galloc, parent, buffer_id);
+ }
}
AT_PRINTF("\n");
- if (parse_seq_len) {
- last_barrier_pos = ind + 1;
- }
}
}
}
-size_t ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, ggml_tallocr_t talloc, struct ggml_cgraph * graph) {
+bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids) {
size_t hash_size = graph->visited_hash_table.size;
- // check if the hash table is initialized and large enough
+ // initialize hash table
if (galloc->hash_set.size < hash_size) {
- if (galloc->hash_set.keys != NULL) {
- free(galloc->hash_set.keys);
- }
- if (galloc->hash_values != NULL) {
- free(galloc->hash_values);
- }
- galloc->hash_set.keys = malloc(sizeof(struct ggml_tensor *) * hash_size);
+ free(galloc->hash_set.keys);
+ free(galloc->hash_values);
galloc->hash_set.size = hash_size;
- galloc->hash_values = malloc(sizeof(struct hash_node) * hash_size);
+ galloc->hash_set.keys = calloc(sizeof(struct ggml_tensor *), hash_size);
+ galloc->hash_values = calloc(sizeof(struct hash_node), hash_size);
+ GGML_ASSERT(galloc->hash_set.keys != NULL);
+ GGML_ASSERT(galloc->hash_values != NULL);
+ } else {
+ // reset hash table
+ memset(galloc->hash_set.keys, 0, sizeof(struct ggml_tensor *) * galloc->hash_set.size);
+ memset(galloc->hash_values, 0, sizeof(struct hash_node) * galloc->hash_set.size);
}
- // reset hash table
- memset(galloc->hash_set.keys, 0, sizeof(struct ggml_tensor *) * hash_size);
- memset(galloc->hash_values, 0, sizeof(struct hash_node) * hash_size);
-
- galloc->talloc = talloc;
- ggml_tallocr_alloc_graph_impl(galloc, graph);
- galloc->talloc = NULL;
-
- size_t max_size = ggml_tallocr_max_size(talloc);
-
- return max_size;
-}
-
-void ggml_gallocr_alloc_graph_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, struct ggml_hash_set hash_set, ggml_tallocr_t * hash_node_talloc) {
- const size_t hash_size = hash_set.size;
-
- GGML_ASSERT(hash_size >= (size_t)(graph->n_nodes + graph->n_leafs));
+ // reset allocators
+ for (int i = 0; i < galloc->n_buffers; i++) {
+ ggml_dyn_tallocr_reset(galloc->buf_tallocs[i]);
+ }
- galloc->talloc = NULL;
+ // allocate in hash table
+ ggml_gallocr_alloc_graph_impl(galloc, graph, node_buffer_ids);
- // alloc hash_values if needed
- if (galloc->hash_values == NULL || galloc->hash_values_size < hash_size) {
- free(galloc->hash_values);
- galloc->hash_values = malloc(sizeof(struct hash_node) * hash_size);
- galloc->hash_values_size = hash_size;
+ // set the node_allocs from the hash table
+ if (galloc->n_nodes < graph->n_nodes) {
+ free(galloc->node_allocs);
+ galloc->node_allocs = calloc(sizeof(struct node_alloc), graph->n_nodes);
+ GGML_ASSERT(galloc->node_allocs != NULL);
}
-
- // free hash_set.keys if needed
- if (galloc->hash_set.keys != NULL) {
- free(galloc->hash_set.keys);
+ galloc->n_nodes = graph->n_nodes;
+ for (int i = 0; i < graph->n_nodes; i++) {
+ struct ggml_tensor * node = graph->nodes[i];
+ struct node_alloc * node_alloc = &galloc->node_allocs[i];
+ node_alloc->buffer_id = get_node_buffer_id(node_buffer_ids, i);
+ if (node->view_src || node->data) {
+ node_alloc->dst.offset = SIZE_MAX;
+ node_alloc->dst.size_max = 0;
+ } else {
+ struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
+ node_alloc->dst.offset = hn->offset;
+ node_alloc->dst.size_max = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], node);
+ }
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ struct ggml_tensor * src = node->src[j];
+ if (!src || src->view_src || src->data) {
+ node_alloc->src[j].offset = SIZE_MAX;
+ node_alloc->src[j].size_max = 0;
+ } else {
+ struct hash_node * hn = ggml_gallocr_hash_get(galloc, src);
+ node_alloc->src[j].offset = hn->offset;
+ node_alloc->src[j].size_max = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], src);
+ }
+ }
}
- galloc->hash_set = hash_set;
- // reset hash values
- memset(galloc->hash_values, 0, sizeof(struct hash_node) * hash_size);
+ // reallocate buffers if needed
+ for (int i = 0; i < galloc->n_buffers; i++) {
+ size_t cur_size = galloc->buffers[i] ? ggml_backend_buffer_get_size(galloc->buffers[i]) : 0;
+ size_t new_size = ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i]);
- galloc->hash_allocs = hash_node_talloc;
-
- ggml_tallocr_alloc_graph_impl(galloc, graph);
+ if (new_size > cur_size) {
+#ifndef NDEBUG
+ fprintf(stderr, "%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
+#endif
+ ggml_backend_buffer_free(galloc->buffers[i]);
+ galloc->buffers[i] = ggml_backend_buft_alloc_buffer(galloc->bufts[i], new_size);
+ if (galloc->buffers[i] == NULL) {
+ fprintf(stderr, "%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size);
+ return false;
+ }
+ }
+ }
- // remove unowned resources
- galloc->hash_set.keys = NULL;
- galloc->hash_allocs = NULL;
+ return true;
}
-// legacy API wrapper
-
-struct ggml_allocr {
- ggml_tallocr_t talloc;
- ggml_gallocr_t galloc;
-};
-
-static ggml_allocr_t ggml_allocr_new_impl(ggml_tallocr_t talloc) {
- ggml_allocr_t alloc = (ggml_allocr_t)malloc(sizeof(struct ggml_allocr));
- *alloc = (struct ggml_allocr) {
- /*.talloc = */ talloc,
- /*.galloc = */ ggml_gallocr_new(),
- };
- return alloc;
+bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) {
+ return ggml_gallocr_reserve_n(galloc, graph, NULL);
}
-ggml_allocr_t ggml_allocr_new(void * data, size_t size, size_t alignment) {
- return ggml_allocr_new_impl(ggml_tallocr_new(data, size, alignment));
-}
+static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor * node, struct node_alloc * node_alloc, struct tensor_alloc * tensor_alloc) {
+ assert(node->data || node->view_src || ggml_backend_buffer_get_alloc_size(galloc->buffers[node_alloc->buffer_id], node) <= tensor_alloc->size_max);
-ggml_allocr_t ggml_allocr_new_measure(size_t alignment) {
- return ggml_allocr_new_impl(ggml_tallocr_new_measure(alignment));
-}
+ if (node->view_src != NULL) {
+ if (node->buffer == NULL) {
+ assert(tensor_alloc->offset == SIZE_MAX);
+ if (node->view_src->buffer == NULL) {
+ // this tensor was allocated without ggml-backend
+ return;
+ }
+ ggml_backend_view_init(galloc->buffers[node_alloc->buffer_id], node);
+ }
+ } else {
+ if (node->data == NULL) {
+ assert(tensor_alloc->offset != SIZE_MAX);
+ assert(ggml_backend_buffer_get_alloc_size(galloc->buffers[node_alloc->buffer_id], node) <= tensor_alloc->size_max);
+ void * base = ggml_backend_buffer_get_base(galloc->buffers[node_alloc->buffer_id]);
+ void * addr = (char *)base + tensor_alloc->offset;
+ ggml_backend_tensor_alloc(galloc->buffers[node_alloc->buffer_id], node, addr);
+ } else {
+ if (node->buffer == NULL) {
+ // this tensor was allocated without ggml-backend
+ return;
+ }
-ggml_allocr_t ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer) {
- return ggml_allocr_new_impl(ggml_tallocr_new_from_buffer(buffer));
+#ifndef NDEBUG
+ size_t offset =
+ (char *)node->data -
+ (char *)ggml_backend_buffer_get_base(node->buffer);
+ size_t size = ggml_backend_buffer_get_alloc_size(node->buffer, node);
+ assert(tensor_alloc->offset == SIZE_MAX || offset == tensor_alloc->offset);
+ assert(tensor_alloc->offset == SIZE_MAX || size <= tensor_alloc->size_max);
+#endif
+ }
+ }
}
-ggml_allocr_t ggml_allocr_new_from_backend(struct ggml_backend * backend, size_t size) {
- return ggml_allocr_new_impl(ggml_tallocr_new_from_backend(backend, size));
+static bool ggml_gallocr_node_needs_realloc(ggml_gallocr_t galloc, struct ggml_tensor * node, struct node_alloc * nalloc, struct tensor_alloc * talloc) {
+ ggml_backend_buffer_type_t buft = galloc->bufts[nalloc->buffer_id];
+ size_t node_size = (node->data || node->view_src) ? 0 : ggml_backend_buft_get_alloc_size(buft, node);
+ return talloc->size_max >= node_size;
}
-ggml_allocr_t ggml_allocr_new_measure_from_backend(struct ggml_backend * backend) {
- return ggml_allocr_new_impl(ggml_tallocr_new_measure_from_backend(backend));
-}
+static bool ggml_gallocr_needs_realloc(ggml_gallocr_t galloc, struct ggml_cgraph * graph) {
+ if (galloc->n_nodes != graph->n_nodes) {
+#ifndef NDEBUG
+ fprintf(stderr, "%s: graph has different number of nodes\n", __func__);
+#endif
+ return true;
+ }
-struct ggml_backend_buffer * ggml_allocr_get_buffer(ggml_allocr_t alloc) {
- return ggml_tallocr_get_buffer(alloc->talloc);
-}
+ for (int i = 0; i < graph->n_nodes; i++) {
+ struct ggml_tensor * node = graph->nodes[i];
+ struct node_alloc * node_alloc = &galloc->node_allocs[i];
-void ggml_allocr_set_parse_seq(ggml_allocr_t alloc, const int * list, int n) {
- ggml_gallocr_set_parse_seq(alloc->galloc, list, n);
-}
+ if (!ggml_gallocr_node_needs_realloc(galloc, node, node_alloc, &node_alloc->dst)) {
+#ifndef NDEBUG
+ fprintf(stderr, "%s: node %s is not valid\n", __func__, node->name);
+#endif
+ return true;
+ }
-void ggml_allocr_free(ggml_allocr_t alloc) {
- if (alloc == NULL) {
- return;
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ struct ggml_tensor * src = node->src[j];
+ if (src == NULL) {
+ break;
+ }
+ if (!ggml_gallocr_node_needs_realloc(galloc, src, node_alloc, &node_alloc->src[j])) {
+#ifndef NDEBUG
+ fprintf(stderr, "%s: src %d (%s) of node %s is not valid\n", __func__, j, src->name, node->name);
+#endif
+ return true;
+ }
+ }
}
- ggml_gallocr_free(alloc->galloc);
- ggml_tallocr_free(alloc->talloc);
- free(alloc);
+ return false;
}
-bool ggml_allocr_is_measure(ggml_allocr_t alloc) {
- return ggml_tallocr_is_measure(alloc->talloc);
-}
+bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph) {
+ if (ggml_gallocr_needs_realloc(galloc, graph)) {
+ if (galloc->n_buffers == 1) {
+#ifndef NDEBUG
+ fprintf(stderr, "%s: reallocating buffers automatically\n", __func__);
+#endif
+ if (!ggml_gallocr_reserve(galloc, graph)) {
+ return false;
+ }
+ } else {
+#ifndef NDEBUG
+ fprintf(stderr, "%s: cannot reallocate multi buffer graph automatically, call reserve\n", __func__);
+#endif
+ return false;
+ }
+ }
-void ggml_allocr_reset(ggml_allocr_t alloc) {
- ggml_tallocr_reset(alloc->talloc);
-}
+ // reset buffers
+ for (int i = 0; i < galloc->n_buffers; i++) {
+ // zero size buffers are not allocated
+ if (galloc->buffers[i] != NULL) {
+ ggml_backend_buffer_reset(galloc->buffers[i]);
+ }
+ }
-void ggml_allocr_alloc(ggml_allocr_t alloc, struct ggml_tensor * tensor) {
- ggml_tallocr_alloc(alloc->talloc, tensor);
-}
+ // allocate the graph tensors from the previous assignments
+ for (int i = 0; i < graph->n_nodes; i++) {
+ struct ggml_tensor * node = graph->nodes[i];
+ struct node_alloc * node_alloc = &galloc->node_allocs[i];
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ struct ggml_tensor * src = node->src[j];
+ if (src == NULL) {
+ break;
+ }
+ ggml_gallocr_init_tensor(galloc, src, node_alloc, &node_alloc->src[j]);
+ }
+ ggml_gallocr_init_tensor(galloc, node, node_alloc, &node_alloc->dst);
+ }
-size_t ggml_allocr_max_size(ggml_allocr_t alloc) {
- return ggml_tallocr_max_size(alloc->talloc);
+ return true;
}
-size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph) {
- return ggml_gallocr_alloc_graph(alloc->galloc, alloc->talloc, graph);
+size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) {
+ GGML_ASSERT(buffer_id >= 0 && buffer_id < galloc->n_buffers);
+
+ if (galloc->buffers[buffer_id] == NULL) {
+ return 0;
+ }
+ return ggml_backend_buffer_get_size(galloc->buffers[buffer_id]);
}
// utils
return false;
}
- ggml_tallocr_t tallocr = ggml_tallocr_new_from_buffer(buffer);
+ struct ggml_tallocr * tallocr = ggml_tallocr_new(buffer);
for (struct ggml_tensor * t = first; t != last; t = ggml_get_next_tensor(ctx, t)) {
if (t->data == NULL) {
if (t->view_src == NULL) {
ggml_tallocr_alloc(tallocr, t);
- } else {
+ } else if (t->buffer == NULL) {
ggml_backend_view_init(buffer, t);
}
} else {
- if (t->view_src != NULL) {
+ if (t->view_src != NULL && t->buffer == NULL) {
// view of a pre-allocated tensor
ggml_backend_view_init(buffer, t);
}
}
if (this_size > max_size) {
- // tensor is too large to fit in a single buffer
fprintf(stderr, "%s: tensor %s is too large to fit in a %s buffer (tensor size: %zu, max buffer size: %zu)\n",
__func__, t->name,
ggml_backend_buft_name(buft),
}
if (n_buffers == 0) {
- // all the tensors in the context are already allocated
#ifndef NDEBUG
fprintf(stderr, "%s: all tensors in the context are already allocated\n", __func__);
#endif
extern "C" {
#endif
-struct ggml_backend;
-struct ggml_backend_buffer;
-struct ggml_backend_buffer_type;
+typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
+typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
+typedef struct ggml_backend * ggml_backend_t;
-//
-// Legacy API
-//
-
-typedef struct ggml_allocr * ggml_allocr_t;
-
-// initialize allocator for use with CPU backend only
-GGML_API ggml_allocr_t ggml_allocr_new(void * data, size_t size, size_t alignment);
-GGML_API ggml_allocr_t ggml_allocr_new_measure(size_t alignment);
-
-// initialize allocator for use with ggml-backend
-GGML_API ggml_allocr_t ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer);
-GGML_API ggml_allocr_t ggml_allocr_new_from_backend(struct ggml_backend * backend, size_t size); // allocates an owned buffer
-GGML_API ggml_allocr_t ggml_allocr_new_measure_from_backend(struct ggml_backend * backend);
-
-GGML_API struct ggml_backend_buffer * ggml_allocr_get_buffer(ggml_allocr_t alloc);
-
-// tell the allocator to parse nodes following the order described in the list
-// you should call this if your graph are optimized to execute out-of-order
-GGML_API void ggml_allocr_set_parse_seq(ggml_allocr_t alloc, const int * list, int n);
-
-GGML_API void ggml_allocr_free (ggml_allocr_t alloc);
-GGML_API bool ggml_allocr_is_measure (ggml_allocr_t alloc);
-GGML_API void ggml_allocr_reset (ggml_allocr_t alloc);
-GGML_API void ggml_allocr_alloc (ggml_allocr_t alloc, struct ggml_tensor * tensor);
-GGML_API size_t ggml_allocr_max_size (ggml_allocr_t alloc);
-
-GGML_API size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph);
+// Tensor allocator
+typedef struct ggml_tallocr * ggml_tallocr_t;
-//
-// ggml-backend v2 API
-//
+GGML_API ggml_tallocr_t ggml_tallocr_new(ggml_backend_buffer_t buffer);
+GGML_API void ggml_tallocr_free(ggml_tallocr_t talloc);
+GGML_API void ggml_tallocr_alloc(ggml_tallocr_t talloc, struct ggml_tensor * tensor);
-// Separate tensor and graph allocator objects
-// This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators
-// The original API is kept as a wrapper around the new API
+// Graph allocator
+/*
+ Example usage:
+ ggml_gallocr_t galloc = ggml_gallocr_new(ggml_bacckend_cpu_buffer_type());
-// Tensor allocator
-typedef struct ggml_tallocr * ggml_tallocr_t;
+ // optional: create a worst-case graph and reserve the buffers to avoid reallocations
+ ggml_gallocr_reserve(galloc, build_graph(max_batch));
-GGML_API ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment);
-GGML_API ggml_tallocr_t ggml_tallocr_new_measure(size_t alignment);
-GGML_API ggml_tallocr_t ggml_tallocr_new_from_buft(struct ggml_backend_buffer_type * buft, size_t size);
-GGML_API ggml_tallocr_t ggml_tallocr_new_from_backend(struct ggml_backend * backend, size_t size); // allocates an owned buffer
-GGML_API ggml_tallocr_t ggml_tallocr_new_from_buffer(struct ggml_backend_buffer * buffer);
-GGML_API ggml_tallocr_t ggml_tallocr_new_measure_from_buft(struct ggml_backend_buffer_type * buft);
-GGML_API ggml_tallocr_t ggml_tallocr_new_measure_from_backend(struct ggml_backend * backend);
+ // allocate the graph
+ struct ggml_cgraph * graph = build_graph(batch);
+ ggml_gallocr_alloc_graph(galloc, graph);
-GGML_API struct ggml_backend_buffer * ggml_tallocr_get_buffer(ggml_tallocr_t talloc);
+ printf("compute buffer size: %zu bytes\n", ggml_gallocr_get_buffer_size(galloc, 0));
-GGML_API void ggml_tallocr_free (ggml_tallocr_t talloc);
-GGML_API bool ggml_tallocr_is_measure (ggml_tallocr_t talloc);
-GGML_API void ggml_tallocr_reset (ggml_tallocr_t talloc);
-GGML_API void ggml_tallocr_alloc (ggml_tallocr_t talloc, struct ggml_tensor * tensor);
-GGML_API size_t ggml_tallocr_max_size (ggml_tallocr_t talloc);
+ // evaluate the graph
+ ggml_backend_graph_compute(backend, graph);
+*/
+// special tensor flags for use with the graph allocator:
+// ggml_set_input(): all input tensors are allocated at the beginning of the graph in non-overlapping addresses
+// ggml_set_output(): output tensors are never freed and never overwritten
-// Graph allocator
typedef struct ggml_gallocr * ggml_gallocr_t;
-GGML_API ggml_gallocr_t ggml_gallocr_new(void);
-GGML_API void ggml_gallocr_free(ggml_gallocr_t galloc);
+GGML_API ggml_gallocr_t ggml_gallocr_new(ggml_backend_buffer_type_t buft);
+GGML_API ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs);
+GGML_API void ggml_gallocr_free(ggml_gallocr_t galloc);
-GGML_API void ggml_gallocr_set_parse_seq(ggml_gallocr_t galloc, const int * list, int n);
-GGML_API size_t ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, ggml_tallocr_t talloc, struct ggml_cgraph * graph);
+// pre-allocate buffers from a measure graph - does not allocate or modify the graph
+// call with a worst-case graph to avoid buffer reallocations
+// not strictly required for single buffer usage: ggml_gallocr_alloc_graph will reallocate the buffers automatically if needed
+// returns false if the buffer allocation failed
+GGML_API bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph * graph);
+GGML_API bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids);
-// Allocate tensors from the allocators given by the hash table
-GGML_API void ggml_gallocr_alloc_graph_n(
- ggml_gallocr_t galloc,
- struct ggml_cgraph * graph,
- struct ggml_hash_set hash_set,
- ggml_tallocr_t * hash_node_talloc);
+// automatic reallocation if the topology changes when using a single buffer
+// returns false if using multiple buffers and a re-allocation is needed (call ggml_gallocr_reserve_n first to set the node buffers)
+GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph);
+GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id);
// Utils
// Create a buffer and allocate all the tensors in a ggml_context
-GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, struct ggml_backend_buffer_type * buft);
-GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, struct ggml_backend * backend);
+GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);
+GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend);
#ifdef __cplusplus
}
// backend CPU
+static const size_t TENSOR_ALIGNMENT = 32; // required for mmap as gguf only guarantees 32-byte alignment
+
GGML_CALL static const char * ggml_backend_cpu_buffer_name(ggml_backend_buffer_t buffer) {
return "CPU";
}
GGML_CALL static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
- return (void *)buffer->context;
+ uintptr_t data = (uintptr_t)buffer->context;
+
+ // align the buffer
+ if (data % TENSOR_ALIGNMENT != 0) {
+ data = GGML_PAD(data, TENSOR_ALIGNMENT);
+ }
+
+ return (void *)data;
}
GGML_CALL static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
/* .reset = */ NULL,
};
-static const size_t TENSOR_ALIGNMENT = 64; // should be enough for AVX 512
-
GGML_CALL static const char * ggml_backend_cpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
return "CPU";
GGML_CALL static ggml_backend_buffer_t ggml_backend_cpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned
- void * data = malloc(size); // TODO: maybe use GGML_ALIGNED_MALLOC?
-
- GGML_ASSERT(data != NULL && "failed to allocate buffer");
+ void * data = malloc(size); // TODO: use GGML_ALIGNED_MALLOC (move to ggml-impl.h)
+ if (data == NULL) {
+ fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size);
+ return NULL;
+ }
return ggml_backend_buffer_init(buft, cpu_backend_buffer_i, data, size);
}
ggml_backend_t ggml_backend_cpu_init(void) {
struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context));
+ if (ctx == NULL) {
+ return NULL;
+ }
ctx->n_threads = GGML_DEFAULT_N_THREADS;
ctx->work_data = NULL;
ctx->abort_callback_data = NULL;
ggml_backend_t cpu_backend = malloc(sizeof(struct ggml_backend));
+ if (cpu_backend == NULL) {
+ free(ctx);
+ return NULL;
+ }
*cpu_backend = (struct ggml_backend) {
/* .interface = */ cpu_backend_i,
}
GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
+ GGML_ASSERT((uintptr_t)ptr % TENSOR_ALIGNMENT == 0 && "buffer pointer must be aligned");
return ggml_backend_buffer_init(ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size);
}
ctx->n_buffers = n_buffers;
ctx->buffers = (ggml_backend_buffer_t *) malloc(n_buffers * sizeof(ggml_backend_buffer_t));
+ GGML_ASSERT(ctx->buffers != NULL);
+
size_t total_size = 0;
for (size_t i = 0; i < n_buffers; i++) {
ctx->buffers[i] = buffers[i];
}
}
+// creates a copy of the tensor with the same memory layout
+static struct ggml_tensor * ggml_dup_tensor_layout(struct ggml_context * ctx, const struct ggml_tensor * tensor) {
+ struct ggml_tensor * dup = ggml_dup_tensor(ctx, tensor);
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
+ dup->nb[i] = tensor->nb[i];
+ }
+ return dup;
+}
+
+static bool ggml_is_view_op(enum ggml_op op) {
+ return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE;
+}
// scheduler
#define GGML_MAX_SPLIT_INPUTS 16
struct ggml_backend_sched_split {
- ggml_tallocr_t tallocr;
+ int backend_id;
int i_start;
int i_end;
struct ggml_tensor * inputs[GGML_MAX_SPLIT_INPUTS];
int n_backends;
ggml_backend_t backends[GGML_MAX_BACKENDS];
ggml_backend_buffer_type_t bufts[GGML_MAX_BACKENDS];
- ggml_tallocr_t tallocs[GGML_MAX_BACKENDS];
ggml_gallocr_t galloc;
// hash keys of the nodes in the graph
struct ggml_hash_set hash_set;
- // hash values (arrays of [hash_set.size])
- ggml_tallocr_t * node_talloc; // tallocr assigned to each node (indirectly this is the backend)
- struct ggml_tensor * (* node_copies)[GGML_MAX_BACKENDS]; // copies of each node for each destination backend
+ // hash values
+ int * tensor_backend_id;
+ struct ggml_tensor * (* tensor_copies)[GGML_MAX_BACKENDS];
+
+ int * node_backend_ids; // [n_nodes]
+ int n_nodes;
// copy of the graph with modified inputs
struct ggml_cgraph * graph;
struct ggml_context * ctx;
+ ggml_backend_sched_eval_callback callback_eval;
+ void * callback_eval_user_data;
+
// align context_buffer to GGML_MEM_ALIGN
#ifdef _MSC_VER
__declspec(align(GGML_MEM_ALIGN))
#else
__attribute__((aligned(GGML_MEM_ALIGN)))
#endif
- char context_buffer[GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS*sizeof(struct ggml_tensor) + sizeof(struct ggml_cgraph)];
-
- ggml_backend_sched_eval_callback callback_eval;
- void * callback_eval_user_data;
+ char context_buffer[GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + sizeof(struct ggml_cgraph)];
};
#define hash_id(node) ggml_hash_find_or_insert(sched->hash_set, node)
-#define node_allocr(node) sched->node_talloc[hash_id(node)]
-
-static bool ggml_is_view_op(enum ggml_op op) {
- return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE;
-}
+#define tensor_backend_id(node) sched->tensor_backend_id[hash_id(node)]
+#define tensor_backend(node) (tensor_backend_id(node) == -1 ? NULL : sched->backends[tensor_backend_id(node)])
-// returns the priority of the backend, lower is better
-static int sched_backend_prio(ggml_backend_sched_t sched, ggml_backend_t backend) {
+// returns the priority of the backend, lower id is higher priority
+static int ggml_backend_sched_backend_id(ggml_backend_sched_t sched, ggml_backend_t backend) {
for (int i = 0; i < sched->n_backends; i++) {
if (sched->backends[i] == backend) {
return i;
}
}
- return INT_MAX;
+ return -1;
}
-static int sched_allocr_prio(ggml_backend_sched_t sched, ggml_tallocr_t allocr) {
- for (int i = 0; i < sched->n_backends; i++) {
- if (sched->tallocs[i] == allocr) {
- return i;
- }
- }
- return INT_MAX;
-}
-
-static ggml_tallocr_t sched_allocr_from_buffer(ggml_backend_sched_t sched, ggml_backend_buffer_t buffer) {
+static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, ggml_backend_buffer_t buffer) {
if (buffer == NULL) {
- return NULL;
- }
-
- // check if this is already allocate in a allocr buffer (from user manual allocations)
- for (int i = 0; i < sched->n_backends; i++) {
- if (ggml_tallocr_get_buffer(sched->tallocs[i]) == buffer) {
- return sched->tallocs[i];
- }
+ return -1;
}
// find highest prio backend that supports the buffer type
for (int i = 0; i < sched->n_backends; i++) {
if (ggml_backend_buft_supports_backend(buffer->buft, sched->backends[i])) {
- return sched->tallocs[i];
+ return i;
}
}
GGML_ASSERT(false && "tensor buffer type not supported by any backend");
}
-static ggml_backend_t get_allocr_backend(ggml_backend_sched_t sched, ggml_tallocr_t allocr) {
- if (allocr == NULL) {
- return NULL;
- }
- for (int i = 0; i < sched->n_backends; i++) {
- if (sched->tallocs[i] == allocr) {
- return sched->backends[i];
- }
- }
- GGML_UNREACHABLE();
-}
-
#if 0
static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS][128]; // debug only
#define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__)
#endif
// returns the backend that should be used for the node based on the current locations
-static ggml_tallocr_t sched_allocr_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * node) {
+static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * tensor) {
+ // TODO: use supports_op to check if the backend supports the op
+
// assign pre-allocated nodes to their backend
// dst
- ggml_tallocr_t cur_allocr = sched_allocr_from_buffer(sched, node->buffer);
- if (cur_allocr != NULL) {
+ int cur_backend = ggml_backend_sched_backend_from_buffer(sched, tensor->buffer);
+ if (cur_backend != -1) {
SET_CAUSE(node, "1.dst");
- return cur_allocr;
+ return cur_backend;
}
// view_src
- if (node->view_src != NULL) {
- cur_allocr = sched_allocr_from_buffer(sched, node->view_src->buffer);
- if (cur_allocr != NULL) {
+ if (tensor->view_src != NULL) {
+ cur_backend = ggml_backend_sched_backend_from_buffer(sched, tensor->view_src->buffer);
+ if (cur_backend != -1) {
SET_CAUSE(node, "1.vsrc");
- return cur_allocr;
+ return cur_backend;
}
}
// assign nodes that use weights to the backend of the weights
for (int i = 0; i < GGML_MAX_SRC; i++) {
- const struct ggml_tensor * src = node->src[i];
+ const struct ggml_tensor * src = tensor->src[i];
if (src == NULL) {
break;
}
if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
- ggml_tallocr_t src_allocr = sched_allocr_from_buffer(sched, src->buffer);
+ int src_backend = ggml_backend_sched_backend_from_buffer(sched, src->buffer);
// operations with weights are always run on the same backend as the weights
SET_CAUSE(node, "1.wgt%d", i);
- return src_allocr;
+ return src_backend;
}
}
- return NULL;
+ return -1;
}
static char * fmt_size(size_t size) {
return buffer;
}
-static void sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
int cur_split = 0;
for (int i = 0; i < graph->n_nodes; i++) {
if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) {
- ggml_backend_t split_backend = get_allocr_backend(sched, sched->splits[cur_split].tallocr);
+ ggml_backend_t split_backend = sched->backends[sched->splits[cur_split].backend_id];
fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend),
sched->splits[cur_split].n_inputs);
for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) {
if (ggml_is_view_op(node->op)) {
continue;
}
- ggml_tallocr_t node_allocr = node_allocr(node);
- ggml_backend_t node_backend = node_allocr ? get_allocr_backend(sched, node_allocr) : NULL; // FIXME:
+ ggml_backend_t tensor_backend = tensor_backend(node);
fprintf(stderr, "node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, ggml_op_name(node->op), node->name,
- fmt_size(ggml_nbytes(node)), node_allocr ? ggml_backend_name(node_backend) : "NULL", GET_CAUSE(node));
+ fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node));
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * src = node->src[j];
if (src == NULL) {
break;
}
- ggml_tallocr_t src_allocr = node_allocr(src);
- ggml_backend_t src_backend = src_allocr ? get_allocr_backend(sched, src_allocr) : NULL;
+ ggml_backend_t src_backend = tensor_backend(src);
fprintf(stderr, " %20.20s (%5.5s) [%5.5s %8.8s]", src->name,
fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", GET_CAUSE(src));
}
}
}
-// creates a copy of the tensor with the same memory layout
-static struct ggml_tensor * ggml_dup_tensor_layout(struct ggml_context * ctx, const struct ggml_tensor * tensor) {
- struct ggml_tensor * dup = ggml_dup_tensor(ctx, tensor);
- for (int i = 0; i < GGML_MAX_DIMS; i++) {
- dup->nb[i] = tensor->nb[i];
- }
- return dup;
-}
-
-
//#define DEBUG_PASS1
//#define DEBUG_PASS2
//#define DEBUG_PASS3
//#define DEBUG_PASS4
// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
-static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
// reset splits
sched->n_splits = 0;
sched->is_reset = false;
// pass 1: assign backends to ops with pre-allocated inputs
for (int i = 0; i < graph->n_leafs; i++) {
struct ggml_tensor * leaf = graph->leafs[i];
- if (node_allocr(leaf) != NULL) {
+ if (tensor_backend_id(leaf) != -1) {
// do not overwrite user assignments
continue;
}
- node_allocr(leaf) = sched_allocr_from_cur(sched, leaf);
+ tensor_backend_id(leaf) = ggml_backend_sched_backend_id_from_cur(sched, leaf);
}
for (int i = 0; i < graph->n_nodes; i++) {
struct ggml_tensor * node = graph->nodes[i];
- if (node_allocr(node) != NULL) {
+ if (tensor_backend_id(node) != -1) {
// do not overwrite user assignments
continue;
}
- node_allocr(node) = sched_allocr_from_cur(sched, node);
+ tensor_backend_id(node) = ggml_backend_sched_backend_id_from_cur(sched, node);
// src
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * src = node->src[j];
if (src == NULL) {
break;
}
- if (node_allocr(src) == NULL) {
- node_allocr(src) = sched_allocr_from_cur(sched, src);
+ if (tensor_backend_id(src) == -1) {
+ tensor_backend_id(src) = ggml_backend_sched_backend_id_from_cur(sched, src);
}
}
}
// pass 2.1 expand gpu up
{
- ggml_tallocr_t cur_allocr = NULL;
+ int cur_backend_id = -1;
for (int i = graph->n_nodes - 1; i >= 0; i--) {
struct ggml_tensor * node = graph->nodes[i];
if (ggml_is_view_op(node->op)) {
continue;
}
- ggml_tallocr_t node_allocr = node_allocr(node);
- if (node_allocr != NULL) {
- if (sched_allocr_prio(sched, node_allocr) == sched->n_backends - 1) {
+ int tensor_backend_id = tensor_backend_id(node);
+ if (tensor_backend_id != -1) {
+ if (tensor_backend_id == sched->n_backends - 1) {
// skip cpu (lowest prio backend)
- cur_allocr = NULL;
+ cur_backend_id = -1;
} else {
- cur_allocr = node_allocr;
+ cur_backend_id = tensor_backend_id;
}
} else {
- node_allocr(node) = cur_allocr;
+ tensor_backend_id(node) = cur_backend_id;
SET_CAUSE(node, "2.1");
}
}
// pass 2.2 expand gpu down
{
- ggml_tallocr_t cur_allocr = NULL;
+ int cur_backend_id = -1;
for (int i = 0; i < graph->n_nodes; i++) {
struct ggml_tensor * node = graph->nodes[i];
if (ggml_is_view_op(node->op)) {
continue;
}
- ggml_tallocr_t node_allocr = node_allocr(node);
- if (node_allocr != NULL) {
- if (sched_allocr_prio(sched, node_allocr) == sched->n_backends - 1) {
+ int tensor_backend_id = tensor_backend_id(node);
+ if (tensor_backend_id != -1) {
+ if (tensor_backend_id == sched->n_backends - 1) {
// skip cpu (lowest prio backend)
- cur_allocr = NULL;
+ cur_backend_id = -1;
} else {
- cur_allocr = node_allocr;
+ cur_backend_id = tensor_backend_id;
}
} else {
- node_allocr(node) = cur_allocr;
+ tensor_backend_id(node) = cur_backend_id;
SET_CAUSE(node, "2.2");
}
}
// pass 2.3 expand rest up
{
- ggml_tallocr_t cur_allocr = NULL;
+ int cur_backend_id = -1;
for (int i = graph->n_nodes - 1; i >= 0; i--) {
struct ggml_tensor * node = graph->nodes[i];
if (ggml_is_view_op(node->op)) {
continue;
}
- ggml_tallocr_t node_allocr = node_allocr(node);
- if (node_allocr != NULL) {
- cur_allocr = node_allocr;
+ int tensor_backend_id = tensor_backend_id(node);
+ if (tensor_backend_id != -1) {
+ cur_backend_id = tensor_backend_id;
} else {
- node_allocr(node) = cur_allocr;
+ tensor_backend_id(node) = cur_backend_id;
SET_CAUSE(node, "2.3");
}
}
// pass 2.4 expand rest down
{
- ggml_tallocr_t cur_allocr = NULL;
+ int cur_backend_id = -1;
for (int i = 0; i < graph->n_nodes; i++) {
struct ggml_tensor * node = graph->nodes[i];
if (ggml_is_view_op(node->op)) {
continue;
}
- ggml_tallocr_t node_allocr = node_allocr(node);
- if (node_allocr != NULL) {
- cur_allocr = node_allocr;
+ int tensor_backend_id = tensor_backend_id(node);
+ if (tensor_backend_id != -1) {
+ cur_backend_id = tensor_backend_id;
} else {
- node_allocr(node) = cur_allocr;
+ tensor_backend_id(node) = cur_backend_id;
SET_CAUSE(node, "2.4");
}
}
// pass 3: assign backends to remaining src from dst and view_src
for (int i = 0; i < graph->n_nodes; i++) {
struct ggml_tensor * node = graph->nodes[i];
- ggml_tallocr_t cur_allocr = node_allocr(node);
- if (node->view_src != NULL && cur_allocr == NULL) {
- cur_allocr = node_allocr(node) = node_allocr(node->view_src);
+ int cur_backend_id = tensor_backend_id(node);
+ if (node->view_src != NULL && cur_backend_id == -1) {
+ cur_backend_id = tensor_backend_id(node) = tensor_backend_id(node->view_src);
SET_CAUSE(node, "3.vsrc");
}
for (int j = 0; j < GGML_MAX_SRC; j++) {
if (src == NULL) {
break;
}
- ggml_tallocr_t src_allocr = node_allocr(src);
- if (src_allocr == NULL) {
+ int src_backend_id = tensor_backend_id(src);
+ if (src_backend_id == -1) {
if (src->view_src != NULL) {
// views are always on the same backend as the source
- node_allocr(src) = node_allocr(src->view_src);
+ tensor_backend_id(src) = tensor_backend_id(src->view_src);
SET_CAUSE(src, "3.vsrc");
} else {
- node_allocr(src) = cur_allocr;
+ tensor_backend_id(src) = cur_backend_id;
SET_CAUSE(src, "3.cur");
}
}
for (int i = 0; i < graph->n_nodes; i++) {
struct ggml_tensor * node = graph->nodes[i];
if (!ggml_is_view_op(node->op)) {
- sched->splits[0].tallocr = node_allocr(node);
+ sched->splits[0].backend_id = tensor_backend_id(node);
break;
}
}
sched->splits[0].i_start = 0;
sched->splits[0].n_inputs = 0;
memset(sched->splits[0].inputs, 0, sizeof(sched->splits[0].inputs)); //HACK
- ggml_tallocr_t cur_allocr = sched->splits[0].tallocr;
- size_t cur_backend_id = sched_allocr_prio(sched, cur_allocr);
+ int cur_backend_id = sched->splits[0].backend_id;
for (int i = 0; i < graph->n_nodes; i++) {
struct ggml_tensor * node = graph->nodes[i];
continue;
}
- ggml_tallocr_t node_allocr = node_allocr(node);
+ int tensor_backend_id = tensor_backend_id(node);
- GGML_ASSERT(node_allocr != NULL); // all nodes should be assigned by now
+ GGML_ASSERT(tensor_backend_id != -1); // all nodes should be assigned by now
- if (node_allocr != cur_allocr) {
+ if (tensor_backend_id != cur_backend_id) {
sched->splits[cur_split].i_end = i;
cur_split++;
GGML_ASSERT(cur_split < GGML_MAX_SPLITS);
- sched->splits[cur_split].tallocr = node_allocr;
+ sched->splits[cur_split].backend_id = tensor_backend_id;
sched->splits[cur_split].i_start = i;
sched->splits[cur_split].n_inputs = 0;
- cur_allocr = node_allocr;
- cur_backend_id = sched_allocr_prio(sched, cur_allocr);
+ cur_backend_id = tensor_backend_id;
}
// find inputs that are not on the same backend
if (src == NULL) {
break;
}
- ggml_tallocr_t src_allocr = node_allocr(src);
- GGML_ASSERT(src_allocr != NULL); // all inputs should be assigned by now
- if (src_allocr != node_allocr) {
+ int src_backend_id = tensor_backend_id(src);
+ assert(src_backend_id != -1); // all inputs should be assigned by now
+ if (src_backend_id != tensor_backend_id) {
// create a copy of the input in the split's backend
size_t id = hash_id(src);
- if (sched->node_copies[id][cur_backend_id] == NULL) {
- ggml_backend_t backend = get_allocr_backend(sched, cur_allocr);
+ if (sched->tensor_copies[id][cur_backend_id] == NULL) {
+ ggml_backend_t backend = sched->backends[cur_backend_id];
struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
ggml_format_name(tensor_copy, "%s#%s", ggml_backend_name(backend), src->name);
- sched->node_copies[id][cur_backend_id] = tensor_copy;
- node_allocr(tensor_copy) = cur_allocr;
+ sched->tensor_copies[id][cur_backend_id] = tensor_copy;
+ tensor_backend_id(tensor_copy) = cur_backend_id;
SET_CAUSE(tensor_copy, "4.cpy");
int n_inputs = sched->splits[cur_split].n_inputs++;
GGML_ASSERT(n_inputs < GGML_MAX_SPLIT_INPUTS);
sched->splits[cur_split].inputs[n_inputs] = src;
}
- node->src[j] = sched->node_copies[id][cur_backend_id];
-
-#if 0
- // check if the input is already in the split
- bool found = false;
- for (int k = 0; k < sched->splits[cur_split].n_inputs; k++) {
- if (sched->splits[cur_split].inputs[k] == src) {
- found = true;
- break;
- }
- }
-
- if (!found) {
- int n_inputs = sched->splits[cur_split].n_inputs++;
- //printf("split %d input %d: %s (%s)\n", cur_split, n_inputs, src->name, ggml_backend_name(get_allocr_backend(sched, src_allocr)));
- GGML_ASSERT(n_inputs < GGML_MAX_SPLIT_INPUTS);
- sched->splits[cur_split].inputs[n_inputs] = src;
- }
-#endif
+ node->src[j] = sched->tensor_copies[id][cur_backend_id];
}
}
}
// sanity check: all sources should have the same backend as the node
for (int i = 0; i < graph->n_nodes; i++) {
struct ggml_tensor * node = graph->nodes[i];
- ggml_tallocr_t node_allocr = node_allocr(node);
- if (node_allocr == NULL) {
+ ggml_backend_t tensor_backend = tensor_backend(node);
+ if (tensor_backend == NULL) {
fprintf(stderr, "!!!!!!! %s has no backend\n", node->name);
}
- if (node->view_src != NULL && node_allocr != node_allocr(node->view_src)) {
+ if (node->view_src != NULL && tensor_backend != tensor_backend(node->view_src)) {
fprintf(stderr, "!!!!!!! %s has backend %s, view_src %s has backend %s\n",
- node->name, node_allocr ? ggml_backend_name(get_allocr_backend(sched, node_allocr)) : "NULL",
- node->view_src->name, node_allocr(node->view_src) ? ggml_backend_name(get_allocr_backend(sched, node_allocr(node->view_src))) : "NULL");
+ node->name, tensor_backend ? ggml_backend_name(tensor_backend) : "NULL",
+ node->view_src->name, tensor_backend(node->view_src) ? ggml_backend_name(tensor_backend(node->view_src)) : "NULL");
}
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * src = node->src[j];
if (src == NULL) {
break;
}
- ggml_tallocr_t src_allocr = node_allocr(src);
- if (src_allocr != node_allocr /* && src_backend != NULL */) { // ignore nulls for now
+ ggml_backend_t src_backend = tensor_backend(src);
+ if (src_backend != tensor_backend /* && src_backend != NULL */) {
fprintf(stderr, "!!!! %s has backend %s, src %d (%s) has backend %s\n",
- node->name, node_allocr ? ggml_backend_name(get_allocr_backend(sched, node_allocr)) : "NULL",
- j, src->name, src_allocr ? ggml_backend_name(get_allocr_backend(sched, src_allocr)) : "NULL");
+ node->name, tensor_backend ? ggml_backend_name(tensor_backend) : "NULL",
+ j, src->name, src_backend ? ggml_backend_name(src_backend) : "NULL");
}
- if (src->view_src != NULL && src_allocr != node_allocr(src->view_src)) {
+ if (src->view_src != NULL && src_backend != tensor_backend(src->view_src)) {
fprintf(stderr, "!!!!!!! [src] %s has backend %s, view_src %s has backend %s\n",
- src->name, src_allocr ? ggml_backend_name(get_allocr_backend(sched, src_allocr)) : "NULL",
- src->view_src->name, node_allocr(src->view_src) ? ggml_backend_name(get_allocr_backend(sched, node_allocr(src->view_src))) : "NULL");
+ src->name, src_backend ? ggml_backend_name(src_backend) : "NULL",
+ src->view_src->name, tensor_backend(src->view_src) ? ggml_backend_name(tensor_backend(src->view_src)) : "NULL");
}
}
}
struct ggml_backend_sched_split * split = &sched->splits[i];
split->graph = ggml_graph_view(graph, split->i_start, split->i_end);
- // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split
for (int j = 0; j < split->n_inputs; j++) {
struct ggml_tensor * input = split->inputs[j];
- struct ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][sched_allocr_prio(sched, split->tallocr)];
+ struct ggml_tensor * input_cpy = sched->tensor_copies[hash_id(input)][split->backend_id];
+
// add a dependency to the input source so that it is not freed before the copy is done
- GGML_ASSERT(input_cpy->src[0] == NULL || input_cpy->src[0] == input);
- input_cpy->src[0] = input;
+ struct ggml_tensor * input_dep = ggml_view_tensor(sched->ctx, input);
+ sched->node_backend_ids[graph_copy->n_nodes] = tensor_backend_id(input);
+ graph_copy->nodes[graph_copy->n_nodes++] = input_dep;
+
+ // add a dependency to the input copy so that it is allocated at the start of the split
+ sched->node_backend_ids[graph_copy->n_nodes] = split->backend_id;
graph_copy->nodes[graph_copy->n_nodes++] = input_cpy;
}
for (int j = split->i_start; j < split->i_end; j++) {
+ sched->node_backend_ids[graph_copy->n_nodes] = tensor_backend_id(graph->nodes[j]);
graph_copy->nodes[graph_copy->n_nodes++] = graph->nodes[j];
}
}
sched->graph = graph_copy;
}
-static void sched_alloc_splits(ggml_backend_sched_t sched) {
- ggml_gallocr_alloc_graph_n(
- sched->galloc,
- sched->graph,
- sched->hash_set,
- sched->node_talloc);
+static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
+ // ggml_gallocr_reserve_n(sched->galloc, sched->graph, sched->node_backend_ids);
+ if (!ggml_gallocr_alloc_graph(sched->galloc, sched->graph)) {
+#ifndef NDEBUG
+ fprintf(stderr, "ggml_backend_sched: failed to allocate graph, reserving\n");
+#endif
+ ggml_gallocr_reserve_n(sched->galloc, sched->graph, sched->node_backend_ids);
+ if (!ggml_gallocr_alloc_graph(sched->galloc, sched->graph)) {
+ fprintf(stderr, "ggml_backend_sched: failed to allocate graph\n");
+ return false;
+ }
+ }
+
+ return true;
}
-static void sched_compute_splits(ggml_backend_sched_t sched) {
+static bool ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
uint64_t copy_us[GGML_MAX_BACKENDS] = {0};
uint64_t compute_us[GGML_MAX_BACKENDS] = {0};
for (int i = 0; i < sched->n_splits; i++) {
struct ggml_backend_sched_split * split = &splits[i];
- ggml_backend_t split_backend = get_allocr_backend(sched, split->tallocr);
- int split_backend_id = sched_backend_prio(sched, split_backend);
+ int split_backend_id = split->backend_id;
+ ggml_backend_t split_backend = sched->backends[split_backend_id];
// copy the input tensors to the split backend
uint64_t copy_start_us = ggml_time_us();
for (int j = 0; j < split->n_inputs; j++) {
struct ggml_tensor * input = split->inputs[j];
- struct ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][split_backend_id];
+ struct ggml_tensor * input_cpy = sched->tensor_copies[hash_id(input)][split_backend_id];
GGML_ASSERT(input->buffer != NULL);
GGML_ASSERT(input_cpy->buffer != NULL);
- // TODO: avoid this copy if it was already copied in a previous split, and the input didn't change
- // this is important to avoid copying constants such as KQ_mask and inp_pos multiple times
ggml_backend_tensor_copy_async(split_backend, input, input_cpy);
}
//ggml_backend_synchronize(split_backend); // necessary to measure copy time
uint64_t compute_start_us = ggml_time_us();
if (!sched->callback_eval) {
- ggml_backend_graph_compute(split_backend, &split->graph);
+ if (!ggml_backend_graph_compute(split_backend, &split->graph)) {
+ return false;
+ }
//ggml_backend_synchronize(split_backend); // necessary to measure compute time
} else {
// similar to ggml_backend_compare_graph_backend
struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1);
- ggml_backend_graph_compute(split_backend, &gv);
+ if (!ggml_backend_graph_compute(split_backend, &gv)) {
+ return false;
+ }
if (need && !sched->callback_eval(t, false, sched->callback_eval_user_data)) {
break;
}
}
#endif
-}
-
-static void sched_reset(ggml_backend_sched_t sched) {
- for (int i = 0; i < sched->n_backends; i++) {
- ggml_tallocr_reset(sched->tallocs[i]);
- }
- // reset state for the next run
- size_t hash_size = sched->hash_set.size;
- memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size);
- memset(sched->node_talloc, 0, sizeof(sched->node_talloc[0]) * hash_size);
- memset(sched->node_copies, 0, sizeof(sched->node_copies[0]) * hash_size);
- sched->is_reset = true;
+ return true;
}
ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size) {
struct ggml_backend_sched * sched = calloc(sizeof(struct ggml_backend_sched), 1);
// initialize hash table
- sched->hash_set = ggml_hash_set_new(graph_size + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS);
- sched->node_talloc = calloc(sizeof(sched->node_talloc[0]) * sched->hash_set.size, 1);
- sched->node_copies = calloc(sizeof(sched->node_copies[0]) * sched->hash_set.size, 1);
+ sched->hash_set = ggml_hash_set_new(graph_size + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS);
+ sched->tensor_backend_id = calloc(sizeof(sched->tensor_backend_id[0]), sched->hash_set.size);
+ sched->tensor_copies = calloc(sizeof(sched->tensor_copies[0]), sched->hash_set.size);
+ sched->node_backend_ids = calloc(sizeof(sched->node_backend_ids[0]), graph_size);
sched->n_backends = n_backends;
for (int i = 0; i < n_backends; i++) {
sched->bufts[i] = bufts ? bufts[i] : ggml_backend_get_default_buffer_type(backends[i]);
}
- sched->galloc = ggml_gallocr_new();
+ sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends);
- // init measure allocs for each backend
- for (int i = 0; i < n_backends; i++) {
- sched->tallocs[i] = ggml_tallocr_new_measure_from_buft(sched->bufts[i]);
- }
-
- sched_reset(sched);
+ ggml_backend_sched_reset(sched);
return sched;
}
if (sched == NULL) {
return;
}
- for (int i = 0; i < sched->n_backends; i++) {
- ggml_tallocr_free(sched->tallocs[i]);
- }
ggml_gallocr_free(sched->galloc);
ggml_free(sched->ctx);
free(sched->hash_set.keys);
- free(sched->node_talloc);
- free(sched->node_copies);
+ free(sched->tensor_backend_id);
+ free(sched->tensor_copies);
+ free(sched->node_backend_ids);
free(sched);
}
-void ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) {
- GGML_ASSERT(ggml_tallocr_is_measure(sched->tallocs[0])); // can only be initialized once
+void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
+ // reset state for the next run
+ size_t hash_size = sched->hash_set.size;
+ memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size); // NOLINT
+ memset(sched->tensor_backend_id, -1, sizeof(sched->tensor_backend_id[0]) * hash_size);
+ memset(sched->tensor_copies, 0, sizeof(sched->tensor_copies[0]) * hash_size);
- sched_split_graph(sched, measure_graph);
- sched_alloc_splits(sched);
+ sched->is_reset = true;
+}
- // allocate buffers and reset allocators
- for (int i = 0; i < sched->n_backends; i++) {
- size_t size = ggml_tallocr_max_size(sched->tallocs[i]);
- ggml_tallocr_free(sched->tallocs[i]);
- sched->tallocs[i] = ggml_tallocr_new_from_buft(sched->bufts[i], size);
+bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) {
+ ggml_backend_sched_split_graph(sched, measure_graph);
+
+ if (!ggml_gallocr_reserve_n(sched->galloc, sched->graph, sched->node_backend_ids)) {
+ return false;
}
- sched_reset(sched);
+ ggml_backend_sched_reset(sched);
+ return true;
}
-void ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+bool ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS);
if (!sched->is_reset) {
- sched_reset(sched);
+ ggml_backend_sched_reset(sched);
}
- sched_split_graph(sched, graph);
- sched_alloc_splits(sched);
- sched_compute_splits(sched);
-}
+ ggml_backend_sched_split_graph(sched, graph);
+ if (!ggml_backend_sched_alloc_splits(sched)) {
+ return false;
+ }
-void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
- sched_reset(sched);
-}
+ if (!ggml_backend_sched_compute_splits(sched)) {
+ return false;
+ }
+ return true;
+}
void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) {
sched->callback_eval = callback;
return sched->n_splits;
}
-ggml_tallocr_t ggml_backend_sched_get_tallocr(ggml_backend_sched_t sched, ggml_backend_t backend) {
- int backend_index = sched_backend_prio(sched, backend);
- GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
- return sched->tallocs[backend_index];
-}
-
-ggml_backend_buffer_t ggml_backend_sched_get_buffer(ggml_backend_sched_t sched, ggml_backend_t backend) {
- int backend_index = sched_backend_prio(sched, backend);
+size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) {
+ int backend_index = ggml_backend_sched_backend_id(sched, backend);
GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
- return ggml_tallocr_get_buffer(sched->tallocs[backend_index]);
+ return ggml_gallocr_get_buffer_size(sched->galloc, backend_index);
}
void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
- int backend_index = sched_backend_prio(sched, backend);
+ int backend_index = ggml_backend_sched_backend_id(sched, backend);
GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
- node_allocr(node) = sched->tallocs[backend_index];
+ tensor_backend_id(node) = backend_index;
}
ggml_backend_t ggml_backend_sched_get_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node) {
- ggml_tallocr_t allocr = node_allocr(node);
- if (allocr == NULL) {
+ int backend_index = tensor_backend_id(node);
+ if (backend_index == -1) {
return NULL;
}
- return get_allocr_backend(sched, allocr);
+ return sched->backends[backend_index];
}
// utils
void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
GGML_ASSERT(tensor->buffer == NULL);
- //GGML_ASSERT(tensor->data == NULL); // views of pre-allocated tensors may have the data set in ggml_new_tensor, but still need to be initialized by the backend
GGML_ASSERT(tensor->view_src != NULL);
GGML_ASSERT(tensor->view_src->buffer != NULL);
GGML_ASSERT(tensor->view_src->data != NULL);
ggml_backend_buffer_init_tensor(buffer, tensor);
}
-static struct ggml_tensor * graph_dup_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies,
+static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies,
struct ggml_context * ctx_allocated, struct ggml_context * ctx_unallocated, struct ggml_tensor * src) {
GGML_ASSERT(src != NULL);
struct ggml_tensor * dst = ggml_dup_tensor_layout(src->data && !src->view_src ? ctx_allocated : ctx_unallocated, src);
if (src->view_src != NULL) {
- dst->view_src = graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, src->view_src);
+ dst->view_src = graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, src->view_src);
dst->view_offs = src->view_offs;
}
dst->op = src->op;
if (s == NULL) {
break;
}
- dst->src[i] = graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, s);
+ dst->src[i] = graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, s);
}
node_copies[id] = dst;
return dst;
}
-static void graph_init_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies, bool * node_init, struct ggml_tensor * src) {
+static void graph_copy_init_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies, bool * node_init, struct ggml_tensor * src) {
size_t id = ggml_hash_find(hash_set, src);
if (node_init[id]) {
return;
struct ggml_tensor * dst = node_copies[id];
if (dst->view_src != NULL) {
- graph_init_tensor(hash_set, node_copies, node_init, src->view_src);
+ graph_copy_init_tensor(hash_set, node_copies, node_init, src->view_src);
ggml_backend_view_init(dst->view_src->buffer, dst);
}
else {
if (s == NULL) {
break;
}
- graph_init_tensor(hash_set, node_copies, node_init, s);
+ graph_copy_init_tensor(hash_set, node_copies, node_init, s);
}
}
struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) {
struct ggml_hash_set hash_set = {
/* .size = */ graph->visited_hash_table.size,
- /* .keys = */ calloc(sizeof(hash_set.keys[0]) * graph->visited_hash_table.size, 1)
+ /* .keys = */ calloc(sizeof(hash_set.keys[0]), graph->visited_hash_table.size) // NOLINT
};
- struct ggml_tensor ** node_copies = calloc(sizeof(node_copies[0]) * hash_set.size, 1);
- bool * node_init = calloc(sizeof(node_init[0]) * hash_set.size, 1);
+ struct ggml_tensor ** node_copies = calloc(sizeof(node_copies[0]), hash_set.size); // NOLINT
+ bool * node_init = calloc(sizeof(node_init[0]), hash_set.size);
struct ggml_init_params params = {
/* .mem_size = */ ggml_tensor_overhead()*hash_set.size + ggml_graph_overhead_custom(graph->size, false),
// dup nodes
for (int i = 0; i < graph->n_nodes; i++) {
struct ggml_tensor * node = graph->nodes[i];
- graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, node);
+ graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, node);
}
// allocate nodes
// copy data and init views
for (int i = 0; i < graph->n_nodes; i++) {
struct ggml_tensor * node = graph->nodes[i];
- graph_init_tensor(hash_set, node_copies, node_init, node);
+ graph_copy_init_tensor(hash_set, node_copies, node_init, node);
}
// build graph copy
// in build_graph:
build_graph(...) {
- // allocating tensors in a specific backend (optional, recommended: pre-allocate inputs in a different buffer)
- alloc_cpu = ggml_backend_sched_get_allocr(sched, backend_cpu);
- ggml_allocr_alloc(alloc_cpu, tensor);
-
- // manually assigning nodes to a backend (optional, shouldn't be needed in most cases)
+ // manually assign nodes to a backend (optional, should not be needed in most cases)
struct ggml_tensor * node = ggml_mul_mat(ctx, ...);
ggml_backend_sched_set_node_backend(sched, node, backend_gpu);
}
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size);
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
// Initialize backend buffers from a measure graph
- GGML_API void ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
+ GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
// Get the number of splits of the last graph
GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
- GGML_API ggml_tallocr_t ggml_backend_sched_get_tallocr(ggml_backend_sched_t sched, ggml_backend_t backend);
- GGML_API ggml_backend_buffer_t ggml_backend_sched_get_buffer (ggml_backend_sched_t sched, ggml_backend_t backend);
+ GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
GGML_API void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
GGML_API ggml_backend_t ggml_backend_sched_get_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
// Allocate and compute graph on the backend scheduler
- GGML_API void ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
+ GGML_API bool ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
- // Reset all assignments and allocators - must be called before using the sched allocators to allocate inputs
+ // Reset all assignments and allocators - must be called before changing the node backends
GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched);
// Set a callback to be called for each resulting node during graph compute
/*.nb =*/ { 0, 0, 0, 0 },
/*.op =*/ GGML_OP_NONE,
/*.op_params =*/ { 0 },
- /*.is_param =*/ false,
+ /*.flags =*/ 0,
/*.grad =*/ NULL,
/*.src =*/ { NULL },
/*.perf_runs =*/ 0,
void ggml_set_param(
struct ggml_context * ctx,
struct ggml_tensor * tensor) {
- tensor->is_param = true;
+ tensor->flags |= GGML_TENSOR_FLAG_PARAM;
GGML_ASSERT(tensor->grad == NULL);
tensor->grad = ggml_dup_tensor(ctx, tensor);
return NULL;
}
- if (node->is_param) {
+ if (node->flags & GGML_TENSOR_FLAG_PARAM) {
return node;
}
clone->op = node->op;
clone->grad = node->grad;
- clone->is_param = node->is_param;
+ clone->flags = node->flags;
clone->extra = node->extra;
for (int k = 0; k < GGML_MAX_DIMS; ++k) {
clone->nb[k] = node->nb[k];
for (int i = 0; i < gf->n_nodes; i++) {
struct ggml_tensor * node = gf->nodes[i];
- if (node->is_param) {
+ if (node->flags & GGML_TENSOR_FLAG_PARAM) {
GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
ggml_build_forward_expand(gb, node->grad);
}
GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
i,
node->ne[0], node->ne[1], node->ne[2],
- ggml_op_name(node->op), node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
+ ggml_op_name(node->op), (node->flags & GGML_TENSOR_FLAG_PARAM) ? "x" : node->grad ? "g" : " ", node->perf_runs,
(double) node->perf_cycles / (double) ggml_cycles_per_ms(),
(double) node->perf_cycles / (double) ggml_cycles_per_ms() / (double) node->perf_runs,
(double) node->perf_time_us / 1000.0,
continue;
}
- if (node->is_param) {
+ if (node->flags & GGML_TENSOR_FLAG_PARAM) {
snprintf(color, sizeof(color), "yellow");
} else if (node->grad) {
if (ggml_graph_find(gf, node)) {
int np = 0;
int64_t nx = 0;
for (int i = 0; i < gf->n_nodes; ++i) {
- if (gf->nodes[i]->is_param) {
+ if (gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) {
GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op);
GGML_ASSERT(np < GGML_MAX_PARAMS);
int np = 0;
int nx = 0;
for (int i = 0; i < gf->n_nodes; ++i) {
- if (gf->nodes[i]->is_param) {
+ if (gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) {
GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op);
GGML_ASSERT(np < GGML_MAX_PARAMS);
////////////////////////////////////////////////////////////////////////////////
+void ggml_set_input(struct ggml_tensor * tensor) {
+ tensor->flags |= GGML_TENSOR_FLAG_INPUT;
+}
+
+void ggml_set_output(struct ggml_tensor * tensor) {
+ tensor->flags |= GGML_TENSOR_FLAG_OUTPUT;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
void ggml_quantize_init(enum ggml_type type) {
ggml_critical_section_start();
enum ggml_log_level {
GGML_LOG_LEVEL_ERROR = 2,
- GGML_LOG_LEVEL_WARN = 3,
- GGML_LOG_LEVEL_INFO = 4,
+ GGML_LOG_LEVEL_WARN = 3,
+ GGML_LOG_LEVEL_INFO = 4,
GGML_LOG_LEVEL_DEBUG = 5
};
+ enum ggml_tensor_flag {
+ GGML_TENSOR_FLAG_INPUT = 1,
+ GGML_TENSOR_FLAG_OUTPUT = 2,
+ GGML_TENSOR_FLAG_PARAM = 4,
+ };
+
// ggml object
struct ggml_object {
size_t offs;
// op params - allocated as int32_t for alignment
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
- bool is_param;
+ int32_t flags;
struct ggml_tensor * grad;
struct ggml_tensor * src[GGML_MAX_SRC];
ggml_opt_callback callback,
void * callback_data);
+ //
+ // tensor flags
+ //
+ GGML_API void ggml_set_input(struct ggml_tensor * tensor);
+ GGML_API void ggml_set_output(struct ggml_tensor * tensor);
+
//
// quantization
//
// memory buffers used to evaluate the model
std::vector<uint8_t> buf_compute_meta;
ggml_backend_sched_t sched = nullptr;
- // allocator for the input tensors
- ggml_tallocr * alloc = nullptr;
// input tensors
ggml_backend_buffer_t buf_input = nullptr;
static struct ggml_cgraph * llama_build_graph(
llama_context & lctx,
- const llama_batch & batch) {
+ const llama_batch & batch,
+ bool worst_case) {
const auto & model = lctx.model;
- // check if we should build the worst-case graph (for memory measurement)
- const bool worst_case = ggml_tallocr_is_measure(lctx.alloc);
-
// this callback allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
llm_build_cb cb = [&](struct ggml_tensor * cur, const char * name, int il) {
if (il >= 0) {
struct llm_build_context llm(lctx, batch, cb, worst_case);
- //
- // set input data
- //
-
- if (!ggml_tallocr_is_measure(lctx.alloc)) {
- if (batch.token) {
- const int64_t n_tokens = batch.n_tokens;
-
- ggml_backend_tensor_set(lctx.inp_tokens, batch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens));
- }
-
- if (batch.embd) {
- const int64_t n_embd = llm.n_embd;
- const int64_t n_tokens = batch.n_tokens;
-
- ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
- }
-
- if (batch.pos) {
- const int64_t n_tokens = batch.n_tokens;
-
- ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
- }
-
- {
- const int64_t n_kv = llm.n_kv;
- const int64_t n_tokens = batch.n_tokens;
-
- GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
- float * data = (float *) lctx.inp_KQ_mask->data;
-
- for (int h = 0; h < 1; ++h) {
- for (int j = 0; j < n_tokens; ++j) {
- const llama_pos pos = batch.pos[j];
- const llama_seq_id seq_id = batch.seq_id[j][0];
-
- for (int i = 0; i < n_kv; ++i) {
- float f;
- if (!lctx.kv_self.cells[i].has_seq_id(seq_id) ||
- (llm.causal_attn && lctx.kv_self.cells[i].pos > pos)) {
- f = -INFINITY;
- } else {
- f = 0;
- }
- data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
- }
- }
- }
- }
-
- if (llm.do_rope_shift) {
- const int64_t n_ctx = llm.n_ctx;
-
- GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
- int32_t * data = (int32_t *) lctx.inp_K_shift->data;
-
- for (int i = 0; i < n_ctx; ++i) {
- data[i] = lctx.kv_self.cells[i].delta;
- }
- }
-
- {
- GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_sum->buffer));
- float * data = (float *) lctx.inp_sum->data;
-
- for (int i = 0; i < batch.n_tokens; ++i) {
- data[i] = 1.0f/float(batch.n_tokens);
- }
- }
- }
-
llm.init();
switch (model.arch) {
return result;
}
+static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
+ //
+ // set input data
+ //
+
+ const auto & hparams = lctx.model.hparams;
+ const auto & cparams = lctx.cparams;
+ const auto & kv_self = lctx.kv_self;
+
+ if (batch.token) {
+ const int64_t n_tokens = batch.n_tokens;
+
+ ggml_backend_tensor_set(lctx.inp_tokens, batch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens));
+ }
+
+ if (batch.embd) {
+ const int64_t n_embd = hparams.n_embd;
+ const int64_t n_tokens = batch.n_tokens;
+
+ ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
+ }
+
+ if (batch.pos) {
+ const int64_t n_tokens = batch.n_tokens;
+
+ ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
+ }
+
+ {
+ const int64_t n_kv = kv_self.n;
+ const int64_t n_tokens = batch.n_tokens;
+
+ assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
+
+ float * data = (float *) lctx.inp_KQ_mask->data;
+
+ for (int h = 0; h < 1; ++h) {
+ for (int j = 0; j < n_tokens; ++j) {
+ const llama_pos pos = batch.pos[j];
+ const llama_seq_id seq_id = batch.seq_id[j][0];
+
+ for (int i = 0; i < n_kv; ++i) {
+ float f;
+ if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
+ f = -INFINITY;
+ } else {
+ f = 0;
+ }
+ data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
+ }
+ }
+ }
+ }
+
+
+ {
+ assert(ggml_backend_buffer_is_host(lctx.inp_sum->buffer));
+ float * data = (float *) lctx.inp_sum->data;
+
+ for (int i = 0; i < batch.n_tokens; ++i) {
+ data[i] = 1.0f/float(batch.n_tokens);
+ }
+ }
+
+ if (kv_self.has_shift) {
+ const int64_t n_ctx = cparams.n_ctx;
+
+ assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
+
+ int32_t * data = (int32_t *) lctx.inp_K_shift->data;
+
+ for (int i = 0; i < n_ctx; ++i) {
+ data[i] = lctx.kv_self.cells[i].delta;
+ }
+ }
+}
+
// decode a batch of tokens by evaluating the transformer
//
// - lctx: llama context
ggml_backend_sched_reset(lctx.sched);
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
- ggml_cgraph * gf = llama_build_graph(lctx, batch);
+ ggml_cgraph * gf = llama_build_graph(lctx, batch, false);
// the output is always the last tensor in the graph
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
if (lctx.backend_cpu != nullptr) {
ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads);
}
+
+ llama_set_inputs(lctx, batch);
+
ggml_backend_sched_graph_compute(lctx.sched, gf);
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead());
ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), LLAMA_MAX_NODES);
- ctx->alloc = ggml_backend_sched_get_tallocr(ctx->sched, ctx->backend_cpu);
// build worst-case graph
int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_batch);
int n_past = cparams.n_ctx - n_tokens;
llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
- ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0));
+ ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true);
// initialize scheduler with the worst-case graph
- ggml_backend_sched_init_measure(ctx->sched, gf);
- ctx->alloc = ggml_backend_sched_get_tallocr(ctx->sched, ctx->backend_cpu);
+ if (!ggml_backend_sched_reserve(ctx->sched, gf)) {
+ LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
+ llama_free(ctx);
+ return nullptr;
+ }
- for (ggml_backend_t backend : ctx->backends) {
- ggml_backend_buffer_t buf = ggml_backend_sched_get_buffer(ctx->sched, backend);
+ for (size_t i = 0; i < ctx->backends.size(); i++) {
+ ggml_backend_t backend = ctx->backends[i];
+ ggml_backend_buffer_type_t buft = backend_buft[i];
+ size_t size = ggml_backend_sched_get_buffer_size(ctx->sched, backend);
LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
- ggml_backend_buffer_name(buf),
- ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0);
+ ggml_backend_buft_name(buft),
+ size / 1024.0 / 1024.0);
}
// note: the number of splits during measure is higher than during inference due to the kv shift
-2c7cf49810d523b9632da393a9e8270b60bf3b24
+5070f078a67c18c11736e78316ab715ca9afde16