#include "ggml/ggml.h"
+#include "ggml/ggml-alloc.h"
#include "common.h"
#include "common-ggml.h"
return true;
}
-// evaluate the transformer
-//
-// - model: the model
-// - n_threads: number of threads to use
-// - n_past: the context size so far
-// - embd_inp: the embeddings of the tokens in the context
-// - embd_w: the predicted logits for the next token
-//
-bool gpt2_eval(
+// build the computation graph
+struct ggml_cgraph * gpt2_graph(
const gpt2_model & model,
- const int n_threads,
+ struct ggml_allocr * allocr,
const int n_past,
- const std::vector<gpt_vocab::id> & embd_inp,
- std::vector<float> & embd_w,
- size_t & mem_per_token) {
+ const std::vector<gpt_vocab::id> & embd_inp) {
const int N = embd_inp.size();
const auto & hparams = model.hparams;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_head = hparams.n_head;
- const int n_vocab = hparams.n_vocab;
- static size_t buf_size = 256u*1024*1024;
- static void * buf = malloc(buf_size);
-
- if (mem_per_token > 0 && mem_per_token*N > buf_size) {
- const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
- //printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
-
- // reallocate
- buf_size = buf_size_new;
- buf = realloc(buf, buf_size);
- if (buf == nullptr) {
- fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
- return false;
- }
- }
+ // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data
+ static size_t buf_size = ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead();
+ static std::vector<uint8_t> buf(buf_size);
struct ggml_init_params params = {
/*.mem_size =*/ buf_size,
- /*.mem_buffer =*/ buf,
- /*.no_alloc =*/ false,
+ /*.mem_buffer =*/ buf.data(),
+ /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_allocr_alloc_graph()
};
struct ggml_context * ctx0 = ggml_init(params);
- struct ggml_cgraph gf = {};
+
+ struct ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
- memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
+ ggml_allocr_alloc(allocr, embd);
+
+ // avoid writing to tensors if we are only measuring the memory usage
+ if (!ggml_allocr_is_measure(allocr)) {
+ memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
+ }
struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
- for (int i = 0; i < N; ++i) {
- ((int32_t *) position->data)[i] = n_past + i;
+ ggml_allocr_alloc(allocr, position);
+ if (!ggml_allocr_is_measure(allocr)) {
+ for (int i = 0; i < N; ++i) {
+ ((int32_t *) position->data)[i] = n_past + i;
+ }
+ }
+
+ struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+ ggml_allocr_alloc(allocr, KQ_scale);
+ if (!ggml_allocr_is_measure(allocr)) {
+ ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
}
// wte + wpe
struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
- ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
- ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
}
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
// KQ_scaled = KQ / sqrt(n_embd/n_head)
// [n_past + N, N, 12]
struct ggml_tensor * KQ_scaled =
- ggml_scale_inplace(ctx0,
+ ggml_scale(ctx0,
KQ,
- ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
- );
+ KQ_scale);
// KQ_masked = mask_past(KQ_scaled)
// [n_past + N, N, 12]
- struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
+ struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
// KQ = soft_max(KQ_masked)
// [n_past + N, N, 12]
- struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
// [n_past + N, 64, 12]
inpL = ggml_mul_mat(ctx0, model.lm_head, inpL);
// logits -> probs
- //inpL = ggml_soft_max_inplace(ctx0, inpL);
+ //inpL = ggml_soft_max(ctx0, inpL);
+
+ ggml_build_forward_expand(gf, inpL);
+
+ ggml_free(ctx0);
+
+ return gf;
+}
+
+// evaluate the transformer
+//
+// - model: the model
+// - allocr: ggml_allocr to use to allocate the compute buffer
+// - n_threads: number of threads to use
+// - n_past: the context size so far
+// - embd_inp: the embeddings of the tokens in the context
+// - embd_w: the predicted logits for the next token
+//
+bool gpt2_eval(
+ const gpt2_model & model,
+ struct ggml_allocr * allocr,
+ const int n_threads,
+ const int n_past,
+ const std::vector<gpt_vocab::id> & embd_inp,
+ std::vector<float> & embd_w) {
+ const int N = embd_inp.size();
+
+ const auto & hparams = model.hparams;
+
+ const int n_vocab = hparams.n_vocab;
+
+ // reset the allocator to free all the memory allocated during the previous inference
+ ggml_allocr_reset(allocr);
+
+ struct ggml_cgraph * gf = gpt2_graph(model, allocr, n_past, embd_inp);
+
+ // allocate tensors
+ ggml_allocr_alloc_graph(allocr, gf);
// run the computation
- ggml_build_forward_expand(&gf, inpL);
- ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+ struct ggml_cplan plan = ggml_graph_plan(gf, n_threads);
+ static std::vector<uint8_t> work_buffer;
+ work_buffer.resize(plan.work_size);
+ plan.work_data = work_buffer.data();
+ ggml_graph_compute(gf, &plan);
//if (n_past%100 == 0) {
// ggml_graph_print (&gf);
// ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
//}
+ // in this case, the output tensor is the last one in the graph
+ struct ggml_tensor * inpL = gf->nodes[gf->n_nodes - 1];
+
//embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
embd_w.resize(n_vocab);
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
- if (mem_per_token == 0) {
- mem_per_token = ggml_used_mem(ctx0)/N;
- }
- //printf("used_mem = %zu\n", ggml_used_mem(ctx0));
-
- ggml_free(ctx0);
-
return true;
}
test_gpt_tokenizer(vocab, params.token_test);
}
+ // keep this buffer alive while evaluating the model
+ std::vector<uint8_t> compute_buffer;
+
+ struct ggml_allocr * allocr = NULL;
+ // allocate the compute buffer
+ {
+ allocr = ggml_allocr_new_measure(GGML_MEM_ALIGN);
+
+ // create the worst case graph for memory usage estimation
+ int n_tokens = std::min(model.hparams.n_ctx, params.n_batch);
+ int n_past = model.hparams.n_ctx - n_tokens;
+ struct ggml_cgraph * gf = gpt2_graph(model, allocr, n_past, std::vector<gpt_vocab::id>(n_tokens, 0));
+
+ // compute the required memory
+ size_t mem_size = ggml_allocr_alloc_graph(allocr, gf) + GGML_MEM_ALIGN;
+
+ // recreate the allocator with the required memory
+ ggml_allocr_free(allocr);
+ compute_buffer.resize(mem_size);
+ allocr = ggml_allocr_new(compute_buffer.data(), mem_size, GGML_MEM_ALIGN);
+
+ fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0/1024.0);
+ }
+
int n_past = 0;
int64_t t_sample_us = 0;
// this reduces the memory usage during inference, at the cost of a bit of speed at the beginning
std::vector<gpt_vocab::id> embd;
- // determine the required inference memory per token:
- size_t mem_per_token = 0;
- gpt2_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
-
for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
// predict
if (embd.size() > 0) {
const int64_t t_start_us = ggml_time_us();
- if (!gpt2_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {
+ if (!gpt2_eval(model, allocr, params.n_threads, n_past, embd, logits)) {
printf("Failed to predict\n");
return 1;
}
const int64_t t_main_end_us = ggml_time_us();
printf("\n\n");
- printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token);
printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f);
printf("%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past);