]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml/examples: add backend support for numerical optimization (#949)
authorJohannes Gäßler <redacted>
Fri, 20 Sep 2024 12:36:38 +0000 (14:36 +0200)
committerGitHub <redacted>
Fri, 20 Sep 2024 12:36:38 +0000 (14:36 +0200)
* CUDA eval works

* stochastic gradient descent op

* Adam except decay

* CUDA CROSS_ENTROPY_LOSS_BACK

* CUDA mnist-fc training works

* backend CLI arg

* refactor gguf load

* remove sched from opt_step_adam

* implement l1 regularization (weight decay)

* extra call to add optimizer

* initialize gradients with ggml_graph_reset

* gradient accumulation

* increment iter per eval instead of epoch

* adjust backend interfaces

* fix ggml_graph_reset without backend

* fix ggml graph export/import

* fixup

* rename

* revert ggml_opt changes

* more general CUDA repeat_back

* update documentation, fix CNN

* validation split

* add clarifying comment

* optimize PyTorch training

* adjust buffer size, thread count

* fix 0.0f validation split

* Update examples/mnist/mnist-common.cpp

Co-authored-by: Georgi Gerganov <redacted>
* fix gradient accumulation

* tensor flag for accumulators -> tensor hash set

* Update include/ggml.h

Co-authored-by: slaren <redacted>
* Update tests/test-backend-ops.cpp

Co-authored-by: slaren <redacted>
* Update tests/test-backend-ops.cpp

Co-authored-by: slaren <redacted>
* fix test prints

* Update src/ggml-backend.c

Co-authored-by: Georgi Gerganov <redacted>
* better CUDA support for noncontiguous out_prod

* add comment

---------

Co-authored-by: Georgi Gerganov <redacted>
Co-authored-by: slaren <redacted>
33 files changed:
examples/mnist/README.md
examples/mnist/mnist-common.cpp
examples/mnist/mnist-common.h
examples/mnist/mnist-eval.cpp
examples/mnist/mnist-train-cnn.py
examples/mnist/mnist-train-fc.py
examples/mnist/mnist-train.cpp
include/ggml-backend.h
include/ggml.h
src/ggml-backend-impl.h
src/ggml-backend.c
src/ggml-cann.cpp
src/ggml-cuda.cu
src/ggml-cuda/binbcast.cu
src/ggml-cuda/binbcast.cuh
src/ggml-cuda/cross-entropy-loss.cu
src/ggml-cuda/cross-entropy-loss.cuh
src/ggml-cuda/opt-step-adamw.cu [new file with mode: 0644]
src/ggml-cuda/opt-step-adamw.cuh [new file with mode: 0644]
src/ggml-cuda/out-prod.cu [new file with mode: 0644]
src/ggml-cuda/out-prod.cuh [new file with mode: 0644]
src/ggml-cuda/unary.cu
src/ggml-cuda/unary.cuh
src/ggml-kompute.cpp
src/ggml-metal.m
src/ggml-rpc.cpp
src/ggml-sycl.cpp
src/ggml-vulkan.cpp
src/ggml.c
tests/test-backend-ops.cpp
tests/test-grad0.cpp
tests/test-mul-mat0.c
tests/test1.c

index 0e8f07909e88ede5f1cbbdcfe48778627146ba54..9e0966f441974c2e7428b00f1259af91627bbce4 100644 (file)
@@ -18,7 +18,7 @@ $ python3 mnist-train-fc.py mnist-fc-f32.gguf
 
 ...
 
-Test loss: 0.069983+-0.009196, Test accuracy: 97.94+-0.14%
+Test loss: 0.066051+-0.011630, Test accuracy: 98.07+-0.14%
 
 Model tensors saved to mnist-fc-f32.gguf:
 fc1.weight       (500, 784)
@@ -28,7 +28,7 @@ fc2.bias         (10,)
 ```
 
 The training script includes an evaluation of the model on the test set.
-To evaluate the model using GGML, run:
+To evaluate the model on the CPU using GGML, run:
 
 ```bash
 $ ../../build/bin/mnist-eval mnist-fc-f32.gguf data/MNIST/raw/t10k-images-idx3-ubyte data/MNIST/raw/t10k-labels-idx1-ubyte
@@ -37,26 +37,26 @@ ________________________________________________________
 ________________________________________________________
 ________________________________________________________
 ________________________________________________________
-________________________________######__________________
-____________________________########____________________
-________________________########________________________
-____________________########________________##__________
-__________________######____________________##__________
-________________######______________________####________
-______________######________________________####________
-____________######__________________________####________
-____________####____________________________####________
-__________####______________________________####________
-__________####______________________________####________
-__________##________________________________####________
-__________##______________________________####__________
-__________##____________________________######__________
-__________##__________________________######____________
-____________##____________________########______________
-____________##########################__________________
-______________##################________________________
-________________________________________________________
-________________________________________________________
+__________________________________####__________________
+______________________________########__________________
+__________________________##########____________________
+______________________##############____________________
+____________________######________####__________________
+__________________________________####__________________
+__________________________________####__________________
+________________________________####____________________
+______________________________####______________________
+________________________##########______________________
+______________________########__####____________________
+________________________##__________##__________________
+____________________________________##__________________
+__________________________________##____________________
+__________________________________##____________________
+________________________________##______________________
+____________________________####________________________
+__________##____________######__________________________
+__________##############________________________________
+________________####____________________________________
 ________________________________________________________
 ________________________________________________________
 ________________________________________________________
@@ -64,18 +64,23 @@ ________________________________________________________
 mnist_graph_eval: trying to load a ggml graph from mnist-fc-f32.gguf
 ggml_graph_import: invalid magic number, got 46554747
 mnist_graph_eval: could not load a ggml graph from mnist-fc-f32.gguf
+ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
+ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
+ggml_cuda_init: found 1 CUDA devices:
+  Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
+mnist_model: using CPU backend
 mnist_model_init_from_file: loading model weights from 'mnist-fc-f32.gguf'
 mnist_model_init_from_file: model arch is mnist-fc
 mnist_model_init_from_file: successfully loaded weights from mnist-fc-f32.gguf
-main: loaded model in 1.52 ms
-mnist_model_eval: model evaluation on 10000 images took 26.65 ms, 2.66 us/image
-main: predicted digit is 0
-main: test_loss=0.069983+-0.009196
-main: test_acc=97.94+-0.14%
+main: loaded model in 13.03 ms
+mnist_model_eval: model evaluation on 10000 images took 95.02 ms, 9.50 us/image
+main: predicted digit is 3
+main: test_loss=0.066051+-0.009343
+main: test_acc=98.07+-0.14%
 ```
 
 In addition to the evaluation on the test set the GGML evaluation also prints a random image from the test set as well as the model prediction for said image.
-To train a fully connected model using GGML run:
+To train a fully connected model on the CPU using GGML run:
 
 ``` bash
 $ ../../build/bin/mnist-train mnist-fc mnist-fc-f32.gguf data/MNIST/raw/train-images-idx3-ubyte data/MNIST/raw/train-labels-idx1-ubyte
@@ -96,12 +101,12 @@ $ python3 mnist-train-cnn.py mnist-cnn-f32.gguf
 
 ...
 
-Test loss: 0.046456
-Test accuracy: 98.40%
+Test loss: 0.045483
+Test accuracy: 98.56%
 GGUF model saved to 'mnist-cnn-f32.gguf'
 ```
 
-The saved model can be evaluated using the `mnist-eval` binary:
+The saved model can be evaluated on the CPU using the `mnist-eval` binary:
 
 ```bash
 $ ../../build/bin/mnist-eval mnist-fc-f32.gguf data/MNIST/raw/t10k-images-idx3-ubyte data/MNIST/raw/t10k-labels-idx1-ubyte
@@ -111,25 +116,25 @@ ________________________________________________________
 ________________________________________________________
 ________________________________________________________
 ________________________________________________________
-________________________________________________________
-________________________________________________________
-________________________####____________________________
-__________________________##____________________________
-__________________________##____________________________
-__________________________##____________________________
-__________________________##____________________________
-__________________________##____________________________
-____________________________##__________________________
-____________________________##__________________________
-____________________________##__________________________
-______________________________##________________________
-______________________________##________________________
-______________________________####______________________
-________________________________##______________________
-________________________________##______________________
-________________________________####____________________
+______________________________________##________________
+______________________________________##________________
+______________________________________##________________
+____________________________________##__________________
+__________________________________####__________________
 __________________________________##____________________
 ________________________________##______________________
+______________________________##________________________
+____________________________####________________________
+____________________________##__________________________
+__________________________##____________________________
+________________________##______________________________
+______________________##________________________________
+____________________####________________________________
+____________________##__________________________________
+__________________##____________________________________
+________________##______________________________________
+________________________________________________________
+________________________________________________________
 ________________________________________________________
 ________________________________________________________
 ________________________________________________________
@@ -137,17 +142,22 @@ ________________________________________________________
 mnist_graph_eval: trying to load a ggml graph from mnist-cnn-f32.gguf
 ggml_graph_import: invalid magic number, got 46554747
 mnist_graph_eval: could not load a ggml graph from mnist-cnn-f32.gguf
+ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
+ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
+ggml_cuda_init: found 1 CUDA devices:
+  Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
+mnist_model: using CPU backend
 mnist_model_init_from_file: loading model weights from 'mnist-cnn-f32.gguf'
 mnist_model_init_from_file: model arch is mnist-cnn
 mnist_model_init_from_file: successfully loaded weights from mnist-cnn-f32.gguf
-main: loaded model in 5.45 ms
-mnist_model_eval: model evaluation on 10000 images took 605.60 ms, 60.56 us/image
+main: loaded model in 11.88 ms
+mnist_model_eval: model evaluation on 10000 images took 1074.09 ms, 107.41 us/image
 main: predicted digit is 1
-main: test_loss=0.046456+-0.007354
-main: test_acc=98.40+-0.13%
+main: test_loss=0.045483+-0.006884
+main: test_acc=98.56+-0.12%
 ```
 
-Like with the fully connected network the convolutional network can also be trained using GGML:
+Like with the fully connected network the convolutional network can also be trained on the CPU using GGML:
 
 ``` bash
 $ ../../build/bin/mnist-train mnist-cnn mnist-cnn-f32.gguf data/MNIST/raw/train-images-idx3-ubyte data/MNIST/raw/train-labels-idx1-ubyte
@@ -155,6 +165,12 @@ $ ../../build/bin/mnist-train mnist-cnn mnist-cnn-f32.gguf data/MNIST/raw/train-
 
 As always, the evaluation is done using `mnist-eval` and like with the fully connected network the GGML graph is exported to `mnist-cnn-f32.ggml`.
 
+## CUDA
+
+The fully connected model can be trained and evaluated using CUDA.
+`mnist-train` and `mnist-eval` accept an additional, optional argument behind those listed so far to specify the backend.
+The default is `CPU`, by specifying `CUDA0` the first available CUDA device can be used instead (make sure to compile GGML with CUDA cupport).
+
 ## Web demo
 
 The evaluation code can be compiled to WebAssembly using [Emscripten](https://emscripten.org/) (may need to re-login to update `$PATH` after installation).
index f0e90bb029e52bb1b7e95008995af9b058ded93e..cc56d2736a18a2158c19c8f5f2c3c38c2af2a6e0 100644 (file)
@@ -1,3 +1,5 @@
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
 #include "ggml.h"
 
 #include "mnist-common.h"
@@ -111,15 +113,23 @@ mnist_eval_result mnist_graph_eval(const std::string & fname, const float * imag
     GGML_ASSERT(images_batch);
     GGML_ASSERT(images_batch->ne[0] == MNIST_NINPUT || (images_batch->ne[0] == MNIST_HW && images_batch->ne[1] == MNIST_HW));
 
+    struct ggml_tensor * labels_batch = ggml_graph_get_tensor(gf, "labels");
+    GGML_ASSERT(labels_batch);
+    GGML_ASSERT(labels_batch->ne[0] == MNIST_NCLASSES);
+    GGML_ASSERT(labels_batch->ne[2] == 1);
+    GGML_ASSERT(labels_batch->ne[3] == 1);
+
+    const int nbatch = labels_batch->ne[1];
+    GGML_ASSERT(nex % nbatch == 0);
+
     struct ggml_tensor * logits_batch = ggml_graph_get_tensor(gf, "logits");
     GGML_ASSERT(logits_batch);
     GGML_ASSERT(logits_batch->ne[0] == MNIST_NCLASSES);
+    GGML_ASSERT(logits_batch->ne[1] == nbatch);
     GGML_ASSERT(logits_batch->ne[2] == 1);
     GGML_ASSERT(logits_batch->ne[3] == 1);
 
     GGML_ASSERT(images_batch->ne[1] == logits_batch->ne[1] || images_batch->ne[3] == logits_batch->ne[1]);
-    const int nbatch = logits_batch->ne[1];
-    GGML_ASSERT(nex % nbatch == 0);
 
     struct ggml_tensor * loss = ggml_graph_get_tensor(gf, "loss");
 
@@ -127,7 +137,8 @@ mnist_eval_result mnist_graph_eval(const std::string & fname, const float * imag
         const int64_t t_start_us = ggml_time_us();
 
         for (int iex0; iex0 < nex; iex0 += nbatch) {
-            memcpy(images_batch->data, images + iex0*MNIST_NINPUT, ggml_nbytes(images_batch));
+            memcpy(images_batch->data, images + iex0*MNIST_NINPUT,   ggml_nbytes(images_batch));
+            memcpy(labels_batch->data, labels + iex0*MNIST_NCLASSES, ggml_nbytes(labels_batch));
             ggml_graph_compute_with_ctx(ctx_compute, gf, nthreads);
 
             for (int iexb = 0; iexb < nbatch; ++iexb) {
@@ -154,18 +165,67 @@ mnist_eval_result mnist_graph_eval(const std::string & fname, const float * imag
     return result;
 }
 
-mnist_model mnist_model_init_from_file(const std::string & fname) {
-    mnist_model model;
+// Temporary util function for loading data from GGUF to a backend != CPU until GGML itself provides this functionality:
+bool load_from_gguf(const char * fname, struct ggml_context * ctx_ggml, struct gguf_context * ctx_gguf) {
+    FILE * f = ggml_fopen(fname, "rb");
+    if (!f) {
+        return false;
+    }
+
+    const size_t buf_size = 4*1024*1024;
+    void * buf = malloc(buf_size);
+
+    const int n_tensors = gguf_get_n_tensors(ctx_gguf);
+    for (int i = 0; i < n_tensors; i++) {
+        const char * name = gguf_get_tensor_name(ctx_gguf, i);
+
+        struct ggml_tensor * tensor = ggml_get_tensor(ctx_ggml, name);
+        if (!tensor) {
+            continue;
+        }
+
+        const size_t offs = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, i);
+
+        if (fseek(f, offs, SEEK_SET) != 0) {
+            fclose(f);
+            free(buf);
+            return false;
+        }
+
+        const size_t nbytes = ggml_nbytes(tensor);
+        for (size_t pos = 0; pos < nbytes; pos += buf_size) {
+            const size_t nbytes_cpy = buf_size < nbytes - pos ? buf_size : nbytes - pos;
+
+            if (fread(buf, 1, nbytes_cpy, f) != nbytes_cpy) {
+                fclose(f);
+                free(buf);
+                return false;
+            }
+
+            ggml_backend_tensor_set(tensor, buf, pos, nbytes_cpy);
+        }
+    }
+
+    fclose(f);
+    free(buf);
+    return true;
+}
+
+mnist_model mnist_model_init_from_file(const std::string & fname, const std::string & backend) {
+    mnist_model model(backend);
     fprintf(stderr, "%s: loading model weights from '%s'\n", __func__, fname.c_str());
 
-    struct gguf_init_params params = {
-        /*.no_alloc   =*/ false,
-        /*.ctx        =*/ &model.ctx_weight,
-    };
-    gguf_context * ctx = gguf_init_from_file(fname.c_str(), params);
-    if (!ctx) {
-        fprintf(stderr, "%s: gguf_init_from_file() failed\n", __func__);
-        exit(1);
+    struct gguf_context * ctx;
+    {
+        struct gguf_init_params params = {
+            /*.no_alloc   =*/ true,
+            /*.ctx        =*/ &model.ctx_weight,
+        };
+        ctx = gguf_init_from_file(fname.c_str(), params);
+        if (!ctx) {
+            fprintf(stderr, "%s: gguf_init_from_file() failed\n", __func__);
+            exit(1);
+        }
     }
     model.arch = gguf_get_val_str(ctx, gguf_find_key(ctx, "general.architecture"));
     fprintf(stderr, "%s: model arch is %s\n", __func__, model.arch.c_str());
@@ -204,8 +264,8 @@ mnist_model mnist_model_init_from_file(const std::string & fname) {
 
         model.conv1_bias = ggml_get_tensor(model.ctx_weight, "conv1.bias");
         GGML_ASSERT(model.conv1_bias->type == GGML_TYPE_F32);
-        GGML_ASSERT(model.conv1_bias->ne[0] == MNIST_HW);
-        GGML_ASSERT(model.conv1_bias->ne[1] == MNIST_HW);
+        GGML_ASSERT(model.conv1_bias->ne[0] == 1);
+        GGML_ASSERT(model.conv1_bias->ne[1] == 1);
         GGML_ASSERT(model.conv1_bias->ne[2] == MNIST_CNN_NCB);
         GGML_ASSERT(model.conv1_bias->ne[3] == 1);
 
@@ -218,8 +278,8 @@ mnist_model mnist_model_init_from_file(const std::string & fname) {
 
         model.conv2_bias = ggml_get_tensor(model.ctx_weight, "conv2.bias");
         GGML_ASSERT(model.conv2_bias->type == GGML_TYPE_F32);
-        GGML_ASSERT(model.conv2_bias->ne[0] == MNIST_HW/2);
-        GGML_ASSERT(model.conv2_bias->ne[1] == MNIST_HW/2);
+        GGML_ASSERT(model.conv2_bias->ne[0] == 1);
+        GGML_ASSERT(model.conv2_bias->ne[1] == 1);
         GGML_ASSERT(model.conv2_bias->ne[2] == MNIST_CNN_NCB*2);
         GGML_ASSERT(model.conv2_bias->ne[3] == 1);
 
@@ -239,12 +299,19 @@ mnist_model mnist_model_init_from_file(const std::string & fname) {
     } else {
         fprintf(stderr, "%s: unknown model arch: %s\n", __func__, model.arch.c_str());
     }
+    model.buf_weight = ggml_backend_alloc_ctx_tensors(model.ctx_weight, model.backend);
+
+    if(!load_from_gguf(fname.c_str(), model.ctx_weight, ctx)) {
+        fprintf(stderr, "%s: loading weights from %s failed\n", __func__, fname.c_str());
+        exit(1);
+    }
+
     fprintf(stderr, "%s: successfully loaded weights from %s\n", __func__, fname.c_str());
     return model;
 }
 
-mnist_model mnist_model_init_random(const std::string & arch) {
-    mnist_model model;
+mnist_model mnist_model_init_random(const std::string & arch, const std::string & backend) {
+    mnist_model model(backend);
     model.arch = arch;
 
     std::random_device rd{};
@@ -294,21 +361,25 @@ mnist_model mnist_model_init_random(const std::string & arch) {
         fprintf(stderr, "%s: unknown model arch: %s\n", __func__, model.arch.c_str());
     }
 
+    model.buf_weight = ggml_backend_alloc_ctx_tensors(model.ctx_weight, model.backend);
+
     for (ggml_tensor * t : init_tensors) {
         GGML_ASSERT(t->type == GGML_TYPE_F32);
-        float * data = ggml_get_data_f32(t);
         const int64_t ne = ggml_nelements(t);
+        std::vector<float> tmp(ne);
 
         for (int64_t i = 0; i < ne; ++i) {
-            data[i] = nd(gen);
+            tmp[i] = nd(gen);
         }
+        ggml_backend_tensor_set(t, tmp.data(), 0, ggml_nbytes(t));
     }
 
     return model;
 }
 
-void mnist_model_build(mnist_model & model, const int nbatch) {
-    model.nbatch = nbatch;
+void mnist_model_build(mnist_model & model, const int nbatch_logical, const int nbatch_physical) {
+    model.nbatch_logical  = nbatch_logical;
+    model.nbatch_physical = nbatch_physical;
 
     if (model.arch == "mnist-fc") {
         ggml_set_param(model.ctx_compute, model.fc1_weight);
@@ -316,9 +387,9 @@ void mnist_model_build(mnist_model & model, const int nbatch) {
         ggml_set_param(model.ctx_compute, model.fc2_weight);
         ggml_set_param(model.ctx_compute, model.fc2_bias);
 
-        model.images = ggml_new_tensor_2d(model.ctx_compute, GGML_TYPE_F32, MNIST_NINPUT, model.nbatch);
-        ggml_set_input(model.images);
+        model.images = ggml_new_tensor_2d(model.ctx_compute, GGML_TYPE_F32, MNIST_NINPUT, model.nbatch_physical);
         ggml_set_name(model.images, "images");
+        ggml_set_input(model.images);
 
         ggml_tensor * fc1 = ggml_relu(model.ctx_compute, ggml_add(model.ctx_compute,
             ggml_mul_mat(model.ctx_compute, model.fc1_weight, model.images),
@@ -334,9 +405,9 @@ void mnist_model_build(mnist_model & model, const int nbatch) {
         ggml_set_param(model.ctx_compute, model.dense_weight);
         ggml_set_param(model.ctx_compute, model.dense_bias);
 
-        model.images = ggml_new_tensor_4d(model.ctx_compute, GGML_TYPE_F32, 28, 28, 1, model.nbatch);
-        ggml_set_input(model.images);
+        model.images = ggml_new_tensor_4d(model.ctx_compute, GGML_TYPE_F32, 28, 28, 1, model.nbatch_physical);
         ggml_set_name(model.images, "images");
+        ggml_set_input(model.images);
 
         struct ggml_tensor * conv1_out = ggml_relu(model.ctx_compute, ggml_add(model.ctx_compute,
             ggml_conv_2d(model.ctx_compute, model.conv1_kernel, model.images, 1, 1, 1, 1, 1, 1),
@@ -344,13 +415,13 @@ void mnist_model_build(mnist_model & model, const int nbatch) {
         GGML_ASSERT(conv1_out->ne[0] == MNIST_HW);
         GGML_ASSERT(conv1_out->ne[1] == MNIST_HW);
         GGML_ASSERT(conv1_out->ne[2] == MNIST_CNN_NCB);
-        GGML_ASSERT(conv1_out->ne[3] == model.nbatch);
+        GGML_ASSERT(conv1_out->ne[3] == model.nbatch_physical);
 
         struct ggml_tensor * conv2_in = ggml_pool_2d(model.ctx_compute, conv1_out, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
         GGML_ASSERT(conv2_in->ne[0] == MNIST_HW/2);
         GGML_ASSERT(conv2_in->ne[1] == MNIST_HW/2);
         GGML_ASSERT(conv2_in->ne[2] == MNIST_CNN_NCB);
-        GGML_ASSERT(conv2_in->ne[3] == model.nbatch);
+        GGML_ASSERT(conv2_in->ne[3] == model.nbatch_physical);
 
         struct ggml_tensor * conv2_out = ggml_relu(model.ctx_compute, ggml_add(model.ctx_compute,
             ggml_conv_2d(model.ctx_compute, model.conv2_kernel, conv2_in, 1, 1, 1, 1, 1, 1),
@@ -358,19 +429,19 @@ void mnist_model_build(mnist_model & model, const int nbatch) {
         GGML_ASSERT(conv2_out->ne[0] == MNIST_HW/2);
         GGML_ASSERT(conv2_out->ne[1] == MNIST_HW/2);
         GGML_ASSERT(conv2_out->ne[2] == MNIST_CNN_NCB*2);
-        GGML_ASSERT(conv2_out->ne[3] == model.nbatch);
+        GGML_ASSERT(conv2_out->ne[3] == model.nbatch_physical);
 
         struct ggml_tensor * dense_in = ggml_pool_2d(model.ctx_compute, conv2_out, GGML_OP_POOL_MAX, 2, 2, 2, 2, 0, 0);
         GGML_ASSERT(dense_in->ne[0] == MNIST_HW/4);
         GGML_ASSERT(dense_in->ne[1] == MNIST_HW/4);
         GGML_ASSERT(dense_in->ne[2] == MNIST_CNN_NCB*2);
-        GGML_ASSERT(dense_in->ne[3] == model.nbatch);
+        GGML_ASSERT(dense_in->ne[3] == model.nbatch_physical);
 
         dense_in = ggml_reshape_2d(model.ctx_compute,
             ggml_cont(model.ctx_compute, ggml_permute(model.ctx_compute, dense_in, 1, 2, 0, 3)),
-            (MNIST_HW/4)*(MNIST_HW/4)*(MNIST_CNN_NCB*2), model.nbatch);
+            (MNIST_HW/4)*(MNIST_HW/4)*(MNIST_CNN_NCB*2), model.nbatch_physical);
         GGML_ASSERT(dense_in->ne[0] == (MNIST_HW/4)*(MNIST_HW/4)*(MNIST_CNN_NCB*2));
-        GGML_ASSERT(dense_in->ne[1] == model.nbatch);
+        GGML_ASSERT(dense_in->ne[1] == model.nbatch_physical);
         GGML_ASSERT(dense_in->ne[2] == 1);
         GGML_ASSERT(dense_in->ne[3] == 1);
 
@@ -379,30 +450,31 @@ void mnist_model_build(mnist_model & model, const int nbatch) {
         GGML_ASSERT(false);
     }
 
-    ggml_set_output(model.logits);
     ggml_set_name(model.logits, "logits");
+    ggml_set_output(model.logits);
     GGML_ASSERT(model.logits->type == GGML_TYPE_F32);
     GGML_ASSERT(model.logits->ne[0] == MNIST_NCLASSES);
-    GGML_ASSERT(model.logits->ne[1] == model.nbatch);
+    GGML_ASSERT(model.logits->ne[1] == model.nbatch_physical);
     GGML_ASSERT(model.logits->ne[2] == 1);
     GGML_ASSERT(model.logits->ne[3] == 1);
 
     model.probs = ggml_soft_max(model.ctx_compute, model.logits);
-    ggml_set_output(model.probs);
     ggml_set_name(model.probs, "probs");
+    ggml_set_output(model.probs);
     GGML_ASSERT(model.probs->type == GGML_TYPE_F32);
     GGML_ASSERT(model.probs->ne[0] == MNIST_NCLASSES);
-    GGML_ASSERT(model.probs->ne[1] == model.nbatch);
+    GGML_ASSERT(model.probs->ne[1] == model.nbatch_physical);
     GGML_ASSERT(model.probs->ne[2] == 1);
     GGML_ASSERT(model.probs->ne[3] == 1);
 
-    model.labels = ggml_new_tensor_2d(model.ctx_compute, GGML_TYPE_F32, MNIST_NCLASSES, model.nbatch);
-    ggml_set_input(model.labels);
+    model.labels = ggml_new_tensor_2d(model.ctx_compute, GGML_TYPE_F32, MNIST_NCLASSES, model.nbatch_physical);
     ggml_set_name(model.labels, "labels");
+    ggml_set_input(model.labels);
 
     model.loss = ggml_cross_entropy_loss(model.ctx_compute, model.logits, model.labels);
-    ggml_set_output(model.loss);
     ggml_set_name(model.loss, "loss");
+    ggml_set_output(model.loss);
+    ggml_set_loss(model.loss);
     GGML_ASSERT(model.loss->type == GGML_TYPE_F32);
     GGML_ASSERT(model.loss->ne[0] == 1);
     GGML_ASSERT(model.loss->ne[1] == 1);
@@ -410,26 +482,38 @@ void mnist_model_build(mnist_model & model, const int nbatch) {
     GGML_ASSERT(model.loss->ne[3] == 1);
 }
 
-mnist_eval_result mnist_model_eval(const mnist_model & model, const float * images, const float * labels, const int nex, const int nthreads) {
+mnist_eval_result mnist_model_eval(mnist_model & model, const float * images, const float * labels, const int nex) {
     mnist_eval_result result;
 
     struct ggml_cgraph * gf = ggml_new_graph(model.ctx_compute);
     ggml_build_forward_expand(gf, model.loss);
 
+    model.buf_compute = ggml_backend_alloc_ctx_tensors(model.ctx_compute, model.backend);
+
     {
         const int64_t t_start_us = ggml_time_us();
 
-        GGML_ASSERT(nex % model.nbatch == 0);
-        for (int iex0 = 0; iex0 < nex; iex0 += model.nbatch) {
-            memcpy(model.images->data, images + iex0*MNIST_NINPUT,   ggml_nbytes(model.images));
-            memcpy(model.labels->data, labels + iex0*MNIST_NCLASSES, ggml_nbytes(model.labels));
-            ggml_graph_compute_with_ctx(model.ctx_compute, gf, nthreads);
+        float loss;
+        std::vector<float> logits(model.nbatch_physical*MNIST_NCLASSES);
+
+        GGML_ASSERT(sizeof(loss)  == ggml_nbytes(model.loss));
+        GGML_ASSERT(logits.size() == ggml_nelements(model.logits));
+
+        GGML_ASSERT(nex % model.nbatch_physical == 0);
+        for (int iex0 = 0; iex0 < nex; iex0 += model.nbatch_physical) {
+            ggml_backend_tensor_set(model.images, images + iex0*MNIST_NINPUT,   0, ggml_nbytes(model.images));
+            ggml_backend_tensor_set(model.labels, labels + iex0*MNIST_NCLASSES, 0, ggml_nbytes(model.labels));
+
+            ggml_backend_graph_compute(model.backend, gf);
+
+            ggml_backend_tensor_get(model.loss,   &loss,         0, ggml_nbytes(model.loss));
+            ggml_backend_tensor_get(model.logits, logits.data(), 0, ggml_nbytes(model.logits));
 
-            result.loss.push_back(*ggml_get_data_f32(model.loss));
+            result.loss.push_back(loss);
 
-            for (int iexb = 0; iexb < model.nbatch; ++iexb) {
-                const float * logits_data = ggml_get_data_f32(model.logits) + iexb*MNIST_NCLASSES;
-                result.pred.push_back(std::max_element(logits_data, logits_data + MNIST_NCLASSES) - logits_data);
+            for (int iexb = 0; iexb < model.nbatch_physical; ++iexb) {
+                const float * logits_iexb = logits.data() + iexb*MNIST_NCLASSES;
+                result.pred.push_back(std::max_element(logits_iexb, logits_iexb + MNIST_NCLASSES) - logits_iexb);
             }
         }
 
@@ -443,81 +527,143 @@ mnist_eval_result mnist_model_eval(const mnist_model & model, const float * imag
     return result;
 }
 
-void mnist_model_train(mnist_model & model, const float * images, const float * labels, const int nex, const int nthreads) {
+void mnist_model_train(mnist_model & model, const float * images, const float * labels, const int nex, const int nepoch, const float val_split) {
     const int64_t t_start_us = ggml_time_us();
 
-    struct ggml_cgraph * gf = ggml_new_graph_custom(model.ctx_compute, 16384, true);
+    // gf == graph forward, forward pass only.
+    struct ggml_cgraph * gf = ggml_new_graph_custom(model.ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
     ggml_build_forward_expand(gf, model.loss);
 
-    struct ggml_cgraph * gb = ggml_graph_dup(model.ctx_compute, gf);
-    ggml_build_backward_expand(model.ctx_compute, gf, gb, true);
+    // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
+    struct ggml_cgraph * gb_grad = ggml_graph_dup(model.ctx_compute, gf);
+    ggml_build_backward_expand(model.ctx_compute, gf, gb_grad, /*accumulate =*/ true, false);
 
-    struct ggml_opt_context opt_ctx;
-    struct ggml_opt_params  opt_pars = ggml_opt_default_params(GGML_OPT_TYPE_ADAM);
-    opt_pars.print_forward_graph = false;
-    opt_pars.print_backward_graph = false;
-    opt_pars.n_threads = nthreads;
-    opt_pars.adam.n_iter = 1; // per call of ggml_opt_resume_g
-    ggml_opt_init(model.ctx_compute, &opt_ctx, opt_pars, 0);
+    // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
+    struct ggml_cgraph * gb_opt = ggml_graph_dup(model.ctx_compute, gb_grad);
+    ggml_build_opt_adamw(model.ctx_compute, gf, gb_opt, 1e-3f, 0.9f, 0.999f, 1e-8f, 0.0f);
 
-    for (int epoch = 0; epoch < 20; ++epoch) {
-        fprintf(stderr, "%s: epoch %d start...", __func__, epoch);
+    model.buf_compute = ggml_backend_alloc_ctx_tensors(model.ctx_compute, model.backend);
+    ggml_graph_reset(gb_opt); // Set gradients to zero, reset optimizer.
+
+    const int iex_split = ((int)((1.0f - val_split)*nex) / model.nbatch_logical) * model.nbatch_logical;
+
+    for (int epoch = 0; epoch < nepoch; ++epoch) {
+        fprintf(stderr, "%s: epoch %02d start...", __func__, epoch);
         const int64_t t_start_us = ggml_time_us();
-        mnist_eval_result result;
-        for (int iex0 = 0; iex0 < nex; iex0 += model.nbatch) {
-            memcpy(model.images->data, images + iex0*MNIST_NINPUT,   ggml_nbytes(model.images));
-            memcpy(model.labels->data, labels + iex0*MNIST_NCLASSES, ggml_nbytes(model.labels));
 
-            enum ggml_opt_result opt_result = ggml_opt_resume_g(model.ctx_compute, &opt_ctx, model.loss, gf, gb, NULL, NULL);
-            GGML_ASSERT(opt_result == GGML_OPT_RESULT_OK || opt_result == GGML_OPT_RESULT_DID_NOT_CONVERGE);
+        float loss;
+        std::vector<float> logits(model.nbatch_physical*MNIST_NCLASSES);
+        int iex0 = 0;
+
+        mnist_eval_result result_train;
+        for (; iex0 < iex_split; iex0 += model.nbatch_physical) {
+            ggml_backend_tensor_set(model.images, images + iex0*MNIST_NINPUT,   0, ggml_nbytes(model.images));
+            ggml_backend_tensor_set(model.labels, labels + iex0*MNIST_NCLASSES, 0, ggml_nbytes(model.labels));
+
+            // With a period of nbatch_logical/nbatch_physical iterations:
+            if ((iex0 + model.nbatch_physical) % model.nbatch_logical != 0) {
+                // For the first nbatch_logical/nbatch_physical - 1 iterations, only calculate gradients and accumulate them:
+                ggml_backend_graph_compute(model.backend, gb_grad);
+            } else {
+                // For the last iteration, calculate gradients and also apply the optimizer:
+                ggml_backend_graph_compute(model.backend, gb_opt); // gb_opt contains all nodes of gb_grad so no extra call for gb_grad is needed.
+                ggml_graph_reset(gb_grad); // Set gradients to zero, do not reset optimizer.
+            }
+
+            ggml_backend_tensor_get(model.loss,   &loss,         0, ggml_nbytes(model.loss));
+            ggml_backend_tensor_get(model.logits, logits.data(), 0, ggml_nbytes(model.logits));
 
-            result.loss.push_back(*ggml_get_data_f32(model.loss));
+            result_train.loss.push_back(loss);
 
-            for (int iexb = 0; iexb < model.nbatch; ++iexb) {
-                const float * ptr_p = (const float *) model.logits->data + iexb*MNIST_NCLASSES;
-                result.pred.push_back(std::max_element(ptr_p, ptr_p + MNIST_NCLASSES) - ptr_p);
+            for (int iexb = 0; iexb < model.nbatch_physical; ++iexb) {
+                const float * logits_iexb = logits.data() + iexb*MNIST_NCLASSES;
+                result_train.pred.push_back(std::max_element(logits_iexb, logits_iexb + MNIST_NCLASSES) - logits_iexb);
             }
         }
 
-        const double loss_mean = mnist_loss(result).first;
-        const double percent_correct = 100.0 * mnist_accuracy(result, labels).first;
+        mnist_eval_result result_val;
+        for (; iex0 < nex; iex0 += model.nbatch_physical) {
+            ggml_backend_tensor_set(model.images, images + iex0*MNIST_NINPUT,   0, ggml_nbytes(model.images));
+            ggml_backend_tensor_set(model.labels, labels + iex0*MNIST_NCLASSES, 0, ggml_nbytes(model.labels));
 
-        const int64_t t_epoch_us = ggml_time_us() - t_start_us;
-        const double t_epoch_s = 1e-6*t_epoch_us;
-        fprintf(stderr, "done, took %.2lfs, train_loss=%.6lf, train_acc=%.2f%%\n", t_epoch_s, loss_mean, percent_correct);
+            ggml_backend_graph_compute(model.backend, gf); // For the validation set, only the forward pass is needed.
+
+            ggml_backend_tensor_get(model.loss,   &loss,         0, ggml_nbytes(model.loss));
+            ggml_backend_tensor_get(model.logits, logits.data(), 0, ggml_nbytes(model.logits));
+
+            result_val.loss.push_back(loss);
+
+            for (int iexb = 0; iexb < model.nbatch_physical; ++iexb) {
+                const float * logits_iexb = logits.data() + iexb*MNIST_NCLASSES;
+                result_val.pred.push_back(std::max_element(logits_iexb, logits_iexb + MNIST_NCLASSES) - logits_iexb);
+            }
+        }
+
+        {
+            const double loss_mean = mnist_loss(result_train).first;
+            const double percent_correct = 100.0 * mnist_accuracy(result_train, labels + 0*MNIST_NCLASSES).first;
+
+            const int64_t t_epoch_us = ggml_time_us() - t_start_us;
+            const double t_epoch_s = 1e-6*t_epoch_us;
+            fprintf(stderr, "done, took %.2lfs, train_loss=%.6lf, train_acc=%.2f%%", t_epoch_s, loss_mean, percent_correct);
+        }
+
+        if (iex_split < nex) {
+            const std::pair<double, double> loss = mnist_loss(result_val);
+            const std::pair<double, double> acc  = mnist_accuracy(result_val, labels + iex_split*MNIST_NCLASSES);
+
+            fprintf(stderr, ", val_loss=%.6lf+-%.6lf, val_acc=%.2f+-%.2f%%", loss.first, loss.second, 100.0*acc.first, 100.0*acc.second);
+        }
+        fprintf(stderr, "\n");
     }
 
     const int64_t t_total_us = ggml_time_us() - t_start_us;
     const double t_total_s = 1e-6*t_total_us;
     fprintf(stderr, "%s: training took %.2lfs\n", __func__, t_total_s);
 
-    std::string fname = model.arch + "-f32.ggml";
-    fprintf(stderr, "%s: saving the ggml graph for the forward pass to %s\n", __func__, fname.c_str());
-    ggml_graph_export(gf, fname.c_str());
+    if (ggml_backend_is_cpu(model.backend)) {
+        std::string fname = model.arch + "-f32.ggml";
+        fprintf(stderr, "%s: saving the GGML graph for the forward pass to %s\n", __func__, fname.c_str());
+        ggml_graph_export(gf, fname.c_str());
+    } else {
+        fprintf(stderr, "%s: not saving the GGML graph for the forward pass because this is only supported for the CPU backend\n", __func__);
+    }
 }
 
 void mnist_model_save(mnist_model & model, const std::string & fname) {
     printf("%s: saving model to '%s'\n", __func__, fname.c_str());
 
+    struct ggml_context * ggml_ctx;
+    {
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ 100 * 1024*1024,
+            /*.mem_buffer =*/ NULL,
+            /*.no_alloc   =*/ false,
+        };
+        ggml_ctx = ggml_init(params);
+    }
+
     gguf_context * gguf_ctx = gguf_init_empty();
     gguf_set_val_str(gguf_ctx, "general.architecture", model.arch.c_str());
 
+    std::vector<struct ggml_tensor *> weights;
     if (model.arch == "mnist-fc") {
-        gguf_add_tensor(gguf_ctx, model.fc1_weight);
-        gguf_add_tensor(gguf_ctx, model.fc1_bias);
-        gguf_add_tensor(gguf_ctx, model.fc2_weight);
-        gguf_add_tensor(gguf_ctx, model.fc2_bias);
+        weights = {model.fc1_weight, model.fc1_bias, model.fc2_weight, model.fc2_bias};
     } else if (model.arch == "mnist-cnn") {
-        gguf_add_tensor(gguf_ctx, model.conv1_kernel);
-        gguf_add_tensor(gguf_ctx, model.conv1_bias);
-        gguf_add_tensor(gguf_ctx, model.conv2_kernel);
-        gguf_add_tensor(gguf_ctx, model.conv2_bias);
-        gguf_add_tensor(gguf_ctx, model.dense_weight);
-        gguf_add_tensor(gguf_ctx, model.dense_bias);
+        weights = {model.conv1_kernel, model.conv1_bias, model.conv2_kernel, model.conv2_bias, model.dense_weight, model.dense_bias};
     } else {
         GGML_ASSERT(false);
     }
+    for (struct ggml_tensor * t : weights) {
+        struct ggml_tensor * copy = ggml_dup_tensor(ggml_ctx, t);
+        ggml_set_name(copy, t->name);
+        ggml_backend_tensor_get(t, copy->data, 0, ggml_nbytes(t));
+        gguf_add_tensor(gguf_ctx, copy);
+    }
     gguf_write_to_file(gguf_ctx, fname.c_str(), false);
+
+    ggml_free(ggml_ctx);
+    gguf_free(gguf_ctx);
 }
 
 std::pair<double, double> mnist_loss(const mnist_eval_result & result) {
@@ -564,9 +710,9 @@ int wasm_eval(uint8_t * digitPtr) {
     std::vector<float> digit(digitPtr, digitPtr + MNIST_NINPUT);
     std::vector<float> labels(MNIST_NCLASSES);
 
-    mnist_model model = mnist_model_init_from_file("mnist-f32.gguf");
-    mnist_model_build(model, 1);
-    mnist_eval_result result = mnist_model_eval(model, digit.data(), labels.data(), 1, 1);
+    mnist_model model = mnist_model_init_from_file("mnist-f32.gguf", "CPU");
+    mnist_model_build(model, 1, 1);
+    mnist_eval_result result = mnist_model_eval(model, digit.data(), labels.data(), 1);
 
     return result.pred[0];
 }
index 1e15c5add3e46421daff024fdeccf618dbbe86bc..72638a4094c47519bc7878a244c98c479f22e9db 100644 (file)
@@ -1,14 +1,20 @@
+#include <cstdint>
 #include <string>
+#include <thread>
 #include <vector>
 
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
 #include "ggml.h"
 
-#define MNIST_NTRAIN 60000
-#define MNIST_NTEST  10000
-#define MNIST_NBATCH 500
+#define MNIST_NTRAIN          60000
+#define MNIST_NTEST           10000
+#define MNIST_NBATCH_LOGICAL   1000
+#define MNIST_NBATCH_PHYSICAL   500
 
-static_assert(MNIST_NTRAIN % MNIST_NBATCH == 0, "MNIST_NTRAIN % MNIST_BATCH != 0");
-static_assert(MNIST_NTEST  % MNIST_NBATCH == 0, "MNIST_NTRAIN % MNIST_BATCH != 0");
+static_assert(MNIST_NBATCH_LOGICAL % MNIST_NBATCH_PHYSICAL == 0, "MNIST_NBATCH_LOGICAL % MNIST_NBATCH_PHYSICAL != 0");
+static_assert(MNIST_NTRAIN % MNIST_NBATCH_LOGICAL == 0, "MNIST_NTRAIN % MNIST_NBATCH_LOGICAL != 0");
+static_assert(MNIST_NTEST  % MNIST_NBATCH_LOGICAL == 0, "MNIST_NTRAIN % MNIST_NBATCH_LOGICAL != 0");
 
 #define MNIST_HW       28
 #define MNIST_NINPUT   (MNIST_HW*MNIST_HW)
@@ -21,7 +27,9 @@ static_assert(MNIST_NTEST  % MNIST_NBATCH == 0, "MNIST_NTRAIN % MNIST_BATCH != 0
 
 struct mnist_model {
     std::string arch;
-    int nbatch;
+    ggml_backend_t backend;
+    int nbatch_logical;
+    int nbatch_physical;
 
     struct ggml_tensor  * images = nullptr;
     struct ggml_tensor  * labels = nullptr;
@@ -41,31 +49,45 @@ struct mnist_model {
     struct ggml_tensor * dense_weight = nullptr;
     struct ggml_tensor * dense_bias   = nullptr;
 
-    static const size_t size_weight  = 100 *      1024*1024;
-    static const size_t size_compute =   1 * 1024*1024*1024;
-
-    void                * buf_weight  = nullptr;
     struct ggml_context * ctx_weight  = nullptr;
-    void                * buf_compute = nullptr;
     struct ggml_context * ctx_compute = nullptr;
+    ggml_backend_buffer_t buf_weight  = nullptr;
+    ggml_backend_buffer_t buf_compute = nullptr;
+
+    mnist_model(const std::string & backend_name) {
+        const size_t backend_index = ggml_backend_reg_find_by_name(backend_name.c_str());
+        if (backend_index == SIZE_MAX) {
+            fprintf(stderr, "%s: ERROR: backend %s not found, available:\n", __func__, backend_name.c_str());
+            for (size_t i = 0; i < ggml_backend_reg_get_count(); ++i) {
+                fprintf(stderr, "  - %s\n", ggml_backend_reg_get_name(i));
+            }
+            exit(1);
+        }
+
+        fprintf(stderr, "%s: using %s backend\n", __func__, backend_name.c_str());
+        backend = ggml_backend_reg_init_backend(backend_index, nullptr);
+        if (ggml_backend_is_cpu(backend)) {
+            const int ncores_logical = std::thread::hardware_concurrency();
+            ggml_backend_cpu_set_n_threads(backend, std::min(ncores_logical, (ncores_logical + 4)/2));
+        }
 
-    mnist_model() {
-        buf_weight = malloc(size_weight);
         {
+            const size_t size_meta = 1024*ggml_tensor_overhead();
             struct ggml_init_params params = {
-                /*.mem_size   =*/ size_weight,
-                /*.mem_buffer =*/ buf_weight,
-                /*.no_alloc   =*/ false,
+                /*.mem_size   =*/ size_meta,
+                /*.mem_buffer =*/ nullptr,
+                /*.no_alloc   =*/ true,
             };
             ctx_weight = ggml_init(params);
         }
 
-        buf_compute = malloc(size_compute);
         {
+            // The compute context needs a total of 3 compute graphs: forward pass + backwards pass (with/without optimizer step).
+            const size_t size_meta = GGML_DEFAULT_GRAPH_SIZE*ggml_tensor_overhead() + 3*ggml_graph_overhead();
             struct ggml_init_params params = {
-                /*.mem_size   =*/ size_compute,
-                /*.mem_buffer =*/ buf_compute,
-                /*.no_alloc   =*/ false,
+                /*.mem_size   =*/ size_meta,
+                /*.mem_buffer =*/ nullptr,
+                /*.no_alloc   =*/ true,
             };
             ctx_compute = ggml_init(params);
         }
@@ -75,8 +97,9 @@ struct mnist_model {
         ggml_free(ctx_weight);
         ggml_free(ctx_compute);
 
-        free(buf_weight);
-        free(buf_compute);
+        ggml_backend_buffer_free(buf_weight);
+        ggml_backend_buffer_free(buf_compute);
+        ggml_backend_free(backend);
     }
 };
 
@@ -93,11 +116,11 @@ bool mnist_label_load(const std::string & fname, float * buf, const int nex);
 
 mnist_eval_result mnist_graph_eval(const std::string & fname, const float * images, const float * labels, const int nex, const int nthreads);
 
-mnist_model       mnist_model_init_from_file(const std::string & fname);
-mnist_model       mnist_model_init_random(const std::string & arch);
-void              mnist_model_build(mnist_model & model, const int nbatch);
-mnist_eval_result mnist_model_eval(const mnist_model & model, const float * images, const float * labels, const int nex, const int nthreads);
-void              mnist_model_train(mnist_model & model, const float * images, const float * labels, const int nex, const int nthreads);
+mnist_model       mnist_model_init_from_file(const std::string & fname, const std::string & backend);
+mnist_model       mnist_model_init_random(const std::string & arch, const std::string & backend);
+void              mnist_model_build(mnist_model & model, const int nbatch_logical, const int nbatch_physical);
+mnist_eval_result mnist_model_eval(mnist_model & model, const float * images, const float * labels, const int nex);
+void              mnist_model_train(mnist_model & model, const float * images, const float * labels, const int nex, const int nepoch, const float val_split);
 void              mnist_model_save(mnist_model & model, const std::string & fname);
 
 std::pair<double, double> mnist_loss(const mnist_eval_result & result);
index dbd0e0ddbd902a74fd2a22dc57d3c4e6b74fddd8..ae2cadbc5f38670c21655fe5aa29cee0ecf5e020 100644 (file)
@@ -19,8 +19,8 @@ int main(int argc, char ** argv) {
     srand(time(NULL));
     ggml_time_init();
 
-    if (argc != 4) {
-        fprintf(stderr, "Usage: %s mnist-fc-f32.gguf data/MNIST/raw/t10k-images-idx3-ubyte data/MNIST/raw/t10k-labels-idx1-ubyte\n", argv[0]);
+    if (argc != 4 && argc != 5) {
+        fprintf(stderr, "Usage: %s mnist-fc-f32.gguf data/MNIST/raw/t10k-images-idx3-ubyte data/MNIST/raw/t10k-labels-idx1-ubyte [CPU/CUDA0]\n", argv[0]);
         exit(1);
     }
 
@@ -36,36 +36,44 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
-    const int nthreads = std::thread::hardware_concurrency();
-
     const int iex = rand() % MNIST_NTEST;
     const std::vector<float> digit(images.begin() + iex*MNIST_NINPUT, images.begin() + (iex+1)*MNIST_NINPUT);
 
     mnist_image_print(stdout, images.data() + iex*MNIST_NINPUT);
 
-    mnist_eval_result result_eval = mnist_graph_eval(argv[1], images.data(), labels.data(), MNIST_NTEST, nthreads);
-    if (result_eval.success) {
-        fprintf(stdout, "%s: predicted digit is %d\n", __func__, result_eval.pred[iex]);
+    const std::string backend = argc >= 5 ? argv[4] : "CPU";
+
+    mnist_eval_result result_eval;
+
+    if (backend == "CPU") {
+        const int ncores_logical = std::thread::hardware_concurrency();
+        result_eval = mnist_graph_eval(
+            argv[1], images.data(), labels.data(), MNIST_NTEST, std::min(ncores_logical, (ncores_logical + 4)/2));
+        if (result_eval.success) {
+            fprintf(stdout, "%s: predicted digit is %d\n", __func__, result_eval.pred[iex]);
 
-        std::pair<double, double> result_loss = mnist_loss(result_eval);
-        fprintf(stdout, "%s: test_loss=%.6lf+-%.6lf\n", __func__, result_loss.first, result_loss.second);
+            std::pair<double, double> result_loss = mnist_loss(result_eval);
+            fprintf(stdout, "%s: test_loss=%.6lf+-%.6lf\n", __func__, result_loss.first, result_loss.second);
 
-        std::pair<double, double> result_acc = mnist_accuracy(result_eval, labels.data());
-        fprintf(stdout, "%s: test_acc=%.2lf+-%.2lf%%\n", __func__, 100.0*result_acc.first, 100.0*result_acc.second);
+            std::pair<double, double> result_acc = mnist_accuracy(result_eval, labels.data());
+            fprintf(stdout, "%s: test_acc=%.2lf+-%.2lf%%\n", __func__, 100.0*result_acc.first, 100.0*result_acc.second);
 
-        return 0;
+            return 0;
+        }
+    } else {
+        fprintf(stdout, "%s: not trying to load a GGML graph from %s because this is only supported for the CPU backend\n", __func__, argv[1]);
     }
 
     const int64_t t_start_us = ggml_time_us();
 
-    mnist_model model = mnist_model_init_from_file(argv[1]);
+    mnist_model model = mnist_model_init_from_file(argv[1], backend);
 
-    mnist_model_build(model, MNIST_NBATCH);
+    mnist_model_build(model, MNIST_NBATCH_LOGICAL, MNIST_NBATCH_PHYSICAL);
 
     const int64_t t_load_us = ggml_time_us() - t_start_us;
 
     fprintf(stdout, "%s: loaded model in %.2lf ms\n", __func__, t_load_us / 1000.0);
-    result_eval = mnist_model_eval(model, images.data(), labels.data(), MNIST_NTEST, nthreads);
+    result_eval = mnist_model_eval(model, images.data(), labels.data(), MNIST_NTEST);
     fprintf(stdout, "%s: predicted digit is %d\n", __func__, result_eval.pred[iex]);
 
     std::pair<double, double> result_loss = mnist_loss(result_eval);
index 697e0b30c8e880e47c28bdb83dcc81394c02adaf..b91fe815f91d781a57493359070f343c6d33b701 100755 (executable)
@@ -42,8 +42,8 @@ def train(model_path):
     )
 
     model.summary()
-    batch_size = 500
-    epochs = 20
+    batch_size = 1000
+    epochs = 30
     model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
 
     t_start = time()
@@ -61,16 +61,14 @@ def train(model_path):
     gguf_writer.add_tensor("conv1.kernel", conv1_kernel, raw_shape=(8, 1, 3, 3))
 
     conv1_bias = model.layers[0].weights[1].numpy()
-    conv1_bias = np.repeat(conv1_bias, 28*28)
-    gguf_writer.add_tensor("conv1.bias", conv1_bias, raw_shape=(1, 8, 28, 28))
+    gguf_writer.add_tensor("conv1.bias", conv1_bias, raw_shape=(1, 8, 1, 1))
 
     conv2_kernel = model.layers[2].weights[0].numpy()
     conv2_kernel = np.moveaxis(conv2_kernel, [0, 1, 2, 3], [2, 3, 1, 0])
     gguf_writer.add_tensor("conv2.kernel", conv2_kernel, raw_shape=(16, 8, 3, 3))
 
     conv2_bias = model.layers[2].weights[1].numpy()
-    conv2_bias = np.repeat(conv2_bias, 14*14)
-    gguf_writer.add_tensor("conv2.bias", conv2_bias, raw_shape=(1, 16, 14, 14))
+    gguf_writer.add_tensor("conv2.bias", conv2_bias, raw_shape=(1, 16, 1, 1))
 
     dense_weight = model.layers[-1].weights[0].numpy()
     dense_weight = dense_weight.transpose()
index 3f52bac5cc6cdcec830296155db2ebc5dfeaf82c..6d8abc5164de4bf1652df406621728b8d850cc9f 100644 (file)
@@ -12,8 +12,8 @@ from time import time
 input_size  = 784  # img_size = (28,28) ---> 28*28=784 in total
 hidden_size = 500  # number of nodes at hidden layer
 num_classes = 10   # number of output classes discrete range [0,9]
-num_epochs  = 20   # number of times which the entire dataset is passed throughout the model
-batch_size  = 500  # the size of input data took for one iteration
+num_epochs  = 30   # number of times which the entire dataset is passed throughout the model
+batch_size  = 1000 # the size of input data used for one iteration
 lr          = 1e-3 # size of step
 
 
@@ -38,8 +38,9 @@ def train(model_path):
     assert len(train_data) == 60000
     assert len(test_data)  == 10000
 
-    train_gen = torch.utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
-    test_gen  = torch.utils.data.DataLoader(dataset=test_data,  batch_size=batch_size, shuffle=False)
+    kwargs_train_test = dict(batch_size=batch_size, num_workers=4, pin_memory=True)
+    train_gen = torch.utils.data.DataLoader(dataset=train_data, shuffle=True,  **kwargs_train_test)
+    test_gen  = torch.utils.data.DataLoader(dataset=test_data,  shuffle=False, **kwargs_train_test)
 
     net = Net(input_size, hidden_size, num_classes)
 
index b3f5cbbc3b17a2f8032092719a0ad8b930994517..41a16221a8ebc773eb8da608182de177d0988247 100644 (file)
@@ -12,8 +12,8 @@
 #endif
 
 int main(int argc, char ** argv) {
-    if (argc != 5) {
-        fprintf(stderr, "Usage: %s mnist-fc mnist-fc-f32.gguf data/MNIST/raw/train-images-idx3-ubyte data/MNIST/raw/train-labels-idx1-ubyte\n", argv[0]);
+    if (argc != 5 && argc != 6) {
+        fprintf(stderr, "Usage: %s mnist-fc mnist-fc-f32.gguf data/MNIST/raw/train-images-idx3-ubyte data/MNIST/raw/train-labels-idx1-ubyte [CPU/CUDA0]\n", argv[0]);
         exit(0);
     }
 
@@ -29,11 +29,11 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
-    mnist_model model = mnist_model_init_random(argv[1]);
+    mnist_model model = mnist_model_init_random(argv[1], argc >= 6 ? argv[5] : "CPU");
 
-    mnist_model_build(model, MNIST_NBATCH);
+    mnist_model_build(model, MNIST_NBATCH_LOGICAL, MNIST_NBATCH_PHYSICAL);
 
-    mnist_model_train(model, images.data(), labels.data(), MNIST_NTRAIN, std::thread::hardware_concurrency());
+    mnist_model_train(model, images.data(), labels.data(), MNIST_NTRAIN, /*nepoch =*/ 30, /*val_split =*/ 0.05f);
 
     mnist_model_save(model, argv[2]);
 }
index e497b6d02388a14ed28aa15fbbb7bdf8b9f85d62..71c0bef8ee7ee69a7d59eb881d9a9c67551630a9 100644 (file)
@@ -66,6 +66,7 @@ extern "C" {
     // "offset" refers to the offset of the tensor data for setting/getting data
     GGML_API GGML_CALL void ggml_backend_tensor_set(      struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
     GGML_API GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
+    GGML_API GGML_CALL void ggml_backend_tensor_memset(   struct ggml_tensor * tensor,     uint8_t value, size_t offset, size_t size);
 
     GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
 
@@ -122,7 +123,7 @@ extern "C" {
     // The backend registry is a registry of all the available backends, and allows initializing backends in a generic way
 
     GGML_API size_t                     ggml_backend_reg_get_count(void);
-    GGML_API size_t                     ggml_backend_reg_find_by_name(const char * name);
+    GGML_API size_t                     ggml_backend_reg_find_by_name(const char * name); // returns index of backend with name, or SIZE_MAX if not found
     GGML_API ggml_backend_t             ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is backend_name:params (params is optional)
     GGML_API const char *               ggml_backend_reg_get_name(size_t i);
     GGML_API ggml_backend_t             ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific
index 536018b669d3d7646ee67f88966427ce7ee7569e..6f2f00a1ea5062fc13ac8163f3d979f855786fe5 100644 (file)
@@ -533,6 +533,7 @@ extern "C" {
 
         GGML_OP_CROSS_ENTROPY_LOSS,
         GGML_OP_CROSS_ENTROPY_LOSS_BACK,
+        GGML_OP_OPT_STEP_ADAMW,
 
         GGML_OP_COUNT,
     };
@@ -569,10 +570,12 @@ extern "C" {
         GGML_LOG_LEVEL_DEBUG = 5
     };
 
+    // this tensor...
     enum ggml_tensor_flag {
-        GGML_TENSOR_FLAG_INPUT  = 1,
-        GGML_TENSOR_FLAG_OUTPUT = 2,
-        GGML_TENSOR_FLAG_PARAM  = 4,
+        GGML_TENSOR_FLAG_INPUT    = 1, // ...is an input for the GGML compute graph
+        GGML_TENSOR_FLAG_OUTPUT   = 2, // ...is an output for the GGML compute graph
+        GGML_TENSOR_FLAG_PARAM    = 4, // ...contains trainable parameters
+        GGML_TENSOR_FLAG_LOSS     = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
     };
 
     // ggml object
@@ -2080,17 +2083,38 @@ extern "C" {
             struct ggml_tensor          * b,
             struct ggml_tensor          * c);
 
+    // AdamW optimizer step
+    // Paper: https://arxiv.org/pdf/1711.05101v3.pdf
+    // PyTorch: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html
+    GGML_API struct ggml_tensor * ggml_opt_step_adamw(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            float                 alpha,
+            float                 beta1,
+            float                 beta2,
+            float                 eps,
+            float                 wd); // weight decay
+
     //
     // automatic differentiation
     //
 
-    GGML_API void ggml_set_param(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * tensor);
+    GGML_API void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor);
+    GGML_API void ggml_set_loss(struct ggml_tensor * tensor);
 
 
     GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
-    GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep);
+    GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate, bool keep);
+
+    GGML_API void ggml_build_opt_adamw(
+            struct ggml_context * ctx,
+            struct ggml_cgraph  * gf,
+            struct ggml_cgraph  * gb,
+            float                 alpha,
+            float                 beta1,
+            float                 beta2,
+            float                 eps,
+            float                 wd); // weight decay
 
     // graph allocation in a context
     GGML_API struct ggml_cgraph * ggml_new_graph         (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
@@ -2098,7 +2122,7 @@ extern "C" {
     GGML_API struct ggml_cgraph * ggml_graph_dup         (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
     GGML_API struct ggml_cgraph   ggml_graph_view        (struct ggml_cgraph * cgraph, int i0, int i1);
     GGML_API void                 ggml_graph_cpy         (struct ggml_cgraph * src, struct ggml_cgraph * dst);
-    GGML_API void                 ggml_graph_reset       (struct ggml_cgraph * cgraph);  // zero grads
+    GGML_API void                 ggml_graph_reset       (struct ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1
     GGML_API void                 ggml_graph_clear       (struct ggml_cgraph * cgraph);
 
     GGML_API size_t ggml_graph_overhead(void);
index 36ca370867c9e7bcccbd58a86febd5bc82704086..b0d4141cc4363dae92235c8915b3c2df796f5d8c 100644 (file)
@@ -38,15 +38,16 @@ extern "C" {
     typedef void * ggml_backend_buffer_context_t;
 
     struct ggml_backend_buffer_i {
-        const char * (*GGML_CALL get_name)   (ggml_backend_buffer_t buffer);
-        void         (*GGML_CALL free_buffer)(ggml_backend_buffer_t buffer);
-        void *       (*GGML_CALL get_base)   (ggml_backend_buffer_t buffer);
-        void         (*GGML_CALL init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
-        void         (*GGML_CALL set_tensor) (ggml_backend_buffer_t buffer,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
-        void         (*GGML_CALL get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
-        bool         (*GGML_CALL cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); // dst is in the buffer, src may be in any buffer
-        void         (*GGML_CALL clear)      (ggml_backend_buffer_t buffer, uint8_t value);
-        void         (*GGML_CALL reset)      (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
+        const char * (*GGML_CALL get_name)      (ggml_backend_buffer_t buffer);
+        void         (*GGML_CALL free_buffer)   (ggml_backend_buffer_t buffer);
+        void *       (*GGML_CALL get_base)      (ggml_backend_buffer_t buffer);
+        void         (*GGML_CALL init_tensor)   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+        void         (*GGML_CALL memset_tensor) (ggml_backend_buffer_t buffer,       struct ggml_tensor * tensor,     uint8_t value, size_t offset, size_t size);
+        void         (*GGML_CALL set_tensor)    (ggml_backend_buffer_t buffer,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+        void         (*GGML_CALL get_tensor)    (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
+        bool         (*GGML_CALL cpy_tensor)    (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); // dst is in the buffer, src may be in any buffer
+        void         (*GGML_CALL clear)         (ggml_backend_buffer_t buffer, uint8_t value);
+        void         (*GGML_CALL reset)         (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
     };
 
     struct ggml_backend_buffer {
index b5d9301a787629de5260cd9e8e8e591c430c9c88..97ca5a1f32f744e9d7f0afe98aa361d1c6a6d722 100644 (file)
@@ -246,6 +246,22 @@ GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void *
     buf->iface.get_tensor(buf, tensor, data, offset, size);
 }
 
+GGML_API GGML_CALL void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+    ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
+
+    GGML_ASSERT(buf != NULL && "tensor buffer not set");
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
+
+    if (!size) {
+        return;
+    }
+    
+    GGML_ASSERT(buf->iface.memset_tensor != NULL && "memset not supported by backend buffer");
+
+    buf->iface.memset_tensor(buf, tensor, value, offset, size);
+}
+
 void ggml_backend_synchronize(ggml_backend_t backend) {
     if (backend->iface.synchronize == NULL) {
         return;
@@ -569,6 +585,12 @@ GGML_CALL static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t
     free(buffer->context);
 }
 
+GGML_CALL static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+    memset((char *)tensor->data + offset, value, size);
+
+    GGML_UNUSED(buffer);
+}
+
 GGML_CALL static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
     memcpy((char *)tensor->data + offset, data, size);
 
@@ -600,6 +622,7 @@ static struct ggml_backend_buffer_i cpu_backend_buffer_i = {
     /* .free_buffer     = */ ggml_backend_cpu_buffer_free_buffer,
     /* .get_base        = */ ggml_backend_cpu_buffer_get_base,
     /* .init_tensor     = */ NULL, // no initialization required
+    /* .memset_tensor   = */ ggml_backend_cpu_buffer_memset_tensor,
     /* .set_tensor      = */ ggml_backend_cpu_buffer_set_tensor,
     /* .get_tensor      = */ ggml_backend_cpu_buffer_get_tensor,
     /* .cpy_tensor      = */ ggml_backend_cpu_buffer_cpy_tensor,
@@ -613,6 +636,7 @@ static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = {
     /* .free_buffer     = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
     /* .get_base        = */ ggml_backend_cpu_buffer_get_base,
     /* .init_tensor     = */ NULL, // no initialization required
+    /* .memset_tensor   = */ ggml_backend_cpu_buffer_memset_tensor,
     /* .set_tensor      = */ ggml_backend_cpu_buffer_set_tensor,
     /* .get_tensor      = */ ggml_backend_cpu_buffer_get_tensor,
     /* .cpy_tensor      = */ ggml_backend_cpu_buffer_cpy_tensor,
@@ -980,6 +1004,7 @@ static struct ggml_backend_buffer_i ggml_backend_multi_buffer_context_interface(
         /* .free_buffer     = */ ggml_backend_multi_buffer_free_buffer,
         /* .get_base        = */ NULL,
         /* .init_tensor     = */ NULL,
+        /* .memset_tensor   = */ NULL,
         /* .set_tensor      = */ NULL,
         /* .get_tensor      = */ NULL,
         /* .cpy_tensor      = */ NULL,
index 06930ba2e5bee04870bbc5165a097f42e0ce7bd1..a2d64da59a4c4404a508572983d1d2d80d20d9a7 100644 (file)
@@ -1036,6 +1036,7 @@ static ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
     /* .free_buffer     = */ ggml_backend_cann_buffer_free_buffer,
     /* .get_base        = */ ggml_backend_cann_buffer_get_base,
     /* .init_tensor     = */ ggml_backend_cann_buffer_init_tensor,
+    /* .memset_tensor   = */ NULL,
     /* .set_tensor      = */ ggml_backend_cann_buffer_set_tensor,
     /* .get_tensor      = */ ggml_backend_cann_buffer_get_tensor,
     /* .cpy_tensor      = */ ggml_backend_cann_buffer_cpy_tensor,
index 982316f565e9c168ab18da500756dbf25fc65af5..e8a340713b2d958c94f20aa7ae64fbee20b03935 100644 (file)
@@ -21,6 +21,8 @@
 #include "ggml-cuda/mmq.cuh"
 #include "ggml-cuda/mmvq.cuh"
 #include "ggml-cuda/norm.cuh"
+#include "ggml-cuda/opt-step-adamw.cuh"
+#include "ggml-cuda/out-prod.cuh"
 #include "ggml-cuda/pad.cuh"
 #include "ggml-cuda/pool2d.cuh"
 #include "ggml-cuda/quantize.cuh"
@@ -493,6 +495,14 @@ GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t
     }
 }
 
+GGML_CALL static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+
+    ggml_cuda_set_device(ctx->device);
+    CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + offset, value, size, cudaStreamPerThread));
+    CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
+}
+
 GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
     ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
 
@@ -544,6 +554,7 @@ static ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {
     /* .free_buffer     = */ ggml_backend_cuda_buffer_free_buffer,
     /* .get_base        = */ ggml_backend_cuda_buffer_get_base,
     /* .init_tensor     = */ ggml_backend_cuda_buffer_init_tensor,
+    /* .memset_tensor   = */ ggml_backend_cuda_buffer_memset_tensor,
     /* .set_tensor      = */ ggml_backend_cuda_buffer_set_tensor,
     /* .get_tensor      = */ ggml_backend_cuda_buffer_get_tensor,
     /* .cpy_tensor      = */ ggml_backend_cuda_buffer_cpy_tensor,
@@ -860,6 +871,7 @@ static struct ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = {
     /* .free_buffer     = */ ggml_backend_cuda_split_buffer_free_buffer,
     /* .get_base        = */ ggml_backend_cuda_split_buffer_get_base,
     /* .init_tensor     = */ ggml_backend_cuda_split_buffer_init_tensor,
+    /* .memset_tensor   = */ NULL,
     /* .set_tensor      = */ ggml_backend_cuda_split_buffer_set_tensor,
     /* .get_tensor      = */ ggml_backend_cuda_split_buffer_get_tensor,
     /* .cpy_tensor      = */ NULL,
@@ -2168,6 +2180,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_REPEAT:
             ggml_cuda_op_repeat(ctx, dst);
             break;
+        case GGML_OP_REPEAT_BACK:
+            ggml_cuda_op_repeat_back(ctx, dst);
+            break;
         case GGML_OP_GET_ROWS:
             ggml_cuda_op_get_rows(ctx, dst);
             break;
@@ -2201,6 +2216,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
                 case GGML_UNARY_OP_NEG:
                     ggml_cuda_op_neg(ctx, dst);
                     break;
+                case GGML_UNARY_OP_STEP:
+                    ggml_cuda_op_step(ctx, dst);
+                    break;
                 case GGML_UNARY_OP_GELU:
                     ggml_cuda_op_gelu(ctx, dst);
                     break;
@@ -2267,6 +2285,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_MUL_MAT_ID:
             ggml_cuda_mul_mat_id(ctx, dst);
             break;
+        case GGML_OP_OUT_PROD:
+            ggml_cuda_out_prod(ctx, dst);
+            break;
         case GGML_OP_SCALE:
             ggml_cuda_op_scale(ctx, dst);
             break;
@@ -2324,6 +2345,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_CROSS_ENTROPY_LOSS:
             ggml_cuda_cross_entropy_loss(ctx, dst);
             break;
+        case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+            ggml_cuda_cross_entropy_loss_back(ctx, dst);
+            break;
+        case GGML_OP_OPT_STEP_ADAMW:
+            ggml_cuda_opt_step_adamw(ctx, dst);
+            break;
         default:
             return false;
     }
@@ -2757,6 +2784,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(op)) {
                 case GGML_UNARY_OP_NEG:
+                case GGML_UNARY_OP_STEP:
                 case GGML_UNARY_OP_GELU:
                 case GGML_UNARY_OP_SILU:
                 case GGML_UNARY_OP_RELU:
@@ -2809,6 +2837,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
                         return false;
                 }
             } break;
+        case GGML_OP_OUT_PROD:
+            return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
         case GGML_OP_GET_ROWS:
             {
                 switch (op->src[0]->type) {
@@ -2865,6 +2895,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
             } break;
         case GGML_OP_DUP:
         case GGML_OP_REPEAT:
+            {
+                ggml_type src0_type = op->src[0]->type;
+                return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
+            } break;
+        case GGML_OP_REPEAT_BACK:
+                return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1;
         case GGML_OP_CONCAT:
             {
                 ggml_type src0_type = op->src[0]->type;
@@ -2931,9 +2967,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
             }
             return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
                 op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
         case GGML_OP_CROSS_ENTROPY_LOSS:
+        case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+        case GGML_OP_OPT_STEP_ADAMW:
             return true;
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
         default:
             return false;
     }
index e1390a0414559fca9bc0a7ae05fd394de3c65615..c7b6be4e2905c1d927f0bc3860a0617b1ab16ce4 100644 (file)
@@ -1,4 +1,5 @@
 #include "binbcast.cuh"
+#include <cstdint>
 
 static __device__ __forceinline__ float op_repeat(const float a, const float b) {
     return b;
@@ -90,6 +91,30 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
     dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
 }
 
+template <typename T>
+static __global__ void k_repeat_back(
+    const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
+    const int64_t ne0, const int64_t ne1, const int64_t ne2) {
+
+    const int64_t tid0 = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
+    const int64_t tid1 = (int64_t) blockIdx.y*blockDim.y + threadIdx.y;
+    const int64_t tid2 = (int64_t) blockIdx.z*blockDim.z + threadIdx.z;
+
+    if (tid0 >= ne0) {
+        return;
+    }
+
+    T sum = 0;
+    for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
+        for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
+            for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
+                sum += src[i2*ne01*ne00 + i1*ne00 + i0];
+            }
+        }
+    }
+    dst[tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
+}
+
 template<float (*bin_op)(const float, const float)>
 struct bin_bcast_cuda {
     template<typename src0_t, typename src1_t, typename dst_t>
@@ -247,6 +272,16 @@ struct bin_bcast_cuda {
     }
 };
 
+template <typename T>
+static void repeat_back_cuda(
+    const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
+    const int64_t ne0, const int64_t ne1, const int64_t ne2, cudaStream_t stream) {
+
+    const dim3 block_dims(WARP_SIZE, 1, 1);
+    const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2);
+    k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2);
+}
+
 template<class op>
 static void ggml_cuda_op_bin_bcast(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
@@ -286,3 +321,35 @@ void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
 }
+
+void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0->type == dst->type);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_can_repeat(dst, src0));
+
+    cudaStream_t stream = ctx.stream();
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    GGML_ASSERT(src0->ne[3] == 1);
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+    GGML_ASSERT(dst->ne[3] == 1);
+
+    switch (dst->type) {
+        case GGML_TYPE_F32: {
+            const float * src0_d = (const float *) src0->data;
+            float       * dst_d  = (float       *) dst->data;
+            repeat_back_cuda<float>(src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, stream);
+        } break;
+        default: {
+            GGML_ASSERT(false);
+        } break;
+    }
+}
index 198c9ef6fd8ea73c3e9e85f5ef0e60676365ca91..3ac1c9b03fcea7255a124e4a0d9e1a18b3bd8b64 100644 (file)
@@ -5,3 +5,5 @@ void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 5575a90f643266bf5eb233596b731e2a938cc1a0..ed09406a88bacb119441e457ffa6c01c60b7f5ce 100644 (file)
@@ -71,6 +71,32 @@ static __global__ void cross_entropy_loss_f32(const float * logits, const float
     dst[blockIdx.x] = loss;
 }
 
+static __global__ void cross_entropy_loss_back_f32(const float * logits, const float * labels, const float * loss, float * dst, const int nclasses) {
+    extern __shared__ float tmp[];
+
+    float maxval = -INFINITY;
+    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+        const float val = logits[blockIdx.x*nclasses + i];
+        maxval = fmaxf(maxval, val);
+        tmp[i] = val;
+    }
+    maxval = warp_reduce_max(maxval);
+
+    float sum = 0.0f;
+    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+        const float val = expf(tmp[i] - maxval);
+        sum += val;
+        tmp[i] = val;
+    }
+    sum = warp_reduce_sum(sum);
+    const float sm_scale = 1.0f/sum;
+
+    const float d_by_nrows = *loss/gridDim.x;
+    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+        dst[blockIdx.x*nclasses + i] = (tmp[i]*sm_scale - labels[blockIdx.x*nclasses + i])*d_by_nrows;
+    }
+}
+
 void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
@@ -104,3 +130,37 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
     // Combine results from individual blocks:
     sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream);
 }
+
+void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    const ggml_tensor * opt0 = dst->src[2];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(opt0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(ggml_is_contiguous(opt0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_are_same_shape(src0, src1));
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    const int64_t ne00  = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    const float * src0_d = (const float *) src0->data;
+    const float * src1_d = (const float *) src1->data;
+    const float * opt0_d = (const float *) opt0->data;
+    float       * dst_d  = (float       *) dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    const dim3 blocks_dim(WARP_SIZE, 1, 1);
+    const dim3 blocks_num(nrows, 1, 1);
+    const int shmem = ne00*sizeof(float);
+
+    cross_entropy_loss_back_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, opt0_d, dst_d, ne00);
+}
index 9d7b8b0f0082ba52fe90da130d58a0b13597a967..9ec7152ff4518607a01f61ab2f4e209ec8cef6dd 100644 (file)
@@ -3,3 +3,5 @@
 #define CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE 256
 
 void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/src/ggml-cuda/opt-step-adamw.cu b/src/ggml-cuda/opt-step-adamw.cu
new file mode 100644 (file)
index 0000000..d6f13a9
--- /dev/null
@@ -0,0 +1,80 @@
+#include "opt-step-adamw.cuh"
+
+#include <cstdint>
+
+static __global__ void opt_step_adamw_f32(
+    float * __restrict__ x, const float * __restrict__ g, float * __restrict__ g_m, float * __restrict__ g_v, const int64_t k,
+    const float alpha, const float beta1, const float beta2, const float eps, const float wd,
+    const float beta1h, const float beta2h) {
+
+    const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    const float gi = g[i];
+    const float gmi = g_m[i]*beta1 +    gi*(1.0f - beta1);
+    const float gvi = g_v[i]*beta2 + gi*gi*(1.0f - beta2);
+
+    g_m[i] = gmi;
+    g_v[i] = gvi;
+
+    const float mh =       gmi*beta1h;
+    const float vh = sqrtf(gvi*beta2h) + eps;
+
+    x[i] = x[i]*(1.0f - alpha*wd) - mh/vh;
+}
+
+static void opt_step_adamw_f32_cuda(
+    float * x, const float * g, float * g_m, float * g_v, const int64_t k,
+    const float alpha, const float beta1, const float beta2, const float eps, const float wd,
+    const float beta1h, const float beta2h, cudaStream_t stream) {
+
+    const dim3 block_dims(CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1);
+    const dim3 block_nums((k + CUDA_OPT_STEP_ADAMW_BLOCK_SIZE - 1) / CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1);
+    opt_step_adamw_f32<<<block_nums, block_dims, 0, stream>>>(x, g, g_m, g_v, k, alpha, beta1, beta2, eps, wd, beta1h, beta2h);
+}
+
+void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0        = dst->src[0];
+    const ggml_tensor * src0_grad   = dst->src[1];
+    const ggml_tensor * src0_grad_m = dst->src[2];
+    const ggml_tensor * src0_grad_v = dst->src[3];
+
+    GGML_ASSERT(src0->type        == GGML_TYPE_F32);
+    GGML_ASSERT(src0_grad->type   == GGML_TYPE_F32);
+    GGML_ASSERT(src0_grad_m->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0_grad_v->type == GGML_TYPE_F32);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src0_grad));
+    GGML_ASSERT(ggml_is_contiguous(src0_grad_m));
+    GGML_ASSERT(ggml_is_contiguous(src0_grad_v));
+    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
+    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
+    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
+
+    float       * src0_d        = (float       *) src0->data;
+    const float * src0_grad_d   = (const float *) src0_grad->data;
+    float       * src0_grad_m_d = (float       *) src0_grad_m->data;
+    float       * src0_grad_v_d = (float       *) src0_grad_v->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    const int64_t ne = ggml_nelements(src0);
+
+    int64_t iter;  memcpy(&iter,  &dst->op_params[0], sizeof(int64_t));
+    float   alpha; memcpy(&alpha, &dst->op_params[2], sizeof(float));
+    float   beta1; memcpy(&beta1, &dst->op_params[3], sizeof(float));
+    float   beta2; memcpy(&beta2, &dst->op_params[4], sizeof(float));
+    float   eps;   memcpy(&eps,   &dst->op_params[5], sizeof(float));
+    float   wd;    memcpy(&wd,    &dst->op_params[6], sizeof(float));
+
+    const float beta1h  = alpha/(1.0f - powf(beta1, iter));
+    const float beta2h  =  1.0f/(1.0f - powf(beta2, iter));
+
+    opt_step_adamw_f32_cuda(src0_d, src0_grad_d, src0_grad_m_d, src0_grad_v_d, ne, alpha, beta1, beta2, eps, wd, beta1h, beta2h, stream);
+
+    iter++;
+    memcpy(&dst->op_params[0], &iter, sizeof(int64_t));
+}
diff --git a/src/ggml-cuda/opt-step-adamw.cuh b/src/ggml-cuda/opt-step-adamw.cuh
new file mode 100644 (file)
index 0000000..58d6f6e
--- /dev/null
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_OPT_STEP_ADAMW_BLOCK_SIZE 256
+
+void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/src/ggml-cuda/out-prod.cu b/src/ggml-cuda/out-prod.cu
new file mode 100644 (file)
index 0000000..657d50e
--- /dev/null
@@ -0,0 +1,52 @@
+#include "out-prod.cuh"
+#include "vendors/cuda.h"
+
+#include <cstdint>
+
+void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+
+    GGML_ASSERT(ne01 == ne11);
+    GGML_ASSERT(ne0 == ne00);
+    GGML_ASSERT(ne1 == ne10);
+
+    GGML_ASSERT(ne2 == src0->ne[2]);
+    GGML_ASSERT(ne2 == src1->ne[2]);
+    GGML_ASSERT(ne3 == src0->ne[3]);
+    GGML_ASSERT(ne3 == src1->ne[3]);
+
+    const float * src0_d = (const float *) src0->data;
+    const float * src1_d = (const float *) src1->data;
+    float       *  dst_d = (float       *)  dst->data;
+
+    cudaStream_t   stream = ctx.stream();
+    cublasHandle_t handle = ctx.cublas_handle();
+
+    const float alpha = 1.0f;
+    const float beta = 0.0f;
+
+    GGML_ASSERT(ne2 == 1);
+    GGML_ASSERT(ne3 == 1);
+    CUBLAS_CHECK(cublasSetStream(handle, stream));
+
+    const bool src1_T = ggml_is_transposed(src1);
+    const cublasOperation_t src1_cublas_op =  src1_T ? CUBLAS_OP_N : CUBLAS_OP_T;
+    const int64_t           ldb            = (src1_T ?        nb10 :        nb11) /  sizeof(float);
+    GGML_ASSERT(                             (src1_T ?        nb11 :        nb10) == sizeof(float));
+
+    CUBLAS_CHECK(
+        cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
+                ne0, ne1, ne01,
+                &alpha, src0_d, ne00,
+                        src1_d, ldb,
+                &beta,  dst_d,  ne0));
+}
diff --git a/src/ggml-cuda/out-prod.cuh b/src/ggml-cuda/out-prod.cuh
new file mode 100644 (file)
index 0000000..a0046f5
--- /dev/null
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 8ac669f94e2de23ed18a999041cab0861f8c5480..163b5a8ffec6b72f1ffac66a6b4c31f94e139808 100644 (file)
@@ -10,6 +10,16 @@ static __global__ void neg_f32(const float * x, float * dst, const int k) {
     dst[i] = -x[i];
 }
 
+static __global__ void step_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    dst[i] = x[i] > 0.0f;
+}
+
 static __global__ void gelu_f32(const float * x, float * dst, const int k) {
     const float GELU_COEF_A    = 0.044715f;
     const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
@@ -134,6 +144,11 @@ static void neg_f32_cuda(const float * x, float * dst, const int k, cudaStream_t
     neg_f32<<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
+static void step_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_STEP_BLOCK_SIZE - 1) / CUDA_STEP_BLOCK_SIZE;
+    step_f32<<<num_blocks, CUDA_STEP_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
 static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
     gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -213,6 +228,20 @@ void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     neg_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
 }
 
+void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    step_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
 void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const float * src0_d = (const float *)src0->data;
index ed2ffc461e8102caafcf1c2d48fa82892a8eafc4..fe519f6a232dfa053bd704d6f3f7c0942eda5510 100644 (file)
@@ -1,6 +1,7 @@
 #include "common.cuh"
 
 #define CUDA_NEG_BLOCK_SIZE 256
+#define CUDA_STEP_BLOCK_SIZE 256
 #define CUDA_GELU_BLOCK_SIZE 256
 #define CUDA_SILU_BLOCK_SIZE 256
 #define CUDA_TANH_BLOCK_SIZE 256
@@ -15,6 +16,8 @@
 
 void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
+void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
 void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 41ac63fa48e0fadc7565c409c9cffbf654804c59..d0395ff9f8ba135edc70074bfd0588f741cbe773 100644 (file)
@@ -1872,6 +1872,7 @@ static ggml_backend_buffer_i ggml_backend_kompute_buffer_i = {
     /* .free_buffer     = */ ggml_backend_kompute_buffer_free_buffer,
     /* .get_base        = */ ggml_backend_kompute_buffer_get_base,
     /* .init_tensor     = */ NULL,
+    /* .memset_tensor   = */ NULL,
     /* .set_tensor      = */ ggml_backend_kompute_buffer_set_tensor,
     /* .get_tensor      = */ ggml_backend_kompute_buffer_get_tensor,
     /* .cpy_tensor      = */ NULL,
index f04e5af71f9ebdc57dd39fe39d9bb4834d00c664..87f2e16f3229eacad50cd4637914de1804b790a9 100644 (file)
@@ -3165,6 +3165,7 @@ static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
     /* .free_buffer     = */ ggml_backend_metal_buffer_free_buffer,
     /* .get_base        = */ ggml_backend_metal_buffer_get_base,
     /* .init_tensor     = */ NULL,
+    /* .memset_tensor   = */ NULL,
     /* .set_tensor      = */ ggml_backend_metal_buffer_set_tensor,
     /* .get_tensor      = */ ggml_backend_metal_buffer_get_tensor,
     /* .cpy_tensor      = */ ggml_backend_metal_buffer_cpy_tensor,
index 8f9d0a46019691a6d66cb8c33b13a4335ea1f09a..6ed428cefe87152b1b9fbf8f6e79ab40459e0192 100644 (file)
@@ -469,6 +469,7 @@ static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
     /* .free_buffer     = */ ggml_backend_rpc_buffer_free_buffer,
     /* .get_base        = */ ggml_backend_rpc_buffer_get_base,
     /* .init_tensor     = */ ggml_backend_rpc_buffer_init_tensor,
+    /* .memset_tensor   = */ NULL,
     /* .set_tensor      = */ ggml_backend_rpc_buffer_set_tensor,
     /* .get_tensor      = */ ggml_backend_rpc_buffer_get_tensor,
     /* .cpy_tensor      = */ ggml_backend_rpc_buffer_cpy_tensor,
index 0d884f89a4e7bac03c179d14f2cb28b3712f3a53..62eafd01de5f4952843e19a1b70c1732fd1499cc 100644 (file)
@@ -4318,6 +4318,7 @@ static struct ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
     /* .free_buffer     = */ ggml_backend_sycl_buffer_free_buffer,
     /* .get_base        = */ ggml_backend_sycl_buffer_get_base,
     /* .init_tensor     = */ ggml_backend_sycl_buffer_init_tensor,
+    /* .memset_tensor   = */ NULL,
     /* .set_tensor      = */ ggml_backend_sycl_buffer_set_tensor,
     /* .get_tensor      = */ ggml_backend_sycl_buffer_get_tensor,
     /* .cpy_tensor      = */ ggml_backend_sycl_buffer_cpy_tensor,
index d6f647c89fbff8dbf1086255b3eb039748b3fed0..611fe69acb473e6924e4fd3ad24307cff7b1dfd2 100644 (file)
@@ -6221,6 +6221,7 @@ static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = {
     /* .free_buffer     = */ ggml_backend_vk_buffer_free_buffer,
     /* .get_base        = */ ggml_backend_vk_buffer_get_base,
     /* .init_tensor     = */ ggml_backend_vk_buffer_init_tensor,
+    /* .memset_tensor   = */ NULL,
     /* .set_tensor      = */ ggml_backend_vk_buffer_set_tensor,
     /* .get_tensor      = */ ggml_backend_vk_buffer_get_tensor,
     /* .cpy_tensor      = */ ggml_backend_vk_buffer_cpy_tensor,
index f188e448e7a3c5b59554a544e81aea9794c01601..a5334e267b2de8a1431caf85651fce484093e974 100644 (file)
@@ -1,6 +1,7 @@
 #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
 #define _USE_MATH_DEFINES // For M_PI on MSVC
 
+#include "ggml-backend.h"
 #include "ggml-impl.h"
 #include "ggml-quants.h"
 #include "ggml.h"
@@ -2977,9 +2978,10 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
 
     "CROSS_ENTROPY_LOSS",
     "CROSS_ENTROPY_LOSS_BACK",
+    "OPT_STEP_ADAMW",
 };
 
-static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
+static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -3070,9 +3072,10 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
 
     "cross_entropy_loss(x,y)",
     "cross_entropy_loss_back(x,y)",
+    "adamw(x)",
 };
 
-static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
+static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -4079,7 +4082,11 @@ static void ggml_set_op_params_f32(struct ggml_tensor * tensor, uint32_t i, floa
 }
 
 struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
-    memset(tensor->data, 0, ggml_nbytes(tensor));
+    if (tensor->buffer) {
+        ggml_backend_tensor_memset(tensor, 0, 0, ggml_nbytes(tensor));
+    } else {
+        memset(tensor->data, 0, ggml_nbytes(tensor));
+    }
     return tensor;
 }
 
@@ -8305,11 +8312,46 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
     return result;
 }
 
-////////////////////////////////////////////////////////////////////////////////
+// opt_step_adamw
 
-void ggml_set_param(
+struct ggml_tensor * ggml_opt_step_adamw(
         struct ggml_context * ctx,
-        struct ggml_tensor * tensor) {
+        struct ggml_tensor  * a,
+        float                 alpha,
+        float                 beta1,
+        float                 beta2,
+        float                 eps,
+        float                 wd) {
+    GGML_ASSERT(a->grad);
+    GGML_ASSERT(alpha >  0.0f);
+    GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f);
+    GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f);
+    GGML_ASSERT(eps   >= 0.0f);
+    GGML_ASSERT(wd    >= 0.0f && wd    <= 1.0f);
+
+    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+    result->op   = GGML_OP_OPT_STEP_ADAMW;
+    result->grad = NULL;
+    result->src[0] = a;
+    result->src[1] = a->grad;
+    result->src[2] = ggml_dup_tensor(ctx, a->grad);
+    result->src[3] = ggml_dup_tensor(ctx, a->grad);
+
+    const int64_t iter = 1;
+    memcpy(&result->op_params[0], &iter, sizeof(int64_t));
+    ggml_set_op_params_f32(result, 2, alpha);
+    ggml_set_op_params_f32(result, 3, beta1);
+    ggml_set_op_params_f32(result, 4, beta2);
+    ggml_set_op_params_f32(result, 5, eps);
+    ggml_set_op_params_f32(result, 6, wd);
+
+    return result;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor) {
     tensor->flags |= GGML_TENSOR_FLAG_PARAM;
 
     GGML_ASSERT(tensor->grad == NULL);
@@ -8317,6 +8359,13 @@ void ggml_set_param(
     ggml_format_name(tensor->grad, "%s (grad)", tensor->name);
 }
 
+void ggml_set_loss(struct ggml_tensor * tensor) {
+    GGML_ASSERT(ggml_is_scalar(tensor));
+    GGML_ASSERT(tensor->type == GGML_TYPE_F32);
+    GGML_ASSERT(tensor->grad);
+    tensor->flags |= GGML_TENSOR_FLAG_LOSS;
+}
+
 // ggml_compute_forward_dup
 
 static void ggml_compute_forward_dup_same_cont(
@@ -17391,7 +17440,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
     const int64_t ir0 = dr*ith;
     const int64_t ir1 = MIN(ir0 + dr, nr);
 
-    float * d   = (float *) opt0->data;
+    const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr;
 
     for (int64_t i1 = ir0; i1 < ir1; i1++) {
         float * ds0 = (float *)((char *) dst->data  + i1*dst->nb[1]);
@@ -17415,7 +17464,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
 
         // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
         ggml_vec_sub_f32(nc, ds0, ds0, s1);
-        ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr);
+        ggml_vec_scale_f32(nc, ds0, d_by_nr);
 
 #ifndef NDEBUG
         for (int i = 0; i < nc; ++i) {
@@ -17444,6 +17493,94 @@ static void ggml_compute_forward_cross_entropy_loss_back(
     }
 }
 
+static void ggml_compute_forward_opt_step_adamw_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0        = dst->src[0];
+    const struct ggml_tensor * src0_grad   = dst->src[1];
+    const struct ggml_tensor * src0_grad_m = dst->src[2];
+    const struct ggml_tensor * src0_grad_v = dst->src[3];
+    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    /* const float   gnorm = 1.0f; */
+    int64_t       iter;   memcpy(&iter, &dst->op_params[0], sizeof(int64_t));
+    const float   alpha = ggml_get_op_params_f32(dst, 2);
+    const float   beta1 = ggml_get_op_params_f32(dst, 3);
+    const float   beta2 = ggml_get_op_params_f32(dst, 4);
+    const float   eps   = ggml_get_op_params_f32(dst, 5);
+    const float   wd    = ggml_get_op_params_f32(dst, 6);
+
+    const float beta1h  = alpha/(1.0f - powf(beta1, iter));
+    const float beta2h  =  1.0f/(1.0f - powf(beta2, iter));
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        const int64_t i03 = ir/(ne02*ne01);
+        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+        const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;
+
+        float       * w = (float       *) ((char       *) src0->data        + offset); // weight
+        const float * g = (const float *) ((const char *) src0_grad->data   + offset); // grad
+        float       * m = (float       *) ((char       *) src0_grad_m->data + offset);
+        float       * v = (float       *) ((char       *) src0_grad_v->data + offset);
+
+        for (int i00 = 0; i00 < ne00; ++i00) {
+            m[i00] = m[i00]*beta1 +        g[i00]*(1.0f - beta1);
+            v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2);
+
+            const float mh =       m[i00]*beta1h;
+            const float vh = sqrtf(v[i00]*beta2h) + eps;
+
+            // The weight decay is applied independently of the Adam momenta m and v.
+            // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
+            // See: https://arxiv.org/pdf/1711.05101v3.pdf
+            w[i00] = w[i00]*(1.0f - alpha*wd) - mh/vh;
+        }
+    }
+
+    ggml_barrier(params->threadpool);
+    if (ith != 0) {
+        return;
+    }
+
+    iter++;
+    memcpy(&dst->op_params[0], &iter, sizeof(int64_t));
+}
+
+static void ggml_compute_forward_opt_step_adamw(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_opt_step_adamw_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
 /////////////////////////////////
 
 static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
@@ -17789,6 +17926,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
                 ggml_compute_forward_cross_entropy_loss_back(params, tensor);
             }
             break;
+        case GGML_OP_OPT_STEP_ADAMW:
+            {
+                ggml_compute_forward_opt_step_adamw(params, tensor);
+            }
+            break;
         case GGML_OP_NONE:
             {
                 // nop
@@ -17943,7 +18085,7 @@ void ggml_build_backward_gradient_checkpointing(
         struct ggml_tensor  * * checkpoints,
         int                     n_checkpoints) {
     ggml_graph_cpy(gf, gb_tmp);
-    ggml_build_backward_expand(ctx, gf, gb_tmp, true);
+    ggml_build_backward_expand(ctx, gf, gb_tmp, false, true);
 
     if (n_checkpoints <= 0) {
         ggml_graph_cpy(gb_tmp, gb);
@@ -17981,42 +18123,93 @@ void ggml_build_backward_gradient_checkpointing(
     ggml_hash_map_free(replacements);
 }
 
-// functions to change gradients considering the case that input a might be initial gradient with zero value
-
-static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) {
+// utility functions to change gradients
+// if a is in acc_table, modify gradients in-place and mark result as gradient accumulator
+// else if a is in zero_table, replace a
+// else, just add/subtract/etc. the gradients
+
+static struct ggml_tensor * ggml_add_or_set(
+        struct ggml_context  * ctx,
+        struct ggml_tensor   * a,
+        struct ggml_tensor   * b,
+        struct ggml_hash_set * zero_table,
+        struct ggml_hash_set * acc_table) {
+    if (ggml_hash_contains(acc_table, a)) {
+        struct ggml_tensor * ret = ggml_add_impl(ctx, a, b, true);
+        const size_t insert_result = ggml_hash_insert(acc_table, ret);
+        GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
+        GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
+        return ret;
+    }
     if (ggml_hash_contains(zero_table, a)) {
         return b;
-    } else {
-        return ggml_add_impl(ctx, a, b, false);
     }
+    return ggml_add_impl(ctx, a, b, false);
 }
 
-static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct ggml_hash_set * zero_table) {
+static struct ggml_tensor * ggml_acc_or_set(
+        struct ggml_context  * ctx,
+        struct ggml_tensor   * a,
+        struct ggml_tensor   * b,
+        const  size_t          nb1,
+        const  size_t          nb2,
+        const  size_t          nb3,
+        const  size_t          offset,
+        struct ggml_hash_set * zero_table,
+        struct ggml_hash_set * acc_table) {
+    if (ggml_hash_contains(acc_table, a)) {
+        struct ggml_tensor * ret = ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
+        const size_t insert_result = ggml_hash_insert(acc_table, ret);
+        GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
+        GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
+        return ret;
+    }
     if (ggml_hash_contains(zero_table, a)) {
-        struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f);
+        struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
         return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
-    } else {
-        return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
     }
+    return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
 }
 
-static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) {
+static struct ggml_tensor * ggml_add1_or_set(
+        struct ggml_context  * ctx,
+        struct ggml_tensor   * a,
+        struct ggml_tensor   * b,
+        struct ggml_hash_set * zero_table,
+        struct ggml_hash_set * acc_table) {
+    if (ggml_hash_contains(acc_table, a)) {
+        struct ggml_tensor * ret = ggml_add1_impl(ctx, a, b, true);
+        const size_t insert_result = ggml_hash_insert(acc_table, ret);
+        GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
+        GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
+        return ret;
+    }
     if (ggml_hash_contains(zero_table, a)) {
         return ggml_repeat(ctx, b, a);
-    } else {
-        return ggml_add1_impl(ctx, a, b, false);
     }
+    return ggml_add1_impl(ctx, a, b, false);
 }
 
-static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) {
+static struct ggml_tensor * ggml_sub_or_set(
+        struct ggml_context  * ctx,
+        struct ggml_tensor   * a,
+        struct ggml_tensor   * b,
+        struct ggml_hash_set * zero_table,
+        struct ggml_hash_set * acc_table) {
+    if (ggml_hash_contains(acc_table, a)) {
+        struct ggml_tensor * ret = ggml_sub_impl(ctx, a, b, true);
+        const size_t insert_result = ggml_hash_insert(acc_table, ret);
+        GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
+        GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
+        return ret;
+    }
     if (ggml_hash_contains(zero_table, a)) {
         return ggml_neg(ctx, b);
-    } else {
-        return ggml_sub_impl(ctx, a, b, false);
     }
+    return ggml_sub_impl(ctx, a, b, false);
 }
 
-static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set * zero_table) {
+static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set * zero_table, struct ggml_hash_set * acc_table) {
     struct ggml_tensor * src0 = tensor->src[0];
     struct ggml_tensor * src1 = tensor->src[1];
     struct ggml_tensor * src2 = tensor->src[2];
@@ -18025,38 +18218,38 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
         case GGML_OP_DUP:
             {
                 if (src0->grad) {
-                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
+                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
                 }
             } break;
         case GGML_OP_ADD:
             {
                 if (src0->grad) {
-                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
+                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
                 }
                 if (src1->grad) {
                     if (ggml_are_same_shape(src0, src1)) {
-                        src1->grad = ggml_add_or_set(ctx, src1->grad,                       tensor->grad,        zero_table);
+                        src1->grad = ggml_add_or_set(ctx, src1->grad,                       tensor->grad,        zero_table, acc_table);
                     } else {
-                        src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_repeat_back(ctx, tensor->grad, src1), zero_table);
+                        src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_repeat_back(ctx, tensor->grad, src1), zero_table, acc_table);
                     }
                 }
             } break;
         case GGML_OP_ADD1:
             {
                 if (src0->grad) {
-                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
+                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
                 }
                 if (src1->grad) {
                     src1->grad = ggml_add_or_set(ctx,
                         src1->grad,
                         ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean
-                        zero_table);
+                        zero_table, acc_table);
                 }
             } break;
         case GGML_OP_ACC:
             {
                 if (src0->grad) {
-                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
+                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
                 }
                 if (src1->grad) {
                     const size_t nb1     = ((int32_t *) tensor->op_params)[0];
@@ -18078,16 +18271,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                             ggml_reshape(ctx,
                                 ggml_cont(ctx, tensor_grad_view),
                                 src1->grad),
-                            zero_table);
+                            zero_table, acc_table);
                 }
             } break;
         case GGML_OP_SUB:
             {
                 if (src0->grad) {
-                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
+                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
                 }
                 if (src1->grad) {
-                    src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table);
+                    src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table);
                 }
             } break;
         case GGML_OP_MUL:
@@ -18097,14 +18290,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                         ggml_add_or_set(ctx,
                                 src0->grad,
                                 ggml_mul(ctx, src1, tensor->grad),
-                                zero_table);
+                                zero_table, acc_table);
                 }
                 if (src1->grad) {
                     src1->grad =
                         ggml_add_or_set(ctx,
                                 src1->grad,
                                 ggml_mul(ctx, src0, tensor->grad),
-                                zero_table);
+                                zero_table, acc_table);
                 }
             } break;
         case GGML_OP_DIV:
@@ -18114,7 +18307,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                         ggml_add_or_set(ctx,
                                 src0->grad,
                                 ggml_div(ctx, tensor->grad, src1),
-                                zero_table);
+                                zero_table, acc_table);
                 }
                 if (src1->grad) {
                     src1->grad =
@@ -18123,7 +18316,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 ggml_mul(ctx,
                                     tensor->grad,
                                     ggml_div(ctx, tensor, src1)),
-                                zero_table);
+                                zero_table, acc_table);
                 }
             } break;
         case GGML_OP_SQR:
@@ -18135,7 +18328,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 ggml_scale(ctx,
                                     ggml_mul(ctx, src0, tensor->grad),
                                     2.0f),
-                                zero_table);
+                                zero_table, acc_table);
                 }
             } break;
         case GGML_OP_SQRT:
@@ -18149,7 +18342,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                         tensor->grad,
                                         tensor),
                                     0.5f),
-                                zero_table);
+                                zero_table, acc_table);
                 }
             } break;
         case GGML_OP_LOG:
@@ -18161,7 +18354,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 ggml_div(ctx,
                                     tensor->grad,
                                     src0),
-                                zero_table);
+                                zero_table, acc_table);
                 }
             } break;
         case GGML_OP_SIN:
@@ -18173,7 +18366,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 ggml_mul(ctx,
                                     tensor->grad,
                                     ggml_cos(ctx, src0)),
-                                zero_table);
+                                zero_table, acc_table);
                 }
             } break;
         case GGML_OP_COS:
@@ -18185,7 +18378,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 ggml_mul(ctx,
                                     tensor->grad,
                                     ggml_sin(ctx, src0)),
-                                zero_table);
+                                zero_table, acc_table);
                 }
             } break;
         case GGML_OP_SUM:
@@ -18195,7 +18388,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                         ggml_add1_or_set(ctx,
                                 src0->grad,
                                 tensor->grad,
-                                zero_table);
+                                zero_table, acc_table);
                 }
             } break;
         case GGML_OP_SUM_ROWS:
@@ -18207,7 +18400,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 ggml_repeat(ctx,
                                     tensor->grad,
                                     src0->grad),
-                                zero_table);
+                                zero_table, acc_table);
                 }
             } break;
         case GGML_OP_MEAN:
@@ -18222,7 +18415,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src0->grad = ggml_add_or_set(ctx,
                             src0->grad,
                             ggml_repeat_back(ctx, tensor->grad, src0->grad),
-                            zero_table);
+                            zero_table, acc_table);
                 }
             } break;
         case GGML_OP_REPEAT_BACK:
@@ -18232,7 +18425,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src0->grad = ggml_add_or_set(ctx,
                             src0->grad,
                             ggml_repeat(ctx, tensor->grad, src0->grad),
-                            zero_table);
+                            zero_table, acc_table);
                 }
             } break;
         case GGML_OP_CONCAT:
@@ -18257,7 +18450,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src0->grad = ggml_add_or_set(ctx,
                             src0->grad,
                             ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
-                            zero_table);
+                            zero_table, acc_table);
                 }
             } break;
         case GGML_OP_RMS_NORM_BACK:
@@ -18305,7 +18498,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                         ggml_add_or_set(ctx,
                                 src0->grad, // [n,m,q1,r1]
                                 s1_tg,      // [n,m,q1,r1]
-                                zero_table);
+                                zero_table, acc_table);
                 }
                 if (src1->grad) {
                     src1->grad =
@@ -18323,7 +18516,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                     src0,                           // [n,m,q1,r1]
                                     ggml_transpose(ctx,             // [p,m,qq,rr]
                                         tensor->grad)),             // [m,p,qq,rr]
-                                zero_table);
+                                zero_table, acc_table);
                 }
             } break;
         case GGML_OP_MUL_MAT_ID:
@@ -18345,7 +18538,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                         ggml_add_or_set(ctx,
                             src0->grad,
                             ggml_scale_impl(ctx, tensor->grad, s, false),
-                            zero_table);
+                            zero_table, acc_table);
                 }
             } break;
         case GGML_OP_SET:
@@ -18374,7 +18567,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                             tensor->grad,
                             ggml_neg(ctx, tensor_grad_view),
                             nb1, nb2, nb3, offset, false),
-                        zero_table);
+                        zero_table, acc_table);
                 }
 
                 if (src1->grad) {
@@ -18384,7 +18577,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                             ggml_reshape(ctx,
                                 ggml_cont(ctx, tensor_grad_view),
                                 src1->grad),
-                            zero_table);
+                            zero_table, acc_table);
                 }
             } break;
         case GGML_OP_CPY:
@@ -18395,7 +18588,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 // tensor = src0 * 1 + src1 * 0
                 if (src0->grad) {
                     // dsrc0 = dtensor * 1
-                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
+                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
                 }
                 if (src1->grad) {
                     // dsrc1 = dtensor * 0 -> noop
@@ -18407,7 +18600,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 if (src0->grad) {
                     GGML_ASSERT(ggml_is_contiguous(src0->grad));
                     GGML_ASSERT(ggml_is_contiguous(tensor->grad));
-                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
+                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
                 }
             } break;
         case GGML_OP_RESHAPE:
@@ -18421,7 +18614,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                     ? tensor->grad
                                     : ggml_cont(ctx, tensor->grad),
                                 src0->grad),
-                        zero_table);
+                        zero_table, acc_table);
                 }
             } break;
         case GGML_OP_VIEW:
@@ -18450,7 +18643,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                         nb3 = (nb3 / n0) * ng;
                     }
 
-                    src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table);
+                    src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table, acc_table);
                 }
             } break;
         case GGML_OP_PERMUTE:
@@ -18475,7 +18668,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 axes_backward[1],
                                 axes_backward[2],
                                 axes_backward[3]),
-                            zero_table);
+                            zero_table, acc_table);
                 }
             } break;
         case GGML_OP_TRANSPOSE:
@@ -18485,7 +18678,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src0->grad =
                         ggml_add_or_set(ctx, src0->grad,
                             ggml_transpose(ctx, tensor->grad),
-                        zero_table);
+                        zero_table, acc_table);
                 }
             } break;
         case GGML_OP_GET_ROWS:
@@ -18497,7 +18690,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                             // last ggml_get_rows_back argument src0->grad is only
                             // necessary to setup correct output shape
                             ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
-                        zero_table);
+                        zero_table, acc_table);
                 }
                 if (src1->grad) {
                     // noop
@@ -18521,7 +18714,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                             /* ggml_diag_mask_inf_impl() shouldn't be here */
                             /* ref:  https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */
                             ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
-                        zero_table);
+                        zero_table, acc_table);
                 }
             } break;
         case GGML_OP_DIAG_MASK_ZERO:
