]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
llama/ggml: add LLM training support (llama/10544)
authorJohannes Gäßler <redacted>
Mon, 12 May 2025 12:44:49 +0000 (14:44 +0200)
committerGeorgi Gerganov <redacted>
Tue, 13 May 2025 10:02:19 +0000 (13:02 +0300)
* llama/ggml: add LLM training support

more compact progress bar

llama_save_model_to_file

llama_opt_param_filter

ggml_graph_dup force_grads

refactor ggml_opt, fix test-opt

* remove logits_all

* refactor CUDA implementation for ACC

* reset graph at beginning of opt period

include/ggml-opt.h
include/ggml.h
src/ggml-backend.cpp
src/ggml-cuda/acc.cu
src/ggml-cuda/sum.cu
src/ggml-opt.cpp
src/ggml.c
tests/test-backend-ops.cpp
tests/test-opt.cpp

index eb5eab9de6781be9eaca8cc196080954b51ebcb4..da0c24b46fed96739dab53670ed72b243ca7dd65 100644 (file)
@@ -37,13 +37,16 @@ extern "C" {
     // ====== Dataset ======
 
     GGML_API ggml_opt_dataset_t ggml_opt_dataset_init(
-            int64_t ne_datapoint, // number of elements per datapoint
-            int64_t ne_label,     // number of elements per label
-            int64_t ndata,        // total number of datapoints/labels
-            int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
+            enum ggml_type type_data,    // the type for the internal data tensor
+            enum ggml_type type_label,   // the type for the internal labels tensor
+            int64_t        ne_datapoint, // number of elements per datapoint
+            int64_t        ne_label,     // number of elements per label
+            int64_t        ndata,        // total number of datapoints/labels
+            int64_t        ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
     GGML_API void ggml_opt_dataset_free(ggml_opt_dataset_t dataset);
 
     // get underlying tensors that store the data
+    GGML_API int64_t              ggml_opt_dataset_ndata (ggml_opt_dataset_t dataset);
     GGML_API struct ggml_tensor * ggml_opt_dataset_data  (ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata]
     GGML_API struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset); // shape = [nd_label,     ndata]
 
@@ -56,13 +59,19 @@ extern "C" {
             struct ggml_tensor * data_batch,   // shape = [ne_datapoint, ndata_batch]
             struct ggml_tensor * labels_batch, // shape = [ne_label,     ndata_batch]
             int64_t              ibatch);
+    GGML_API void ggml_opt_dataset_get_batch_host(
+            ggml_opt_dataset_t   dataset,
+            void               * data_batch,
+            size_t               nb_data_batch,
+            void               * labels_batch,
+            int64_t              ibatch);
 
     // ====== Model / Context ======
 
     enum ggml_opt_build_type {
-        GGML_OPT_BUILD_TYPE_FORWARD,
-        GGML_OPT_BUILD_TYPE_GRAD,
-        GGML_OPT_BUILD_TYPE_OPT,
+        GGML_OPT_BUILD_TYPE_FORWARD = 10,
+        GGML_OPT_BUILD_TYPE_GRAD    = 20,
+        GGML_OPT_BUILD_TYPE_OPT     = 30,
     };
 
     // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
@@ -81,20 +90,22 @@ extern "C" {
     // userdata can be used to pass arbitrary data
     typedef struct ggml_opt_optimizer_params (*ggml_opt_get_optimizer_params)(void * userdata);
 
-    // returns the default optimizer params (constant)
+    // returns the default optimizer params (constant, hard-coded values)
     // userdata is not used
     GGML_API struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata);
 
+    // casts userdata to ggml_opt_optimizer_params and returns it
+    GGML_API struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata);
+
     // parameters for initializing a new optimization context
     struct ggml_opt_params {
         ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs
 
-        struct ggml_context * ctx_compute; // created in user code, holds non-static tensors
-
-        // the forward graph is defined by inputs and outputs
-        // those tensors and all tensors inbetween are not intended to be reusable between multiple optimization contexts
-        struct ggml_tensor * inputs;
-        struct ggml_tensor * outputs;
+        // by default the forward graph needs to be reconstructed for each eval
+        // if ctx_compute, inputs, and outputs are set the graphs are instead allocated statically
+        struct ggml_context * ctx_compute;
+        struct ggml_tensor  * inputs;
+        struct ggml_tensor  * outputs;
 
         enum ggml_opt_loss_type  loss_type;
         enum ggml_opt_build_type build_type;
@@ -107,12 +118,9 @@ extern "C" {
 
     // get parameters for an optimization context with defaults set where possible
     // parameters for which no sensible defaults exist are supplied as arguments to this function
-    GGML_API ggml_opt_params ggml_opt_default_params(
-            ggml_backend_sched_t      backend_sched,
-            struct ggml_context     * ctx_compute,
-            struct ggml_tensor      * inputs,
-            struct ggml_tensor      * outputs,
-            enum ggml_opt_loss_type   loss_type);
+    GGML_API struct ggml_opt_params ggml_opt_default_params(
+            ggml_backend_sched_t    backend_sched,
+            enum ggml_opt_loss_type loss_type);
 
     GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params);
     GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx);
