--- /dev/null
+#include "ggml-opt.h"
+
+#include "ggml.h"
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
+#include "ggml-impl.h"
+
+#include <algorithm>
+#include <cmath>
+#include <cstdint>
+#include <inttypes.h>
+#include <map>
+#include <random>
+#include <vector>
+
+struct ggml_opt_dataset {
+ struct ggml_context * ctx;
+ ggml_backend_buffer_t buf;
+ struct ggml_tensor * data;
+ struct ggml_tensor * labels;
+
+ int64_t ndata;
+ int64_t ndata_shard;
+ size_t nbs_data;
+ size_t nbs_labels;
+
+ std::vector<int64_t> permutation;
+};
+
+struct ggml_opt_context {
+ ggml_backend_sched_t backend_sched;
+ ggml_cgraph * allocated_graph;
+ ggml_cgraph * allocated_graph_copy;
+ struct ggml_context * ctx_static;
+ struct ggml_context * ctx_static_cpu;
+ struct ggml_context * ctx_compute;
+ struct ggml_context * ctx_copy;
+ ggml_backend_buffer_t buf_static;
+ ggml_backend_buffer_t buf_static_cpu;
+ std::mt19937 rng;
+
+ struct ggml_tensor * inputs;
+ struct ggml_tensor * outputs;
+ struct ggml_tensor * labels;
+
+ struct ggml_tensor * loss;
+ struct ggml_tensor * pred;
+ struct ggml_tensor * ncorrect;
+
+ struct ggml_cgraph * gf;
+ struct ggml_cgraph * gb_grad;
+ struct ggml_cgraph * gb_opt;
+
+ int64_t iter;
+ int32_t opt_period;
+ int32_t opt_i;
+ bool loss_per_datapoint;
+
+ ggml_opt_get_optimizer_params get_opt_pars;
+ void * get_opt_pars_ud;
+ struct ggml_tensor * adamw_params;
+};
+
+struct ggml_opt_result {
+ int64_t ndata = 0;
+ std::vector<float> loss;
+ std::vector<int32_t> pred;
+ int64_t ncorrect = 0;
+
+ bool loss_per_datapoint = false;
+ int64_t opt_period = -1;
+};
+
+// ====== Dataset ======
+
+ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label, int64_t ndata, int64_t ndata_shard) {
+ GGML_ASSERT(ne_datapoint > 0);
+ GGML_ASSERT(ne_label >= 0);
+ GGML_ASSERT(ndata > 0);
+ GGML_ASSERT(ndata_shard > 0);
+
+ ggml_opt_dataset_t result = new ggml_opt_dataset;
+ result->ndata = ndata;
+ result->ndata_shard = ndata_shard;
+
+ {
+ struct ggml_init_params params = {
+ /*.mem_size =*/ 2*ggml_tensor_overhead(),
+ /*.mem_buffer =*/ nullptr,
+ /*.no_alloc =*/ true,
+ };
+ result->ctx = ggml_init(params);
+ }
+
+ result->data = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_datapoint, ndata);
+ result->nbs_data = ggml_nbytes(result->data) * ndata_shard/ndata;
+
+ if (ne_label > 0) {
+ result->labels = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_label, ndata);
+ result->nbs_labels = ggml_nbytes(result->labels) * ndata_shard/ndata;
+ } else {
+ result->labels = nullptr;
+ result->nbs_labels = 0;
+ }
+
+ result->buf = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx, ggml_backend_cpu_buffer_type());
+
+ const int64_t nshards = ndata/ndata_shard;
+ result->permutation.resize(nshards);
+ for (int64_t i = 0; i < nshards; ++i) {
+ result->permutation[i] = i;
+ }
+ return result;
+}
+
+void ggml_opt_dataset_free(ggml_opt_dataset_t dataset) {
+ ggml_backend_buffer_free(dataset->buf);
+ ggml_free(dataset->ctx);
+ delete dataset;
+}
+
+struct ggml_tensor * ggml_opt_dataset_data(ggml_opt_dataset_t dataset) {
+ return dataset->data;
+}
+
+struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset) {
+ return dataset->labels;
+}
+
+void ggml_opt_dataset_shuffle(ggml_opt_context_t opt_ctx, ggml_opt_dataset_t dataset, int64_t idata) {
+ GGML_ASSERT(idata <= dataset->ndata);
+
+ if (idata < 0) {
+ std::shuffle(dataset->permutation.begin(), dataset->permutation.end(), opt_ctx->rng);
+ return;
+ }
+
+ GGML_ASSERT(idata % dataset->ndata_shard == 0);
+ const int64_t ishard_max = idata / dataset->ndata_shard;
+ std::shuffle(dataset->permutation.begin(), dataset->permutation.begin() + ishard_max, opt_ctx->rng);
+}
+
+void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor * data_batch, struct ggml_tensor * labels_batch, int64_t ibatch) {
+ GGML_ASSERT( data_batch && ggml_is_contiguous(data_batch));
+ GGML_ASSERT(!labels_batch || ggml_is_contiguous(labels_batch));
+ GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
+
+ const size_t nb_data_batch = ggml_nbytes(data_batch);
+ GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
+ const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data;
+
+ if (labels_batch) {
+ const size_t nb_labels_batch = ggml_nbytes(labels_batch);
+ GGML_ASSERT(nb_labels_batch == shards_per_batch*dataset->nbs_labels);
+ }
+
+ GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size()));
+
+ for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) {
+ const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch];
+
+ const char * ptr_data = (const char *) dataset->data->data + ishard*dataset->nbs_data;
+ ggml_backend_tensor_set(data_batch, ptr_data, ishard_batch*dataset->nbs_data, dataset->nbs_data);
+
+ if (!labels_batch) {
+ continue;
+ }
+
+ const char * ptr_labels = (const char *) dataset->labels->data + ishard*dataset->nbs_labels;
+ ggml_backend_tensor_set(labels_batch, ptr_labels, ishard_batch*dataset->nbs_labels, dataset->nbs_labels);
+ }
+}
+
+// ====== Model / Context ======
+
+struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata) {
+ GGML_UNUSED(userdata);
+
+ ggml_opt_optimizer_params result;
+
+ result.adamw.alpha = 0.001f;
+ result.adamw.beta1 = 0.9f;
+ result.adamw.beta2 = 0.999f;
+ result.adamw.eps = 1e-8f;
+ result.adamw.wd = 0.0f;
+
+ return result;
+}
+
+struct ggml_opt_params ggml_opt_default_params(
+ ggml_backend_sched_t backend_sched,
+ struct ggml_context * ctx_compute,
+ struct ggml_tensor * inputs,
+ struct ggml_tensor * outputs,
+ enum ggml_opt_loss_type loss_type) {
+ return {
+ /*backend_sched =*/ backend_sched,
+ /*ctx_compute =*/ ctx_compute,
+ /*inputs =*/ inputs,
+ /*logits =*/ outputs,
+ /*loss_type =*/ loss_type,
+ /*build_type =*/ GGML_OPT_BUILD_TYPE_OPT,
+ /*opt_period =*/ 1,
+ /*get_opt_pars =*/ ggml_opt_get_default_optimizer_params,
+ /*get_opt_pars_ud =*/ nullptr,
+ };
+}
+
+static ggml_tensor * map_tensor(std::map<ggml_tensor *, ggml_tensor *> & tensor_map, ggml_context * ctx, ggml_tensor * tensor) {
+ if (!tensor) {
+ return nullptr;
+ }
+
+ if (tensor_map.find(tensor) != tensor_map.end()) {
+ return tensor_map[tensor];
+ }
+
+ ggml_tensor * new_tensor = ggml_dup_tensor(ctx, tensor);
+ tensor_map[tensor] = new_tensor;
+
+ new_tensor->op = tensor->op;
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
+ new_tensor->nb[i] = tensor->nb[i];
+ }
+ new_tensor->flags = tensor->flags;
+ memcpy(new_tensor->op_params, tensor->op_params, sizeof(tensor->op_params));
+ strcpy(new_tensor->name, tensor->name);
+ new_tensor->data = tensor->data;
+ new_tensor->buffer = tensor->buffer;
+ new_tensor->extra = tensor->extra;
+ new_tensor->view_offs = tensor->view_offs;
+ new_tensor->view_src = map_tensor(tensor_map, ctx, tensor->view_src);
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
+ new_tensor->src[i] = map_tensor(tensor_map, ctx, tensor->src[i]);
+ }
+
+ return new_tensor;
+}
+
+static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * graph) {
+ std::map<ggml_tensor *, ggml_tensor *> tensor_map;
+
+ ggml_cgraph * new_graph = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true);
+
+ for (int i = 0; i < graph->n_leafs; i++) {
+ ggml_build_forward_expand(new_graph, map_tensor(tensor_map, ctx, graph->leafs[i]));
+ }
+ for (int i = 0; i < graph->n_nodes; i++) {
+ ggml_build_forward_expand(new_graph, map_tensor(tensor_map, ctx, graph->nodes[i]));
+ }
+ for (int i = 0; i < graph->n_nodes; ++i) {
+ const size_t igrad_src = ggml_hash_find(&graph->visited_hash_set, graph->nodes[i]);
+ const size_t igrad_dst = ggml_hash_find(&new_graph->visited_hash_set, new_graph->nodes[i]);
+ graph->grads[igrad_dst] = new_graph->grads[igrad_src];
+ graph->grad_accs[igrad_dst] = new_graph->grad_accs[igrad_src];
+ }
+
+ return new_graph;
+}
+
+static void ggml_opt_alloc_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph) {
+ GGML_ASSERT(graph);
+ if (opt_ctx->allocated_graph == graph) {
+ return;
+ }
+
+ ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph
+
+ {
+ ggml_init_params params = {
+ /*.mem_size =*/ ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE,
+ /*.mem_buffer =*/ nullptr,
+ /*.no_alloc =*/ true,
+ };
+ ggml_free(opt_ctx->ctx_copy);
+ opt_ctx->ctx_copy = ggml_init(params);
+ }
+
+ opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);
+
+ ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
+ opt_ctx->allocated_graph = graph;
+}
+
+ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
+ ggml_opt_context_t result = new struct ggml_opt_context;
+ result->backend_sched = params.backend_sched;
+ result->allocated_graph = nullptr;
+ result->allocated_graph_copy = nullptr;
+ result->ctx_compute = params.ctx_compute;
+ result->ctx_copy = nullptr;
+ result->inputs = params.inputs;
+ result->outputs = params.outputs;
+ result->iter = 1;
+ result->opt_period = params.opt_period;
+ result->opt_i = 0;
+ result->get_opt_pars = params.get_opt_pars;
+ result->get_opt_pars_ud = params.get_opt_pars_ud;
+
+ GGML_ASSERT(result->inputs->data && "the inputs must be allocated statically");
+ GGML_ASSERT(result->opt_period >= 1);
+
+ const bool accumulate = params.build_type == GGML_OPT_BUILD_TYPE_GRAD ||
+ (params.build_type == GGML_OPT_BUILD_TYPE_OPT && result->opt_period > 1);
+
+ ggml_set_input(result->inputs);
+ ggml_set_output(result->outputs);
+
+ result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
+ ggml_build_forward_expand(result->gf, result->outputs);
+
+ int n_param = 0;
+ for (int i = 0; i < result->gf->n_nodes; ++i) {
+ if (result->gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) {
+ n_param++;
+ }
+ }
+
+ {
+ // The static context is used for:
+ // - gradients (1 tensor per param if using gradient accumulation)
+ // - optimizer momenta (2 tensors per param)
+ // - labels
+ // - loss + its gradient (up to 5 tensors)
+ // - pred
+ // - ncorrect (2 tensors).
+ const size_t tensors_per_param = (accumulate ? 1 : 0) + (params.build_type == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
+ const size_t size_meta = (tensors_per_param*n_param + 9) * ggml_tensor_overhead();
+ struct ggml_init_params params = {
+ /*.mem_size =*/ size_meta,
+ /*.mem_buffer =*/ nullptr,
+ /*.no_alloc =*/ true,
+ };
+ result->ctx_static = ggml_init(params);
+ }
+ {
+ // The static cpu context is used for:
+ // - optimizer parameters (1 for the entire context)
+ const size_t size_meta = 1 * ggml_tensor_overhead();
+ struct ggml_init_params params = {
+ /*.mem_size =*/ size_meta,
+ /*.mem_buffer =*/ nullptr,
+ /*.no_alloc =*/ true,
+ };
+ result->ctx_static_cpu = ggml_init(params);
+ }
+
+
+ switch (params.loss_type) {
+ case GGML_OPT_LOSS_TYPE_MEAN: {
+ result->labels = nullptr;
+ result->loss = ggml_sum(result->ctx_static, result->outputs);
+ ggml_set_name(result->loss, "loss_sum");
+ const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs));
+ result->loss = ggml_scale(result->ctx_static, result->loss, scale);
+ ggml_set_name(result->loss, "loss_mean");
+ result->loss_per_datapoint = true;
+ break;
+ }
+ case GGML_OPT_LOSS_TYPE_SUM: {
+ result->labels = nullptr;
+ result->loss = ggml_sum(result->ctx_static, result->outputs);
+ ggml_set_name(result->loss, "loss_sum");
+ result->loss_per_datapoint = false;
+ break;
+ }
+ case GGML_OPT_LOSS_TYPE_CROSS_ENTROPY: {
+ result->labels = ggml_dup_tensor(result->ctx_static, result->outputs);
+ ggml_set_input(result->labels);
+ ggml_set_name(result->labels, "labels");
+ result->loss = ggml_cross_entropy_loss(result->ctx_static, result->outputs, result->labels);
+ ggml_set_name(result->loss, "loss_cross_entropy");
+ if (result->opt_period > 1) {
+ result->loss = ggml_scale(result->ctx_static, result->loss, 1.0f / result->opt_period);
+ ggml_set_name(result->loss, "loss_cross_entropy_scaled");
+ }
+ result->loss_per_datapoint = true;
+ break;
+ }
+ case GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: {
+ result->labels = ggml_dup_tensor(result->ctx_static, result->outputs);
+ ggml_set_input(result->labels);
+ ggml_set_name(result->labels, "labels");
+ result->loss = ggml_sub(result->ctx_static, result->outputs, result->labels);
+ ggml_set_name(result->loss, "loss_error");
+ result->loss = ggml_sqr(result->ctx_static, result->loss);
+ ggml_set_name(result->loss, "loss_squared_error");
+ result->loss = ggml_sum(result->ctx_static, result->loss);
+ ggml_set_name(result->loss, "loss_sum_squared_error");
+ const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs));
+ result->loss = ggml_scale(result->ctx_static, result->loss, scale);
+ ggml_set_name(result->loss, "loss_mean_squared_error");
+ result->loss_per_datapoint = true;
+ break;
+ }
+ }
+ ggml_set_output(result->loss);
+ ggml_set_loss(result->loss);
+ ggml_build_forward_expand(result->gf, result->loss);
+
+ result->pred = ggml_argmax(result->ctx_static, result->outputs);
+ ggml_set_name(result->pred, "pred");
+ ggml_set_output(result->pred);
+ ggml_build_forward_expand(result->gf, result->pred);
+
+ if (result->labels) {
+ result->ncorrect = ggml_count_equal(result->ctx_static, result->pred, ggml_argmax(result->ctx_static, result->labels));
+ ggml_set_name(result->ncorrect, "ncorrect");
+ ggml_set_output(result->ncorrect);
+ ggml_build_forward_expand(result->gf, result->ncorrect);
+ } else {
+ result->ncorrect = nullptr;
+ }
+
+ if (params.build_type == GGML_OPT_BUILD_TYPE_FORWARD) {
+ result->gb_grad = nullptr;
+ result->gb_opt = nullptr;
+
+ result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
+ result->buf_static_cpu = nullptr;
+
+ ggml_opt_alloc_graph(result, result->gf);
+
+ return result;
+ }
+
+ // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
+ result->gb_grad = ggml_graph_dup(result->ctx_compute, result->gf);
+ ggml_build_backward_expand(result->ctx_static, result->ctx_compute, result->gb_grad, accumulate);
+
+ if (params.build_type == GGML_OPT_BUILD_TYPE_GRAD) {
+ result->gb_opt = nullptr;
+
+ result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
+ result->buf_static_cpu = nullptr;
+
+ ggml_opt_alloc_graph(result, result->gb_grad);
+ ggml_graph_reset(result->gb_grad);
+
+ return result;
+ }
+
+ GGML_ASSERT(params.build_type == GGML_OPT_BUILD_TYPE_OPT);
+
+ // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
+ result->gb_opt = ggml_graph_dup(result->ctx_compute, result->gb_grad);
+
+ result->adamw_params = ggml_new_tensor_1d(result->ctx_static_cpu, GGML_TYPE_F32, 7);
+ ggml_set_input(result->adamw_params);
+ ggml_set_name(result->adamw_params, "adamw_params");
+
+ for (int i = result->gf->n_nodes-1; i >= 0; --i) {
+ struct ggml_tensor * node = result->gb_opt->nodes[i];
+ struct ggml_tensor * grad = ggml_graph_get_grad(result->gb_opt, node);
+
+ if (node->flags & GGML_TENSOR_FLAG_PARAM) {
+ struct ggml_tensor * m = ggml_dup_tensor(result->ctx_static, node);
+ struct ggml_tensor * v = ggml_dup_tensor(result->ctx_static, node);
+ struct ggml_tensor * opt_step = ggml_opt_step_adamw(result->ctx_compute, node, grad, m, v, result->adamw_params);
+ ggml_build_forward_expand(result->gb_opt, opt_step);
+ }
+ }
+
+ result->buf_static = ggml_backend_alloc_ctx_tensors(
+ result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
+
+ result->buf_static_cpu = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx_static_cpu, ggml_backend_cpu_buffer_type());
+
+ ggml_opt_alloc_graph(result, result->gb_opt);
+ ggml_graph_reset(result->gb_opt);
+
+ return result;
+}
+
+void ggml_opt_free(ggml_opt_context_t opt_ctx) {
+ if (opt_ctx == nullptr) {
+ return;
+ }
+ ggml_backend_buffer_free(opt_ctx->buf_static);
+ ggml_backend_buffer_free(opt_ctx->buf_static_cpu);
+ ggml_free(opt_ctx->ctx_static);
+ ggml_free(opt_ctx->ctx_static_cpu);
+ delete opt_ctx;
+}
+
+void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer) {
+ if (optimizer) {
+ ggml_graph_reset(opt_ctx->gb_opt);
+ opt_ctx->iter = 1;
+ } else {
+ ggml_graph_reset(opt_ctx->gb_grad);
+ }
+}
+
+struct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) {
+ return opt_ctx->inputs;
+}
+
+struct ggml_tensor * ggml_opt_outputs(ggml_opt_context_t opt_ctx) {
+ return opt_ctx->outputs;
+}
+
+struct ggml_tensor * ggml_opt_labels(ggml_opt_context_t opt_ctx) {
+ return opt_ctx->labels;
+}
+
+struct ggml_tensor * ggml_opt_loss(ggml_opt_context_t opt_ctx) {
+ return opt_ctx->loss;
+}
+
+struct ggml_tensor * ggml_opt_pred(ggml_opt_context_t opt_ctx) {
+ return opt_ctx->pred;
+}
+
+struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx) {
+ return opt_ctx->ncorrect;
+}
+
+struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node) {
+ return ggml_graph_get_grad_acc(opt_ctx->gb_opt, node);
+}
+
+// ====== Optimization Result ======
+
+ggml_opt_result_t ggml_opt_result_init() {
+ return new ggml_opt_result;
+}
+
+void ggml_opt_result_free(ggml_opt_result_t result) {
+ delete result;
+}
+
+void ggml_opt_result_reset(ggml_opt_result_t result) {
+ result->ndata = 0;
+ result->loss.clear();
+ result->pred.clear();
+ result->ncorrect = 0;
+}
+
+void ggml_opt_result_ndata(ggml_opt_result_t result, int64_t * ndata) {
+ *ndata = result->ndata;
+}
+
+void ggml_opt_result_loss(ggml_opt_result_t result, double * loss, double * unc) {
+ const int64_t nbatches = result->loss.size(); // Number of physical batches.
+
+ if (nbatches == 0) {
+ *loss = 0.0;
+ *unc = NAN;
+ return;
+ }
+
+ double sum = 0.0;
+ double sum_squared = 0.0;
+
+ for (const float & loss : result->loss) {
+ // If the loss is per datapoint it was scaled by 1.0f/opt_period for each physical batch.
+ const float loss_scaled = result->loss_per_datapoint ? loss*result->opt_period : loss;
+ sum += loss_scaled;
+ sum_squared += loss_scaled*loss_scaled;
+ }
+
+ const double mean = sum/nbatches;
+ *loss = result->loss_per_datapoint ? mean : sum;
+
+ if (!unc) {
+ return;
+ }
+
+ if (nbatches < 2) {
+ *unc = NAN;
+ return;
+ }
+
+ const double var_sum = sum_squared/nbatches - mean*mean; // variance without Bessel's correction, i.e. nbatches/(nbatches-1)
+ *unc = result->loss_per_datapoint ? sqrt(var_sum / (nbatches - 1)) : sqrt(var_sum * nbatches/(nbatches - 1));
+}
+
+void ggml_opt_result_pred(ggml_opt_result_t result, int32_t * pred) {
+ for (size_t i = 0; i < result->pred.size(); ++i) {
+ pred[i] = result->pred[i];
+ }
+}
+
+void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, double * unc) {
+ *accuracy = result->ncorrect >= 0 ? double(result->ncorrect) / double(result->ndata) : NAN;
+
+ if (!unc) {
+ return;
+ }
+
+ *unc = result->ncorrect >= 0 && result->ndata >= 2 ?
+ sqrt((*accuracy) * (1.0 - (*accuracy)) / double(result->ndata - 1)) : NAN;
+}
+
+// ====== Computation ======
+
+static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, ggml_opt_result * result) {
+ if (graph != opt_ctx->gf) {
+ struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
+
+ GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
+ GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
+ GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
+ GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f);
+ GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f);
+ GGML_ASSERT(opt_pars.adamw.eps >= 0.0f);
+ GGML_ASSERT(opt_pars.adamw.wd >= 0.0f);
+ GGML_ASSERT(opt_pars.adamw.wd <= 1.0f);
+
+ // beta1, beta2 after applying warmup
+ const float beta1h = 1.0f/(1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
+ const float beta2h = 1.0f/(1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
+
+ float * adamw_par_data = ggml_get_data_f32(opt_ctx->adamw_params);
+ adamw_par_data[0] = opt_pars.adamw.alpha;
+ adamw_par_data[1] = opt_pars.adamw.beta1;
+ adamw_par_data[2] = opt_pars.adamw.beta2;
+ adamw_par_data[3] = opt_pars.adamw.eps;
+ adamw_par_data[4] = opt_pars.adamw.wd;
+ adamw_par_data[5] = beta1h;
+ adamw_par_data[6] = beta2h;
+ }
+
+ ggml_opt_alloc_graph(opt_ctx, graph);
+ ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
+ opt_ctx->iter += opt_ctx->allocated_graph == opt_ctx->gb_opt;
+
+ if (!result) {
+ return;
+ }
+
+ if (result->ndata == 0) {
+ result->loss_per_datapoint = opt_ctx->loss_per_datapoint;
+ result->opt_period = opt_ctx->opt_period;
+ } else {
+ GGML_ASSERT(result->loss_per_datapoint == opt_ctx->loss_per_datapoint);
+ GGML_ASSERT(result->opt_period == opt_ctx->opt_period);
+ }
+
+ const int64_t ndata = opt_ctx->outputs->ne[1];
+ GGML_ASSERT(result->ndata == ndata*int64_t(result->loss.size()) && "varying batch size not supported");
+ result->ndata += ndata;
+
+ GGML_ASSERT(ggml_is_scalar(opt_ctx->loss));
+ GGML_ASSERT(opt_ctx->loss->type == GGML_TYPE_F32);
+ float loss;
+ ggml_backend_tensor_get(opt_ctx->loss, &loss, 0, ggml_nbytes(opt_ctx->loss));
+ result->loss.push_back(loss);
+
+ GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32);
+ std::vector<int32_t> pred(ndata);
+ ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred));
+ result->pred.insert(result->pred.end(), pred.begin(), pred.end());
+
+ if (!opt_ctx->labels || result->ncorrect < 0) {
+ result->ncorrect = -1;
+ return;
+ }
+
+ GGML_ASSERT(ggml_is_scalar(opt_ctx->ncorrect));
+ GGML_ASSERT(opt_ctx->ncorrect->type == GGML_TYPE_I64);
+ int64_t ncorrect;
+ ggml_backend_tensor_get(opt_ctx->ncorrect, &ncorrect, 0, ggml_nbytes(opt_ctx->ncorrect));
+ result->ncorrect += ncorrect;
+}
+
+void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) {
+ ggml_opt_eval_graph(opt_ctx, opt_ctx->gf, result);
+}
+
+void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) {
+ if (opt_ctx->opt_period == 1) {
+ ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
+ return;
+ }
+
+ const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
+ if (opt_i_next == 0) {
+ ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
+ ggml_opt_reset(opt_ctx, /*optimizer =*/ false);
+ } else {
+ ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_grad, result);
+ }
+ opt_ctx->opt_i = opt_i_next;
+}
+
+// ====== High-Level Functions ======
+
+void ggml_opt_epoch(
+ ggml_opt_context_t opt_ctx,
+ ggml_opt_dataset_t dataset,
+ ggml_opt_result_t result_train,
+ ggml_opt_result_t result_eval,
+ int64_t idata_split,
+ ggml_opt_epoch_callback callback_train,
+ ggml_opt_epoch_callback callback_eval) {
+ struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx);
+ struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
+ struct ggml_tensor * data = ggml_opt_dataset_data(dataset);
+ GGML_ASSERT(data->ne[0] == inputs->ne[0]);
+
+ const int64_t ndata = data->ne[1];
+ const int64_t ndata_batch = inputs->ne[1];
+
+ GGML_ASSERT(data->ne[1] % inputs->ne[1] == 0);
+ const int64_t nbatches = ndata/ndata_batch;
+
+ idata_split = idata_split < 0 ? ndata : idata_split;
+ GGML_ASSERT(idata_split % ndata_batch == 0);
+ const int64_t ibatch_split = idata_split / ndata_batch;
+
+ int64_t ibatch = 0;
+ int64_t t_loop_start = ggml_time_us();
+ for (; ibatch < ibatch_split; ++ibatch) {
+ ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
+ ggml_opt_forward_backward(opt_ctx, result_train);
+ if (callback_train) {
+ callback_train(true, opt_ctx, dataset, result_train, ibatch+1, ibatch_split, t_loop_start);
+ }
+ }
+ t_loop_start = ggml_time_us();
+ for (; ibatch < nbatches; ++ibatch) {
+ ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
+ ggml_opt_forward(opt_ctx, result_eval);
+ if (callback_eval) {
+ callback_eval(false, opt_ctx, dataset, result_eval, ibatch+1-ibatch_split, nbatches-ibatch_split, t_loop_start);
+ }
+ }
+}
+
+void ggml_opt_epoch_callback_progress_bar(
+ bool train,
+ ggml_opt_context_t opt_ctx,
+ ggml_opt_dataset_t dataset,
+ ggml_opt_result_t result,
+ int64_t ibatch,
+ int64_t ibatch_max,
+ int64_t t_start_us) {
+ fprintf(stderr, "%s[", train ? "train: " : "val: ");
+
+ constexpr int64_t bar_length = 25;
+ for (int64_t j = 0; j < bar_length; ++j) {
+ const int64_t ibatch_j = ibatch_max * j/bar_length;
+ if (ibatch_j < ibatch) {
+ fprintf(stderr, "=");
+ } else if (ibatch_max * (j - 1)/bar_length < ibatch) {
+ fprintf(stderr, ">");
+ } else {
+ fprintf(stderr, " ");
+ }
+ }
+
+ const int64_t batch_size = ggml_opt_inputs(opt_ctx)->ne[1];
+ const int64_t idata = ibatch*batch_size;
+ const int64_t idata_max = ibatch_max*batch_size;
+
+ double loss;
+ double loss_unc;
+ ggml_opt_result_loss(result, &loss, &loss_unc);
+
+ double accuracy;
+ double accuracy_unc;
+ ggml_opt_result_accuracy(result, &accuracy, &accuracy_unc);
+
+ const int64_t t_ibatch_us = ggml_time_us() - t_start_us;
+ int64_t t_ibatch_s = t_ibatch_us / 1000000;
+ const int64_t t_ibatch_h = t_ibatch_s / 3600;
+ t_ibatch_s -= t_ibatch_h * 3600;
+ const int64_t t_ibatch_m = t_ibatch_s / 60;
+ t_ibatch_s -= t_ibatch_m * 60;
+
+ const int64_t t_eta_us = t_ibatch_us * (ibatch_max - ibatch)/ibatch;
+ int64_t t_eta_s = t_eta_us / 1000000;
+ const int64_t t_eta_h = t_eta_s / 3600;
+ t_eta_s -= t_eta_h * 3600;
+ const int64_t t_eta_m = t_eta_s / 60;
+ t_eta_s -= t_eta_m * 60;
+
+ fprintf(stderr, "| data=%06" PRId64 "/%06" PRId64 ", loss=%.6lf+-%.6lf, accuracy=%.2lf+-%.2lf%%, "
+ "t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 ", ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 "]\r",
+ idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc,
+ t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s);
+ if (ibatch == ibatch_max) {
+ fprintf(stderr, "\n");
+ }
+ fflush(stderr);
+
+ GGML_UNUSED(dataset);
+}
+
+void ggml_opt_fit(
+ ggml_backend_sched_t backend_sched,
+ ggml_context * ctx_compute,
+ ggml_tensor * inputs,
+ ggml_tensor * outputs,
+ ggml_opt_dataset_t dataset,
+ enum ggml_opt_loss_type loss_type,
+ ggml_opt_get_optimizer_params get_opt_pars,
+ int64_t nepoch,
+ int64_t nbatch_logical,
+ float val_split,
+ bool silent) {
+ ggml_time_init();
+ const int64_t t_start_us = ggml_time_us();
+
+ const int64_t ndata = ggml_opt_dataset_data(dataset)->ne[1];
+ const int64_t nbatch_physical = inputs->ne[1];
+ GGML_ASSERT(ndata % nbatch_logical == 0);
+ GGML_ASSERT(nbatch_logical % nbatch_physical == 0);
+
+ const int64_t opt_period = nbatch_logical / nbatch_physical;
+ const int64_t nbatches_logical = ndata / nbatch_logical;
+
+ GGML_ASSERT(val_split >= 0.0f);
+ GGML_ASSERT(val_split < 1.0f);
+ const int64_t ibatch_split = int64_t(((1.0f - val_split) * nbatches_logical)) * opt_period; // train <-> val split index (physical)
+ const int64_t idata_split = ibatch_split * nbatch_physical;
+
+ int64_t epoch = 1;
+
+ ggml_opt_params params = ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type);
+ params.opt_period = opt_period;
+ params.get_opt_pars = get_opt_pars;
+ params.get_opt_pars_ud = &epoch;
+ ggml_opt_context_t opt_ctx = ggml_opt_init(params);
+
+ // Shuffling the data is generally useful but there is only a point if not all data is used in a single batch.
+ if (nbatch_logical < ndata) {
+ ggml_opt_dataset_shuffle(opt_ctx, dataset, -1); // Shuffle all data (train + validation).
+ }
+
+ ggml_opt_result_t result_train = ggml_opt_result_init();
+ ggml_opt_result_t result_val = ggml_opt_result_init();
+
+ ggml_opt_epoch_callback epoch_callback = silent ? nullptr : ggml_opt_epoch_callback_progress_bar;
+
+ for (; epoch <= nepoch; ++epoch) {
+ if (nbatch_logical < idata_split) {
+ ggml_opt_dataset_shuffle(opt_ctx, dataset, idata_split);
+ }
+
+ ggml_opt_result_reset(result_train);
+ ggml_opt_result_reset(result_val);
+
+ if (!silent) {
+ fprintf(stderr, "%s: epoch %04" PRId64 "/%04" PRId64 ":\n", __func__, epoch, nepoch);
+ }
+ ggml_opt_epoch(opt_ctx, dataset, result_train, result_val, idata_split, epoch_callback, epoch_callback);
+ if (!silent) {
+ fprintf(stderr, "\n");
+ }
+ }
+
+ if (!silent) {
+ int64_t t_total_s = (ggml_time_us() - t_start_us) / 1000000;
+ const int64_t t_total_h = t_total_s / 3600;
+ t_total_s -= t_total_h * 3600;
+ const int64_t t_total_m = t_total_s / 60;
+ t_total_s -= t_total_m * 60;
+ fprintf(stderr, "%s: training took %02" PRId64 ":%02" PRId64 ":%02" PRId64 "\n", __func__, t_total_h, t_total_m, t_total_s);
+ }
+
+ ggml_opt_free(opt_ctx);
+ ggml_opt_result_free(result_train);
+ ggml_opt_result_free(result_val);
+}
/*.op =*/ GGML_OP_NONE,
/*.op_params =*/ { 0 },
/*.flags =*/ 0,
- /*.grad =*/ NULL,
/*.src =*/ { NULL },
/*.view_src =*/ view_src,
/*.view_offs =*/ view_offs,
/*.data =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data,
/*.name =*/ { 0 },
/*.extra =*/ NULL,
- ///*.padding =*/ { 0 },
+ /*.padding =*/ { 0 },
};
#ifdef __clang__
GGML_ASSERT(mask);
}
- bool is_node = false;
-
// permute(0, 2, 1, 3)
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
float params[] = { scale, max_bias, logit_softcap };
ggml_set_op_params(result, params, sizeof(params));
- result->op = GGML_OP_FLASH_ATTN_EXT;
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->op = GGML_OP_FLASH_ATTN_EXT;
result->src[0] = q;
result->src[1] = k;
result->src[2] = v;
GGML_ASSERT(ne2 % kvne2 == 0);
- bool is_node = false;
-
- if (q->grad || k->grad || v->grad) {
- // when using this operation (in backwards pass) these grads are set.
- // we don't want to create (big) grad of our result, so is_node is false.
- is_node = false;
- }
-
// store gradients of q, k and v as continuous tensors concatenated in result.
// note: v and gradv are actually transposed, i.e. v->ne[0] != D.
const int64_t elem_q = ggml_nelements(q);
int32_t masked_i = masked ? 1 : 0;
ggml_set_op_params(result, &masked_i, sizeof(masked_i));
- result->op = GGML_OP_FLASH_ATTN_BACK;
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->op = GGML_OP_FLASH_ATTN_BACK;
result->src[0] = q;
result->src[1] = k;
result->src[2] = v;
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * grad,
- float alpha,
- float beta1,
- float beta2,
- float eps,
- float wd) {
+ struct ggml_tensor * m,
+ struct ggml_tensor * v,
+ struct ggml_tensor * adamw_params) {
GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
GGML_ASSERT(ggml_are_same_shape(a, grad));
- GGML_ASSERT(alpha > 0.0f);
- GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f);
- GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f);
- GGML_ASSERT(eps >= 0.0f);
- GGML_ASSERT(wd >= 0.0f && wd <= 1.0f);
+ GGML_ASSERT(ggml_are_same_shape(a, m));
+ GGML_ASSERT(ggml_are_same_shape(a, v));
+ GGML_ASSERT(adamw_params->type == GGML_TYPE_F32);
+ GGML_ASSERT(ggml_nelements(adamw_params) == 7);
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
- const int64_t iter = 1;
- memcpy(&result->op_params[0], &iter, sizeof(int64_t));
- ggml_set_op_params_f32(result, 2, alpha);
- ggml_set_op_params_f32(result, 3, beta1);
- ggml_set_op_params_f32(result, 4, beta2);
- ggml_set_op_params_f32(result, 5, eps);
- ggml_set_op_params_f32(result, 6, wd);
-
result->op = GGML_OP_OPT_STEP_ADAMW;
result->src[0] = a;
result->src[1] = grad;
- result->src[2] = ggml_dup_tensor(ctx, grad);
- result->src[3] = ggml_dup_tensor(ctx, grad);
+ result->src[2] = m;
+ result->src[3] = v;
+ result->src[4] = adamw_params;
return result;
}
GGML_FREE(map);
}
-// gradient checkpointing
-
-static struct ggml_tensor * ggml_recompute_graph_node(
- struct ggml_context * ctx,
- struct ggml_cgraph * graph,
- struct hash_map * replacements,
- struct ggml_tensor * node) {
-
- if (node == NULL) {
- return NULL;
- }
-
- if (node->flags & GGML_TENSOR_FLAG_PARAM) {
- return node;
- }
-
- if (!ggml_hash_contains(&graph->visited_hash_set, node)) {
- return node;
- }
-
- int count_children = 0;
- for (int k = 0; k < GGML_MAX_SRC; ++k) {
- if (node->src[k]) {
- ++count_children;
- }
- }
-
- if (count_children == 0) {
- return node;
- }
-
- size_t i = ggml_hash_find(&replacements->set, node);
- GGML_ASSERT(i != GGML_HASHSET_FULL); // assert that not full
- if (replacements->set.keys[i] == node) {
- return replacements->vals[i];
- }
-
- struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, GGML_MAX_DIMS, node->ne);
-
- // insert clone into replacements
- GGML_ASSERT(replacements->set.keys[i] == NULL); // assert that we don't overwrite
- replacements->set.keys[i] = node;
- replacements->vals[i] = clone;
-
- clone->op = node->op;
- clone->grad = node->grad;
- clone->flags = node->flags;
- clone->extra = node->extra;
- for (int k = 0; k < GGML_MAX_DIMS; ++k) {
- clone->nb[k] = node->nb[k];
- }
- for (int k = 0; k < GGML_MAX_SRC; ++k) {
- clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
- }
- if (node->view_src != NULL) {
- clone->data = (node->view_src->data == NULL)
- ? NULL // view_src not yet allocated
- : (char *) node->view_src->data // view_src already allocated
- + node->view_offs;
- clone->view_src = node->view_src;
- clone->view_offs = node->view_offs;
- }
-
- GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
- GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME);
- memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
- ggml_format_name(clone, "%s (clone)", ggml_get_name(node));
-
- return clone;
-}
-
-void ggml_build_backward_gradient_checkpointing(
- struct ggml_context * ctx,
- struct ggml_cgraph * gf,
- struct ggml_cgraph * gb,
- struct ggml_cgraph * gb_tmp,
- struct ggml_tensor * * checkpoints,
- int n_checkpoints) {
- ggml_graph_cpy(gf, gb_tmp);
- ggml_build_backward_expand(ctx, gf, gb_tmp, false);
-
- if (n_checkpoints <= 0) {
- ggml_graph_cpy(gb_tmp, gb);
- return;
- }
-
- struct hash_map * replacements = ggml_new_hash_map(gf->n_nodes + gf->n_leafs + n_checkpoints);
-
- // insert checkpoints in replacements
- for (int i = 0; i < n_checkpoints; ++i) {
- size_t k = ggml_hash_find(&replacements->set, checkpoints[i]);
- GGML_ASSERT(k != GGML_HASHSET_FULL); // assert that not full
- GGML_ASSERT(replacements->set.keys[k] == NULL); // assert that we don't overwrite
- replacements->set.keys[k] = checkpoints[i];
- replacements->vals[k] = checkpoints[i];
- }
-
- ggml_graph_cpy(gf, gb);
- // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
- // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
- // by recomputing them from checkpoints
- for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
- struct ggml_tensor * node = gb_tmp->nodes[i];
- for (int k = 0; k < GGML_MAX_SRC; ++k) {
- // insert new tensors recomputing src, reusing already made replacements,
- // remember replacements: remember new tensors with mapping from corresponding gf nodes
- // recurse for input tensors,
- // unless (i.e. terminating when) input tensors are replacements (like checkpoints)
- node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
- }
- // insert rewritten backward node with replacements made into resulting backward graph gb
- ggml_build_forward_expand(gb, node);
- }
-
- ggml_hash_map_free(replacements);
-}
-
// utility functions to change gradients
// if a is in acc_table, modify gradients in-place and mark result as gradient accumulator
// else if a is in zero_table, replace a
// else, just add/subtract/etc. the gradients
-static struct ggml_tensor * ggml_add_or_set(
- struct ggml_context * ctx,
- struct ggml_tensor * a,
- struct ggml_tensor * b,
- struct ggml_hash_set * zero_table,
- struct ggml_hash_set * acc_table) {
- if (ggml_hash_contains(acc_table, a)) {
- struct ggml_tensor * ret = ggml_add_impl(ctx, a, b, true);
- const size_t insert_result = ggml_hash_insert(acc_table, ret);
- GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
- GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
- return ret;
- }
- if (ggml_hash_contains(zero_table, a)) {
- return b;
+static void ggml_add_or_set(
+ struct ggml_context * ctx,
+ struct ggml_cgraph * cgraph,
+ size_t isrc,
+ struct ggml_tensor * tensor) {
+ if (cgraph->grads[isrc]) {
+ cgraph->grads[isrc] = ggml_add_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
+ } else {
+ cgraph->grads[isrc] = tensor;
}
- return ggml_add_impl(ctx, a, b, false);
+ ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
}
-static struct ggml_tensor * ggml_acc_or_set(
- struct ggml_context * ctx,
- struct ggml_tensor * a,
- struct ggml_tensor * b,
- const size_t nb1,
- const size_t nb2,
- const size_t nb3,
- const size_t offset,
- struct ggml_hash_set * zero_table,
- struct ggml_hash_set * acc_table) {
- if (ggml_hash_contains(acc_table, a)) {
- struct ggml_tensor * ret = ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
- const size_t insert_result = ggml_hash_insert(acc_table, ret);
- GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
- GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
- return ret;
- }
- if (ggml_hash_contains(zero_table, a)) {
- struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
- return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
+static void ggml_acc_or_set(
+ struct ggml_context * ctx,
+ struct ggml_cgraph * cgraph,
+ size_t isrc,
+ struct ggml_tensor * src,
+ struct ggml_tensor * tensor,
+ const size_t nb1,
+ const size_t nb2,
+ const size_t nb3,
+ const size_t offset) {
+ if (cgraph->grads[isrc]) {
+ cgraph->grads[isrc] = ggml_acc_impl(ctx, cgraph->grads[isrc], tensor, nb1, nb2, nb3, offset, cgraph->grad_accs[isrc]);
+ } else {
+ struct ggml_tensor * a_zero = ggml_scale(ctx, src, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
+ cgraph->grads[isrc] = ggml_acc_impl(ctx, a_zero, tensor, nb1, nb2, nb3, offset, false);
}
- return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
+ ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
}
-static struct ggml_tensor * ggml_add1_or_set(
- struct ggml_context * ctx,
- struct ggml_tensor * a,
- struct ggml_tensor * b,
- struct ggml_hash_set * zero_table,
- struct ggml_hash_set * acc_table) {
- if (ggml_hash_contains(acc_table, a)) {
- struct ggml_tensor * ret = ggml_add1_impl(ctx, a, b, true);
- const size_t insert_result = ggml_hash_insert(acc_table, ret);
- GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
- GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
- return ret;
- }
- if (ggml_hash_contains(zero_table, a)) {
- return ggml_repeat(ctx, b, a);
+static void ggml_add1_or_set(
+ struct ggml_context * ctx,
+ struct ggml_cgraph * cgraph,
+ size_t isrc,
+ struct ggml_tensor * src,
+ struct ggml_tensor * tensor) {
+ if (cgraph->grads[isrc]) {
+ cgraph->grads[isrc] = ggml_add1_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
+ } else {
+ cgraph->grads[isrc] = ggml_repeat(ctx, tensor, src);
}
- return ggml_add1_impl(ctx, a, b, false);
+ ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
}
-static struct ggml_tensor * ggml_sub_or_set(
- struct ggml_context * ctx,
- struct ggml_tensor * a,
- struct ggml_tensor * b,
- struct ggml_hash_set * zero_table,
- struct ggml_hash_set * acc_table) {
- if (ggml_hash_contains(acc_table, a)) {
- struct ggml_tensor * ret = ggml_sub_impl(ctx, a, b, true);
- const size_t insert_result = ggml_hash_insert(acc_table, ret);
- GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
- GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
- return ret;
- }
- if (ggml_hash_contains(zero_table, a)) {
- return ggml_neg(ctx, b);
+static void ggml_sub_or_set(
+ struct ggml_context * ctx,
+ struct ggml_cgraph * cgraph,
+ size_t isrc,
+ struct ggml_tensor * tensor) {
+ if (cgraph->grads[isrc]) {
+ cgraph->grads[isrc] = ggml_sub_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
+ } else {
+ cgraph->grads[isrc] = ggml_neg(ctx, tensor);
}
- return ggml_sub_impl(ctx, a, b, false);
+ ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
}
-static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set * zero_table, struct ggml_hash_set * acc_table) {
+static void ggml_compute_backward(
+ struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, bool * grads_needed) {
+ struct ggml_tensor * tensor = cgraph->nodes[i];
+ struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, tensor);
+
+ if (!grad) {
+ return;
+ }
+
struct ggml_tensor * src0 = tensor->src[0];
struct ggml_tensor * src1 = tensor->src[1];
struct ggml_tensor * src2 = tensor->src[2];
+ struct ggml_hash_set * hash_set = &cgraph->visited_hash_set;
+ const size_t isrc0 = ggml_hash_find(hash_set, src0);
+ const size_t isrc1 = ggml_hash_find(hash_set, src1);
+ const size_t isrc2 = ggml_hash_find(hash_set, src2);
+ const bool src0_needs_grads = isrc0 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc0) && grads_needed[isrc0];
+ const bool src1_needs_grads = isrc1 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc1) && grads_needed[isrc1];
+ const bool src2_needs_grads = isrc2 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc2) && grads_needed[isrc2];
switch (tensor->op) {
- case GGML_OP_DUP:
- {
- if (src0->grad) {
- src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
- }
- } break;
- case GGML_OP_ADD:
- {
- if (src0->grad) {
- src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
- }
- if (src1->grad) {
- if (ggml_are_same_shape(src0, src1)) {
- src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table);
- } else {
- src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_repeat_back(ctx, tensor->grad, src1), zero_table, acc_table);
- }
- }
- } break;
- case GGML_OP_ADD1:
- {
- if (src0->grad) {
- src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
- }
- if (src1->grad) {
- src1->grad = ggml_add_or_set(ctx,
- src1->grad,
- ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean
- zero_table, acc_table);
- }
- } break;
- case GGML_OP_ACC:
- {
- if (src0->grad) {
- src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
- }
- if (src1->grad) {
- const size_t nb1 = ((int32_t *) tensor->op_params)[0];
- const size_t nb2 = ((int32_t *) tensor->op_params)[1];
- const size_t nb3 = ((int32_t *) tensor->op_params)[2];
- const size_t offset = ((int32_t *) tensor->op_params)[3];
-
- struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx,
- tensor->grad,
- src1->grad->ne[0],
- src1->grad->ne[1],
- src1->grad->ne[2],
- src1->grad->ne[3],
- nb1, nb2, nb3, offset);
-
- src1->grad =
- ggml_add_or_set(ctx,
- src1->grad,
- ggml_reshape(ctx,
- ggml_cont(ctx, tensor_grad_view),
- src1->grad),
- zero_table, acc_table);
- }
- } break;
- case GGML_OP_SUB:
- {
- if (src0->grad) {
- src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
- }
- if (src1->grad) {
- src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table);
- }
- } break;
- case GGML_OP_MUL:
- {
- if (src0->grad) {
- src0->grad =
- ggml_add_or_set(ctx,
- src0->grad,
- ggml_mul(ctx, src1, tensor->grad),
- zero_table, acc_table);
- }
- if (src1->grad) {
- src1->grad =
- ggml_add_or_set(ctx,
- src1->grad,
- ggml_mul(ctx, src0, tensor->grad),
- zero_table, acc_table);
- }
- } break;
- case GGML_OP_DIV:
- {
- if (src0->grad) {
- src0->grad =
- ggml_add_or_set(ctx,
- src0->grad,
- ggml_div(ctx, tensor->grad, src1),
- zero_table, acc_table);
- }
- if (src1->grad) {
- src1->grad =
- ggml_sub_or_set(ctx,
- src1->grad,
- ggml_mul(ctx,
- tensor->grad,
- ggml_div(ctx, tensor, src1)),
- zero_table, acc_table);
- }
- } break;
- case GGML_OP_SQR:
- {
- if (src0->grad) {
- src0->grad =
- ggml_add_or_set(ctx,
- src0->grad,
- ggml_scale(ctx,
- ggml_mul(ctx, src0, tensor->grad),
- 2.0f),
- zero_table, acc_table);
- }
- } break;
- case GGML_OP_SQRT:
- {
- if (src0->grad) {
- src0->grad =
- ggml_add_or_set(ctx,
- src0->grad,
- ggml_scale(ctx,
- ggml_div(ctx,
- tensor->grad,
- tensor),
- 0.5f),
- zero_table, acc_table);
- }
- } break;
- case GGML_OP_LOG:
- {
- if (src0->grad) {
- src0->grad =
- ggml_add_or_set(ctx,
- src0->grad,
- ggml_div(ctx,
- tensor->grad,
- src0),
- zero_table, acc_table);
- }
- } break;
- case GGML_OP_SIN:
- {
- if (src0->grad) {
- src0->grad =
- ggml_add_or_set(ctx,
- src0->grad,
- ggml_mul(ctx,
- tensor->grad,
- ggml_cos(ctx, src0)),
- zero_table, acc_table);
- }
- } break;
- case GGML_OP_COS:
- {
- if (src0->grad) {
- src0->grad =
- ggml_sub_or_set(ctx,
- src0->grad,
- ggml_mul(ctx,
- tensor->grad,
- ggml_sin(ctx, src0)),
- zero_table, acc_table);
- }
- } break;
- case GGML_OP_SUM:
- {
- if (src0->grad) {
- src0->grad =
- ggml_add1_or_set(ctx,
- src0->grad,
- tensor->grad,
- zero_table, acc_table);
- }
- } break;
- case GGML_OP_SUM_ROWS:
- {
- if (src0->grad) {
- src0->grad =
- ggml_add_or_set(ctx,
- src0->grad,
- ggml_repeat(ctx,
- tensor->grad,
- src0->grad),
- zero_table, acc_table);
- }
- } break;
- case GGML_OP_MEAN:
- case GGML_OP_ARGMAX:
- case GGML_OP_COUNT_EQUAL:
- {
- GGML_ABORT("fatal error"); // TODO: implement
- }
- case GGML_OP_REPEAT:
- {
- // necessary for llama
- if (src0->grad) {
- src0->grad = ggml_add_or_set(ctx,
- src0->grad,
- ggml_repeat_back(ctx, tensor->grad, src0->grad),
- zero_table, acc_table);
- }
- } break;
- case GGML_OP_REPEAT_BACK:
- {
- if (src0->grad) {
- // TODO: test this
- src0->grad = ggml_add_or_set(ctx,
- src0->grad,
- ggml_repeat(ctx, tensor->grad, src0->grad),
- zero_table, acc_table);
- }
- } break;
- case GGML_OP_CONCAT:
- {
- GGML_ABORT("fatal error"); // TODO: implement
- }
- case GGML_OP_SILU_BACK:
- {
- GGML_ABORT("fatal error"); // TODO: not implemented
+ case GGML_OP_DUP: {
+ if (src0_needs_grads) {
+ ggml_add_or_set(ctx, cgraph, isrc0, grad);
}
- case GGML_OP_NORM:
- {
- GGML_ABORT("fatal error"); // TODO: not implemented
+ } break;
+ case GGML_OP_ADD: {
+ if (src0_needs_grads) {
+ ggml_add_or_set(ctx, cgraph, isrc0, grad);
}
- case GGML_OP_RMS_NORM:
- {
- // necessary for llama
- if (src0->grad) {
- float eps;
- memcpy(&eps, tensor->op_params, sizeof(float));
-
- src0->grad = ggml_add_or_set(ctx,
- src0->grad,
- ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
- zero_table, acc_table);
+ if (src1_needs_grads) {
+ struct ggml_tensor * tmp = grad;
+ if (!ggml_are_same_shape(src0, src1)) {
+ tmp = ggml_repeat_back(ctx, tmp, src1);
}
- } break;
- case GGML_OP_RMS_NORM_BACK:
- {
- GGML_ABORT("fatal error"); // TODO: not implemented
+ ggml_add_or_set(ctx, cgraph, isrc1, tmp);
}
- case GGML_OP_GROUP_NORM:
- {
- GGML_ABORT("fatal error"); // TODO: not implemented
+ } break;
+ case GGML_OP_ADD1: {
+ if (src0_needs_grads) {
+ ggml_add_or_set(ctx, cgraph, isrc0, grad);
}
- case GGML_OP_MUL_MAT:
- {
- // https://cs231n.github.io/optimization-2/#staged
- // # forward pass
- // s0 = np.random.randn(5, 10)
- // s1 = np.random.randn(10, 3)
- // t = s0.dot(s1)
-
- // # now suppose we had the gradient on t from above in the circuit
- // dt = np.random.randn(*t.shape) # same shape as t
- // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
- // ds1 = t.T.dot(dt)
-
- // tensor.shape [m,p,qq,rr]
- // src0.shape [n,m,q1,r1]
- // src1.shape [n,p,qq,rr]
-
- // necessary for llama
- if (src0->grad) {
- struct ggml_tensor * s1_tg =
- ggml_out_prod(ctx, // [n,m,qq,rr]
- src1, // [n,p,qq,rr]
- tensor->grad); // [m,p,qq,rr]
- const int64_t qq = s1_tg->ne[2];
- const int64_t rr = s1_tg->ne[3];
- const int64_t q1 = src0->ne[2];
- const int64_t r1 = src0->ne[3];
- const bool ne2_broadcasted = qq > q1;
- const bool ne3_broadcasted = rr > r1;
- if (ne2_broadcasted || ne3_broadcasted) {
- // sum broadcast repetitions of s1_tg into shape of src0
- s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
- }
- src0->grad =
- ggml_add_or_set(ctx,
- src0->grad, // [n,m,q1,r1]
- s1_tg, // [n,m,q1,r1]
- zero_table, acc_table);
- }
- if (src1->grad) {
- src1->grad =
- ggml_add_or_set(ctx,
- src1->grad, // [n,p,qq,rr]
- // ggml_mul_mat(ctx, // [n,p,qq,rr]
- // ggml_cont(ctx, // [m,n,q1,r1]
- // ggml_transpose(ctx, src0)), // [m,n,q1,r1]
- // tensor->grad), // [m,p,qq,rr]
-
- // // when src0 is bigger than tensor->grad (this is mostly the case in llama),
- // // avoid transpose of src0, rather transpose smaller tensor->grad
- // // and then use ggml_out_prod
- ggml_out_prod(ctx, // [n,p,qq,rr]
- src0, // [n,m,q1,r1]
- ggml_transpose(ctx, // [p,m,qq,rr]
- tensor->grad)), // [m,p,qq,rr]
- zero_table, acc_table);
- }
- } break;
- case GGML_OP_MUL_MAT_ID:
- {
- GGML_ABORT("fatal error"); // TODO: not implemented
+ if (src1_needs_grads) {
+ ggml_add_or_set(ctx, cgraph, isrc1, ggml_mean(ctx, grad)); // TODO: should probably be sum instead of mean
}
- case GGML_OP_OUT_PROD:
- {
- GGML_ABORT("fatal error"); // TODO: not implemented
+ } break;
+ case GGML_OP_ACC: {
+ if (src0_needs_grads) {
+ ggml_add_or_set(ctx, cgraph, isrc0, grad);
}
- case GGML_OP_SCALE:
- {
- // necessary for llama
- if (src0->grad) {
- float s;
- memcpy(&s, tensor->op_params, sizeof(float));
-
- src0->grad =
- ggml_add_or_set(ctx,
- src0->grad,
- ggml_scale_impl(ctx, tensor->grad, s, false),
- zero_table, acc_table);
- }
- } break;
- case GGML_OP_SET:
- {
- const size_t nb1 = ((int32_t *) tensor->op_params)[0];
- const size_t nb2 = ((int32_t *) tensor->op_params)[1];
- const size_t nb3 = ((int32_t *) tensor->op_params)[2];
- const size_t offset = ((int32_t *) tensor->op_params)[3];
-
- struct ggml_tensor * tensor_grad_view = NULL;
-
- if (src0->grad || src1->grad) {
- GGML_ASSERT(src0->type == tensor->type);
- GGML_ASSERT(tensor->grad->type == tensor->type);
- GGML_ASSERT(!src1->grad || src1->grad->type == tensor->grad->type);
-
- tensor_grad_view = ggml_view_4d(ctx,
- tensor->grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
- nb1, nb2, nb3, offset);
- }
+ if (src1_needs_grads) {
+ const size_t nb1 = ((int32_t *) tensor->op_params)[0];
+ const size_t nb2 = ((int32_t *) tensor->op_params)[1];
+ const size_t nb3 = ((int32_t *) tensor->op_params)[2];
+ const size_t offset = ((int32_t *) tensor->op_params)[3];
- if (src0->grad) {
- src0->grad = ggml_add_or_set(ctx,
- src0->grad,
- ggml_acc_impl(ctx,
- tensor->grad,
- ggml_neg(ctx, tensor_grad_view),
- nb1, nb2, nb3, offset, false),
- zero_table, acc_table);
- }
+ struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx,
+ grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
+ nb1, nb2, nb3, offset);
- if (src1->grad) {
- src1->grad =
- ggml_add_or_set(ctx,
- src1->grad,
- ggml_reshape(ctx,
- ggml_cont(ctx, tensor_grad_view),
- src1->grad),
- zero_table, acc_table);
- }
- } break;
- case GGML_OP_CPY:
- {
- // necessary for llama
- // cpy overwrites value of src1 by src0 and returns view(src1)
- // the overwriting is mathematically equivalent to:
- // tensor = src0 * 1 + src1 * 0
- if (src0->grad) {
- // dsrc0 = dtensor * 1
- src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
- }
- if (src1->grad) {
- // dsrc1 = dtensor * 0 -> noop
- }
- } break;
- case GGML_OP_CONT:
- {
- // same as cpy
- if (src0->grad) {
- GGML_ASSERT(ggml_is_contiguous(src0->grad));
- GGML_ASSERT(ggml_is_contiguous(tensor->grad));
- src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
- }
- } break;
- case GGML_OP_RESHAPE:
- {
- // necessary for llama
- if (src0->grad) {
- src0->grad =
- ggml_add_or_set(ctx, src0->grad,
- ggml_reshape(ctx,
- ggml_is_contiguous(tensor->grad)
- ? tensor->grad
- : ggml_cont(ctx, tensor->grad),
- src0->grad),
- zero_table, acc_table);
- }
- } break;
- case GGML_OP_VIEW:
- {
- // necessary for llama
- if (src0->grad) {
- size_t offset;
-
- memcpy(&offset, tensor->op_params, sizeof(offset));
-
- size_t nb1 = tensor->nb[1];
- size_t nb2 = tensor->nb[2];
- size_t nb3 = tensor->nb[3];
-
- if (src0->type != src0->grad->type) {
- // gradient is typically F32, but src0 could be other type
- size_t ng = ggml_element_size(src0->grad);
- size_t n0 = ggml_element_size(src0);
- GGML_ASSERT(offset % n0 == 0);
- GGML_ASSERT(nb1 % n0 == 0);
- GGML_ASSERT(nb2 % n0 == 0);
- GGML_ASSERT(nb3 % n0 == 0);
- offset = (offset / n0) * ng;
- nb1 = (nb1 / n0) * ng;
- nb2 = (nb2 / n0) * ng;
- nb3 = (nb3 / n0) * ng;
- }
-
- src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table, acc_table);
- }
- } break;
- case GGML_OP_PERMUTE:
- {
- // necessary for llama
- if (src0->grad) {
- int32_t * axes = (int32_t *) tensor->op_params;
- int axis0 = axes[0] & 0x3;
- int axis1 = axes[1] & 0x3;
- int axis2 = axes[2] & 0x3;
- int axis3 = axes[3] & 0x3;
- int axes_backward[4] = {0,0,0,0};
- axes_backward[axis0] = 0;
- axes_backward[axis1] = 1;
- axes_backward[axis2] = 2;
- axes_backward[axis3] = 3;
- src0->grad =
- ggml_add_or_set(ctx, src0->grad,
- ggml_permute(ctx,
- tensor->grad,
- axes_backward[0],
- axes_backward[1],
- axes_backward[2],
- axes_backward[3]),
- zero_table, acc_table);
- }
- } break;
- case GGML_OP_TRANSPOSE:
- {
- // necessary for llama
- if (src0->grad) {
- src0->grad =
- ggml_add_or_set(ctx, src0->grad,
- ggml_transpose(ctx, tensor->grad),
- zero_table, acc_table);
- }
- } break;
- case GGML_OP_GET_ROWS:
- {
- // necessary for llama (only for tokenizer)
- if (src0->grad) {
- src0->grad =
- ggml_add_or_set(ctx, src0->grad,
- // last ggml_get_rows_back argument src0->grad is only
- // necessary to setup correct output shape
- ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
- zero_table, acc_table);
- }
- if (src1->grad) {
- // noop
- }
- } break;
- case GGML_OP_GET_ROWS_BACK:
- {
- GGML_ABORT("fatal error"); // TODO: not implemented
+ ggml_add_or_set(ctx, cgraph, isrc1, ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1));
}
- case GGML_OP_DIAG:
- {
- GGML_ABORT("fatal error"); // TODO: not implemented
+ } break;
+ case GGML_OP_SUB: {
+ if (src0_needs_grads) {
+ ggml_add_or_set(ctx, cgraph, isrc0, grad);
}
- case GGML_OP_DIAG_MASK_INF:
- {
- // necessary for llama
- if (src0->grad) {
- const int n_past = ((int32_t *) tensor->op_params)[0];
- src0->grad =
- ggml_add_or_set(ctx, src0->grad,
- /* ggml_diag_mask_inf_impl() shouldn't be here */
- /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */
- ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
- zero_table, acc_table);
- }
- } break;
- case GGML_OP_DIAG_MASK_ZERO:
- {
- // necessary for llama
- if (src0->grad) {
- const int n_past = ((int32_t *) tensor->op_params)[0];
- src0->grad =
- ggml_add_or_set(ctx, src0->grad,
- ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
- zero_table, acc_table);
- }
- } break;
- case GGML_OP_SOFT_MAX:
- {
- // necessary for llama
- if (src0->grad) {
- src0->grad =
- ggml_add_or_set(ctx, src0->grad,
- ggml_soft_max_back(ctx, tensor->grad, tensor),
- zero_table, acc_table);
- }
- GGML_ASSERT((!src1 || !src1->grad) && "backward pass for softmax mask not implemented");
- } break;
- case GGML_OP_SOFT_MAX_BACK:
- {
- GGML_ABORT("fatal error"); // TODO: not implemented
+ if (src1_needs_grads) {
+ ggml_sub_or_set(ctx, cgraph, isrc1, grad);
}
- case GGML_OP_ROPE:
- {
- // necessary for llama
- if (src0->grad) {
- //const int n_past = ((int32_t *) tensor->op_params)[0];
- const int n_dims = ((int32_t *) tensor->op_params)[1];
- const int mode = ((int32_t *) tensor->op_params)[2];
- //const int n_ctx = ((int32_t *) tensor->op_params)[3];
- const int n_ctx_orig = ((int32_t *) tensor->op_params)[4];
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
-
- memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
- memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
- memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float));
- memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
- memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
- memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
-
- src0->grad = ggml_add_or_set(ctx,
- src0->grad,
- ggml_rope_back(ctx,
- tensor->grad,
- src1,
- src2,
- n_dims,
- mode,
- n_ctx_orig,
- freq_base,
- freq_scale,
- ext_factor,
- attn_factor,
- beta_fast,
- beta_slow),
- zero_table, acc_table);
- }
- GGML_ASSERT((!src2 || !src2->grad) && "gradients for freq factors not implemented");
- } break;
- case GGML_OP_ROPE_BACK:
- {
- if (src0->grad) {
- //const int n_past = ((int32_t *) tensor->op_params)[0];
- const int n_dims = ((int32_t *) tensor->op_params)[1];
- const int mode = ((int32_t *) tensor->op_params)[2];
- //const int n_ctx = ((int32_t *) tensor->op_params)[3];
- const int n_ctx_orig = ((int32_t *) tensor->op_params)[4];
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
-
- memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
- memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
- memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float));
- memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
- memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
- memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
-
- src0->grad = ggml_add_or_set(ctx,
- src0->grad,
- ggml_rope_impl(ctx,
- tensor->grad,
- src1,
- src2,
- n_dims,
- mode,
- n_ctx_orig,
- freq_base,
- freq_scale,
- ext_factor,
- attn_factor,
- beta_fast,
- beta_slow,
- false),
- zero_table, acc_table);
+ } break;
+ case GGML_OP_MUL: {
+ if (src0_needs_grads) {
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, src1, grad));
+ }
+ if (src1_needs_grads) {
+ struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad);
+ if (!ggml_are_same_shape(src0, src1)) {
+ tmp = ggml_repeat_back(ctx, tmp, src1);
}
- } break;
- case GGML_OP_CLAMP:
- {
- GGML_ABORT("fatal error"); // TODO: not implemented
+ ggml_add_or_set(ctx, cgraph, isrc1, tmp);
}
- case GGML_OP_CONV_TRANSPOSE_1D:
- {
- GGML_ABORT("fatal error"); // TODO: not implemented
+ } break;
+ case GGML_OP_DIV: {
+ if (src0_needs_grads) {
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_div(ctx, grad, src1));
}
- case GGML_OP_IM2COL:
- {
- if (src1->grad) {
- const int32_t s0 = ggml_get_op_params_i32(tensor, 0);
- const int32_t s1 = ggml_get_op_params_i32(tensor, 1);
- const int32_t p0 = ggml_get_op_params_i32(tensor, 2);
- const int32_t p1 = ggml_get_op_params_i32(tensor, 3);
- const int32_t d0 = ggml_get_op_params_i32(tensor, 4);
- const int32_t d1 = ggml_get_op_params_i32(tensor, 5);
- const bool is_2D = ggml_get_op_params_i32(tensor, 6) == 1;
-
- src1->grad = ggml_add_or_set(ctx,
- src1->grad,
- ggml_im2col_back(ctx, src0, tensor->grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D),
- zero_table, acc_table);
- }
- } break;
- case GGML_OP_IM2COL_BACK:
- {
- GGML_ABORT("fatal error"); // TODO: not implemented
+ if (src1_needs_grads) {
+ ggml_sub_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, grad, ggml_div(ctx, tensor, src1)));
}
- case GGML_OP_CONV_TRANSPOSE_2D:
- {
- GGML_ABORT("fatal error"); // TODO: not implemented
+ } break;
+ case GGML_OP_SQR: {
+ if (src0_needs_grads) {
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale(ctx, ggml_mul(ctx, src0, grad), 2.0f));
}
- case GGML_OP_POOL_1D:
- {
- GGML_ABORT("fatal error"); // TODO: not implemented
+ } break;
+ case GGML_OP_SQRT: {
+ if (src0_needs_grads) {
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale(ctx, ggml_div(ctx, grad, tensor), 0.5f));
}
- case GGML_OP_POOL_2D:
- {
- if (src0->grad) {
- const enum ggml_op_pool op = ggml_get_op_params_i32(tensor, 0);
- const int32_t k0 = ggml_get_op_params_i32(tensor, 1);
- const int32_t k1 = ggml_get_op_params_i32(tensor, 2);
- const int32_t s0 = ggml_get_op_params_i32(tensor, 3);
- const int32_t s1 = ggml_get_op_params_i32(tensor, 4);
- const int32_t p0 = ggml_get_op_params_i32(tensor, 5);
- const int32_t p1 = ggml_get_op_params_i32(tensor, 6);
-
- src0->grad = ggml_add_or_set(ctx,
- src0->grad,
- ggml_pool_2d_back(ctx, tensor->grad, src0, op, k0, k1, s0, s1, p0, p1),
- zero_table, acc_table);
- }
- } break;
- case GGML_OP_POOL_2D_BACK:
- {
- GGML_ABORT("fatal error"); // TODO: not implemented
+ } break;
+ case GGML_OP_LOG: {
+ if (src0_needs_grads) {
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_div(ctx, grad, src0));
}
- case GGML_OP_UPSCALE:
- {
- GGML_ABORT("fatal error"); // TODO: not implemented
+ } break;
+ case GGML_OP_SIN: {
+ if (src0_needs_grads) {
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_cos(ctx, src0)));
}
- case GGML_OP_PAD:
- {
- GGML_ABORT("fatal error"); // TODO: not implemented
+ } break;
+ case GGML_OP_COS: {
+ if (src0_needs_grads) {
+ ggml_sub_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_sin(ctx, src0)));
}
- case GGML_OP_ARANGE:
- {
- GGML_ABORT("fatal error"); // TODO: not implemented
+ } break;
+ case GGML_OP_SUM: {
+ if (src0_needs_grads) {
+ ggml_add1_or_set(ctx, cgraph, isrc0, src0, grad);
}
- case GGML_OP_TIMESTEP_EMBEDDING:
- {
- GGML_ABORT("fatal error"); // TODO: not implemented
+ } break;
+ case GGML_OP_SUM_ROWS: {
+ if (src0_needs_grads) {
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat(ctx, grad, src0));
}
- case GGML_OP_ARGSORT:
- {
- GGML_ABORT("fatal error"); // TODO: not implemented
+ } break;
+ case GGML_OP_MEAN: {
+ if (src0_needs_grads) {
+ ggml_add1_or_set(ctx, cgraph, isrc0, src0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false));
}
- case GGML_OP_LEAKY_RELU:
- {
- GGML_ABORT("fatal error"); // TODO: not implemented
+ } break;
+ case GGML_OP_REPEAT: {
+ if (src0_needs_grads) {
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat_back(ctx, grad, src0));
}
- case GGML_OP_FLASH_ATTN_EXT:
- {
- GGML_ABORT("FA backward pass not adapted after rework");
- struct ggml_tensor * flash_grad = NULL;
- if (src0->grad || src1->grad || tensor->src[2]->grad) {
- int32_t t = ggml_get_op_params_i32(tensor, 0);
- GGML_ASSERT(t == 0 || t == 1);
- bool masked = t != 0;
- flash_grad =
- ggml_flash_attn_back(ctx,
- src0,
- src1,
- tensor->src[2],
- tensor->grad,
- masked);
+ } break;
+ case GGML_OP_REPEAT_BACK: {
+ if (src0_needs_grads) {
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat(ctx, grad, src0));
+ }
+ } break;
+ case GGML_OP_RMS_NORM: {
+ if (src0_needs_grads) {
+ float eps;
+ memcpy(&eps, tensor->op_params, sizeof(float));
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, src0, grad, eps));
+ }
+ } break;
+ case GGML_OP_MUL_MAT: {
+ // https://cs231n.github.io/optimization-2/#staged
+ // # forward pass
+ // s0 = np.random.randn(5, 10)
+ // s1 = np.random.randn(10, 3)
+ // t = s0.dot(s1)
+
+ // # now suppose we had the gradient on t from above in the circuit
+ // dt = np.random.randn(*t.shape) # same shape as t
+ // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
+ // ds1 = t.T.dot(dt)
+
+ // tensor.shape [m,p,qq,rr]
+ // src0.shape [n,m,q1,r1]
+ // src1.shape [n,p,qq,rr]
+
+ if (src0_needs_grads) {
+ struct ggml_tensor * s1_tg =
+ ggml_out_prod(ctx, // [n,m,qq,rr]
+ src1, // [n,p,qq,rr]
+ grad); // [m,p,qq,rr]
+ const int64_t qq = s1_tg->ne[2];
+ const int64_t rr = s1_tg->ne[3];
+ const int64_t q1 = src0->ne[2];
+ const int64_t r1 = src0->ne[3];
+ const bool ne2_broadcasted = qq > q1;
+ const bool ne3_broadcasted = rr > r1;
+ if (ne2_broadcasted || ne3_broadcasted) {
+ // sum broadcast repetitions of s1_tg into shape of src0
+ s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
}
+ ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/);
+ }
+ if (src1_needs_grads) {
+ ggml_add_or_set(ctx, cgraph, isrc1,
+ // ggml_mul_mat(ctx, // [n,p,qq,rr]
+ // ggml_cont(ctx, // [m,n,q1,r1]
+ // ggml_transpose(ctx, src0)), // [m,n,q1,r1]
+ // grad), // [m,p,qq,rr]
+
+ // when src0 is bigger than tensor->grad (this is mostly the case in llama),
+ // avoid transpose of src0, rather transpose smaller tensor->grad
+ // and then use ggml_out_prod
+ ggml_out_prod(ctx, // [n,p,qq,rr]
+ src0, // [n,m,q1,r1]
+ ggml_transpose(ctx, // [p,m,qq,rr]
+ grad))); // [m,p,qq,rr]
+ }
+ } break;
+ case GGML_OP_SCALE: {
+ if (src0_needs_grads) {
+ float s;
+ memcpy(&s, tensor->op_params, sizeof(float));
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, false));
+ }
+ } break;
+ case GGML_OP_SET: {
+ const size_t nb1 = ((const int32_t *) tensor->op_params)[0];
+ const size_t nb2 = ((const int32_t *) tensor->op_params)[1];
+ const size_t nb3 = ((const int32_t *) tensor->op_params)[2];
+ const size_t offset = ((const int32_t *) tensor->op_params)[3];
+
+ struct ggml_tensor * tensor_grad_view = NULL;
+
+ if (src0_needs_grads || src1_needs_grads) {
+ GGML_ASSERT(src0->type == tensor->type);
+ GGML_ASSERT(!cgraph->grads[isrc0] || cgraph->grads[isrc0]->type == grad->type);
+ GGML_ASSERT(!cgraph->grads[isrc1] || !src1_needs_grads || cgraph->grads[isrc1]->type == grad->type);
+
+ tensor_grad_view = ggml_view_4d(ctx,
+ grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
+ nb1, nb2, nb3, offset);
+ }
- const int64_t elem_q = ggml_nelements(src0);
- const int64_t elem_k = ggml_nelements(src1);
- const int64_t elem_v = ggml_nelements(src2);
-
- enum ggml_type result_type = flash_grad->type;
- GGML_ASSERT(ggml_blck_size(result_type) == 1);
- const size_t tsize = ggml_type_size(result_type);
-
- const size_t offs_q = 0;
- const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
- const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
-
- if (src0->grad) {
- struct ggml_tensor * view_q = ggml_view_1d(ctx, flash_grad, elem_q, offs_q);
- struct ggml_tensor * grad_q = ggml_reshape(ctx, view_q, src0);
- src0->grad = ggml_add_or_set(ctx,
- src0->grad,
- grad_q,
- zero_table, acc_table);
- }
- if (src1->grad) {
- struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k);
- struct ggml_tensor * grad_k = ggml_reshape(ctx, view_k, src1);
- src1->grad = ggml_add_or_set(ctx,
- src1->grad,
- grad_k,
- zero_table, acc_table);
- }
- if (src2->grad) {
- struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v);
- struct ggml_tensor * grad_v = ggml_reshape(ctx, view_v, src2);
- src2->grad = ggml_add_or_set(ctx,
- src2->grad,
- grad_v,
- zero_table, acc_table);
+ if (src0_needs_grads) {
+ struct ggml_tensor * tmp = ggml_neg(ctx, tensor_grad_view);
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_acc_impl(ctx, grad, tmp, nb1, nb2, nb3, offset, false));
+ }
+
+ if (src1_needs_grads) {
+ ggml_add_or_set(ctx, cgraph, isrc1, ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1));
+ }
+ } break;
+ case GGML_OP_CPY: {
+ // cpy overwrites value of src1 by src0 and returns view(src1)
+ // the overwriting is mathematically equivalent to:
+ // tensor = src0 * 1 + src1 * 0
+ if (src0_needs_grads) {
+ // dsrc0 = dtensor * 1
+ ggml_add_or_set(ctx, cgraph, isrc0, grad);
+ }
+ if (src1_needs_grads) {
+ // dsrc1 = dtensor * 0 -> noop
+ }
+ } break;
+ case GGML_OP_CONT: {
+ // same as cpy
+ if (src0_needs_grads) {
+ GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0]));
+ GGML_ASSERT(ggml_is_contiguous(grad));
+ ggml_add_or_set(ctx, cgraph, isrc0, grad);
+ }
+ } break;
+ case GGML_OP_RESHAPE: {
+ if (src0_needs_grads) {
+ struct ggml_tensor * grad_cont = ggml_is_contiguous(grad) ? grad : ggml_cont(ctx, grad);
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad_cont, src0));
+ }
+ } break;
+ case GGML_OP_VIEW: {
+ if (src0_needs_grads) {
+ size_t offset;
+
+ memcpy(&offset, tensor->op_params, sizeof(offset));
+
+ size_t nb1 = tensor->nb[1];
+ size_t nb2 = tensor->nb[2];
+ size_t nb3 = tensor->nb[3];
+
+ if (cgraph->grads[isrc0] && src0->type != cgraph->grads[isrc0]->type) {
+ // gradient is typically F32, but src0 could be other type
+ size_t ng = ggml_element_size(cgraph->grads[isrc0]);
+ size_t n0 = ggml_element_size(src0);
+ GGML_ASSERT(offset % n0 == 0);
+ GGML_ASSERT(nb1 % n0 == 0);
+ GGML_ASSERT(nb2 % n0 == 0);
+ GGML_ASSERT(nb3 % n0 == 0);
+ offset = (offset / n0) * ng;
+ nb1 = (nb1 / n0) * ng;
+ nb2 = (nb2 / n0) * ng;
+ nb3 = (nb3 / n0) * ng;
}
- } break;
- case GGML_OP_FLASH_ATTN_BACK:
- {
- GGML_ABORT("fatal error"); // not supported
+
+ ggml_acc_or_set(ctx, cgraph, isrc0, src0, grad, nb1, nb2, nb3, offset);
}
- case GGML_OP_SSM_CONV:
- case GGML_OP_SSM_SCAN:
- {
- GGML_ABORT("fatal error"); // TODO: not implemented
+ } break;
+ case GGML_OP_PERMUTE: {
+ if (src0_needs_grads) {
+ const int32_t * axes = (const int32_t *) tensor->op_params;
+ const int axis0 = axes[0] & 0x3;
+ const int axis1 = axes[1] & 0x3;
+ const int axis2 = axes[2] & 0x3;
+ const int axis3 = axes[3] & 0x3;
+ int axb[4] = {0,0,0,0}; // axes backward
+ axb[axis0] = 0;
+ axb[axis1] = 1;
+ axb[axis2] = 2;
+ axb[axis3] = 3;
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_permute(ctx, grad, axb[0], axb[1], axb[2], axb[3]));
}
+ } break;
+ case GGML_OP_TRANSPOSE: {
+ if (src0_needs_grads) {
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_transpose(ctx, grad));
+ }
+ } break;
+ case GGML_OP_GET_ROWS: {
+ if (src0_needs_grads) {
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_get_rows_back(ctx, grad, src1, src0));
+ }
+ if (src1_needs_grads) {
+ // noop
+ }
+ } break;
+ case GGML_OP_DIAG_MASK_INF: {
+ if (src0_needs_grads) {
+ /* ggml_diag_mask_inf_impl() shouldn't be here */
+ /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */
+ const int n_past = ((const int32_t *) tensor->op_params)[0];
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false));
+ }
+ } break;
+ case GGML_OP_DIAG_MASK_ZERO: {
+ if (src0_needs_grads) {
+ const int n_past = ((const int32_t *) tensor->op_params)[0];
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false));
+ }
+ } break;
+ case GGML_OP_SOFT_MAX: {
+ if (src0_needs_grads) {
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_back(ctx, grad, tensor));
+ }
+ GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented");
+ } break;
+ case GGML_OP_ROPE: {
+ if (src0_needs_grads) {
+ //const int n_past = ((int32_t *) tensor->op_params)[0];
+ const int n_dims = ((const int32_t *) tensor->op_params)[1];
+ const int mode = ((const int32_t *) tensor->op_params)[2];
+ //const int n_ctx = ((int32_t *) tensor->op_params)[3];
+ const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4];
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+
+ memcpy(&freq_base, (const float *) tensor->op_params + 5, sizeof(float));
+ memcpy(&freq_scale, (const float *) tensor->op_params + 6, sizeof(float));
+ memcpy(&ext_factor, (const float *) tensor->op_params + 7, sizeof(float));
+ memcpy(&attn_factor, (const float *) tensor->op_params + 8, sizeof(float));
+ memcpy(&beta_fast, (const float *) tensor->op_params + 9, sizeof(float));
+ memcpy(&beta_slow, (const float *) tensor->op_params + 10, sizeof(float));
+
+ ggml_add_or_set(ctx, cgraph, isrc0,
+ ggml_rope_back(ctx, grad, src1, src2, n_dims, mode, n_ctx_orig, freq_base,
+ freq_scale, ext_factor, attn_factor, beta_fast, beta_slow));
+ }
+ GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented");
+ } break;
+ case GGML_OP_IM2COL: {
+ if (src1_needs_grads) {
+ const int32_t s0 = ggml_get_op_params_i32(tensor, 0);
+ const int32_t s1 = ggml_get_op_params_i32(tensor, 1);
+ const int32_t p0 = ggml_get_op_params_i32(tensor, 2);
+ const int32_t p1 = ggml_get_op_params_i32(tensor, 3);
+ const int32_t d0 = ggml_get_op_params_i32(tensor, 4);
+ const int32_t d1 = ggml_get_op_params_i32(tensor, 5);
+ const bool is_2D = ggml_get_op_params_i32(tensor, 6) == 1;
+
+ ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, src0, grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
+ }
+ } break;
+ case GGML_OP_POOL_2D: {
+ if (src0_needs_grads) {
+ const enum ggml_op_pool op = ggml_get_op_params_i32(tensor, 0);
+ const int32_t k0 = ggml_get_op_params_i32(tensor, 1);
+ const int32_t k1 = ggml_get_op_params_i32(tensor, 2);
+ const int32_t s0 = ggml_get_op_params_i32(tensor, 3);
+ const int32_t s1 = ggml_get_op_params_i32(tensor, 4);
+ const int32_t p0 = ggml_get_op_params_i32(tensor, 5);
+ const int32_t p1 = ggml_get_op_params_i32(tensor, 6);
+
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_pool_2d_back(ctx, grad, src0, op, k0, k1, s0, s1, p0, p1));
+ }
+ } break;
case GGML_OP_WIN_PART:
case GGML_OP_WIN_UNPART:
- case GGML_OP_UNARY:
- {
- switch (ggml_get_unary_op(tensor)) {
- case GGML_UNARY_OP_ABS:
- {
- if (src0->grad) {
- src0->grad =
- ggml_add_or_set(ctx,
- src0->grad,
- ggml_mul(ctx,
- ggml_sgn(ctx, src0),
- tensor->grad),
- zero_table, acc_table);
- }
- } break;
- case GGML_UNARY_OP_SGN:
- {
- if (src0->grad) {
- // noop
- }
- } break;
- case GGML_UNARY_OP_NEG:
- {
- if (src0->grad) {
- src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
- }
- } break;
- case GGML_UNARY_OP_STEP:
- {
- if (src0->grad) {
- // noop
- }
- } break;
- case GGML_UNARY_OP_TANH:
- {
- GGML_ABORT("fatal error"); // TODO: not implemented
- }
- case GGML_UNARY_OP_ELU:
- {
- GGML_ABORT("fatal error"); // TODO: not implemented
- }
- case GGML_UNARY_OP_RELU:
- {
- if (src0->grad) {
- src0->grad = ggml_add_or_set(ctx,
- src0->grad,
- ggml_mul(ctx,
- ggml_step(ctx, src0),
- tensor->grad),
- zero_table, acc_table);
- }
- } break;
- case GGML_UNARY_OP_SIGMOID:
- {
- GGML_ABORT("fatal error"); // TODO: not implemented
- }
- case GGML_UNARY_OP_GELU:
- {
- GGML_ABORT("fatal error"); // TODO: not implemented
- }
- case GGML_UNARY_OP_GELU_QUICK:
- {
- GGML_ABORT("fatal error"); // TODO: not implemented
- }
- case GGML_UNARY_OP_SILU:
- {
- // necessary for llama
- if (src0->grad) {
- src0->grad = ggml_add_or_set(ctx,
- src0->grad,
- ggml_silu_back(ctx, src0, tensor->grad),
- zero_table, acc_table);
- }
- } break;
- case GGML_UNARY_OP_EXP:
- {
- if (src0->grad) {
- src0->grad = ggml_add_or_set(ctx,
- src0->grad,
- ggml_mul(ctx, tensor, tensor->grad),
- zero_table, acc_table);
- }
- } break;
- default:
- GGML_ABORT("fatal error");
- }
- } break;
- case GGML_OP_GET_REL_POS:
- case GGML_OP_ADD_REL_POS:
- case GGML_OP_RWKV_WKV6:
- case GGML_OP_MAP_UNARY:
- case GGML_OP_MAP_BINARY:
- case GGML_OP_MAP_CUSTOM1_F32:
- case GGML_OP_MAP_CUSTOM2_F32:
- case GGML_OP_MAP_CUSTOM3_F32:
- case GGML_OP_MAP_CUSTOM1:
- case GGML_OP_MAP_CUSTOM2:
- case GGML_OP_MAP_CUSTOM3:
- {
- GGML_ABORT("fatal error"); // not supported
- }
- case GGML_OP_CROSS_ENTROPY_LOSS:
- {
- if (src0->grad) {
- src0->grad = ggml_add_or_set(ctx,
- src0->grad,
- ggml_cross_entropy_loss_back(ctx,
- src0,
- src1,
- tensor->grad),
- zero_table, acc_table);
- }
- GGML_ASSERT(!src1->grad && "backward pass for labels not implemented");
- } break;
- case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
- {
- GGML_ABORT("fatal error"); // not supported
+ case GGML_OP_UNARY: {
+ switch (ggml_get_unary_op(tensor)) {
+ case GGML_UNARY_OP_ABS: {
+ if (src0_needs_grads) {
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, ggml_sgn(ctx, src0), grad));
+ }
+ } break;
+ case GGML_UNARY_OP_SGN: {
+ // noop
+ } break;
+ case GGML_UNARY_OP_NEG: {
+ if (src0_needs_grads) {
+ ggml_sub_or_set(ctx, cgraph, isrc0, grad);
+ }
+ } break;
+ case GGML_UNARY_OP_STEP: {
+ // noop
+ } break;
+ case GGML_UNARY_OP_RELU: {
+ if (src0_needs_grads) {
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, ggml_step(ctx, src0), grad));
+ }
+ } break;
+ case GGML_UNARY_OP_SILU: {
+ if (src0_needs_grads) {
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, src0, grad));
+ }
+ } break;
+ case GGML_UNARY_OP_EXP: {
+ if (src0_needs_grads) {
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, tensor, grad));
+ }
+ } break;
+ default: {
+ fprintf(stderr, "%s: unsupported unary op for backward pass: %s\n",
+ __func__, ggml_unary_op_name(ggml_get_unary_op(tensor)));
+ GGML_ABORT("fatal error");
+ } break;
}
- case GGML_OP_OPT_STEP_ADAMW:
- {
- GGML_ABORT("fatal error"); // not supported
+ } break;
+ case GGML_OP_CROSS_ENTROPY_LOSS: {
+ if (src0_needs_grads) {
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, src0, src1, grad));
}
- case GGML_OP_NONE:
- {
- // nop
- } break;
+ GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
+ } break;
+ case GGML_OP_NONE: {
+ // noop
+ } break;
case GGML_OP_COUNT:
- {
- GGML_ABORT("fatal error");
- }
+ default: {
+ fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
+ GGML_ABORT("fatal error");
+ } break;
}
- for (int i = 0; i < GGML_MAX_SRC; ++i) {
- if (tensor->src[i] && tensor->src[i]->grad) {
- GGML_ASSERT(ggml_are_same_shape(tensor->src[i], tensor->src[i]->grad));
- }
- }
+ GGML_ASSERT(!src0_needs_grads || ggml_are_same_shape(src0, cgraph->grads[isrc0]));
+ GGML_ASSERT(!src1_needs_grads || ggml_are_same_shape(src1, cgraph->grads[isrc1]));
+ GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2]));
}
static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
- if (node->grad == NULL) {
- // this usually happens when we generate intermediate nodes from constants in the backward pass
- // it can also happen during forward pass, if the user performs computations with constants
- if (node->op != GGML_OP_NONE) {
- //GGML_PRINT_DEBUG("%s: warning: node %p has no grad, but op %d\n", __func__, (void *) node, node->op);
- }
- }
-
// check if already visited
if (ggml_hash_insert(&cgraph->visited_hash_set, node) == GGML_HASHSET_ALREADY_EXISTS) {
return;
ggml_build_forward_impl(cgraph, tensor, true);
}
-void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate) {
- GGML_ASSERT(gf->n_nodes > 0);
- GGML_ASSERT(gf->grads);
+void ggml_build_backward_expand(
+ struct ggml_context * ctx_static,
+ struct ggml_context * ctx_compute,
+ struct ggml_cgraph * cgraph,
+ bool accumulate) {
+ GGML_ASSERT(cgraph->n_nodes > 0);
+ GGML_ASSERT(cgraph->grads);
+ GGML_ASSERT(cgraph->grad_accs);
+
+ const int n_nodes_f = cgraph->n_nodes;
- for (int i = 0; i < gf->n_nodes; ++i) {
- struct ggml_tensor * node = gf->nodes[i];
+ const size_t hash_size = ggml_hash_size(2*cgraph->size);
+ memset(cgraph->grads, 0, hash_size*sizeof(struct ggml_tensor *));
+ memset(cgraph->grad_accs, 0, hash_size*sizeof(struct ggml_tensor *));
+ bool * grads_needed = calloc(hash_size, sizeof(bool));
+
+ {
+ bool any_params = false;
+ bool any_loss = false;
+ for (int i = 0; i < n_nodes_f; ++i) {
+ struct ggml_tensor * node = cgraph->nodes[i];
+ any_params = any_params || (node->flags & GGML_TENSOR_FLAG_PARAM);
+ any_loss = any_loss || (node->flags & GGML_TENSOR_FLAG_LOSS);
+ }
+ GGML_ASSERT(any_params && "no trainable parameters found, did you forget to call ggml_set_param?");
+ GGML_ASSERT(any_loss && "no training loss found, did you forget to call ggml_set_loss?");
+ }
+
+ for (int i = 0; i < n_nodes_f; ++i) {
+ struct ggml_tensor * node = cgraph->nodes[i];
if (node->type == GGML_TYPE_I32) {
continue;
}
- bool needs_grad = node->flags & GGML_TENSOR_FLAG_PARAM;
+ bool node_needs_grad = node->flags & GGML_TENSOR_FLAG_PARAM;
bool ignore_src[GGML_MAX_SRC] = {false};
switch (node->op) {
// gradients in node->src[0] for one reason or another have no effect on output gradients
break;
}
for (int j = 0; j < GGML_MAX_SRC; ++j) {
- if (!node->src[j] || !node->src[j]->grad || ignore_src[j]) {
+ if (!node->src[j] || ignore_src[j] || !grads_needed[ggml_hash_find(&cgraph->visited_hash_set, node->src[j])]) {
continue;
}
GGML_ASSERT(node->src[j]->type == GGML_TYPE_F32 || node->src[j]->type == GGML_TYPE_F16);
- needs_grad = true;
+ node_needs_grad = true;
break;
}
- if (!needs_grad) {
+ if (!node_needs_grad) {
continue;
}
GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW ||
node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE);
- // create a new tensor with the same type and shape as the node and set it as grad
- node->grad = ggml_dup_tensor(ctx, node);
- }
-
- // keep tables of original gradients for replacement/accumulation logic
- struct ggml_hash_set zero_table = ggml_hash_set_new(gf->size);
- struct ggml_hash_set acc_table = ggml_hash_set_new(gf->size);
- for (int i = 0; i < gf->n_nodes; i++) {
- struct ggml_tensor * node = gf->nodes[i];
-
- if (node->grad) {
- {
- const size_t insert_result = ggml_hash_insert(&zero_table, node->grad);
- GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
- GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
- }
-
- // only gradients of trainable parameters should be accumulated
- if (accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
- const size_t insert_result = ggml_hash_insert(&acc_table, node->grad);
- GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
- GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
- }
+ const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
+ if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {
+ cgraph->grads[igrad] = ggml_dup_tensor(ctx_static, node);
+ cgraph->grad_accs[igrad] = cgraph->grads[igrad];
}
+ grads_needed[igrad] = true;
}
- for (int i = gf->n_nodes - 1; i >= 0; i--) {
- struct ggml_tensor * node = gf->nodes[i];
-
+ for (int i = n_nodes_f - 1; i >= 0; --i) {
// inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
// use allocator to automatically make inplace operations
- if (node->grad) {
- ggml_compute_backward(ctx, node, &zero_table, &acc_table);
- }
+ ggml_compute_backward(ctx_compute, cgraph, i, grads_needed);
}
- for (int i = 0; i < gf->n_nodes; i++) {
- struct ggml_tensor * node = gf->nodes[i];
-
- if (node->flags & GGML_TENSOR_FLAG_PARAM) {
- GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
- ggml_build_forward_expand(gb, node->grad);
- }
- }
-
- ggml_hash_set_free(&zero_table);
- ggml_hash_set_free(&acc_table);
-}
-
-void ggml_build_opt_adamw(
- struct ggml_context * ctx,
- struct ggml_cgraph * gf,
- struct ggml_cgraph * gb,
- float alpha,
- float beta1,
- float beta2,
- float eps,
- float wd) {
- for (int i = 0; i < gf->n_nodes; i++) {
- struct ggml_tensor * node = gf->nodes[i];
-
- if (node->flags & GGML_TENSOR_FLAG_PARAM) {
- GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
- struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, node->grad, alpha, beta1, beta2, eps, wd);
- ggml_build_forward_expand(gb, opt_step);
- }
- }
+ free(grads_needed);
}
static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs
incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys
if (grads) {
- incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads
+ incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads
+ incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grad_accs
}
incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));
void * p = cgraph + 1;
- struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
- struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
- struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
- struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
+ struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
+ struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
+ struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
+ struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
+ struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
+
ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));
// check that we allocated the correct amount of memory
/*.n_leafs =*/ 0,
/*.nodes =*/ nodes_ptr,
/*.grads =*/ grads_ptr,
+ /*.grad_accs =*/ grad_accs_ptr,
/*.leafs =*/ leafs_ptr,
/*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr },
/*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
};
ggml_hash_set_reset(&cgraph->visited_hash_set);
+ if (grads) {
+ memset(cgraph->grads, 0, hash_size*sizeof(struct ggml_tensor *));
+ memset(cgraph->grad_accs, 0, hash_size*sizeof(struct ggml_tensor *));
+ }
return cgraph;
}
/*.n_leafs =*/ 0,
/*.nodes =*/ cgraph0->nodes + i0,
/*.grads =*/ cgraph0->grads ? cgraph0->grads + i0 : NULL,
+ /*.grad_accs =*/ cgraph0->grad_accs ? cgraph0->grad_accs + i0 : NULL,
/*.leafs =*/ NULL,
/*.hash_table =*/ { 0, NULL, NULL },
/*.order =*/ cgraph0->order,
dst->nodes[i] = src->nodes[i];
}
- if (src->grads) {
- GGML_ASSERT(dst->grads != NULL);
- for (int i = 0; i < src->n_nodes; ++i) {
- dst->grads[i] = src->grads[i];
- }
- }
-
for (size_t i = 0; i < src->visited_hash_set.size; ++i) {
// copy all hashset keys (tensors) that are in use
if (ggml_bitset_get(src->visited_hash_set.used, i)) {
ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
}
}
+
+ if (src->grads) {
+ GGML_ASSERT(dst->grads != NULL);
+ GGML_ASSERT(dst->grad_accs != NULL);
+ for (int i = 0; i < src->n_nodes; ++i) {
+ const size_t igrad_src = ggml_hash_find(&src->visited_hash_set, src->nodes[i]);
+ const size_t igrad_dst = ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]);
+ dst->grads[igrad_dst] = src->grads[igrad_src];
+ dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src];
+ }
+ }
}
struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
GGML_ASSERT(cgraph->grads != NULL);
for (int i = 0; i < cgraph->n_nodes; i++) {
- struct ggml_tensor * node = cgraph->nodes[i];
+ struct ggml_tensor * node = cgraph->nodes[i];
+ struct ggml_tensor * grad_acc = ggml_graph_get_grad_acc(cgraph, node);
+
+ if (node->op == GGML_OP_OPT_STEP_ADAMW) {
+ // clear momenta
+ if (node->src[2]->data) {
+ ggml_set_zero(node->src[2]);
+ }
+ if (node->src[3]->data) {
+ ggml_set_zero(node->src[3]);
+ }
+ }
// initial gradients of loss should be 1, 0 otherwise
- if (node->grad) {
+ if (grad_acc) {
if (node->flags & GGML_TENSOR_FLAG_LOSS) {
- GGML_ASSERT(node->grad->buffer);
- GGML_ASSERT(node->type == GGML_TYPE_F32);
- GGML_ASSERT(ggml_is_scalar(node));
+ GGML_ASSERT(grad_acc->type == GGML_TYPE_F32);
+ GGML_ASSERT(ggml_is_scalar(grad_acc));
const float onef = 1.0f;
- ggml_backend_tensor_set(node->grad, &onef, 0, ggml_nbytes(node->grad));
+ if (grad_acc->buffer) {
+ ggml_backend_tensor_set(grad_acc, &onef, 0, sizeof(float));
+ } else {
+ GGML_ASSERT(grad_acc->data);
+ *((float *) grad_acc->data) = onef;
+ }
} else {
- ggml_set_zero(node->grad);
+ ggml_set_zero(grad_acc);
}
}
-
- GGML_ASSERT(node);
- if (node->op == GGML_OP_OPT_STEP_ADAMW) {
- // set iteration to 1 and clear momenta
- ggml_set_op_params_i32(node, 0, 1);
- ggml_set_zero(node->src[2]);
- ggml_set_zero(node->src[3]);
- }
}
}
cgraph->n_nodes++;
}
-struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name) {
+struct ggml_tensor * ggml_graph_get_tensor(const struct ggml_cgraph * cgraph, const char * name) {
for (int i = 0; i < cgraph->n_leafs; i++) {
struct ggml_tensor * leaf = cgraph->leafs[i];
return NULL;
}
+struct ggml_tensor * ggml_graph_get_grad(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
+ const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
+ return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) ? cgraph->grads[igrad] : NULL;
+}
+
+struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
+ const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
+ return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) ? cgraph->grad_accs[igrad] : NULL;
+}
+
void ggml_graph_print(const struct ggml_cgraph * cgraph) {
GGML_LOG_INFO("=== GRAPH ===\n");
GGML_LOG_INFO(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s\n",
i,
node->ne[0], node->ne[1], node->ne[2],
- ggml_op_name(node->op), (node->flags & GGML_TENSOR_FLAG_PARAM) ? "x" : node->grad ? "g" : " ");
+ ggml_op_name(node->op), (node->flags & GGML_TENSOR_FLAG_PARAM) ? "x" :
+ ggml_graph_get_grad(cgraph, node) ? "g" : " ");
}
GGML_LOG_INFO("n_leafs = %d\n", cgraph->n_leafs);
static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
for (int i = 0; i < cgraph->n_nodes; i++) {
struct ggml_tensor * parent = cgraph->nodes[i];
+ struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, parent);
- if (parent->grad == node) {
+ if (grad == node) {
return parent;
}
}
for (int i = 0; i < gb->n_nodes; i++) {
struct ggml_tensor * node = gb->nodes[i];
+ struct ggml_tensor * grad = ggml_graph_get_grad(gb, node);
if (ggml_graph_get_parent(gb, node) != NULL) {
continue;
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
snprintf(color, sizeof(color), "yellow");
- } else if (node->grad) {
+ } else if (grad) {
if (ggml_graph_find(gf, node)) {
snprintf(color, sizeof(color), "green");
} else {
fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], node->ne[2], ggml_op_symbol(node->op));
}
- if (node->grad) {
- fprintf(fp, " | <g>%s\"; ]\n", ggml_op_symbol(node->grad->op));
+ if (grad) {
+ fprintf(fp, " | <g>%s\"; ]\n", ggml_op_symbol(grad->op));
} else {
fprintf(fp, "\"; ]\n");
}