@@ -18532,7 +18725,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src0->grad =
                         ggml_add_or_set(ctx, src0->grad,
                             ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
-                        zero_table);
+                        zero_table, acc_table);
                 }
             } break;
         case GGML_OP_SOFT_MAX:
@@ -18542,7 +18735,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src0->grad =
                         ggml_add_or_set(ctx, src0->grad,
                             ggml_soft_max_back(ctx, tensor->grad, tensor),
-                        zero_table);
+                        zero_table, acc_table);
                 }
 
             } break;
@@ -18583,7 +18776,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 attn_factor,
                                 beta_fast,
                                 beta_slow),
-                            zero_table);
+                            zero_table, acc_table);
                 }
             } break;
         case GGML_OP_ROPE_BACK:
@@ -18619,7 +18812,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 beta_fast,
                                 beta_slow,
                                 false),
-                            zero_table);
+                            zero_table, acc_table);
                 }
             } break;
         case GGML_OP_CLAMP:
@@ -18644,7 +18837,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src1->grad = ggml_add_or_set(ctx,
                             src1->grad,
                             ggml_im2col_back(ctx, src0, tensor->grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D),
-                            zero_table);
+                            zero_table, acc_table);
                 }
             } break;
         case GGML_OP_IM2COL_BACK:
@@ -18673,7 +18866,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src0->grad = ggml_add_or_set(ctx,
                             src0->grad,
                             ggml_pool_2d_back(ctx, tensor->grad, src0, op, k0, k1, s0, s1, p0, p1),
-                            zero_table);
+                            zero_table, acc_table);
                 }
             } break;
         case GGML_OP_POOL_2D_BACK:
@@ -18738,7 +18931,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src0->grad = ggml_add_or_set(ctx,
                             src0->grad,
                             grad_q,
-                            zero_table);
+                            zero_table, acc_table);
                 }
                 if (src1->grad) {
                     struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k);
@@ -18746,7 +18939,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src1->grad = ggml_add_or_set(ctx,
                             src1->grad,
                             grad_k,
-                            zero_table);
+                            zero_table, acc_table);
                 }
                 if (src2->grad) {
                     struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v);
@@ -18754,7 +18947,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src2->grad = ggml_add_or_set(ctx,
                             src2->grad,
                             grad_v,
-                            zero_table);
+                            zero_table, acc_table);
                 }
             } break;
         case GGML_OP_FLASH_ATTN_BACK:
@@ -18780,7 +18973,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                             ggml_mul(ctx,
                                                 ggml_sgn(ctx, src0),
                                                 tensor->grad),
-                                            zero_table);
+                                            zero_table, acc_table);
                             }
                         } break;
                     case GGML_UNARY_OP_SGN:
@@ -18792,7 +18985,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     case GGML_UNARY_OP_NEG:
                         {
                             if (src0->grad) {
-                                src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table);
+                                src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
                             }
                         } break;
                     case GGML_UNARY_OP_STEP:
@@ -18817,7 +19010,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                         ggml_mul(ctx,
                                             ggml_step(ctx, src0),
                                             tensor->grad),
-                                        zero_table);
+                                        zero_table, acc_table);
                             }
                         } break;
                     case GGML_UNARY_OP_SIGMOID:
@@ -18839,7 +19032,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 src0->grad = ggml_add_or_set(ctx,
                                         src0->grad,
                                         ggml_silu_back(ctx, src0, tensor->grad),
-                                        zero_table);
+                                        zero_table, acc_table);
                             }
                         } break;
                     case GGML_UNARY_OP_EXP:
@@ -18848,7 +19041,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 src0->grad = ggml_add_or_set(ctx,
                                         src0->grad,
                                         ggml_mul(ctx, tensor, tensor->grad),
-                                        zero_table);
+                                        zero_table, acc_table);
                             }
                         } break;
                     default:
@@ -18878,13 +19071,17 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                     src0,
                                     src1,
                                     tensor->grad),
-                                zero_table);
+                                zero_table, acc_table);
                 }
             } break;
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
             {
                 GGML_ABORT("fatal error"); // not supported
             }
+        case GGML_OP_OPT_STEP_ADAMW:
+            {
+                GGML_ABORT("fatal error"); // not supported
+            }
         case GGML_OP_NONE:
             {
                 // nop
@@ -18974,7 +19171,7 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor *
     ggml_build_forward_impl(cgraph, tensor, true);
 }
 
-void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep) {
+void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate, bool keep) {
     GGML_ASSERT(gf->n_nodes > 0);
     GGML_ASSERT(gf->grads);
 
@@ -18990,21 +19187,35 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
         }
     }
 
-    // remember original gradients which start with zero values
+    // keep tables of original gradients for replacement/accumulation logic
     struct ggml_hash_set zero_table = ggml_hash_set_new(gf->size);
+    struct ggml_hash_set acc_table  = ggml_hash_set_new(gf->size);
     for (int i = 0; i < gf->n_nodes; i++) {
-        if (gf->grads[i]) {
-            ggml_hash_insert(&zero_table, gf->grads[i]);
+        struct ggml_tensor * node = gf->nodes[i];
+
+        if (node->grad) {
+            {
+                const size_t insert_result = ggml_hash_insert(&zero_table, node->grad);
+                GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
+                GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
+            }
+
+            // only gradients of trainable parameters should be accumulated
+            if (accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
+                const size_t insert_result = ggml_hash_insert(&acc_table, node->grad);
+                GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
+                GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
+            }
         }
     }
 
     for (int i = gf->n_nodes - 1; i >= 0; i--) {
         struct ggml_tensor * node = gf->nodes[i];
 
-        // inplace operations to add gradients are not created by ggml_compute_backward
+        // inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
         // use allocator to automatically make inplace operations
         if (node->grad) {
-            ggml_compute_backward(ctx, node, &zero_table);
+            ggml_compute_backward(ctx, node, &zero_table, &acc_table);
         }
     }
 
@@ -19018,8 +19229,30 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
     }
 
     ggml_hash_set_free(&zero_table);
+    ggml_hash_set_free(&acc_table);
+}
+
+void ggml_build_opt_adamw(
+        struct ggml_context * ctx,
+        struct ggml_cgraph  * gf,
+        struct ggml_cgraph  * gb,
+        float                 alpha,
+        float                 beta1,
+        float                 beta2,
+        float                 eps,
+        float                 wd) {
+    for (int i = 0; i < gf->n_nodes; i++) {
+        struct ggml_tensor * node = gf->nodes[i];
+
+        if (node->flags & GGML_TENSOR_FLAG_PARAM) {
+            GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
+            struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, alpha, beta1, beta2, eps, wd);
+            ggml_build_forward_expand(gb, opt_step);
+        }
+    }
 }
 
+
 static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
     void * ptr = *p;
     ptr = (void *) GGML_PAD((uintptr_t) ptr, align);
@@ -19147,10 +19380,28 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
     GGML_ASSERT(cgraph->grads != NULL);
 
     for (int i = 0; i < cgraph->n_nodes; i++) {
-        struct ggml_tensor * grad = cgraph->grads[i];
+        struct ggml_tensor * node = cgraph->nodes[i];
+
+        // initial gradients of loss should be 1, 0 otherwise
+        if (node->grad) {
+            if (node->flags & GGML_TENSOR_FLAG_LOSS) {
+                GGML_ASSERT(node->grad->buffer);
+                GGML_ASSERT(node->type == GGML_TYPE_F32);
+                GGML_ASSERT(ggml_is_scalar(node));
+
+                const float onef = 1.0f;
+                ggml_backend_tensor_set(node->grad, &onef, 0, ggml_nbytes(node->grad));
+            } else {
+                ggml_set_zero(node->grad);
+            }
+        }
 
-        if (grad) {
-            ggml_set_zero(grad);
+        GGML_ASSERT(node);
+        if (node->op == GGML_OP_OPT_STEP_ADAMW) {
+            // set iteration to 1 and clear momenta
+            ggml_set_op_params_i32(node, 0, 1);
+            ggml_set_zero(node->src[2]);
+            ggml_set_zero(node->src[3]);
         }
     }
 }
@@ -19415,6 +19666,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             } break;
         case GGML_OP_CROSS_ENTROPY_LOSS:
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+        case GGML_OP_OPT_STEP_ADAMW:
             {
                 n_tasks = n_threads;
             } break;
@@ -21777,7 +22029,7 @@ enum ggml_opt_result ggml_opt_resume(
     ggml_build_forward_expand(gf, f);
 
     struct ggml_cgraph * gb = ggml_graph_dup(ctx, gf);
-    ggml_build_backward_expand(ctx, gf, gb, true);
+    ggml_build_backward_expand(ctx, gf, gb, false, true);
 
     return ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL);
 }
index 635de01d70439c60d1c7a91b7f0ef89f2a236740..6c4de06c9db5f10eefb46b389564cf4de4ed419b 100644 (file)
@@ -799,10 +799,11 @@ struct test_case {
             out = ggml_sum(ctx, out);
             ggml_set_name(out, "sum_of_out");
         }
+        ggml_set_loss(out);
 
         ggml_build_forward_expand(gf, out);
         ggml_graph_cpy(gf, gb);
-        ggml_build_backward_expand(ctx, gf, gb, false);
+        ggml_build_backward_expand(ctx, gf, gb, false, false);
         if (expect.size() != 1 || expect[0] != 0.0f) {
             GGML_ASSERT(gb->n_nodes > gf->n_nodes);
             for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
@@ -837,22 +838,11 @@ struct test_case {
             return false;
         }
 
-        // randomize tensors
-        initialize_tensors(ctx);
-
-        for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
-            if (!t->grad) {
-                continue;
-            }
 
-            std::vector<float> tmp(ggml_nelements(t->grad));
-            ggml_backend_tensor_set(t->grad, tmp.data(), 0, ggml_nbytes(t->grad));
-        }
+        initialize_tensors(ctx); // Randomizes all tensors (including gradients).
+        ggml_graph_reset(gb);    // Sets gradients to 1 if loss, 0 otherwise.
 
-        // build graphs
-        const float onef = 1.0f;
         ggml_backend_graph_compute(backend, gf);
-        ggml_backend_tensor_set(out->grad, &onef, 0, ggml_nbytes(out->grad));
         ggml_backend_graph_compute(backend, gb);
 
         bool ok = true;
@@ -1681,6 +1671,50 @@ struct test_mul_mat_id : public test_case {
     }
 };
 
+// GGML_OP_OUT_PROD
+struct test_out_prod : public test_case {
+    const ggml_type type_a;
+    const ggml_type type_b;
+    const int64_t m;
+    const int64_t n;
+    const int64_t k;
+    const std::array<int64_t, 2> bs; // dims 3 and 4
+    const bool trans_b;
+
+    std::string vars() override {
+        return VARS_TO_STR7(type_a, type_b, m, n, k, bs, trans_b);
+    }
+
+    double max_nmse_err() override {
+        return 5e-4;
+    }
+
+    test_out_prod(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
+            int64_t m = 32, int64_t n = 32, int64_t k = 32,
+            std::array<int64_t, 2> bs = {10, 10},
+            bool trans_b = false)
+        : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), trans_b(trans_b) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, m, k, bs[0], bs[1]);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * b;
+        if (trans_b) {
+            b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0], bs[1]);
+            b = ggml_transpose(ctx, b);
+        } else {
+            b = ggml_new_tensor_4d(ctx, type_b, n, k, bs[0], bs[1]);
+        }
+        ggml_set_name(b, "b");
+
+        ggml_tensor * out = ggml_out_prod(ctx, a, b);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
 // GGML_OP_SQR
 struct test_sqr : public test_case {
     const ggml_type type;
@@ -2666,6 +2700,51 @@ struct test_cross_entropy_loss : public test_case {
     }
 };
 
+// GGML_OP_OPT_STEP_ADAMW
+struct test_opt_step_adamw : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    const float alpha;
+    const float beta1;
+    const float beta2;
+    const float eps;
+    const float wd;
+
+    std::string vars() override {
+        return VARS_TO_STR7(type, ne, alpha, beta1, beta2, eps, wd);
+    }
+
+    test_opt_step_adamw(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 5, 4, 3},
+            float alpha = 1e-3f,
+            float beta1 = 0.9f,
+            float beta2 = 0.999f,
+            float eps = 1e-8f,
+            float wd = 0.0f)
+        : type(type), ne(ne), alpha(alpha), beta1(beta1), beta2(beta2), eps(eps), wd(wd) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
+        ggml_set_param(ctx, a); // Despite tensor a having gradients the output tensor will not.
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_opt_step_adamw(ctx, a, alpha, beta1, beta2, eps, wd);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t, 0.0f, 1.0f); // grad_v needs non-negative values.
+        }
+    }
+
+    bool grad_precise() override {
+        return true;
+    }
+};
+
 enum llm_norm_type {
     LLM_NORM,
     LLM_NORM_RMS,
@@ -3159,14 +3238,15 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
     test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
 
-
-    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 1}));
-    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1}));
-    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 2, 1, 1}));
-    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 1}));
-    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 2}));
-    test_cases.emplace_back(new test_repeat(GGML_TYPE_I32, {10, 5, 4, 3}, {2, 1, 1, 1}));
-    test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 5, 4, 3}, {1, 1, 1, 2}));
+    for (int ne3 : {1, 3}) { // CUDA backwards pass only supports ne3 == 1
+        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 1}));
+        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {2, 1, 1, 1}));
+        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 2, 1, 1}));
+        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 2, 1}));
+        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 2}));
+        test_cases.emplace_back(new test_repeat(GGML_TYPE_I32, {10, 5, 4, ne3}, {2, 1, 1, 1}));
+        test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 5, 4, ne3}, {1, 1, 1, 2}));
+    }
 
     test_cases.emplace_back(new test_dup(GGML_TYPE_F32));
     test_cases.emplace_back(new test_dup(GGML_TYPE_F16));
@@ -3350,6 +3430,27 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         }
     }
 
+    for (ggml_type type_a : base_types) {
+        for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, { 1,  1}));
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10,  1}));
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10,  1}));
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10}));
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10}));
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10}));
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10}));
+
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, { 1,  1}));
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, { 1,  1}, true));
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10,  1}));
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10,  1}));
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10}));
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10}));
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10}));
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10}));
+        }
+    }
+
     test_cases.emplace_back(new test_sqr());
     test_cases.emplace_back(new test_sqrt());
     test_cases.emplace_back(new test_log());
@@ -3476,6 +3577,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     }
 
     test_cases.emplace_back(new test_cross_entropy_loss());
+    for (float wd : {0.0f, 1e-2f}) {
+        test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}, 1.0f, 1e-3f, 0.9f, 0.999f, wd));
+    }
 
     // these tests are disabled to save execution time, but they can be handy for debugging
 #if 0
index 1834c11d894b4ce4279715598fe35c0c8c1c83a2..2ef606d2c3591cbdd85d3ee120430e740875b2fc 100644 (file)
@@ -240,7 +240,7 @@ static bool check_gradient(
     struct ggml_cgraph * gb = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
     ggml_build_forward_expand(gf, f);
     ggml_graph_cpy(gf, gb);
-    ggml_build_backward_expand(ctx0, gf, gb, false);
+    ggml_build_backward_expand(ctx0, gf, gb, false, false);
 
     ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
 
index ebd004b0473baa2b16813ece95b8f4a22bb8224d..74192bdd6b44672165602dc975cb68a6f6088630 100644 (file)
@@ -100,7 +100,7 @@ bool check_gradient(
     struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
     ggml_build_forward_expand(gf, f);
     struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf);
-    ggml_build_backward_expand(ctx0, gf, gb, false);
+    ggml_build_backward_expand(ctx0, gf, gb, false, false);
 
     ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
     ggml_graph_reset  (gf);
index 7b5a546a8e1900d2c110fa6258072fcf07cb545b..15119a7d18dedee0dbdadb1eca22d391ca2e1563 100644 (file)
@@ -31,7 +31,7 @@ int main(int argc, const char ** argv) {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
         ggml_build_forward_expand(gf, f);
         struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf);
-        ggml_build_backward_expand(ctx0, gf, gb, false);
+        ggml_build_backward_expand(ctx0, gf, gb, false, false);
 
         ggml_set_f32(x, 2.0f);
         ggml_set_f32(a, 3.0f);
@@ -83,7 +83,7 @@ int main(int argc, const char ** argv) {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
         ggml_build_forward_expand(gf, y);
         struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf);
-        ggml_build_backward_expand(ctx0, gf, gb, false);
+        ggml_build_backward_expand(ctx0, gf, gb, false, false);
 
         ggml_graph_reset(gf);
         ggml_set_f32(y->grad, 1.0f);
@@ -103,7 +103,7 @@ int main(int argc, const char ** argv) {
 
         struct ggml_cgraph * gbb = ggml_graph_dup(ctx0, gb);
 
-        ggml_build_backward_expand(ctx0, gb, gbb, true);
+        ggml_build_backward_expand(ctx0, gb, gbb, false, true);
 
         ggml_graph_reset(gb);
         ggml_set_f32(g1->grad, 1.0f);
@@ -134,7 +134,7 @@ int main(int argc, const char ** argv) {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
         ggml_build_forward_expand(gf, y);
         struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf);
-        ggml_build_backward_expand(ctx0, gf, gb, false);
+        ggml_build_backward_expand(ctx0, gf, gb, false, false);
 
         ggml_set_f32(x1, 3.0f);
         ggml_set_f32(x2, 4.0f);
@@ -172,7 +172,7 @@ int main(int argc, const char ** argv) {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
         ggml_build_forward_expand(gf, y);
         struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf);
-        ggml_build_backward_expand(ctx0, gf, gb, false);
+        ggml_build_backward_expand(ctx0, gf, gb, false, false);
 
         ggml_set_f32(x1, 1.0f);
         ggml_set_f32(x2, 2.0f);
@@ -199,7 +199,7 @@ int main(int argc, const char ** argv) {
 
         struct ggml_cgraph * gbb = ggml_graph_dup(ctx0, gb);
 
-        ggml_build_backward_expand(ctx0, gb, gbb, true);
+        ggml_build_backward_expand(ctx0, gb, gbb, false, true);
 
         ggml_graph_reset(gb);
         ggml_set_f32(g1->grad, 1.0f);
@@ -235,7 +235,7 @@ int main(int argc, const char ** argv) {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
         ggml_build_forward_expand(gf, y);
         struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf);
-        ggml_build_backward_expand(ctx0, gf, gb, false);
+        ggml_build_backward_expand(ctx0, gf, gb, false, false);
 
         ggml_set_f32(x1, 3.0f);
         ggml_set_f32(x2, 5.0f);
@@ -290,7 +290,7 @@ int main(int argc, const char ** argv) {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
         ggml_build_forward_expand(gf, y);
         struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf);
-        ggml_build_backward_expand(ctx0, gf, gb, false);
+        ggml_build_backward_expand(ctx0, gf, gb, false, false);
 
         ggml_set_f32(x1, 3.0f);
         ggml_set_f32(x2, 5.0f);
@@ -345,7 +345,7 @@ int main(int argc, const char ** argv) {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
         ggml_build_forward_expand(gf, y);
         struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf);
-        ggml_build_backward_expand(ctx0, gf, gb, false);
+        ggml_build_backward_expand(ctx0, gf, gb, false, false);
 
         ggml_set_f32(x1, 3.0f);
         ggml_set_f32(x2, 5.0f);
@@ -394,7 +394,7 @@ int main(int argc, const char ** argv) {
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
         ggml_build_forward_expand(gf, y);
         struct ggml_cgraph * gb = ggml_graph_dup(ctx0, gf);
-        ggml_build_backward_expand(ctx0, gf, gb, false);
+        ggml_build_backward_expand(ctx0, gf, gb, false, false);
 
         ggml_set_f32(x1, 3.0f);
         ggml_set_f32(x2, 5.0f);