@@ -121,6 +129,7 @@ extern "C" {
     GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
 
     // get underlying tensors that store data
+    // if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc
     GGML_API struct ggml_tensor * ggml_opt_inputs(  ggml_opt_context_t opt_ctx); // forward graph input tensor
     GGML_API struct ggml_tensor * ggml_opt_outputs( ggml_opt_context_t opt_ctx); // forward graph output tensor
     GGML_API struct ggml_tensor * ggml_opt_labels(  ggml_opt_context_t opt_ctx); // labels to compare outputs against
@@ -128,11 +137,12 @@ extern "C" {
     GGML_API struct ggml_tensor * ggml_opt_pred(    ggml_opt_context_t opt_ctx); // predictions made by outputs
     GGML_API struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels
 
+    // get the gradient accumulator for a node from the forward graph
     GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
 
     // ====== Optimization Result ======
 
-    GGML_API ggml_opt_result_t ggml_opt_result_init();
+    GGML_API ggml_opt_result_t ggml_opt_result_init(void);
     GGML_API void ggml_opt_result_free(ggml_opt_result_t result);
     GGML_API void ggml_opt_result_reset(ggml_opt_result_t result);
 
@@ -144,11 +154,20 @@ extern "C" {
 
     // ====== Computation ======
 
-    // do forward pass, increment result if not NULL
-    GGML_API void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
+    // if not using static graphs, this function must be called prior to ggml_opt_alloc
+    GGML_API void ggml_opt_prepare_alloc(
+        ggml_opt_context_t    opt_ctx,
+        struct ggml_context * ctx_compute,
+        struct ggml_cgraph  * gf,
+        struct ggml_tensor  * inputs,
+        struct ggml_tensor  * outputs);
+
+    // allocate the next graph for evaluation, either forward or forward + backward
+    // must be called exactly once prior to calling ggml_opt_eval
+    GGML_API void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward);
 
-    // do forward pass, increment result if not NULL, do backward pass
-    GGML_API void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
+    // do forward pass, increment result if not NULL, do backward pass if allocated
+    GGML_API void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
 
     // ############################################################################
     // ## The high-level functions start here. They do not depend on any private ##
@@ -200,9 +219,9 @@ extern "C" {
     // fit model defined by inputs and outputs to dataset
     GGML_API void ggml_opt_fit(
             ggml_backend_sched_t            backend_sched,  // backend scheduler for constructing the compute graphs
-            ggml_context                  * ctx_compute,    // context with temporarily allocated tensors to calculate the outputs
-            ggml_tensor                   * inputs,         // input tensor with shape [ne_datapoint, ndata_batch]
-            ggml_tensor                   * outputs,        // output tensor, must have shape [ne_label, ndata_batch] if labels are used
+            struct ggml_context           * ctx_compute,    // context with temporarily allocated tensors to calculate the outputs
+            struct ggml_tensor            * inputs,         // input tensor with shape [ne_datapoint, ndata_batch]
+            struct ggml_tensor            * outputs,        // output tensor, must have shape [ne_label, ndata_batch] if labels are used
             ggml_opt_dataset_t              dataset,        // dataset with data and optionally also labels
             enum ggml_opt_loss_type         loss_type,      // loss to minimize
             ggml_opt_get_optimizer_params   get_opt_pars,   // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
index c518366d58a7a3e7230feab31a255be4a13b47d1..e91dedf14a1cbbcf7d554e4b12d95604ec8f168c 100644 (file)
@@ -768,7 +768,7 @@ extern "C" {
     // Tensor flags
     GGML_API void ggml_set_input(struct ggml_tensor * tensor);
     GGML_API void ggml_set_output(struct ggml_tensor * tensor);
-    GGML_API void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor);
+    GGML_API void ggml_set_param(struct ggml_tensor * tensor);
     GGML_API void ggml_set_loss(struct ggml_tensor * tensor);
 
     //
@@ -938,7 +938,7 @@ extern "C" {
     GGML_API struct ggml_tensor * ggml_repeat_back(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
+            struct ggml_tensor  * b); // sum up values that are adjacent in dims > 0 instead of repeated with same stride
 
     // concat a and b along dim
     // used in stable-diffusion
@@ -2049,15 +2049,14 @@ extern "C" {
 
     GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
     GGML_API void ggml_build_backward_expand(
-        struct ggml_context * ctx_static,  // context for static gradients (loss + gradient accumulation)
-        struct ggml_context * ctx_compute, // context for gradient computation
-        struct ggml_cgraph  * cgraph,
-        bool                  accumulate); // whether or not gradients should be accumulated, requires static allocation of tensors in ctx_static
+        struct ggml_context *  ctx,        // context for gradient computation
+        struct ggml_cgraph  *  cgraph,
+        struct ggml_tensor  ** grad_accs);
 
     // graph allocation in a context
     GGML_API struct ggml_cgraph * ggml_new_graph       (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
     GGML_API struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads);
-    GGML_API struct ggml_cgraph * ggml_graph_dup       (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
+    GGML_API struct ggml_cgraph * ggml_graph_dup       (struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads);
     GGML_API void                 ggml_graph_cpy       (struct ggml_cgraph * src, struct ggml_cgraph * dst);
     GGML_API void                 ggml_graph_reset     (struct ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1
     GGML_API void                 ggml_graph_clear     (struct ggml_cgraph * cgraph);
index 6f69d895f170d039a1dcea8f89403c26e64ab8b5..b30b4cb386f9fee7e1c4956ec65808ed99cd0f86 100644 (file)
@@ -1111,7 +1111,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
 
             const int node_backend_id = tensor_backend_id(node);
 
-            assert(node_backend_id != -1); // all nodes should be assigned by now
+            assert(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback
 
             // check if we should start a new split based on the sources of the current node
             bool need_new_split = false;
index 96bfe1c9d81470761639bb30dc6da5cddbe9e610..e084607c029a65e2201e7af43f7a491415c65037 100644 (file)
@@ -1,47 +1,61 @@
 #include "acc.cuh"
 
-static __global__ void acc_f32(const float * x, const float * y, float * dst, const int ne,
-    const int ne10, const int ne11, const int ne12,
-    const int nb1, const int nb2, int offset) {
-    const int i = blockDim.x * blockIdx.x + threadIdx.x;
+static __global__ void acc_f32(const float * x, const float * y, float * dst, const int64_t ne,
+        const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
+        const int64_t s11, const int64_t s12, const int64_t s13, const int64_t offset) {
+    const int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
+
     if (i >= ne) {
         return;
     }
-    int src1_idx = i - offset;
-    int oz = src1_idx / nb2;
-    int oy = (src1_idx - (oz * nb2)) / nb1;
-    int ox = src1_idx % nb1;
-    if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
-        dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
-    } else {
-        dst[i] = x[i];
+
+    int64_t src1_idx = i - offset;
+
+    int64_t tmp = src1_idx;
+    const int64_t i13 = tmp / s13;
+    tmp -= i13 * s13;
+    const int64_t i12 = tmp / s12;
+    tmp -= i12 * s12;
+    const int64_t i11 = tmp / s11;
+    tmp -= i11 * s11;
+    const int64_t i10 = tmp;
+
+    float val = x[i];
+    if (src1_idx >= 0 && i10 < ne10 && i11 < ne11 && i12 < ne12 && i13 < ne13) {
+        val += y[((i13*ne12 + i12) * ne11 + i11) * ne10 + i10];
     }
+    dst[i] = val;
 }
 
-static void acc_f32_cuda(const float * x, const float * y, float * dst, const int n_elements,
-    const int ne10, const int ne11, const int ne12,
-    const int nb1, const int nb2, const int offset, cudaStream_t stream) {
-    int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE;
-    acc_f32<<<num_blocks, CUDA_ACC_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset);
+static void acc_f32_cuda(const float * x, const float * y, float * dst, const int64_t n_elements,
+        const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
+        const int64_t s1, const int64_t s2, const int64_t s3, const int64_t offset, cudaStream_t stream) {
+    const int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE;
+    acc_f32<<<num_blocks, CUDA_ACC_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset);
 }
 
 void ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
-    const float * src0_d = (const float *)src0->data;
-    const float * src1_d = (const float *)src1->data;
-    float * dst_d = (float *)dst->data;
+
+    const float * src0_d = (const float *) src0->data;
+    const float * src1_d = (const float *) src1->data;
+    float       * dst_d  = (float       *)  dst->data;
+
     cudaStream_t stream = ctx.stream();
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
 
-    int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
-    int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
-    // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
-    int offset = dst->op_params[3] / 4; // offset in bytes
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(dst->nb[0] == ggml_element_size(dst));
+    GGML_ASSERT(ggml_is_contiguously_allocated(dst));
+
+    const int64_t s1     = dst->op_params[0] / sizeof(float);
+    const int64_t s2     = dst->op_params[1] / sizeof(float);
+    const int64_t s3     = dst->op_params[2] / sizeof(float);
+    const int64_t offset = dst->op_params[3] / sizeof(float);
 
-    acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, stream);
+    acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], s1, s2, s3, offset, stream);
 }
index f9589080a0c3b98dff0373ed0a28bce0d59087e8..eb3d7cdba98a7ae4b36433e14fe28975b4b3931a 100644 (file)
@@ -31,7 +31,7 @@ void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
-    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguously_allocated(src0));
 
     const float * src0_d = (const float *) src0->data;
     float * dst_d = (float *) dst->data;
index 7c3e24103a250ae47fd2cfee653eed1f85d338be..58d77578f458dafa9c97a6099f0618f716c7f601 100644 (file)
@@ -28,16 +28,19 @@ struct ggml_opt_dataset {
 };
 
 struct ggml_opt_context {
-    ggml_backend_sched_t    backend_sched        = nullptr;
-    ggml_cgraph           * allocated_graph      = nullptr;
-    ggml_cgraph           * allocated_graph_copy = nullptr;
-    struct ggml_context   * ctx_static           = nullptr;
-    struct ggml_context   * ctx_static_cpu       = nullptr;
-    struct ggml_context   * ctx_compute          = nullptr;
-    struct ggml_context   * ctx_copy             = nullptr;
-    ggml_backend_buffer_t   buf_static           = nullptr;
-    ggml_backend_buffer_t   buf_static_cpu       = nullptr;
-    std::mt19937            rng;
+    ggml_backend_sched_t       backend_sched        = nullptr;
+    ggml_cgraph              * allocated_graph      = nullptr;
+    ggml_cgraph              * allocated_graph_copy = nullptr;
+    struct ggml_context      * ctx_static           = nullptr;
+    struct ggml_context      * ctx_cpu              = nullptr;
+    struct ggml_context      * ctx_compute          = nullptr;
+    struct ggml_context      * ctx_copy             = nullptr;
+    ggml_backend_buffer_t      buf_static           = nullptr;
+    ggml_backend_buffer_t      buf_cpu              = nullptr;
+    std::mt19937               rng;
+    enum ggml_opt_loss_type    loss_type;
+    enum ggml_opt_build_type   build_type;
+    enum ggml_opt_build_type   build_type_alloc;
 
     struct ggml_tensor * inputs  = nullptr;
     struct ggml_tensor * outputs = nullptr;
@@ -50,6 +53,11 @@ struct ggml_opt_context {
     struct ggml_cgraph * gf      = nullptr;
     struct ggml_cgraph * gb_grad = nullptr;
     struct ggml_cgraph * gb_opt  = nullptr;
+    bool static_graphs           = false;
+    bool eval_ready              = false;
+    std::vector<struct ggml_tensor *> grad_accs;
+    std::vector<struct ggml_tensor *> grad_m;
+    std::vector<struct ggml_tensor *> grad_v;
 
     int64_t iter               = 1;
     int32_t opt_period         = 1;
@@ -73,7 +81,13 @@ struct ggml_opt_result {
 
 // ====== Dataset ======
 
-ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label, int64_t ndata, int64_t ndata_shard) {
+ggml_opt_dataset_t ggml_opt_dataset_init(
+        enum ggml_type type_data,
+        enum ggml_type type_label,
+        int64_t        ne_datapoint,
+        int64_t        ne_label,
+        int64_t        ndata,
+        int64_t        ndata_shard) {
     GGML_ASSERT(ne_datapoint >  0);
     GGML_ASSERT(ne_label     >= 0);
     GGML_ASSERT(ndata        >  0);
@@ -92,11 +106,11 @@ ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label,
         result->ctx = ggml_init(params);
     }
 
-    result->data = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_datapoint, ndata);
+    result->data = ggml_new_tensor_2d(result->ctx, type_data, ne_datapoint, ndata);
     result->nbs_data = ggml_nbytes(result->data) * ndata_shard/ndata;
 
     if (ne_label > 0) {
-        result->labels = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_label, ndata);
+        result->labels = ggml_new_tensor_2d(result->ctx, type_label, ne_label, ndata);
         result->nbs_labels = ggml_nbytes(result->labels) * ndata_shard/ndata;
     } else {
         result->labels = nullptr;
@@ -119,6 +133,10 @@ void ggml_opt_dataset_free(ggml_opt_dataset_t dataset) {
     delete dataset;
 }
 
+int64_t ggml_opt_dataset_ndata(ggml_opt_dataset_t dataset) {
+    return dataset->ndata;
+}
+
 struct ggml_tensor * ggml_opt_dataset_data(ggml_opt_dataset_t dataset) {
     return dataset->data;
 }
@@ -144,6 +162,8 @@ void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor *
     GGML_ASSERT(   data_batch && ggml_is_contiguous(data_batch));
     GGML_ASSERT(!labels_batch || ggml_is_contiguous(labels_batch));
     GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
+    GGML_ASSERT(                   data_batch->type == dataset->data->type);
+    GGML_ASSERT(!labels_batch || labels_batch->type == dataset->labels->type);
 
     const size_t nb_data_batch = ggml_nbytes(data_batch);
     GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
@@ -171,6 +191,31 @@ void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor *
     }
 }
 
+void ggml_opt_dataset_get_batch_host(ggml_opt_dataset_t dataset, void * data_batch, size_t nb_data_batch, void * labels_batch, int64_t ibatch) {
+    GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
+    GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
+
+    const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data;
+
+    GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size()));
+
+    for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) {
+        const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch];
+
+        const char * ptr_data       = (const char *) dataset->data->data + ishard      *dataset->nbs_data;
+        char       * ptr_data_batch = (char       *) data_batch          + ishard_batch*dataset->nbs_data;
+        memcpy(ptr_data_batch, ptr_data, dataset->nbs_data);
+
+        if (!labels_batch) {
+            continue;
+        }
+
+        const char * ptr_labels       = (const char *) dataset->labels->data + ishard      *dataset->nbs_labels;
+        char       * ptr_labels_batch = (char       *) labels_batch          + ishard_batch*dataset->nbs_labels;
+        memcpy(ptr_labels_batch, ptr_labels, dataset->nbs_labels);
+    }
+}
+
 // ====== Model / Context ======
 
 struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata) {
@@ -187,17 +232,18 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us
     return result;
 }
 
+struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) {
+    return *((struct ggml_opt_optimizer_params *) userdata);
+}
+
 struct ggml_opt_params ggml_opt_default_params(
         ggml_backend_sched_t      backend_sched,
-        struct ggml_context     * ctx_compute,
-        struct ggml_tensor      * inputs,
-        struct ggml_tensor      * outputs,
         enum ggml_opt_loss_type   loss_type) {
     return {
         /*backend_sched   =*/ backend_sched,
-        /*ctx_compute     =*/ ctx_compute,
-        /*inputs          =*/ inputs,
-        /*logits          =*/ outputs,
+        /*ctx_compute     =*/ nullptr,
+        /*inputs          =*/ nullptr,
+        /*logits          =*/ nullptr,
         /*loss_type       =*/ loss_type,
         /*build_type      =*/ GGML_OPT_BUILD_TYPE_OPT,
         /*opt_period      =*/ 1,
@@ -266,195 +312,246 @@ static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * src) {
     return dst;
 }
 
-static void ggml_opt_alloc_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph) {
-    GGML_ASSERT(graph);
-    if (opt_ctx->allocated_graph == graph) {
-        return;
-    }
-
-    ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph
-
-    {
-        ggml_init_params params = {
-            /*.mem_size   =*/ ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE,
-            /*.mem_buffer =*/ nullptr,
-            /*.no_alloc   =*/ true,
-        };
-        ggml_free(opt_ctx->ctx_copy);
-        opt_ctx->ctx_copy = ggml_init(params);
-    }
-
-    opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);
-
-    ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
-    opt_ctx->allocated_graph = graph;
-}
-
-ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
-    ggml_opt_context_t result = new struct ggml_opt_context;
-    result->backend_sched   = params.backend_sched;
-    result->ctx_compute     = params.ctx_compute;
-    result->inputs          = params.inputs;
-    result->outputs         = params.outputs;
-    result->opt_period      = params.opt_period;
-    result->get_opt_pars    = params.get_opt_pars;
-    result->get_opt_pars_ud = params.get_opt_pars_ud;
-
-    GGML_ASSERT(result->inputs->data && "the inputs must be allocated statically");
-    GGML_ASSERT(result->opt_period >= 1);
-
-    const bool accumulate = params.build_type == GGML_OPT_BUILD_TYPE_GRAD ||
-        (params.build_type == GGML_OPT_BUILD_TYPE_OPT && result->opt_period > 1);
+static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
+    GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc");
+    GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
 
-    ggml_set_input(result->inputs);
-    ggml_set_output(result->outputs);
+    const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD &&
+        !(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
 
-    result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
-    ggml_build_forward_expand(result->gf, result->outputs);
+    ggml_set_input(opt_ctx->inputs);
+    ggml_set_output(opt_ctx->outputs);
 
     int n_param = 0;
-    for (int i = 0; i < result->gf->n_nodes; ++i) {
-        if (result->gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) {
+    for (int i = 0; i < opt_ctx->gf->n_nodes; ++i) {
+        const struct ggml_tensor * node = opt_ctx->gf->nodes[i];
+        if (node->flags & GGML_TENSOR_FLAG_PARAM) {
             n_param++;
         }
+        GGML_ASSERT(!(node->flags & GGML_TENSOR_FLAG_LOSS) && "support for extra loss terms not implemented");
     }
 
-    {
+    if (!opt_ctx->ctx_static) {
         // The static context is used for:
-        //   - gradients (1 tensor per param if using gradient accumulation)
+        //   - gradients (1 per loss, 1 tensor per param if using gradient accumulation)
         //   - optimizer momenta (2 tensors per param)
-        //   - labels
-        //   - loss + its gradient (up to 5 tensors)
-        //   - pred
-        //   - ncorrect (2 tensors).
-        const size_t tensors_per_param = (accumulate ? 1 : 0) + (params.build_type == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
-        const size_t size_meta = (tensors_per_param*n_param + 9) * ggml_tensor_overhead();
+        //   - labels (if using static graphs)
+        //   - loss (if using static graphs, up to 5 tensors)
+        //   - pred (if using static graphs)
+        //   - ncorrect (if using static graphs, 2 tensors).
+        constexpr size_t n_loss = 1;
+        const size_t tensors_per_param = (accumulate ? 1 : 0) +
+            (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
+        const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
+        const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead();
         struct ggml_init_params params = {
             /*.mem_size   =*/ size_meta,
             /*.mem_buffer =*/ nullptr,
             /*.no_alloc   =*/ true,
         };
-        result->ctx_static = ggml_init(params);
+        opt_ctx->ctx_static = ggml_init(params);
     }
+    GGML_ASSERT(opt_ctx->build_type <= opt_ctx->build_type_alloc);
+
     {
-        // The static cpu context is used for:
-        //   - optimizer parameters (1 for the entire context)
+        // The cpu context is allocated statically if using static graphs, dynamically otherwise.
+        // It is used for:
+        //   - optimizer parameters (1 shared for all optimizer invocations)
         const size_t size_meta = 1 * ggml_tensor_overhead();
         struct ggml_init_params params = {
             /*.mem_size   =*/ size_meta,
             /*.mem_buffer =*/ nullptr,
             /*.no_alloc   =*/ true,
         };
-        result->ctx_static_cpu = ggml_init(params);
+        ggml_free(opt_ctx->ctx_cpu);
+        opt_ctx->ctx_cpu = ggml_init(params);
+
+        ggml_backend_buffer_free(opt_ctx->buf_cpu);
+        opt_ctx->buf_cpu = nullptr;
     }
 
+    struct ggml_context * ctx_results = opt_ctx->static_graphs ? opt_ctx->ctx_static : opt_ctx->ctx_compute;
 
-    switch (params.loss_type) {
+    switch (opt_ctx->loss_type) {
         case GGML_OPT_LOSS_TYPE_MEAN: {
-            result->loss = ggml_sum(result->ctx_static, result->outputs);
-            ggml_set_name(result->loss, "loss_sum");
-            const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs));
-            result->loss = ggml_scale(result->ctx_static, result->loss, scale);
-            ggml_set_name(result->loss, "loss_mean");
-            result->loss_per_datapoint = true;
+            opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs);
+            ggml_set_name(opt_ctx->loss, "loss_sum");
+            const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs));
+            opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale);
+            ggml_set_name(opt_ctx->loss, "loss_mean");
+            opt_ctx->loss_per_datapoint = true;
             break;
         }
         case GGML_OPT_LOSS_TYPE_SUM: {
-            result->loss = ggml_sum(result->ctx_static, result->outputs);
-            ggml_set_name(result->loss, "loss_sum");
-            result->loss_per_datapoint = false;
+            opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs);
+            ggml_set_name(opt_ctx->loss, "loss_sum");
+            opt_ctx->loss_per_datapoint = false;
             break;
         }
         case GGML_OPT_LOSS_TYPE_CROSS_ENTROPY: {
-            result->labels = ggml_dup_tensor(result->ctx_static, result->outputs);
-            ggml_set_input(result->labels);
-            ggml_set_name(result->labels, "labels");
-            result->loss = ggml_cross_entropy_loss(result->ctx_static, result->outputs, result->labels);
-            ggml_set_name(result->loss, "loss_cross_entropy");
-            if (result->opt_period > 1) {
-                result->loss = ggml_scale(result->ctx_static, result->loss, 1.0f / result->opt_period);
-                ggml_set_name(result->loss, "loss_cross_entropy_scaled");
+            opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs);
+            ggml_set_input(opt_ctx->labels);
+            ggml_set_name(opt_ctx->labels, "labels");
+            opt_ctx->loss = ggml_cross_entropy_loss(ctx_results, opt_ctx->outputs, opt_ctx->labels);
+            ggml_set_name(opt_ctx->loss, "loss_cross_entropy");
+            if (opt_ctx->opt_period > 1) {
+                opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, 1.0f / opt_ctx->opt_period);
+                ggml_set_name(opt_ctx->loss, "loss_cross_entropy_scaled");
             }
-            result->loss_per_datapoint = true;
+            opt_ctx->loss_per_datapoint = true;
             break;
         }
         case GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: {
-            result->labels = ggml_dup_tensor(result->ctx_static, result->outputs);
-            ggml_set_input(result->labels);
-            ggml_set_name(result->labels, "labels");
-            result->loss = ggml_sub(result->ctx_static, result->outputs, result->labels);
-            ggml_set_name(result->loss, "loss_error");
-            result->loss = ggml_sqr(result->ctx_static, result->loss);
-            ggml_set_name(result->loss, "loss_squared_error");
-            result->loss = ggml_sum(result->ctx_static, result->loss);
-            ggml_set_name(result->loss, "loss_sum_squared_error");
-            const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs));
-            result->loss = ggml_scale(result->ctx_static, result->loss, scale);
-            ggml_set_name(result->loss, "loss_mean_squared_error");
-            result->loss_per_datapoint = true;
+            opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs);
+            ggml_set_input(opt_ctx->labels);
+            ggml_set_name(opt_ctx->labels, "labels");
+            opt_ctx->loss = ggml_sub(ctx_results, opt_ctx->outputs, opt_ctx->labels);
+            ggml_set_name(opt_ctx->loss, "loss_error");
+            opt_ctx->loss = ggml_sqr(ctx_results, opt_ctx->loss);
+            ggml_set_name(opt_ctx->loss, "loss_squared_error");
+            opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->loss);
+            ggml_set_name(opt_ctx->loss, "loss_sum_squared_error");
+            const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs));
+            opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale);
+            ggml_set_name(opt_ctx->loss, "loss_mean_squared_error");
+            opt_ctx->loss_per_datapoint = true;
             break;
         }
     }
