]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml: new optimization interface (#988)
authorJohannes Gäßler <redacted>
Sat, 16 Nov 2024 12:49:35 +0000 (13:49 +0100)
committerGitHub <redacted>
Sat, 16 Nov 2024 12:49:35 +0000 (13:49 +0100)
* ggml: new optimization interface

remove test2.c, test3.c

store adamw params in tensor

move grads from tensor to graph

* avoid segfault upon API misuse

* add ggml-opt.h to public headers

* remove dependence of ggml-opt.cpp on ggml-cpu.h

30 files changed:
CMakeLists.txt
examples/CMakeLists.txt
examples/mnist/README.md
examples/mnist/mnist-common.cpp
examples/mnist/mnist-common.h
examples/mnist/mnist-eval.cpp
examples/mnist/mnist-train.cpp
include/ggml-backend.h
include/ggml-opt.h [new file with mode: 0644]
include/ggml.h
src/CMakeLists.txt
src/ggml-alloc.c
src/ggml-backend.cpp
src/ggml-cpu/ggml-cpu.c
src/ggml-cuda/opt-step-adamw.cu
src/ggml-impl.h
src/ggml-metal/ggml-metal.m
src/ggml-opt.cpp [new file with mode: 0644]
src/ggml.c
tests/CMakeLists.txt
tests/test-backend-ops.cpp
tests/test-grad0.cpp [deleted file]
tests/test-mul-mat0.c
tests/test-opt.cpp
tests/test1.c [deleted file]
tests/test1.zig [deleted file]
tests/test2.c [deleted file]
tests/test2.zig [deleted file]
tests/test3.c [deleted file]
tests/test3.zig [deleted file]

index 4fb78e59fa72c770fb816c469933564399a7d663..fd949982645470e6013bce1eab1c04f8ec6c2f57 100644 (file)
@@ -228,6 +228,7 @@ set(GGML_PUBLIC_HEADERS
     include/ggml-cann.h
     include/ggml-cuda.h
     include/ggml-kompute.h
+    include/ggml-opt.h
     include/ggml-metal.h
     include/ggml-rpc.h
     include/ggml-sycl.h
index 79c7f2442b95a75b385ef288375ff9178f7ffb27..b273a1addb9add9d40bec3a054a1ae708702e527 100644 (file)
@@ -20,7 +20,7 @@ target_include_directories(common-ggml PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
 
 add_subdirectory(gpt-2)
 add_subdirectory(gpt-j)
-add_subdirectory(mnist)
+add_subdirectory(mnist)
 add_subdirectory(sam)
 add_subdirectory(yolo)
 add_subdirectory(simple)
index 9e0966f441974c2e7428b00f1259af91627bbce4..10df02dbb425fbd865f3ce987f0aac72da999f88 100644 (file)
@@ -18,7 +18,7 @@ $ python3 mnist-train-fc.py mnist-fc-f32.gguf
 
 ...
 
-Test loss: 0.066051+-0.011630, Test accuracy: 98.07+-0.14%
+Test loss: 0.066377+-0.010468, Test accuracy: 97.94+-0.14%
 
 Model tensors saved to mnist-fc-f32.gguf:
 fc1.weight       (500, 784)
@@ -61,22 +61,21 @@ ________________________________________________________
 ________________________________________________________
 ________________________________________________________
 ________________________________________________________
-mnist_graph_eval: trying to load a ggml graph from mnist-fc-f32.gguf
-ggml_graph_import: invalid magic number, got 46554747
-mnist_graph_eval: could not load a ggml graph from mnist-fc-f32.gguf
 ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
 ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
 ggml_cuda_init: found 1 CUDA devices:
   Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
-mnist_model: using CPU backend
+mnist_model: using CUDA0 (NVIDIA GeForce RTX 3090) as primary backend
+mnist_model: unsupported operations will be executed on the following fallback backends (in order of priority):
+mnist_model:  - CPU (AMD Ryzen 9 5950X 16-Core Processor)
 mnist_model_init_from_file: loading model weights from 'mnist-fc-f32.gguf'
 mnist_model_init_from_file: model arch is mnist-fc
 mnist_model_init_from_file: successfully loaded weights from mnist-fc-f32.gguf
-main: loaded model in 13.03 ms
-mnist_model_eval: model evaluation on 10000 images took 95.02 ms, 9.50 us/image
+main: loaded model in 109.44 ms
+mnist_model_eval: model evaluation on 10000 images took 76.92 ms, 7.69 us/image
 main: predicted digit is 3
-main: test_loss=0.066051+-0.009343
-main: test_acc=98.07+-0.14%
+main: test_loss=0.066379+-0.009101
+main: test_acc=97.94+-0.14%
 ```
 
 In addition to the evaluation on the test set the GGML evaluation also prints a random image from the test set as well as the model prediction for said image.
@@ -87,10 +86,6 @@ $ ../../build/bin/mnist-train mnist-fc mnist-fc-f32.gguf data/MNIST/raw/train-im
 ```
 
 It can then be evaluated with the same binary as above.
-When training a model with GGML the computation graph for the forward pass is also exported to `mnist-fc-f32.ggml`.
-Compared to the GGUF (which only contains the weights) this file also contains the model architecture.
-As long as the input and output tensors are well-defined an exported GGML graph is fully agnostic w.r.t. the model architecture.
-It can be evaluated using the `mnist-eval` binary by substituting the argument for the GGUF file.
 
 ## Convolutional network
 
@@ -101,8 +96,8 @@ $ python3 mnist-train-cnn.py mnist-cnn-f32.gguf
 
 ...
 
-Test loss: 0.045483
-Test accuracy: 98.56%
+Test loss: 0.047947
+Test accuracy: 98.46%
 GGUF model saved to 'mnist-cnn-f32.gguf'
 ```
 
@@ -139,25 +134,24 @@ ________________________________________________________
 ________________________________________________________
 ________________________________________________________
 ________________________________________________________
-mnist_graph_eval: trying to load a ggml graph from mnist-cnn-f32.gguf
-ggml_graph_import: invalid magic number, got 46554747
-mnist_graph_eval: could not load a ggml graph from mnist-cnn-f32.gguf
 ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
 ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
 ggml_cuda_init: found 1 CUDA devices:
   Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
-mnist_model: using CPU backend
+mnist_model: using CUDA0 (NVIDIA GeForce RTX 3090) as primary backend
+mnist_model: unsupported operations will be executed on the following fallback backends (in order of priority):
+mnist_model:  - CPU (AMD Ryzen 9 5950X 16-Core Processor)
 mnist_model_init_from_file: loading model weights from 'mnist-cnn-f32.gguf'
 mnist_model_init_from_file: model arch is mnist-cnn
 mnist_model_init_from_file: successfully loaded weights from mnist-cnn-f32.gguf
-main: loaded model in 11.88 ms
-mnist_model_eval: model evaluation on 10000 images took 1074.09 ms, 107.41 us/image
+main: loaded model in 91.99 ms
+mnist_model_eval: model evaluation on 10000 images took 267.61 ms, 26.76 us/image
 main: predicted digit is 1
-main: test_loss=0.045483+-0.006884
-main: test_acc=98.56+-0.12%
+main: test_loss=0.047955+-0.007029
+main: test_acc=98.46+-0.12%
 ```
 
-Like with the fully connected network the convolutional network can also be trained on the CPU using GGML:
+Like with the fully connected network the convolutional network can also be trained using GGML:
 
 ``` bash
 $ ../../build/bin/mnist-train mnist-cnn mnist-cnn-f32.gguf data/MNIST/raw/train-images-idx3-ubyte data/MNIST/raw/train-labels-idx1-ubyte
@@ -165,11 +159,12 @@ $ ../../build/bin/mnist-train mnist-cnn mnist-cnn-f32.gguf data/MNIST/raw/train-
 
 As always, the evaluation is done using `mnist-eval` and like with the fully connected network the GGML graph is exported to `mnist-cnn-f32.ggml`.
 
-## CUDA
+## Hardware Acceleration
 
-The fully connected model can be trained and evaluated using CUDA.
-`mnist-train` and `mnist-eval` accept an additional, optional argument behind those listed so far to specify the backend.
-The default is `CPU`, by specifying `CUDA0` the first available CUDA device can be used instead (make sure to compile GGML with CUDA cupport).
+Both the training and evaluation code is agnostic in terms of hardware as long as the corresponding GGML backend has implemented the necessary operations.
+A specific backend can be selected by appending the above commands with a backend name.
+The compute graphs then schedule the operations to preferentially use the specified backend.
+Note that if a backend does not implement some of the necessary operations a CPU fallback is used instead which may result in bad performance.
 
 ## Web demo
 
index 0c4ca07d01297d654aafd7cf98c9fcf60ae724ca..c848823b033c07b71565f169fc79ebe3069b19fc 100644 (file)
@@ -1,6 +1,7 @@
+#include "ggml.h"
 #include "ggml-alloc.h"
 #include "ggml-backend.h"
-#include "ggml.h"
+#include "ggml-opt.h"
 
 #include "mnist-common.h"
 
@@ -14,7 +15,7 @@
 #include <string>
 #include <utility>
 
-bool mnist_image_load(const std::string & fname, mnist_dataset & dataset) {
+bool mnist_image_load(const std::string & fname, ggml_opt_dataset_t dataset) {
     auto fin = std::ifstream(fname, std::ios::binary);
     if (!fin) {
         fprintf(stderr, "failed to open images file %s\n", fname.c_str());
@@ -23,12 +24,14 @@ bool mnist_image_load(const std::string & fname, mnist_dataset & dataset) {
     fin.seekg(16);
 
     uint8_t image[MNIST_NINPUT];
-    float * buf = ggml_get_data_f32(dataset.data);
+    struct ggml_tensor * images = ggml_opt_dataset_data(dataset);
+    float * buf = ggml_get_data_f32(images);
 
-    for (int iex = 0; iex < dataset.nex; ++iex) {
+    GGML_ASSERT(images->ne[0] == MNIST_NINPUT);
+    for (int64_t iex = 0; iex < images->ne[1]; ++iex) {
         fin.read((char *) image, sizeof(image));
 
-        for (int i = 0; i < MNIST_NINPUT; ++i) {
+        for (int64_t i = 0; i < MNIST_NINPUT; ++i) {
             buf[iex*MNIST_NINPUT + i] = image[i] / 255.0f; // Normalize to [0, 1]
         }
     }
@@ -36,11 +39,14 @@ bool mnist_image_load(const std::string & fname, mnist_dataset & dataset) {
     return true;
 }
 
-void mnist_image_print(FILE * stream, mnist_dataset & dataset, const int iex) {
-    const float * image = ggml_get_data_f32(dataset.data) + iex*MNIST_NINPUT;
+void mnist_image_print(FILE * stream, ggml_opt_dataset_t dataset, const int iex) {
+    struct ggml_tensor * images = ggml_opt_dataset_data(dataset);
+    GGML_ASSERT(images->ne[0] == MNIST_NINPUT);
+    GGML_ASSERT(iex < images->ne[1]);
+    const float * image = ggml_get_data_f32(images) + iex*MNIST_NINPUT;
 
-    for (int row = 0; row < MNIST_HW; row++) {
-        for (int col = 0; col < MNIST_HW; col++) {
+    for (int64_t row = 0; row < MNIST_HW; row++) {
+        for (int64_t col = 0; col < MNIST_HW; col++) {
             const int rgb = roundf(255.0f * image[row*MNIST_HW + col]);
 #ifdef _WIN32
             fprintf(stream, "%s", rgb >= 220 ? "##" : "__");                // Represented via text.
@@ -52,7 +58,7 @@ void mnist_image_print(FILE * stream, mnist_dataset & dataset, const int iex) {
     }
 }
 
-bool mnist_label_load(const std::string & fname, mnist_dataset & dataset) {
+bool mnist_label_load(const std::string & fname, ggml_opt_dataset_t dataset) {
     auto fin = std::ifstream(fname, std::ios::binary);
     if (!fin) {
         fprintf(stderr, "failed to open labels file %s\n", fname.c_str());
@@ -61,12 +67,14 @@ bool mnist_label_load(const std::string & fname, mnist_dataset & dataset) {
     fin.seekg(8);
 
     uint8_t label;
-    float * buf = ggml_get_data_f32(dataset.labels);
+    struct ggml_tensor * labels = ggml_opt_dataset_labels(dataset);
+    float * buf = ggml_get_data_f32(labels);
 
-    for (int iex = 0; iex < dataset.nex; ++iex) {
+    GGML_ASSERT(labels->ne[0] == MNIST_NCLASSES);
+    for (int64_t iex = 0; iex < labels->ne[1]; ++iex) {
         fin.read((char *) &label, sizeof(label));
 
-        for (int i = 0; i < MNIST_NCLASSES; ++i) {
+        for (int64_t i = 0; i < MNIST_NCLASSES; ++i) {
             buf[iex*MNIST_NCLASSES + i] = i == label ? 1.0f : 0.0f;
         }
     }
@@ -74,99 +82,6 @@ bool mnist_label_load(const std::string & fname, mnist_dataset & dataset) {
     return true;
 }
 
-mnist_eval_result mnist_graph_eval(const std::string & fname, const float * images, const float * labels, const int nex, const int nthreads) {
-    fprintf(stderr, "%s: trying to load a ggml graph from %s\n", __func__, fname.c_str());
-    mnist_eval_result result;
-
-    struct ggml_context * ctx_data = nullptr;
-    struct ggml_context * ctx_eval = nullptr;
-
-    struct ggml_cgraph * gf;
-    {
-        const int64_t t_start_us = ggml_time_us();
-
-        gf = ggml_graph_import(fname.c_str(), &ctx_data, &ctx_eval);
-
-        const int64_t t_total_us = ggml_time_us() - t_start_us;
-        const double t_total_ms = 1e-3*t_total_us;
-        if (gf) {
-            fprintf(stderr, "%s: graph import took %.2lf ms\n", __func__, t_total_ms);
-        }
-    }
-
-    if (!gf) {
-        fprintf(stderr, "%s: could not load a ggml graph from %s\n", __func__, fname.c_str());
-        return result;
-    }
-    fprintf(stderr, "%s: successfully loaded a ggml graph from %s\n", __func__, fname.c_str());
-
-    const size_t buf_size = 100 * 1024*1024;
-    void * buf_compute = malloc(buf_size);
-
-    struct ggml_init_params params = {
-        /*.mem_size   =*/ buf_size,
-        /*.mem_buffer =*/ buf_compute,
-        /*.no_alloc   =*/ false,
-    };
-
-    struct ggml_context * ctx_compute = ggml_init(params);
-
-    struct ggml_tensor * images_batch = ggml_graph_get_tensor(gf, "images");
-    GGML_ASSERT(images_batch);
-    GGML_ASSERT(images_batch->ne[0] == MNIST_NINPUT || (images_batch->ne[0] == MNIST_HW && images_batch->ne[1] == MNIST_HW));
-
-    struct ggml_tensor * labels_batch = ggml_graph_get_tensor(gf, "labels");
-    GGML_ASSERT(labels_batch);
-    GGML_ASSERT(labels_batch->ne[0] == MNIST_NCLASSES);
-    GGML_ASSERT(labels_batch->ne[2] == 1);
-    GGML_ASSERT(labels_batch->ne[3] == 1);
-
-    const int nbatch = labels_batch->ne[1];
-    GGML_ASSERT(nex % nbatch == 0);
-
-    struct ggml_tensor * logits_batch = ggml_graph_get_tensor(gf, "logits");
-    GGML_ASSERT(logits_batch);
-    GGML_ASSERT(logits_batch->ne[0] == MNIST_NCLASSES);
-    GGML_ASSERT(logits_batch->ne[1] == nbatch);
-    GGML_ASSERT(logits_batch->ne[2] == 1);
-    GGML_ASSERT(logits_batch->ne[3] == 1);
-
-    GGML_ASSERT(images_batch->ne[1] == logits_batch->ne[1] || images_batch->ne[3] == logits_batch->ne[1]);
-
-    struct ggml_tensor * loss = ggml_graph_get_tensor(gf, "loss");
-
-    {
-        const int64_t t_start_us = ggml_time_us();
-
-        for (int iex0; iex0 < nex; iex0 += nbatch) {
-            memcpy(images_batch->data, images + iex0*MNIST_NINPUT,   ggml_nbytes(images_batch));
-            memcpy(labels_batch->data, labels + iex0*MNIST_NCLASSES, ggml_nbytes(labels_batch));
-            ggml_graph_compute_with_ctx(ctx_compute, gf, nthreads);
-
-            for (int iexb = 0; iexb < nbatch; ++iexb) {
-                const float * probs_data = ggml_get_data_f32(logits_batch) + iexb*MNIST_NCLASSES;
-
-                result.pred.push_back(std::max_element(probs_data, probs_data + MNIST_NCLASSES) - probs_data);
-            }
-
-            result.loss.push_back(*ggml_get_data_f32(loss));
-        }
-
-        const int64_t t_total_us = ggml_time_us() - t_start_us;
-        const double t_total_ms = 1e-3*t_total_us;
-        fprintf(stderr, "%s: model evaluation on %d images took %.2lf ms, %.2lf us/image\n",
-                __func__, nex, t_total_ms, (double) t_total_us/nex);
-    }
-
-    ggml_free(ctx_data);
-    ggml_free(ctx_eval);
-    ggml_free(ctx_compute);
-    free(buf_compute);
-
-    result.success = true;
-    return result;
-}
-
 // Temporary util function for loading data from GGUF to a backend != CPU until GGML itself provides this functionality:
 bool load_from_gguf(const char * fname, struct ggml_context * ctx_ggml, struct gguf_context * ctx_gguf) {
     FILE * f = ggml_fopen(fname, "rb");
@@ -213,15 +128,15 @@ bool load_from_gguf(const char * fname, struct ggml_context * ctx_ggml, struct g
     return true;
 }
 
-mnist_model mnist_model_init_from_file(const std::string & fname, const std::string & backend) {
-    mnist_model model(backend);
+mnist_model mnist_model_init_from_file(const std::string & fname, const std::string & backend, const int nbatch_logical, const int nbatch_physical) {
+    mnist_model model(backend, nbatch_logical, nbatch_physical);
     fprintf(stderr, "%s: loading model weights from '%s'\n", __func__, fname.c_str());
 
     struct gguf_context * ctx;
     {
         struct gguf_init_params params = {
             /*.no_alloc   =*/ true,
-            /*.ctx        =*/ &model.ctx_weight,
+            /*.ctx        =*/ &model.ctx_gguf,
         };
         ctx = gguf_init_from_file(fname.c_str(), params);
         if (!ctx) {
@@ -233,66 +148,66 @@ mnist_model mnist_model_init_from_file(const std::string & fname, const std::str
     fprintf(stderr, "%s: model arch is %s\n", __func__, model.arch.c_str());
 
     if (model.arch == "mnist-fc") {
-        model.fc1_weight = ggml_get_tensor(model.ctx_weight, "fc1.weight");
+        model.fc1_weight = ggml_get_tensor(model.ctx_gguf, "fc1.weight");
         GGML_ASSERT(model.fc1_weight->ne[0] == MNIST_NINPUT);
         GGML_ASSERT(model.fc1_weight->ne[1] == MNIST_NHIDDEN);
         GGML_ASSERT(model.fc1_weight->ne[2] == 1);
         GGML_ASSERT(model.fc1_weight->ne[3] == 1);
 
-        model.fc1_bias = ggml_get_tensor(model.ctx_weight, "fc1.bias");
+        model.fc1_bias = ggml_get_tensor(model.ctx_gguf, "fc1.bias");
         GGML_ASSERT(model.fc1_bias->ne[0] == MNIST_NHIDDEN);
         GGML_ASSERT(model.fc1_bias->ne[1] == 1);
         GGML_ASSERT(model.fc1_bias->ne[2] == 1);
         GGML_ASSERT(model.fc1_bias->ne[3] == 1);
 
-        model.fc2_weight = ggml_get_tensor(model.ctx_weight, "fc2.weight");
+        model.fc2_weight = ggml_get_tensor(model.ctx_gguf, "fc2.weight");
         GGML_ASSERT(model.fc2_weight->ne[0] == MNIST_NHIDDEN);
         GGML_ASSERT(model.fc2_weight->ne[1] == MNIST_NCLASSES);
         GGML_ASSERT(model.fc2_weight->ne[2] == 1);
         GGML_ASSERT(model.fc2_weight->ne[3] == 1);
 
-        model.fc2_bias = ggml_get_tensor(model.ctx_weight, "fc2.bias");
+        model.fc2_bias = ggml_get_tensor(model.ctx_gguf, "fc2.bias");
         GGML_ASSERT(model.fc2_bias->ne[0] == MNIST_NCLASSES);
         GGML_ASSERT(model.fc2_bias->ne[1] == 1);
         GGML_ASSERT(model.fc2_bias->ne[2] == 1);
         GGML_ASSERT(model.fc2_bias->ne[3] == 1);
     } else if (model.arch == "mnist-cnn") {
-        model.conv1_kernel = ggml_get_tensor(model.ctx_weight, "conv1.kernel");
+        model.conv1_kernel = ggml_get_tensor(model.ctx_gguf, "conv1.kernel");
         GGML_ASSERT(model.conv1_kernel->type == GGML_TYPE_F32);
         GGML_ASSERT(model.conv1_kernel->ne[0] == 3);
         GGML_ASSERT(model.conv1_kernel->ne[1] == 3);
         GGML_ASSERT(model.conv1_kernel->ne[2] == 1);
         GGML_ASSERT(model.conv1_kernel->ne[3] == MNIST_CNN_NCB);
 
-        model.conv1_bias = ggml_get_tensor(model.ctx_weight, "conv1.bias");
+        model.conv1_bias = ggml_get_tensor(model.ctx_gguf, "conv1.bias");
         GGML_ASSERT(model.conv1_bias->type == GGML_TYPE_F32);
         GGML_ASSERT(model.conv1_bias->ne[0] == 1);
         GGML_ASSERT(model.conv1_bias->ne[1] == 1);
         GGML_ASSERT(model.conv1_bias->ne[2] == MNIST_CNN_NCB);
         GGML_ASSERT(model.conv1_bias->ne[3] == 1);
 
-        model.conv2_kernel = ggml_get_tensor(model.ctx_weight, "conv2.kernel");
+        model.conv2_kernel = ggml_get_tensor(model.ctx_gguf, "conv2.kernel");
         GGML_ASSERT(model.conv2_kernel->type == GGML_TYPE_F32);
         GGML_ASSERT(model.conv2_kernel->ne[0] == 3);
         GGML_ASSERT(model.conv2_kernel->ne[1] == 3);
         GGML_ASSERT(model.conv2_kernel->ne[2] == MNIST_CNN_NCB);
         GGML_ASSERT(model.conv2_kernel->ne[3] == MNIST_CNN_NCB*2);
 
-        model.conv2_bias = ggml_get_tensor(model.ctx_weight, "conv2.bias");
+        model.conv2_bias = ggml_get_tensor(model.ctx_gguf, "conv2.bias");
         GGML_ASSERT(model.conv2_bias->type == GGML_TYPE_F32);
         GGML_ASSERT(model.conv2_bias->ne[0] == 1);
         GGML_ASSERT(model.conv2_bias->ne[1] == 1);
         GGML_ASSERT(model.conv2_bias->ne[2] == MNIST_CNN_NCB*2);
         GGML_ASSERT(model.conv2_bias->ne[3] == 1);
 
-        model.dense_weight = ggml_get_tensor(model.ctx_weight, "dense.weight");
+        model.dense_weight = ggml_get_tensor(model.ctx_gguf, "dense.weight");
         GGML_ASSERT(model.dense_weight->type == GGML_TYPE_F32);
         GGML_ASSERT(model.dense_weight->ne[0] == (MNIST_HW/4)*(MNIST_HW/4)*(MNIST_CNN_NCB*2));
         GGML_ASSERT(model.dense_weight->ne[1] == MNIST_NCLASSES);
         GGML_ASSERT(model.dense_weight->ne[2] == 1);
         GGML_ASSERT(model.dense_weight->ne[3] == 1);
 
-        model.dense_bias = ggml_get_tensor(model.ctx_weight, "dense.bias");
+        model.dense_bias = ggml_get_tensor(model.ctx_gguf, "dense.bias");
         GGML_ASSERT(model.dense_bias->type == GGML_TYPE_F32);
         GGML_ASSERT(model.dense_bias->ne[0] == MNIST_NCLASSES);
         GGML_ASSERT(model.dense_bias->ne[1] == 1);
@@ -301,19 +216,29 @@ mnist_model mnist_model_init_from_file(const std::string & fname, const std::str
     } else {
         fprintf(stderr, "%s: unknown model arch: %s\n", __func__, model.arch.c_str());
     }
-    model.buf_weight = ggml_backend_alloc_ctx_tensors(model.ctx_weight, model.backend);
 
-    if(!load_from_gguf(fname.c_str(), model.ctx_weight, ctx)) {
+    model.buf_gguf = ggml_backend_alloc_ctx_tensors(model.ctx_gguf, model.backends[0]);
+
+    if(!load_from_gguf(fname.c_str(), model.ctx_gguf, ctx)) {
         fprintf(stderr, "%s: loading weights from %s failed\n", __func__, fname.c_str());
         exit(1);
     }
 
+    // The space in ctx_gguf exactly fits the model weights,
+    // the images (which also need to be statically allocated) need to be put in a different context.
+
+    model.images = ggml_new_tensor_2d(model.ctx_static, GGML_TYPE_F32, MNIST_NINPUT, MNIST_NBATCH_PHYSICAL);
+    ggml_set_name(model.images, "images");
+    ggml_set_input(model.images);
+
+    model.buf_static = ggml_backend_alloc_ctx_tensors(model.ctx_static, model.backends[0]);
+
     fprintf(stderr, "%s: successfully loaded weights from %s\n", __func__, fname.c_str());
     return model;
 }
 
-mnist_model mnist_model_init_random(const std::string & arch, const std::string & backend) {
-    mnist_model model(backend);
+mnist_model mnist_model_init_random(const std::string & arch, const std::string & backend, const int nbatch_logical, const int nbatch_physical) {
+    mnist_model model(backend, nbatch_logical, nbatch_physical);
     model.arch = arch;
 
     std::random_device rd{};
@@ -324,10 +249,10 @@ mnist_model mnist_model_init_random(const std::string & arch, const std::string
     if (model.arch == "mnist-fc") {
         fprintf(stderr, "%s: initializing random weights for a fully connected model\n", __func__);
 
-        model.fc1_weight = ggml_new_tensor_2d(model.ctx_weight, GGML_TYPE_F32, MNIST_NINPUT,  MNIST_NHIDDEN);
-        model.fc1_bias   = ggml_new_tensor_1d(model.ctx_weight, GGML_TYPE_F32,                MNIST_NHIDDEN);
-        model.fc2_weight = ggml_new_tensor_2d(model.ctx_weight, GGML_TYPE_F32, MNIST_NHIDDEN, MNIST_NCLASSES);
-        model.fc2_bias   = ggml_new_tensor_1d(model.ctx_weight, GGML_TYPE_F32,                MNIST_NCLASSES);
+        model.fc1_weight = ggml_new_tensor_2d(model.ctx_static, GGML_TYPE_F32, MNIST_NINPUT,  MNIST_NHIDDEN);
+        model.fc1_bias   = ggml_new_tensor_1d(model.ctx_static, GGML_TYPE_F32,                MNIST_NHIDDEN);
+        model.fc2_weight = ggml_new_tensor_2d(model.ctx_static, GGML_TYPE_F32, MNIST_NHIDDEN, MNIST_NCLASSES);
+        model.fc2_bias   = ggml_new_tensor_1d(model.ctx_static, GGML_TYPE_F32,                MNIST_NCLASSES);
 
         ggml_set_name(model.fc1_weight, "fc1.weight");
         ggml_set_name(model.fc1_bias,   "fc1.bias");
@@ -339,12 +264,12 @@ mnist_model mnist_model_init_random(const std::string & arch, const std::string
         init_tensors.push_back(model.fc2_weight);
         init_tensors.push_back(model.fc2_bias);
     } else if (model.arch == "mnist-cnn") {
-        model.conv1_kernel = ggml_new_tensor_4d(model.ctx_weight, GGML_TYPE_F32, 3, 3, 1, MNIST_CNN_NCB);
-        model.conv1_bias   = ggml_new_tensor_3d(model.ctx_weight, GGML_TYPE_F32, 1, 1,    MNIST_CNN_NCB);
-        model.conv2_kernel = ggml_new_tensor_4d(model.ctx_weight, GGML_TYPE_F32, 3, 3, MNIST_CNN_NCB, MNIST_CNN_NCB*2);
-        model.conv2_bias   = ggml_new_tensor_3d(model.ctx_weight, GGML_TYPE_F32, 1, 1,                MNIST_CNN_NCB*2);
-        model.dense_weight = ggml_new_tensor_2d(model.ctx_weight, GGML_TYPE_F32, (MNIST_HW/4)*(MNIST_HW/4)*(MNIST_CNN_NCB*2), MNIST_NCLASSES);
-        model.dense_bias   = ggml_new_tensor_1d(model.ctx_weight, GGML_TYPE_F32, MNIST_NCLASSES);
+        model.conv1_kernel = ggml_new_tensor_4d(model.ctx_static, GGML_TYPE_F32, 3, 3, 1, MNIST_CNN_NCB);
+        model.conv1_bias   = ggml_new_tensor_3d(model.ctx_static, GGML_TYPE_F32, 1, 1,    MNIST_CNN_NCB);
+        model.conv2_kernel = ggml_new_tensor_4d(model.ctx_static, GGML_TYPE_F32, 3, 3, MNIST_CNN_NCB, MNIST_CNN_NCB*2);
+        model.conv2_bias   = ggml_new_tensor_3d(model.ctx_static, GGML_TYPE_F32, 1, 1,                MNIST_CNN_NCB*2);
+        model.dense_weight = ggml_new_tensor_2d(model.ctx_static, GGML_TYPE_F32, (MNIST_HW/4)*(MNIST_HW/4)*(MNIST_CNN_NCB*2), MNIST_NCLASSES);
+        model.dense_bias   = ggml_new_tensor_1d(model.ctx_static, GGML_TYPE_F32, MNIST_NCLASSES);
 
         ggml_set_name(model.conv1_kernel, "conv1.kernel");
         ggml_set_name(model.conv1_bias,   "conv1.bias");
@@ -363,7 +288,11 @@ mnist_model mnist_model_init_random(const std::string & arch, const std::string
         fprintf(stderr, "%s: unknown model arch: %s\n", __func__, model.arch.c_str());
     }
 
-    model.buf_weight = ggml_backend_alloc_ctx_tensors(model.ctx_weight, model.backend);
+    model.images = ggml_new_tensor_2d(model.ctx_static, GGML_TYPE_F32, MNIST_NINPUT, MNIST_NBATCH_PHYSICAL);
+    ggml_set_name(model.images, "images");
+    ggml_set_input(model.images);
+
+    model.buf_static = ggml_backend_alloc_ctx_tensors(model.ctx_static, model.backends[0]);
 
     for (ggml_tensor * t : init_tensors) {
         GGML_ASSERT(t->type == GGML_TYPE_F32);
@@ -379,20 +308,13 @@ mnist_model mnist_model_init_random(const std::string & arch, const std::string
     return model;
 }
 
-void mnist_model_build(mnist_model & model, const int nbatch_logical, const int nbatch_physical) {
-    model.nbatch_logical  = nbatch_logical;
-    model.nbatch_physical = nbatch_physical;
-
+void mnist_model_build(mnist_model & model) {
     if (model.arch == "mnist-fc") {
         ggml_set_param(model.ctx_compute, model.fc1_weight);
         ggml_set_param(model.ctx_compute, model.fc1_bias);
         ggml_set_param(model.ctx_compute, model.fc2_weight);
         ggml_set_param(model.ctx_compute, model.fc2_bias);
 
-        model.images = ggml_new_tensor_2d(model.ctx_compute, GGML_TYPE_F32, MNIST_NINPUT, model.nbatch_physical);
-        ggml_set_name(model.images, "images");
-        ggml_set_input(model.images);
-
         ggml_tensor * fc1 = ggml_relu(model.ctx_compute, ggml_add(model.ctx_compute,
             ggml_mul_mat(model.ctx_compute, model.fc1_weight, model.images),
             model.fc1_bias));
@@ -407,12 +329,10 @@ void mnist_model_build(mnist_model & model, const int nbatch_logical, const int
         ggml_set_param(model.ctx_compute, model.dense_weight);
         ggml_set_param(model.ctx_compute, model.dense_bias);
 
-        model.images = ggml_new_tensor_4d(model.ctx_compute, GGML_TYPE_F32, 28, 28, 1, model.nbatch_physical);
-        ggml_set_name(model.images, "images");
-        ggml_set_input(model.images);
+        struct ggml_tensor * images_2D = ggml_reshape_4d(model.ctx_compute, model.images, MNIST_HW, MNIST_HW, 1, model.images->ne[1]);
 
         struct ggml_tensor * conv1_out = ggml_relu(model.ctx_compute, ggml_add(model.ctx_compute,
-            ggml_conv_2d(model.ctx_compute, model.conv1_kernel, model.images, 1, 1, 1, 1, 1, 1),
+            ggml_conv_2d(model.ctx_compute, model.conv1_kernel, images_2D, 1, 1, 1, 1, 1, 1),
             model.conv1_bias));
         GGML_ASSERT(conv1_out->ne[0] == MNIST_HW);
         GGML_ASSERT(conv1_out->ne[1] == MNIST_HW);
@@ -459,212 +379,35 @@ void mnist_model_build(mnist_model & model, const int nbatch_logical, const int
     GGML_ASSERT(model.logits->ne[1] == model.nbatch_physical);
     GGML_ASSERT(model.logits->ne[2] == 1);
     GGML_ASSERT(model.logits->ne[3] == 1);
-
-    model.probs = ggml_soft_max(model.ctx_compute, model.logits);
-    ggml_set_name(model.probs, "probs");
-    ggml_set_output(model.probs);
-    GGML_ASSERT(model.probs->type == GGML_TYPE_F32);
-    GGML_ASSERT(model.probs->ne[0] == MNIST_NCLASSES);
-    GGML_ASSERT(model.probs->ne[1] == model.nbatch_physical);
-    GGML_ASSERT(model.probs->ne[2] == 1);
-    GGML_ASSERT(model.probs->ne[3] == 1);
-
-    model.labels = ggml_new_tensor_2d(model.ctx_compute, GGML_TYPE_F32, MNIST_NCLASSES, model.nbatch_physical);
-    ggml_set_name(model.labels, "labels");
-    ggml_set_input(model.labels);
-
-    model.loss = ggml_cross_entropy_loss(model.ctx_compute, model.logits, model.labels);
-    ggml_set_name(model.loss, "loss");
-    ggml_set_output(model.loss);
-    ggml_set_loss(model.loss);
-    GGML_ASSERT(model.loss->type == GGML_TYPE_F32);
-    GGML_ASSERT(model.loss->ne[0] == 1);
-    GGML_ASSERT(model.loss->ne[1] == 1);
-    GGML_ASSERT(model.loss->ne[2] == 1);
-    GGML_ASSERT(model.loss->ne[3] == 1);
-
-    model.pred = ggml_argmax(model.ctx_compute, model.logits);
-    ggml_set_name(model.pred, "predictions");
-    ggml_set_output(model.pred);
-    GGML_ASSERT(model.pred->type == GGML_TYPE_I32);
-    GGML_ASSERT(model.pred->ne[0] == model.nbatch_physical);
-    GGML_ASSERT(model.pred->ne[1] == 1);
-    GGML_ASSERT(model.pred->ne[2] == 1);
-    GGML_ASSERT(model.pred->ne[3] == 1);
-
-    model.acc_count = ggml_count_equal(model.ctx_compute, model.pred, ggml_argmax(model.ctx_compute, model.labels));
-    ggml_set_name(model.acc_count, "accuracy_count");
-    ggml_set_output(model.acc_count);
-    GGML_ASSERT(model.acc_count->type == GGML_TYPE_I64);
-    GGML_ASSERT(model.acc_count->ne[0] == 1);
-    GGML_ASSERT(model.acc_count->ne[1] == 1);
-    GGML_ASSERT(model.acc_count->ne[2] == 1);
-    GGML_ASSERT(model.acc_count->ne[3] == 1);
 }
 
-mnist_eval_result mnist_model_eval(mnist_model & model, mnist_dataset & dataset) {
-    mnist_eval_result result;
-
-    struct ggml_cgraph * gf = ggml_new_graph(model.ctx_compute);
-    // The outputs are diverging branches of the graphs, therefore multiple calls to ggml_build_forward_expand are needed.
-    ggml_build_forward_expand(gf, model.loss);
-    ggml_build_forward_expand(gf, model.pred);
-    ggml_build_forward_expand(gf, model.acc_count);
+ggml_opt_result_t mnist_model_eval(mnist_model & model, ggml_opt_dataset_t dataset) {
+    ggml_opt_result_t result = ggml_opt_result_init();
 
-    model.buf_compute = ggml_backend_alloc_ctx_tensors(model.ctx_compute, model.backend);
+    ggml_opt_params params = ggml_opt_default_params(model.backend_sched, model.ctx_compute, model.images, model.logits, GGML_OPT_LOSS_TYPE_CROSS_ENTROPY);
+    params.build_type = GGML_OPT_BUILD_TYPE_FORWARD;
+    ggml_opt_context_t opt_ctx = ggml_opt_init(params);
 
     {
         const int64_t t_start_us = ggml_time_us();
 
-        float                tmp_loss;
-        std::vector<int32_t> tmp_pred(model.nbatch_physical);
-        int64_t              tmp_acc_count;
-
-        GGML_ASSERT(sizeof(tmp_loss)                    == ggml_nbytes(model.loss));
-        GGML_ASSERT(sizeof(tmp_pred[0])*tmp_pred.size() == ggml_nbytes(model.pred));
-        GGML_ASSERT(sizeof(tmp_acc_count)               == ggml_nbytes(model.acc_count));
-
-        GGML_ASSERT(dataset.nex % model.nbatch_physical == 0);
-        const int nbatches = dataset.nex/model.nbatch_physical;
-        for (int ibatch = 0; ibatch < nbatches; ++ibatch) {
-            dataset.get_batch(model.images, model.labels, ibatch);
-
-            ggml_backend_graph_compute(model.backend, gf);
-
-            ggml_backend_tensor_get(model.loss,      &tmp_loss,       0, ggml_nbytes(model.loss));
-            ggml_backend_tensor_get(model.pred,      tmp_pred.data(), 0, ggml_nbytes(model.pred));
-            ggml_backend_tensor_get(model.acc_count, &tmp_acc_count,  0, ggml_nbytes(model.acc_count));
-
-            result.loss.push_back(tmp_loss);
-            result.pred.insert(result.pred.end(), tmp_pred.begin(), tmp_pred.end());
-            result.ncorrect += tmp_acc_count;
-            result.ntotal   += model.nbatch_physical;
-        }
+        ggml_opt_epoch(opt_ctx, dataset, nullptr, result, /*idata_split =*/ 0, nullptr, nullptr);
 
         const int64_t t_total_us = ggml_time_us() - t_start_us;
         const double t_total_ms = 1e-3*t_total_us;
+        const int nex = ggml_opt_dataset_data(dataset)->ne[1];
         fprintf(stderr, "%s: model evaluation on %d images took %.2lf ms, %.2lf us/image\n",
-                __func__, (int)dataset.nex, t_total_ms, (double) t_total_us/dataset.nex);
+                __func__, nex, t_total_ms, (double) t_total_us/nex);
     }
 
-    result.success = true;
+    ggml_opt_free(opt_ctx);
+
     return result;
 }
 
-void mnist_model_train(mnist_model & model, mnist_dataset & dataset, const int nepoch, const float val_split) {
-    const int64_t t_start_us = ggml_time_us();
-
-    const int opt_period        = model.nbatch_logical / model.nbatch_physical;
-    const int nbatches_logical  = dataset.nex / model.nbatch_logical;
-    const int nbatches_physical = dataset.nex / model.nbatch_physical;
-    const int ibatch_split      = ((int)((1.0f - val_split)*nbatches_logical))*opt_period; // train <-> val split index (physical)
-    const int ishard_split      = ibatch_split * model.nbatch_physical/dataset.shard_size;
-
-    // gf == graph forward, forward pass only.
-    struct ggml_cgraph * gf = ggml_new_graph_custom(model.ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
-    // The outputs are diverging branches of the graphs, therefore multiple calls to ggml_build_forward_expand are needed.
-    ggml_build_forward_expand(gf, model.loss);
-    ggml_build_forward_expand(gf, model.pred);
-    ggml_build_forward_expand(gf, model.acc_count);
-
-    // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
-    struct ggml_cgraph * gb_grad = ggml_graph_dup(model.ctx_compute, gf);
-    ggml_build_backward_expand(model.ctx_compute, gf, gb_grad, /*accumulate =*/ opt_period > 1);
-
-    // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
-    struct ggml_cgraph * gb_opt = ggml_graph_dup(model.ctx_compute, gb_grad);
-    ggml_build_opt_adamw(model.ctx_compute, gf, gb_opt, 1e-3f, 0.9f, 0.999f, 1e-8f, 0.0f);
-
-    model.buf_compute = ggml_backend_alloc_ctx_tensors(model.ctx_compute, model.backend);
-    ggml_graph_reset(gb_opt); // Set gradients to zero, reset optimizer.
-
-    dataset.shuffle(-1); // Shuffle all data (train + validation).
-
-    for (int epoch = 0; epoch < nepoch; ++epoch) {
-        fprintf(stderr, "%s: epoch %02d start...", __func__, epoch);
-        const int64_t t_start_us = ggml_time_us();
-
-        dataset.shuffle(ishard_split); // Shuffle only the training data, keeping training and validation set separate.
-
-        int ibatch_physical = 0;
-
-        float                tmp_loss;
-        std::vector<int32_t> tmp_pred(model.nbatch_physical);
-        int64_t              tmp_acc_count;
-
-        GGML_ASSERT(sizeof(tmp_loss)                    == ggml_nbytes(model.loss));
-        GGML_ASSERT(sizeof(tmp_pred[0])*tmp_pred.size() == ggml_nbytes(model.pred));
-        GGML_ASSERT(sizeof(tmp_acc_count)               == ggml_nbytes(model.acc_count));
-
-        mnist_eval_result result_train;
-        for (; ibatch_physical < ibatch_split; ++ibatch_physical) {
-            dataset.get_batch(model.images, model.labels, ibatch_physical);
-
-            // With a period of opt_period == nbatch_logical/nbatch_physical iterations:
-            if ((ibatch_physical + 1) % opt_period != 0) {
-                // For the first opt_period - 1 iterations, only calculate gradients and accumulate them:
-                ggml_backend_graph_compute(model.backend, gb_grad);
-            } else {
-                // For the last iteration, calculate gradients and also apply the optimizer:
-                ggml_backend_graph_compute(model.backend, gb_opt); // gb_opt contains all nodes of gb_grad so no extra call for gb_grad is needed.
-                ggml_graph_reset(gb_grad); // Set gradients to zero, do not reset optimizer.
-            }
-
-            ggml_backend_tensor_get(model.loss,      &tmp_loss,       0, ggml_nbytes(model.loss));
-            ggml_backend_tensor_get(model.pred,      tmp_pred.data(), 0, ggml_nbytes(model.pred));
-            ggml_backend_tensor_get(model.acc_count, &tmp_acc_count,  0, ggml_nbytes(model.acc_count));
-
-            result_train.loss.push_back(tmp_loss);
-            result_train.pred.insert(result_train.pred.end(), tmp_pred.begin(), tmp_pred.end());
-            result_train.ncorrect += tmp_acc_count;
-            result_train.ntotal   += model.nbatch_physical;
-        }
-
-        mnist_eval_result result_val;
-        for (; ibatch_physical < nbatches_physical; ++ibatch_physical) {
-            dataset.get_batch(model.images, model.labels, ibatch_physical);
-
-            ggml_backend_graph_compute(model.backend, gf); // For the validation set, only the forward pass is needed.
-
-            ggml_backend_tensor_get(model.loss,      &tmp_loss,       0, ggml_nbytes(model.loss));
-            ggml_backend_tensor_get(model.pred,      tmp_pred.data(), 0, ggml_nbytes(model.pred));
-            ggml_backend_tensor_get(model.acc_count, &tmp_acc_count,  0, ggml_nbytes(model.acc_count));
-
-            result_val.loss.push_back(tmp_loss);
-            result_val.pred.insert(result_val.pred.end(), tmp_pred.begin(), tmp_pred.end());
-            result_val.ncorrect += tmp_acc_count;
-            result_val.ntotal   += model.nbatch_physical;
-        }
-
-        {
-            const double loss_mean = mnist_loss(result_train).first;
-            const double percent_correct = 100.0 * mnist_accuracy(result_train).first;
-
-            const int64_t t_epoch_us = ggml_time_us() - t_start_us;
-            const double t_epoch_s = 1e-6*t_epoch_us;
-            fprintf(stderr, "done, took %.2lfs, train_loss=%.6lf, train_acc=%.2f%%", t_epoch_s, loss_mean, percent_correct);
-        }
-
-        if (ibatch_split < nbatches_physical) {
-            const std::pair<double, double> loss = mnist_loss(result_val);
-            const std::pair<double, double> acc  = mnist_accuracy(result_val);
-
-            fprintf(stderr, ", val_loss=%.6lf+-%.6lf, val_acc=%.2f+-%.2f%%", loss.first, loss.second, 100.0*acc.first, 100.0*acc.second);
-        }
-        fprintf(stderr, "\n");
-    }
-
-    const int64_t t_total_us = ggml_time_us() - t_start_us;
-    const double t_total_s = 1e-6*t_total_us;
-    fprintf(stderr, "%s: training took %.2lfs\n", __func__, t_total_s);
-
-    if (ggml_backend_is_cpu(model.backend)) {
-        std::string fname = model.arch + "-f32.ggml";
-        fprintf(stderr, "%s: saving the GGML graph for the forward pass to %s\n", __func__, fname.c_str());
-        ggml_graph_export(gf, fname.c_str());
-    } else {
-        fprintf(stderr, "%s: not saving the GGML graph for the forward pass because this is only supported for the CPU backend\n", __func__);
-    }
+void mnist_model_train(mnist_model & model, ggml_opt_dataset_t dataset, const int nepoch, const float val_split) {
+    ggml_opt_fit(model.backend_sched, model.ctx_compute, model.images, model.logits, dataset,
+        GGML_OPT_LOSS_TYPE_CROSS_ENTROPY, ggml_opt_get_default_optimizer_params, nepoch, model.nbatch_logical, val_split, false);
 }
 
 void mnist_model_save(mnist_model & model, const std::string & fname) {
@@ -703,34 +446,6 @@ void mnist_model_save(mnist_model & model, const std::string & fname) {
     gguf_free(gguf_ctx);
 }
 
-std::pair<double, double> mnist_loss(const mnist_eval_result & result) {
-    const size_t nbatches = result.loss.size();
-    GGML_ASSERT(nbatches >= 2);
-
-    double sum         = 0.0;
-    double sum_squared = 0.0;
-
-    for (const float & loss : result.loss) {
-        sum         += loss;
-        sum_squared += loss*loss;
-    }
-
-    const double mean        = sum/nbatches;
-    const double uncertainty = sqrt((sum_squared/nbatches - mean*mean) / (nbatches - 1));
-
-    return std::make_pair(mean, uncertainty);
-}
-
-std::pair<double, double> mnist_accuracy(const mnist_eval_result & result) {
-    GGML_ASSERT(result.ntotal >= result.ncorrect);
-    GGML_ASSERT(result.ntotal >= 2);
-
-    const double fraction_correct = ((double) result.ncorrect) / ((double) result.ntotal);
-    const double uncertainty      = sqrt(fraction_correct * (1.0 - fraction_correct) / (result.ncorrect - 1));
-
-    return std::make_pair(fraction_correct, uncertainty);
-}
-
 #ifdef __cplusplus
 extern "C" {
 #endif
@@ -738,15 +453,19 @@ extern "C" {
 int wasm_eval(uint8_t * digitPtr) {
     std::vector<float> digit(digitPtr, digitPtr + MNIST_NINPUT);
 
-    struct mnist_dataset dataset(1, 1);
-    memcpy(dataset.data->data, digitPtr, ggml_nbytes(dataset.data));
-    ggml_set_zero(dataset.labels); // The labels are not needed.
+    ggml_opt_dataset_t dataset = ggml_opt_dataset_init(MNIST_NINPUT, MNIST_NCLASSES, 1, 1);
+    struct ggml_tensor * data = ggml_opt_dataset_data(dataset);
+    memcpy(data->data, digitPtr, ggml_nbytes(data));
+    ggml_set_zero(ggml_opt_dataset_labels(dataset)); // The labels are not needed.
+
+    mnist_model model = mnist_model_init_from_file("mnist-f32.gguf", "CPU", /*nbatch_logical =*/ 1, /*nbatch_physical =*/ 1);
+    mnist_model_build(model);
+    ggml_opt_result_t result = mnist_model_eval(model, dataset);
 
-    mnist_model model = mnist_model_init_from_file("mnist-f32.gguf", "CPU");
-    mnist_model_build(model, 1, 1);
-    mnist_eval_result result = mnist_model_eval(model, dataset);
+    int32_t pred;
+    ggml_opt_result_pred(result, &pred);
 
-    return result.pred[0];
+    return pred;
 }
 
 int wasm_random_digit(char * digitPtr) {
index 22d59b98926d677fa8e5884b28fe28f09f2f25ed..090cf3715944634798357096e74e0b22b05df945 100644 (file)
@@ -9,11 +9,16 @@
 #include "ggml-backend.h"
 #include "ggml.h"
 #include "ggml-cpu.h"
+#include "ggml-opt.h"
 
-#define MNIST_NTRAIN          60000
-#define MNIST_NTEST           10000
-#define MNIST_NBATCH_LOGICAL   1000
-#define MNIST_NBATCH_PHYSICAL   500
+#define MNIST_NTRAIN 60000
+#define MNIST_NTEST  10000
+
+// Gradient accumulation can be achieved by setting the logical batch size to a multiple of the physical one.
+// The logical batch size determines how many datapoints are used for a gradient update.
+// The physical batch size determines how many datapoints are processed in parallel, larger values utilize compute better but need more memory.
+#define MNIST_NBATCH_LOGICAL  1000
+#define MNIST_NBATCH_PHYSICAL  500
 
 static_assert(MNIST_NBATCH_LOGICAL % MNIST_NBATCH_PHYSICAL == 0, "MNIST_NBATCH_LOGICAL % MNIST_NBATCH_PHYSICAL != 0");
 static_assert(MNIST_NTRAIN % MNIST_NBATCH_LOGICAL == 0, "MNIST_NTRAIN % MNIST_NBATCH_LOGICAL != 0");
@@ -28,77 +33,15 @@ static_assert(MNIST_NTEST  % MNIST_NBATCH_LOGICAL == 0, "MNIST_NTRAIN % MNIST_NB
 // NCB = number of channels base
 #define MNIST_CNN_NCB 8
 
-struct mnist_dataset {
-    struct ggml_context * ctx;
-    struct ggml_tensor  * data;
-    struct ggml_tensor  * labels;
-
-    int64_t nex;
-    int64_t shard_size;
-    size_t  nbs_data;
-    size_t  nbs_labels;
-
-    std::vector<int64_t> permutation;
-    std::mt19937 rng;
-
-    mnist_dataset(const int64_t nex, const int64_t shard_size) : nex(nex), shard_size(shard_size) {
-        const size_t nbytes_images = nex*MNIST_NINPUT  *sizeof(float) + ggml_tensor_overhead();
-        const size_t nbytes_labels = nex*MNIST_NCLASSES*sizeof(float) + ggml_tensor_overhead();
-        struct ggml_init_params params = {
-            /*.mem_size   =*/ nbytes_images + nbytes_labels,
-            /*.mem_buffer =*/ nullptr,
-            /*.no_alloc   =*/ false,
-        };
-        ctx = ggml_init(params);
-
-        data   = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, MNIST_HW, MNIST_HW, nex);
-        labels = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, MNIST_NCLASSES,     nex);
-
-        nbs_data   = ggml_nbytes(data)   * shard_size/nex;
-        nbs_labels = ggml_nbytes(labels) * shard_size/nex;
-
-        permutation.resize(nex/shard_size);
-        for (size_t i = 0; i < permutation.size(); ++i) {
-            permutation[i] = i;
-        }
-    }
-
-    ~mnist_dataset() {
-        ggml_free(ctx);
-    }
-
-    void shuffle(const size_t ishard_max) {
-        if (ishard_max < permutation.size()) {
-            std::shuffle(permutation.begin(), permutation.begin() + ishard_max, rng);
-            return;
-        }
-        std::shuffle(permutation.begin(), permutation.end(), rng);
-    }
-
-    void get_batch(struct ggml_tensor * data_batch, struct ggml_tensor * labels_batch, const int64_t ibatch) {
-        const int64_t shards_per_batch = ggml_nbytes(data_batch) / nbs_data;
-        for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) {
-            const int64_t ishard = permutation[ibatch*shards_per_batch + ishard_batch];
-
-            ggml_backend_tensor_set(data_batch,   (const char *)   data->data + ishard*nbs_data,   ishard_batch*nbs_data,   nbs_data);
-            ggml_backend_tensor_set(labels_batch, (const char *) labels->data + ishard*nbs_labels, ishard_batch*nbs_labels, nbs_labels);
-        }
-    }
-};
-
 struct mnist_model {
     std::string arch;
-    ggml_backend_t backend;
-    int nbatch_logical;
-    int nbatch_physical;
-
-    struct ggml_tensor  * images    = nullptr;
-    struct ggml_tensor  * labels    = nullptr;
-    struct ggml_tensor  * logits    = nullptr;
-    struct ggml_tensor  * probs     = nullptr;
-    struct ggml_tensor  * loss      = nullptr;
-    struct ggml_tensor  * pred      = nullptr;
-    struct ggml_tensor  * acc_count = nullptr;
+    ggml_backend_sched_t backend_sched;
+    std::vector<ggml_backend_t> backends;
+    const int nbatch_logical;
+    const int nbatch_physical;
+
+    struct ggml_tensor * images     = nullptr;
+    struct ggml_tensor * logits     = nullptr;
 
     struct ggml_tensor * fc1_weight = nullptr;
     struct ggml_tensor * fc1_bias   = nullptr;
@@ -112,28 +55,66 @@ struct mnist_model {
     struct ggml_tensor * dense_weight = nullptr;
     struct ggml_tensor * dense_bias   = nullptr;
 
-    struct ggml_context * ctx_weight  = nullptr;
+    struct ggml_context * ctx_gguf    = nullptr;
+    struct ggml_context * ctx_static  = nullptr;
     struct ggml_context * ctx_compute = nullptr;
-    ggml_backend_buffer_t buf_weight  = nullptr;
-    ggml_backend_buffer_t buf_compute = nullptr;
-
-    mnist_model(const std::string & backend_name) {
-        ggml_backend_dev_t dev = ggml_backend_dev_by_name(backend_name.c_str());
-        if (dev == nullptr) {
-            fprintf(stderr, "%s: ERROR: backend %s not found, available:\n", __func__, backend_name.c_str());
-            for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
-                ggml_backend_dev_t this_dev = ggml_backend_dev_get(i);
-                fprintf(stderr, "  - %s (%s)\n", ggml_backend_dev_name(this_dev), ggml_backend_dev_description(this_dev));
+    ggml_backend_buffer_t buf_gguf    = nullptr;
+    ggml_backend_buffer_t buf_static  = nullptr;
+
+    mnist_model(const std::string & backend_name, const int nbatch_logical, const int nbatch_physical)
+            : nbatch_logical(nbatch_logical), nbatch_physical(nbatch_physical) {
+        std::vector<ggml_backend_dev_t> devices;
+        const int ncores_logical = std::thread::hardware_concurrency();
+        const int nthreads = std::min(ncores_logical, (ncores_logical + 4) / 2);
+
+        // Add primary backend:
+        if (!backend_name.empty()) {
+            ggml_backend_dev_t dev = ggml_backend_dev_by_name(backend_name.c_str());
+            if (dev == nullptr) {
+                fprintf(stderr, "%s: ERROR: backend %s not found, available:\n", __func__, backend_name.c_str());
+                for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
+                    ggml_backend_dev_t dev_i = ggml_backend_dev_get(i);
+                    fprintf(stderr, "  - %s (%s)\n", ggml_backend_dev_name(dev_i), ggml_backend_dev_description(dev_i));
+                }
+                exit(1);
             }
-            exit(1);
+
+            ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
+            GGML_ASSERT(backend);
+
+            if (ggml_backend_is_cpu(backend)) {
+                ggml_backend_cpu_set_n_threads(backend, nthreads);
+            }
+
+            backends.push_back(backend);
+            devices.push_back(dev);
         }
 
-        fprintf(stderr, "%s: using %s (%s) backend\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev));
+        // Add all available backends as fallback.
+        // A "backend" is a stream on a physical device so there is no problem with adding multiple backends for the same device.
+        for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
+            ggml_backend_dev_t dev = ggml_backend_dev_get(i);
+
+            ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
+            GGML_ASSERT(backend);
 
-        backend = ggml_backend_dev_init(dev, NULL);
-        if (ggml_backend_is_cpu(backend)) {
-            const int ncores_logical = std::thread::hardware_concurrency();
-            ggml_backend_cpu_set_n_threads(backend, std::min(ncores_logical, (ncores_logical + 4)/2));
+            if (ggml_backend_is_cpu(backend)) {
+                ggml_backend_cpu_set_n_threads(backend, nthreads);
+            }
+
+            backends.push_back(backend);
+            devices.push_back(dev);
+        }
+
+        // The order of the backends passed to ggml_backend_sched_new determines which backend is given priority.
+        backend_sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), GGML_DEFAULT_GRAPH_SIZE, false);
+        fprintf(stderr, "%s: using %s (%s) as primary backend\n",
+                __func__, ggml_backend_name(backends[0]), ggml_backend_dev_description(devices[0]));
+        if (backends.size() >= 2) {
+            fprintf(stderr, "%s: unsupported operations will be executed on the following fallback backends (in order of priority):\n", __func__);
+            for (size_t i = 1; i < backends.size(); ++i) {
+                fprintf(stderr, "%s:  - %s (%s)\n", __func__, ggml_backend_name(backends[i]), ggml_backend_dev_description(devices[i]));
+            }
         }
 
         {
@@ -143,7 +124,7 @@ struct mnist_model {
                 /*.mem_buffer =*/ nullptr,
                 /*.no_alloc   =*/ true,
             };
-            ctx_weight = ggml_init(params);
+            ctx_static = ggml_init(params);
         }
 
         {
@@ -159,36 +140,26 @@ struct mnist_model {
     }
 
     ~mnist_model() {
-        ggml_free(ctx_weight);
+        ggml_free(ctx_gguf);
+        ggml_free(ctx_static);
         ggml_free(ctx_compute);
 
-        ggml_backend_buffer_free(buf_weight);
-        ggml_backend_buffer_free(buf_compute);
-        ggml_backend_free(backend);
+        ggml_backend_buffer_free(buf_gguf);
+        ggml_backend_buffer_free(buf_static);
+        ggml_backend_sched_free(backend_sched);
+        for (ggml_backend_t backend : backends) {
+            ggml_backend_free(backend);
+        }
     }
 };
 
-struct mnist_eval_result {
-    bool success = false;
-
-    std::vector<float>   loss;
-    std::vector<int32_t> pred;
-    int64_t              ncorrect = 0;
-    int64_t              ntotal   = 0;
-};
+bool mnist_image_load(const std::string & fname, ggml_opt_dataset_t dataset);
+void mnist_image_print(FILE * f, ggml_opt_dataset_t dataset, const int iex);
+bool mnist_label_load(const std::string & fname, ggml_opt_dataset_t dataset);
 
-bool mnist_image_load(const std::string & fname, mnist_dataset & dataset);
-void mnist_image_print(FILE * f, mnist_dataset & dataset, const int iex);
-bool mnist_label_load(const std::string & fname, mnist_dataset & dataset);
-
-mnist_eval_result mnist_graph_eval(const std::string & fname, const float * images, const float * labels, const int nex, const int nthreads);
-
-mnist_model       mnist_model_init_from_file(const std::string & fname, const std::string & backend);
-mnist_model       mnist_model_init_random(const std::string & arch, const std::string & backend);
-void              mnist_model_build(mnist_model & model, const int nbatch_logical, const int nbatch_physical);
-mnist_eval_result mnist_model_eval(mnist_model & model, mnist_dataset & dataset);
-void              mnist_model_train(mnist_model & model, mnist_dataset & dataset, const int nepoch, const float val_split);
+mnist_model       mnist_model_init_from_file(const std::string & fname, const std::string & backend, const int nbatch_logical, const int nbatch_physical);
+mnist_model       mnist_model_init_random(const std::string & arch, const std::string & backend, const int nbatch_logical, const int nbatch_physical);
+void              mnist_model_build(mnist_model & model);
+ggml_opt_result_t mnist_model_eval(mnist_model & model, ggml_opt_dataset_t dataset);
+void              mnist_model_train(mnist_model & model, ggml_opt_dataset_t dataset, const int nepoch, const float val_split);
 void              mnist_model_save(mnist_model & model, const std::string & fname);
-
-std::pair<double, double> mnist_loss(const mnist_eval_result & result);
-std::pair<double, double> mnist_accuracy(const mnist_eval_result & result);
index 125bbb8cb154fbb4fccf08e5a24209ad2fa6cb61..218742cfedd2f446d847e44f60887bd27e22d884 100644 (file)
@@ -1,4 +1,5 @@
 #include "ggml.h"
+#include "ggml-opt.h"
 
 #include "mnist-common.h"
 
@@ -24,7 +25,7 @@ int main(int argc, char ** argv) {
         exit(1);
     }
 
-    struct mnist_dataset dataset(/*nex =*/ MNIST_NTEST, /*shard_size =*/ MNIST_NBATCH_PHYSICAL);
+    ggml_opt_dataset_t dataset = ggml_opt_dataset_init(MNIST_NINPUT, MNIST_NCLASSES, MNIST_NTEST, MNIST_NBATCH_PHYSICAL);
 
     if (!mnist_image_load(argv[2], dataset)) {
         return 1;
@@ -36,46 +37,31 @@ int main(int argc, char ** argv) {
     const int iex = rand() % MNIST_NTEST;
     mnist_image_print(stdout, dataset, iex);
 
-    const std::string backend = argc >= 5 ? argv[4] : "CPU";
-
-    mnist_eval_result result_eval;
-
-    if (backend == "CPU") {
-        const int ncores_logical = std::thread::hardware_concurrency();
-        result_eval = mnist_graph_eval(
-            argv[1], ggml_get_data_f32(dataset.data), ggml_get_data_f32(dataset.labels), MNIST_NTEST, std::min(ncores_logical, (ncores_logical + 4)/2));
-        if (result_eval.success) {
-            fprintf(stdout, "%s: predicted digit is %d\n", __func__, result_eval.pred[iex]);
-
-            std::pair<double, double> result_loss = mnist_loss(result_eval);
-            fprintf(stdout, "%s: test_loss=%.6lf+-%.6lf\n", __func__, result_loss.first, result_loss.second);
-
-            std::pair<double, double> result_acc = mnist_accuracy(result_eval);
-            fprintf(stdout, "%s: test_acc=%.2lf+-%.2lf%%\n", __func__, 100.0*result_acc.first, 100.0*result_acc.second);
-
-            return 0;
-        }
-    } else {
-        fprintf(stdout, "%s: not trying to load a GGML graph from %s because this is only supported for the CPU backend\n", __func__, argv[1]);
-    }
+    const std::string backend = argc >= 5 ? argv[4] : "";
 
     const int64_t t_start_us = ggml_time_us();
+    mnist_model model = mnist_model_init_from_file(argv[1], backend, MNIST_NBATCH_LOGICAL, MNIST_NBATCH_PHYSICAL);
+    mnist_model_build(model);
+    const int64_t t_load_us = ggml_time_us() - t_start_us;
+    fprintf(stdout, "%s: loaded model in %.2lf ms\n", __func__, t_load_us / 1000.0);
 
-    mnist_model model = mnist_model_init_from_file(argv[1], backend);
+    ggml_opt_result_t result_eval = mnist_model_eval(model, dataset);
 
-    mnist_model_build(model, MNIST_NBATCH_LOGICAL, MNIST_NBATCH_PHYSICAL);
+    std::vector<int32_t> pred(MNIST_NTEST);
+    ggml_opt_result_pred(result_eval, pred.data());
+    fprintf(stdout, "%s: predicted digit is %d\n", __func__, pred[iex]);
 
-    const int64_t t_load_us = ggml_time_us() - t_start_us;
-
-    fprintf(stdout, "%s: loaded model in %.2lf ms\n", __func__, t_load_us / 1000.0);
-    result_eval = mnist_model_eval(model, dataset);
-    fprintf(stdout, "%s: predicted digit is %d\n", __func__, result_eval.pred[iex]);
+    double loss;
+    double loss_unc;
+    ggml_opt_result_loss(result_eval, &loss, &loss_unc);
+    fprintf(stdout, "%s: test_loss=%.6lf+-%.6lf\n", __func__, loss, loss_unc);
 
-    std::pair<double, double> result_loss = mnist_loss(result_eval);
-    fprintf(stdout, "%s: test_loss=%.6lf+-%.6lf\n", __func__, result_loss.first, result_loss.second);
+    double accuracy;
+    double accuracy_unc;
+    ggml_opt_result_accuracy(result_eval, &accuracy, &accuracy_unc);
+    fprintf(stdout, "%s: test_acc=%.2lf+-%.2lf%%\n", __func__, 100.0*accuracy, 100.0*accuracy_unc);
 
-    std::pair<double, double> result_acc = mnist_accuracy(result_eval);
-    fprintf(stdout, "%s: test_acc=%.2lf+-%.2lf%%\n", __func__, 100.0*result_acc.first, 100.0*result_acc.second);
+    ggml_opt_result_free(result_eval);
 
     return 0;
 }
index 161dcf80fb864d4ca0eb0a729c34de1c42eb8498..a61dd05b0aff1d58a634d82b4df7ec2e18257e4b 100644 (file)
@@ -1,3 +1,4 @@
+#include "ggml-opt.h"
 #include "mnist-common.h"
 
 #include <cmath>
@@ -19,7 +20,7 @@ int main(int argc, char ** argv) {
     // The MNIST model is so small that the overhead from data shuffling is non-negligible, especially with CUDA.
     // With a shard size of 10 this overhead is greatly reduced at the cost of less shuffling (does not seem to have a significant impact).
     // A batch of 500 images then consists of 50 random shards of size 10 instead of 500 random shards of size 1.
-    struct mnist_dataset dataset(/*nex =*/ MNIST_NTRAIN, /*shard_size =*/ 10);
+    ggml_opt_dataset_t dataset = ggml_opt_dataset_init(MNIST_NINPUT, MNIST_NCLASSES, MNIST_NTRAIN, /*ndata_shard =*/ 10);
 
     if (!mnist_image_load(argv[3], dataset)) {
         return 1;
@@ -28,9 +29,9 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
-    mnist_model model = mnist_model_init_random(argv[1], argc >= 6 ? argv[5] : "CPU");
+    mnist_model model = mnist_model_init_random(argv[1], argc >= 6 ? argv[5] : "", MNIST_NBATCH_LOGICAL, MNIST_NBATCH_PHYSICAL);
 
-    mnist_model_build(model, MNIST_NBATCH_LOGICAL, MNIST_NBATCH_PHYSICAL);
+    mnist_model_build(model);
 
     mnist_model_train(model, dataset, /*nepoch =*/ 30, /*val_split =*/ 0.05f);
 
index 0a65dbfcaf5e845540b232323d14a2afa6dc81f3..cef164764bb1a6a0915a6036dd221081652510bb 100644 (file)
@@ -86,7 +86,7 @@ extern "C" {
     GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
     GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
 
-    // "offset" refers to the offset of the tensor data for setting/getting data
+    // "offset" refers to the offset in tensor->data for setting/getting data
     GGML_API void ggml_backend_tensor_set(      struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
     GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
     GGML_API void ggml_backend_tensor_memset(   struct ggml_tensor * tensor,     uint8_t value, size_t offset, size_t size);
@@ -242,14 +242,20 @@ extern "C" {
         ggml_backend_sched_reserve(sched, reserve_graph);
 
         // compute
-        graph = build_graph(sched);
-        ggml_backend_sched_graph_compute(sched, graph);
+        graph = build_graph(sched); // the graph and its tensors are single-use in terms of allocation, multi-use in terms of computation
+        for (int i = 0; i < 10; ++i) {
+            ggml_backend_sched_graph_compute(sched, graph); // on the first iteration the graph is allocated automatically
+        }
 
         // if there are graph inputs:
-        ggml_backend_sched_reset(sched);
-        ggml_backend_sched_alloc_graph(sched, graph);
-        ggml_backend_tensor_set(input_tensor, ...);
-        ggml_backend_sched_graph_compute(sched, graph);
+        graph = build_graph(sched); // get a new graph that is not allocated (the metadata for the old graph is freed once ggml_free is called)
+        ggml_backend_sched_reset(sched); // clear the allocation of the previous graph
+        ggml_backend_sched_alloc_graph(sched, graph); // explicitly allocate the new graph but do not execute it
+        ggml_backend_tensor_set(input_tensor, ...); // copy data to the newly allocated graph tensors
+        ggml_backend_sched_graph_compute(sched, graph); // execute the graph
+
+        // as an alternative to the above it is also possible to assign the inputs to a dedicated context and
+        // allocate them statically via ggml_backend_alloc_ctx_tensors
     }
     */
 
@@ -264,7 +270,7 @@ extern "C" {
     //
     typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
 
-    // Initialize a backend scheduler
+    // Initialize a backend scheduler, backends with low index are given priority over backends with high index
     GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel);
     GGML_API void                 ggml_backend_sched_free(ggml_backend_sched_t sched);
 
@@ -289,7 +295,9 @@ extern "C" {
     GGML_API enum ggml_status     ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
     GGML_API void                 ggml_backend_sched_synchronize(ggml_backend_sched_t sched);
 
-    // Reset all assignments and allocators - must be called before changing the node backends
+    // Reset all assignments and allocators - must be called before changing the node backends or allocating a new graph.
+    // This in effect deallocates all tensors that were previously allocated and leaves them with dangling pointers.
+    // The correct way to use this API is to discard the deallocated tensors and create new ones.
     GGML_API void                 ggml_backend_sched_reset(ggml_backend_sched_t sched);
 
     // Set a callback to be called for each resulting node during graph compute
diff --git a/include/ggml-opt.h b/include/ggml-opt.h
new file mode 100644 (file)
index 0000000..eb5eab9
--- /dev/null
@@ -0,0 +1,216 @@
+// This file contains functionality for training models using GGML.
+// It is not strictly needed vs. just vanilla GGML but it provides a more high-level interface for common needs such as datasets.
+// At the bottom of this file especially there are relatively high-level functions that are suitable use or adaptation in user code.
+//
+// Module maintainer: Johannes Gäßler (@JohannesGaessler, johannesg@5d6.de)
+
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+#include <stdint.h>
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+    struct ggml_opt_dataset;
+    struct ggml_opt_context;
+    struct ggml_opt_result;
+
+    typedef struct ggml_opt_dataset * ggml_opt_dataset_t;
+    typedef struct ggml_opt_context * ggml_opt_context_t;
+    typedef struct ggml_opt_result  * ggml_opt_result_t;
+
+    // ====== Loss ======
+
+    // built-in loss types, i.e. the built-in quantities minimized by the optimizer
+    // custom loss types can be defined via mean or sum which simply reduce the outputs for all datapoints to a single value
+    enum ggml_opt_loss_type {
+        GGML_OPT_LOSS_TYPE_MEAN,
+        GGML_OPT_LOSS_TYPE_SUM,
+        GGML_OPT_LOSS_TYPE_CROSS_ENTROPY,
+        GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR,
+    };
+
+    // ====== Dataset ======
+
+    GGML_API ggml_opt_dataset_t ggml_opt_dataset_init(
+            int64_t ne_datapoint, // number of elements per datapoint
+            int64_t ne_label,     // number of elements per label
+            int64_t ndata,        // total number of datapoints/labels
+            int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
+    GGML_API void ggml_opt_dataset_free(ggml_opt_dataset_t dataset);
+
+    // get underlying tensors that store the data
+    GGML_API struct ggml_tensor * ggml_opt_dataset_data  (ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata]
+    GGML_API struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset); // shape = [nd_label,     ndata]
+
+    // shuffle idata first datapoints from dataset with RNG from opt_ctx, shuffle all datapoints if idata is negative
+    GGML_API void ggml_opt_dataset_shuffle(ggml_opt_context_t opt_ctx, ggml_opt_dataset_t dataset, int64_t idata);
+
+    // get batch at position ibatch from dataset and copy the data to data_batch and labels_batch
+    GGML_API void ggml_opt_dataset_get_batch(
+            ggml_opt_dataset_t   dataset,
+            struct ggml_tensor * data_batch,   // shape = [ne_datapoint, ndata_batch]
+            struct ggml_tensor * labels_batch, // shape = [ne_label,     ndata_batch]
+            int64_t              ibatch);
+
+    // ====== Model / Context ======
+
+    enum ggml_opt_build_type {
+        GGML_OPT_BUILD_TYPE_FORWARD,
+        GGML_OPT_BUILD_TYPE_GRAD,
+        GGML_OPT_BUILD_TYPE_OPT,
+    };
+
+    // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
+    struct ggml_opt_optimizer_params {
+        // AdamW optimizer parameters
+        struct {
+            float alpha; // learning rate
+            float beta1;
+            float beta2;
+            float eps;   // epsilon for numerical stability
+            float wd;    // weight decay for AdamW, use 0.0f to disable
+        } adamw;
+    };
+
+    // callback to calculate optimizer parameters prior to a backward pass
+    // userdata can be used to pass arbitrary data
+    typedef struct ggml_opt_optimizer_params (*ggml_opt_get_optimizer_params)(void * userdata);
+
+    // returns the default optimizer params (constant)
+    // userdata is not used
+    GGML_API struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata);
+
+    // parameters for initializing a new optimization context
+    struct ggml_opt_params {
+        ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs
+
+        struct ggml_context * ctx_compute; // created in user code, holds non-static tensors
+
+        // the forward graph is defined by inputs and outputs
+        // those tensors and all tensors inbetween are not intended to be reusable between multiple optimization contexts
+        struct ggml_tensor * inputs;
+        struct ggml_tensor * outputs;
+
+        enum ggml_opt_loss_type  loss_type;
+        enum ggml_opt_build_type build_type;
+
+        int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done
+
+        ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
+        void * get_opt_pars_ud;                     // userdata for calculating optimizer parameters
+    };
+
+    // get parameters for an optimization context with defaults set where possible
+    // parameters for which no sensible defaults exist are supplied as arguments to this function
+    GGML_API ggml_opt_params ggml_opt_default_params(
+            ggml_backend_sched_t      backend_sched,
+            struct ggml_context     * ctx_compute,
+            struct ggml_tensor      * inputs,
+            struct ggml_tensor      * outputs,
+            enum ggml_opt_loss_type   loss_type);
+
+    GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params);
+    GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx);
+
+    // set gradients to zero, initilize loss, and optionally reset the optimizer
+    GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
+
+    // get underlying tensors that store data
+    GGML_API struct ggml_tensor * ggml_opt_inputs(  ggml_opt_context_t opt_ctx); // forward graph input tensor
+    GGML_API struct ggml_tensor * ggml_opt_outputs( ggml_opt_context_t opt_ctx); // forward graph output tensor
+    GGML_API struct ggml_tensor * ggml_opt_labels(  ggml_opt_context_t opt_ctx); // labels to compare outputs against
+    GGML_API struct ggml_tensor * ggml_opt_loss(    ggml_opt_context_t opt_ctx); // scalar tensor that contains the loss
+    GGML_API struct ggml_tensor * ggml_opt_pred(    ggml_opt_context_t opt_ctx); // predictions made by outputs
+    GGML_API struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels
+
+    GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
+
+    // ====== Optimization Result ======
+
+    GGML_API ggml_opt_result_t ggml_opt_result_init();
+    GGML_API void ggml_opt_result_free(ggml_opt_result_t result);
+    GGML_API void ggml_opt_result_reset(ggml_opt_result_t result);
+
+    // get data from result, uncertainties are optional and can be ignored by passing NULL
+    GGML_API void ggml_opt_result_ndata(   ggml_opt_result_t result, int64_t * ndata);                  // writes 1 value, number of datapoints
+    GGML_API void ggml_opt_result_loss(    ggml_opt_result_t result, double  * loss,     double * unc); // writes 1 value
+    GGML_API void ggml_opt_result_pred(    ggml_opt_result_t result, int32_t * pred);                   // writes ndata values
+    GGML_API void ggml_opt_result_accuracy(ggml_opt_result_t result, double  * accuracy, double * unc); // writes 1 value
+
+    // ====== Computation ======
+
+    // do forward pass, increment result if not NULL
+    GGML_API void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
+
+    // do forward pass, increment result if not NULL, do backward pass
+    GGML_API void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
+
+    // ############################################################################
+    // ## The high-level functions start here. They do not depend on any private ##
+    // ## functions or structs and can be copied to and adapted for user code.   ##
+    // ############################################################################
+
+    // ====== Intended Usage ======
+    //
+    // 1. Select the appropriate loss for your problem.
+    // 2. Create a dataset and set the data for the "data" tensor. Also set the "labels" tensor if your loss needs them.
+    //    Setting the shard size to 1 will be fine, it's the granularity with which data is shuffled/loaded (bigger values are faster).
+    // 3. Create a GGML graph for your model with no_alloc == true. Use two separate contexts for the tensors.
+    //    The first context should contain the model parameters and inputs and be allocated statically in user code.
+    //    The second context should contain all other tensors and will be (re)allocated automatically.
+    //    Due to this automated allocation the data of the second context is not defined when accessed in user code.
+    //    Note that the second dimension of the inputs/outputs are interpreted as the number of datapoints in those tensors.
+    // 4. Call ggml_opt_fit. If you need more control you can use ggml_opt_epoch instead.
+
+    // signature for a callback while evaluating opt_ctx on dataset, called after an evaluation
+    typedef void (*ggml_opt_epoch_callback)(
+            bool               train,       // true after training evaluation, false after validation evaluation
+            ggml_opt_context_t opt_ctx,
+            ggml_opt_dataset_t dataset,
+            ggml_opt_result_t  result,      // result associated with the dataset subsection
+            int64_t            ibatch,      // number of batches that have been evaluated so far
+            int64_t            ibatch_max,  // total number of batches in this dataset subsection
+            int64_t            t_start_us); // time at which the evaluation on the dataset subsection was started
+
+    // do training on front of dataset, do evaluation only on back of dataset
+    GGML_API void ggml_opt_epoch(
+            ggml_opt_context_t      opt_ctx,
+            ggml_opt_dataset_t      dataset,
+            ggml_opt_result_t       result_train,   // result to increment during training, ignored if NULL
+            ggml_opt_result_t       result_eval,    // result to increment during evaluation, ignored if NULL
+            int64_t                 idata_split,    // data index at which to split training and evaluation
+            ggml_opt_epoch_callback callback_train,
+            ggml_opt_epoch_callback callback_eval);
+
+    // callback that prints a progress bar on stderr
+    GGML_API void ggml_opt_epoch_callback_progress_bar(
+            bool               train,
+            ggml_opt_context_t opt_ctx,
+            ggml_opt_dataset_t dataset,
+            ggml_opt_result_t  result,
+            int64_t            ibatch,
+            int64_t            ibatch_max,
+            int64_t            t_start_us);
+
+    // fit model defined by inputs and outputs to dataset
+    GGML_API void ggml_opt_fit(
+            ggml_backend_sched_t            backend_sched,  // backend scheduler for constructing the compute graphs
+            ggml_context                  * ctx_compute,    // context with temporarily allocated tensors to calculate the outputs
+            ggml_tensor                   * inputs,         // input tensor with shape [ne_datapoint, ndata_batch]
+            ggml_tensor                   * outputs,        // output tensor, must have shape [ne_label, ndata_batch] if labels are used
+            ggml_opt_dataset_t              dataset,        // dataset with data and optionally also labels
+            enum ggml_opt_loss_type         loss_type,      // loss to minimize
+            ggml_opt_get_optimizer_params   get_opt_pars,   // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
+            int64_t                         nepoch,         // how many times the dataset should be iterated over
+            int64_t                         nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
+            float                           val_split,      // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
+            bool                            silent);        // whether or not info prints to stderr should be suppressed
+
+#ifdef  __cplusplus
+}
+#endif
index 3b3f6798a36a7f4f0686d6c50ef8ed75ea5b3cd6..69e6a24344b9783e15cc1c61b76f43bebab9951e 100644 (file)
@@ -602,7 +602,6 @@ extern "C" {
 
         int32_t flags;
 
-        struct ggml_tensor * grad;
         struct ggml_tensor * src[GGML_MAX_SRC];
 
         // source tensor and offset for views
@@ -615,7 +614,7 @@ extern "C" {
 
         void * extra; // extra things e.g. for ggml-cuda.cu
 
-        // char padding[4];
+        char padding[8];
     };
 
     static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
@@ -1985,28 +1984,20 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * grad,
-            float                 alpha,
-            float                 beta1,
-            float                 beta2,
-            float                 eps,
-            float                 wd); // weight decay
+            struct ggml_tensor  * m,
+            struct ggml_tensor  * v,
+            struct ggml_tensor  * adamw_params); // parameters such a the learning rate
 
     //
     // automatic differentiation
     //
 
-    GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
-    GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate);
-
-    GGML_API void ggml_build_opt_adamw(
-            struct ggml_context * ctx,
-            struct ggml_cgraph  * gf,
-            struct ggml_cgraph  * gb,
-            float                 alpha,
-            float                 beta1,
-            float                 beta2,
-            float                 eps,
-            float                 wd); // weight decay
+    GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
+    GGML_API void ggml_build_backward_expand(
+        struct ggml_context * ctx_static,  // context for static gradients (loss + gradient accumulation)
+        struct ggml_context * ctx_compute, // context for gradient computation
+        struct ggml_cgraph  * cgraph,
+        bool                  accumulate); // whether or not gradients should be accumulated, requires static allocation of tensors in ctx_static
 
     // graph allocation in a context
     GGML_API struct ggml_cgraph * ggml_new_graph       (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
@@ -2026,7 +2017,9 @@ extern "C" {
     GGML_API size_t ggml_graph_overhead(void);
     GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads);
 
-    GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);
+    GGML_API struct ggml_tensor * ggml_graph_get_tensor  (const struct ggml_cgraph * cgraph, const char * name);
+    GGML_API struct ggml_tensor * ggml_graph_get_grad    (const struct ggml_cgraph * cgraph, const struct ggml_tensor * node);
+    GGML_API struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node);
 
     GGML_API void                 ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname);
     GGML_API struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval);
@@ -2037,198 +2030,15 @@ extern "C" {
     // dump the graph into a file using the dot format
     GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
 
-    // build gradient checkpointing backward graph gb for gf using provided checkpoints
-    // gb_tmp will contain original backward graph with rewritten backward process nodes,
-    // but without the second forward pass nodes.
-    GGML_API void ggml_build_backward_gradient_checkpointing(
-            struct ggml_context   * ctx,
-            struct ggml_cgraph    * gf,
-            struct ggml_cgraph    * gb,
-            struct ggml_cgraph    * gb_tmp,
-            struct ggml_tensor  * * checkpoints,
-            int                     n_checkpoints);
-    //
-    // optimization
-    //
-
-    // optimization methods
-    enum ggml_opt_type {
-        GGML_OPT_TYPE_ADAM,
-        GGML_OPT_TYPE_LBFGS,
-    };
-
-    // linesearch methods
-    enum ggml_linesearch {
-        GGML_LINESEARCH_DEFAULT = 1,
-
-        GGML_LINESEARCH_BACKTRACKING_ARMIJO       = 0,
-        GGML_LINESEARCH_BACKTRACKING_WOLFE        = 1,
-        GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2,
-    };
-
-    // optimization return values
-    enum ggml_opt_result {
-        GGML_OPT_RESULT_OK = 0,
-        GGML_OPT_RESULT_DID_NOT_CONVERGE,
-        GGML_OPT_RESULT_NO_CONTEXT,
-        GGML_OPT_RESULT_INVALID_WOLFE,
-        GGML_OPT_RESULT_FAIL,
-        GGML_OPT_RESULT_CANCEL,
-
-        GGML_LINESEARCH_FAIL = -128,
-        GGML_LINESEARCH_MINIMUM_STEP,
-        GGML_LINESEARCH_MAXIMUM_STEP,
-        GGML_LINESEARCH_MAXIMUM_ITERATIONS,
-        GGML_LINESEARCH_INVALID_PARAMETERS,
-    };
-
-    typedef void (*ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel);
+    // TODO these functions were sandwiched in the old optimization interface, is there a better place for them?
     typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);
 
     // Set callback for all future logging events.
     // If this is not called, or NULL is supplied, everything is output on stderr.
     GGML_API void ggml_log_set(ggml_log_callback log_callback, void * user_data);
 
-    // optimization parameters
-    //
-    //   see ggml.c (ggml_opt_default_params) for default values
-    //
-    struct ggml_opt_params {
-        enum ggml_opt_type type;
-
-        size_t graph_size;
-
-        int n_threads;
-
-        // delta-based convergence test
-        //
-        //   if past == 0 - disabled
-        //   if past > 0:
-        //     stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|)
-        //
-        int past;
-        float delta;
-
-        // maximum number of iterations without improvement
-        //
-        //   if 0 - disabled
-        //   if > 0:
-        //     assume convergence if no cost improvement in this number of iterations
-        //
-        int max_no_improvement;
-
-        bool print_forward_graph;
-        bool print_backward_graph;
-
-        int n_gradient_accumulation;
-
-        // ADAM parameters
-        struct {
-            int n_iter;
-
-            float sched; // schedule multiplier (fixed, decay or warmup)
-            float decay; // weight decay for AdamW, use 0.0f to disable
-            int   decay_min_ndim; // minimum number of tensor dimension to apply weight decay
-            float alpha; // learning rate
-            float beta1;
-            float beta2;
-            float eps;   // epsilon for numerical stability
-            float eps_f; // epsilon for convergence test
-            float eps_g; // epsilon for convergence test
-            float gclip; // gradient clipping
-        } adam;
-
-        // LBFGS parameters
-        struct {
-            int m; // number of corrections to approximate the inv. Hessian
-            int n_iter;
-            int max_linesearch;
-
-            float eps;      // convergence tolerance
-            float ftol;     // line search tolerance
-            float wolfe;
-            float min_step;
-            float max_step;
-
-            enum ggml_linesearch linesearch;
-        } lbfgs;
-    };
-
-    struct ggml_opt_context {
-        struct ggml_context * ctx;
-        struct ggml_opt_params params;
-
-        int iter;
-        int64_t nx; // number of parameter elements
-
-        bool just_initialized;
-
-        float loss_before;
-        float loss_after;
-
-        struct {
-            struct ggml_tensor * g;  // current gradient
-            struct ggml_tensor * m;  // first moment
-            struct ggml_tensor * v;  // second moment
-            struct ggml_tensor * pf; // past function values
-            float fx_best;
-            float fx_prev;
-            int n_no_improvement;
-        } adam;
-
-        struct {
-            struct ggml_tensor * x;    // current parameters
-            struct ggml_tensor * xp;   // previous parameters
-            struct ggml_tensor * g;    // current gradient
-            struct ggml_tensor * gp;   // previous gradient
-            struct ggml_tensor * d;    // search direction
-            struct ggml_tensor * pf;   // past function values
-            struct ggml_tensor * lmal; // the L-BFGS memory alpha
-            struct ggml_tensor * lmys; // the L-BFGS memory ys
-            struct ggml_tensor * lms;  // the L-BFGS memory s
-            struct ggml_tensor * lmy;  // the L-BFGS memory y
-            float fx_best;
-            float step;
-            int j;
-            int k;
-            int end;
-            int n_no_improvement;
-        } lbfgs;
-    };
-
     GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
 
-    GGML_API struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);
-
-    // optimize the function defined by the tensor f
-    GGML_API enum ggml_opt_result ggml_opt(
-            struct ggml_context * ctx,
-            struct ggml_opt_params params,
-            struct ggml_tensor * f);
-
-    // initialize optimizer context
-    GGML_API void ggml_opt_init(
-            struct ggml_context     * ctx,
-            struct ggml_opt_context * opt,
-            struct ggml_opt_params    params,
-            int64_t                   nx);
-
-    // continue optimizing the function defined by the tensor f
-    GGML_API enum ggml_opt_result ggml_opt_resume(
-            struct ggml_context * ctx,
-            struct ggml_opt_context * opt,
-            struct ggml_tensor * f);
-
-    // continue optimizing the function defined by the tensor f
-    GGML_API enum ggml_opt_result ggml_opt_resume_g(
-            struct ggml_context * ctx,
-            struct ggml_opt_context * opt,
-            struct ggml_tensor * f,
-            struct ggml_cgraph * gf,
-            struct ggml_cgraph * gb,
-            ggml_opt_callback callback,
-            void * callback_data);
-
     //
     // quantization
     //
index 71934c6791e32d29f11bd2124b59c2a630903691..ae7d3abc8de32e94eb8a96729dd615a8136f5f96 100644 (file)
@@ -207,9 +207,11 @@ add_library(ggml-base
             ../include/ggml-alloc.h
             ../include/ggml-backend.h
             ../include/ggml-cpp.h
+            ../include/ggml-opt.h
             ggml.c
             ggml-alloc.c
             ggml-backend.cpp
+            ggml-opt.cpp
             ggml-threading.cpp
             ggml-threading.h
             ggml-quants.c
index 041de9e3efc6927568fe7d2525091a32dbe4b5e4..2b2240be85dd34f1129db1d2d3bf2d3bcc33e708 100644 (file)
@@ -466,18 +466,12 @@ static bool ggml_gallocr_is_own(ggml_gallocr_t galloc, struct ggml_tensor * t) {
     return ggml_gallocr_hash_get(galloc, t)->allocated;
 }
 
-static void ggml_gallocr_set_node_offset(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id, size_t offset) {
-    struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
-    hn->buffer_id = buffer_id;
-    hn->offset = offset;
-    hn->allocated = true;
-}
-
 static bool ggml_gallocr_is_allocated(ggml_gallocr_t galloc, struct ggml_tensor * t) {
     return t->data != NULL || ggml_gallocr_hash_get(galloc, t)->allocated;
 }
 
 static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id) {
+    GGML_ASSERT(buffer_id >= 0);
     struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
 
     if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_is_view(node)) {
@@ -816,7 +810,11 @@ static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor *
 }
 
 static bool ggml_gallocr_node_needs_realloc(ggml_gallocr_t galloc, struct ggml_tensor * node, struct tensor_alloc * talloc) {
-    size_t node_size = (node->data || node->view_src) ? 0 : ggml_backend_buft_get_alloc_size(galloc->bufts[talloc->buffer_id], node);
+    size_t node_size = 0;
+    if (!node->data && !node->view_src) {
+        GGML_ASSERT(talloc->buffer_id >= 0); // prevent segfault when misusing the API
+        node_size = ggml_backend_buft_get_alloc_size(galloc->bufts[talloc->buffer_id], node);
+    }
     return talloc->size_max >= node_size;
 }
 
index e48877ba8fa5a898a10c7e4e2322234af1a5eb8d..634fe38eef8717bffd50162f31a359d27f98ba84 100644 (file)
@@ -279,7 +279,7 @@ void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, siz
     buf->iface.get_tensor(buf, tensor, data, offset, size);
 }
 
-GGML_API void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
     ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
 
     if (size == 0) {
index 61f53cd0145a620992fe5c1c95700064d8d231a1..df648792949817c3a1962c29a6706d36c425fe28 100644 (file)
@@ -12216,11 +12216,16 @@ static void ggml_compute_forward_opt_step_adamw_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
 
-    const struct ggml_tensor * src0        = dst->src[0];
-    const struct ggml_tensor * src0_grad   = dst->src[1];
-    const struct ggml_tensor * src0_grad_m = dst->src[2];
-    const struct ggml_tensor * src0_grad_v = dst->src[3];
+    const struct ggml_tensor * src0         = dst->src[0];
+    const struct ggml_tensor * src0_grad    = dst->src[1];
+    const struct ggml_tensor * src0_grad_m  = dst->src[2];
+    const struct ggml_tensor * src0_grad_v  = dst->src[3];
+    const struct ggml_tensor * adamw_params = dst->src[4];
+
     GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
+    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
+    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
+    GGML_ASSERT(ggml_nelements(adamw_params) == 7);
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -12237,16 +12242,14 @@ static void ggml_compute_forward_opt_step_adamw_f32(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    /* const float   gnorm = 1.0f; */
-    int64_t       iter;   memcpy(&iter, &dst->op_params[0], sizeof(int64_t));
-    const float   alpha = ggml_get_op_params_f32(dst, 2);
-    const float   beta1 = ggml_get_op_params_f32(dst, 3);
-    const float   beta2 = ggml_get_op_params_f32(dst, 4);
-    const float   eps   = ggml_get_op_params_f32(dst, 5);
-    const float   wd    = ggml_get_op_params_f32(dst, 6);
-
-    const float beta1h  = alpha/(1.0f - powf(beta1, iter));
-    const float beta2h  =  1.0f/(1.0f - powf(beta2, iter));
+    const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
+    const float alpha  = adamw_params_ptr[0];
+    const float beta1  = adamw_params_ptr[1];
+    const float beta2  = adamw_params_ptr[2];
+    const float eps    = adamw_params_ptr[3];
+    const float wd     = adamw_params_ptr[4];
+    const float beta1h = adamw_params_ptr[5];
+    const float beta2h = adamw_params_ptr[6];
 
     for (int ir = ir0; ir < ir1; ++ir) {
         const int64_t i03 = ir/(ne02*ne01);
@@ -12270,17 +12273,9 @@ static void ggml_compute_forward_opt_step_adamw_f32(
             // The weight decay is applied independently of the Adam momenta m and v.
             // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
             // See: https://arxiv.org/pdf/1711.05101v3.pdf
-            w[i00] = w[i00]*(1.0f - alpha*wd) - mh/vh;
+            w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh;
         }
     }
-
-    ggml_barrier(params->threadpool);
-    if (ith != 0) {
-        return;
-    }
-
-    iter++;
-    memcpy(&dst->op_params[0], &iter, sizeof(int64_t));
 }
 
 static void ggml_compute_forward_opt_step_adamw(
index d6f13a9c62df29fa732655d0414db6409166a659..35154f29966525f24721416df4d3dc1175b122e7 100644 (file)
@@ -1,11 +1,11 @@
+#include "ggml-impl.h"
 #include "opt-step-adamw.cuh"
 
 #include <cstdint>
 
 static __global__ void opt_step_adamw_f32(
-    float * __restrict__ x, const float * __restrict__ g, float * __restrict__ g_m, float * __restrict__ g_v, const int64_t k,
-    const float alpha, const float beta1, const float beta2, const float eps, const float wd,
-    const float beta1h, const float beta2h) {
+    float * __restrict__ x, const float * __restrict__ g, float * __restrict__ g_m, float * __restrict__ g_v,
+    const float * __restrict__ pars, const int64_t k) {
 
     const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
 
@@ -13,6 +13,14 @@ static __global__ void opt_step_adamw_f32(
         return;
     }
 
+    const float alpha  = pars[0];
+    const float beta1  = pars[1];
+    const float beta2  = pars[2];
+    const float eps    = pars[3];
+    const float wd     = pars[4];
+    const float beta1h = pars[5];
+    const float beta2h = pars[6];
+
     const float gi = g[i];
     const float gmi = g_m[i]*beta1 +    gi*(1.0f - beta1);
     const float gvi = g_v[i]*beta2 + gi*gi*(1.0f - beta2);
@@ -23,58 +31,48 @@ static __global__ void opt_step_adamw_f32(
     const float mh =       gmi*beta1h;
     const float vh = sqrtf(gvi*beta2h) + eps;
 
-    x[i] = x[i]*(1.0f - alpha*wd) - mh/vh;
+    x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh;
 }
 
 static void opt_step_adamw_f32_cuda(
-    float * x, const float * g, float * g_m, float * g_v, const int64_t k,
-    const float alpha, const float beta1, const float beta2, const float eps, const float wd,
-    const float beta1h, const float beta2h, cudaStream_t stream) {
+    float * x, const float * g, float * g_m, float * g_v, const float * pars, const int64_t k, cudaStream_t stream) {
 
     const dim3 block_dims(CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1);
     const dim3 block_nums((k + CUDA_OPT_STEP_ADAMW_BLOCK_SIZE - 1) / CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1);
-    opt_step_adamw_f32<<<block_nums, block_dims, 0, stream>>>(x, g, g_m, g_v, k, alpha, beta1, beta2, eps, wd, beta1h, beta2h);
+    opt_step_adamw_f32<<<block_nums, block_dims, 0, stream>>>(x, g, g_m, g_v, pars, k);
 }
 
 void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * src0        = dst->src[0];
-    const ggml_tensor * src0_grad   = dst->src[1];
-    const ggml_tensor * src0_grad_m = dst->src[2];
-    const ggml_tensor * src0_grad_v = dst->src[3];
-
-    GGML_ASSERT(src0->type        == GGML_TYPE_F32);
-    GGML_ASSERT(src0_grad->type   == GGML_TYPE_F32);
-    GGML_ASSERT(src0_grad_m->type == GGML_TYPE_F32);
-    GGML_ASSERT(src0_grad_v->type == GGML_TYPE_F32);
+    const ggml_tensor * src0         = dst->src[0];
+    const ggml_tensor * src0_grad    = dst->src[1];
+    const ggml_tensor * src0_grad_m  = dst->src[2];
+    const ggml_tensor * src0_grad_v  = dst->src[3];
+    const ggml_tensor * adamw_params = dst->src[4];
+
+    GGML_ASSERT(src0->type         == GGML_TYPE_F32);
+    GGML_ASSERT(src0_grad->type    == GGML_TYPE_F32);
+    GGML_ASSERT(src0_grad_m->type  == GGML_TYPE_F32);
+    GGML_ASSERT(src0_grad_v->type  == GGML_TYPE_F32);
+    GGML_ASSERT(adamw_params->type == GGML_TYPE_F32);
     GGML_ASSERT(ggml_is_contiguous(src0));
     GGML_ASSERT(ggml_is_contiguous(src0_grad));
     GGML_ASSERT(ggml_is_contiguous(src0_grad_m));
     GGML_ASSERT(ggml_is_contiguous(src0_grad_v));
+    GGML_ASSERT(ggml_is_contiguous(adamw_params));
     GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
     GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
     GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
+    GGML_ASSERT(ggml_nelements(adamw_params) == 7);
 
-    float       * src0_d        = (float       *) src0->data;
-    const float * src0_grad_d   = (const float *) src0_grad->data;
-    float       * src0_grad_m_d = (float       *) src0_grad_m->data;
-    float       * src0_grad_v_d = (float       *) src0_grad_v->data;
+    float       * src0_d         = (float       *) src0->data;
+    const float * src0_grad_d    = (const float *) src0_grad->data;
+    float       * src0_grad_m_d  = (float       *) src0_grad_m->data;
+    float       * src0_grad_v_d  = (float       *) src0_grad_v->data;
+    const float * adamw_params_d = (const float *) adamw_params->data;
 
     cudaStream_t stream = ctx.stream();
 
     const int64_t ne = ggml_nelements(src0);
 
-    int64_t iter;  memcpy(&iter,  &dst->op_params[0], sizeof(int64_t));
-    float   alpha; memcpy(&alpha, &dst->op_params[2], sizeof(float));
-    float   beta1; memcpy(&beta1, &dst->op_params[3], sizeof(float));
-    float   beta2; memcpy(&beta2, &dst->op_params[4], sizeof(float));
-    float   eps;   memcpy(&eps,   &dst->op_params[5], sizeof(float));
-    float   wd;    memcpy(&wd,    &dst->op_params[6], sizeof(float));
-
-    const float beta1h  = alpha/(1.0f - powf(beta1, iter));
-    const float beta2h  =  1.0f/(1.0f - powf(beta2, iter));
-
-    opt_step_adamw_f32_cuda(src0_d, src0_grad_d, src0_grad_m_d, src0_grad_v_d, ne, alpha, beta1, beta2, eps, wd, beta1h, beta2h, stream);
-
-    iter++;
-    memcpy(&dst->op_params[0], &iter, sizeof(int64_t));
+    opt_step_adamw_f32_cuda(src0_d, src0_grad_d, src0_grad_m_d, src0_grad_v_d, adamw_params_d, ne, stream);
 }
index aa4d2b85d7d7cf6c74d8aa705039e3dafd3a5d3e..92a64fe5a59303b2fae76f70b4503e8b842bf446 100644 (file)
@@ -196,7 +196,7 @@ void ggml_hash_set_reset(struct ggml_hash_set * hash_set);
 static bool ggml_hash_contains(const struct ggml_hash_set * hash_set, struct ggml_tensor * key);
 
 // returns GGML_HASHSET_FULL if table is full, otherwise the current index of the key or where it should be inserted
-static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, struct ggml_tensor * key);
+static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, const struct ggml_tensor * key);
 
 // returns GGML_HASHSET_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
 static size_t ggml_hash_insert(struct ggml_hash_set * hash_set, struct ggml_tensor * key);
@@ -210,7 +210,7 @@ static inline size_t ggml_hash(const struct ggml_tensor * p) {
     return (size_t)(uintptr_t)p >> 4;
 }
 
-static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, struct ggml_tensor * key) {
+static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, const struct ggml_tensor * key) {
     size_t h = ggml_hash(key) % hash_set->size;
 
     // linear probing
@@ -281,13 +281,14 @@ enum ggml_cgraph_eval_order {
 };
 
 struct ggml_cgraph {
-    int size;
-    int n_nodes;
-    int n_leafs;
-
-    struct ggml_tensor ** nodes;
-    struct ggml_tensor ** grads;
-    struct ggml_tensor ** leafs;
+    int size;    // maximum number of nodes/leafs/grads/grad_accs
+    int n_nodes; // number of nodes currently in use
+    int n_leafs; // number of leafs currently in use
+
+    struct ggml_tensor ** nodes;     // tensors with data that can change if the graph is evaluated
+    struct ggml_tensor ** grads;     // the outputs of these tensors are the gradients of the nodes
+    struct ggml_tensor ** grad_accs; // accumulators for node gradients
+    struct ggml_tensor ** leafs;     // tensors with constant data
 
     struct ggml_hash_set visited_hash_set;
 
index b4b5cfd26891d074a903656336725394bdb1c622..95b21fbf9c50396fc3341c92cdb1cd5efc2f7ee5 100644 (file)
@@ -3639,6 +3639,12 @@ static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
     return ctx->all_data;
 }
 
+static void ggml_backend_metal_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+    memset((char *)tensor->data + offset, value, size);
+
+    UNUSED(buffer);
+}
+
 static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
     memcpy((char *)tensor->data + offset, data, size);
 
@@ -3671,7 +3677,7 @@ static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
     /* .free_buffer     = */ ggml_backend_metal_buffer_free_buffer,
     /* .get_base        = */ ggml_backend_metal_buffer_get_base,
     /* .init_tensor     = */ NULL,
-    /* .memset_tensor   = */ NULL,
+    /* .memset_tensor   = */ ggml_backend_metal_buffer_memset_tensor,
     /* .set_tensor      = */ ggml_backend_metal_buffer_set_tensor,
     /* .get_tensor      = */ ggml_backend_metal_buffer_get_tensor,
     /* .cpy_tensor      = */ ggml_backend_metal_buffer_cpy_tensor,
diff --git a/src/ggml-opt.cpp b/src/ggml-opt.cpp
new file mode 100644 (file)
index 0000000..808aa0d
--- /dev/null
@@ -0,0 +1,867 @@
+#include "ggml-opt.h"
+
+#include "ggml.h"
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
+#include "ggml-impl.h"
+
+#include <algorithm>
+#include <cmath>
+#include <cstdint>
+#include <inttypes.h>
+#include <map>
+#include <random>
+#include <vector>
+
+struct ggml_opt_dataset {
+    struct ggml_context   * ctx;
+    ggml_backend_buffer_t   buf;
+    struct ggml_tensor    * data;
+    struct ggml_tensor    * labels;
+
+    int64_t ndata;
+    int64_t ndata_shard;
+    size_t  nbs_data;
+    size_t  nbs_labels;
+
+    std::vector<int64_t> permutation;
+};
+
+struct ggml_opt_context {
+    ggml_backend_sched_t    backend_sched;
+    ggml_cgraph           * allocated_graph;
+    ggml_cgraph           * allocated_graph_copy;
+    struct ggml_context   * ctx_static;
+    struct ggml_context   * ctx_static_cpu;
+    struct ggml_context   * ctx_compute;
+    struct ggml_context   * ctx_copy;
+    ggml_backend_buffer_t   buf_static;
+    ggml_backend_buffer_t   buf_static_cpu;
+    std::mt19937            rng;
+
+    struct ggml_tensor * inputs;
+    struct ggml_tensor * outputs;
+    struct ggml_tensor * labels;
+
+    struct ggml_tensor * loss;
+    struct ggml_tensor * pred;
+    struct ggml_tensor * ncorrect;
+
+    struct ggml_cgraph * gf;
+    struct ggml_cgraph * gb_grad;
+    struct ggml_cgraph * gb_opt;
+
+    int64_t iter;
+    int32_t opt_period;
+    int32_t opt_i;
+    bool    loss_per_datapoint;
+
+    ggml_opt_get_optimizer_params get_opt_pars;
+    void * get_opt_pars_ud;
+    struct ggml_tensor * adamw_params;
+};
+
+struct ggml_opt_result {
+    int64_t              ndata    = 0;
+    std::vector<float>   loss;
+    std::vector<int32_t> pred;
+    int64_t              ncorrect = 0;
+
+    bool loss_per_datapoint = false;
+    int64_t opt_period = -1;
+};
+
+// ====== Dataset ======
+
+ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label, int64_t ndata, int64_t ndata_shard) {
+    GGML_ASSERT(ne_datapoint >  0);
+    GGML_ASSERT(ne_label     >= 0);
+    GGML_ASSERT(ndata        >  0);
+    GGML_ASSERT(ndata_shard  >  0);
+
+    ggml_opt_dataset_t result = new ggml_opt_dataset;
+    result->ndata       = ndata;
+    result->ndata_shard = ndata_shard;
+
+    {
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ 2*ggml_tensor_overhead(),
+            /*.mem_buffer =*/ nullptr,
+            /*.no_alloc   =*/ true,
+        };
+        result->ctx = ggml_init(params);
+    }
+
+    result->data = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_datapoint, ndata);
+    result->nbs_data = ggml_nbytes(result->data) * ndata_shard/ndata;
+
+    if (ne_label > 0) {
+        result->labels = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_label, ndata);
+        result->nbs_labels = ggml_nbytes(result->labels) * ndata_shard/ndata;
+    } else {
+        result->labels = nullptr;
+        result->nbs_labels = 0;
+    }
+
+    result->buf = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx, ggml_backend_cpu_buffer_type());
+
+    const int64_t nshards = ndata/ndata_shard;
+    result->permutation.resize(nshards);
+    for (int64_t i = 0; i < nshards; ++i) {
+        result->permutation[i] = i;
+    }
+    return result;
+}
+
+void ggml_opt_dataset_free(ggml_opt_dataset_t dataset) {
+    ggml_backend_buffer_free(dataset->buf);
+    ggml_free(dataset->ctx);
+    delete dataset;
+}
+
+struct ggml_tensor * ggml_opt_dataset_data(ggml_opt_dataset_t dataset) {
+    return dataset->data;
+}
+
+struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset) {
+    return dataset->labels;
+}
+
+void ggml_opt_dataset_shuffle(ggml_opt_context_t opt_ctx, ggml_opt_dataset_t dataset, int64_t idata) {
+    GGML_ASSERT(idata <= dataset->ndata);
+
+    if (idata < 0) {
+        std::shuffle(dataset->permutation.begin(), dataset->permutation.end(), opt_ctx->rng);
+        return;
+    }
+
+    GGML_ASSERT(idata % dataset->ndata_shard == 0);
+    const int64_t ishard_max = idata / dataset->ndata_shard;
+    std::shuffle(dataset->permutation.begin(), dataset->permutation.begin() + ishard_max, opt_ctx->rng);
+}
+
+void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor * data_batch, struct ggml_tensor * labels_batch, int64_t ibatch) {
+    GGML_ASSERT(   data_batch && ggml_is_contiguous(data_batch));
+    GGML_ASSERT(!labels_batch || ggml_is_contiguous(labels_batch));
+    GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
+
+    const size_t nb_data_batch = ggml_nbytes(data_batch);
+    GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
+    const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data;
+
+    if (labels_batch) {
+        const size_t nb_labels_batch = ggml_nbytes(labels_batch);
+        GGML_ASSERT(nb_labels_batch == shards_per_batch*dataset->nbs_labels);
+    }
+
+    GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size()));
+
+    for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) {
+        const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch];
+
+        const char * ptr_data = (const char *) dataset->data->data + ishard*dataset->nbs_data;
+        ggml_backend_tensor_set(data_batch, ptr_data, ishard_batch*dataset->nbs_data, dataset->nbs_data);
+
+        if (!labels_batch) {
+            continue;
+        }
+
+        const char * ptr_labels = (const char *) dataset->labels->data + ishard*dataset->nbs_labels;
+        ggml_backend_tensor_set(labels_batch, ptr_labels, ishard_batch*dataset->nbs_labels, dataset->nbs_labels);
+    }
+}
+
+// ====== Model / Context ======
+
+struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata) {
+    GGML_UNUSED(userdata);
+
+    ggml_opt_optimizer_params result;
+
+    result.adamw.alpha = 0.001f;
+    result.adamw.beta1 = 0.9f;
+    result.adamw.beta2 = 0.999f;
+    result.adamw.eps   = 1e-8f;
+    result.adamw.wd    = 0.0f;
+
+    return result;
+}
+
+struct ggml_opt_params ggml_opt_default_params(
+        ggml_backend_sched_t backend_sched,
+        struct ggml_context * ctx_compute,
+        struct ggml_tensor * inputs,
+        struct ggml_tensor * outputs,
+        enum ggml_opt_loss_type loss_type) {
+    return {
+        /*backend_sched   =*/ backend_sched,
+        /*ctx_compute     =*/ ctx_compute,
+        /*inputs          =*/ inputs,
+        /*logits          =*/ outputs,
+        /*loss_type       =*/ loss_type,
+        /*build_type      =*/ GGML_OPT_BUILD_TYPE_OPT,
+        /*opt_period      =*/ 1,
+        /*get_opt_pars    =*/ ggml_opt_get_default_optimizer_params,
+        /*get_opt_pars_ud =*/ nullptr,
+    };
+}
+
+static ggml_tensor * map_tensor(std::map<ggml_tensor *, ggml_tensor *> & tensor_map, ggml_context * ctx, ggml_tensor * tensor) {
+    if (!tensor) {
+        return nullptr;
+    }
+
+    if (tensor_map.find(tensor) != tensor_map.end()) {
+        return tensor_map[tensor];
+    }
+
+    ggml_tensor * new_tensor = ggml_dup_tensor(ctx, tensor);
+    tensor_map[tensor] = new_tensor;
+
+    new_tensor->op = tensor->op;
+    for (int i = 0; i < GGML_MAX_DIMS; i++) {
+        new_tensor->nb[i] = tensor->nb[i];
+    }
+    new_tensor->flags = tensor->flags;
+    memcpy(new_tensor->op_params, tensor->op_params, sizeof(tensor->op_params));
+    strcpy(new_tensor->name, tensor->name);
+    new_tensor->data = tensor->data;
+    new_tensor->buffer = tensor->buffer;
+    new_tensor->extra = tensor->extra;
+    new_tensor->view_offs = tensor->view_offs;
+    new_tensor->view_src = map_tensor(tensor_map, ctx, tensor->view_src);
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        new_tensor->src[i] = map_tensor(tensor_map, ctx, tensor->src[i]);
+    }
+
+    return new_tensor;
+}
+
+static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * graph) {
+    std::map<ggml_tensor *, ggml_tensor *> tensor_map;
+
+    ggml_cgraph * new_graph = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true);
+
+    for (int i = 0; i < graph->n_leafs; i++) {
+        ggml_build_forward_expand(new_graph, map_tensor(tensor_map, ctx, graph->leafs[i]));
+    }
+    for (int i = 0; i < graph->n_nodes; i++) {
+        ggml_build_forward_expand(new_graph, map_tensor(tensor_map, ctx, graph->nodes[i]));
+    }
+    for (int i = 0; i < graph->n_nodes; ++i) {
+        const size_t igrad_src = ggml_hash_find(&graph->visited_hash_set, graph->nodes[i]);
+        const size_t igrad_dst = ggml_hash_find(&new_graph->visited_hash_set, new_graph->nodes[i]);
+        graph->grads[igrad_dst]     = new_graph->grads[igrad_src];
+        graph->grad_accs[igrad_dst] = new_graph->grad_accs[igrad_src];
+    }
+
+    return new_graph;
+}
+
+static void ggml_opt_alloc_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph) {
+    GGML_ASSERT(graph);
+    if (opt_ctx->allocated_graph == graph) {
+        return;
+    }
+
+    ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph
+
+    {
+        ggml_init_params params = {
+            /*.mem_size   =*/ ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE,
+            /*.mem_buffer =*/ nullptr,
+            /*.no_alloc   =*/ true,
+        };
+        ggml_free(opt_ctx->ctx_copy);
+        opt_ctx->ctx_copy = ggml_init(params);
+    }
+
+    opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);
+
+    ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
+    opt_ctx->allocated_graph = graph;
+}
+
+ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
+    ggml_opt_context_t result = new struct ggml_opt_context;
+    result->backend_sched        = params.backend_sched;
+    result->allocated_graph      = nullptr;
+    result->allocated_graph_copy = nullptr;
+    result->ctx_compute          = params.ctx_compute;
+    result->ctx_copy             = nullptr;
+    result->inputs               = params.inputs;
+    result->outputs              = params.outputs;
+    result->iter                 = 1;
+    result->opt_period           = params.opt_period;
+    result->opt_i                = 0;
+    result->get_opt_pars         = params.get_opt_pars;
+    result->get_opt_pars_ud      = params.get_opt_pars_ud;
+
+    GGML_ASSERT(result->inputs->data && "the inputs must be allocated statically");
+    GGML_ASSERT(result->opt_period >= 1);
+
+    const bool accumulate = params.build_type == GGML_OPT_BUILD_TYPE_GRAD ||
+        (params.build_type == GGML_OPT_BUILD_TYPE_OPT && result->opt_period > 1);
+
+    ggml_set_input(result->inputs);
+    ggml_set_output(result->outputs);
+
+    result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
+    ggml_build_forward_expand(result->gf, result->outputs);
+
+    int n_param = 0;
+    for (int i = 0; i < result->gf->n_nodes; ++i) {
+        if (result->gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) {
+            n_param++;
+        }
+    }
+
+    {
+        // The static context is used for:
+        //   - gradients (1 tensor per param if using gradient accumulation)
+        //   - optimizer momenta (2 tensors per param)
+        //   - labels
+        //   - loss + its gradient (up to 5 tensors)
+        //   - pred
+        //   - ncorrect (2 tensors).
+        const size_t tensors_per_param = (accumulate ? 1 : 0) + (params.build_type == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
+        const size_t size_meta = (tensors_per_param*n_param + 9) * ggml_tensor_overhead();
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ size_meta,
+            /*.mem_buffer =*/ nullptr,
+            /*.no_alloc   =*/ true,
+        };
+        result->ctx_static = ggml_init(params);
+    }
+    {
+        // The static cpu context is used for:
+        //   - optimizer parameters (1 for the entire context)
+        const size_t size_meta = 1 * ggml_tensor_overhead();
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ size_meta,
+            /*.mem_buffer =*/ nullptr,
+            /*.no_alloc   =*/ true,
+        };
+        result->ctx_static_cpu = ggml_init(params);
+    }
+
+
+    switch (params.loss_type) {
+        case GGML_OPT_LOSS_TYPE_MEAN: {
+            result->labels = nullptr;
+            result->loss = ggml_sum(result->ctx_static, result->outputs);
+            ggml_set_name(result->loss, "loss_sum");
+            const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs));
+            result->loss = ggml_scale(result->ctx_static, result->loss, scale);
+            ggml_set_name(result->loss, "loss_mean");
+            result->loss_per_datapoint = true;
+            break;
+        }
+        case GGML_OPT_LOSS_TYPE_SUM: {
+            result->labels = nullptr;
+            result->loss = ggml_sum(result->ctx_static, result->outputs);
+            ggml_set_name(result->loss, "loss_sum");
+            result->loss_per_datapoint = false;
+            break;
+        }
+        case GGML_OPT_LOSS_TYPE_CROSS_ENTROPY: {
+            result->labels = ggml_dup_tensor(result->ctx_static, result->outputs);
+            ggml_set_input(result->labels);
+            ggml_set_name(result->labels, "labels");
+            result->loss = ggml_cross_entropy_loss(result->ctx_static, result->outputs, result->labels);
+            ggml_set_name(result->loss, "loss_cross_entropy");
+            if (result->opt_period > 1) {
+                result->loss = ggml_scale(result->ctx_static, result->loss, 1.0f / result->opt_period);
+                ggml_set_name(result->loss, "loss_cross_entropy_scaled");
+            }
+            result->loss_per_datapoint = true;
+            break;
+        }
+        case GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: {
+            result->labels = ggml_dup_tensor(result->ctx_static, result->outputs);
+            ggml_set_input(result->labels);
+            ggml_set_name(result->labels, "labels");
+            result->loss = ggml_sub(result->ctx_static, result->outputs, result->labels);
+            ggml_set_name(result->loss, "loss_error");
+            result->loss = ggml_sqr(result->ctx_static, result->loss);
+            ggml_set_name(result->loss, "loss_squared_error");
+            result->loss = ggml_sum(result->ctx_static, result->loss);
+            ggml_set_name(result->loss, "loss_sum_squared_error");
+            const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs));
+            result->loss = ggml_scale(result->ctx_static, result->loss, scale);
+            ggml_set_name(result->loss, "loss_mean_squared_error");
+            result->loss_per_datapoint = true;
+            break;
+        }
+    }
+    ggml_set_output(result->loss);
+    ggml_set_loss(result->loss);
+    ggml_build_forward_expand(result->gf, result->loss);
+
+    result->pred = ggml_argmax(result->ctx_static, result->outputs);
+    ggml_set_name(result->pred, "pred");
+    ggml_set_output(result->pred);
+    ggml_build_forward_expand(result->gf, result->pred);
+
+    if (result->labels) {
+        result->ncorrect = ggml_count_equal(result->ctx_static, result->pred, ggml_argmax(result->ctx_static, result->labels));
+        ggml_set_name(result->ncorrect, "ncorrect");
+        ggml_set_output(result->ncorrect);
+        ggml_build_forward_expand(result->gf, result->ncorrect);
+    } else {
+        result->ncorrect = nullptr;
+    }
+
+    if (params.build_type == GGML_OPT_BUILD_TYPE_FORWARD) {
+        result->gb_grad = nullptr;
+        result->gb_opt  = nullptr;
+
+        result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
+        result->buf_static_cpu = nullptr;
+
+        ggml_opt_alloc_graph(result, result->gf);
+
+        return result;
+    }
+
+    // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
+    result->gb_grad = ggml_graph_dup(result->ctx_compute, result->gf);
+    ggml_build_backward_expand(result->ctx_static, result->ctx_compute, result->gb_grad, accumulate);
+
+    if (params.build_type == GGML_OPT_BUILD_TYPE_GRAD) {
+        result->gb_opt  = nullptr;
+
+        result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
+        result->buf_static_cpu = nullptr;
+
+        ggml_opt_alloc_graph(result, result->gb_grad);
+        ggml_graph_reset(result->gb_grad);
+
+        return result;
+    }
+
+    GGML_ASSERT(params.build_type == GGML_OPT_BUILD_TYPE_OPT);
+
+    // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
+    result->gb_opt = ggml_graph_dup(result->ctx_compute, result->gb_grad);
+
+    result->adamw_params = ggml_new_tensor_1d(result->ctx_static_cpu, GGML_TYPE_F32, 7);
+    ggml_set_input(result->adamw_params);
+    ggml_set_name(result->adamw_params, "adamw_params");
+
+    for (int i = result->gf->n_nodes-1; i >= 0; --i) {
+        struct ggml_tensor * node = result->gb_opt->nodes[i];
+        struct ggml_tensor * grad = ggml_graph_get_grad(result->gb_opt, node);
+
+        if (node->flags & GGML_TENSOR_FLAG_PARAM) {
+            struct ggml_tensor * m        = ggml_dup_tensor(result->ctx_static, node);
+            struct ggml_tensor * v        = ggml_dup_tensor(result->ctx_static, node);
+            struct ggml_tensor * opt_step = ggml_opt_step_adamw(result->ctx_compute, node, grad, m, v, result->adamw_params);
+            ggml_build_forward_expand(result->gb_opt, opt_step);
+        }
+    }
+
+    result->buf_static = ggml_backend_alloc_ctx_tensors(
+        result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
+
+    result->buf_static_cpu = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx_static_cpu, ggml_backend_cpu_buffer_type());
+
+    ggml_opt_alloc_graph(result, result->gb_opt);
+    ggml_graph_reset(result->gb_opt);
+
+    return result;
+}
+
+void ggml_opt_free(ggml_opt_context_t opt_ctx) {
+    if (opt_ctx == nullptr) {
+        return;
+    }
+    ggml_backend_buffer_free(opt_ctx->buf_static);
+    ggml_backend_buffer_free(opt_ctx->buf_static_cpu);
+    ggml_free(opt_ctx->ctx_static);
+    ggml_free(opt_ctx->ctx_static_cpu);
+    delete opt_ctx;
+}
+
+void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer) {
+    if (optimizer) {
+        ggml_graph_reset(opt_ctx->gb_opt);
+        opt_ctx->iter = 1;
+    } else {
+        ggml_graph_reset(opt_ctx->gb_grad);
+    }
+}
+
+struct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) {
+    return opt_ctx->inputs;
+}
+
+struct ggml_tensor * ggml_opt_outputs(ggml_opt_context_t opt_ctx) {
+    return opt_ctx->outputs;
+}
+
+struct ggml_tensor * ggml_opt_labels(ggml_opt_context_t opt_ctx) {
+    return opt_ctx->labels;
+}
+
+struct ggml_tensor * ggml_opt_loss(ggml_opt_context_t opt_ctx) {
+    return opt_ctx->loss;
+}
+
+struct ggml_tensor * ggml_opt_pred(ggml_opt_context_t opt_ctx) {
+    return opt_ctx->pred;
+}
+
+struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx) {
+    return opt_ctx->ncorrect;
+}
+
+struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node) {
+    return ggml_graph_get_grad_acc(opt_ctx->gb_opt, node);
+}
+
+// ====== Optimization Result ======
+
+ggml_opt_result_t ggml_opt_result_init() {
+    return new ggml_opt_result;
+}
+
+void ggml_opt_result_free(ggml_opt_result_t result) {
+    delete result;
+}
+
+void ggml_opt_result_reset(ggml_opt_result_t result) {
+    result->ndata = 0;
+    result->loss.clear();
+    result->pred.clear();
+    result->ncorrect = 0;
+}
+
+void ggml_opt_result_ndata(ggml_opt_result_t result, int64_t * ndata) {
+    *ndata = result->ndata;
+}
+
+void ggml_opt_result_loss(ggml_opt_result_t result, double * loss, double * unc) {
+    const int64_t nbatches = result->loss.size(); // Number of physical batches.
+
+    if (nbatches == 0) {
+        *loss = 0.0;
+        *unc  = NAN;
+        return;
+    }
+
+    double sum         = 0.0;
+    double sum_squared = 0.0;
+
+    for (const float & loss : result->loss) {
+        // If the loss is per datapoint it was scaled by 1.0f/opt_period for each physical batch.
+        const float loss_scaled = result->loss_per_datapoint ? loss*result->opt_period : loss;
+        sum         += loss_scaled;
+        sum_squared += loss_scaled*loss_scaled;
+    }
+
+    const double mean = sum/nbatches;
+    *loss = result->loss_per_datapoint ? mean : sum;
+
+    if (!unc) {
+        return;
+    }
+
+    if (nbatches < 2) {
+        *unc = NAN;
+        return;
+    }
+
+    const double var_sum = sum_squared/nbatches - mean*mean; // variance without Bessel's correction, i.e. nbatches/(nbatches-1)
+    *unc = result->loss_per_datapoint ? sqrt(var_sum / (nbatches - 1)) : sqrt(var_sum * nbatches/(nbatches - 1));
+}
+
+void ggml_opt_result_pred(ggml_opt_result_t result, int32_t * pred) {
+    for (size_t i = 0; i < result->pred.size(); ++i) {
+        pred[i] = result->pred[i];
+    }
+}
+
+void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, double * unc) {
+    *accuracy = result->ncorrect >= 0 ? double(result->ncorrect) / double(result->ndata) : NAN;
+
+    if (!unc) {
+        return;
+    }
+
+    *unc = result->ncorrect >= 0 && result->ndata >= 2 ?
+        sqrt((*accuracy) * (1.0 - (*accuracy)) / double(result->ndata - 1)) : NAN;
+}
+
+// ====== Computation ======
+
+static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, ggml_opt_result * result) {
+    if (graph != opt_ctx->gf) {
+        struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
+
+        GGML_ASSERT(opt_pars.adamw.alpha >  0.0f);
+        GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
+        GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
+        GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f);
+        GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f);
+        GGML_ASSERT(opt_pars.adamw.eps   >= 0.0f);
+        GGML_ASSERT(opt_pars.adamw.wd    >= 0.0f);
+        GGML_ASSERT(opt_pars.adamw.wd    <= 1.0f);
+
+        // beta1, beta2 after applying warmup
+        const float beta1h = 1.0f/(1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
+        const float beta2h = 1.0f/(1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
+
+        float * adamw_par_data = ggml_get_data_f32(opt_ctx->adamw_params);
+        adamw_par_data[0] = opt_pars.adamw.alpha;
+        adamw_par_data[1] = opt_pars.adamw.beta1;
+        adamw_par_data[2] = opt_pars.adamw.beta2;
+        adamw_par_data[3] = opt_pars.adamw.eps;
+        adamw_par_data[4] = opt_pars.adamw.wd;
+        adamw_par_data[5] = beta1h;
+        adamw_par_data[6] = beta2h;
+    }
+
+    ggml_opt_alloc_graph(opt_ctx, graph);
+    ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
+    opt_ctx->iter += opt_ctx->allocated_graph == opt_ctx->gb_opt;
+
+    if (!result) {
+        return;
+    }
+
+    if (result->ndata == 0) {
+        result->loss_per_datapoint = opt_ctx->loss_per_datapoint;
+        result->opt_period         = opt_ctx->opt_period;
+    } else {
+        GGML_ASSERT(result->loss_per_datapoint == opt_ctx->loss_per_datapoint);
+        GGML_ASSERT(result->opt_period         == opt_ctx->opt_period);
+    }
+
+    const int64_t ndata = opt_ctx->outputs->ne[1];
+    GGML_ASSERT(result->ndata == ndata*int64_t(result->loss.size()) && "varying batch size not supported");
+    result->ndata += ndata;
+
+    GGML_ASSERT(ggml_is_scalar(opt_ctx->loss));
+    GGML_ASSERT(opt_ctx->loss->type == GGML_TYPE_F32);
+    float loss;
+    ggml_backend_tensor_get(opt_ctx->loss, &loss, 0, ggml_nbytes(opt_ctx->loss));
+    result->loss.push_back(loss);
+
+    GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32);
+    std::vector<int32_t> pred(ndata);
+    ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred));
+    result->pred.insert(result->pred.end(), pred.begin(), pred.end());
+
+    if (!opt_ctx->labels || result->ncorrect < 0) {
+        result->ncorrect = -1;
+        return;
+    }
+
+    GGML_ASSERT(ggml_is_scalar(opt_ctx->ncorrect));
+    GGML_ASSERT(opt_ctx->ncorrect->type == GGML_TYPE_I64);
+    int64_t ncorrect;
+    ggml_backend_tensor_get(opt_ctx->ncorrect, &ncorrect, 0, ggml_nbytes(opt_ctx->ncorrect));
+    result->ncorrect += ncorrect;
+}
+
+void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) {
+    ggml_opt_eval_graph(opt_ctx, opt_ctx->gf, result);
+}
+
+void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) {
+    if (opt_ctx->opt_period == 1) {
+        ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
+        return;
+    }
+
+    const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
+    if (opt_i_next == 0) {
+        ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
+        ggml_opt_reset(opt_ctx, /*optimizer =*/ false);
+    } else {
+        ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_grad, result);
+    }
+    opt_ctx->opt_i = opt_i_next;
+}
+
+// ====== High-Level Functions ======
+
+void ggml_opt_epoch(
+        ggml_opt_context_t      opt_ctx,
+        ggml_opt_dataset_t      dataset,
+        ggml_opt_result_t       result_train,
+        ggml_opt_result_t       result_eval,
+        int64_t                 idata_split,
+        ggml_opt_epoch_callback callback_train,
+        ggml_opt_epoch_callback callback_eval) {
+    struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx);
+    struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
+    struct ggml_tensor * data   = ggml_opt_dataset_data(dataset);
+    GGML_ASSERT(data->ne[0] == inputs->ne[0]);
+
+    const int64_t ndata       =   data->ne[1];
+    const int64_t ndata_batch = inputs->ne[1];
+
+    GGML_ASSERT(data->ne[1] % inputs->ne[1] == 0);
+    const int64_t nbatches = ndata/ndata_batch;
+
+    idata_split = idata_split < 0 ? ndata : idata_split;
+    GGML_ASSERT(idata_split % ndata_batch == 0);
+    const int64_t ibatch_split = idata_split / ndata_batch;
+
+    int64_t ibatch = 0;
+    int64_t t_loop_start = ggml_time_us();
+    for (; ibatch < ibatch_split; ++ibatch) {
+        ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
+        ggml_opt_forward_backward(opt_ctx, result_train);
+        if (callback_train) {
+            callback_train(true, opt_ctx, dataset, result_train, ibatch+1, ibatch_split, t_loop_start);
+        }
+    }
+    t_loop_start = ggml_time_us();
+    for (; ibatch < nbatches; ++ibatch) {
+        ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
+        ggml_opt_forward(opt_ctx, result_eval);
+        if (callback_eval) {
+            callback_eval(false, opt_ctx, dataset, result_eval, ibatch+1-ibatch_split, nbatches-ibatch_split, t_loop_start);
+        }
+    }
+}
+
+void ggml_opt_epoch_callback_progress_bar(
+        bool               train,
+        ggml_opt_context_t opt_ctx,
+        ggml_opt_dataset_t dataset,
+        ggml_opt_result_t  result,
+        int64_t            ibatch,
+        int64_t            ibatch_max,
+        int64_t            t_start_us) {
+    fprintf(stderr, "%s[", train ? "train: " : "val:   ");
+
+    constexpr int64_t bar_length = 25;
+    for (int64_t j = 0; j < bar_length; ++j) {
+        const int64_t ibatch_j = ibatch_max * j/bar_length;
+        if (ibatch_j < ibatch) {
+            fprintf(stderr, "=");
+        } else if (ibatch_max * (j - 1)/bar_length < ibatch) {
+            fprintf(stderr, ">");
+        } else {
+            fprintf(stderr, " ");
+        }
+    }
+
+    const int64_t batch_size = ggml_opt_inputs(opt_ctx)->ne[1];
+    const int64_t idata      = ibatch*batch_size;
+    const int64_t idata_max  = ibatch_max*batch_size;
+
+    double loss;
+    double loss_unc;
+    ggml_opt_result_loss(result, &loss, &loss_unc);
+
+    double accuracy;
+    double accuracy_unc;
+    ggml_opt_result_accuracy(result, &accuracy, &accuracy_unc);
+
+    const int64_t t_ibatch_us = ggml_time_us() - t_start_us;
+    int64_t t_ibatch_s = t_ibatch_us / 1000000;
+    const int64_t t_ibatch_h = t_ibatch_s / 3600;
+    t_ibatch_s -= t_ibatch_h * 3600;
+    const int64_t t_ibatch_m = t_ibatch_s / 60;
+    t_ibatch_s -= t_ibatch_m * 60;
+
+    const int64_t t_eta_us = t_ibatch_us * (ibatch_max - ibatch)/ibatch;
+    int64_t t_eta_s = t_eta_us / 1000000;
+    const int64_t t_eta_h = t_eta_s / 3600;
+    t_eta_s -= t_eta_h * 3600;
+    const int64_t t_eta_m = t_eta_s / 60;
+    t_eta_s -= t_eta_m * 60;
+
+    fprintf(stderr, "| data=%06" PRId64 "/%06" PRId64 ", loss=%.6lf+-%.6lf, accuracy=%.2lf+-%.2lf%%, "
+            "t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 ", ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 "]\r",
+            idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc,
+            t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s);
+    if (ibatch == ibatch_max) {
+        fprintf(stderr, "\n");
+    }
+    fflush(stderr);
+
+    GGML_UNUSED(dataset);
+}
+
+void ggml_opt_fit(
+        ggml_backend_sched_t            backend_sched,
+        ggml_context                  * ctx_compute,
+        ggml_tensor                   * inputs,
+        ggml_tensor                   * outputs,
+        ggml_opt_dataset_t              dataset,
+        enum ggml_opt_loss_type         loss_type,
+        ggml_opt_get_optimizer_params   get_opt_pars,
+        int64_t                         nepoch,
+        int64_t                         nbatch_logical,
+        float                           val_split,
+        bool                            silent) {
+    ggml_time_init();
+    const int64_t t_start_us = ggml_time_us();
+
+    const int64_t ndata           = ggml_opt_dataset_data(dataset)->ne[1];
+    const int64_t nbatch_physical = inputs->ne[1];
+    GGML_ASSERT(ndata          % nbatch_logical  == 0);
+    GGML_ASSERT(nbatch_logical % nbatch_physical == 0);
+
+    const int64_t opt_period       = nbatch_logical / nbatch_physical;
+    const int64_t nbatches_logical = ndata / nbatch_logical;
+
+    GGML_ASSERT(val_split >= 0.0f);
+    GGML_ASSERT(val_split <  1.0f);
+    const int64_t ibatch_split = int64_t(((1.0f - val_split) * nbatches_logical)) * opt_period; // train <-> val split index (physical)
+    const int64_t idata_split  = ibatch_split * nbatch_physical;
+
+    int64_t epoch = 1;
+
+    ggml_opt_params params = ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type);
+    params.opt_period      = opt_period;
+    params.get_opt_pars    = get_opt_pars;
+    params.get_opt_pars_ud = &epoch;
+    ggml_opt_context_t opt_ctx = ggml_opt_init(params);
+
+    // Shuffling the data is generally useful but there is only a point if not all data is used in a single batch.
+    if (nbatch_logical < ndata) {
+        ggml_opt_dataset_shuffle(opt_ctx, dataset, -1); // Shuffle all data (train + validation).
+    }
+
+    ggml_opt_result_t result_train = ggml_opt_result_init();
+    ggml_opt_result_t result_val   = ggml_opt_result_init();
+
+    ggml_opt_epoch_callback epoch_callback = silent ? nullptr : ggml_opt_epoch_callback_progress_bar;
+
+    for (; epoch <= nepoch; ++epoch) {
+        if (nbatch_logical < idata_split) {
+            ggml_opt_dataset_shuffle(opt_ctx, dataset, idata_split);
+        }
+
+        ggml_opt_result_reset(result_train);
+        ggml_opt_result_reset(result_val);
+
+        if (!silent) {
+            fprintf(stderr, "%s: epoch %04" PRId64 "/%04" PRId64 ":\n", __func__, epoch, nepoch);
+        }
+        ggml_opt_epoch(opt_ctx, dataset, result_train, result_val, idata_split, epoch_callback, epoch_callback);
+        if (!silent) {
+            fprintf(stderr, "\n");
+        }
+    }
+
+    if (!silent) {
+        int64_t t_total_s = (ggml_time_us() - t_start_us) / 1000000;
+        const int64_t t_total_h = t_total_s / 3600;
+        t_total_s -= t_total_h * 3600;
+        const int64_t t_total_m = t_total_s / 60;
+        t_total_s -= t_total_m * 60;
+        fprintf(stderr, "%s: training took %02" PRId64 ":%02" PRId64 ":%02" PRId64 "\n", __func__, t_total_h, t_total_m, t_total_s);
+    }
+
+    ggml_opt_free(opt_ctx);
+    ggml_opt_result_free(result_train);
+    ggml_opt_result_free(result_val);
+}
index 5cdf59f2560a15dfb91f88a1b8dfe48fc14ba02e..4a478fcaa2064153e3d3b605bd55f8cac536fe3b 100644 (file)
@@ -1592,14 +1592,13 @@ static struct ggml_tensor * ggml_new_tensor_impl(
         /*.op           =*/ GGML_OP_NONE,
         /*.op_params    =*/ { 0 },
         /*.flags        =*/ 0,
-        /*.grad         =*/ NULL,
         /*.src          =*/ { NULL },
         /*.view_src     =*/ view_src,
         /*.view_offs    =*/ view_offs,
         /*.data         =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data,
         /*.name         =*/ { 0 },
         /*.extra        =*/ NULL,
-        ///*.padding      =*/ { 0 },
+        /*.padding      =*/ { 0 },
     };
 
 #ifdef __clang__
@@ -4194,8 +4193,6 @@ struct ggml_tensor * ggml_flash_attn_ext(
         GGML_ASSERT(mask);
     }
 
-    bool is_node = false;
-
     // permute(0, 2, 1, 3)
     int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
@@ -4203,8 +4200,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
     float params[] = { scale, max_bias, logit_softcap };
     ggml_set_op_params(result, params, sizeof(params));
 
-    result->op   = GGML_OP_FLASH_ATTN_EXT;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_FLASH_ATTN_EXT;
     result->src[0] = q;
     result->src[1] = k;
     result->src[2] = v;
@@ -4272,14 +4268,6 @@ struct ggml_tensor * ggml_flash_attn_back(
 
     GGML_ASSERT(ne2 % kvne2 == 0);
 
-    bool is_node = false;
-
-    if (q->grad || k->grad || v->grad) {
-        // when using this operation (in backwards pass) these grads are set.
-        // we don't want to create (big) grad of our result, so is_node is false.
-        is_node = false;
-    }
-
     // store gradients of q, k and v as continuous tensors concatenated in result.
     // note: v and gradv are actually transposed, i.e. v->ne[0] != D.
     const int64_t elem_q = ggml_nelements(q);
@@ -4302,8 +4290,7 @@ struct ggml_tensor * ggml_flash_attn_back(
     int32_t masked_i = masked ? 1 : 0;
     ggml_set_op_params(result, &masked_i, sizeof(masked_i));
 
-    result->op   = GGML_OP_FLASH_ATTN_BACK;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->op     = GGML_OP_FLASH_ATTN_BACK;
     result->src[0] = q;
     result->src[1] = k;
     result->src[2] = v;
@@ -4945,34 +4932,24 @@ struct ggml_tensor * ggml_opt_step_adamw(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * grad,
-        float                 alpha,
-        float                 beta1,
-        float                 beta2,
-        float                 eps,
-        float                 wd) {
+        struct ggml_tensor  * m,
+        struct ggml_tensor  * v,
+        struct ggml_tensor  * adamw_params) {
     GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
     GGML_ASSERT(ggml_are_same_shape(a, grad));
-    GGML_ASSERT(alpha >  0.0f);
-    GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f);
-    GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f);
-    GGML_ASSERT(eps   >= 0.0f);
-    GGML_ASSERT(wd    >= 0.0f && wd    <= 1.0f);
+    GGML_ASSERT(ggml_are_same_shape(a, m));
+    GGML_ASSERT(ggml_are_same_shape(a, v));
+    GGML_ASSERT(adamw_params->type == GGML_TYPE_F32);
+    GGML_ASSERT(ggml_nelements(adamw_params) == 7);
 
     struct ggml_tensor * result = ggml_view_tensor(ctx, a);
 
-    const int64_t iter = 1;
-    memcpy(&result->op_params[0], &iter, sizeof(int64_t));
-    ggml_set_op_params_f32(result, 2, alpha);
-    ggml_set_op_params_f32(result, 3, beta1);
-    ggml_set_op_params_f32(result, 4, beta2);
-    ggml_set_op_params_f32(result, 5, eps);
-    ggml_set_op_params_f32(result, 6, wd);
-
     result->op     = GGML_OP_OPT_STEP_ADAMW;
     result->src[0] = a;
     result->src[1] = grad;
-    result->src[2] = ggml_dup_tensor(ctx, grad);
-    result->src[3] = ggml_dup_tensor(ctx, grad);
+    result->src[2] = m;
+    result->src[3] = v;
+    result->src[4] = adamw_params;
 
     return result;
 }
@@ -5041,1112 +5018,514 @@ static void ggml_hash_map_free(struct hash_map * map) {
     GGML_FREE(map);
 }
 
-// gradient checkpointing
-
-static struct ggml_tensor * ggml_recompute_graph_node(
-        struct ggml_context * ctx,
-        struct ggml_cgraph  * graph,
-        struct hash_map     * replacements,
-        struct ggml_tensor  * node) {
-
-    if (node == NULL) {
-        return NULL;
-    }
-
-    if (node->flags & GGML_TENSOR_FLAG_PARAM) {
-        return node;
-    }
-
-    if (!ggml_hash_contains(&graph->visited_hash_set, node)) {
-        return node;
-    }
-
-    int count_children = 0;
-    for (int k = 0; k < GGML_MAX_SRC; ++k) {
-        if (node->src[k]) {
-            ++count_children;
-        }
-    }
-
-    if (count_children == 0) {
-        return node;
-    }
-
-    size_t i = ggml_hash_find(&replacements->set, node);
-    GGML_ASSERT(i != GGML_HASHSET_FULL); // assert that not full
-    if (replacements->set.keys[i] == node) {
-        return replacements->vals[i];
-    }
-
-    struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, GGML_MAX_DIMS, node->ne);
-
-    // insert clone into replacements
-    GGML_ASSERT(replacements->set.keys[i] == NULL); // assert that we don't overwrite
-    replacements->set.keys[i] = node;
-    replacements->vals[i] = clone;
-
-    clone->op       = node->op;
-    clone->grad     = node->grad;
-    clone->flags    = node->flags;
-    clone->extra    = node->extra;
-    for (int k = 0; k < GGML_MAX_DIMS; ++k) {
-        clone->nb[k] = node->nb[k];
-    }
-    for (int k = 0; k < GGML_MAX_SRC; ++k) {
-        clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
-    }
-    if (node->view_src != NULL) {
-        clone->data = (node->view_src->data == NULL)
-                        ? NULL // view_src not yet allocated
-                        : (char *) node->view_src->data // view_src already allocated
-                                 + node->view_offs;
-        clone->view_src  = node->view_src;
-        clone->view_offs = node->view_offs;
-    }
-
-    GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
-    GGML_ASSERT(sizeof(node->name)      == GGML_MAX_NAME);
-    memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
-    ggml_format_name(clone, "%s (clone)", ggml_get_name(node));
-
-    return clone;
-}
-
-void ggml_build_backward_gradient_checkpointing(
-        struct ggml_context   * ctx,
-        struct ggml_cgraph    * gf,
-        struct ggml_cgraph    * gb,
-        struct ggml_cgraph    * gb_tmp,
-        struct ggml_tensor  * * checkpoints,
-        int                     n_checkpoints) {
-    ggml_graph_cpy(gf, gb_tmp);
-    ggml_build_backward_expand(ctx, gf, gb_tmp, false);
-
-    if (n_checkpoints <= 0) {
-        ggml_graph_cpy(gb_tmp, gb);
-        return;
-    }
-
-    struct hash_map * replacements = ggml_new_hash_map(gf->n_nodes + gf->n_leafs + n_checkpoints);
-
-    // insert checkpoints in replacements
-    for (int i = 0; i < n_checkpoints; ++i) {
-        size_t k = ggml_hash_find(&replacements->set, checkpoints[i]);
-        GGML_ASSERT(k != GGML_HASHSET_FULL); // assert that not full
-        GGML_ASSERT(replacements->set.keys[k] == NULL); // assert that we don't overwrite
-        replacements->set.keys[k] = checkpoints[i];
-        replacements->vals[k]     = checkpoints[i];
-    }
-
-    ggml_graph_cpy(gf, gb);
-    // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
-    // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
-    // by recomputing them from checkpoints
-    for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
-        struct ggml_tensor * node = gb_tmp->nodes[i];
-        for (int k = 0; k < GGML_MAX_SRC; ++k) {
-            // insert new tensors recomputing src, reusing already made replacements,
-            // remember replacements: remember new tensors with mapping from corresponding gf nodes
-            // recurse for input tensors,
-            // unless (i.e. terminating when) input tensors are replacements (like checkpoints)
-            node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
-        }
-        // insert rewritten backward node with replacements made into resulting backward graph gb
-        ggml_build_forward_expand(gb, node);
-    }
-
-    ggml_hash_map_free(replacements);
-}
-
 // utility functions to change gradients
 // if a is in acc_table, modify gradients in-place and mark result as gradient accumulator
 // else if a is in zero_table, replace a
 // else, just add/subtract/etc. the gradients
 
-static struct ggml_tensor * ggml_add_or_set(
-        struct ggml_context  * ctx,
-        struct ggml_tensor   * a,
-        struct ggml_tensor   * b,
-        struct ggml_hash_set * zero_table,
-        struct ggml_hash_set * acc_table) {
-    if (ggml_hash_contains(acc_table, a)) {
-        struct ggml_tensor * ret = ggml_add_impl(ctx, a, b, true);
-        const size_t insert_result = ggml_hash_insert(acc_table, ret);
-        GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
-        GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
-        return ret;
-    }
-    if (ggml_hash_contains(zero_table, a)) {
-        return b;
+static void ggml_add_or_set(
+        struct ggml_context * ctx,
+        struct ggml_cgraph  * cgraph,
+        size_t                isrc,
+        struct ggml_tensor  * tensor) {
+    if (cgraph->grads[isrc]) {
+        cgraph->grads[isrc] = ggml_add_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
+    } else {
+        cgraph->grads[isrc] = tensor;
     }
-    return ggml_add_impl(ctx, a, b, false);
+    ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
 }
 
-static struct ggml_tensor * ggml_acc_or_set(
-        struct ggml_context  * ctx,
-        struct ggml_tensor   * a,
-        struct ggml_tensor   * b,
-        const  size_t          nb1,
-        const  size_t          nb2,
-        const  size_t          nb3,
-        const  size_t          offset,
-        struct ggml_hash_set * zero_table,
-        struct ggml_hash_set * acc_table) {
-    if (ggml_hash_contains(acc_table, a)) {
-        struct ggml_tensor * ret = ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
-        const size_t insert_result = ggml_hash_insert(acc_table, ret);
-        GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
-        GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
-        return ret;
-    }
-    if (ggml_hash_contains(zero_table, a)) {
-        struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
-        return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
+static void ggml_acc_or_set(
+        struct ggml_context * ctx,
+        struct ggml_cgraph  * cgraph,
+        size_t                isrc,
+        struct ggml_tensor  * src,
+        struct ggml_tensor  * tensor,
+        const  size_t         nb1,
+        const  size_t         nb2,
+        const  size_t         nb3,
+        const  size_t         offset) {
+    if (cgraph->grads[isrc]) {
+        cgraph->grads[isrc] = ggml_acc_impl(ctx, cgraph->grads[isrc], tensor, nb1, nb2, nb3, offset, cgraph->grad_accs[isrc]);
+    } else {
+        struct ggml_tensor * a_zero = ggml_scale(ctx, src, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
+        cgraph->grads[isrc] = ggml_acc_impl(ctx, a_zero, tensor, nb1, nb2, nb3, offset, false);
     }
-    return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
+    ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
 }
 
-static struct ggml_tensor * ggml_add1_or_set(
-        struct ggml_context  * ctx,
-        struct ggml_tensor   * a,
-        struct ggml_tensor   * b,
-        struct ggml_hash_set * zero_table,
-        struct ggml_hash_set * acc_table) {
-    if (ggml_hash_contains(acc_table, a)) {
-        struct ggml_tensor * ret = ggml_add1_impl(ctx, a, b, true);
-        const size_t insert_result = ggml_hash_insert(acc_table, ret);
-        GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
-        GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
-        return ret;
-    }
-    if (ggml_hash_contains(zero_table, a)) {
-        return ggml_repeat(ctx, b, a);
+static void ggml_add1_or_set(
+        struct ggml_context * ctx,
+        struct ggml_cgraph  * cgraph,
+        size_t                isrc,
+        struct ggml_tensor  * src,
+        struct ggml_tensor  * tensor) {
+    if (cgraph->grads[isrc]) {
+        cgraph->grads[isrc] = ggml_add1_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
+    } else {
+        cgraph->grads[isrc] = ggml_repeat(ctx, tensor, src);
     }
-    return ggml_add1_impl(ctx, a, b, false);
+    ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
 }
 
-static struct ggml_tensor * ggml_sub_or_set(
-        struct ggml_context  * ctx,
-        struct ggml_tensor   * a,
-        struct ggml_tensor   * b,
-        struct ggml_hash_set * zero_table,
-        struct ggml_hash_set * acc_table) {
-    if (ggml_hash_contains(acc_table, a)) {
-        struct ggml_tensor * ret = ggml_sub_impl(ctx, a, b, true);
-        const size_t insert_result = ggml_hash_insert(acc_table, ret);
-        GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
-        GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
-        return ret;
-    }
-    if (ggml_hash_contains(zero_table, a)) {
-        return ggml_neg(ctx, b);
+static void ggml_sub_or_set(
+        struct ggml_context * ctx,
+        struct ggml_cgraph  * cgraph,
+        size_t                isrc,
+        struct ggml_tensor  * tensor) {
+    if (cgraph->grads[isrc]) {
+        cgraph->grads[isrc] = ggml_sub_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
+    } else {
+        cgraph->grads[isrc] = ggml_neg(ctx, tensor);
     }
-    return ggml_sub_impl(ctx, a, b, false);
+    ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
 }
 
-static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set * zero_table, struct ggml_hash_set * acc_table) {
+static void ggml_compute_backward(
+        struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, bool * grads_needed) {
+    struct ggml_tensor * tensor = cgraph->nodes[i];
+    struct ggml_tensor * grad   = ggml_graph_get_grad(cgraph, tensor);
+
+    if (!grad) {
+        return;
+    }
+
     struct ggml_tensor * src0 = tensor->src[0];
     struct ggml_tensor * src1 = tensor->src[1];
     struct ggml_tensor * src2 = tensor->src[2];
+    struct ggml_hash_set * hash_set = &cgraph->visited_hash_set;
+    const size_t isrc0 = ggml_hash_find(hash_set, src0);
+    const size_t isrc1 = ggml_hash_find(hash_set, src1);
+    const size_t isrc2 = ggml_hash_find(hash_set, src2);
+    const bool src0_needs_grads = isrc0 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc0) && grads_needed[isrc0];
+    const bool src1_needs_grads = isrc1 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc1) && grads_needed[isrc1];
+    const bool src2_needs_grads = isrc2 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc2) && grads_needed[isrc2];
 
     switch (tensor->op) {
-        case GGML_OP_DUP:
-            {
-                if (src0->grad) {
-                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
-                }
-            } break;
-        case GGML_OP_ADD:
-            {
-                if (src0->grad) {
-                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
-                }
-                if (src1->grad) {
-                    if (ggml_are_same_shape(src0, src1)) {
-                        src1->grad = ggml_add_or_set(ctx, src1->grad,                       tensor->grad,        zero_table, acc_table);
-                    } else {
-                        src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_repeat_back(ctx, tensor->grad, src1), zero_table, acc_table);
-                    }
-                }
-            } break;
-        case GGML_OP_ADD1:
-            {
-                if (src0->grad) {
-                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
-                }
-                if (src1->grad) {
-                    src1->grad = ggml_add_or_set(ctx,
-                        src1->grad,
-                        ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean
-                        zero_table, acc_table);
-                }
-            } break;
-        case GGML_OP_ACC:
-            {
-                if (src0->grad) {
-                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
-                }
-                if (src1->grad) {
-                    const size_t nb1     = ((int32_t *) tensor->op_params)[0];
-                    const size_t nb2     = ((int32_t *) tensor->op_params)[1];
-                    const size_t nb3     = ((int32_t *) tensor->op_params)[2];
-                    const size_t offset  = ((int32_t *) tensor->op_params)[3];
-
-                    struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx,
-                        tensor->grad,
-                        src1->grad->ne[0],
-                        src1->grad->ne[1],
-                        src1->grad->ne[2],
-                        src1->grad->ne[3],
-                        nb1, nb2, nb3, offset);
-
-                    src1->grad =
-                        ggml_add_or_set(ctx,
-                            src1->grad,
-                            ggml_reshape(ctx,
-                                ggml_cont(ctx, tensor_grad_view),
-                                src1->grad),
-                            zero_table, acc_table);
-                }
-            } break;
-        case GGML_OP_SUB:
-            {
-                if (src0->grad) {
-                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
-                }
-                if (src1->grad) {
-                    src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table);
-                }
-            } break;
-        case GGML_OP_MUL:
-            {
-                if (src0->grad) {
-                    src0->grad =
-                        ggml_add_or_set(ctx,
-                                src0->grad,
-                                ggml_mul(ctx, src1, tensor->grad),
-                                zero_table, acc_table);
-                }
-                if (src1->grad) {
-                    src1->grad =
-                        ggml_add_or_set(ctx,
-                                src1->grad,
-                                ggml_mul(ctx, src0, tensor->grad),
-                                zero_table, acc_table);
-                }
-            } break;
-        case GGML_OP_DIV:
-            {
-                if (src0->grad) {
-                    src0->grad =
-                        ggml_add_or_set(ctx,
-                                src0->grad,
-                                ggml_div(ctx, tensor->grad, src1),
-                                zero_table, acc_table);
-                }
-                if (src1->grad) {
-                    src1->grad =
-                        ggml_sub_or_set(ctx,
-                                src1->grad,
-                                ggml_mul(ctx,
-                                    tensor->grad,
-                                    ggml_div(ctx, tensor, src1)),
-                                zero_table, acc_table);
-                }
-            } break;
-        case GGML_OP_SQR:
-            {
-                if (src0->grad) {
-                    src0->grad =
-                        ggml_add_or_set(ctx,
-                                src0->grad,
-                                ggml_scale(ctx,
-                                    ggml_mul(ctx, src0, tensor->grad),
-                                    2.0f),
-                                zero_table, acc_table);
-                }
-            } break;
-        case GGML_OP_SQRT:
-            {
-                if (src0->grad) {
-                    src0->grad =
-                        ggml_add_or_set(ctx,
-                                src0->grad,
-                                ggml_scale(ctx,
-                                    ggml_div(ctx,
-                                        tensor->grad,
-                                        tensor),
-                                    0.5f),
-                                zero_table, acc_table);
-                }
-            } break;
-        case GGML_OP_LOG:
-            {
-                if (src0->grad) {
-                    src0->grad =
-                        ggml_add_or_set(ctx,
-                                src0->grad,
-                                ggml_div(ctx,
-                                    tensor->grad,
-                                    src0),
-                                zero_table, acc_table);
-                }
-            } break;
-        case GGML_OP_SIN:
-            {
-                if (src0->grad) {
-                    src0->grad =
-                        ggml_add_or_set(ctx,
-                                src0->grad,
-                                ggml_mul(ctx,
-                                    tensor->grad,
-                                    ggml_cos(ctx, src0)),
-                                zero_table, acc_table);
-                }
-            } break;
-        case GGML_OP_COS:
-            {
-                if (src0->grad) {
-                    src0->grad =
-                        ggml_sub_or_set(ctx,
-                                src0->grad,
-                                ggml_mul(ctx,
-                                    tensor->grad,
-                                    ggml_sin(ctx, src0)),
-                                zero_table, acc_table);
-                }
-            } break;
-        case GGML_OP_SUM:
-            {
-                if (src0->grad) {
-                    src0->grad =
-                        ggml_add1_or_set(ctx,
-                                src0->grad,
-                                tensor->grad,
-                                zero_table, acc_table);
-                }
-            } break;
-        case GGML_OP_SUM_ROWS:
-            {
-                if (src0->grad) {
-                    src0->grad =
-                        ggml_add_or_set(ctx,
-                                src0->grad,
-                                ggml_repeat(ctx,
-                                    tensor->grad,
-                                    src0->grad),
-                                zero_table, acc_table);
-                }
-            } break;
-        case GGML_OP_MEAN:
-        case GGML_OP_ARGMAX:
-        case GGML_OP_COUNT_EQUAL:
-            {
-                GGML_ABORT("fatal error"); // TODO: implement
-            }
-        case GGML_OP_REPEAT:
-            {
-                // necessary for llama
-                if (src0->grad) {
-                    src0->grad = ggml_add_or_set(ctx,
-                            src0->grad,
-                            ggml_repeat_back(ctx, tensor->grad, src0->grad),
-                            zero_table, acc_table);
-                }
-            } break;
-        case GGML_OP_REPEAT_BACK:
-            {
-                if (src0->grad) {
-                    // TODO: test this
-                    src0->grad = ggml_add_or_set(ctx,
-                            src0->grad,
-                            ggml_repeat(ctx, tensor->grad, src0->grad),
-                            zero_table, acc_table);
-                }
-            } break;
-        case GGML_OP_CONCAT:
-            {
-                GGML_ABORT("fatal error"); // TODO: implement
-            }
-        case GGML_OP_SILU_BACK:
-            {
-                GGML_ABORT("fatal error"); // TODO: not implemented
+        case GGML_OP_DUP: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, grad);
             }
-        case GGML_OP_NORM:
-            {
-                GGML_ABORT("fatal error"); // TODO: not implemented
+        } break;
+        case GGML_OP_ADD: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, grad);
             }
-        case GGML_OP_RMS_NORM:
-            {
-                // necessary for llama
-                if (src0->grad) {
-                    float eps;
-                    memcpy(&eps, tensor->op_params, sizeof(float));
-
-                    src0->grad = ggml_add_or_set(ctx,
-                            src0->grad,
-                            ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
-                            zero_table, acc_table);
+            if (src1_needs_grads) {
+                struct ggml_tensor * tmp = grad;
+                if (!ggml_are_same_shape(src0, src1)) {
+                    tmp = ggml_repeat_back(ctx, tmp, src1);
                 }
-            } break;
-        case GGML_OP_RMS_NORM_BACK:
-            {
-                GGML_ABORT("fatal error"); // TODO: not implemented
+                ggml_add_or_set(ctx, cgraph, isrc1, tmp);
             }
-        case GGML_OP_GROUP_NORM:
-            {
-                GGML_ABORT("fatal error"); // TODO: not implemented
+        } break;
+        case GGML_OP_ADD1: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, grad);
             }
-        case GGML_OP_MUL_MAT:
-            {
-                // https://cs231n.github.io/optimization-2/#staged
-                // # forward pass
-                // s0 = np.random.randn(5, 10)
-                // s1 = np.random.randn(10, 3)
-                // t = s0.dot(s1)
-
-                // # now suppose we had the gradient on t from above in the circuit
-                // dt = np.random.randn(*t.shape) # same shape as t
-                // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
-                // ds1 = t.T.dot(dt)
-
-                // tensor.shape [m,p,qq,rr]
-                // src0.shape   [n,m,q1,r1]
-                // src1.shape   [n,p,qq,rr]
-
-                // necessary for llama
-                if (src0->grad) {
-                    struct ggml_tensor * s1_tg =
-                        ggml_out_prod(ctx, // [n,m,qq,rr]
-                            src1,          // [n,p,qq,rr]
-                            tensor->grad); // [m,p,qq,rr]
-                    const int64_t qq = s1_tg->ne[2];
-                    const int64_t rr = s1_tg->ne[3];
-                    const int64_t q1 = src0->ne[2];
-                    const int64_t r1 = src0->ne[3];
-                    const bool ne2_broadcasted = qq > q1;
-                    const bool ne3_broadcasted = rr > r1;
-                    if (ne2_broadcasted || ne3_broadcasted) {
-                        // sum broadcast repetitions of s1_tg into shape of src0
-                        s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
-                    }
-                    src0->grad =
-                        ggml_add_or_set(ctx,
-                                src0->grad, // [n,m,q1,r1]
-                                s1_tg,      // [n,m,q1,r1]
-                                zero_table, acc_table);
-                }
-                if (src1->grad) {
-                    src1->grad =
-                        ggml_add_or_set(ctx,
-                                src1->grad,                            // [n,p,qq,rr]
-                                // ggml_mul_mat(ctx,                   // [n,p,qq,rr]
-                                //     ggml_cont(ctx,                  // [m,n,q1,r1]
-                                //         ggml_transpose(ctx, src0)), // [m,n,q1,r1]
-                                //     tensor->grad),                  // [m,p,qq,rr]
-
-                                // // when src0 is bigger than tensor->grad (this is mostly the case in llama),
-                                // // avoid transpose of src0, rather transpose smaller tensor->grad
-                                // // and then use ggml_out_prod
-                                ggml_out_prod(ctx,                  // [n,p,qq,rr]
-                                    src0,                           // [n,m,q1,r1]
-                                    ggml_transpose(ctx,             // [p,m,qq,rr]
-                                        tensor->grad)),             // [m,p,qq,rr]
-                                zero_table, acc_table);
-                }
-            } break;
-        case GGML_OP_MUL_MAT_ID:
-            {
-                GGML_ABORT("fatal error"); // TODO: not implemented
+            if (src1_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc1, ggml_mean(ctx, grad)); // TODO: should probably be sum instead of mean
             }
-        case GGML_OP_OUT_PROD:
-            {
-                GGML_ABORT("fatal error"); // TODO: not implemented
+        } break;
+        case GGML_OP_ACC: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, grad);
             }
-        case GGML_OP_SCALE:
-            {
-                // necessary for llama
-                if (src0->grad) {
-                    float s;
-                    memcpy(&s, tensor->op_params, sizeof(float));
-
-                    src0->grad =
-                        ggml_add_or_set(ctx,
-                            src0->grad,
-                            ggml_scale_impl(ctx, tensor->grad, s, false),
-                            zero_table, acc_table);
-                }
-            } break;
-        case GGML_OP_SET:
-            {
-                const size_t nb1     = ((int32_t *) tensor->op_params)[0];
-                const size_t nb2     = ((int32_t *) tensor->op_params)[1];
-                const size_t nb3     = ((int32_t *) tensor->op_params)[2];
-                const size_t offset  = ((int32_t *) tensor->op_params)[3];
-
-                struct ggml_tensor * tensor_grad_view = NULL;
-
-                if (src0->grad || src1->grad) {
-                    GGML_ASSERT(src0->type == tensor->type);
-                    GGML_ASSERT(tensor->grad->type == tensor->type);
-                    GGML_ASSERT(!src1->grad || src1->grad->type == tensor->grad->type);
-
-                    tensor_grad_view = ggml_view_4d(ctx,
-                        tensor->grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
-                        nb1, nb2, nb3, offset);
-                }
+            if (src1_needs_grads) {
+                const size_t nb1    = ((int32_t *) tensor->op_params)[0];
+                const size_t nb2    = ((int32_t *) tensor->op_params)[1];
+                const size_t nb3    = ((int32_t *) tensor->op_params)[2];
+                const size_t offset = ((int32_t *) tensor->op_params)[3];
 
-                if (src0->grad) {
-                    src0->grad = ggml_add_or_set(ctx,
-                        src0->grad,
-                        ggml_acc_impl(ctx,
-                            tensor->grad,
-                            ggml_neg(ctx, tensor_grad_view),
-                            nb1, nb2, nb3, offset, false),
-                        zero_table, acc_table);
-                }
+                struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx,
+                    grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
+                    nb1, nb2, nb3, offset);
 
-                if (src1->grad) {
-                    src1->grad =
-                        ggml_add_or_set(ctx,
-                            src1->grad,
-                            ggml_reshape(ctx,
-                                ggml_cont(ctx, tensor_grad_view),
-                                src1->grad),
-                            zero_table, acc_table);
-                }
-            } break;
-        case GGML_OP_CPY:
-            {
-                // necessary for llama
-                // cpy overwrites value of src1 by src0 and returns view(src1)
-                // the overwriting is mathematically equivalent to:
-                // tensor = src0 * 1 + src1 * 0
-                if (src0->grad) {
-                    // dsrc0 = dtensor * 1
-                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
-                }
-                if (src1->grad) {
-                    // dsrc1 = dtensor * 0 -> noop
-                }
-            } break;
-        case GGML_OP_CONT:
-            {
-                // same as cpy
-                if (src0->grad) {
-                    GGML_ASSERT(ggml_is_contiguous(src0->grad));
-                    GGML_ASSERT(ggml_is_contiguous(tensor->grad));
-                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
-                }
-            } break;
-        case GGML_OP_RESHAPE:
-            {
-                // necessary for llama
-                if (src0->grad) {
-                    src0->grad =
-                        ggml_add_or_set(ctx, src0->grad,
-                            ggml_reshape(ctx,
-                                ggml_is_contiguous(tensor->grad)
-                                    ? tensor->grad
-                                    : ggml_cont(ctx, tensor->grad),
-                                src0->grad),
-                        zero_table, acc_table);
-                }
-            } break;
-        case GGML_OP_VIEW:
-            {
-                // necessary for llama
-                if (src0->grad) {
-                    size_t offset;
-
-                    memcpy(&offset, tensor->op_params, sizeof(offset));
-
-                    size_t nb1 = tensor->nb[1];
-                    size_t nb2 = tensor->nb[2];
-                    size_t nb3 = tensor->nb[3];
-
-                    if (src0->type != src0->grad->type) {
-                        // gradient is typically F32, but src0 could be other type
-                        size_t ng = ggml_element_size(src0->grad);
-                        size_t n0 = ggml_element_size(src0);
-                        GGML_ASSERT(offset % n0 == 0);
-                        GGML_ASSERT(nb1 % n0 == 0);
-                        GGML_ASSERT(nb2 % n0 == 0);
-                        GGML_ASSERT(nb3 % n0 == 0);
-                        offset = (offset / n0) * ng;
-                        nb1 = (nb1 / n0) * ng;
-                        nb2 = (nb2 / n0) * ng;
-                        nb3 = (nb3 / n0) * ng;
-                    }
-
-                    src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table, acc_table);
-                }
-            } break;
-        case GGML_OP_PERMUTE:
-            {
-                // necessary for llama
-                if (src0->grad) {
-                    int32_t * axes = (int32_t *) tensor->op_params;
-                    int axis0 = axes[0] & 0x3;
-                    int axis1 = axes[1] & 0x3;
-                    int axis2 = axes[2] & 0x3;
-                    int axis3 = axes[3] & 0x3;
-                    int axes_backward[4] = {0,0,0,0};
-                    axes_backward[axis0] = 0;
-                    axes_backward[axis1] = 1;
-                    axes_backward[axis2] = 2;
-                    axes_backward[axis3] = 3;
-                    src0->grad =
-                        ggml_add_or_set(ctx, src0->grad,
-                            ggml_permute(ctx,
-                                tensor->grad,
-                                axes_backward[0],
-                                axes_backward[1],
-                                axes_backward[2],
-                                axes_backward[3]),
-                            zero_table, acc_table);
-                }
-            } break;
-        case GGML_OP_TRANSPOSE:
-            {
-                // necessary for llama
-                if (src0->grad) {
-                    src0->grad =
-                        ggml_add_or_set(ctx, src0->grad,
-                            ggml_transpose(ctx, tensor->grad),
-                        zero_table, acc_table);
-                }
-            } break;
-        case GGML_OP_GET_ROWS:
-            {
-                // necessary for llama (only for tokenizer)
-                if (src0->grad) {
-                    src0->grad =
-                        ggml_add_or_set(ctx, src0->grad,
-                            // last ggml_get_rows_back argument src0->grad is only
-                            // necessary to setup correct output shape
-                            ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
-                        zero_table, acc_table);
-                }
-                if (src1->grad) {
-                    // noop
-                }
-            } break;
-        case GGML_OP_GET_ROWS_BACK:
-            {
-                GGML_ABORT("fatal error"); // TODO: not implemented
+                ggml_add_or_set(ctx, cgraph, isrc1, ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1));
             }
-        case GGML_OP_DIAG:
-            {
-                GGML_ABORT("fatal error"); // TODO: not implemented
+        } break;
+        case GGML_OP_SUB: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, grad);
             }
-        case GGML_OP_DIAG_MASK_INF:
-            {
-                // necessary for llama
-                if (src0->grad) {
-                    const int n_past = ((int32_t *) tensor->op_params)[0];
-                    src0->grad =
-                        ggml_add_or_set(ctx, src0->grad,
-                            /* ggml_diag_mask_inf_impl() shouldn't be here */
-                            /* ref:  https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */
-                            ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
-                        zero_table, acc_table);
-                }
-            } break;
-        case GGML_OP_DIAG_MASK_ZERO:
-            {
-                // necessary for llama
-                if (src0->grad) {
-                    const int n_past = ((int32_t *) tensor->op_params)[0];
-                    src0->grad =
-                        ggml_add_or_set(ctx, src0->grad,
-                            ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
-                        zero_table, acc_table);
-                }
-            } break;
-        case GGML_OP_SOFT_MAX:
-            {
-                // necessary for llama
-                if (src0->grad) {
-                    src0->grad =
-                        ggml_add_or_set(ctx, src0->grad,
-                            ggml_soft_max_back(ctx, tensor->grad, tensor),
-                        zero_table, acc_table);
-                }
-                GGML_ASSERT((!src1 || !src1->grad) && "backward pass for softmax mask not implemented");
-            } break;
-        case GGML_OP_SOFT_MAX_BACK:
-            {
-                GGML_ABORT("fatal error"); // TODO: not implemented
+            if (src1_needs_grads) {
+                ggml_sub_or_set(ctx, cgraph, isrc1, grad);
             }
-        case GGML_OP_ROPE:
-            {
-                // necessary for llama
-                if (src0->grad) {
-                    //const int n_past = ((int32_t *) tensor->op_params)[0];
-                    const int n_dims     = ((int32_t *) tensor->op_params)[1];
-                    const int mode       = ((int32_t *) tensor->op_params)[2];
-                    //const int n_ctx      = ((int32_t *) tensor->op_params)[3];
-                    const int n_ctx_orig = ((int32_t *) tensor->op_params)[4];
-                    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
-
-                    memcpy(&freq_base,   (int32_t *) tensor->op_params +  5, sizeof(float));
-                    memcpy(&freq_scale,  (int32_t *) tensor->op_params +  6, sizeof(float));
-                    memcpy(&ext_factor,  (int32_t *) tensor->op_params +  7, sizeof(float));
-                    memcpy(&attn_factor, (int32_t *) tensor->op_params +  8, sizeof(float));
-                    memcpy(&beta_fast,   (int32_t *) tensor->op_params +  9, sizeof(float));
-                    memcpy(&beta_slow,   (int32_t *) tensor->op_params + 10, sizeof(float));
-
-                    src0->grad = ggml_add_or_set(ctx,
-                            src0->grad,
-                            ggml_rope_back(ctx,
-                                tensor->grad,
-                                src1,
-                                src2,
-                                n_dims,
-                                mode,
-                                n_ctx_orig,
-                                freq_base,
-                                freq_scale,
-                                ext_factor,
-                                attn_factor,
-                                beta_fast,
-                                beta_slow),
-                            zero_table, acc_table);
-                }
-                GGML_ASSERT((!src2 || !src2->grad) && "gradients for freq factors not implemented");
-            } break;
-        case GGML_OP_ROPE_BACK:
-            {
-                if (src0->grad) {
-                    //const int n_past = ((int32_t *) tensor->op_params)[0];
-                    const int n_dims     = ((int32_t *) tensor->op_params)[1];
-                    const int mode       = ((int32_t *) tensor->op_params)[2];
-                    //const int n_ctx      = ((int32_t *) tensor->op_params)[3];
-                    const int n_ctx_orig = ((int32_t *) tensor->op_params)[4];
-                    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
-
-                    memcpy(&freq_base,   (int32_t *) tensor->op_params +  5, sizeof(float));
-                    memcpy(&freq_scale,  (int32_t *) tensor->op_params +  6, sizeof(float));
-                    memcpy(&ext_factor,  (int32_t *) tensor->op_params +  7, sizeof(float));
-                    memcpy(&attn_factor, (int32_t *) tensor->op_params +  8, sizeof(float));
-                    memcpy(&beta_fast,   (int32_t *) tensor->op_params +  9, sizeof(float));
-                    memcpy(&beta_slow,   (int32_t *) tensor->op_params + 10, sizeof(float));
-
-                    src0->grad = ggml_add_or_set(ctx,
-                            src0->grad,
-                            ggml_rope_impl(ctx,
-                                tensor->grad,
-                                src1,
-                                src2,
-                                n_dims,
-                                mode,
-                                n_ctx_orig,
-                                freq_base,
-                                freq_scale,
-                                ext_factor,
-                                attn_factor,
-                                beta_fast,
-                                beta_slow,
-                                false),
-                            zero_table, acc_table);
+        } break;
+        case GGML_OP_MUL: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, src1, grad));
+            }
+            if (src1_needs_grads) {
+                struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad);
+                if (!ggml_are_same_shape(src0, src1)) {
+                    tmp = ggml_repeat_back(ctx, tmp, src1);
                 }
-            } break;
-        case GGML_OP_CLAMP:
-            {
-                GGML_ABORT("fatal error"); // TODO: not implemented
+                ggml_add_or_set(ctx, cgraph, isrc1, tmp);
             }
-        case GGML_OP_CONV_TRANSPOSE_1D:
-            {
-                GGML_ABORT("fatal error"); // TODO: not implemented
+        } break;
+        case GGML_OP_DIV: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_div(ctx, grad, src1));
             }
-        case GGML_OP_IM2COL:
-            {
-                if (src1->grad) {
-                    const int32_t s0    = ggml_get_op_params_i32(tensor, 0);
-                    const int32_t s1    = ggml_get_op_params_i32(tensor, 1);
-                    const int32_t p0    = ggml_get_op_params_i32(tensor, 2);
-                    const int32_t p1    = ggml_get_op_params_i32(tensor, 3);
-                    const int32_t d0    = ggml_get_op_params_i32(tensor, 4);
-                    const int32_t d1    = ggml_get_op_params_i32(tensor, 5);
-                    const bool    is_2D = ggml_get_op_params_i32(tensor, 6) == 1;
-
-                    src1->grad = ggml_add_or_set(ctx,
-                            src1->grad,
-                            ggml_im2col_back(ctx, src0, tensor->grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D),
-                            zero_table, acc_table);
-                }
-            } break;
-        case GGML_OP_IM2COL_BACK:
-            {
-                GGML_ABORT("fatal error"); // TODO: not implemented
+            if (src1_needs_grads) {
+                ggml_sub_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, grad, ggml_div(ctx, tensor, src1)));
             }
-        case GGML_OP_CONV_TRANSPOSE_2D:
-            {
-                GGML_ABORT("fatal error"); // TODO: not implemented
+        } break;
+        case GGML_OP_SQR: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale(ctx, ggml_mul(ctx, src0, grad), 2.0f));
             }
-        case GGML_OP_POOL_1D:
-            {
-                GGML_ABORT("fatal error"); // TODO: not implemented
+        } break;
+        case GGML_OP_SQRT: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale(ctx, ggml_div(ctx, grad, tensor), 0.5f));
             }
-        case GGML_OP_POOL_2D:
-            {
-                if (src0->grad) {
-                    const enum ggml_op_pool op = ggml_get_op_params_i32(tensor, 0);
-                    const      int32_t      k0 = ggml_get_op_params_i32(tensor, 1);
-                    const      int32_t      k1 = ggml_get_op_params_i32(tensor, 2);
-                    const      int32_t      s0 = ggml_get_op_params_i32(tensor, 3);
-                    const      int32_t      s1 = ggml_get_op_params_i32(tensor, 4);
-                    const      int32_t      p0 = ggml_get_op_params_i32(tensor, 5);
-                    const      int32_t      p1 = ggml_get_op_params_i32(tensor, 6);
-
-                    src0->grad = ggml_add_or_set(ctx,
-                            src0->grad,
-                            ggml_pool_2d_back(ctx, tensor->grad, src0, op, k0, k1, s0, s1, p0, p1),
-                            zero_table, acc_table);
-                }
-            } break;
-        case GGML_OP_POOL_2D_BACK:
-            {
-                GGML_ABORT("fatal error"); // TODO: not implemented
+        } break;
+        case GGML_OP_LOG: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_div(ctx, grad, src0));
             }
-        case GGML_OP_UPSCALE:
-            {
-                GGML_ABORT("fatal error"); // TODO: not implemented
+        } break;
+        case GGML_OP_SIN: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_cos(ctx, src0)));
             }
-        case GGML_OP_PAD:
-            {
-                GGML_ABORT("fatal error"); // TODO: not implemented
+        } break;
+        case GGML_OP_COS: {
+            if (src0_needs_grads) {
+                ggml_sub_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_sin(ctx, src0)));
             }
-        case GGML_OP_ARANGE:
-            {
-                GGML_ABORT("fatal error"); // TODO: not implemented
+        } break;
+        case GGML_OP_SUM: {
+            if (src0_needs_grads) {
+                ggml_add1_or_set(ctx, cgraph, isrc0, src0, grad);
             }
-        case GGML_OP_TIMESTEP_EMBEDDING:
-            {
-                GGML_ABORT("fatal error"); // TODO: not implemented
+        } break;
+        case GGML_OP_SUM_ROWS: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat(ctx, grad, src0));
             }
-        case GGML_OP_ARGSORT:
-            {
-                GGML_ABORT("fatal error"); // TODO: not implemented
+        } break;
+        case GGML_OP_MEAN: {
+            if (src0_needs_grads) {
+                ggml_add1_or_set(ctx, cgraph, isrc0, src0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false));
             }
-        case GGML_OP_LEAKY_RELU:
-            {
-                GGML_ABORT("fatal error"); // TODO: not implemented
+        } break;
+        case GGML_OP_REPEAT: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat_back(ctx, grad, src0));
             }
-        case GGML_OP_FLASH_ATTN_EXT:
-            {
-                GGML_ABORT("FA backward pass not adapted after rework");
-                struct ggml_tensor * flash_grad = NULL;
-                if (src0->grad || src1->grad || tensor->src[2]->grad) {
-                    int32_t t = ggml_get_op_params_i32(tensor, 0);
-                    GGML_ASSERT(t == 0 || t == 1);
-                    bool masked = t != 0;
-                    flash_grad =
-                        ggml_flash_attn_back(ctx,
-                            src0,
-                            src1,
-                            tensor->src[2],
-                            tensor->grad,
-                            masked);
+        } break;
+        case GGML_OP_REPEAT_BACK: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat(ctx, grad, src0));
+            }
+        } break;
+        case GGML_OP_RMS_NORM: {
+            if (src0_needs_grads) {
+                float eps;
+                memcpy(&eps, tensor->op_params, sizeof(float));
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, src0, grad, eps));
+            }
+        } break;
+        case GGML_OP_MUL_MAT: {
+            // https://cs231n.github.io/optimization-2/#staged
+            // # forward pass
+            // s0 = np.random.randn(5, 10)
+            // s1 = np.random.randn(10, 3)
+            // t = s0.dot(s1)
+
+            // # now suppose we had the gradient on t from above in the circuit
+            // dt = np.random.randn(*t.shape) # same shape as t
+            // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
+            // ds1 = t.T.dot(dt)
+
+            // tensor.shape [m,p,qq,rr]
+            // src0.shape   [n,m,q1,r1]
+            // src1.shape   [n,p,qq,rr]
+
+            if (src0_needs_grads) {
+                struct ggml_tensor * s1_tg =
+                    ggml_out_prod(ctx, // [n,m,qq,rr]
+                        src1,          // [n,p,qq,rr]
+                        grad);         // [m,p,qq,rr]
+                const int64_t qq = s1_tg->ne[2];
+                const int64_t rr = s1_tg->ne[3];
+                const int64_t q1 = src0->ne[2];
+                const int64_t r1 = src0->ne[3];
+                const bool ne2_broadcasted = qq > q1;
+                const bool ne3_broadcasted = rr > r1;
+                if (ne2_broadcasted || ne3_broadcasted) {
+                    // sum broadcast repetitions of s1_tg into shape of src0
+                    s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
                 }
+                ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/);
+            }
+            if (src1_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc1,
+                        // ggml_mul_mat(ctx,                   // [n,p,qq,rr]
+                        //     ggml_cont(ctx,                  // [m,n,q1,r1]
+                        //         ggml_transpose(ctx, src0)), // [m,n,q1,r1]
+                        //     grad),                          // [m,p,qq,rr]
+
+                        // when src0 is bigger than tensor->grad (this is mostly the case in llama),
+                        // avoid transpose of src0, rather transpose smaller tensor->grad
+                        // and then use ggml_out_prod
+                        ggml_out_prod(ctx,      // [n,p,qq,rr]
+                            src0,               // [n,m,q1,r1]
+                            ggml_transpose(ctx, // [p,m,qq,rr]
+                                grad)));        // [m,p,qq,rr]
+            }
+        } break;
+        case GGML_OP_SCALE: {
+            if (src0_needs_grads) {
+                float s;
+                memcpy(&s, tensor->op_params, sizeof(float));
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, false));
+            }
+        } break;
+        case GGML_OP_SET: {
+            const size_t nb1    = ((const int32_t *) tensor->op_params)[0];
+            const size_t nb2    = ((const int32_t *) tensor->op_params)[1];
+            const size_t nb3    = ((const int32_t *) tensor->op_params)[2];
+            const size_t offset = ((const int32_t *) tensor->op_params)[3];
+
+            struct ggml_tensor * tensor_grad_view = NULL;
+
+            if (src0_needs_grads || src1_needs_grads) {
+                GGML_ASSERT(src0->type == tensor->type);
+                GGML_ASSERT(!cgraph->grads[isrc0] ||                      cgraph->grads[isrc0]->type == grad->type);
+                GGML_ASSERT(!cgraph->grads[isrc1] || !src1_needs_grads || cgraph->grads[isrc1]->type == grad->type);
+
+                tensor_grad_view = ggml_view_4d(ctx,
+                    grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
+                    nb1, nb2, nb3, offset);
+            }
 
-                const int64_t elem_q = ggml_nelements(src0);
-                const int64_t elem_k = ggml_nelements(src1);
-                const int64_t elem_v = ggml_nelements(src2);
-
-                enum ggml_type result_type = flash_grad->type;
-                GGML_ASSERT(ggml_blck_size(result_type) == 1);
-                const size_t tsize = ggml_type_size(result_type);
-
-                const size_t offs_q = 0;
-                const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
-                const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
-
-                if (src0->grad) {
-                    struct ggml_tensor * view_q = ggml_view_1d(ctx, flash_grad, elem_q, offs_q);
-                    struct ggml_tensor * grad_q = ggml_reshape(ctx, view_q, src0);
-                    src0->grad = ggml_add_or_set(ctx,
-                            src0->grad,
-                            grad_q,
-                            zero_table, acc_table);
-                }
-                if (src1->grad) {
-                    struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k);
-                    struct ggml_tensor * grad_k = ggml_reshape(ctx, view_k, src1);
-                    src1->grad = ggml_add_or_set(ctx,
-                            src1->grad,
-                            grad_k,
-                            zero_table, acc_table);
-                }
-                if (src2->grad) {
-                    struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v);
-                    struct ggml_tensor * grad_v = ggml_reshape(ctx, view_v, src2);
-                    src2->grad = ggml_add_or_set(ctx,
-                            src2->grad,
-                            grad_v,
-                            zero_table, acc_table);
+            if (src0_needs_grads) {
+                struct ggml_tensor * tmp = ggml_neg(ctx, tensor_grad_view);
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_acc_impl(ctx, grad, tmp, nb1, nb2, nb3, offset, false));
+            }
+
+            if (src1_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc1, ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1));
+            }
+        } break;
+        case GGML_OP_CPY: {
+            // cpy overwrites value of src1 by src0 and returns view(src1)
+            // the overwriting is mathematically equivalent to:
+            // tensor = src0 * 1 + src1 * 0
+            if (src0_needs_grads) {
+                // dsrc0 = dtensor * 1
+                ggml_add_or_set(ctx, cgraph, isrc0, grad);
+            }
+            if (src1_needs_grads) {
+                // dsrc1 = dtensor * 0 -> noop
+            }
+        } break;
+        case GGML_OP_CONT: {
+            // same as cpy
+            if (src0_needs_grads) {
+                GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0]));
+                GGML_ASSERT(ggml_is_contiguous(grad));
+                ggml_add_or_set(ctx, cgraph, isrc0, grad);
+            }
+        } break;
+        case GGML_OP_RESHAPE: {
+            if (src0_needs_grads) {
+                struct ggml_tensor * grad_cont = ggml_is_contiguous(grad) ? grad : ggml_cont(ctx, grad);
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad_cont, src0));
+            }
+        } break;
+        case GGML_OP_VIEW: {
+            if (src0_needs_grads) {
+                size_t offset;
+
+                memcpy(&offset, tensor->op_params, sizeof(offset));
+
+                size_t nb1 = tensor->nb[1];
+                size_t nb2 = tensor->nb[2];
+                size_t nb3 = tensor->nb[3];
+
+                if (cgraph->grads[isrc0] && src0->type != cgraph->grads[isrc0]->type) {
+                    // gradient is typically F32, but src0 could be other type
+                    size_t ng = ggml_element_size(cgraph->grads[isrc0]);
+                    size_t n0 = ggml_element_size(src0);
+                    GGML_ASSERT(offset % n0 == 0);
+                    GGML_ASSERT(nb1 % n0 == 0);
+                    GGML_ASSERT(nb2 % n0 == 0);
+                    GGML_ASSERT(nb3 % n0 == 0);
+                    offset = (offset / n0) * ng;
+                    nb1 = (nb1 / n0) * ng;
+                    nb2 = (nb2 / n0) * ng;
+                    nb3 = (nb3 / n0) * ng;
                 }
-            } break;
-        case GGML_OP_FLASH_ATTN_BACK:
-            {
-                GGML_ABORT("fatal error"); // not supported
+
+                ggml_acc_or_set(ctx, cgraph, isrc0, src0, grad, nb1, nb2, nb3, offset);
             }
-        case GGML_OP_SSM_CONV:
-        case GGML_OP_SSM_SCAN:
-            {
-                GGML_ABORT("fatal error"); // TODO: not implemented
+        } break;
+        case GGML_OP_PERMUTE: {
+            if (src0_needs_grads) {
+                const int32_t * axes = (const int32_t *) tensor->op_params;
+                const int axis0 = axes[0] & 0x3;
+                const int axis1 = axes[1] & 0x3;
+                const int axis2 = axes[2] & 0x3;
+                const int axis3 = axes[3] & 0x3;
+                int axb[4] = {0,0,0,0}; // axes backward
+                axb[axis0] = 0;
+                axb[axis1] = 1;
+                axb[axis2] = 2;
+                axb[axis3] = 3;
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_permute(ctx, grad, axb[0], axb[1], axb[2], axb[3]));
             }
+        } break;
+        case GGML_OP_TRANSPOSE: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_transpose(ctx, grad));
+            }
+        } break;
+        case GGML_OP_GET_ROWS: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_get_rows_back(ctx, grad, src1, src0));
+            }
+            if (src1_needs_grads) {
+                // noop
+            }
+        } break;
+        case GGML_OP_DIAG_MASK_INF: {
+            if (src0_needs_grads) {
+                /* ggml_diag_mask_inf_impl() shouldn't be here */
+                /* ref:  https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */
+                const int n_past = ((const int32_t *) tensor->op_params)[0];
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false));
+            }
+        } break;
+        case GGML_OP_DIAG_MASK_ZERO: {
+            if (src0_needs_grads) {
+                const int n_past = ((const int32_t *) tensor->op_params)[0];
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false));
+            }
+        } break;
+        case GGML_OP_SOFT_MAX: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_back(ctx, grad, tensor));
+            }
+            GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented");
+        } break;
+        case GGML_OP_ROPE: {
+            if (src0_needs_grads) {
+                //const int n_past = ((int32_t *) tensor->op_params)[0];
+                const int n_dims     = ((const int32_t *) tensor->op_params)[1];
+                const int mode       = ((const int32_t *) tensor->op_params)[2];
+                //const int n_ctx      = ((int32_t *) tensor->op_params)[3];
+                const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4];
+                float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+
+                memcpy(&freq_base,   (const float *) tensor->op_params +  5, sizeof(float));
+                memcpy(&freq_scale,  (const float *) tensor->op_params +  6, sizeof(float));
+                memcpy(&ext_factor,  (const float *) tensor->op_params +  7, sizeof(float));
+                memcpy(&attn_factor, (const float *) tensor->op_params +  8, sizeof(float));
+                memcpy(&beta_fast,   (const float *) tensor->op_params +  9, sizeof(float));
+                memcpy(&beta_slow,   (const float *) tensor->op_params + 10, sizeof(float));
+
+                ggml_add_or_set(ctx, cgraph, isrc0,
+                    ggml_rope_back(ctx, grad, src1, src2, n_dims, mode, n_ctx_orig, freq_base,
+                        freq_scale, ext_factor, attn_factor, beta_fast, beta_slow));
+            }
+            GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented");
+        } break;
+        case GGML_OP_IM2COL: {
+            if (src1_needs_grads) {
+                const int32_t s0    = ggml_get_op_params_i32(tensor, 0);
+                const int32_t s1    = ggml_get_op_params_i32(tensor, 1);
+                const int32_t p0    = ggml_get_op_params_i32(tensor, 2);
+                const int32_t p1    = ggml_get_op_params_i32(tensor, 3);
+                const int32_t d0    = ggml_get_op_params_i32(tensor, 4);
+                const int32_t d1    = ggml_get_op_params_i32(tensor, 5);
+                const bool    is_2D = ggml_get_op_params_i32(tensor, 6) == 1;
+
+                ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, src0, grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
+            }
+        } break;
+        case GGML_OP_POOL_2D: {
+            if (src0_needs_grads) {
+                const enum ggml_op_pool op = ggml_get_op_params_i32(tensor, 0);
+                const      int32_t      k0 = ggml_get_op_params_i32(tensor, 1);
+                const      int32_t      k1 = ggml_get_op_params_i32(tensor, 2);
+                const      int32_t      s0 = ggml_get_op_params_i32(tensor, 3);
+                const      int32_t      s1 = ggml_get_op_params_i32(tensor, 4);
+                const      int32_t      p0 = ggml_get_op_params_i32(tensor, 5);
+                const      int32_t      p1 = ggml_get_op_params_i32(tensor, 6);
+
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_pool_2d_back(ctx, grad, src0, op, k0, k1, s0, s1, p0, p1));
+            }
+        } break;
         case GGML_OP_WIN_PART:
         case GGML_OP_WIN_UNPART:
-        case GGML_OP_UNARY:
-            {
-                switch (ggml_get_unary_op(tensor)) {
-                    case GGML_UNARY_OP_ABS:
-                        {
-                            if (src0->grad) {
-                                src0->grad =
-                                    ggml_add_or_set(ctx,
-                                            src0->grad,
-                                            ggml_mul(ctx,
-                                                ggml_sgn(ctx, src0),
-                                                tensor->grad),
-                                            zero_table, acc_table);
-                            }
-                        } break;
-                    case GGML_UNARY_OP_SGN:
-                        {
-                            if (src0->grad) {
-                                // noop
-                            }
-                        } break;
-                    case GGML_UNARY_OP_NEG:
-                        {
-                            if (src0->grad) {
-                                src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
-                            }
-                        } break;
-                    case GGML_UNARY_OP_STEP:
-                        {
-                            if (src0->grad) {
-                                // noop
-                            }
-                        } break;
-                    case GGML_UNARY_OP_TANH:
-                        {
-                            GGML_ABORT("fatal error"); // TODO: not implemented
-                        }
-                    case GGML_UNARY_OP_ELU:
-                        {
-                            GGML_ABORT("fatal error"); // TODO: not implemented
-                        }
-                    case GGML_UNARY_OP_RELU:
-                        {
-                            if (src0->grad) {
-                                src0->grad = ggml_add_or_set(ctx,
-                                        src0->grad,
-                                        ggml_mul(ctx,
-                                            ggml_step(ctx, src0),
-                                            tensor->grad),
-                                        zero_table, acc_table);
-                            }
-                        } break;
-                    case GGML_UNARY_OP_SIGMOID:
-                        {
-                            GGML_ABORT("fatal error"); // TODO: not implemented
-                        }
-                    case GGML_UNARY_OP_GELU:
-                        {
-                            GGML_ABORT("fatal error"); // TODO: not implemented
-                        }
-                    case GGML_UNARY_OP_GELU_QUICK:
-                        {
-                            GGML_ABORT("fatal error"); // TODO: not implemented
-                        }
-                    case GGML_UNARY_OP_SILU:
-                        {
-                            // necessary for llama
-                            if (src0->grad) {
-                                src0->grad = ggml_add_or_set(ctx,
-                                        src0->grad,
-                                        ggml_silu_back(ctx, src0, tensor->grad),
-                                        zero_table, acc_table);
-                            }
-                        } break;
-                    case GGML_UNARY_OP_EXP:
-                        {
-                            if (src0->grad) {
-                                src0->grad = ggml_add_or_set(ctx,
-                                        src0->grad,
-                                        ggml_mul(ctx, tensor, tensor->grad),
-                                        zero_table, acc_table);
-                            }
-                        } break;
-                    default:
-                        GGML_ABORT("fatal error");
-                }
-            } break;
-        case GGML_OP_GET_REL_POS:
-        case GGML_OP_ADD_REL_POS:
-        case GGML_OP_RWKV_WKV6:
-        case GGML_OP_MAP_UNARY:
-        case GGML_OP_MAP_BINARY:
-        case GGML_OP_MAP_CUSTOM1_F32:
-        case GGML_OP_MAP_CUSTOM2_F32:
-        case GGML_OP_MAP_CUSTOM3_F32:
-        case GGML_OP_MAP_CUSTOM1:
-        case GGML_OP_MAP_CUSTOM2:
-        case GGML_OP_MAP_CUSTOM3:
-            {
-                GGML_ABORT("fatal error"); // not supported
-            }
-        case GGML_OP_CROSS_ENTROPY_LOSS:
-            {
-                if (src0->grad) {
-                    src0->grad = ggml_add_or_set(ctx,
-                                src0->grad,
-                                ggml_cross_entropy_loss_back(ctx,
-                                    src0,
-                                    src1,
-                                    tensor->grad),
-                                zero_table, acc_table);
-                }
-                GGML_ASSERT(!src1->grad && "backward pass for labels not implemented");
-            } break;
-        case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
-            {
-                GGML_ABORT("fatal error"); // not supported
+        case GGML_OP_UNARY: {
+            switch (ggml_get_unary_op(tensor)) {
+                case GGML_UNARY_OP_ABS: {
+                    if (src0_needs_grads) {
+                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, ggml_sgn(ctx, src0), grad));
+                    }
+                } break;
+                case GGML_UNARY_OP_SGN: {
+                    // noop
+                } break;
+                case GGML_UNARY_OP_NEG: {
+                    if (src0_needs_grads) {
+                        ggml_sub_or_set(ctx, cgraph, isrc0, grad);
+                    }
+                } break;
+                case GGML_UNARY_OP_STEP: {
+                    // noop
+                } break;
+                case GGML_UNARY_OP_RELU: {
+                    if (src0_needs_grads) {
+                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, ggml_step(ctx, src0), grad));
+                    }
+                } break;
+                case GGML_UNARY_OP_SILU: {
+                    if (src0_needs_grads) {
+                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, src0, grad));
+                    }
+                } break;
+                case GGML_UNARY_OP_EXP: {
+                    if (src0_needs_grads) {
+                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, tensor, grad));
+                    }
+                } break;
+                default: {
+                    fprintf(stderr, "%s: unsupported unary op for backward pass: %s\n",
+                        __func__, ggml_unary_op_name(ggml_get_unary_op(tensor)));
+                    GGML_ABORT("fatal error");
+                } break;
             }
-        case GGML_OP_OPT_STEP_ADAMW:
-            {
-                GGML_ABORT("fatal error"); // not supported
+        } break;
+        case GGML_OP_CROSS_ENTROPY_LOSS: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, src0, src1, grad));
             }
-        case GGML_OP_NONE:
-            {
-                // nop
-            } break;
+            GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
+        } break;
+        case GGML_OP_NONE: {
+            // noop
+        } break;
         case GGML_OP_COUNT:
-            {
-                GGML_ABORT("fatal error");
-            }
+        default: {
+            fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
+            GGML_ABORT("fatal error");
+        } break;
     }
 
-    for (int i = 0; i < GGML_MAX_SRC; ++i) {
-        if (tensor->src[i] && tensor->src[i]->grad) {
-            GGML_ASSERT(ggml_are_same_shape(tensor->src[i], tensor->src[i]->grad));
-        }
-    }
+    GGML_ASSERT(!src0_needs_grads || ggml_are_same_shape(src0, cgraph->grads[isrc0]));
+    GGML_ASSERT(!src1_needs_grads || ggml_are_same_shape(src1, cgraph->grads[isrc1]));
+    GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2]));
 }
 
 static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
-    if (node->grad == NULL) {
-        // this usually happens when we generate intermediate nodes from constants in the backward pass
-        // it can also happen during forward pass, if the user performs computations with constants
-        if (node->op != GGML_OP_NONE) {
-            //GGML_PRINT_DEBUG("%s: warning: node %p has no grad, but op %d\n", __func__, (void *) node, node->op);
-        }
-    }
-
     // check if already visited
     if (ggml_hash_insert(&cgraph->visited_hash_set, node) == GGML_HASHSET_ALREADY_EXISTS) {
         return;
@@ -6207,18 +5586,42 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor *
     ggml_build_forward_impl(cgraph, tensor, true);
 }
 
-void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate) {
-    GGML_ASSERT(gf->n_nodes > 0);
-    GGML_ASSERT(gf->grads);
+void ggml_build_backward_expand(
+        struct ggml_context * ctx_static,
+        struct ggml_context * ctx_compute,
+        struct ggml_cgraph  * cgraph,
+        bool                  accumulate) {
+    GGML_ASSERT(cgraph->n_nodes > 0);
+    GGML_ASSERT(cgraph->grads);
+    GGML_ASSERT(cgraph->grad_accs);
+
+    const int n_nodes_f = cgraph->n_nodes;
 
-    for (int i = 0; i < gf->n_nodes; ++i) {
-        struct ggml_tensor * node = gf->nodes[i];
+    const size_t hash_size = ggml_hash_size(2*cgraph->size);
+    memset(cgraph->grads,     0, hash_size*sizeof(struct ggml_tensor *));
+    memset(cgraph->grad_accs, 0, hash_size*sizeof(struct ggml_tensor *));
+    bool * grads_needed = calloc(hash_size, sizeof(bool));
+
+    {
+        bool any_params = false;
+        bool any_loss   = false;
+        for (int i = 0; i < n_nodes_f; ++i) {
+            struct ggml_tensor * node = cgraph->nodes[i];
+            any_params = any_params || (node->flags & GGML_TENSOR_FLAG_PARAM);
+            any_loss   = any_loss   || (node->flags & GGML_TENSOR_FLAG_LOSS);
+        }
+        GGML_ASSERT(any_params && "no trainable parameters found, did you forget to call ggml_set_param?");
+        GGML_ASSERT(any_loss && "no training loss found, did you forget to call ggml_set_loss?");
+    }
+
+    for (int i = 0; i < n_nodes_f; ++i) {
+        struct ggml_tensor * node = cgraph->nodes[i];
 
         if (node->type == GGML_TYPE_I32) {
             continue;
         }
 
-        bool needs_grad = node->flags & GGML_TENSOR_FLAG_PARAM;
+        bool node_needs_grad = node->flags & GGML_TENSOR_FLAG_PARAM;
         bool ignore_src[GGML_MAX_SRC] = {false};
         switch (node->op) {
             // gradients in node->src[0] for one reason or another have no effect on output gradients
@@ -6246,14 +5649,14 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
                 break;
         }
         for (int j = 0; j < GGML_MAX_SRC; ++j) {
-            if (!node->src[j] || !node->src[j]->grad || ignore_src[j]) {
+            if (!node->src[j] || ignore_src[j] || !grads_needed[ggml_hash_find(&cgraph->visited_hash_set, node->src[j])]) {
                 continue;
             }
             GGML_ASSERT(node->src[j]->type == GGML_TYPE_F32 || node->src[j]->type == GGML_TYPE_F16);
-            needs_grad = true;
+            node_needs_grad = true;
             break;
         }
-        if (!needs_grad) {
+        if (!node_needs_grad) {
             continue;
         }
 
@@ -6261,73 +5664,21 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
         GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW ||
             node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE);
 
-        // create a new tensor with the same type and shape as the node and set it as grad
-        node->grad = ggml_dup_tensor(ctx, node);
-    }
-
-    // keep tables of original gradients for replacement/accumulation logic
-    struct ggml_hash_set zero_table = ggml_hash_set_new(gf->size);
-    struct ggml_hash_set acc_table  = ggml_hash_set_new(gf->size);
-    for (int i = 0; i < gf->n_nodes; i++) {
-        struct ggml_tensor * node = gf->nodes[i];
-
-        if (node->grad) {
-            {
-                const size_t insert_result = ggml_hash_insert(&zero_table, node->grad);
-                GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
-                GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
-            }
-
-            // only gradients of trainable parameters should be accumulated
-            if (accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
-                const size_t insert_result = ggml_hash_insert(&acc_table, node->grad);
-                GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
-                GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
-            }
+        const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
+        if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {
+            cgraph->grads[igrad]     = ggml_dup_tensor(ctx_static, node);
+            cgraph->grad_accs[igrad] = cgraph->grads[igrad];
         }
+        grads_needed[igrad] = true;
     }
 
-    for (int i = gf->n_nodes - 1; i >= 0; i--) {
-        struct ggml_tensor * node = gf->nodes[i];
-
+    for (int i = n_nodes_f - 1; i >= 0; --i) {
         // inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
         // use allocator to automatically make inplace operations
-        if (node->grad) {
-            ggml_compute_backward(ctx, node, &zero_table, &acc_table);
-        }
+        ggml_compute_backward(ctx_compute, cgraph, i, grads_needed);
     }
 
-    for (int i = 0; i < gf->n_nodes; i++) {
-        struct ggml_tensor * node = gf->nodes[i];
-
-        if (node->flags & GGML_TENSOR_FLAG_PARAM) {
-            GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
-            ggml_build_forward_expand(gb, node->grad);
-        }
-    }
-
-    ggml_hash_set_free(&zero_table);
-    ggml_hash_set_free(&acc_table);
-}
-
-void ggml_build_opt_adamw(
-        struct ggml_context * ctx,
-        struct ggml_cgraph  * gf,
-        struct ggml_cgraph  * gb,
-        float                 alpha,
-        float                 beta1,
-        float                 beta2,
-        float                 eps,
-        float                 wd) {
-    for (int i = 0; i < gf->n_nodes; i++) {
-        struct ggml_tensor * node = gf->nodes[i];
-
-        if (node->flags & GGML_TENSOR_FLAG_PARAM) {
-            GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
-            struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, node->grad, alpha, beta1, beta2, eps, wd);
-            ggml_build_forward_expand(gb, opt_step);
-        }
-    }
+    free(grads_needed);
 }
 
 static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
@@ -6345,7 +5696,8 @@ static size_t ggml_graph_nbytes(size_t size, bool grads) {
     incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs
     incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys
     if (grads) {
-        incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads
+        incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads
+        incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grad_accs
     }
     incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));
 
@@ -6371,10 +5723,12 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
 
     void * p = cgraph + 1;
 
-    struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
-    struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
-    struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
-    struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
+    struct ggml_tensor ** nodes_ptr     =         incr_ptr_aligned(&p, size      * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
+    struct ggml_tensor ** leafs_ptr     =         incr_ptr_aligned(&p, size      * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
+    struct ggml_tensor ** hash_keys_ptr =         incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
+    struct ggml_tensor ** grads_ptr     = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
+    struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
+
     ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));
 
     // check that we allocated the correct amount of memory
@@ -6386,12 +5740,17 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
         /*.n_leafs      =*/ 0,
         /*.nodes        =*/ nodes_ptr,
         /*.grads        =*/ grads_ptr,
+        /*.grad_accs    =*/ grad_accs_ptr,
         /*.leafs        =*/ leafs_ptr,
         /*.hash_table   =*/ { hash_size, hash_used, hash_keys_ptr },
         /*.order        =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
     };
 
     ggml_hash_set_reset(&cgraph->visited_hash_set);
+    if (grads) {
+        memset(cgraph->grads,     0, hash_size*sizeof(struct ggml_tensor *));
+        memset(cgraph->grad_accs, 0, hash_size*sizeof(struct ggml_tensor *));
+    }
 
     return cgraph;
 }
@@ -6407,6 +5766,7 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1)
         /*.n_leafs      =*/ 0,
         /*.nodes        =*/ cgraph0->nodes + i0,
         /*.grads        =*/ cgraph0->grads ? cgraph0->grads + i0 : NULL,
+        /*.grad_accs    =*/ cgraph0->grad_accs ? cgraph0->grad_accs + i0 : NULL,
         /*.leafs        =*/ NULL,
         /*.hash_table   =*/ { 0, NULL, NULL },
         /*.order        =*/ cgraph0->order,
@@ -6432,19 +5792,23 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
         dst->nodes[i] = src->nodes[i];
     }
 
-    if (src->grads) {
-        GGML_ASSERT(dst->grads != NULL);
-        for (int i = 0; i < src->n_nodes; ++i) {
-            dst->grads[i] = src->grads[i];
-        }
-    }
-
     for (size_t i = 0; i < src->visited_hash_set.size; ++i) {
         // copy all hashset keys (tensors) that are in use
         if (ggml_bitset_get(src->visited_hash_set.used, i)) {
             ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
         }
     }
+
+    if (src->grads) {
+        GGML_ASSERT(dst->grads     != NULL);
+        GGML_ASSERT(dst->grad_accs != NULL);
+        for (int i = 0; i < src->n_nodes; ++i) {
+            const size_t igrad_src = ggml_hash_find(&src->visited_hash_set, src->nodes[i]);
+            const size_t igrad_dst = ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]);
+            dst->grads[igrad_dst]     = src->grads[igrad_src];
+            dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src];
+        }
+    }
 }
 
 struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
@@ -6470,29 +5834,36 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
     GGML_ASSERT(cgraph->grads != NULL);
 
     for (int i = 0; i < cgraph->n_nodes; i++) {
-        struct ggml_tensor * node = cgraph->nodes[i];
+        struct ggml_tensor * node     = cgraph->nodes[i];
+        struct ggml_tensor * grad_acc = ggml_graph_get_grad_acc(cgraph, node);
+
+        if (node->op == GGML_OP_OPT_STEP_ADAMW) {
+            // clear momenta
+            if (node->src[2]->data) {
+                ggml_set_zero(node->src[2]);
+            }
+            if (node->src[3]->data) {
+                ggml_set_zero(node->src[3]);
+            }
+        }
 
         // initial gradients of loss should be 1, 0 otherwise
-        if (node->grad) {
+        if (grad_acc) {
             if (node->flags & GGML_TENSOR_FLAG_LOSS) {
-                GGML_ASSERT(node->grad->buffer);
-                GGML_ASSERT(node->type == GGML_TYPE_F32);
-                GGML_ASSERT(ggml_is_scalar(node));
+                GGML_ASSERT(grad_acc->type == GGML_TYPE_F32);
+                GGML_ASSERT(ggml_is_scalar(grad_acc));
 
                 const float onef = 1.0f;
-                ggml_backend_tensor_set(node->grad, &onef, 0, ggml_nbytes(node->grad));
+                if (grad_acc->buffer) {
+                    ggml_backend_tensor_set(grad_acc, &onef, 0, sizeof(float));
+                } else {
+                    GGML_ASSERT(grad_acc->data);
+                    *((float *) grad_acc->data) = onef;
+                }
             } else {
-                ggml_set_zero(node->grad);
+                ggml_set_zero(grad_acc);
             }
         }
-
-        GGML_ASSERT(node);
-        if (node->op == GGML_OP_OPT_STEP_ADAMW) {
-            // set iteration to 1 and clear momenta
-            ggml_set_op_params_i32(node, 0, 1);
-            ggml_set_zero(node->src[2]);
-            ggml_set_zero(node->src[3]);
-        }
     }
 }
 
@@ -6530,7 +5901,7 @@ void ggml_graph_add_node(struct ggml_cgraph * cgraph, struct ggml_tensor * tenso
     cgraph->n_nodes++;
 }
 
-struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name) {
+struct ggml_tensor * ggml_graph_get_tensor(const struct ggml_cgraph * cgraph, const char * name) {
     for (int i = 0; i < cgraph->n_leafs; i++) {
         struct ggml_tensor * leaf = cgraph->leafs[i];
 
@@ -6550,6 +5921,16 @@ struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const ch
     return NULL;
 }
 
+struct ggml_tensor * ggml_graph_get_grad(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
+    const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
+    return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) ? cgraph->grads[igrad] : NULL;
+}
+
+struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
+    const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
+    return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) ? cgraph->grad_accs[igrad] : NULL;
+}
+
 void ggml_graph_print(const struct ggml_cgraph * cgraph) {
     GGML_LOG_INFO("=== GRAPH ===\n");
 
@@ -6560,7 +5941,8 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
         GGML_LOG_INFO(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s\n",
                 i,
                 node->ne[0], node->ne[1], node->ne[2],
-                ggml_op_name(node->op), (node->flags & GGML_TENSOR_FLAG_PARAM) ? "x" : node->grad ? "g" : " ");
+                ggml_op_name(node->op), (node->flags & GGML_TENSOR_FLAG_PARAM) ? "x" :
+                      ggml_graph_get_grad(cgraph, node) ? "g" : " ");
     }
 
     GGML_LOG_INFO("n_leafs = %d\n", cgraph->n_leafs);
@@ -6595,8 +5977,9 @@ static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml
 static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
     for (int i = 0; i < cgraph->n_nodes; i++) {
         struct ggml_tensor * parent = cgraph->nodes[i];
+        struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, parent);
 
-        if (parent->grad == node) {
+        if (grad == node) {
             return parent;
         }
     }
@@ -6636,6 +6019,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
 
     for (int i = 0; i < gb->n_nodes; i++) {
         struct ggml_tensor * node = gb->nodes[i];
+        struct ggml_tensor * grad = ggml_graph_get_grad(gb, node);
 
         if (ggml_graph_get_parent(gb, node) != NULL) {
             continue;
@@ -6643,7 +6027,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
 
         if (node->flags & GGML_TENSOR_FLAG_PARAM) {
             snprintf(color, sizeof(color), "yellow");
-        } else if (node->grad) {
+        } else if (grad) {
             if (ggml_graph_find(gf, node)) {
                 snprintf(color, sizeof(color), "green");
             } else {
@@ -6670,8 +6054,8 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
             fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], node->ne[2], ggml_op_symbol(node->op));
         }
 
-        if (node->grad) {
-            fprintf(fp, " | <g>%s\"; ]\n", ggml_op_symbol(node->grad->op));
+        if (grad) {
+            fprintf(fp, " | <g>%s\"; ]\n", ggml_op_symbol(grad->op));
         } else {
             fprintf(fp, "\"; ]\n");
         }
index f0a3386a050f54f561d42dcc02962652961cfdbf..e1ee785fa46ffce6019c9db9a2eb3fa931db0c85 100644 (file)
@@ -176,23 +176,14 @@ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm")
 endif()
 
 #
-# test-grad0
+# test-opt
 
-set(TEST_TARGET test-grad0)
+set(TEST_TARGET test-opt)
 add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)
 target_link_libraries(${TEST_TARGET} PRIVATE ggml)
 add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
 set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw")
 
-#
-# test-opt
-
-# set(TEST_TARGET test-opt)
-# add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)
-# target_link_libraries(${TEST_TARGET} PRIVATE ggml)
-# add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
-# set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw")
-
 #
 # test-quantize-fns
 
@@ -268,36 +259,6 @@ target_link_libraries(${TEST_TARGET} PRIVATE ggml)
 add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
 set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw")
 
-#
-# test1
-
-set(TEST_TARGET test1)
-add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
-target_link_libraries(${TEST_TARGET} PRIVATE ggml)
-if (MSVC)
-    target_link_options(${TEST_TARGET} PRIVATE "/STACK: 8388608") # 8MB
-endif()
-add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
-set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw")
-
-#
-# test2
-
-# set(TEST_TARGET test2)
-# add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
-# target_link_libraries(${TEST_TARGET} PRIVATE ggml)
-# add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
-# set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw")
-
-#
-# test3
-
-# set(TEST_TARGET test3)
-# add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
-# target_link_libraries(${TEST_TARGET} PRIVATE ggml)
-# add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
-# set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw")
-
 #
 # test-pool
 
index 6618d03d150a00316b6aad419c48cc4df1d2e587..f8a59b6df175fc2073fe92b7f90f28eb7b7fd3dd 100644 (file)
@@ -811,11 +811,11 @@ struct test_case {
 
         ggml_build_forward_expand(gf, out);
         ggml_graph_cpy(gf, gb);
-        ggml_build_backward_expand(ctx, gf, gb, false);
+        ggml_build_backward_expand(ctx, ctx, gb, false);
         if (expect.size() != 1 || expect[0] != 0.0f) {
             GGML_ASSERT(ggml_graph_n_nodes(gb) > ggml_graph_n_nodes(gf));
             for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
-                GGML_ASSERT(!(t->flags & GGML_TENSOR_FLAG_PARAM) || t->grad->op != GGML_OP_NONE);
+                GGML_ASSERT(!(t->flags & GGML_TENSOR_FLAG_PARAM) || ggml_graph_get_grad(gb, t)->op != GGML_OP_NONE);
             }
         }
 
@@ -862,7 +862,13 @@ struct test_case {
             const char * bn = ggml_backend_name(backend);
             const int64_t ne = ggml_nelements(t);
 
-            std::vector<float> ga = tensor_to_float(t->grad);
+            std::vector<float> ga;
+            struct ggml_tensor * grad = ggml_graph_get_grad(gb, t);
+            if (grad) {
+                ga = tensor_to_float(grad);
+            } else {
+                ga.resize(ne); // default value is 0.0f
+            }
 
             for (int64_t i = 0; i < ne; ++i) { // gradient algebraic
                 // check for nans
@@ -2500,6 +2506,35 @@ struct test_sum_rows : public test_case {
     }
 };
 
+// GGML_OP_MEAN
+struct test_mean : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_mean(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 5, 4, 3})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(ctx, a);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_mean(ctx, a);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    float grad_eps() override {
+        return 0.1f * ne[0]*ne[1]*ne[2]*ne[3];
+    }
+};
+
 // GGML_OP_UPSCALE
 struct test_upscale : public test_case {
     const ggml_type type;
@@ -2834,24 +2869,14 @@ struct test_cross_entropy_loss : public test_case {
 struct test_opt_step_adamw : public test_case {
     const ggml_type type;
     const std::array<int64_t, 4> ne;
-    const float alpha;
-    const float beta1;
-    const float beta2;
-    const float eps;
-    const float wd;
 
     std::string vars() override {
-        return VARS_TO_STR7(type, ne, alpha, beta1, beta2, eps, wd);
+        return VARS_TO_STR2(type, ne);
     }
 
     test_opt_step_adamw(ggml_type type = GGML_TYPE_F32,
-            std::array<int64_t, 4> ne = {10, 5, 4, 3},
-            float alpha = 1e-3f,
-            float beta1 = 0.9f,
-            float beta2 = 0.999f,
-            float eps = 1e-8f,
-            float wd = 0.0f)
-        : type(type), ne(ne), alpha(alpha), beta1(beta1), beta2(beta2), eps(eps), wd(wd) {}
+            std::array<int64_t, 4> ne = {10, 5, 4, 3})
+        : type(type), ne(ne) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
@@ -2861,7 +2886,16 @@ struct test_opt_step_adamw : public test_case {
         ggml_tensor * grad = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
         ggml_set_name(grad, "grad");
 
-        ggml_tensor * out = ggml_opt_step_adamw(ctx, a, grad, alpha, beta1, beta2, eps, wd);
+        ggml_tensor * grad_m = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
+        ggml_set_name(grad_m, "grad_m");
+
+        ggml_tensor * grad_v = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
+        ggml_set_name(grad_v, "grad_v");
+
+        ggml_tensor * adamw_params = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 7);
+        ggml_set_name(adamw_params, "adamw_params");
+
+        ggml_tensor * out = ggml_opt_step_adamw(ctx, a, grad, grad_m, grad_v, adamw_params);
         ggml_set_name(out, "out");
 
         return out;
@@ -2869,7 +2903,7 @@ struct test_opt_step_adamw : public test_case {
 
     void initialize_tensors(ggml_context * ctx) override {
         for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
-            init_tensor_uniform(t, 0.0f, 1.0f); // grad_v needs non-negative values.
+            init_tensor_uniform(t, 0.0f, 1.0f); // grad_v and adamw_params need non-negative values.
         }
     }
 
@@ -3735,6 +3769,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
 
     test_cases.emplace_back(new test_sum());
     test_cases.emplace_back(new test_sum_rows());
+    test_cases.emplace_back(new test_mean());
     test_cases.emplace_back(new test_upscale());
     test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true));
     test_cases.emplace_back(new test_upscale_ext());
@@ -3766,9 +3801,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     }
 
     test_cases.emplace_back(new test_cross_entropy_loss());
-    for (float wd : {0.0f, 1e-2f}) {
-        test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}, 1.0f, 1e-3f, 0.9f, 0.999f, wd));
-    }
+    test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
 
     // these tests are disabled to save execution time, but they can be handy for debugging
 #if 0
@@ -3938,6 +3971,8 @@ int main(int argc, char ** argv) {
         ggml_backend_free(backend);
     }
 
+    ggml_quantize_free();
+
     printf("%zu/%zu backends passed\n", n_ok, ggml_backend_dev_count());
 
     if (n_ok != ggml_backend_dev_count()) {
@@ -3945,8 +3980,6 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
-    ggml_quantize_free();
-
     printf("\033[1;32mOK\033[0m\n");
     return 0;
 }
diff --git a/tests/test-grad0.cpp b/tests/test-grad0.cpp
deleted file mode 100644 (file)
index c712dba..0000000
+++ /dev/null
@@ -1,1684 +0,0 @@
-#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
-#include "ggml.h"
-#include "ggml-cpu.h"
-
-#include <cfloat>
-#include <cmath>
-#include <cstdint>
-#include <cstdio>
-#include <cstdlib>
-#include <cassert>
-#include <initializer_list>
-#include <vector>
-
-#if defined(_MSC_VER)
-#pragma warning(disable: 4244 4267) // possible loss of data
-#endif
-
-#if defined(__GNUC__)
-#pragma GCC diagnostic ignored "-Wdouble-promotion"
-#endif
-
-#define MAX_NARGS 3
-
-#undef MIN
-#undef MAX
-#define MIN(a, b) ((a) < (b) ? (a) : (b))
-#define MAX(a, b) ((a) > (b) ? (a) : (b))
-
-#define GGML_SILU_FP16
-
-//
-// logging
-//
-
-#if (GGML_DEBUG >= 1)
-#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
-#else
-#define GGML_PRINT_DEBUG(...)
-#endif
-
-#if (GGML_DEBUG >= 5)
-#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
-#else
-#define GGML_PRINT_DEBUG_5(...)
-#endif
-
-#if (GGML_DEBUG >= 10)
-#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
-#else
-#define GGML_PRINT_DEBUG_10(...)
-#endif
-
-#define GGML_PRINT(...) printf(__VA_ARGS__)
-
-static float frand(void) {
-    return (float)rand()/(float)RAND_MAX;
-}
-
-static int irand(int n) {
-    if (n == 0) return 0;
-    return rand()%n;
-}
-
-static void get_random_dims(int64_t * dims, int ndims) {
-    dims[0] = dims[1] = dims[2] = dims[3] = 1;
-
-    for (int i = 0; i < ndims; i++) {
-        dims[i] = 1 + irand(4);
-    }
-}
-
-static struct ggml_tensor * get_random_tensor_f32(
-        struct ggml_context * ctx0,
-        int ndims,
-        int64_t ne[],
-        float fmin,
-        float fmax) {
-    struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F32, ndims, ne);
-
-    switch (ndims) {
-        case 1:
-            for (int i0 = 0; i0 < ne[0]; i0++) {
-                ((float *)result->data)[i0] = frand()*(fmax - fmin) + fmin;
-            }
-            break;
-        case 2:
-            for (int i1 = 0; i1 < ne[1]; i1++) {
-                for (int i0 = 0; i0 < ne[0]; i0++) {
-                    ((float *)result->data)[i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
-                }
-            }
-            break;
-        case 3:
-            for (int i2 = 0; i2 < ne[2]; i2++) {
-                for (int i1 = 0; i1 < ne[1]; i1++) {
-                    for (int i0 = 0; i0 < ne[0]; i0++) {
-                        ((float *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
-                    }
-                }
-            }
-            break;
-        case 4:
-            for (int i3 = 0; i3 < ne[3]; i3++) {
-                for (int i2 = 0; i2 < ne[2]; i2++) {
-                    for (int i1 = 0; i1 < ne[1]; i1++) {
-                        for (int i0 = 0; i0 < ne[0]; i0++) {
-                            ((float *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
-                        }
-                    }
-                }
-            }
-            break;
-        default:
-            assert(false);
-    }
-
-    return result;
-}
-
-static struct ggml_tensor * get_random_tensor_f16(
-        struct ggml_context * ctx0,
-        int ndims,
-        int64_t ne[],
-        float fmin,
-        float fmax) {
-    struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F16, ndims, ne);
-
-    switch (ndims) {
-        case 1:
-            for (int i0 = 0; i0 < ne[0]; i0++) {
-                ((ggml_fp16_t *)result->data)[i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
-            }
-            break;
-        case 2:
-            for (int i1 = 0; i1 < ne[1]; i1++) {
-                for (int i0 = 0; i0 < ne[0]; i0++) {
-                    ((ggml_fp16_t *)result->data)[i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
-                }
-            }
-            break;
-        case 3:
-            for (int i2 = 0; i2 < ne[2]; i2++) {
-                for (int i1 = 0; i1 < ne[1]; i1++) {
-                    for (int i0 = 0; i0 < ne[0]; i0++) {
-                        ((ggml_fp16_t *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
-                    }
-                }
-            }
-            break;
-        case 4:
-            for (int i3 = 0; i3 < ne[3]; i3++) {
-                for (int i2 = 0; i2 < ne[2]; i2++) {
-                    for (int i1 = 0; i1 < ne[1]; i1++) {
-                        for (int i0 = 0; i0 < ne[0]; i0++) {
-                            ((ggml_fp16_t *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
-                        }
-                    }
-                }
-            }
-            break;
-        default:
-            assert(false);
-    }
-
-    return result;
-}
-
-static struct ggml_tensor * get_random_tensor_i32(
-        struct ggml_context * ctx0,
-        int ndims,
-        int64_t ne[],
-        int32_t imin,
-        int32_t imax) {
-    struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_I32, ndims, ne);
-
-    switch (ndims) {
-        case 1:
-            for (int i0 = 0; i0 < ne[0]; i0++) {
-                ((int32_t *)result->data)[i0] = irand(imax - imin) + imin;
-            }
-            break;
-        case 2:
-            for (int i1 = 0; i1 < ne[1]; i1++) {
-                for (int i0 = 0; i0 < ne[0]; i0++) {
-                    ((int32_t *)result->data)[i1*ne[0] + i0] = irand(imax - imin) + imin;
-                }
-            }
-            break;
-        case 3:
-            for (int i2 = 0; i2 < ne[2]; i2++) {
-                for (int i1 = 0; i1 < ne[1]; i1++) {
-                    for (int i0 = 0; i0 < ne[0]; i0++) {
-                        ((int32_t *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = irand(imax - imin) + imin;
-                    }
-                }
-            }
-            break;
-        case 4:
-            for (int i3 = 0; i3 < ne[3]; i3++) {
-                for (int i2 = 0; i2 < ne[2]; i2++) {
-                    for (int i1 = 0; i1 < ne[1]; i1++) {
-                        for (int i0 = 0; i0 < ne[0]; i0++) {
-                            ((int32_t *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = irand(imax - imin) + imin;
-                        }
-                    }
-                }
-            }
-            break;
-        default:
-            assert(false);
-    }
-
-    return result;
-}
-
-static bool check_gradient(
-        const char * op_name,
-        struct ggml_context * ctx0,
-        struct ggml_tensor * x[],
-        struct ggml_tensor * f,
-        int ndims,
-        int nargs,
-        float eps,
-        float max_error_abs,
-        float max_error_rel,
-        std::vector<double> expected_vals) {
-
-    static int n_threads = -1;
-    if (n_threads < 0) {
-        n_threads = GGML_DEFAULT_N_THREADS;
-
-        const char *env = getenv("GGML_N_THREADS");
-        if (env) {
-            n_threads = atoi(env);
-        }
-
-        printf("GGML_N_THREADS = %d\n", n_threads);
-    }
-
-    struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
-    struct ggml_cgraph * gb = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
-    ggml_build_forward_expand(gf, f);
-    ggml_graph_cpy(gf, gb);
-    ggml_build_backward_expand(ctx0, gf, gb, false);
-
-    ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
-
-    ggml_graph_reset(gb);
-    if (f->grad) {
-        ggml_set_f32(f->grad, 1.0f);
-    }
-
-    ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
-
-    // ggml_graph_dump_dot(gf, NULL, "test-grad0-forward.dot");
-    // ggml_graph_dump_dot(gb, gf,  "test-grad0-backward.dot");
-
-    for (int i = 0; i < nargs; ++i) {
-        bool all_g0_bad = true;
-        const int nelements = ggml_nelements(x[i]);
-        for (int k = 0; k < nelements; ++k) {
-            // Calculate gradient numerically:
-            const float x0 = ggml_get_f32_1d(x[i], k);
-            const float xm = x0 - eps;
-            const float xp = x0 + eps;
-            ggml_set_f32_1d(x[i], k, xp);
-
-            ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
-
-            const double f0 = ggml_get_f32_1d(f, 0);
-
-            ggml_set_f32_1d(x[i], k, xm);
-
-            ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
-
-            const double f1 = ggml_get_f32_1d(f, 0);
-            const double g0 = (f0 - f1)/(2.0*(double) eps);
-
-            // The numerical calculation of the gradient fails around noncontinuities (e.g. 0 for ReLU).
-            // In such cases, provide a vector of expected values and skip the comparison for failed calculations.
-            if (!expected_vals.empty()) {
-                bool matches_any = false;
-                for (const double & ev : expected_vals) {
-                    const double error_abs = std::fabs(g0 - ev);
-                    if (error_abs > max_error_abs) {
-                        continue;
-                    }
-                    const double error_rel = g0 != 0.0 ? fabs(g0 - ev)/fabs(g0) : 0.0;
-                    if (error_rel > max_error_rel) {
-                        continue;
-                    }
-                    matches_any = true;
-                    break;
-                }
-                if (!matches_any) {
-                    continue;
-                }
-            }
-            all_g0_bad = false;
-
-            ggml_set_f32_1d(x[i], k, x0);
-
-            // compute gradient using backward graph
-            ggml_graph_reset(gb);
-            if (f->grad) {
-                ggml_set_f32(f->grad, 1.0f);
-            }
-
-            ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
-
-            const double g1 = ggml_get_f32_1d(x[i]->grad, k);
-
-            const double error_abs = fabs(g0 - g1);
-            const double error_rel = g0 != 0.0 ? fabs(g0 - g1)/fabs(g0) : 0.0;
-
-            if (error_abs > max_error_abs || error_rel > max_error_rel) {
-                printf("%s: ndims=%d, i=%d, k=%d, x0=%f, xm=%f, xp=%f, f0=%f, f1=%f, g0=%f, g1=%f, eps=%f, error_abs=%f, error_rel=%f\n",
-                            op_name, ndims, i, k, x0, xm, xp, f0, f1, g0, g1, eps, error_abs, error_rel);
-                //assert(false);
-                return false;
-            }
-        }
-        if (all_g0_bad) {
-            printf("%s: numerical calculation of the gradient failed for all values\n", op_name);
-            return false;
-        }
-    }
-
-    return true;
-}
-
-// TODO: clean-up this ..
-static bool check_mat_mul(
-        const struct ggml_tensor * y,
-        const struct ggml_tensor * x0,
-        const struct ggml_tensor * x1) {
-    float * dst  = (float *) y->data;
-    float * src0 = (float *) x0->data;
-    float * src1 = (float *) x1->data;
-
-    const int nc = x0->ne[1];
-    const int nr = x1->ne[1];
-    const int nk = x0->ne[0];
-
-    GGML_PRINT_DEBUG("check_mat_mul: nc=%d, nr=%d, nk=%d\n", nc, nr, nk);
-
-    GGML_PRINT_DEBUG("x0:\n");
-    for (int j = 0; j < x0->ne[1]; ++j) {
-        for (int i = 0; i < x0->ne[0]; ++i) {
-            GGML_PRINT_DEBUG("%6.3f ", src0[j*nk + i]);
-        }
-        GGML_PRINT_DEBUG("\n");
-    }
-    GGML_PRINT_DEBUG("\n");
-
-    GGML_PRINT_DEBUG("x1:\n");
-    for (int j = 0; j < x1->ne[1]; ++j) {
-        for (int i = 0; i < x1->ne[0]; ++i) {
-            GGML_PRINT_DEBUG("%6.3f ", src1[j*nk + i]);
-        }
-        GGML_PRINT_DEBUG("\n");
-    }
-    GGML_PRINT_DEBUG("\n");
-
-    GGML_PRINT_DEBUG("y: n_dims = %d, (%lld, %lld)\n", y->n_dims, y->ne[0], y->ne[1]);
-    for (int j = 0; j < y->ne[1]; ++j) {
-        for (int i = 0; i < y->ne[0]; ++i) {
-            GGML_PRINT_DEBUG("%6.3f ", dst[j*nr + i]);
-        }
-        GGML_PRINT_DEBUG("\n");
-    }
-
-    for (int i = 0; i < nr; ++i) {
-        for (int j = 0; j < nc; ++j) {
-            float sum = 0.0f;
-
-            for (int k = 0; k < nk; ++k) {
-                sum += src0[j*nk + k]*src1[i*nk + k];
-            }
-
-            if (fabsf(dst[i*nc + j] - sum) > 1e-5f) {
-                fprintf(stderr, "check_mat_mul: dst[%d] = %f, sum = %f\n", i*nc + j, dst[i*nc + j], sum);
-                assert(false);
-                return false;
-            }
-        }
-    }
-
-    return true;
-}
-
-#define NUM_PERMUTATIONS (4*3*2*1)
-
-int main(int argc, const char ** argv) {
-    struct ggml_init_params params = {
-        /* .mem_size   = */ 256*1024*1024,
-        /* .mem_buffer = */ NULL,
-        /* .no_alloc   = */ false,
-    };
-
-    int64_t ne[4];
-
-    int all_permutations[4 * NUM_PERMUTATIONS];
-    {
-        int count = 0;
-        for (int ax0=0; ax0<4; ++ax0) {
-            for (int ax1=0; ax1<4; ++ax1) {
-                if (ax1 == ax0) continue;
-                for (int ax2=0; ax2<4; ++ax2) {
-                    if (ax2 == ax0) continue;
-                    if (ax2 == ax1) continue;
-                    for (int ax3=0; ax3<4; ++ax3) {
-                        if (ax3 == ax0) continue;
-                        if (ax3 == ax1) continue;
-                        if (ax3 == ax2) continue;
-                        assert(count < NUM_PERMUTATIONS);
-                        all_permutations[count*4+0] = ax0;
-                        all_permutations[count*4+1] = ax1;
-                        all_permutations[count*4+2] = ax2;
-                        all_permutations[count*4+3] = ax3;
-                        ++count;
-                    }
-                }
-            }
-        }
-    }
-
-    unsigned seed_iter = 1;
-
-    // original loop: 1000
-    int niter = 4;
-    const char *env = getenv("GGML_NLOOP");
-    if (env != NULL) {
-        niter = atoi(env);
-    }
-    if (argc > 1) {
-        niter = atoi(argv[1]);
-    }
-    for (int iter = 0; iter < niter; ++iter) {
-        srand(seed_iter);
-        seed_iter = rand();
-        unsigned seed = rand();
-
-        printf("test-grad0: iter:%d/%d\n", (iter+1), niter);
-        struct ggml_context * ctx0 = ggml_init(params);
-
-        get_random_dims(ne, 4);
-
-        struct ggml_tensor * x[MAX_NARGS];
-
-        // add f32
-        {
-            srand(seed);
-            const int nargs = 2;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1]));
-
-                check_gradient("add f32", ctx0, x, f, ndims, nargs, 1e-3f, 2e-3f, 2e-3f, {});
-            }
-        }
-
-        // add f16
-        {
-            srand(seed);
-            const int nargs = 2;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f16(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1]));
-
-                check_gradient("add f16", ctx0, x, f, ndims, nargs, 1e-1f, 2e-1f, 2e-1f, {});
-            }
-        }
-
-        // sub
-        {
-            srand(seed);
-            const int nargs = 2;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sub(ctx0, x[0], x[1]));
-
-                check_gradient("sub", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
-            }
-        }
-
-        // mul
-        {
-            srand(seed);
-            const int nargs = 2;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_mul(ctx0, x[0], x[1]));
-
-                check_gradient("mul", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // div
-        {
-            srand(seed);
-            const int nargs = 2;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, 0.5f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_div(ctx0, x[0], x[1]));
-
-                check_gradient("div", ctx0, x, f, ndims, nargs, 1e-3f, 1e-1f, 1e-1f, {});
-            }
-        }
-
-        // sqr
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, x[0]));
-
-                check_gradient("sqr", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // sqrt
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqrt(ctx0, x[0]));
-
-                check_gradient("sqrt", ctx0, x, f, ndims, nargs, 1e-3f, 2e-2f, 1e-1f, {});
-            }
-        }
-
-        // log
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_log(ctx0, x[0]));
-
-                check_gradient("log", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-1f, {});
-            }
-        }
-
-        // sum
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, x[0]);
-
-                check_gradient("sum", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
-            }
-        }
-
-
-        // sum_rows
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sum_rows(ctx0, x[0])));
-
-                check_gradient("sum_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY, {});
-            }
-        }
-
-        // mean, not yet fully implemented
-        if(0)
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_mean(ctx0, x[0]));
-
-                check_gradient("mean", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
-            }
-        }
-
-        // argmax
-        if (0)
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_argmax(ctx0, x[0]));
-
-                check_gradient("argmax", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
-            }
-        }
-
-        // repeat
-        {
-            srand(seed);
-            int64_t ne2[4];
-            get_random_dims(ne2, 4);
-
-            ne2[0] = ne[0] * ne2[0];
-            ne2[1] = ne[1] * ne2[1];
-            ne2[2] = 1;
-            ne2[3] = 1;
-
-            const int nargs = 1;
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[1], ggml_repeat(ctx0, x[0], x[1]))));
-
-                check_gradient("repeat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY, {});
-            }
-        }
-
-        // repeat back
-        {
-            srand(seed);
-            int64_t ne2[4];
-            get_random_dims(ne2, 4);
-
-            ne2[0] = ne[0] * ne2[0];
-            ne2[1] = ne[1] * ne2[1];
-            ne2[2] = 1;
-            ne2[3] = 1;
-
-            const int nargs = 1;
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[0], ggml_repeat_back(ctx0, x[1], x[0]))));
-
-                check_gradient("repeat back", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY, {});
-            }
-        }
-
-        // abs
-        {
-           const int nargs = 1;
-
-           for (int ndims = 1; ndims <= 4; ++ndims) {
-               for (int i = 0; i < nargs; ++i) {
-                   x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                   ggml_set_param(ctx0, x[i]);
-               }
-
-               struct ggml_tensor * f = ggml_sum(ctx0, ggml_abs(ctx0, x[0]));
-
-               check_gradient("abs", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-3f, {-1.0, 1.0});
-           }
-        }
-
-        // sgn
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor* f = ggml_sum(ctx0, ggml_sgn(ctx0, x[0]));
-
-                check_gradient("sgn", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {0.0});
-            }
-        }
-
-        // neg
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor* f = ggml_sum(ctx0, ggml_neg(ctx0, x[0]));
-
-                check_gradient("neg", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
-            }
-        }
-
-        // step
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor* f = ggml_sum(ctx0, ggml_step(ctx0, x[0]));
-
-                check_gradient("step", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {0.0});
-            }
-        }
-
-        // tanh, not yet fully implemented
-        if(0)
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor* f = ggml_sum(ctx0, ggml_tanh(ctx0, x[0]));
-
-                check_gradient("tanh", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
-            }
-        }
-
-        // mul_mat
-        {
-            srand(seed);
-            const int nargs = 2;
-
-            for (int ndims = 2; ndims <= 4; ++ndims) {
-                int max_nrep = (ndims >= 3) ? 2 : 1;
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                for (int nrep2 = 1; nrep2 < max_nrep; ++nrep2) {
-                    for (int nrep3 = 1; nrep3 < max_nrep; ++nrep3) {
-                        {
-                            int64_t ne2[4];
-                            get_random_dims(ne2, 4);
-                            ne2[0] = ne[0];
-                            ne2[2] = nrep2 * ne[2];
-                            ne2[3] = nrep3 * ne[3];
-                            x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
-                        }
-
-                        ggml_set_param(ctx0, x[0]);
-                        ggml_set_param(ctx0, x[1]);
-
-                        struct ggml_tensor * m = ggml_mul_mat(ctx0, x[1], x[0]);
-                        struct ggml_tensor * f = ggml_sum(ctx0, m);
-
-                        GGML_PRINT_DEBUG("testing: mul_mat, [%lld, %lld] (%d) * [%lld, %lld] (%d)\n", x[1]->ne[0], x[1]->ne[1], x[1]->n_dims, x[0]->ne[0], x[0]->ne[1], x[0]->n_dims);
-
-                        check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-                        if (ndims == 2) {
-                            // check_mat_mul does not support ndims > 2
-                            check_mat_mul(m, x[1], x[0]);
-                        }
-                    }
-                }
-            }
-        }
-
-        // elu, not yet fully implemented
-        if(0)
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor* f = ggml_sum(ctx0, ggml_elu(ctx0, x[0]));
-
-                check_gradient("elu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
-            }
-        }
-
-        // relu
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor* f = ggml_sum(ctx0, ggml_relu(ctx0, x[0]));
-
-                check_gradient("relu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {0.0, 1.0});
-            }
-        }
-
-        // gelu, not yet fully implemented
-        if(0)
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor* f = ggml_sum(ctx0, ggml_gelu(ctx0, x[0]));
-
-                check_gradient("gelu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
-            }
-        }
-
-        // silu
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_silu(ctx0, x[0]));
-
-#ifdef GGML_SILU_FP16
-                // due to GGML_SILU_FP16 the finite difference method will be slightly wrong -> increase error bounds.
-                check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 0.5, INFINITY, {});
-#else
-                check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-#endif
-            }
-        }
-
-        // rms_norm
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0], 1e-6f));
-
-                check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY, {});
-            }
-        }
-
-        // scale
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-
-                const float s = -1.0f + 2.0f*frand();
-
-                ggml_set_param(ctx0, x[0]);
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_scale(ctx0, x[0], s));
-
-                check_gradient("scale", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // cpy f32
-        {
-            srand(seed);
-            const int nargs = 2;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-                // x[1] is overwritten by x[0], so the gradients don't propagate to x[1]
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1]));
-
-                check_gradient("cpy f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // cpy f16
-        {
-            srand(seed);
-            const int nargs = 2;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f16(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-                // x[1] is overwritten by x[0], so the gradients don't propagate to x[1]
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1]));
-
-                check_gradient("cpy f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY, {});
-            }
-        }
-
-        // reshape (1d->nd)
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                int64_t ne2[4];
-                ne2[0] = 1;
-                ne2[1] = 1;
-                ne2[2] = 1;
-                ne2[3] = 1;
-                for (int i = 0; i < ndims; ++i) {
-                    ne2[0] *= ne[i];
-                }
-                x[0] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
-                x[1] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_reshape(ctx0, x[0], x[1]));
-                check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // reshape (nd->1d)
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                int64_t ne2[4];
-                ne2[0] = 1;
-                ne2[1] = 1;
-                ne2[2] = 1;
-                ne2[3] = 1;
-                for (int i = 0; i < ndims; ++i) {
-                    ne2[0] *= ne[i];
-                }
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_reshape(ctx0, x[0], x[1]));
-                check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // acc 1d
-        {
-            srand(seed);
-            int64_t ne2[4] = { 1, 1, 1, 1 };
-
-            const int nargs = 2;
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-                get_random_dims(ne2, 1);
-                while ((ne2[0] > ne[0]) || (ne2[0] > ggml_nelements(x[0]))) {
-                    get_random_dims(ne2, 1);
-                }
-
-                x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[1]);
-
-                const int max_offset = MAX(0, ggml_nelements(x[0]) - ggml_nelements(x[1]));
-                const int offset = irand(max_offset) * ggml_element_size(x[0]);
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
-
-                check_gradient("acc 1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // acc 2d
-        {
-            srand(seed);
-            int64_t ne2[4]         = { 1, 1, 1, 1 };
-            int64_t max_offsets[4] = { 0, 0, 0, 0 };
-            int64_t offsets[4]     = { 0, 0, 0, 0 };
-
-            const int nargs = 2;
-            for (int ndims = 2; ndims <= 4; ++ndims) {
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-                get_random_dims(ne2, 2);
-                while ((ne2[0] > ne[0]) || (ne2[1] > ne[1]) || (ne2[0]*ne2[1] > ggml_nelements(x[0]))) {
-                    get_random_dims(ne2, 2);
-                }
-
-                x[1] = get_random_tensor_f32(ctx0, 2, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[1]);
-
-                max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
-                max_offsets[1] = MAX(0, x[0]->ne[1] - x[1]->ne[1]);
-                offsets[0] = irand(max_offsets[0]) * x[0]->nb[0];
-                offsets[1] = irand(max_offsets[1]) * x[0]->nb[1];
-                const int offset = offsets[0] + offsets[1];
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
-
-                check_gradient("acc 2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // acc 3d
-        {
-            srand(seed);
-            int64_t ne2[4]         = { 1, 1, 1, 1 };
-            int64_t max_offsets[4] = { 0, 0, 0, 0 };
-            int64_t offsets[4]     = { 0, 0, 0, 0 };
-
-            const int nargs = 2;
-            for (int ndims = 3; ndims <= 4; ++ndims) {
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-                get_random_dims(ne2, 3);
-                while ((ne2[0] > ne[0]) || (ne2[1] > ne[1]) || (ne2[2] > ne[2]) || (ne2[0]*ne2[1]*ne2[2] > ggml_nelements(x[0]))) {
-                    get_random_dims(ne2, 3);
-                }
-
-                x[1] = get_random_tensor_f32(ctx0, 3, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[1]);
-
-                max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
-                max_offsets[1] = MAX(0, x[0]->ne[1] - x[1]->ne[1]);
-                max_offsets[2] = MAX(0, x[0]->ne[2] - x[1]->ne[2]);
-                offsets[0] = irand(max_offsets[0]) * x[0]->nb[0];
-                offsets[1] = irand(max_offsets[1]) * x[0]->nb[1];
-                offsets[2] = irand(max_offsets[2]) * x[0]->nb[2];
-                const int offset = offsets[0] + offsets[1] + offsets[2];
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
-
-                check_gradient("acc 3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // acc 4d
-        {
-            srand(seed);
-            int64_t ne2[4]         = { 1, 1, 1, 1 };
-            int64_t max_offsets[4] = { 0, 0, 0, 0 };
-            int64_t offsets[4]     = { 0, 0, 0, 0 };
-
-            const int nargs = 2;
-            for (int ndims = 4; ndims <= 4; ++ndims) {
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-                get_random_dims(ne2, 4);
-                while ((ne2[0] > ne[0]) || (ne2[1] > ne[1]) || (ne2[2] > ne[2]) || (ne2[3] > ne[3]) || (ne2[0]*ne2[1]*ne2[2]*ne2[3] > ggml_nelements(x[0]))) {
-                    get_random_dims(ne2, 4);
-                }
-
-                x[1] = get_random_tensor_f32(ctx0, 4, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[1]);
-
-                max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
-                max_offsets[1] = MAX(0, x[0]->ne[1] - x[1]->ne[1]);
-                max_offsets[2] = MAX(0, x[0]->ne[2] - x[1]->ne[2]);
-                max_offsets[3] = MAX(0, x[0]->ne[3] - x[1]->ne[3]);
-                offsets[0] = irand(max_offsets[0]) * x[0]->nb[0];
-                offsets[1] = irand(max_offsets[1]) * x[0]->nb[1];
-                offsets[2] = irand(max_offsets[2]) * x[0]->nb[2];
-                offsets[3] = irand(max_offsets[3]) * x[0]->nb[3];
-                const int offset = offsets[0] + offsets[1] + offsets[2] + offsets[3];
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
-
-                check_gradient("acc 4d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // set_1d
-        {
-            srand(seed);
-            int64_t ne2[4];
-
-            const int nargs = 2;
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-                get_random_dims(ne2, 1);
-                while ((ne2[0] > ne[0]) || (ne2[0] > ggml_nelements(x[0]))) {
-                    get_random_dims(ne2, 1);
-                }
-
-                x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[1]);
-
-                const int max_offset = MAX(0, ggml_nelements(x[0]) - ggml_nelements(x[1]));
-                const int offset = irand(max_offset) * ggml_element_size(x[0]);
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_set_1d(ctx0, x[0], x[1], offset));
-
-                check_gradient("set_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // set_2d
-        {
-            srand(seed);
-            int64_t ne2[4];
-            int64_t max_offsets[4] = { 0, 0, 0, 0 };
-            int64_t offsets[4]     = { 0, 0, 0, 0 };
-
-            const int nargs = 1;
-            for (int ndims = 2; ndims <= 4; ++ndims) {
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-                get_random_dims(ne2, 2);
-                while ((ne2[0] > ne[0]) || (ne2[1] > ne[1]) || (ne2[0]*ne2[1] > ggml_nelements(x[0]))) {
-                    get_random_dims(ne2, 2);
-                }
-
-                x[1] = get_random_tensor_f32(ctx0, 2, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[1]);
-
-                max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
-                max_offsets[1] = MAX(0, x[0]->ne[1] - x[1]->ne[1]);
-                offsets[0] = irand(max_offsets[0]) * x[0]->nb[0];
-                offsets[1] = irand(max_offsets[1]) * x[0]->nb[1];
-                const int offset = offsets[0] + offsets[1];
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_set_2d(ctx0, x[0], x[1], x[1]->nb[1], offset));
-
-                check_gradient("set_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // view_1d
-        {
-            srand(seed);
-            const int nargs = 1;
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-
-                ggml_set_param(ctx0, x[0]);
-
-                const int k0 = irand(ggml_nelements(x[0]));
-                const int k1 = irand(ggml_nelements(x[0]));
-                const int i0 = MIN(k0, k1);
-                const int i1 = MAX(k0, k1);
-
-                const int offset = i0 * sizeof(float);
-                const int nelem  = i1 - i0;
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_1d(ctx0, x[0], nelem, offset));
-
-                check_gradient("view_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // view_2d
-        {
-            srand(seed);
-            int64_t ne2[4];
-            int64_t nb2[4];
-
-            const int nargs = 1;
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-
-                get_random_dims(ne2, 2);
-                while (ne2[0]*ne2[1] > ggml_nelements(x[0])) {
-                    get_random_dims(ne2, 2);
-                }
-                const int count = ne2[0]*ne2[1];
-
-                nb2[0] = sizeof(float);
-                nb2[1] = nb2[0]*ne2[0];
-
-                ggml_set_param(ctx0, x[0]);
-
-                const int max_offset = ggml_nelements(x[0]) - count;
-                const int offset = irand(max_offset+1) * sizeof(float);
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_2d(ctx0, x[0], ne2[0], ne2[1], nb2[1], offset));
-
-                check_gradient("view_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // view_3d
-        {
-            srand(seed);
-            int64_t ne2[4] = {1,1,1,1};
-            int64_t nb2[4] = {0,0,0,0};
-
-            const int nargs = 1;
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-
-                get_random_dims(ne2, 3);
-                while (ne2[0]*ne2[1]*ne2[2] > ggml_nelements(x[0])) {
-                    get_random_dims(ne2, 3);
-                }
-                const int count = ne2[0]*ne2[1]*ne2[2];
-
-                nb2[0] = sizeof(float);
-                nb2[1] = nb2[0]*ne2[0];
-                nb2[2] = nb2[1]*ne2[1];
-
-                ggml_set_param(ctx0, x[0]);
-
-                const int max_offset = ggml_nelements(x[0]) - count;
-                const int offset = irand(max_offset+1) * sizeof(float);
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_3d(ctx0, x[0], ne2[0], ne2[1], ne2[2], nb2[1], nb2[2], offset));
-
-                check_gradient("view_3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // permute
-        {
-            srand(seed);
-            int64_t ne2[4];
-
-            const int nargs = 1;
-            for (int ndims = 1; ndims <= 4; ++ndims)
-            {
-                // ggml_permute will set axes of dimensions below n_dims to 1.
-                // to make ggml_permute work correctly on all axes,
-                // the input tensor needs maximal n_dim of 4.
-                for (int i=0; i<ndims; ++i) {
-                    ne2[i] = ne[i];
-                }
-                for (int i=ndims; i<4; ++i) {
-                    ne2[i] = 1;
-                }
-                x[0] = get_random_tensor_f32(ctx0, 4, ne2, -1.0f, 1.0f);
-
-                ggml_set_param(ctx0, x[0]);
-
-                const int p = irand(NUM_PERMUTATIONS);
-                const int ax0 = all_permutations[p*4+0];
-                const int ax1 = all_permutations[p*4+1];
-                const int ax2 = all_permutations[p*4+2];
-                const int ax3 = all_permutations[p*4+3];
-
-                // sum requires contiguous tensor rows
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, x[0], ax0, ax1, ax2, ax3)));
-
-                check_gradient("permute", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // transpose
-        {
-            srand(seed);
-            int64_t ne2[4];
-
-            const int nargs = 1;
-            for (int ndims = 1; ndims <= 4; ++ndims)
-            {
-                // ggml_transpose will set axes of dimensions below n_dims to 1.
-                // to make ggml_transpose work correctly on all axes,
-                // the input tensor needs maximal n_dim of 4.
-                for (int i=0; i<ndims; ++i) {
-                    ne2[i] = ne[i];
-                }
-                for (int i=ndims; i<4; ++i) {
-                    ne2[i] = 1;
-                }
-                x[0] = get_random_tensor_f32(ctx0, 4, ne2, -1.0f, 1.0f);
-
-                ggml_set_param(ctx0, x[0]);
-
-                // sum requires contiguous tensor rows
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, x[0])));
-
-                check_gradient("transpose", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // get_rows
-        {
-            srand(seed);
-            int64_t ne2[4] = {ne[0], ne[1], 1, 1};
-            int64_t ne3[4] = {1+irand(ne[1]), 1, 1, 1};
-            const int nargs = 1;
-            const int ndims = 2;
-            x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
-            x[1] = get_random_tensor_i32(ctx0, 1, ne3, 0, ne2[1]);
-
-            ggml_set_param(ctx0, x[0]);
-
-            struct ggml_tensor * f = ggml_sum(ctx0, ggml_get_rows(ctx0, x[0], x[1]));
-
-            check_gradient("get_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-        }
-
-        // diag_mask_inf
-        {
-            srand(seed);
-            const int nargs = 1;
-            const int ndims = 2;
-
-            x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-            ggml_set_param(ctx0, x[0]);
-
-            int n_past = irand(ne[0]);
-
-            struct ggml_tensor * f = ggml_sum(ctx0, ggml_diag_mask_inf(ctx0, x[0], n_past));
-
-            check_gradient("diag_mask_inf", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-        }
-
-        // diag_mask_zero
-        {
-            srand(seed);
-            const int nargs = 1;
-            const int ndims = 2;
-
-            x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-            ggml_set_param(ctx0, x[0]);
-
-            int n_past = irand(ne[0]);
-
-            struct ggml_tensor * f = ggml_sum(ctx0, ggml_diag_mask_zero(ctx0, x[0], n_past));
-
-            check_gradient("diag_mask_zero", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-        }
-
-        // softmax
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            int64_t ne2[4];
-            get_random_dims(ne2, 4);
-
-            for (int ndims = 1; ndims <= 3; ++ndims) {
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-                float eps = 1e-6f;
-                // dont use only sum as aggregation, because sum of softmax is always 1 -> finite differences should not work
-                // instead use sum(log(soft_max()*(1-eps)+eps)); use eps to avoid log(0)
-                struct ggml_tensor * f = ggml_sum(ctx0,
-                                            ggml_log(ctx0,
-                                                ggml_add1(ctx0,
-                                                    ggml_scale(ctx0,
-                                                        ggml_soft_max(ctx0, x[0]),
-                                                        1.0f - eps),
-                                                    ggml_new_f32(ctx0, eps))));
-
-                check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 2e-1f, INFINITY, {});
-                // NOTE: softmax forward is computed using f16 table lookup instead of using actual expf, but backward assumes actual expf.
-                // this may result in different gradients too finite differences.
-                // when this test reports errors, first try to replace the table lookup with actual expf and test again to see if just that was the cause.
-                // if only the table lookup causes gradients to differ this is acceptable.
-            }
-        }
-
-        // cross_entropy_loss
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            int64_t ne2[4];
-            get_random_dims(ne2, 4);
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
-                x[1] = get_random_tensor_f32(ctx0, ndims, ne2, 0.0f, 1.0f);
-                // the second argument to cross_entropy_loss must sum up to 1 for each row
-                int nr = ggml_nrows(x[1]);
-                int nc = ggml_nelements(x[1]) / nr;
-                for (int ir = 0; ir < nr; ++ir) {
-                    float sum = 0;
-                    for (int ic = 0; ic < nc; ++ic) {
-                        sum += ((float *) x[1]->data)[ic + ir*nc];
-                    }
-                    for (int ic = 0; ic < nc; ++ic) {
-                        ((float *) x[1]->data)[ic + ir*nc] /= sum;
-                    }
-                }
-                ggml_set_param(ctx0, x[0]);
-
-                struct ggml_tensor * f = ggml_cross_entropy_loss(ctx0, x[0], x[1]);
-
-                check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // rope f32
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            int64_t ne2[4];
-            get_random_dims(ne2, 4);
-            ne2[0] += ne2[0] % 2;
-            int n_rot = ne2[0];
-
-            for (int ndims = 3; ndims <= 4; ++ndims) {
-                for (int mode = 0; mode < 4; ++mode) {
-                    for (int n_past = 1; n_past < ne2[2]; ++n_past) {
-                        x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
-
-                        struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne2[2]);
-                        for (int i = 0; i < ne2[2]; ++i) {
-                            ((int32_t *) p->data)[i] = n_past + i;
-                        }
-
-                        ggml_set_param(ctx0, x[0]);
-
-                        const bool skip_past = (mode & 1);
-                        if (skip_past) {
-                            // we have no past, so this would have to work on uninitialized memory.
-                            // we only test the gradients here;
-                            // skip_past should have no influence on gradient computation.
-                            // so when other modes work, we assume that this does as well.
-                            continue;
-                        }
-
-                        struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode));
-
-                        GGML_PRINT_DEBUG("rope f32: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
-                        check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY, {});
-                    }
-                }
-            }
-        }
-
-        // rope f16
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            int64_t ne2[4];
-            get_random_dims(ne2, 4);
-            ne2[0] += ne2[0] % 2;
-            int n_rot = ne2[0];
-
-            for (int ndims = 3; ndims <= 4; ++ndims) {
-                for (int mode = 0; mode < 4; ++mode) {
-                    for (int n_past = 1; n_past < ne2[2]; ++n_past) {
-                        x[0] = get_random_tensor_f16(ctx0, ndims, ne2, -1.0f, 1.0f);
-
-                        struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne2[2]);
-                        for (int i = 0; i < ne2[2]; ++i) {
-                            ((int32_t *) p->data)[i] = n_past + i;
-                        }
-
-                        ggml_set_param(ctx0, x[0]);
-
-                        const bool skip_past = (mode & 1);
-                        if (skip_past) {
-                            // we have no past, so this would have to work on uninitialized memory.
-                            // we only test the gradients here;
-                            // skip_past should have no influence on gradient computation.
-                            // so when other modes work, we assume that this does as well.
-                            continue;
-                        }
-
-                        struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode));
-
-                        GGML_PRINT_DEBUG("rope f16: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
-                        check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY, {});
-                    }
-                }
-            }
-        }
-
-        // im2col f32
-        {
-            srand(seed);
-            const int nargs = 1;
-            const int ndims = 4;
-
-            for (const bool is_2D : {false, true}) {
-                int64_t ne0[ndims];
-                int64_t ne1[ndims];
-                get_random_dims(ne0, ndims);
-                get_random_dims(ne1, ndims);
-
-                // // Ensure that the output is not zero-sized:
-                ne1[0] += 8;
-                ne1[1] += 8;
-
-                if (is_2D) {
-                    ne1[2] = ne0[2];
-                } else {
-                    ne1[1] = ne0[1];
-                    ne0[3] = 1;
-                    ne1[3] = 1;
-                }
-
-                // The order of arguments is swapped because the first tensor is only used for its shape.
-                x[1] = get_random_tensor_f16(ctx0, ndims, ne0, -1.0f, 1.0f);
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne1, -1.0f, 1.0f);
-
-                ggml_set_param(ctx0, x[0]);
-
-                const int s0 =         1 + irand(2);
-                const int s1 = is_2D ? 1 + irand(2) : 0;
-                const int p0 =         0 + irand(2);
-                const int p1 = is_2D ? 0 + irand(2) : 0;
-                const int d0 =         1 + irand(2);
-                const int d1 = is_2D ? 1 + irand(2) : 0;
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_im2col(ctx0, x[1], x[0], s0, s1, p0, p1, d0, d1, is_2D, GGML_TYPE_F32));
-
-                GGML_PRINT_DEBUG("im2col f32: is_2D=%s, s0=%d, s1=%d, p0=%d, p1=%d, d0=%d, d1=%d\n", is_2D ? "yes" : "no", s0, s1, p0, p1, d0, d1);
-                check_gradient("im2col f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // pool_2d f32
-        {
-            srand(seed);
-            const int nargs = 1;
-            const int ndims = 4;
-
-            for (const enum ggml_op_pool op : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
-                int64_t ne0[ndims];
-                get_random_dims(ne0, ndims);
-
-                ne0[0] += 8;
-                ne0[1] += 8;
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne0, -1.0f, 1.0f);
-
-                ggml_set_param(ctx0, x[0]);
-
-                const int k0 = 2 + irand(2);
-                const int k1 = 2 + irand(2);
-                const int s0 = 2 + irand(2);
-                const int s1 = 2 + irand(2);
-                const int p0 = 0 + irand(2);
-                const int p1 = 0 + irand(2);
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_pool_2d(ctx0, x[0], op, k0, k1, s0, s1, p0, p1));
-
-                GGML_PRINT_DEBUG("ggml_pool_2d f32: op=%s k0=%d, k1=%d, s0=%d, s1=%d, p0=%d, p1=%d\n",
-                                 op == GGML_OP_POOL_MAX ? "max" : "avg", k0, k1, s0, s1, p0, p1);
-                std::vector<double> expected_vals;
-                if (op == GGML_OP_POOL_MAX) {
-                    expected_vals.push_back(0.0);
-                    expected_vals.push_back(1.0);
-                }
-                check_gradient("ggml_pool_2d f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, expected_vals);
-            }
-        }
-
-        // flash_attn f32
-        // TODO: adapt to ggml_flash_attn_ext() changes
-        //{
-        //    srand(seed);
-        //    const int nargs = 3;
-
-        //    int64_t ne2[4];
-
-        //    get_random_dims(ne2, 4);
-        //    int64_t D = ne2[0];
-        //    int64_t N = ne2[1];
-        //    int64_t M = ne2[2] + N;
-        //    int64_t B = ne2[3];
-
-        //    for (int masked = 0; masked <= 1; ++masked) {
-        //        for (int ndims = 2; ndims <= 4; ++ndims) {
-        //            int max_nrep = (ndims >= 3) ? 2 : 1;
-        //            for (int nrep = 1; nrep < max_nrep; ++nrep) {
-        //                int64_t neq[4] = { D, N, B*nrep, ne[3] };
-        //                int64_t nek[4] = { D, M, B, ne[3] };
-        //                int64_t nev[4] = { M, D, B, ne[3] };
-        //                if (ndims == 2) {
-        //                    neq[2] = 1; neq[3] = 1;
-        //                    nek[2] = 1; nek[3] = 1;
-        //                    nev[2] = 1; nev[3] = 1;
-        //                } else if (ndims == 3) {
-        //                    neq[3] = 1;
-        //                    nek[3] = 1;
-        //                    nev[3] = 1;
-        //                }
-        //                x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f);
-        //                x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f);
-        //                x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f);
-        //                ggml_set_param(ctx0, x[0]);
-        //                ggml_set_param(ctx0, x[1]);
-        //                ggml_set_param(ctx0, x[2]);
-
-        //                struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
-
-        //                check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY, {});
-        //            }
-        //        }
-        //    }
-        //}
-
-        ggml_free(ctx0);
-    }
-
-    return 0;
-}
index 5a6453c606ddcc50bdc373a98709c93b1fac6244..6e561efe3073adabaf29844d19e83c8f5bfcbb8d 100644 (file)
@@ -97,15 +97,15 @@ bool check_gradient(
         float max_error_abs,
         float max_error_rel) {
     const int n_threads = 1;
+    ggml_set_loss(f);
 
     struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
     ggml_build_forward_expand(gf, f);
     struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf);
-    ggml_build_backward_expand(ctx0, gf, gb, false);
+    ggml_build_backward_expand(ctx0, ctx0, gb, false);
 
     ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
-    ggml_graph_reset  (gf);
-    ggml_set_f32      (f->grad, 1.0f);
+    ggml_graph_reset(gb);
     ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
 
     ggml_graph_dump_dot(gf, NULL, "test-grad0-forward.dot");
@@ -132,11 +132,10 @@ bool check_gradient(
             set_element(x[i], k, x0);
 
             // compute gradient using backward graph
-            ggml_graph_reset  (gf);
-            ggml_set_f32      (f->grad, 1.0f);
+            ggml_graph_reset(gb);
             ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
 
-            const float g1 = get_element(x[i]->grad, k);
+            const float g1 = get_element(ggml_graph_get_grad(gb, x[i]), k);
 
             const float error_abs = fabsf(g0 - g1);
             const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabs(g0) : 0;
index 546ca230ba4170c4d8a9d734f445b96db53b664d..4abe85c74a4f6285456cb4269efa7fb47374a31b 100644 (file)
 #include "ggml.h"
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
+#include "ggml-cpu.h"
+#include "ggml-opt.h"
 
 #include <cmath>
-#include <cstdio>
-#include <cstdlib>
-#include <cassert>
-
-#define MAX_NARGS 2
-
-#if defined(__GNUC__)
-#pragma GCC diagnostic ignored "-Wdouble-promotion"
-#endif
-
-//
-// logging
-//
-#define GGML_DEBUG 0
-#if (GGML_DEBUG >= 1)
-#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
-#else
-#define GGML_PRINT_DEBUG(...)
-#endif
-
-#if (GGML_DEBUG >= 5)
-#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
-#else
-#define GGML_PRINT_DEBUG_5(...)
-#endif
-
-#if (GGML_DEBUG >= 10)
-#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
-#else
-#define GGML_PRINT_DEBUG_10(...)
-#endif
-
-#define GGML_PRINT(...) printf(__VA_ARGS__)
-
-
-static float frand(void) {
-    return (float)rand()/(float)RAND_MAX;
-}
-
-static struct ggml_tensor * get_random_tensor(
-    struct ggml_context * ctx0, int ndims, int64_t ne[], float fmin, float fmax
-) {
-    struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F32, ndims, ne);
-
-    switch (ndims) {
-        case 1:
-            for (int i0 = 0; i0 < ne[0]; i0++) {
-                ((float *)result->data)[i0] = frand()*(fmax - fmin) + fmin;
+#include <inttypes.h>
+#include <random>
+#include <string>
+#include <thread>
+#include <vector>
+
+static bool almost_equal(const double a, const double b, const double atol) {
+    return fabs(a - b) < atol;
+}
+
+constexpr int64_t ne_datapoint = 2;
+constexpr int64_t ne_label     = 1;
+constexpr int64_t ndata        = 6;
+
+struct helper_ctx_data {
+    std::vector<ggml_opt_dataset_t>   datasets_supervised;
+    std::vector<struct ggml_tensor *> data_batch;
+    std::vector<struct ggml_tensor *> labels_batch;
+
+    ggml_opt_dataset_t       dataset_unsupervised;
+    struct ggml_context    * ctx_static;
+    struct ggml_context    * ctx_compute;
+    struct ggml_opt_params   opt_params;
+    ggml_opt_context_t       opt_ctx;
+    struct ggml_tensor     * inputs;
+    struct ggml_tensor     * weights;
+    struct ggml_tensor     * outputs;
+    ggml_backend_buffer_t    buf;
+    ggml_opt_result_t        result;
+    ggml_opt_result_t        result2;
+};
+
+// These default values make it easier to check optimization results vs. expected values.
+static ggml_opt_optimizer_params helper_get_test_opt_pars(void * userdata) {
+    ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(userdata);
+    result.adamw.alpha = 1.0f;
+    result.adamw.beta1 = 0.0f;
+    result.adamw.beta2 = 0.0f;
+    result.adamw.eps   = 0.0f;
+    return result;
+}
+
+static helper_ctx_data helper_get_ctx_data(
+        ggml_backend_sched_t    backend_sched,
+        ggml_backend_t          backend,
+        const bool              init_opt_ctx       = true,
+        const bool              optimizer_defaults = true,
+        int64_t                 nbatch_logical     = 1,
+        int64_t                 nbatch_physical    = 1,
+        enum ggml_opt_loss_type loss_type          = GGML_OPT_LOSS_TYPE_SUM) {
+    std::vector<ggml_opt_dataset_t> datasets(ndata);
+    for (int64_t ndata_shard = 1; ndata_shard <= ndata; ++ndata_shard) {
+        ggml_opt_dataset_t dataset = ggml_opt_dataset_init(ne_datapoint, ne_label, ndata, ndata_shard);
+
+        float * data   = ggml_get_data_f32(ggml_opt_dataset_data(  dataset));
+        float * labels = ggml_get_data_f32(ggml_opt_dataset_labels(dataset));
+
+        for (int64_t idata = 0; idata < ndata; ++idata) {
+            for (int64_t id = 0; id < ne_datapoint; ++id) {
+                data[  idata*ne_datapoint + id] =     16*idata + id;
             }
-            break;
-        case 2:
-            for (int i1 = 0; i1 < ne[1]; i1++) {
-                for (int i0 = 0; i0 < ne[0]; i0++) {
-                    ((float *)result->data)[i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
-                }
+            for (int64_t il = 0; il < ne_label;     ++il) {
+                labels[idata*ne_label     + il] = 16*(16*idata + il);
             }
-            break;
-        case 3:
-            for (int i2 = 0; i2 < ne[2]; i2++) {
-                for (int i1 = 0; i1 < ne[1]; i1++) {
-                    for (int i0 = 0; i0 < ne[0]; i0++) {
-                        ((float *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
+        }
+
+        datasets[ndata_shard-1] = dataset;
+    }
+
+    ggml_opt_dataset_t dataset_unsupervised = ggml_opt_dataset_init(1, 0, ndata, /*ndata_shard =*/ 1);
+
+    float * data = ggml_get_data_f32(ggml_opt_dataset_data(dataset_unsupervised));
+
+    for (int64_t idata = 0; idata < ndata; ++idata) {
+        data[idata] = idata;
+    }
+
+    struct ggml_context * ctx_static;
+    struct ggml_context * ctx_compute;
+    {
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ (2*ndata + 2)*ggml_tensor_overhead(),
+            /*.mem_buffer =*/ nullptr,
+            /*.no_alloc   =*/ true,
+        };
+        ctx_static = ggml_init(params);
+    }
+    {
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ GGML_DEFAULT_GRAPH_SIZE*ggml_tensor_overhead() + 3*ggml_graph_overhead(),
+            /*.mem_buffer =*/ nullptr,
+            /*.no_alloc   =*/ true,
+        };
+        ctx_compute = ggml_init(params);
+    }
+
+    std::vector<struct ggml_tensor *>   data_batch(ndata);
+    std::vector<struct ggml_tensor *> labels_batch(ndata);
+    for (int64_t ndata_batch = 1; ndata_batch <= ndata; ++ndata_batch) {
+        data_batch[ndata_batch-1]   = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, ndata_batch*ne_datapoint);
+        labels_batch[ndata_batch-1] = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, ndata_batch*ne_label);
+    }
+
+    struct ggml_tensor * inputs = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, nbatch_physical);
+    ggml_set_name(inputs, "inputs");
+
+    struct ggml_tensor * weights = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1);
+    ggml_set_name(weights, "weights");
+    ggml_set_param(ctx_static, weights);
+
+    struct ggml_tensor * intermediary = ggml_add(ctx_compute, inputs, weights);
+
+    struct ggml_tensor * outputs = ggml_scale(ctx_compute, intermediary, 1.0f);
+    ggml_set_name(outputs, "outputs");
+
+    ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx_static, backend);
+    const float w0 = float(ndata)/2;
+    ggml_backend_tensor_set(weights, &w0, 0, sizeof(float));
+
+    GGML_ASSERT(nbatch_logical % nbatch_physical == 0);
+    const int32_t opt_period = nbatch_logical / nbatch_physical;
+
+    struct ggml_opt_params opt_params = ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type);
+    opt_params.opt_period = opt_period;
+    if (!optimizer_defaults) {
+        opt_params.get_opt_pars = helper_get_test_opt_pars;
+    }
+    ggml_opt_context_t opt_ctx = init_opt_ctx ? ggml_opt_init(opt_params) : nullptr;
+
+    ggml_opt_result_t result  = ggml_opt_result_init();
+    ggml_opt_result_t result2 = ggml_opt_result_init();
+
+    return {datasets, data_batch, labels_batch, dataset_unsupervised, ctx_static, ctx_compute, opt_params, opt_ctx, inputs, weights, outputs, buf, result, result2};
+}
+
+static void helper_free_ctx_data(struct helper_ctx_data ctx_data) {
+    ggml_opt_result_free(ctx_data.result);
+    ggml_opt_result_free(ctx_data.result2);
+    ggml_opt_free(ctx_data.opt_ctx);
+    ggml_backend_buffer_free(ctx_data.buf);
+    ggml_free(ctx_data.ctx_static);
+    ggml_free(ctx_data.ctx_compute);
+    for (ggml_opt_dataset_t dataset : ctx_data.datasets_supervised) {
+        ggml_opt_dataset_free(dataset);
+    }
+    ggml_opt_dataset_free(ctx_data.dataset_unsupervised);
+}
+
+static void helper_after_test(
+        const char * func, const bool high_level, const std::string options,
+        const std::string subtest, const bool subtest_ok, int & ntest, int & npass) {
+    printf("  %s(high_level=%s%s, subtest=%s): ",
+           func, high_level ? "yes" : "no", options.c_str(), subtest.c_str());
+    if (subtest_ok) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+}
+
+static std::pair<int, int> test_dataset(ggml_backend_sched_t backend_sched, ggml_backend_t backend, const bool shuffle) {
+    int ntest = 0;
+    int npass = 0;
+
+    struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend);
+
+    for (int64_t ndata_shard = 1; ndata_shard <= ndata; ++ndata_shard) {
+        ggml_opt_dataset_t dataset = cd.datasets_supervised[ndata_shard-1];
+
+        if (shuffle) {
+            ggml_opt_dataset_shuffle(cd.opt_ctx, dataset, -1);
+        }
+
+        for (int64_t ndata_batch = 1; ndata_batch <= ndata; ++ndata_batch) {
+            if (ndata_batch % ndata_shard != 0) {
+                continue;
+            }
+            bool subtest_ok = true;
+
+            struct ggml_tensor *   data_batch =   cd.data_batch[ndata_batch-1];
+            struct ggml_tensor * labels_batch = cd.labels_batch[ndata_batch-1];
+
+            std::vector<float>   data(ggml_nelements(  data_batch));
+            std::vector<float> labels(ggml_nelements(labels_batch));
+
+            std::vector<int64_t> idata_shuffled;
+            const int64_t nbatches = ndata / ndata_batch;
+            for (int64_t ibatch = 0; ibatch < nbatches; ++ibatch) {
+                ggml_opt_dataset_get_batch(dataset, data_batch, labels_batch, ibatch);
+
+                ggml_backend_tensor_get(  data_batch,   data.data(), 0, ggml_nbytes(  data_batch));
+                ggml_backend_tensor_get(labels_batch, labels.data(), 0, ggml_nbytes(labels_batch));
+
+                for (int64_t idata_batch = 0; idata_batch < ndata_batch; ++idata_batch) {
+                    const int64_t idata = ibatch*ndata_batch + idata_batch;
+                    const int64_t idata_found = data[idata_batch*ne_datapoint] / 16;
+                    subtest_ok = subtest_ok && (shuffle || idata_found == idata);
+                    idata_shuffled.push_back(idata_found);
+
+                    for (int64_t id = 0; id < ne_datapoint; ++id) {
+                        if (data[  idata_batch*ne_datapoint + id] != 16*idata_found + id) {
+                            subtest_ok = false;
+                        }
+                    }
+                    for (int64_t il = 0; il < ne_label;     ++il) {
+                        if (labels[idata_batch*ne_label     + il] != 16*(16*idata_found + il)) {
+                            subtest_ok = false;
+                        }
                     }
                 }
             }
-            break;
-        case 4:
-            for (int i3 = 0; i3 < ne[3]; i3++) {
-                for (int i2 = 0; i2 < ne[2]; i2++) {
-                    for (int i1 = 0; i1 < ne[1]; i1++) {
-                        for (int i0 = 0; i0 < ne[0]; i0++) {
-                            ((float *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
-                        }
+
+            if (!shuffle || ndata % ndata_batch == 0) {
+                const int ndata_max = (ndata / ndata_batch) * ndata_batch;
+
+                for (int64_t idata = 0; subtest_ok && idata < ndata_max; ++idata) {
+                    int ninstances = 0;
+                    for (int64_t id : idata_shuffled) {
+                        ninstances += id == idata;
                     }
+                    if (ninstances != 1) {
+                        subtest_ok = false;
+                    }
+                }
+            }
+
+            printf("  %s(shuffle=%s, ndata_shard=%" PRId64 ", ndata_batch=%" PRId64 "): ",
+                   __func__, shuffle ? "yes" : "no", ndata_shard, ndata_batch);
+            if (subtest_ok) {
+                printf("\033[1;32mOK\033[0m\n");
+                npass++;
+            } else {
+                printf("\033[1;31mFAIL\033[0m\n");
+            }
+            ntest++;
+        }
+    }
+
+    helper_free_ctx_data(cd);
+
+    return std::make_pair(npass, ntest);
+}
+
+static std::pair<int, int> test_grad(ggml_backend_sched_t backend_sched, ggml_backend_t backend) {
+    int ntest = 0;
+    int npass = 0;
+
+    struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false,
+    /*nbatch_logical =*/ 999999, /*nbatch_physical =*/ 1);
+
+    std::vector<float> grad_history(ndata);
+    for (int64_t idata = 0; idata < ndata; ++idata) {
+        grad_history[idata] = NAN;
+    }
+
+    for (int idata = 0; idata < ndata; ++idata) {
+        const float idataf = idata;
+        ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
+        ggml_opt_forward_backward(cd.opt_ctx, cd.result);
+        ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata, 0, sizeof(float));
+    }
+
+    {
+        bool subtest_ok = true;
+        for (int idata = 0; idata < ndata; ++idata) {
+            if (grad_history[idata] != idata + 1) {
+                subtest_ok = false;
+            }
+        }
+        printf("  %s(): ", __func__);
+        if (subtest_ok) {
+            printf("\033[1;32mOK\033[0m\n");
+            npass++;
+        } else {
+            printf("\033[1;31mFAIL\033[0m\n");
+        }
+        ntest++;
+    }
+
+    helper_free_ctx_data(cd);
+
+    return std::make_pair(npass, ntest);
+}
+
+static void helper_after_test_forward_backward(
+        const char * func, const bool high_level, const bool shuffle,
+        const std::string subtest, const bool subtest_ok, int & ntest, int & npass) {
+    std::string options = ", shuffle=";
+    options += shuffle ? "yes" : "no";
+    helper_after_test(func, high_level, options, subtest, subtest_ok, ntest, npass);
+}
+
+static std::pair<int, int> test_forward_backward(
+        ggml_backend_sched_t backend_sched, ggml_backend_t backend, const bool high_level, const bool shuffle) {
+    int ntest = 0;
+    int npass = 0;
+
+    struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false);
+    struct ggml_tensor * loss = ggml_opt_loss(cd.opt_ctx);
+
+    std::vector<float> loss_history(ndata);
+    for (int64_t idata = 0; idata < ndata; ++idata) {
+        loss_history[idata] = NAN;
+    }
+
+    {
+        int64_t ndata;
+        ggml_opt_result_ndata(cd.result, &ndata);
+        double loss;
+        double loss_unc;
+        ggml_opt_result_loss(cd.result, &loss, &loss_unc);
+        double accuracy;
+        double accuracy_unc;
+        ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);
+        const bool subtest_ok = ndata == 0 && loss == 0.0 && std::isnan(loss_unc) && std::isnan(accuracy) && std::isnan(accuracy_unc);
+        helper_after_test_forward_backward(__func__, high_level, shuffle, "results_initial", subtest_ok, ntest, npass);
+    }
+
+    if (high_level) {
+        ggml_opt_dataset_t dataset = cd.dataset_unsupervised;
+        if (shuffle) {
+            ggml_opt_dataset_shuffle(cd.opt_ctx, dataset, -1);
+        }
+        ggml_opt_epoch(cd.opt_ctx, dataset, nullptr, cd.result, 0, nullptr, nullptr);
+    } else {
+        for (int idata = 0; idata < ndata; ++idata) {
+            const float idataf = idata;
+            ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
+            ggml_opt_forward(cd.opt_ctx, cd.result);
+            ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));
+        }
+    }
+
+    {
+        float weights;
+        ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));
+        const bool subtest_ok = weights == ndata/2;
+        helper_after_test_forward_backward(__func__, high_level, shuffle, "weights_after_forward", subtest_ok, ntest, npass);
+    }
+    {
+        int64_t ndata;
+        ggml_opt_result_ndata(cd.result, &ndata);
+        bool subtest_ok = ndata == 6;
+
+        double loss;
+        double loss_unc;
+        ggml_opt_result_loss(cd.result, &loss, &loss_unc);
+        subtest_ok = subtest_ok && loss == 33.0 && almost_equal(loss_unc, sqrt(3.5), 1e-10);
+
+        double accuracy;
+        double accuracy_unc;
+        ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);
+        subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);
+
+        helper_after_test_forward_backward(__func__, high_level, shuffle, "results_after_forward", subtest_ok, ntest, npass);
+    }
+
+    float w0;
+    ggml_backend_tensor_get(cd.weights, &w0, 0, sizeof(float));
+    for (int i = 0; i < 10; ++i) {
+        ggml_opt_forward_backward(cd.opt_ctx, nullptr);
+    }
+    ggml_backend_tensor_set(cd.weights, &w0, 0, sizeof(float));
+
+    ggml_opt_reset(cd.opt_ctx, /*optimizer =*/ false);
+    ggml_opt_result_reset(cd.result);
+
+    for (int64_t idata = 0; idata < ndata; ++idata) {
+        loss_history[idata] = NAN;
+    }
+
+    if (high_level) {
+        ggml_opt_dataset_t dataset = cd.dataset_unsupervised;
+        if (shuffle) {
+            ggml_opt_dataset_shuffle(cd.opt_ctx, dataset, -1);
+        }
+        ggml_opt_epoch(cd.opt_ctx, dataset, cd.result, nullptr, ndata, nullptr, nullptr);
+    } else {
+        for (int idata = 0; idata < ndata; ++idata) {
+            const float idataf = idata;
+            ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
+            ggml_opt_forward_backward(cd.opt_ctx, cd.result);
+            ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));
+        }
+    }
+
+    {
+        float weights;
+        ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));
+        const bool subtest_ok = weights == -ndata/2;
+        helper_after_test_forward_backward(__func__, high_level, shuffle, "weights_after_forward_backward", subtest_ok, ntest, npass);
+    }
+    {
+        int64_t ndata;
+        ggml_opt_result_ndata(cd.result, &ndata);
+        bool subtest_ok = ndata == 6;
+
+        double loss;
+        double loss_unc;
+        ggml_opt_result_loss(cd.result, &loss, &loss_unc);
+        subtest_ok = subtest_ok && loss == 18.0 && (shuffle || loss_unc == 0.0);
+
+        double accuracy;
+        double accuracy_unc;
+        ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);
+        subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);
+
+        helper_after_test_forward_backward(__func__, high_level, shuffle, "result_after_forward_backward", subtest_ok, ntest, npass);
+    }
+
+    helper_free_ctx_data(cd);
+
+    return std::make_pair(npass, ntest);
+}
+
+static std::pair<int, int> test_epoch_vs_fit(ggml_backend_sched_t backend_sched, ggml_backend_t backend) {
+    int ntest = 0;
+    int npass = 0;
+
+    float weights_epoch;
+    float weights_fit;
+
+    {
+        struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ true);
+        ggml_opt_dataset_t dataset = cd.dataset_unsupervised;
+
+        ggml_opt_dataset_shuffle(cd.opt_ctx, dataset, -1);
+        ggml_opt_epoch(cd.opt_ctx, dataset, cd.result, nullptr, ndata, nullptr, nullptr);
+
+        ggml_backend_tensor_get(cd.weights, &weights_epoch, 0, ggml_nbytes(cd.weights));
+        helper_free_ctx_data(cd);
+    }
+    {
+        struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ false);
+        ggml_opt_dataset_t dataset = cd.dataset_unsupervised;
+
+        ggml_opt_fit(backend_sched, cd.ctx_compute, cd.inputs, cd.outputs, dataset,
+            GGML_OPT_LOSS_TYPE_SUM, ggml_opt_get_default_optimizer_params, 1, 1, 0.0f, true);
+
+        ggml_backend_tensor_get(cd.weights, &weights_fit, 0, ggml_nbytes(cd.weights));
+        helper_free_ctx_data(cd);
+    }
+
+    const bool subtest_ok = weights_epoch == weights_fit;
+
+    printf("  %s(): ", __func__);
+    if (subtest_ok) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+
+    return std::make_pair(npass, ntest);
+}
+
+static void helper_after_test_idata_split(
+        const char * func, const bool high_level, const int epoch,
+        const std::string subtest, const bool subtest_ok, int & ntest, int & npass) {
+    std::string options = ", epoch=";
+    options += std::to_string(epoch);
+    helper_after_test(func, high_level, options, subtest, subtest_ok, ntest, npass);
+}
+
+static std::pair<int, int> test_idata_split(ggml_backend_sched_t backend_sched, ggml_backend_t backend, const bool high_level) {
+    int ntest = 0;
+    int npass = 0;
+
+    struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false);
+    struct ggml_tensor * loss = ggml_opt_loss(cd.opt_ctx);
+    const int idata_split = ndata * 2/3;
+
+    std::vector<float> loss_history(ndata);
+    for (int64_t idata = 0; idata < ndata; ++idata) {
+        loss_history[idata] = NAN;
+    }
+
+    for (int epoch = 1; epoch <= 4; ++epoch) {
+        if (high_level) {
+            ggml_opt_epoch(cd.opt_ctx, cd.dataset_unsupervised, cd.result, cd.result2, idata_split, nullptr, nullptr);
+        } else {
+            int idata = 0;
+            for (; idata < idata_split; ++idata) {
+                const float idataf = idata;
+                ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
+                ggml_opt_forward_backward(cd.opt_ctx, cd.result);
+                ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));
+            }
+            for (; idata < ndata; ++idata) {
+                const float idataf = idata;
+                ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
+                ggml_opt_forward(cd.opt_ctx, cd.result2);
+                ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));
+            }
+        }
+
+        {
+            float weights;
+            ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));
+            const bool subtest_ok = weights == ndata/2 - epoch*idata_split;
+            helper_after_test_idata_split(__func__, high_level, epoch, "weights", subtest_ok, ntest, npass);
+        }
+        {
+            int64_t ndata_result;
+            ggml_opt_result_ndata(cd.result, &ndata_result);
+            bool subtest_ok = ndata_result == idata_split;
+
+            double loss;
+            double loss_unc;
+            ggml_opt_result_loss(cd.result, &loss, &loss_unc);
+            subtest_ok = subtest_ok && loss == 28.0 - epoch*16.0 && loss_unc == 0.0;
+
+            double accuracy;
+            double accuracy_unc;
+            ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);
+            subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);
+
+            helper_after_test_idata_split(__func__, high_level, epoch, "results_backward", subtest_ok, ntest, npass);
+        }
+        {
+            int64_t ndata_result;
+            ggml_opt_result_ndata(cd.result2, &ndata_result);
+            bool subtest_ok = ndata_result == ndata - idata_split;
+
+            double loss;
+            double loss_unc;
+            ggml_opt_result_loss(cd.result2, &loss, &loss_unc);
+            subtest_ok = subtest_ok && loss == 15.0 - epoch*8 && almost_equal(loss_unc, sqrt(0.5), 1e-10);
+
+            double accuracy;
+            double accuracy_unc;
+            ggml_opt_result_accuracy(cd.result2, &accuracy, &accuracy_unc);
+            subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);
+
+            helper_after_test_idata_split(__func__, high_level, epoch, "results_forward", subtest_ok, ntest, npass);
+        }
+
+        ggml_opt_result_reset(cd.result);
+        ggml_opt_result_reset(cd.result2);
+    }
+
+    helper_free_ctx_data(cd);
+
+    return std::make_pair(npass, ntest);
+}
+
+static void helper_after_test_gradient_accumulation(
+        const char * func, const int nbatch_physical, const enum ggml_opt_loss_type loss_type, const int epoch,
+        const std::string subtest, const bool subtest_ok, int & ntest, int & npass) {
+    std::string options = ", nbatch_physical=";
+    options += std::to_string(nbatch_physical);
+    options += ", loss_type=";
+    options += loss_type == GGML_OPT_LOSS_TYPE_MEAN ? "mean" : "sum";
+    options += ", epoch=";
+    options += std::to_string(epoch);
+    helper_after_test(func, false, options, subtest, subtest_ok, ntest, npass);
+}
+
+static std::pair<int, int> test_gradient_accumulation(
+        ggml_backend_sched_t backend_sched, ggml_backend_t backend, const int32_t nbatch_physical, const enum ggml_opt_loss_type loss_type) {
+    int ntest = 0;
+    int npass = 0;
+
+    struct helper_ctx_data cd = helper_get_ctx_data(
+        backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false, /*nbatch_logical =*/ 6, nbatch_physical, loss_type);
+    struct ggml_tensor * loss = ggml_opt_loss(cd.opt_ctx);
+
+    std::vector<float> grad_history(ndata);
+    for (int64_t idata = 0; idata < ndata; ++idata) {
+        grad_history[idata] = NAN;
+    }
+
+    for (int epoch = 1; epoch <= 4; ++epoch) {
+        if (nbatch_physical == 1) {
+            for (int idata = 0; idata < ndata; ++idata) {
+                const float idataf = idata;
+                ggml_backend_tensor_set(cd.inputs, &idataf, 0, 1*sizeof(float));
+                ggml_opt_forward_backward(cd.opt_ctx, cd.result);
+                ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata, 0, 1*sizeof(float));
+            }
+        } else if (nbatch_physical == 2) {
+            for (int idata = 0; idata < ndata; idata += 2) {
+                const float idataf[2] = {float(idata + 0), float(idata + 1)};
+                ggml_backend_tensor_set(cd.inputs, idataf, 0, 2*sizeof(float));
+                ggml_opt_forward_backward(cd.opt_ctx, cd.result);
+
+                grad_history[idata + 0] = 0.0f;
+                ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata + 1, 0, 1*sizeof(float));
+            }
+        } else {
+            GGML_ASSERT(false);
+        }
+
+        {
+            GGML_ASSERT(ndata == 6);
+            constexpr double atol = 1e-6;
+            bool subtest_ok = true;
+            if (loss_type == GGML_OPT_LOSS_TYPE_SUM) {
+                if (nbatch_physical == 1) {
+                    subtest_ok = subtest_ok && almost_equal(grad_history[0], 1.0, atol);
+                    subtest_ok = subtest_ok && almost_equal(grad_history[2], 3.0, atol);
+                    subtest_ok = subtest_ok && almost_equal(grad_history[4], 5.0, atol);
+                } else {
+                    subtest_ok = subtest_ok && almost_equal(grad_history[0], 0.0, atol);
+                    subtest_ok = subtest_ok && almost_equal(grad_history[2], 0.0, atol);
+                    subtest_ok = subtest_ok && almost_equal(grad_history[4], 0.0, atol);
                 }
+                subtest_ok = subtest_ok && almost_equal(grad_history[1], 2.0, atol);
+                subtest_ok = subtest_ok && almost_equal(grad_history[3], 4.0, atol);
+                subtest_ok = subtest_ok && almost_equal(grad_history[5], 0.0, atol);
+            } else if (loss_type == GGML_OPT_LOSS_TYPE_MEAN) {
+                if (nbatch_physical == 1) {
+                    subtest_ok = subtest_ok && almost_equal(grad_history[0], 1.0/ndata, atol);
+                    subtest_ok = subtest_ok && almost_equal(grad_history[2], 3.0/ndata, atol);
+                    subtest_ok = subtest_ok && almost_equal(grad_history[4], 5.0/ndata, atol);
+                } else {
+                    subtest_ok = subtest_ok && almost_equal(grad_history[0], 0.0/ndata, atol);
+                    subtest_ok = subtest_ok && almost_equal(grad_history[2], 0.0/ndata, atol);
+                    subtest_ok = subtest_ok && almost_equal(grad_history[4], 0.0/ndata, atol);
+                }
+                subtest_ok = subtest_ok && almost_equal(grad_history[1], 2.0/ndata, atol);
+                subtest_ok = subtest_ok && almost_equal(grad_history[3], 4.0/ndata, atol);
+                subtest_ok = subtest_ok && almost_equal(grad_history[5], 0.0/ndata, atol);
+            } else {
+                GGML_ASSERT(false);
+            }
+            helper_after_test_gradient_accumulation(__func__, nbatch_physical, loss_type, epoch, "grads", subtest_ok, ntest, npass);
+        }
+        {
+            float weights;
+            ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));
+            const bool subtest_ok = weights == (ndata/2) - epoch;
+            helper_after_test_gradient_accumulation(__func__, nbatch_physical, loss_type, epoch, "weights", subtest_ok, ntest, npass);
+        }
+        {
+            int64_t ndata_result;
+            ggml_opt_result_ndata(cd.result, &ndata_result);
+            bool subtest_ok = ndata_result == ndata/nbatch_physical;
+
+            double loss;
+            ggml_opt_result_loss(cd.result, &loss, /*loss_unc =*/ nullptr);
+            if (loss_type == GGML_OPT_LOSS_TYPE_SUM) {
+                subtest_ok = subtest_ok && loss == (39.0 - epoch*6.0);
+            } else if (loss_type == GGML_OPT_LOSS_TYPE_MEAN) {
+                subtest_ok = subtest_ok && almost_equal(loss, (39.0 - epoch*6.0) / ndata, 1e-6);
+            } else {
+                GGML_ASSERT(false);
             }
-            break;
-        default:
-            assert(false);
+
+            double accuracy;
+            double accuracy_unc;
+            ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);
+            subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);
+
+            helper_after_test_gradient_accumulation(__func__, nbatch_physical, loss_type, epoch, "results", subtest_ok, ntest, npass);
+        }
+
+        ggml_opt_result_reset(cd.result);
     }
 
+    helper_free_ctx_data(cd);
+
+    return std::make_pair(npass, ntest);
+}
+
+static ggml_opt_optimizer_params helper_get_regression_opt_pars(void * userdata) {
+    ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(userdata);
+    result.adamw.alpha = 0.1f;
     return result;
 }
 
-int main(void) {
-    struct ggml_init_params params = {
-        /* .mem_size   = */ 1024*1024*1024,
-        /* .mem_buffer = */ NULL,
-        /* .no_alloc   = */ false,
-    };
+static std::pair<int, int> test_regression(ggml_backend_sched_t backend_sched, ggml_backend_t backend) {
+    int ntest = 0;
+    int npass = 0;
 
-    struct ggml_context * ctx = ggml_init(params);
+    // Test for simple regression with f(x) = a*x + b
 
-    int64_t ne1[4] = {4, 128, 1, 1};
-    int64_t ne2[4] = {4, 256, 1, 1};
-    int64_t ne3[4] = {128, 256, 1, 1};
+    constexpr int64_t ndata_regression = 201;
+    constexpr float a_true = 1.2f;
+    constexpr float b_true = 3.4f;
 
-    struct ggml_tensor * a = get_random_tensor(ctx, 2, ne1, -1, +1);
-    struct ggml_tensor * b = get_random_tensor(ctx, 2, ne2, -1, +1);
-    ggml_set_param(ctx, a);
-    ggml_set_param(ctx, b);
+    std::mt19937 gen(12345);
+    std::normal_distribution<float> nd{0.0f, 0.1f};
 
-    struct ggml_tensor * c = get_random_tensor(ctx, 2, ne3, -1, +1);
+    ggml_opt_dataset_t dataset = ggml_opt_dataset_init(1, 1, ndata_regression, ndata_regression);
 
-    struct ggml_tensor * ab = ggml_mul_mat(ctx, a, b);
-    struct ggml_tensor * d  = ggml_sub(ctx, c, ab);
-    struct ggml_tensor * e  = ggml_sum(ctx, ggml_sqr(ctx, d));
+    float * data   = ggml_get_data_f32(ggml_opt_dataset_data(  dataset));
+    float * labels = ggml_get_data_f32(ggml_opt_dataset_labels(dataset));
 
-    struct ggml_cgraph * ge = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, true);
-    ggml_build_forward_expand(ge, e);
-    ggml_graph_reset(ge);
+    constexpr float x_min = -100.0f;
+    constexpr float x_max =  100.0f;
 
-    ggml_graph_compute_with_ctx(ctx, ge, /*n_threads*/ 1);
+    for (int64_t idata = 0; idata < ndata_regression; ++idata) {
+        const float x = x_min + (x_max - x_min) * idata/(ndata_regression-1);
+        const float y = a_true*x + b_true + nd(gen);
 
-    const float fe = ggml_get_f32_1d(e, 0);
-    printf("%s: e = %.4f\n", __func__, fe);
+        data[idata]   = x;
+        labels[idata] = y;
+    }
 
-    struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_TYPE_ADAM);
+    struct ggml_context * ctx_static;
+    struct ggml_context * ctx_compute;
+    {
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ 3*ggml_tensor_overhead(),
+            /*.mem_buffer =*/ nullptr,
+            /*.no_alloc   =*/ true,
+        };
+        ctx_static = ggml_init(params);
+    }
+    {
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ GGML_DEFAULT_GRAPH_SIZE*ggml_tensor_overhead() + 3*ggml_graph_overhead(),
+            /*.mem_buffer =*/ nullptr,
+            /*.no_alloc   =*/ true,
+        };
+        ctx_compute = ggml_init(params);
+    }
 
-    ggml_opt(ctx, opt_params, e);
+    // The first dimension is the dimension of the datapoints, the second dimension is the number of datapoints.
+    struct ggml_tensor * x = ggml_new_tensor_2d(ctx_static, GGML_TYPE_F32, 1, ndata_regression);
+    ggml_set_name(x, "x");
+
+    struct ggml_tensor * a = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1);
+    ggml_set_name(a, "a");
+    ggml_set_param(ctx_static, a);
+
+    struct ggml_tensor * b = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1);
+    ggml_set_name(b, "b");
+    ggml_set_param(ctx_static, b);
+
+    struct ggml_tensor * f = ggml_add(ctx_compute, ggml_mul(ctx_compute, x, a), b);
+    ggml_set_name(f, "f");
+    ggml_set_param(ctx_static, f);
+
+    ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx_static, backend);
+    const float a0 = 1.0f;
+    const float b0 = 3.0f;
+    ggml_backend_tensor_set(a, &a0, 0, sizeof(float));
+    ggml_backend_tensor_set(b, &b0, 0, sizeof(float));
+
+    ggml_opt_fit(backend_sched, ctx_compute, x, f, dataset, GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR,
+        helper_get_regression_opt_pars, 100, ndata_regression, 0.0f, true);
+
+    {
+        float a_fit;
+        ggml_backend_tensor_get(a, &a_fit, 0, sizeof(float));
+        float b_fit;
+        ggml_backend_tensor_get(b, &b_fit, 0, sizeof(float));
+        const bool subtest_ok = almost_equal(a_fit, a_true, 1e-2) && almost_equal(b_fit, b_true, 1e-2);
+        printf("  %s(subtest=weights): ", __func__);
+        if (subtest_ok) {
+            printf("\033[1;32mOK\033[0m\n");
+            npass++;
+        } else {
+            printf("\033[1;31mFAIL\033[0m\n");
+        }
+        ntest++;
+    }
 
-    ggml_graph_reset(ge);
+    ggml_backend_buffer_free(buf);
+    ggml_free(ctx_static);
+    ggml_opt_dataset_free(dataset);
 
-    ggml_graph_compute_with_ctx(ctx, ge, /*n_threads*/ 1);
+    return std::make_pair(npass, ntest);
+}
 
-    const float fe_opt = ggml_get_f32_1d(e, 0);
-    printf("%s: original  e = %.4f\n", __func__, fe);
-    printf("%s: optimized e = %.4f\n", __func__, fe_opt);
+static std::pair<int, int> test_backend(ggml_backend_sched_t backend_sched, ggml_backend_t backend) {
+    int npass = 0;
+    int ntest = 0;
 
-    const bool success = (fe_opt <= fe);
-    assert(success);
+    for (bool shuffle : {false, true}) {
+        std::pair<int, int> partial = test_dataset(backend_sched, backend, shuffle);
+        npass += partial.first;
+        ntest += partial.second;
+    }
+    {
+        std::pair<int, int> partial = test_grad(backend_sched, backend);
+        npass += partial.first;
+        ntest += partial.second;
+    }
+    for (bool high_level : {false, true}){
+        for (bool shuffle : {false, true}) {
+            if (!high_level && shuffle) {
+                continue;
+            }
 
-    ggml_free(ctx);
-    return success ? 0 : -1;
+            std::pair<int, int> partial = test_forward_backward(backend_sched, backend, high_level, shuffle);
+            npass += partial.first;
+            ntest += partial.second;
+        }
+    }
+    {
+        std::pair<int, int> partial = test_epoch_vs_fit(backend_sched, backend);
+        npass += partial.first;
+        ntest += partial.second;
+    }
+    for (bool high_level : {false, true}){
+        std::pair<int, int> partial = test_idata_split(backend_sched, backend, high_level);
+        npass += partial.first;
+        ntest += partial.second;
+    }
+    for (int32_t nbatch_physical : {2, 1}) {
+        for (enum ggml_opt_loss_type loss_type : {GGML_OPT_LOSS_TYPE_SUM, GGML_OPT_LOSS_TYPE_MEAN}) {
+            std::pair<int, int> partial = test_gradient_accumulation(backend_sched, backend, nbatch_physical, loss_type);
+            npass += partial.first;
+            ntest += partial.second;
+        }
+    }
+    {
+        std::pair<int, int> partial = test_regression(backend_sched, backend);
+        npass += partial.first;
+        ntest += partial.second;
+    }
+
+    return std::make_pair(npass, ntest);
 }
-// int64_t ne1[4] = {4, 128, 1, 1};
-// int64_t ne2[4] = {4, 256, 1, 1};;
-// int64_t ne3[4] = {128, 256, 1, 1};
-// main: original  e = 25890.9375
-// main: optimized e = 10094.7031
 
-// int64_t ne1[4] = {8, 128, 1, 1};
-// int64_t ne2[4] = {8, 256, 1, 1};;
-// int64_t ne3[4] = {128, 256, 1, 1};
-// main: original  e = 39429.5078
-// main: optimized e = 9275.8936
+int main(void) {
+    const size_t dev_count = ggml_backend_dev_count();
+    printf("Testing %zu devices\n\n", dev_count);
+    size_t n_ok = 0;
+
+    std::vector<ggml_backend_dev_t> devs;
+    std::vector<ggml_backend_t>     backends;
 
-// int64_t ne1[4] = {16, 128, 1, 1};
-// int64_t ne2[4] = {16, 256, 1, 1};;
-// int64_t ne3[4] = {128, 256, 1, 1};
-// main: original  e = 68371.1328
-// main: optimized e = 7854.4502
+    for (size_t i = 0; i < dev_count; ++i) {
+        devs.push_back(ggml_backend_dev_get(i));
 
+        ggml_backend_t backend = ggml_backend_dev_init(devs[i], NULL);
+        GGML_ASSERT(backend != NULL);
 
-// int64_t ne1[4] = {32, 128, 1, 1};
-// int64_t ne2[4] = {32, 256, 1, 1};;
-// int64_t ne3[4] = {128, 256, 1, 1};
-// main: original  e = 126061.1953
-// main: optimized e = 5451.0166
+        if (ggml_backend_is_cpu(backend)) {
+            ggml_backend_cpu_set_n_threads(backend, std::thread::hardware_concurrency() / 2);
+        }
+
+        backends.push_back(backend);
+    }
 
-// int64_t ne1[4] = {4, 1024, 1, 1};
-// int64_t ne2[4] = {4, 2048, 1, 1};;
-// int64_t ne3[4] = {1024, 2048, 1, 1};
-// main: original  e = 1620817.8750
-// main: optimized e = 698387.6875
+    for (size_t i = 0; i < dev_count; ++i) {
+        // Put the backend to be tested in front so that it's prioritized:
+        std::vector<ggml_backend_t> backends_modded = {backends[i]};
+        backends_modded.insert(backends_modded.end(), backends.begin(), backends.end());
 
-// another run on M1
-// int64_t ne1[4] = {4, 1024, 1, 1};
-// int64_t ne2[4] = {4, 2048, 1, 1};;
-// int64_t ne3[4] = {1024, 2048, 1, 1};
-// main: original  e = 1629595.6250
-// main: optimized e = 698169.1250
+        ggml_backend_sched_t backend_sched = ggml_backend_sched_new(
+            backends_modded.data(), nullptr, backends_modded.size(), GGML_DEFAULT_GRAPH_SIZE, false);
 
-// int64_t ne1[4] = {32, 1024, 1, 1};
-// int64_t ne2[4] = {32, 2048, 1, 1};;
-// int64_t ne3[4] = {1024, 2048, 1, 1};
-// main: original  e = 8146770.5000
-// main: optimized e = 651119.1250
+        printf("Backend %zu/%zu: %s\n", i + 1, dev_count, ggml_backend_dev_name(devs[i]));
+        printf("  Device description: %s\n", ggml_backend_dev_description(devs[i]));
+        size_t free, total; // NOLINT
+        ggml_backend_dev_memory(devs[i], &free, &total);
+        printf("  Device memory: %zu MB (%zu MB free)\n", total / 1024 / 1024, free / 1024 / 1024);
+        printf("\n");
+
+        std::pair<int, int> result = test_backend(backend_sched, backends[i]);
+
+        printf("  %d/%d tests passed\n", result.first, result.second);
+        printf("  Backend %s: ", ggml_backend_name(backends[i]));
+        if (result.first == result.second) {
+            printf("\033[1;32mOK\033[0m\n");
+            n_ok++;
+        } else {
+            printf("\033[1;31mFAIL\033[0m\n");
+        }
+
+        printf("\n");
+
+        ggml_backend_sched_free(backend_sched);
+    }
+
+    for (ggml_backend_t backend : backends) {
+        ggml_backend_free(backend);
+    }
+
+    printf("%zu/%zu backends passed\n", n_ok, dev_count);
+    if (n_ok != dev_count) {
+        printf("\033[1;31mFAIL\033[0m\n");
+        return 1;
+    }
+    printf("\033[1;32mOK\033[0m\n");
+    return 0;
+}
diff --git a/tests/test1.c b/tests/test1.c
deleted file mode 100644 (file)
index 1a2db19..0000000
+++ /dev/null
@@ -1,459 +0,0 @@
-#include "ggml.h"
-#include "ggml-cpu.h"
-
-#include <stdio.h>
-#include <stdlib.h>
-
-int main(int argc, const char ** argv) {
-    const int n_threads = 2;
-
-    struct ggml_init_params params = {
-        .mem_size   = 128*1024*1024,
-        .mem_buffer = NULL,
-        .no_alloc   = false,
-    };
-
-    struct ggml_context * ctx0 = ggml_init(params);
-
-    {
-        struct ggml_tensor * x = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-
-        ggml_set_param(ctx0, x);
-
-        struct ggml_tensor * a = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-        struct ggml_tensor * b = ggml_mul(ctx0, x, x);
-        struct ggml_tensor * f = ggml_mul(ctx0, b, a);
-
-        // a*x^2
-        // 2*a*x
-
-        ggml_print_objects(ctx0);
-
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
-        ggml_build_forward_expand(gf, f);
-        struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf);
-        ggml_build_backward_expand(ctx0, gf, gb, false);
-
-        ggml_set_f32(x, 2.0f);
-        ggml_set_f32(a, 3.0f);
-
-        ggml_graph_reset(gf);
-        ggml_set_f32(f->grad, 1.0f);
-
-        ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
-
-        printf("f     = %f\n", ggml_get_f32_1d(f, 0));
-        printf("df/dx = %f\n", ggml_get_f32_1d(x->grad, 0));
-
-        GGML_ASSERT(ggml_get_f32_1d(f, 0)       == 12.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x->grad, 0) == 12.0f);
-
-        ggml_set_f32(x, 3.0f);
-
-        ggml_graph_reset(gf);
-        ggml_set_f32(f->grad, 1.0f);
-
-        ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
-
-        printf("f     = %f\n", ggml_get_f32_1d(f, 0));
-        printf("df/dx = %f\n", ggml_get_f32_1d(x->grad, 0));
-
-        GGML_ASSERT(ggml_get_f32_1d(f, 0)       == 27.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x->grad, 0) == 18.0f);
-
-        ggml_graph_dump_dot(gf, NULL, "test1-1-forward.dot");
-        ggml_graph_dump_dot(gb, gf,   "test1-1-backward.dot");
-    }
-
-    ///////////////////////////////////////////////////////////////
-
-    {
-        struct ggml_tensor * x1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-        struct ggml_tensor * x2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-        struct ggml_tensor * x3 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-
-        ggml_set_f32(x1, 3.0f);
-        ggml_set_f32(x2, 1.0f);
-        ggml_set_f32(x3, 0.0f);
-
-        ggml_set_param(ctx0, x1);
-        ggml_set_param(ctx0, x2);
-
-        struct ggml_tensor * y = ggml_add(ctx0, ggml_mul(ctx0, x1, x1), ggml_mul(ctx0, x1, x2));
-
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
-        ggml_build_forward_expand(gf, y);
-        struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf);
-        ggml_build_backward_expand(ctx0, gf, gb, false);
-
-        ggml_graph_reset(gf);
-        ggml_set_f32(y->grad, 1.0f);
-
-        ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
-
-        printf("y      = %f\n", ggml_get_f32_1d(y, 0));
-        printf("df/dx1 = %f\n", ggml_get_f32_1d(x1->grad, 0));
-        printf("df/dx2 = %f\n", ggml_get_f32_1d(x2->grad, 0));
-
-        GGML_ASSERT(ggml_get_f32_1d(y, 0)        == 12.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x1->grad, 0) == 7.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x2->grad, 0) == 3.0f);
-
-        struct ggml_tensor * g1 = x1->grad;
-        struct ggml_tensor * g2 = x2->grad;
-
-        struct ggml_cgraph * gbb = ggml_graph_dup(ctx0, gb);
-
-        ggml_build_backward_expand(ctx0, gb, gbb, false);
-
-        ggml_graph_reset(gb);
-        ggml_set_f32(g1->grad, 1.0f);
-        ggml_set_f32(g2->grad, 1.0f);
-
-        ggml_graph_compute_with_ctx(ctx0, gbb, n_threads);
-
-        printf("H * [1, 1] = [ %f %f ]\n", ggml_get_f32_1d(x1->grad, 0), ggml_get_f32_1d(x2->grad, 0));
-
-        GGML_ASSERT(ggml_get_f32_1d(x1->grad, 0) == 3.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x2->grad, 0) == 1.0f);
-
-        ggml_graph_dump_dot(gf, NULL, "test1-2-forward.dot");
-        ggml_graph_dump_dot(gb, gf,   "test1-2-backward.dot");
-    }
-
-    ///////////////////////////////////////////////////////////////
-
-    {
-        struct ggml_tensor * x1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-        struct ggml_tensor * x2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-
-        ggml_set_param(ctx0, x1);
-        ggml_set_param(ctx0, x2);
-
-        struct ggml_tensor * y = ggml_mul(ctx0, ggml_add(ctx0, ggml_mul(ctx0, x1, x1), ggml_mul(ctx0, x1, x2)), x1);
-
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
-        ggml_build_forward_expand(gf, y);
-        struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf);
-        ggml_build_backward_expand(ctx0, gf, gb, false);
-
-        ggml_set_f32(x1, 3.0f);
-        ggml_set_f32(x2, 4.0f);
-
-        ggml_graph_reset(gf);
-        ggml_set_f32(y->grad, 1.0f);
-
-        ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
-
-        printf("y      = %f\n", ggml_get_f32_1d(y, 0));
-        printf("df/dx1 = %f\n", ggml_get_f32_1d(x1->grad, 0));
-        printf("df/dx2 = %f\n", ggml_get_f32_1d(x2->grad, 0));
-
-        GGML_ASSERT(ggml_get_f32_1d(y, 0)        == 63.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x1->grad, 0) == 51.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x2->grad, 0) == 9.0f);
-
-        ggml_graph_dump_dot(gf, NULL, "test1-3-forward.dot");
-        ggml_graph_dump_dot(gb, gf,   "test1-3-backward.dot");
-    }
-
-    ///////////////////////////////////////////////////////////////
-
-    {
-        struct ggml_tensor * x1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-        struct ggml_tensor * x2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-        struct ggml_tensor * x3 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-
-        ggml_set_param(ctx0, x1);
-        ggml_set_param(ctx0, x2);
-        ggml_set_param(ctx0, x3);
-
-        struct ggml_tensor * y = ggml_mul(ctx0, ggml_mul(ctx0, ggml_mul(ctx0, x1, x1), ggml_mul(ctx0, x2, x2)), x3);
-
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
-        ggml_build_forward_expand(gf, y);
-        struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf);
-        ggml_build_backward_expand(ctx0, gf, gb, false);
-
-        ggml_set_f32(x1, 1.0f);
-        ggml_set_f32(x2, 2.0f);
-        ggml_set_f32(x3, 3.0f);
-
-        ggml_graph_reset(gf);
-        ggml_set_f32(y->grad, 1.0f);
-
-        ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
-
-        printf("y      = %f\n", ggml_get_f32_1d(y, 0));
-        printf("df/dx1 = %f\n", ggml_get_f32_1d(x1->grad, 0));
-        printf("df/dx2 = %f\n", ggml_get_f32_1d(x2->grad, 0));
-        printf("df/dx3 = %f\n", ggml_get_f32_1d(x3->grad, 0));
-
-        GGML_ASSERT(ggml_get_f32_1d(y, 0)        == 12.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x1->grad, 0) == 24.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x2->grad, 0) == 12.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x3->grad, 0) == 4.0f);
-
-        struct ggml_tensor * g1 = x1->grad;
-        struct ggml_tensor * g2 = x2->grad;
-        struct ggml_tensor * g3 = x3->grad;
-
-        struct ggml_cgraph * gbb = ggml_graph_dup(ctx0, gb);
-
-        ggml_build_backward_expand(ctx0, gb, gbb, false);
-
-        ggml_graph_reset(gb);
-        ggml_set_f32(g1->grad, 1.0f);
-        ggml_set_f32(g2->grad, 1.0f);
-        ggml_set_f32(g3->grad, 1.0f);
-
-        ggml_graph_compute_with_ctx(ctx0, gbb, n_threads);
-
-        printf("H * [1, 1, 1] = [ %f %f %f ]\n",
-                ggml_get_f32_1d(x1->grad, 0),
-                ggml_get_f32_1d(x2->grad, 0),
-                ggml_get_f32_1d(x3->grad, 0));
-
-        GGML_ASSERT(ggml_get_f32_1d(x1->grad, 0) == 56.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x2->grad, 0) == 34.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x3->grad, 0) == 12.0f);
-
-        ggml_graph_dump_dot(gf, NULL, "test1-4-forward.dot");
-        ggml_graph_dump_dot(gb, gf,   "test1-4-backward.dot");
-    }
-
-    ///////////////////////////////////////////////////////////////
-
-    {
-        struct ggml_tensor * x1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 3);
-        struct ggml_tensor * x2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 3);
-
-        ggml_set_param(ctx0, x1);
-        ggml_set_param(ctx0, x2);
-
-        struct ggml_tensor * y = ggml_sum(ctx0, ggml_mul(ctx0, x1, x2));
-
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
-        ggml_build_forward_expand(gf, y);
-        struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf);
-        ggml_build_backward_expand(ctx0, gf, gb, false);
-
-        ggml_set_f32(x1, 3.0f);
-        ggml_set_f32(x2, 5.0f);
-
-        ggml_graph_reset(gf);
-        ggml_set_f32(y->grad, 1.0f);
-
-        ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
-
-        printf("y      = %f\n", ggml_get_f32_1d(y, 0));
-        printf("df/dx1 = %f %f %f\n",
-                ggml_get_f32_1d(x1->grad, 0),
-                ggml_get_f32_1d(x1->grad, 1),
-                ggml_get_f32_1d(x1->grad, 2));
-        printf("df/dx2 = %f %f %f\n",
-                ggml_get_f32_1d(x2->grad, 0),
-                ggml_get_f32_1d(x2->grad, 1),
-                ggml_get_f32_1d(x2->grad, 2));
-
-        GGML_ASSERT(ggml_get_f32_1d(y, 0)        == 45.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x1->grad, 0) == 5.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x2->grad, 0) == 3.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x1->grad, 1) == 5.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x2->grad, 1) == 3.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x1->grad, 2) == 5.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x2->grad, 2) == 3.0f);
-
-        ggml_graph_dump_dot(gf, NULL, "test1-5-forward.dot");
-        ggml_graph_dump_dot(gb, gf,   "test1-5-backward.dot");
-    }
-
-    ///////////////////////////////////////////////////////////////
-
-    {
-        struct ggml_tensor * x1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 3);
-        struct ggml_tensor * x2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 3);
-
-        ggml_set_param(ctx0, x1);
-        ggml_set_param(ctx0, x2);
-
-        struct ggml_tensor * y =
-            ggml_sum(ctx0,
-                    ggml_add(ctx0,
-                        ggml_mul(ctx0, x1, x2),
-                        ggml_mul(ctx0,
-                            ggml_repeat(ctx0, ggml_new_f32(ctx0, -2.0f), x1),
-                            ggml_mul(ctx0, x1, x1)
-                            )
-                        )
-                    );
-
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
-        ggml_build_forward_expand(gf, y);
-        struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf);
-        ggml_build_backward_expand(ctx0, gf, gb, false);
-
-        ggml_set_f32(x1, 3.0f);
-        ggml_set_f32(x2, 5.0f);
-
-        ggml_graph_reset(gf);
-        ggml_set_f32(y->grad, 1.0f);
-
-        ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
-
-        printf("y      = %f\n", ggml_get_f32_1d(y, 0));
-        printf("df/dx1 = %f %f %f\n",
-                ggml_get_f32_1d(x1->grad, 0),
-                ggml_get_f32_1d(x1->grad, 1),
-                ggml_get_f32_1d(x1->grad, 2));
-        printf("df/dx2 = %f %f %f\n",
-                ggml_get_f32_1d(x2->grad, 0),
-                ggml_get_f32_1d(x2->grad, 1),
-                ggml_get_f32_1d(x2->grad, 2));
-
-        GGML_ASSERT(ggml_get_f32_1d(y, 0)              == -9.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x1->grad, 0) == -7.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x1->grad, 1) == -7.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x1->grad, 2) == -7.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x2->grad, 0) ==  3.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x2->grad, 1) ==  3.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x2->grad, 2) ==  3.0f);
-
-        ggml_graph_dump_dot(gf, NULL, "test1-6-forward.dot");
-        ggml_graph_dump_dot(gb, gf,   "test1-6-backward.dot");
-    }
-
-    ///////////////////////////////////////////////////////////////
-
-    {
-        struct ggml_tensor * x1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 3);
-        struct ggml_tensor * x2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 3);
-
-        ggml_set_param(ctx0, x1);
-        ggml_set_param(ctx0, x2);
-
-        struct ggml_tensor * y =
-            ggml_sum(ctx0,
-                    ggml_sub(ctx0,
-                        ggml_mul(ctx0, x1, x2),
-                        ggml_mul(ctx0,
-                            ggml_mul(ctx0, x1, x1),
-                            ggml_repeat(ctx0, ggml_new_f32(ctx0, -2.0f), x1)
-                            )
-                        )
-                    );
-
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
-        ggml_build_forward_expand(gf, y);
-        struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf);
-        ggml_build_backward_expand(ctx0, gf, gb, false);
-
-        ggml_set_f32(x1, 3.0f);
-        ggml_set_f32(x2, 5.0f);
-
-        ggml_graph_reset(gf);
-        ggml_set_f32(y->grad, 1.0f);
-
-        ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
-
-        printf("y      = %f\n", ggml_get_f32_1d(y, 0));
-        printf("df/dx1 = %f %f %f\n",
-                ggml_get_f32_1d(x1->grad, 0),
-                ggml_get_f32_1d(x1->grad, 1),
-                ggml_get_f32_1d(x1->grad, 2));
-        printf("df/dx2 = %f %f %f\n",
-                ggml_get_f32_1d(x2->grad, 0),
-                ggml_get_f32_1d(x2->grad, 1),
-                ggml_get_f32_1d(x2->grad, 2));
-
-        GGML_ASSERT(ggml_get_f32_1d(y, 0)        == 99.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x1->grad, 0) == 17.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x1->grad, 1) == 17.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x1->grad, 2) == 17.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x2->grad, 0) ==  3.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x2->grad, 1) ==  3.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x2->grad, 2) ==  3.0f);
-
-        ggml_graph_dump_dot(gf, NULL, "test1-7-forward.dot");
-        ggml_graph_dump_dot(gb, gf,   "test1-7-backward.dot");
-    }
-
-    ///////////////////////////////////////////////////////////////
-
-    {
-        struct ggml_tensor * x1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 3);
-        struct ggml_tensor * x2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 3);
-
-        ggml_set_param(ctx0, x1);
-        ggml_set_param(ctx0, x2);
-
-        struct ggml_tensor * y =
-            ggml_abs(ctx0,
-                    ggml_sub(ctx0, x1, x2)
-                    );
-
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
-        ggml_build_forward_expand(gf, y);
-        struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf);
-        ggml_build_backward_expand(ctx0, gf, gb, false);
-
-        ggml_set_f32(x1, 3.0f);
-        ggml_set_f32(x2, 5.0f);
-
-        ggml_graph_reset(gf);
-        ggml_set_f32(y->grad, 1.0f);
-
-        ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
-
-        printf("y      = %f\n", ggml_get_f32_1d(y, 0));
-        printf("df/dx1 = %f %f %f\n",
-                ggml_get_f32_1d(x1->grad, 0),
-                ggml_get_f32_1d(x1->grad, 1),
-                ggml_get_f32_1d(x1->grad, 2));
-        printf("df/dx2 = %f %f %f\n",
-                ggml_get_f32_1d(x2->grad, 0),
-                ggml_get_f32_1d(x2->grad, 1),
-                ggml_get_f32_1d(x2->grad, 2));
-
-        GGML_ASSERT(ggml_get_f32_1d(y, 0)        ==  2.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x1->grad, 0) == -1.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x1->grad, 1) == -1.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x1->grad, 2) == -1.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x2->grad, 0) ==  1.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x2->grad, 1) ==  1.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x2->grad, 2) ==  1.0f);
-
-        ggml_set_f32(x1, 7.0f);
-        ggml_set_f32(x2, 5.0f);
-
-        ggml_graph_reset(gf);
-        ggml_set_f32(y->grad, 1.0f);
-
-        ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
-
-        printf("y      = %f\n", ggml_get_f32_1d(y, 0));
-        printf("df/dx1 = %f %f %f\n",
-                ggml_get_f32_1d(x1->grad, 0),
-                ggml_get_f32_1d(x1->grad, 1),
-                ggml_get_f32_1d(x1->grad, 2));
-        printf("df/dx2 = %f %f %f\n",
-                ggml_get_f32_1d(x2->grad, 0),
-                ggml_get_f32_1d(x2->grad, 1),
-                ggml_get_f32_1d(x2->grad, 2));
-
-        GGML_ASSERT(ggml_get_f32_1d(y, 0)        ==  2.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x1->grad, 0) ==  1.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x1->grad, 1) ==  1.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x1->grad, 2) ==  1.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x2->grad, 0) == -1.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x2->grad, 1) == -1.0f);
-        GGML_ASSERT(ggml_get_f32_1d(x2->grad, 2) == -1.0f);
-
-        ggml_graph_dump_dot(gf, NULL, "test1-8-forward.dot");
-        ggml_graph_dump_dot(gb, gf,   "test1-8-backward.dot");
-    }
-
-    ggml_free(ctx0);
-
-    return 0;
-}
diff --git a/tests/test1.zig b/tests/test1.zig
deleted file mode 100644 (file)
index 815d814..0000000
+++ /dev/null
@@ -1,450 +0,0 @@
-const std = @import("std");
-const c = @cImport({
-    @cInclude("ggml.h");
-});
-
-pub fn main() !void {
-    const n_threads = 2;
-
-    const params = .{
-        .mem_size = 128 * 1024 * 1024,
-        .mem_buffer = null,
-        .no_alloc = false,
-    };
-
-    const ctx0 = c.ggml_init(params);
-    defer c.ggml_free(ctx0);
-
-    {
-        const x = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);
-
-        c.ggml_set_param(ctx0, x);
-
-        const a = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);
-        const b = c.ggml_mul(ctx0, x, x);
-        const f = c.ggml_mul(ctx0, b, a);
-
-        // a*x^2
-        // 2*a*x
-
-        c.ggml_print_objects(ctx0);
-
-        const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true);
-        c.ggml_build_forward_expand(gf, f);
-        const gb = c.ggml_graph_dup(ctx0, @constCast(gf));
-        c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false);
-
-        _ = c.ggml_set_f32(x, 2.0);
-        _ = c.ggml_set_f32(a, 3.0);
-
-        c.ggml_graph_reset(@constCast(gf));
-        _ = c.ggml_set_f32(f.*.grad, 1.0);
-
-        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);
-
-        std.debug.print("f     = {d:.6}\n", .{c.ggml_get_f32_1d(f, 0)});
-        std.debug.print("df/dx = {d:.6}\n", .{c.ggml_get_f32_1d(x.*.grad, 0)});
-
-        try std.testing.expect(c.ggml_get_f32_1d(f, 0) == 12.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x.*.grad, 0) == 12.0);
-
-        _ = c.ggml_set_f32(x, 3.0);
-
-        c.ggml_graph_reset(@constCast(gf));
-        _ = c.ggml_set_f32(f.*.grad, 1.0);
-
-        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);
-
-        std.debug.print("f     = {d:.6}\n", .{c.ggml_get_f32_1d(f, 0)});
-        std.debug.print("df/dx = {d:.6}\n", .{c.ggml_get_f32_1d(x.*.grad, 0)});
-
-        try std.testing.expect(c.ggml_get_f32_1d(f, 0) == 27.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x.*.grad, 0) == 18.0);
-
-        c.ggml_graph_dump_dot(gf, null, "test1-1-forward.dot");
-        c.ggml_graph_dump_dot(gb, gf, "test1-1-backward.dot");
-    }
-
-    /////////////////////////////////////////////////////////////
-
-    {
-        const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);
-        const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);
-        const x3 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);
-
-        _ = c.ggml_set_f32(x1, 3.0);
-        _ = c.ggml_set_f32(x2, 1.0);
-        _ = c.ggml_set_f32(x3, 0.0);
-
-        c.ggml_set_param(ctx0, x1);
-        c.ggml_set_param(ctx0, x2);
-
-        const y = c.ggml_add(ctx0, c.ggml_mul(ctx0, x1, x1), c.ggml_mul(ctx0, x1, x2));
-
-        const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true);
-        c.ggml_build_forward_expand(gf, y);
-        const gb = c.ggml_graph_dup(ctx0, @constCast(gf));
-        c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false);
-
-        c.ggml_graph_reset(@constCast(gf));
-        _ = c.ggml_set_f32(y.*.grad, 1.0);
-
-        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);
-
-        std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});
-        std.debug.print("df/dx1 = {d:.6}\n", .{c.ggml_get_f32_1d(x1.*.grad, 0)});
-        std.debug.print("df/dx2 = {d:.6}\n", .{c.ggml_get_f32_1d(x2.*.grad, 0)});
-
-        try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 12.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 7.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 3.0);
-
-        const g1 = x1.*.grad;
-        const g2 = x2.*.grad;
-
-        const gbb = c.ggml_graph_dup(ctx0, @constCast(gb));
-
-        c.ggml_build_backward_expand(ctx0, @constCast(gb), @constCast(gbb), true);
-
-        c.ggml_graph_reset(@constCast(gb));
-        _ = c.ggml_set_f32(g1.*.grad, 1.0);
-        _ = c.ggml_set_f32(g2.*.grad, 1.0);
-
-        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gbb), n_threads);
-
-        std.debug.print("H * [1, 1] = [ {d:.6} {d:.6} ]\n", .{ c.ggml_get_f32_1d(x1.*.grad, 0), c.ggml_get_f32_1d(x2.*.grad, 0) });
-
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 3.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 1.0);
-
-        c.ggml_graph_dump_dot(gf, null, "test1-2-forward.dot");
-        c.ggml_graph_dump_dot(gb, gf, "test1-2-backward.dot");
-    }
-
-    ///////////////////////////////////////////////////////////////
-
-    {
-        const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);
-        const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);
-
-        c.ggml_set_param(ctx0, x1);
-        c.ggml_set_param(ctx0, x2);
-
-        const y = c.ggml_mul(ctx0, c.ggml_add(ctx0, c.ggml_mul(ctx0, x1, x1), c.ggml_mul(ctx0, x1, x2)), x1);
-
-        const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true);
-        c.ggml_build_forward_expand(gf, y);
-        const gb = c.ggml_graph_dup(ctx0, @constCast(gf));
-        c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false);
-
-        _ = c.ggml_set_f32(x1, 3.0);
-        _ = c.ggml_set_f32(x2, 4.0);
-
-        c.ggml_graph_reset(@constCast(gf));
-        _ = c.ggml_set_f32(y.*.grad, 1.0);
-
-        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);
-
-        std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});
-        std.debug.print("df/dx1 = {d:.6}\n", .{c.ggml_get_f32_1d(x1.*.grad, 0)});
-        std.debug.print("df/dx2 = {d:.6}\n", .{c.ggml_get_f32_1d(x2.*.grad, 0)});
-
-        try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 63.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 51.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 9.0);
-
-        c.ggml_graph_dump_dot(gf, null, "test1-3-forward.dot");
-        c.ggml_graph_dump_dot(gb, gf, "test1-3-backward.dot");
-    }
-
-    ///////////////////////////////////////////////////////////////
-
-    {
-        const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);
-        const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);
-        const x3 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);
-
-        c.ggml_set_param(ctx0, x1);
-        c.ggml_set_param(ctx0, x2);
-        c.ggml_set_param(ctx0, x3);
-
-        const y = c.ggml_mul(ctx0, c.ggml_mul(ctx0, c.ggml_mul(ctx0, x1, x1), c.ggml_mul(ctx0, x2, x2)), x3);
-
-        const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true);
-        c.ggml_build_forward_expand(gf, y);
-        const gb = c.ggml_graph_dup(ctx0, @constCast(gf));
-        c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false);
-
-        _ = c.ggml_set_f32(x1, 1.0);
-        _ = c.ggml_set_f32(x2, 2.0);
-        _ = c.ggml_set_f32(x3, 3.0);
-
-        c.ggml_graph_reset(@constCast(gf));
-        _ = c.ggml_set_f32(y.*.grad, 1.0);
-
-        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);
-
-        std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});
-        std.debug.print("df/dx1 = {d:.6}\n", .{c.ggml_get_f32_1d(x1.*.grad, 0)});
-        std.debug.print("df/dx2 = {d:.6}\n", .{c.ggml_get_f32_1d(x2.*.grad, 0)});
-        std.debug.print("df/dx3 = {d:.6}\n", .{c.ggml_get_f32_1d(x3.*.grad, 0)});
-
-        try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 12.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 24.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 12.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x3.*.grad, 0) == 4.0);
-
-        const g1 = x1.*.grad;
-        const g2 = x2.*.grad;
-        const g3 = x3.*.grad;
-
-        const gbb = c.ggml_graph_dup(ctx0, @constCast(gb));
-
-        c.ggml_build_backward_expand(ctx0, @constCast(gb), @constCast(gbb), true);
-
-        c.ggml_graph_reset(@constCast(gb));
-        _ = c.ggml_set_f32(g1.*.grad, 1.0);
-        _ = c.ggml_set_f32(g2.*.grad, 1.0);
-        _ = c.ggml_set_f32(g3.*.grad, 1.0);
-
-        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gbb), n_threads);
-
-        std.debug.print("H * [1, 1, 1] = [ {d:.6} {d:.6} {d:.6}]\n", .{
-            c.ggml_get_f32_1d(x1.*.grad, 0),
-            c.ggml_get_f32_1d(x2.*.grad, 0),
-            c.ggml_get_f32_1d(x3.*.grad, 0),
-        });
-
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 56.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 34.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x3.*.grad, 0) == 12.0);
-
-        c.ggml_graph_dump_dot(gf, null, "test1-4-forward.dot");
-        c.ggml_graph_dump_dot(gb, gf, "test1-4-backward.dot");
-    }
-
-    ///////////////////////////////////////////////////////////////
-
-    {
-        const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3);
-        const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3);
-
-        c.ggml_set_param(ctx0, x1);
-        c.ggml_set_param(ctx0, x2);
-
-        const y = c.ggml_sum(ctx0, c.ggml_mul(ctx0, x1, x2));
-
-        const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true);
-        c.ggml_build_forward_expand(gf, y);
-        const gb = c.ggml_graph_dup(ctx0, @constCast(gf));
-        c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false);
-
-        _ = c.ggml_set_f32(x1, 3.0);
-        _ = c.ggml_set_f32(x2, 5.0);
-
-        c.ggml_graph_reset(@constCast(gf));
-        _ = c.ggml_set_f32(y.*.grad, 1.0);
-
-        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);
-
-        std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});
-        std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", .{
-            c.ggml_get_f32_1d(x1.*.grad, 0),
-            c.ggml_get_f32_1d(x1.*.grad, 1),
-            c.ggml_get_f32_1d(x1.*.grad, 2),
-        });
-        std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", .{
-            c.ggml_get_f32_1d(x2.*.grad, 0),
-            c.ggml_get_f32_1d(x2.*.grad, 1),
-            c.ggml_get_f32_1d(x2.*.grad, 2),
-        });
-
-        try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 45.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 5.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 3.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 1) == 5.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 1) == 3.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 2) == 5.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 2) == 3.0);
-
-        c.ggml_graph_dump_dot(gf, null, "test1-5-forward.dot");
-        c.ggml_graph_dump_dot(gb, gf, "test1-5-backward.dot");
-    }
-
-    ///////////////////////////////////////////////////////////////
-
-    {
-        const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3);
-        const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3);
-
-        c.ggml_set_param(ctx0, x1);
-        c.ggml_set_param(ctx0, x2);
-
-        const y =
-            c.ggml_sum(ctx0, c.ggml_add(ctx0, c.ggml_mul(ctx0, x1, x2), c.ggml_mul(ctx0, c.ggml_repeat(ctx0, c.ggml_new_f32(ctx0, -2.0), x1), c.ggml_mul(ctx0, x1, x1))));
-
-        const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true);
-        c.ggml_build_forward_expand(gf, y);
-        const gb = c.ggml_graph_dup(ctx0, @constCast(gf));
-        c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false);
-
-        _ = c.ggml_set_f32(x1, 3.0);
-        _ = c.ggml_set_f32(x2, 5.0);
-
-        c.ggml_graph_reset(@constCast(gf));
-        _ = c.ggml_set_f32(y.*.grad, 1.0);
-
-        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);
-
-        std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});
-        std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", .{
-            c.ggml_get_f32_1d(x1.*.grad, 0),
-            c.ggml_get_f32_1d(x1.*.grad, 1),
-            c.ggml_get_f32_1d(x1.*.grad, 2),
-        });
-        std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", .{
-            c.ggml_get_f32_1d(x2.*.grad, 0),
-            c.ggml_get_f32_1d(x2.*.grad, 1),
-            c.ggml_get_f32_1d(x2.*.grad, 2),
-        });
-
-        try std.testing.expect(c.ggml_get_f32_1d(y, 0) == -9.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == -7.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 1) == -7.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 2) == -7.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 3.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 1) == 3.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 2) == 3.0);
-
-        c.ggml_graph_dump_dot(gf, null, "test1-6-forward.dot");
-        c.ggml_graph_dump_dot(gb, gf, "test1-6-backward.dot");
-    }
-
-    ///////////////////////////////////////////////////////////////
-
-    {
-        const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3);
-        const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3);
-
-        c.ggml_set_param(ctx0, x1);
-        c.ggml_set_param(ctx0, x2);
-
-        const y =
-            c.ggml_sum(ctx0, c.ggml_sub(ctx0, c.ggml_mul(ctx0, x1, x2), c.ggml_mul(ctx0, c.ggml_mul(ctx0, x1, x1), c.ggml_repeat(ctx0, c.ggml_new_f32(ctx0, -2.0), x1))));
-
-        const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true);
-        c.ggml_build_forward_expand(gf, y);
-        const gb = c.ggml_graph_dup(ctx0, @constCast(gf));
-        c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false);
-
-        _ = c.ggml_set_f32(x1, 3.0);
-        _ = c.ggml_set_f32(x2, 5.0);
-
-        c.ggml_graph_reset(@constCast(gf));
-        _ = c.ggml_set_f32(y.*.grad, 1.0);
-
-        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);
-
-        std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});
-        std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", .{
-            c.ggml_get_f32_1d(x1.*.grad, 0),
-            c.ggml_get_f32_1d(x1.*.grad, 1),
-            c.ggml_get_f32_1d(x1.*.grad, 2),
-        });
-        std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", .{
-            c.ggml_get_f32_1d(x2.*.grad, 0),
-            c.ggml_get_f32_1d(x2.*.grad, 1),
-            c.ggml_get_f32_1d(x2.*.grad, 2),
-        });
-
-        try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 99.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 17.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 1) == 17.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 2) == 17.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 3.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 1) == 3.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 2) == 3.0);
-
-        c.ggml_graph_dump_dot(gf, null, "test1-7-forward.dot");
-        c.ggml_graph_dump_dot(gb, gf, "test1-7-backward.dot");
-    }
-
-    ///////////////////////////////////////////////////////////////
-
-    {
-        const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3);
-        const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 3);
-
-        c.ggml_set_param(ctx0, x1);
-        c.ggml_set_param(ctx0, x2);
-
-        const y =
-            c.ggml_abs(ctx0, c.ggml_sub(ctx0, x1, x2));
-
-        const gf = c.ggml_new_graph_custom(ctx0, c.GGML_DEFAULT_GRAPH_SIZE, true);
-        c.ggml_build_forward_expand(gf, y);
-        const gb = c.ggml_graph_dup(ctx0, @constCast(gf));
-        c.ggml_build_backward_expand(ctx0, @constCast(gf), @constCast(gb), false);
-
-        _ = c.ggml_set_f32(x1, 3.0);
-        _ = c.ggml_set_f32(x2, 5.0);
-
-        c.ggml_graph_reset(@constCast(gf));
-        _ = c.ggml_set_f32(y.*.grad, 1.0);
-
-        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);
-
-        std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});
-        std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", .{
-            c.ggml_get_f32_1d(x1.*.grad, 0),
-            c.ggml_get_f32_1d(x1.*.grad, 1),
-            c.ggml_get_f32_1d(x1.*.grad, 2),
-        });
-        std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", .{
-            c.ggml_get_f32_1d(x2.*.grad, 0),
-            c.ggml_get_f32_1d(x2.*.grad, 1),
-            c.ggml_get_f32_1d(x2.*.grad, 2),
-        });
-
-        try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 2.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == -1.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 1) == -1.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 2) == -1.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 1.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 1) == 1.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 2) == 1.0);
-
-        _ = c.ggml_set_f32(x1, 7.0);
-        _ = c.ggml_set_f32(x2, 5.0);
-
-        c.ggml_graph_reset(@constCast(gf));
-        _ = c.ggml_set_f32(y.*.grad, 1.0);
-
-        _ = c.ggml_graph_compute_with_ctx(ctx0, @constCast(gb), n_threads);
-
-        std.debug.print("y      = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});
-        std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n", .{
-            c.ggml_get_f32_1d(x1.*.grad, 0),
-            c.ggml_get_f32_1d(x1.*.grad, 1),
-            c.ggml_get_f32_1d(x1.*.grad, 2),
-        });
-        std.debug.print("df/dx2 = {d:.6} {d:.6} {d:.6}\n", .{
-            c.ggml_get_f32_1d(x2.*.grad, 0),
-            c.ggml_get_f32_1d(x2.*.grad, 1),
-            c.ggml_get_f32_1d(x2.*.grad, 2),
-        });
-
-        try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 2.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 1.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 1) == 1.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 2) == 1.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == -1.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 1) == -1.0);
-        try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 2) == -1.0);
-
-        c.ggml_graph_dump_dot(gf, null, "test1-8-forward.dot");
-        c.ggml_graph_dump_dot(gb, gf, "test1-8-backward.dot");
-    }
-
-    _ = try std.io.getStdIn().reader().readByte();
-}
diff --git a/tests/test2.c b/tests/test2.c
deleted file mode 100644 (file)
index 76c044f..0000000
+++ /dev/null
@@ -1,182 +0,0 @@
-#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
-#include "ggml.h"
-#include "ggml-cpu.h"
-
-#include <math.h>
-#include <stdio.h>
-#include <stdlib.h>
-
-#if defined(_MSC_VER)
-#pragma warning(disable: 4244 4267) // possible loss of data
-#endif
-
-bool is_close(float a, float b, float epsilon) {
-    return fabs(a - b) < epsilon;
-}
-
-int main(int argc, const char ** argv) {
-    struct ggml_init_params params = {
-        .mem_size   = 128*1024*1024,
-        .mem_buffer = NULL,
-        .no_alloc   = false,
-    };
-
-    //struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_TYPE_ADAM);
-    //opt_params.adam.alpha = 0.01f;
-
-    struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_TYPE_LBFGS);
-
-    // original threads: 8
-    int nthreads = 8;
-    const char *env = getenv("GGML_NTHREADS");
-    if (env != NULL) {
-        nthreads = atoi(env);
-    }
-    if (argc > 1) {
-        nthreads = atoi(argv[1]);
-    }
-    opt_params.n_threads = nthreads;
-    printf("test2: n_threads:%d\n", opt_params.n_threads);
-
-    const float xi[] = {  1.0f,  2.0f,  3.0f,  4.0f,  5.0f , 6.0f,  7.0f,  8.0f,  9.0f,  10.0f, };
-          float yi[] = { 15.0f, 25.0f, 35.0f, 45.0f, 55.0f, 65.0f, 75.0f, 85.0f, 95.0f, 105.0f, };
-
-    const int n = sizeof(xi)/sizeof(xi[0]);
-
-    struct ggml_context * ctx0 = ggml_init(params);
-
-    struct ggml_tensor * x = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n);
-    struct ggml_tensor * y = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n);
-
-    for (int i = 0; i < n; i++) {
-        ((float *) x->data)[i] = xi[i];
-        ((float *) y->data)[i] = yi[i];
-    }
-
-    {
-        struct ggml_tensor * t0 = ggml_new_f32(ctx0, 0.0f);
-        struct ggml_tensor * t1 = ggml_new_f32(ctx0, 0.0f);
-
-        // initialize auto-diff parameters:
-        ggml_set_param(ctx0, t0);
-        ggml_set_param(ctx0, t1);
-
-        // f = sum_i[(t0 + t1*x_i - y_i)^2]/(2n)
-        struct ggml_tensor * f =
-            ggml_div(ctx0,
-                    ggml_sum(ctx0,
-                        ggml_sqr(ctx0,
-                            ggml_sub(ctx0,
-                                ggml_add(ctx0,
-                                    ggml_mul(ctx0, x, ggml_repeat(ctx0, t1, x)),
-                                    ggml_repeat(ctx0, t0, x)),
-                                y)
-                            )
-                        ),
-                    ggml_new_f32(ctx0, 2.0f*n));
-
-        enum ggml_opt_result res = ggml_opt(NULL, opt_params, f);
-
-        printf("t0 = %f\n", ggml_get_f32_1d(t0, 0));
-        printf("t1 = %f\n", ggml_get_f32_1d(t1, 0));
-
-        GGML_ASSERT(res == GGML_OPT_RESULT_OK);
-
-        GGML_ASSERT(is_close(ggml_get_f32_1d(t0, 0),  5.0f, 1e-3f));
-        GGML_ASSERT(is_close(ggml_get_f32_1d(t1, 0), 10.0f, 1e-3f));
-    }
-
-    {
-        struct ggml_tensor * t0 = ggml_new_f32(ctx0, -1.0f);
-        struct ggml_tensor * t1 = ggml_new_f32(ctx0,  9.0f);
-
-        ggml_set_param(ctx0, t0);
-        ggml_set_param(ctx0, t1);
-
-        // f = 0.5*sum_i[abs(t0 + t1*x_i - y_i)]/n
-        struct ggml_tensor * f =
-            ggml_mul(ctx0,
-                    ggml_new_f32(ctx0, 1.0/(2*n)),
-                    ggml_sum(ctx0,
-                        ggml_abs(ctx0,
-                            ggml_sub(ctx0,
-                                ggml_add(ctx0,
-                                    ggml_mul(ctx0, x, ggml_repeat(ctx0, t1, x)),
-                                    ggml_repeat(ctx0, t0, x)),
-                                y)
-                            )
-                        )
-                    );
-
-
-        enum ggml_opt_result res = ggml_opt(NULL, opt_params, f);
-
-        GGML_ASSERT(res == GGML_OPT_RESULT_OK);
-        GGML_ASSERT(is_close(ggml_get_f32_1d(t0, 0),  5.0f, 1e-2f));
-        GGML_ASSERT(is_close(ggml_get_f32_1d(t1, 0), 10.0f, 1e-2f));
-    }
-
-    {
-        struct ggml_tensor * t0 = ggml_new_f32(ctx0,  5.0f);
-        struct ggml_tensor * t1 = ggml_new_f32(ctx0, -4.0f);
-
-        ggml_set_param(ctx0, t0);
-        ggml_set_param(ctx0, t1);
-
-        // f = t0^2 + t1^2
-        struct ggml_tensor * f =
-            ggml_add(ctx0,
-                    ggml_sqr(ctx0, t0),
-                    ggml_sqr(ctx0, t1)
-                    );
-
-        enum ggml_opt_result res = ggml_opt(NULL, opt_params, f);
-
-        GGML_ASSERT(res == GGML_OPT_RESULT_OK);
-        GGML_ASSERT(is_close(ggml_get_f32_1d(f,  0), 0.0f, 1e-3f));
-        GGML_ASSERT(is_close(ggml_get_f32_1d(t0, 0), 0.0f, 1e-3f));
-        GGML_ASSERT(is_close(ggml_get_f32_1d(t1, 0), 0.0f, 1e-3f));
-    }
-
-    /////////////////////////////////////////
-
-    {
-        struct ggml_tensor * t0 = ggml_new_f32(ctx0, -7.0f);
-        struct ggml_tensor * t1 = ggml_new_f32(ctx0,  8.0f);
-
-        ggml_set_param(ctx0, t0);
-        ggml_set_param(ctx0, t1);
-
-        // f = (t0 + 2*t1 - 7)^2 + (2*t0 + t1 - 5)^2
-        struct ggml_tensor * f =
-            ggml_add(ctx0,
-                    ggml_sqr(ctx0,
-                        ggml_sub(ctx0,
-                            ggml_add(ctx0,
-                                t0,
-                                ggml_mul(ctx0, t1, ggml_new_f32(ctx0, 2.0f))),
-                            ggml_new_f32(ctx0, 7.0f)
-                            )
-                        ),
-                    ggml_sqr(ctx0,
-                        ggml_sub(ctx0,
-                            ggml_add(ctx0,
-                                ggml_mul(ctx0, t0, ggml_new_f32(ctx0, 2.0f)),
-                                t1),
-                            ggml_new_f32(ctx0, 5.0f)
-                            )
-                        )
-                    );
-
-        enum ggml_opt_result res = ggml_opt(NULL, opt_params, f);
-
-        GGML_ASSERT(res == GGML_OPT_RESULT_OK);
-        GGML_ASSERT(is_close(ggml_get_f32_1d(f,  0), 0.0f, 1e-3f));
-        GGML_ASSERT(is_close(ggml_get_f32_1d(t0, 0), 1.0f, 1e-3f));
-        GGML_ASSERT(is_close(ggml_get_f32_1d(t1, 0), 3.0f, 1e-3f));
-    }
-
-    ggml_free(ctx0);
-
-    return 0;
-}
diff --git a/tests/test2.zig b/tests/test2.zig
deleted file mode 100644 (file)
index 783eba6..0000000
+++ /dev/null
@@ -1,123 +0,0 @@
-const std = @import("std");
-const Thread = std.Thread;
-const c = @cImport({
-    @cInclude("ggml.h");
-});
-
-fn is_close(a: f32, b: f32, epsilon: f32) bool {
-    return @abs(a - b) < epsilon;
-}
-
-pub fn main() !void {
-    const params = .{
-        .mem_size = 128 * 1024 * 1024,
-        .mem_buffer = null,
-        .no_alloc = false,
-    };
-
-    var opt_params = c.ggml_opt_default_params(c.GGML_OPT_TYPE_LBFGS);
-
-    const nthreads = try Thread.getCpuCount();
-    opt_params.n_threads = @intCast(nthreads);
-    std.debug.print("test2: n_threads:{}\n", .{opt_params.n_threads});
-
-    const xi = [_]f32{ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 };
-    const yi = [_]f32{ 15.0, 25.0, 35.0, 45.0, 55.0, 65.0, 75.0, 85.0, 95.0, 105.0 };
-
-    const n = xi.len;
-
-    const ctx0 = c.ggml_init(params);
-    defer c.ggml_free(ctx0);
-
-    const x = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, n);
-    const y = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, n);
-
-    for (0..n) |i| {
-        const x_data_pointer: [*]f32 = @ptrCast(@alignCast(x.*.data));
-        x_data_pointer[i] = xi[i];
-        const y_data_pointer: [*]f32 = @ptrCast(@alignCast(y.*.data));
-        y_data_pointer[i] = yi[i];
-    }
-
-    {
-        const t0 = c.ggml_new_f32(ctx0, 0.0);
-        const t1 = c.ggml_new_f32(ctx0, 0.0);
-
-        // initialize auto-diff parameters:
-        _ = c.ggml_set_param(ctx0, t0);
-        _ = c.ggml_set_param(ctx0, t1);
-
-        // f = sum_i[(t0 + t1*x_i - y_i)^2]/(2n)
-        const f =
-            c.ggml_div(ctx0, c.ggml_sum(ctx0, c.ggml_sqr(ctx0, c.ggml_sub(ctx0, c.ggml_add(ctx0, c.ggml_mul(ctx0, x, c.ggml_repeat(ctx0, t1, x)), c.ggml_repeat(ctx0, t0, x)), y))), c.ggml_new_f32(ctx0, @as(f32, 2.0) * n));
-
-        const res = c.ggml_opt(null, opt_params, f);
-
-        std.debug.print("t0 = {d:.6}\n", .{c.ggml_get_f32_1d(t0, 0)});
-        std.debug.print("t1 = {d:.6}\n", .{c.ggml_get_f32_1d(t1, 0)});
-
-        try std.testing.expect(res == c.GGML_OPT_RESULT_OK);
-        try std.testing.expect(is_close(c.ggml_get_f32_1d(t0, 0), 5.0, 1e-3));
-        try std.testing.expect(is_close(c.ggml_get_f32_1d(t1, 0), 10.0, 1e-3));
-    }
-
-    {
-        const t0 = c.ggml_new_f32(ctx0, -1.0);
-        const t1 = c.ggml_new_f32(ctx0, 9.0);
-
-        _ = c.ggml_set_param(ctx0, t0);
-        _ = c.ggml_set_param(ctx0, t1);
-
-        // f = 0.5*sum_i[abs(t0 + t1*x_i - y_i)]/n
-        const f =
-            c.ggml_mul(ctx0, c.ggml_new_f32(ctx0, @as(f32, 1.0) / (2 * n)), c.ggml_sum(ctx0, c.ggml_abs(ctx0, c.ggml_sub(ctx0, c.ggml_add(ctx0, c.ggml_mul(ctx0, x, c.ggml_repeat(ctx0, t1, x)), c.ggml_repeat(ctx0, t0, x)), y))));
-
-        const res = c.ggml_opt(null, opt_params, f);
-
-        try std.testing.expect(res == c.GGML_OPT_RESULT_OK);
-        try std.testing.expect(is_close(c.ggml_get_f32_1d(t0, 0), 5.0, 1e-2));
-        try std.testing.expect(is_close(c.ggml_get_f32_1d(t1, 0), 10.0, 1e-2));
-    }
-
-    {
-        const t0 = c.ggml_new_f32(ctx0, 5.0);
-        const t1 = c.ggml_new_f32(ctx0, -4.0);
-
-        _ = c.ggml_set_param(ctx0, t0);
-        _ = c.ggml_set_param(ctx0, t1);
-
-        // f = t0^2 + t1^2
-        const f =
-            c.ggml_add(ctx0, c.ggml_sqr(ctx0, t0), c.ggml_sqr(ctx0, t1));
-
-        const res = c.ggml_opt(null, opt_params, f);
-
-        try std.testing.expect(res == c.GGML_OPT_RESULT_OK);
-        try std.testing.expect(is_close(c.ggml_get_f32_1d(f, 0), 0.0, 1e-3));
-        try std.testing.expect(is_close(c.ggml_get_f32_1d(t0, 0), 0.0, 1e-3));
-        try std.testing.expect(is_close(c.ggml_get_f32_1d(t1, 0), 0.0, 1e-3));
-    }
-
-    /////////////////////////////////////////
-
-    {
-        const t0 = c.ggml_new_f32(ctx0, -7.0);
-        const t1 = c.ggml_new_f32(ctx0, 8.0);
-
-        _ = c.ggml_set_param(ctx0, t0);
-        _ = c.ggml_set_param(ctx0, t1);
-
-        // f = (t0 + 2*t1 - 7)^2 + (2*t0 + t1 - 5)^2
-        const f =
-            c.ggml_add(ctx0, c.ggml_sqr(ctx0, c.ggml_sub(ctx0, c.ggml_add(ctx0, t0, c.ggml_mul(ctx0, t1, c.ggml_new_f32(ctx0, 2.0))), c.ggml_new_f32(ctx0, 7.0))), c.ggml_sqr(ctx0, c.ggml_sub(ctx0, c.ggml_add(ctx0, c.ggml_mul(ctx0, t0, c.ggml_new_f32(ctx0, 2.0)), t1), c.ggml_new_f32(ctx0, 5.0))));
-
-        const res = c.ggml_opt(null, opt_params, f);
-
-        try std.testing.expect(res == c.GGML_OPT_RESULT_OK);
-        try std.testing.expect(is_close(c.ggml_get_f32_1d(f, 0), 0.0, 1e-3));
-        try std.testing.expect(is_close(c.ggml_get_f32_1d(t0, 0), 1.0, 1e-3));
-        try std.testing.expect(is_close(c.ggml_get_f32_1d(t1, 0), 3.0, 1e-3));
-    }
-
-    _ = try std.io.getStdIn().reader().readByte();
-}
diff --git a/tests/test3.c b/tests/test3.c
deleted file mode 100644 (file)
index d1e9fcc..0000000
+++ /dev/null
@@ -1,95 +0,0 @@
-#include "ggml.h"
-
-#include <math.h>
-#include <stdio.h>
-#include <stdlib.h>
-
-bool is_close(float a, float b, float epsilon) {
-    return fabs(a - b) < epsilon;
-}
-
-int main(int argc, const char ** argv) {
-    struct ggml_init_params params = {
-        .mem_size   = 1024*1024*1024,
-        .mem_buffer = NULL,
-        .no_alloc   = false,
-    };
-
-    //struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_TYPE_ADAM);
-    struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_TYPE_LBFGS);
-
-    opt_params.n_threads = (argc > 1) ? atoi(argv[1]) : 8;
-
-    const int NP = 1 << 12;
-    const int NF = 1 << 8;
-
-    struct ggml_context * ctx0 = ggml_init(params);
-
-    struct ggml_tensor * F = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, NF, NP);
-    struct ggml_tensor * l = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, NP);
-
-    // regularization weight
-    struct ggml_tensor * lambda = ggml_new_f32(ctx0, 1e-5f);
-
-    srand(0);
-
-    for (int j = 0; j < NP; j++) {
-        const float ll = j < NP/2 ? 1.0f : -1.0f;
-        ((float *)l->data)[j] = ll;
-
-        for (int i = 0; i < NF; i++) {
-            ((float *)F->data)[j*NF + i] = ((ll > 0 && i < NF/2 ? 1.0f : ll < 0 && i >= NF/2 ? 1.0f : 0.0f) + ((float)rand()/(float)RAND_MAX - 0.5f)*0.1f)/(0.5f*NF);
-        }
-    }
-
-    {
-        // initial guess
-        struct ggml_tensor * x = ggml_set_f32(ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, NF), 0.0f);
-
-        ggml_set_param(ctx0, x);
-
-        // f = sum[(fj*x - l)^2]/n + lambda*|x^2|
-        struct ggml_tensor * f =
-            ggml_add(ctx0,
-                    ggml_div(ctx0,
-                        ggml_sum(ctx0,
-                            ggml_sqr(ctx0,
-                                ggml_sub(ctx0,
-                                    ggml_mul_mat(ctx0, F, x),
-                                    l)
-                                )
-                            ),
-                        ggml_new_f32(ctx0, (float)NP)
-                        ),
-                    ggml_mul(ctx0,
-                        ggml_sum(ctx0, ggml_sqr(ctx0, x)),
-                        lambda)
-                    );
-
-        enum ggml_opt_result res = ggml_opt(NULL, opt_params, f);
-
-        GGML_ASSERT(res == GGML_OPT_RESULT_OK);
-
-        // print results
-        for (int i = 0; i < 16; i++) {
-            printf("x[%3d] = %g\n", i, ((float *)x->data)[i]);
-        }
-        printf("...\n");
-        for (int i = NF - 16; i < NF; i++) {
-            printf("x[%3d] = %g\n", i, ((float *)x->data)[i]);
-        }
-        printf("\n");
-
-        for (int i = 0; i < NF; ++i) {
-            if (i < NF/2) {
-                GGML_ASSERT(is_close(((float *)x->data)[i],  1.0f, 1e-2f));
-            } else {
-                GGML_ASSERT(is_close(((float *)x->data)[i], -1.0f, 1e-2f));
-            }
-        }
-    }
-
-    ggml_free(ctx0);
-
-    return 0;
-}
diff --git a/tests/test3.zig b/tests/test3.zig
deleted file mode 100644 (file)
index ecaf1b0..0000000
+++ /dev/null
@@ -1,87 +0,0 @@
-const std = @import("std");
-const Thread = std.Thread;
-const c = @cImport({
-    @cInclude("stdlib.h");
-    @cInclude("ggml.h");
-});
-
-fn is_close(a: f32, b: f32, epsilon: f32) bool {
-    return @abs(a - b) < epsilon;
-}
-
-pub fn main() !void {
-    const params = .{
-        .mem_size = 128 * 1024 * 1024,
-        .mem_buffer = null,
-        .no_alloc = false,
-    };
-
-    var opt_params = c.ggml_opt_default_params(c.GGML_OPT_TYPE_LBFGS);
-
-    const nthreads = try Thread.getCpuCount();
-    opt_params.n_threads = @intCast(nthreads);
-
-    const NP = 1 << 12;
-    const NF = 1 << 8;
-
-    const ctx0 = c.ggml_init(params);
-    defer c.ggml_free(ctx0);
-
-    const F = c.ggml_new_tensor_2d(ctx0, c.GGML_TYPE_F32, NF, NP);
-    const l = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, NP);
-
-    // regularization weight
-    const lambda = c.ggml_new_f32(ctx0, 1e-5);
-
-    c.srand(0);
-
-    const l_data_pointer: [*]f32 = @ptrCast(@alignCast(l.*.data));
-    const f_data_pointer: [*]f32 = @ptrCast(@alignCast(F.*.data));
-    for (0..NP) |j| {
-        const ll = if (j < NP / 2) @as(f32, 1.0) else @as(f32, -1.0);
-        l_data_pointer[j] = ll;
-
-        for (0..NF) |i| {
-            const c_rand: f32 = @floatFromInt(c.rand());
-            f_data_pointer[j * NF + i] =
-                ((if (ll > 0 and i < NF / 2) @as(f32, 1.0) else if (ll < 0 and i >= NF / 2) @as(f32, 1.0) else @as(f32, 0.0)) +
-                (c_rand / c.RAND_MAX - 0.5) * 0.1) / (0.5 * NF);
-        }
-    }
-
-    {
-        // initial guess
-        const x = c.ggml_set_f32(c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, NF), 0.0);
-
-        c.ggml_set_param(ctx0, x);
-
-        // f = sum[(fj*x - l)^2]/n + lambda*|x^2|
-        const f =
-            c.ggml_add(ctx0, c.ggml_div(ctx0, c.ggml_sum(ctx0, c.ggml_sqr(ctx0, c.ggml_sub(ctx0, c.ggml_mul_mat(ctx0, F, x), l))), c.ggml_new_f32(ctx0, @as(f32, NP))), c.ggml_mul(ctx0, c.ggml_sum(ctx0, c.ggml_sqr(ctx0, x)), lambda));
-
-        const res = c.ggml_opt(null, opt_params, f);
-
-        try std.testing.expect(res == c.GGML_OPT_RESULT_OK);
-
-        const x_data_pointer: [*]f32 = @ptrCast(@alignCast(x.*.data));
-        // print results
-        for (0..16) |i| {
-            std.debug.print("x[{d:3}] = {d:.6}\n", .{ i, x_data_pointer[i] });
-        }
-        std.debug.print("...\n", .{});
-        for (NF - 16..NF) |i| {
-            std.debug.print("x[{d:3}] = {d:.6}\n", .{ i, x_data_pointer[i] });
-        }
-        std.debug.print("\n", .{});
-
-        for (0..NF) |i| {
-            if (i < NF / 2) {
-                try std.testing.expect(is_close(x_data_pointer[i], 1.0, 1e-2));
-            } else {
-                try std.testing.expect(is_close(x_data_pointer[i], -1.0, 1e-2));
-            }
-        }
-    }
-
-    _ = try std.io.getStdIn().reader().readByte();
-}