-    ggml_set_output(result->loss);
-    ggml_set_loss(result->loss);
-    ggml_build_forward_expand(result->gf, result->loss);
-
-    result->pred = ggml_argmax(result->ctx_static, result->outputs);
-    ggml_set_name(result->pred, "pred");
-    ggml_set_output(result->pred);
-    ggml_build_forward_expand(result->gf, result->pred);
+    ggml_set_output(opt_ctx->loss);
+    ggml_set_loss(opt_ctx->loss);
+    ggml_build_forward_expand(opt_ctx->gf, opt_ctx->loss);
+
+    if (opt_ctx->loss_type == GGML_OPT_LOSS_TYPE_CROSS_ENTROPY) {
+        opt_ctx->pred = ggml_argmax(ctx_results, opt_ctx->outputs);
+        ggml_set_name(opt_ctx->pred, "pred");
+        ggml_set_output(opt_ctx->pred);
+        ggml_build_forward_expand(opt_ctx->gf, opt_ctx->pred);
+
+        opt_ctx->ncorrect = ggml_count_equal(ctx_results, opt_ctx->pred, ggml_argmax(ctx_results, opt_ctx->labels));
+        ggml_set_name(opt_ctx->ncorrect, "ncorrect");
+        ggml_set_output(opt_ctx->ncorrect);
+        ggml_build_forward_expand(opt_ctx->gf, opt_ctx->ncorrect);
+    }
 
-    if (result->labels) {
-        result->ncorrect = ggml_count_equal(result->ctx_static, result->pred, ggml_argmax(result->ctx_static, result->labels));
-        ggml_set_name(result->ncorrect, "ncorrect");
-        ggml_set_output(result->ncorrect);
-        ggml_build_forward_expand(result->gf, result->ncorrect);
-    } else {
-        result->ncorrect = nullptr;
+    if (opt_ctx->buf_static) {
+        if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_FORWARD) {
+            return;
+        }
+    } else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_FORWARD) {
+        opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(
+            opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
+        return;
     }
 
-    if (params.build_type == GGML_OPT_BUILD_TYPE_FORWARD) {
-        result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
-        return result;
+    if (opt_ctx->grad_accs.empty()) {
+        GGML_ASSERT(opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD);
+
+        const int n_nodes = opt_ctx->gf->n_nodes;
+        opt_ctx->grad_accs.resize(n_nodes);
+        for (int i = 0; i < n_nodes; ++i) {
+            ggml_tensor * node = opt_ctx->gf->nodes[i];
+            if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {
+                opt_ctx->grad_accs[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
+            } else {
+                opt_ctx->grad_accs[i] = nullptr;
+            }
+        }
+
+        if (opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) {
+            opt_ctx->grad_m.resize(n_nodes);
+            opt_ctx->grad_v.resize(n_nodes);
+            for (int i = 0; i < n_nodes; ++i) {
+                ggml_tensor * node = opt_ctx->gf->nodes[i];
+                if (node->flags & GGML_TENSOR_FLAG_PARAM) {
+                    opt_ctx->grad_m[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
+                    opt_ctx->grad_v[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
+                } else {
+                    opt_ctx->grad_m[i] = nullptr;
+                    opt_ctx->grad_v[i] = nullptr;
+                }
+            }
+        }
     }
 
     // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
-    result->gb_grad = ggml_graph_dup(result->ctx_compute, result->gf);
-    ggml_build_backward_expand(result->ctx_static, result->ctx_compute, result->gb_grad, accumulate);
+    opt_ctx->gb_grad = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gf, /*force_grads =*/ true);
+    ggml_build_backward_expand(opt_ctx->ctx_compute, opt_ctx->gb_grad, opt_ctx->grad_accs.data());
 
-    if (params.build_type == GGML_OPT_BUILD_TYPE_GRAD) {
-        result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
-        ggml_graph_reset(result->gb_grad);
-        return result;
+    if (opt_ctx->buf_static) {
+        if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_GRAD) {
+            return;
+        }
+    } else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_GRAD) {
+        opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
+        ggml_graph_reset(opt_ctx->gb_grad);
     }
 
-    GGML_ASSERT(params.build_type == GGML_OPT_BUILD_TYPE_OPT);
+    GGML_ASSERT(opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT);
 
     // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
-    result->gb_opt = ggml_graph_dup(result->ctx_compute, result->gb_grad);
+    opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
 
-    result->adamw_params = ggml_new_tensor_1d(result->ctx_static_cpu, GGML_TYPE_F32, 7);
-    ggml_set_input(result->adamw_params);
-    ggml_set_name(result->adamw_params, "adamw_params");
+    opt_ctx->adamw_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, 7);
+    ggml_set_input(opt_ctx->adamw_params);
+    ggml_set_name(opt_ctx->adamw_params, "adamw_params");
 
-    for (int i = result->gf->n_nodes-1; i >= 0; --i) {
-        struct ggml_tensor * node = result->gb_opt->nodes[i];
-        struct ggml_tensor * grad = ggml_graph_get_grad(result->gb_opt, node);
+    for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
+        struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
+        struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node);
 
-        if (node->flags & GGML_TENSOR_FLAG_PARAM) {
-            struct ggml_tensor * m        = ggml_dup_tensor(result->ctx_static, node);
-            struct ggml_tensor * v        = ggml_dup_tensor(result->ctx_static, node);
-            struct ggml_tensor * opt_step = ggml_opt_step_adamw(result->ctx_compute, node, grad, m, v, result->adamw_params);
-            ggml_build_forward_expand(result->gb_opt, opt_step);
+        if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
+            struct ggml_tensor * m        = opt_ctx->grad_m[i];
+            struct ggml_tensor * v        = opt_ctx->grad_v[i];
+            struct ggml_tensor * opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, opt_ctx->adamw_params);
+
+            ggml_set_name(m,        (std::string("AdamW m for ")    + std::string(node->name)).c_str());
+            ggml_set_name(v,        (std::string("AdamW v for ")    + std::string(node->name)).c_str());
+            ggml_set_name(opt_step, (std::string("AdamW step for ") + std::string(node->name)).c_str());
+
+            ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
         }
     }
 
-    result->buf_static = ggml_backend_alloc_ctx_tensors(
-        result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
+    if (!opt_ctx->buf_static) {
+        opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(
+            opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0));
+        ggml_graph_reset(opt_ctx->gb_opt);
+    }
 
-    result->buf_static_cpu = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx_static_cpu, ggml_backend_cpu_buffer_type());
+    opt_ctx->buf_cpu = ggml_backend_alloc_ctx_tensors_from_buft(opt_ctx->ctx_cpu, ggml_backend_cpu_buffer_type());
+}
 
-    ggml_graph_reset(result->gb_opt);
+ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
+    ggml_opt_context_t result = new struct ggml_opt_context;
+    result->backend_sched    = params.backend_sched;
+    result->ctx_compute      = params.ctx_compute;
+    result->loss_type        = params.loss_type;
+    result->build_type       = params.build_type;
+    result->build_type_alloc = params.build_type;
+    result->inputs           = params.inputs;
+    result->outputs          = params.outputs;
+    result->opt_period       = params.opt_period;
+    result->get_opt_pars     = params.get_opt_pars;
+    result->get_opt_pars_ud  = params.get_opt_pars_ud;
+
+    GGML_ASSERT(result->opt_period >= 1);
+
+    result->static_graphs = result->ctx_compute;
+
+    if (!result->static_graphs) {
+        GGML_ASSERT(!result->inputs);
+        GGML_ASSERT(!result->outputs);
+        return result;
+    }
+
+    GGML_ASSERT(result->inputs);
+    GGML_ASSERT(result->outputs);
+
+    result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
+    ggml_build_forward_expand(result->gf, result->outputs);
+
+    ggml_opt_build(result);
 
     return result;
 }
@@ -464,9 +561,9 @@ void ggml_opt_free(ggml_opt_context_t opt_ctx) {
         return;
     }
     ggml_backend_buffer_free(opt_ctx->buf_static);
-    ggml_backend_buffer_free(opt_ctx->buf_static_cpu);
+    ggml_backend_buffer_free(opt_ctx->buf_cpu);
     ggml_free(opt_ctx->ctx_static);
-    ggml_free(opt_ctx->ctx_static_cpu);
+    ggml_free(opt_ctx->ctx_cpu);
     delete opt_ctx;
 }
 
@@ -582,8 +679,79 @@ void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, doubl
 
 // ====== Computation ======
 
-static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, ggml_opt_result * result) {
-    if (graph != opt_ctx->gf) {
+void ggml_opt_prepare_alloc(
+        ggml_opt_context_t    opt_ctx,
+        struct ggml_context * ctx_compute,
+        struct ggml_cgraph  * gf,
+        struct ggml_tensor  * inputs,
+        struct ggml_tensor  * outputs) {
+    GGML_ASSERT(!opt_ctx->static_graphs);
+    opt_ctx->ctx_compute = ctx_compute;
+    opt_ctx->gf          = gf;
+    opt_ctx->inputs      = inputs;
+    opt_ctx->outputs     = outputs;
+}
+
+void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward) {
+    GGML_ASSERT(!opt_ctx->eval_ready);
+    if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period > 1 && opt_ctx->opt_i == 0) {
+        ggml_graph_reset(opt_ctx->gb_grad);
+    }
+    if (backward) {
+        const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
+        opt_ctx->build_type = opt_i_next == 0 ? GGML_OPT_BUILD_TYPE_OPT : GGML_OPT_BUILD_TYPE_GRAD;
+    } else {
+        opt_ctx->build_type = GGML_OPT_BUILD_TYPE_FORWARD;
+    }
+
+    if (!opt_ctx->static_graphs) {
+        ggml_opt_build(opt_ctx);
+    }
+
+    struct ggml_cgraph * graph = nullptr;
+    switch (opt_ctx->build_type) {
+        case GGML_OPT_BUILD_TYPE_FORWARD: {
+            graph = opt_ctx->gf;
+        } break;
+        case GGML_OPT_BUILD_TYPE_GRAD: {
+            graph = opt_ctx->gb_grad;
+        } break;
+        case GGML_OPT_BUILD_TYPE_OPT: {
+            graph = opt_ctx->gb_opt;
+        } break;
+    }
+    GGML_ASSERT(graph);
+
+    if (opt_ctx->allocated_graph == graph) {
+        opt_ctx->eval_ready = true;
+        return;
+    }
+
+    ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph
+
+    if (opt_ctx->static_graphs) {
+        ggml_init_params params = {
+            /*.mem_size   =*/ graph->size*ggml_tensor_overhead() + ggml_graph_overhead_custom(graph->size, graph->grads),
+            /*.mem_buffer =*/ nullptr,
+            /*.no_alloc   =*/ true,
+        };
+        ggml_free(opt_ctx->ctx_copy);
+        opt_ctx->ctx_copy = ggml_init(params);
+
+        opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);
+    } else {
+        opt_ctx->allocated_graph_copy = graph;
+    }
+
+    ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
+    opt_ctx->allocated_graph = graph;
+
+    opt_ctx->eval_ready = true;
+}
+
+void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {
+    GGML_ASSERT(opt_ctx->eval_ready);
+    if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
         struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
 
         GGML_ASSERT(opt_pars.adamw.alpha >  0.0f);
@@ -609,9 +777,19 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph,
         adamw_par_data[6] = beta2h;
     }
 
-    ggml_opt_alloc_graph(opt_ctx, graph);
     ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
     opt_ctx->iter += opt_ctx->allocated_graph == opt_ctx->gb_opt;
+    opt_ctx->opt_i = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
+
+    if (!opt_ctx->static_graphs) {
+        opt_ctx->gf                   = nullptr;
+        opt_ctx->gb_grad              = nullptr;
+        opt_ctx->gb_opt               = nullptr;
+        opt_ctx->allocated_graph      = nullptr;
+        opt_ctx->allocated_graph_copy = nullptr;
+    }
+
+    opt_ctx->eval_ready = false;
 
     if (!result) {
         return;
@@ -635,12 +813,14 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph,
     ggml_backend_tensor_get(opt_ctx->loss, &loss, 0, ggml_nbytes(opt_ctx->loss));
     result->loss.push_back(loss);
 
-    GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32);
-    std::vector<int32_t> pred(ndata);
-    ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred));
-    result->pred.insert(result->pred.end(), pred.begin(), pred.end());
+    if (opt_ctx->pred) {
+        GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32);
+        std::vector<int32_t> pred(ndata);
+        ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred));
+        result->pred.insert(result->pred.end(), pred.begin(), pred.end());
+    }
 
-    if (!opt_ctx->labels || result->ncorrect < 0) {
+    if (!opt_ctx->ncorrect || result->ncorrect < 0) {
         result->ncorrect = -1;
         return;
     }
@@ -652,26 +832,6 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph,
     result->ncorrect += ncorrect;
 }
 
-void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) {
-    ggml_opt_eval_graph(opt_ctx, opt_ctx->gf, result);
-}
-
-void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) {
-    if (opt_ctx->opt_period == 1) {
-        ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
-        return;
-    }
-
-    const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
-    if (opt_i_next == 0) {
-        ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
-        ggml_opt_reset(opt_ctx, /*optimizer =*/ false);
-    } else {
-        ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_grad, result);
-    }
-    opt_ctx->opt_i = opt_i_next;
-}
-
 // ====== High-Level Functions ======
 
 void ggml_opt_epoch(
@@ -700,16 +860,18 @@ void ggml_opt_epoch(
     int64_t ibatch = 0;
     int64_t t_loop_start = ggml_time_us();
     for (; ibatch < ibatch_split; ++ibatch) {
+        ggml_opt_alloc(opt_ctx, /*backward =*/ true);
         ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
-        ggml_opt_forward_backward(opt_ctx, result_train);
+        ggml_opt_eval(opt_ctx, result_train);
         if (callback_train) {
             callback_train(true, opt_ctx, dataset, result_train, ibatch+1, ibatch_split, t_loop_start);
         }
     }
     t_loop_start = ggml_time_us();
     for (; ibatch < nbatches; ++ibatch) {
+        ggml_opt_alloc(opt_ctx, /*backward =*/ false);
         ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
-        ggml_opt_forward(opt_ctx, result_eval);
+        ggml_opt_eval(opt_ctx, result_eval);
         if (callback_eval) {
             callback_eval(false, opt_ctx, dataset, result_eval, ibatch+1-ibatch_split, nbatches-ibatch_split, t_loop_start);
         }
@@ -726,13 +888,26 @@ void ggml_opt_epoch_callback_progress_bar(
         int64_t            t_start_us) {
     fprintf(stderr, "%s[", train ? "train: " : "val:   ");
 
-    constexpr int64_t bar_length = 25;
+    // The progress bar consists of partially filled blocks, unicode has 8 separate fill levels.
+    constexpr int64_t bar_length = 8;
+    const int64_t ibatch8 = 8 * ibatch;
     for (int64_t j = 0; j < bar_length; ++j) {
-        const int64_t ibatch_j = ibatch_max * j/bar_length;
-        if (ibatch_j < ibatch) {
-            fprintf(stderr, "=");
-        } else if (ibatch_max * (j - 1)/bar_length < ibatch) {
-            fprintf(stderr, ">");
+        if        (ibatch_max * (8*j + 8) / bar_length < ibatch8) {
+            fprintf(stderr, "\u2588"); // full block
+        } else if (ibatch_max * (8*j + 7) / bar_length < ibatch8) {
+            fprintf(stderr, "\u2589"); // 7/8 filled
+        } else if (ibatch_max * (8*j + 6) / bar_length < ibatch8) {
+            fprintf(stderr, "\u258A"); // 6/8 filled
+        } else if (ibatch_max * (8*j + 5) / bar_length < ibatch8) {
+            fprintf(stderr, "\u258B"); // 5/8 filled
+        } else if (ibatch_max * (8*j + 4) / bar_length < ibatch8) {
+            fprintf(stderr, "\u258C"); // 4/8 filled
+        } else if (ibatch_max * (8*j + 3) / bar_length < ibatch8) {
+            fprintf(stderr, "\u258D"); // 3/8 filled
+        } else if (ibatch_max * (8*j + 2) / bar_length < ibatch8) {
+            fprintf(stderr, "\u258E"); // 2/8 filled
+        } else if (ibatch_max * (8*j + 1) / bar_length < ibatch8) {
+            fprintf(stderr, "\u258F"); // 1/8 filled
         } else {
             fprintf(stderr, " ");
         }
@@ -764,8 +939,8 @@ void ggml_opt_epoch_callback_progress_bar(
     const int64_t t_eta_m = t_eta_s / 60;
     t_eta_s -= t_eta_m * 60;
 
-    fprintf(stderr, "| data=%06" PRId64 "/%06" PRId64 ", loss=%.6lf+-%.6lf, accuracy=%.2lf+-%.2lf%%, "
-            "t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 ", ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 "]\r",
+    fprintf(stderr, "] data=%07" PRId64 "/%07" PRId64 " loss=%.5lf±%.5lf acc=%.2lf±%.2lf%% "
+            "t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " \r",
             idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc,
             t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s);
     if (ibatch == ibatch_max) {
@@ -806,7 +981,10 @@ void ggml_opt_fit(
 
     int64_t epoch = 1;
 
-    ggml_opt_params params = ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type);
+    ggml_opt_params params = ggml_opt_default_params(backend_sched, loss_type);
+    params.ctx_compute     = ctx_compute;
+    params.inputs          = inputs;
+    params.outputs         = outputs;
     params.opt_period      = opt_period;
     params.get_opt_pars    = get_opt_pars;
     params.get_opt_pars_ud = &epoch;
index bc673292b37a373104e5cf39607ccbcdd146d2f9..8a6546240f46f901f5c5d98edc69eab1d80c58ef 100644 (file)
@@ -5499,7 +5499,7 @@ static void ggml_compute_backward(
             // tensor = src0 * 1 + src1 * 0
             if (src0_needs_grads) {
                 // dsrc0 = dtensor * 1
-                ggml_add_or_set(ctx, cgraph, isrc0, grad);
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad, src0));
             }
             if (src1_needs_grads) {
                 // dsrc1 = dtensor * 0 -> noop
@@ -5780,10 +5780,9 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor *
 }
 
 void ggml_build_backward_expand(
-        struct ggml_context * ctx_static,
-        struct ggml_context * ctx_compute,
-        struct ggml_cgraph  * cgraph,
-        bool                  accumulate) {
+        struct ggml_context *  ctx,
+        struct ggml_cgraph  *  cgraph,
+        struct ggml_tensor  ** grad_accs) {
     GGML_ASSERT(cgraph->n_nodes > 0);
     GGML_ASSERT(cgraph->grads);
     GGML_ASSERT(cgraph->grad_accs);
@@ -5856,21 +5855,24 @@ void ggml_build_backward_expand(
         GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW ||
             node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE);
 
-        const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
-        GGML_ASSERT(igrad != GGML_HASHSET_FULL);
-        GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, igrad));
-        if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {
-            cgraph->grad_accs[igrad] = ggml_dup_tensor(ctx_static, node);
-            cgraph->grads[igrad]     = cgraph->grad_accs[igrad];
-            ggml_format_name(cgraph->grad_accs[igrad], "grad acc for %s", node->name);
+        const size_t ihash = ggml_hash_find(&cgraph->visited_hash_set, node);
+        GGML_ASSERT(ihash != GGML_HASHSET_FULL);
+        GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, ihash));
+        if (grad_accs && grad_accs[i]) {
+            cgraph->grad_accs[ihash] = grad_accs[i];
+            cgraph->grads[ihash]     = cgraph->grad_accs[ihash];
+        } else if (node->flags & GGML_TENSOR_FLAG_LOSS) {
+            // loss tensors always need a gradient accumulator
+            cgraph->grad_accs[ihash] = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne);
+            cgraph->grads[ihash]     = cgraph->grad_accs[ihash];
         }
-        grads_needed[igrad] = true;
+        grads_needed[ihash] = true;
     }
 
     for (int i = n_nodes_f - 1; i >= 0; --i) {
         // inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
         // use allocator to automatically make inplace operations
-        ggml_compute_backward(ctx_compute, cgraph, i, grads_needed);
+        ggml_compute_backward(ctx, cgraph, i, grads_needed);
     }
 
     free(grads_needed);
@@ -6016,8 +6018,8 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
     }
 }
 
-struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
-    struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads != NULL);
+struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads) {
+    struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads || force_grads);
     ggml_graph_cpy(cgraph, result);
     return result;
 }
@@ -6036,6 +6038,9 @@ struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
 }
 
 void ggml_graph_reset(struct ggml_cgraph * cgraph) {
+    if (!cgraph) {
+        return;
+    }
     GGML_ASSERT(cgraph->grads != NULL);
 
     for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -6345,8 +6350,8 @@ void ggml_set_output(struct ggml_tensor * tensor) {
     tensor->flags |= GGML_TENSOR_FLAG_OUTPUT;
 }
 
-void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor) {
-    GGML_UNUSED(ctx); // TODO: remove this parameter
+void ggml_set_param(struct ggml_tensor * tensor) {
+    GGML_ASSERT(tensor->op == GGML_OP_NONE);
     tensor->flags |= GGML_TENSOR_FLAG_PARAM;
 }
 
index 9ec24d9f23c5bc93b1b1e98e890e1186632358f7..543db93402190edca4a6dc4243ff8c7550ba8ab8 100644 (file)
@@ -823,7 +823,7 @@ struct test_case {
 
         ggml_build_forward_expand(gf, out);
         ggml_graph_cpy(gf, gb);
-        ggml_build_backward_expand(ctx.get(), ctx.get(), gb, false);
+        ggml_build_backward_expand(ctx.get(), gb, nullptr);
         if (expect.size() != 1 || expect[0] != 0.0f) {
             GGML_ASSERT(ggml_graph_n_nodes(gb) > ggml_graph_n_nodes(gf));
             for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
@@ -1026,7 +1026,7 @@ struct test_example : public test_case {
         // Step 3: return the output tensor.
         return out;
     }
-    // In order to also check the gradients for your op, add calls like ggml_set_param(ctx, a)
+    // In order to also check the gradients for your op, add calls like ggml_set_param(a)
     // immediately after you create the tensors.
     // This is optional and only makes sense if a backward pass has actually been implemented for the new op.
 };
@@ -1058,7 +1058,7 @@ struct test_unary : public test_case {
             auto ne = ne_a; ne[0] *= 3;
             a = ggml_new_tensor(ctx, type, 4, ne.data());
             if (grad_supported) {
-                ggml_set_param(ctx, a);
+                ggml_set_param(a);
             }
             ggml_set_name(a, "a");
 
@@ -1067,7 +1067,7 @@ struct test_unary : public test_case {
         } else {
             a = ggml_new_tensor(ctx, type, 4, ne_a.data());
             if (grad_supported) {
-                ggml_set_param(ctx, a);
+                ggml_set_param(a);
             }
             ggml_set_name(a, "a");
         }
@@ -1133,7 +1133,7 @@ struct test_get_rows : public test_case {
 
         const bool grad_supported = ggml_is_matrix(in) && ggml_is_vector(rows);
         if (grad_supported) {
-            ggml_set_param(ctx, in);
+            ggml_set_param(in);
             // rows is a constant input -> no gradients
         }
 
@@ -1322,7 +1322,7 @@ struct test_repeat : public test_case {
         ggml_set_name(target, "target");
 
         ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, src);
+        ggml_set_param(src);
         ggml_set_name(src, "src");
 
         ggml_tensor * out = ggml_repeat(ctx, src, target);
@@ -1406,7 +1406,7 @@ struct test_dup : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, src);
+        ggml_set_param(src);
         ggml_set_name(src, "src");
 
         if (_use_permute) {
@@ -1442,7 +1442,7 @@ struct test_set : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
-        ggml_set_param(ctx, src);
+        ggml_set_param(src);
         ggml_set_name(src, "src");
 
         auto ne_dst = ne;
@@ -1450,7 +1450,7 @@ struct test_set : public test_case {
             ne_dst[i] *= 2;
         }
         ggml_tensor* dst = ggml_new_tensor(ctx, type_dst, 4, ne_dst.data());
-        ggml_set_param(ctx, dst);
+        ggml_set_param(dst);
         ggml_set_name(dst, "dst");
 
         size_t offset = 0;
@@ -1498,7 +1498,7 @@ struct test_cpy : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
-        ggml_set_param(ctx, src);
+        ggml_set_param(src);
         ggml_set_name(src, "src");
 
         if (_src_use_permute) {
@@ -1536,7 +1536,7 @@ struct test_cont : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, src);
+        ggml_set_param(src);
         ggml_set_name(src, "src");
 
         src = ggml_transpose(ctx, src);
@@ -1583,8 +1583,8 @@ struct test_bin_bcast : public test_case {
         // The backward pass supports broadcasting only for GGML_ADD:
         const bool grad_supported = op == ggml_add || ggml_are_same_shape(a, b);
         if (grad_supported) {
-            ggml_set_param(ctx, a);
-            ggml_set_param(ctx, b);
+            ggml_set_param(a);
+            ggml_set_param(b);
         }
 
         ggml_tensor * out = op(ctx, a, b);
@@ -1632,11 +1632,11 @@ struct test_add1 : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, a);
+        ggml_set_param(a);
         ggml_set_name(a, "a");
 
         ggml_tensor * b = ggml_new_tensor_1d(ctx, type, 1);
-        // ggml_set_param(ctx, b); // TODO: implement
+        // ggml_set_param(b); // TODO: implement
         ggml_set_name(b, "b");
 
         ggml_tensor * out = ggml_add1(ctx, a, b);
@@ -1667,7 +1667,7 @@ struct test_scale : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, a);
+        ggml_set_param(a);
         ggml_set_name(a, "a");
 
         ggml_tensor * out = ggml_scale(ctx, a, scale);
@@ -1762,7 +1762,7 @@ struct test_rms_norm : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, a);
+        ggml_set_param(a);
         ggml_set_name(a, "a");
 
         if (v) {
@@ -2028,9 +2028,9 @@ struct test_mul_mat : public test_case {
             b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]);
             if (!ggml_is_quantized(type_a)) {
                 if (bs[1] == 1 && nr[1] == 1) {
-                    ggml_set_param(ctx, a);
+                    ggml_set_param(a);
                 }
-                ggml_set_param(ctx, b);
+                ggml_set_param(b);
             }
             ggml_set_name(a, "a");
             ggml_set_name(b, "b");
@@ -2040,22 +2040,29 @@ struct test_mul_mat : public test_case {
             ggml_set_name(a, "a_permuted");
             ggml_set_name(b, "b_permuted");
         } else {
-
             if (v) {
                 a = ggml_new_tensor_4d(ctx, type_a, k*2, m, bs[0],       bs[1]);
                 b = ggml_new_tensor_4d(ctx, type_b, k*2, n, bs[0]*nr[0], bs[1]*nr[1]);
 
+                if (!ggml_is_quantized(type_a)) {
+                    if (bs[1] == 1 && nr[1] == 1) {
+                        ggml_set_param(a);
+                    }
+                    ggml_set_param(b);
+                }
+
                 a = ggml_view_4d(ctx, a, k, m, bs[0],       bs[1],       a->nb[1], a->nb[2], a->nb[3], 0);
                 b = ggml_view_4d(ctx, b, k, n, bs[0]*nr[0], bs[1]*nr[1], b->nb[1], b->nb[2], b->nb[3], 0);
             } else {
                 a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0],       bs[1]);
                 b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
-            }
-            if (!ggml_is_quantized(type_a)) {
-                if (bs[1] == 1 && nr[1] == 1) {
-                    ggml_set_param(ctx, a);
+
+                if (!ggml_is_quantized(type_a)) {
+                    if (bs[1] == 1 && nr[1] == 1) {
+                        ggml_set_param(a);
+                    }
+                    ggml_set_param(b);
                 }
-                ggml_set_param(ctx, b);
             }
             ggml_set_name(a, "a");
             ggml_set_name(b, "b");
@@ -2204,7 +2211,7 @@ struct test_sqr : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, a);
+        ggml_set_param(a);
         ggml_set_name(a, "a");
 
         ggml_tensor * out = ggml_sqr(ctx, a);
@@ -2233,7 +2240,7 @@ struct test_sqrt : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, a);
+        ggml_set_param(a);
         ggml_set_name(a, "a");
 
         ggml_tensor * out = ggml_sqrt(ctx, a);
@@ -2273,7 +2280,7 @@ struct test_log : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, a);
+        ggml_set_param(a);
         ggml_set_name(a, "a");
 
         ggml_tensor * out = ggml_log(ctx, a);
@@ -2309,7 +2316,7 @@ struct test_sin : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, a);
+        ggml_set_param(a);
         ggml_set_name(a, "a");
 
         ggml_tensor * out = ggml_sin(ctx, a);
@@ -2352,7 +2359,7 @@ struct test_cos : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, a);
+        ggml_set_param(a);
         ggml_set_name(a, "a");
 
         ggml_tensor * out = ggml_cos(ctx, a);
@@ -2432,7 +2439,7 @@ struct test_diag_mask_inf : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, a);
+        ggml_set_param(a);
         ggml_set_name(a, "a");
 
         ggml_tensor * out = ggml_diag_mask_inf(ctx, a, n_past);
@@ -2471,7 +2478,7 @@ struct test_soft_max : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, a);
+        ggml_set_param(a);
         ggml_set_name(a, "a");
 
         ggml_tensor * mask = nullptr;
@@ -2553,7 +2560,7 @@ struct test_rope : public test_case {
             auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3;
             a = ggml_new_tensor(ctx, type, 4, ne.data());
             if (forward) {
-                ggml_set_param(ctx, a);
+                ggml_set_param(a);
             }
             ggml_set_name(a, "a");
 
@@ -2562,7 +2569,7 @@ struct test_rope : public test_case {
         } else {
             a = ggml_new_tensor(ctx, type, 4, ne_a.data());
             if (forward) {
-                ggml_set_param(ctx, a);
+                ggml_set_param(a);
             }
             ggml_set_name(a, "a");
         }
@@ -2676,7 +2683,7 @@ struct test_pool2d : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
-        ggml_set_param(ctx, input);
+        ggml_set_param(input);
         ggml_set_name(input, "input");
 
         ggml_tensor * out = ggml_pool_2d(ctx, input, pool_type, k0, k1, s0, s1, p0, p1);
@@ -2752,7 +2759,7 @@ struct test_im2col : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
-        ggml_set_param(ctx, input);
+        ggml_set_param(input);
         ggml_set_name(input, "input");
 
         ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data());
@@ -2929,7 +2936,7 @@ struct test_sum : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, a);
+        ggml_set_param(a);
         ggml_set_name(a, "a");
 
         ggml_tensor * out = ggml_sum(ctx, a);
@@ -2958,7 +2965,7 @@ struct test_sum_rows : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, a);
+        ggml_set_param(a);
         ggml_set_name(a, "a");
 
         ggml_tensor * out = ggml_sum_rows(ctx, a);
@@ -2983,7 +2990,7 @@ struct test_mean : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, a);
+        ggml_set_param(a);
         ggml_set_name(a, "a");
 
         ggml_tensor * out = ggml_mean(ctx, a);
@@ -3129,11 +3136,11 @@ struct test_acc : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
-        ggml_set_param(ctx, a);
+        ggml_set_param(a);
         ggml_set_name(a, "a");
 
         ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
-        ggml_set_param(ctx, b);
+        ggml_set_param(b);
         ggml_set_name(b, "b");
 
         ggml_tensor * out = ggml_acc(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], b->nb[1]);
@@ -3370,7 +3377,7 @@ struct test_cross_entropy_loss : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * logits = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, logits);
+        ggml_set_param(logits);
         ggml_set_name(logits, "logits");
 
         ggml_tensor * labels = ggml_new_tensor(ctx, type, 4, ne.data());
@@ -3452,7 +3459,7 @@ struct test_opt_step_adamw : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
-        ggml_set_param(ctx, a); // Despite tensor a having gradients the output tensor will not.
+        ggml_set_param(a); // Despite tensor a having gradients the output tensor will not.
         ggml_set_name(a, "a");
 
         ggml_tensor * grad = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
index 1bc160511357186c827287e6a7a9c9f59cc47764..558f877210e7d734078ae0fea327f3a19b4086f3 100644 (file)
@@ -57,7 +57,8 @@ static helper_ctx_data helper_get_ctx_data(
         enum ggml_opt_loss_type loss_type          = GGML_OPT_LOSS_TYPE_SUM) {
     std::vector<ggml_opt_dataset_t> datasets(ndata);
     for (int64_t ndata_shard = 1; ndata_shard <= ndata; ++ndata_shard) {
-        ggml_opt_dataset_t dataset = ggml_opt_dataset_init(ne_datapoint, ne_label, ndata, ndata_shard);
+        ggml_opt_dataset_t dataset = ggml_opt_dataset_init(
+            GGML_TYPE_F32, GGML_TYPE_F32, ne_datapoint, ne_label, ndata, ndata_shard);
 
         float * data   = ggml_get_data_f32(ggml_opt_dataset_data(  dataset));
         float * labels = ggml_get_data_f32(ggml_opt_dataset_labels(dataset));
@@ -74,7 +75,8 @@ static helper_ctx_data helper_get_ctx_data(
         datasets[ndata_shard-1] = dataset;
     }
 
-    ggml_opt_dataset_t dataset_unsupervised = ggml_opt_dataset_init(1, 0, ndata, /*ndata_shard =*/ 1);
+    ggml_opt_dataset_t dataset_unsupervised = ggml_opt_dataset_init(
+        GGML_TYPE_F32, GGML_TYPE_F32, 1, 0, ndata, /*ndata_shard =*/ 1);
 
     float * data = ggml_get_data_f32(ggml_opt_dataset_data(dataset_unsupervised));
 
@@ -113,7 +115,7 @@ static helper_ctx_data helper_get_ctx_data(
 
     struct ggml_tensor * weights = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1);
     ggml_set_name(weights, "weights");
-    ggml_set_param(ctx_static, weights);
+    ggml_set_param(weights);
 
     struct ggml_tensor * intermediary = ggml_add(ctx_compute, inputs, weights);
 
@@ -127,8 +129,11 @@ static helper_ctx_data helper_get_ctx_data(
     GGML_ASSERT(nbatch_logical % nbatch_physical == 0);
     const int32_t opt_period = nbatch_logical / nbatch_physical;
 
-    struct ggml_opt_params opt_params = ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type);
-    opt_params.opt_period = opt_period;
+    struct ggml_opt_params opt_params = ggml_opt_default_params(backend_sched, loss_type);
+    opt_params.ctx_compute = ctx_compute;
+    opt_params.inputs      = inputs;
+    opt_params.outputs     = outputs;
+    opt_params.opt_period  = opt_period;
     if (!optimizer_defaults) {
         opt_params.get_opt_pars = helper_get_test_opt_pars;
     }
@@ -264,8 +269,9 @@ static std::pair<int, int> test_grad(ggml_backend_sched_t backend_sched, ggml_ba
 
     for (int idata = 0; idata < ndata; ++idata) {
         const float idataf = idata;
+        ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
         ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
-        ggml_opt_forward_backward(cd.opt_ctx, cd.result);
+        ggml_opt_eval(cd.opt_ctx, cd.result);
         ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata, 0, sizeof(float));
     }
 
@@ -334,8 +340,9 @@ static std::pair<int, int> test_forward_backward(
     } else {
         for (int idata = 0; idata < ndata; ++idata) {
             const float idataf = idata;
+            ggml_opt_alloc(cd.opt_ctx, /*backward =*/ false);
             ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
-            ggml_opt_forward(cd.opt_ctx, cd.result);
+            ggml_opt_eval(cd.opt_ctx, cd.result);
             ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));
         }
     }
@@ -367,7 +374,8 @@ static std::pair<int, int> test_forward_backward(
     float w0;
     ggml_backend_tensor_get(cd.weights, &w0, 0, sizeof(float));
     for (int i = 0; i < 10; ++i) {
-        ggml_opt_forward_backward(cd.opt_ctx, nullptr);
+        ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
+        ggml_opt_eval(cd.opt_ctx, cd.result);
     }
     ggml_backend_tensor_set(cd.weights, &w0, 0, sizeof(float));
 
@@ -387,8 +395,9 @@ static std::pair<int, int> test_forward_backward(
     } else {
         for (int idata = 0; idata < ndata; ++idata) {
             const float idataf = idata;
+            ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
             ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
-            ggml_opt_forward_backward(cd.opt_ctx, cd.result);
+            ggml_opt_eval(cd.opt_ctx, cd.result);
             ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));
         }
     }
@@ -492,14 +501,16 @@ static std::pair<int, int> test_idata_split(ggml_backend_sched_t backend_sched,
             int idata = 0;
             for (; idata < idata_split; ++idata) {
                 const float idataf = idata;
+                ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
                 ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
-                ggml_opt_forward_backward(cd.opt_ctx, cd.result);
+                ggml_opt_eval(cd.opt_ctx, cd.result);
                 ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));
             }
             for (; idata < ndata; ++idata) {
                 const float idataf = idata;
+                ggml_opt_alloc(cd.opt_ctx, /*backward =*/ false);
                 ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
-                ggml_opt_forward(cd.opt_ctx, cd.result2);
+                ggml_opt_eval(cd.opt_ctx, cd.result2);
                 ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));
             }
         }
@@ -573,7 +584,6 @@ static std::pair<int, int> test_gradient_accumulation(
 
     struct helper_ctx_data cd = helper_get_ctx_data(
         backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false, /*nbatch_logical =*/ 6, nbatch_physical, loss_type);
-    struct ggml_tensor * loss = ggml_opt_loss(cd.opt_ctx);
 
     std::vector<float> grad_history(ndata);
     for (int64_t idata = 0; idata < ndata; ++idata) {
@@ -584,15 +594,17 @@ static std::pair<int, int> test_gradient_accumulation(
         if (nbatch_physical == 1) {
             for (int idata = 0; idata < ndata; ++idata) {
                 const float idataf = idata;
+                ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
                 ggml_backend_tensor_set(cd.inputs, &idataf, 0, 1*sizeof(float));
-                ggml_opt_forward_backward(cd.opt_ctx, cd.result);
+                ggml_opt_eval(cd.opt_ctx, cd.result);
                 ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata, 0, 1*sizeof(float));
             }
         } else if (nbatch_physical == 2) {
             for (int idata = 0; idata < ndata; idata += 2) {
                 const float idataf[2] = {float(idata + 0), float(idata + 1)};
+                ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
                 ggml_backend_tensor_set(cd.inputs, idataf, 0, 2*sizeof(float));
-                ggml_opt_forward_backward(cd.opt_ctx, cd.result);
+                ggml_opt_eval(cd.opt_ctx, cd.result);
 
                 grad_history[idata + 0] = 0.0f;
                 ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata + 1, 0, 1*sizeof(float));
@@ -617,7 +629,7 @@ static std::pair<int, int> test_gradient_accumulation(
                 }
                 subtest_ok = subtest_ok && almost_equal(grad_history[1], 2.0, atol);
                 subtest_ok = subtest_ok && almost_equal(grad_history[3], 4.0, atol);
-                subtest_ok = subtest_ok && almost_equal(grad_history[5], 0.0, atol);
+                subtest_ok = subtest_ok && almost_equal(grad_history[5], 6.0, atol);
             } else if (loss_type == GGML_OPT_LOSS_TYPE_MEAN) {
                 if (nbatch_physical == 1) {
                     subtest_ok = subtest_ok && almost_equal(grad_history[0], 1.0/ndata, atol);
@@ -630,7 +642,7 @@ static std::pair<int, int> test_gradient_accumulation(
                 }
                 subtest_ok = subtest_ok && almost_equal(grad_history[1], 2.0/ndata, atol);
                 subtest_ok = subtest_ok && almost_equal(grad_history[3], 4.0/ndata, atol);
-                subtest_ok = subtest_ok && almost_equal(grad_history[5], 0.0/ndata, atol);
+                subtest_ok = subtest_ok && almost_equal(grad_history[5], 6.0/ndata, atol);
             } else {
                 GGML_ASSERT(false);
             }
@@ -692,7 +704,8 @@ static std::pair<int, int> test_regression(ggml_backend_sched_t backend_sched, g
     std::mt19937 gen(12345);
     std::normal_distribution<float> nd{0.0f, 0.1f};
 
-    ggml_opt_dataset_t dataset = ggml_opt_dataset_init(1, 1, ndata_regression, ndata_regression);
+    ggml_opt_dataset_t dataset = ggml_opt_dataset_init(
+        GGML_TYPE_F32, GGML_TYPE_F32, 1, 1, ndata_regression, ndata_regression);
 
     float * data   = ggml_get_data_f32(ggml_opt_dataset_data(  dataset));
     float * labels = ggml_get_data_f32(ggml_opt_dataset_labels(dataset));
@@ -733,15 +746,14 @@ static std::pair<int, int> test_regression(ggml_backend_sched_t backend_sched, g
 
     struct ggml_tensor * a = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1);
     ggml_set_name(a, "a");
-    ggml_set_param(ctx_static, a);
+    ggml_set_param(a);
 
     struct ggml_tensor * b = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1);
     ggml_set_name(b, "b");
-    ggml_set_param(ctx_static, b);
+    ggml_set_param(b);
 
     struct ggml_tensor * f = ggml_add(ctx_compute, ggml_mul(ctx_compute, x, a), b);
     ggml_set_name(f, "f");
-    ggml_set_param(ctx_static, f);
 
     ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx_static, backend);
     const float a0 = 1.0f;