]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
finetune: SGD optimizer, more CLI args (llama/13873)
authorJonathan Graehl <redacted>
Thu, 14 Aug 2025 10:03:57 +0000 (03:03 -0700)
committerGeorgi Gerganov <redacted>
Thu, 14 Aug 2025 11:17:28 +0000 (14:17 +0300)
* examples/finetune -opt SGD (stochastic gradient descent) memory opt

add unit tested GGML_OPT_OPTIMIZER_SGD to ggml - avoids allocating
m, v tensors.

support finetune.cpp arg -opt SGD (or sgd). (default adamw as before)

llama 3.2-1b-F32 result: observed 11gb gpu ram (41 sec/epoch)
when using SGD instead of 19gb (55 sec/epoch) using adamw.
(wikipedia 100 lines finetune)

(
using the same GPU memory, adamw can only do before OOM 512
batch/context, reaching:
train: [███████▉] data=0000140/0000140 loss=0.02575±0.00099 acc=99.52±0.03% t=00:00:47 ETA=00:00:00
val:   [███████▉] data=0000008/0000008 loss=4.76565±0.28810 acc=41.46±0.77% t=00:00:00 ETA=00:00:00

SGD is superior, though it converges slower, with max before OOM 1728
batch/context (esp see the better validation perf):
train: [███████▉] data=0000039/0000039 loss=0.00371±0.00010 acc=99.96±0.01% t=00:00:41 ETA=00:00:00
val:   [███████▉] data=0000003/0000003 loss=5.11406±0.76034 acc=48.01±0.69% t=00:00:01 ETA=00:00:00
)

note: when finetuning long enough (or w/ enough -lr),
validation accuracy *eventually* drops ('catastrophic forgetting')

-lr-half (halflife) option useful for SGD to avoid oscillation or
super slow underdamped learning (makes setting -lr more forgiving).
terminal -lr for now is set by lr-halvings i.e. if you want at most
1/8 the inital -lr you set -lr-halvings 3.

note: objective loss not directly comparable between adamw, sgd? -
check perplexity or accuracy or consider relative improvements
for convergence

new finetune args -wd 1e-9 to enable weight decay in sgd or adamw,
and max -epochs N (default 2 as before)

cache (1 - wd*alpha) in 'adamw' opt struct -
no noticeable perf benefit, disabled (still done
for new SGD though)

since opt. memory is pre-allocated, the ggml_opt_get_optimizer_params
would probably be able to change between SGD and AdamW with each epoch
but would need to use adamw for the first (unconfirmed - no cmdline arg
to set such a policy yet)

test-opt checks adamw as before and now sgd (except for a few disabled
tests for sgd only; probably just needs logging values and adding
alternate reference values);  tolerance on the 'regression'
test is broader for sgd (so we don't need many more epochs)

* Vulkan: Implement GGML_OP_OPT_STEP_SGD

* tests: Fix OPT_STEP_SGD test-backend-ops

* SGD op param store weight-decay and not 1-alpha*wd

* minor + cosmetic changes

* fix vulkan sgd

* try CI fix

---------

Co-authored-by: 0cc4m <redacted>
Co-authored-by: Johannes Gäßler <redacted>
15 files changed:
include/ggml-opt.h
include/ggml.h
src/ggml-cpu/ggml-cpu.c
src/ggml-cpu/ops.cpp
src/ggml-cpu/ops.h
src/ggml-cuda/ggml-cuda.cu
src/ggml-cuda/opt-step-sgd.cu [new file with mode: 0644]
src/ggml-cuda/opt-step-sgd.cuh [new file with mode: 0644]
src/ggml-opt.cpp
src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp [new file with mode: 0644]
src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
src/ggml.c
tests/test-backend-ops.cpp
tests/test-opt.cpp

index 74ec080a055eaeb4134080c6a77f25feffc0353e..4703a05afe198cc41581efb31700932498b5e929 100644 (file)
@@ -74,16 +74,26 @@ extern "C" {
         GGML_OPT_BUILD_TYPE_OPT     = 30,
     };
 
+    enum ggml_opt_optimizer_type {
+        GGML_OPT_OPTIMIZER_TYPE_ADAMW,
+        GGML_OPT_OPTIMIZER_TYPE_SGD,
+
+        GGML_OPT_OPTIMIZER_TYPE_COUNT
+    };
+
     // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
     struct ggml_opt_optimizer_params {
-        // AdamW optimizer parameters
         struct {
             float alpha; // learning rate
-            float beta1;
-            float beta2;
+            float beta1; // first AdamW momentum
+            float beta2; // second AdamW momentum
             float eps;   // epsilon for numerical stability
-            float wd;    // weight decay for AdamW, use 0.0f to disable
+            float wd;    // weight decay - 0.0f to disable
         } adamw;
+        struct {
+            float alpha; // learning rate
+            float wd;    // weight decay
+        } sgd;
     };
 
     // callback to calculate optimizer parameters prior to a backward pass
@@ -112,8 +122,11 @@ extern "C" {
 
         int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done
 
-        ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
-        void * get_opt_pars_ud;                     // userdata for calculating optimizer parameters
+        ggml_opt_get_optimizer_params get_opt_pars;    // callback for calculating optimizer parameters
+        void *                        get_opt_pars_ud; // userdata for calculating optimizer parameters
+
+        // only GGML_OPT_OPTIMIZER_TYPE_ADAMW needs m, v momenta per parameter tensor
+        enum ggml_opt_optimizer_type optimizer;
     };
 
     // get parameters for an optimization context with defaults set where possible
@@ -142,6 +155,10 @@ extern "C" {
     // get the gradient accumulator for a node from the forward graph
     GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
 
+    GGML_API enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t); //TODO consistent naming scheme
+
+    GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type);
+
     // ====== Optimization Result ======
 
     GGML_API ggml_opt_result_t ggml_opt_result_init(void);
@@ -226,12 +243,14 @@ extern "C" {
             struct ggml_tensor            * outputs,        // output tensor, must have shape [ne_label, ndata_batch] if labels are used
             ggml_opt_dataset_t              dataset,        // dataset with data and optionally also labels
             enum ggml_opt_loss_type         loss_type,      // loss to minimize
+            enum ggml_opt_optimizer_type    optimizer,      // sgd or adamw
             ggml_opt_get_optimizer_params   get_opt_pars,   // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
             int64_t                         nepoch,         // how many times the dataset should be iterated over
             int64_t                         nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
             float                           val_split,      // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
             bool                            silent);        // whether or not info prints to stderr should be suppressed
 
+
 #ifdef  __cplusplus
 }
 #endif
index c09d82a91cf9fa658da50991d92e4050556b8324..da8813fd278928cf362cade02acd77adc6c8e825 100644 (file)
@@ -542,6 +542,7 @@ extern "C" {
         GGML_OP_CROSS_ENTROPY_LOSS,
         GGML_OP_CROSS_ENTROPY_LOSS_BACK,
         GGML_OP_OPT_STEP_ADAMW,
+        GGML_OP_OPT_STEP_SGD,
 
         GGML_OP_GLU,
 
@@ -2311,7 +2312,14 @@ extern "C" {
             struct ggml_tensor  * grad,
             struct ggml_tensor  * m,
             struct ggml_tensor  * v,
-            struct ggml_tensor  * adamw_params); // parameters such a the learning rate
+            struct ggml_tensor  * adamw_params); // parameters such as the learning rate
+
+    // stochastic gradient descent step (with weight decay)
+    GGML_API struct ggml_tensor * ggml_opt_step_sgd(
+        struct ggml_context * ctx,
+        struct ggml_tensor *  a,
+        struct ggml_tensor *  grad,
+        struct ggml_tensor *  sgd_params); // alpha, weight decay
 
     //
     // automatic differentiation
index d89cd8f4ef6526540f038ddb03e51b24068c3bcb..f6bea3df34a0b212457db9c4bf7dfa5e9afc3cd5 100644 (file)
@@ -2022,6 +2022,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
                 ggml_compute_forward_opt_step_adamw(params, tensor);
             }
             break;
+        case GGML_OP_OPT_STEP_SGD:
+            {
+                ggml_compute_forward_opt_step_sgd(params, tensor);
+            }
+            break;
         case GGML_OP_NONE:
             {
                 // nop
@@ -2325,6 +2330,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_CROSS_ENTROPY_LOSS:
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
         case GGML_OP_OPT_STEP_ADAMW:
+        case GGML_OP_OPT_STEP_SGD:
             {
                 n_tasks = n_threads;
             } break;
index 854f1c2b49647b8ff7a42e27d738c8bd6e6d7164..b72a2556a5fc97850613d90fd3e1b8545c79e6aa 100644 (file)
@@ -10330,6 +10330,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
     const int ir1 = MIN(ir0 + dr, nr);
 
     const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
+
     const float alpha  = adamw_params_ptr[0];
     const float beta1  = adamw_params_ptr[1];
     const float beta2  = adamw_params_ptr[2];
@@ -10337,7 +10338,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
     const float wd     = adamw_params_ptr[4];
     const float beta1h = adamw_params_ptr[5];
     const float beta2h = adamw_params_ptr[6];
-
+    const float keep   = 1.f - alpha * wd;
     for (int ir = ir0; ir < ir1; ++ir) {
         const int64_t i03 = ir/(ne02*ne01);
         const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
@@ -10360,7 +10361,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
             // The weight decay is applied independently of the Adam momenta m and v.
             // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
             // See: https://arxiv.org/pdf/1711.05101v3.pdf
-            w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh;
+            w[i00] = w[i00] * keep - alpha * mh / vh;
         }
     }
 }
@@ -10382,3 +10383,63 @@ void ggml_compute_forward_opt_step_adamw(
             }
     }
 }
+
+static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
+    const ggml_tensor * src0       = dst->src[0];
+    const ggml_tensor * src0_grad  = dst->src[1];
+    const ggml_tensor * sgd_params = dst->src[2];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
+    GGML_ASSERT(ggml_nelements(sgd_params) == 2);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr = ggml_nrows(src0);
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1) / nth;
+
+    // row range for this thread
+    const int ir0 = dr * ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    // using adamw param subset we care about - alpha, wd - could have a separate struct
+    const float * sgd_params_ptr   = ggml_get_data_f32(sgd_params);
+    const float   alpha            = sgd_params_ptr[0];
+    const float   keep             = 1.f - alpha * sgd_params_ptr[1];
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        const int64_t i03 = ir / (ne02 * ne01);
+        const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
+        const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
+
+        const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
+
+        float *       w = (float *) ((char *) src0->data + offset);                   // weight
+        const float * g = (const float *) ((const char *) src0_grad->data + offset);  // grad
+
+        for (int i00 = 0; i00 < ne00; ++i00) {
+            w[i00] = w[i00] * keep - alpha * g[i00];
+        }
+    }
+}
+
+void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_opt_step_sgd_f32(params, dst);
+            }
+            break;
+        default:
+            {
+                GGML_ABORT("fatal error - sgd is F32 only");
+            }
+    }
+}
index f154afb4624980f3d8924aa5a6c47a1d555b2b55..82ea79eaa51ccd8cca78a9e0e28e3c8535a9c5fa 100644 (file)
@@ -107,7 +107,7 @@ void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params *
 void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
-
+void ggml_compute_forward_opt_step_sgd(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 #ifdef __cplusplus
 }
 #endif
index 0d92901cb21421e15f06cdaf74a788d5234d15ab..d6402a8daaccf50c162cf5c3b5b652d7baa3e0c5 100644 (file)
@@ -28,6 +28,7 @@
 #include "ggml-cuda/mmvq.cuh"
 #include "ggml-cuda/norm.cuh"
 #include "ggml-cuda/opt-step-adamw.cuh"
+#include "ggml-cuda/opt-step-sgd.cuh"
 #include "ggml-cuda/out-prod.cuh"
 #include "ggml-cuda/pad.cuh"
 #include "ggml-cuda/pool2d.cuh"
@@ -2479,6 +2480,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_OPT_STEP_ADAMW:
             ggml_cuda_opt_step_adamw(ctx, dst);
             break;
+        case GGML_OP_OPT_STEP_SGD:
+            ggml_cuda_opt_step_sgd(ctx, dst);
+            break;
         default:
             return false;
     }
@@ -3536,6 +3540,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_CROSS_ENTROPY_LOSS:
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
         case GGML_OP_OPT_STEP_ADAMW:
+        case GGML_OP_OPT_STEP_SGD:
             return true;
         default:
             return false;
diff --git a/src/ggml-cuda/opt-step-sgd.cu b/src/ggml-cuda/opt-step-sgd.cu
new file mode 100644 (file)
index 0000000..460b16d
--- /dev/null
@@ -0,0 +1,49 @@
+#include "ggml-impl.h"
+#include "opt-step-sgd.cuh"
+
+#include <cstdint>
+
+static __global__ void opt_step_sgd_f32(
+    float * __restrict__ x, const float * __restrict__ g,
+    const float * __restrict__ pars, const int64_t k) {
+
+    const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    x[i] = x[i] * (1.0f - pars[0] * pars[1]) - pars[0] * g[i];
+}
+
+static void opt_step_sgd_f32_cuda(
+    float * x, const float * g, const float * __restrict__ pars, const int64_t k, cudaStream_t stream) {
+
+    const dim3 block_dims(CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);
+    const dim3 block_nums((k + CUDA_OPT_STEP_SGD_BLOCK_SIZE - 1) / CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);
+    opt_step_sgd_f32<<<block_nums, block_dims, 0, stream>>>(x, g, pars, k);
+}
+
+void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0      = dst->src[0];
+    const ggml_tensor * src0_grad = dst->src[1];
+    const ggml_tensor * params    = dst->src[2];
+
+    GGML_ASSERT(src0->type      == GGML_TYPE_F32);
+    GGML_ASSERT(src0_grad->type == GGML_TYPE_F32);
+    GGML_ASSERT(params->type    == GGML_TYPE_F32);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src0_grad));
+    GGML_ASSERT(ggml_is_contiguous(params));
+    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
+    GGML_ASSERT(ggml_nelements(params) == 2);
+
+    float       * src0_d      = (float       *) src0->data;
+    const float * src0_grad_d = (const float *) src0_grad->data;
+    const float * params_d    = (const float *) params->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    const int64_t ne = ggml_nelements(src0);
+
+    opt_step_sgd_f32_cuda(src0_d, src0_grad_d, params_d, ne, stream);
+}
diff --git a/src/ggml-cuda/opt-step-sgd.cuh b/src/ggml-cuda/opt-step-sgd.cuh
new file mode 100644 (file)
index 0000000..f97ab7d
--- /dev/null
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_OPT_STEP_SGD_BLOCK_SIZE 256
+
+void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index a3c82d6757714b64f4a46dfb0e5a7a62c22ee566..e078ad14a39c4ec350948fdbbf75e6c7e8db9ec5 100644 (file)
@@ -64,9 +64,11 @@ struct ggml_opt_context {
     int32_t opt_i              = 0;
     bool    loss_per_datapoint = false;
 
-    ggml_opt_get_optimizer_params get_opt_pars = nullptr;
-    void * get_opt_pars_ud                     = nullptr;
-    struct ggml_tensor * adamw_params          = nullptr;
+    ggml_opt_get_optimizer_params get_opt_pars    = nullptr;
+    void *                        get_opt_pars_ud = nullptr;
+    struct ggml_tensor *          opt_step_params = nullptr; // Stores output of get_opt_pars.
+
+    enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW;
 };
 
 struct ggml_opt_result {
@@ -229,9 +231,13 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us
     result.adamw.eps   = 1e-8f;
     result.adamw.wd    = 0.0f;
 
+    result.sgd.alpha   = 1e-3f;
+    result.sgd.wd      = 0.0f;
+
     return result;
 }
 
+
 struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) {
     return *((struct ggml_opt_optimizer_params *) userdata);
 }
@@ -249,6 +255,7 @@ struct ggml_opt_params ggml_opt_default_params(
         /*opt_period      =*/ 1,
         /*get_opt_pars    =*/ ggml_opt_get_default_optimizer_params,
         /*get_opt_pars_ud =*/ nullptr,
+        /*optimizer       =*/ GGML_OPT_OPTIMIZER_TYPE_ADAMW,
     };
 }
 
@@ -316,9 +323,14 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
     GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc");
     GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
 
+    const enum ggml_opt_optimizer_type optimizer = opt_ctx->optimizer;
+
     const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD &&
         !(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
 
+    const bool need_momenta = opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT &&
+        opt_ctx->optimizer == GGML_OPT_OPTIMIZER_TYPE_ADAMW;
+
     ggml_set_input(opt_ctx->inputs);
     ggml_set_output(opt_ctx->outputs);
 
@@ -340,8 +352,7 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
         //   - pred (if using static graphs)
         //   - ncorrect (if using static graphs, 2 tensors).
         constexpr size_t n_loss = 1;
-        const size_t tensors_per_param = (accumulate ? 1 : 0) +
-            (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
+        const size_t tensors_per_param = (accumulate ? 1 : 0) + (need_momenta ? 2 : 0);
         const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
         const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead();
         struct ggml_init_params params = {
@@ -458,7 +469,7 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
             }
         }
 
-        if (opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) {
+        if (need_momenta && opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) {
             opt_ctx->grad_m.resize(n_nodes);
             opt_ctx->grad_v.resize(n_nodes);
             for (int i = 0; i < n_nodes; ++i) {
@@ -492,23 +503,36 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
     // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
     opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
 
-    opt_ctx->adamw_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, 7);
-    ggml_set_input(opt_ctx->adamw_params);
-    ggml_set_name(opt_ctx->adamw_params, "adamw_params");
-
+    opt_ctx->opt_step_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, need_momenta ? 7 : 2);
+    ggml_tensor * adamw_params = opt_ctx->opt_step_params;
+    ggml_set_input(adamw_params);
+    const char * optimizer_name = ggml_opt_optimizer_name(opt_ctx->optimizer);
+    ggml_format_name(adamw_params, "%s_params", optimizer_name);
     for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
         struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
         struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node);
 
         if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
-            struct ggml_tensor * m        = opt_ctx->grad_m[i];
-            struct ggml_tensor * v        = opt_ctx->grad_v[i];
-            struct ggml_tensor * opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, opt_ctx->adamw_params);
-
-            ggml_set_name(m,        (std::string("AdamW m for ")    + std::string(node->name)).c_str());
-            ggml_set_name(v,        (std::string("AdamW v for ")    + std::string(node->name)).c_str());
-            ggml_set_name(opt_step, (std::string("AdamW step for ") + std::string(node->name)).c_str());
-
+            struct ggml_tensor * m = nullptr;
+            struct ggml_tensor * v = nullptr;
+            if (need_momenta) {
+                m = opt_ctx->grad_m[i];
+                v = opt_ctx->grad_v[i];
+                ggml_format_name(m, "AdamW m for %s", node->name);
+                ggml_format_name(v, "AdamW v for %s", node->name);
+            }
+            struct ggml_tensor * opt_step;
+            switch (optimizer) {
+                case GGML_OPT_OPTIMIZER_TYPE_ADAMW:
+                    opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, adamw_params);
+                    break;
+                case GGML_OPT_OPTIMIZER_TYPE_SGD:
+                    opt_step = ggml_opt_step_sgd(opt_ctx->ctx_compute, node, grad, adamw_params);
+                    break;
+                default:
+                    GGML_ABORT("fatal error");
+            }
+            ggml_format_name(opt_step, "%s step for %s", optimizer_name, node->name);
             ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
         }
     }
@@ -534,6 +558,7 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
     result->opt_period       = params.opt_period;
     result->get_opt_pars     = params.get_opt_pars;
     result->get_opt_pars_ud  = params.get_opt_pars_ud;
+    result->optimizer        = params.optimizer;
 
     GGML_ASSERT(result->opt_period >= 1);
 
@@ -756,29 +781,43 @@ void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward) {
 void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {
     GGML_ASSERT(opt_ctx->eval_ready);
     if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
-        struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
-
-        GGML_ASSERT(opt_pars.adamw.alpha >  0.0f);
-        GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
-        GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
-        GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f);
-        GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f);
-        GGML_ASSERT(opt_pars.adamw.eps   >= 0.0f);
-        GGML_ASSERT(opt_pars.adamw.wd    >= 0.0f);
-        GGML_ASSERT(opt_pars.adamw.wd    <= 1.0f);
-
-        // beta1, beta2 after applying warmup
-        const float beta1h = 1.0f/(1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
-        const float beta2h = 1.0f/(1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
-
-        float * adamw_par_data = ggml_get_data_f32(opt_ctx->adamw_params);
-        adamw_par_data[0] = opt_pars.adamw.alpha;
-        adamw_par_data[1] = opt_pars.adamw.beta1;
-        adamw_par_data[2] = opt_pars.adamw.beta2;
-        adamw_par_data[3] = opt_pars.adamw.eps;
-        adamw_par_data[4] = opt_pars.adamw.wd;
-        adamw_par_data[5] = beta1h;
-        adamw_par_data[6] = beta2h;
+        const ggml_opt_optimizer_params & opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
+
+        switch (opt_ctx->optimizer) {
+            case GGML_OPT_OPTIMIZER_TYPE_ADAMW: {
+                GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
+                GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
+                GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
+                GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f);
+                GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f);
+                GGML_ASSERT(opt_pars.adamw.eps >= 0.0f);
+                GGML_ASSERT(opt_pars.adamw.wd >= 0.0f);
+                GGML_ASSERT(opt_pars.adamw.wd <= 1.0f);
+
+                // beta1, beta2 after applying warmup
+                const float beta1h = 1.0f / (1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
+                const float beta2h = 1.0f / (1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
+
+                float * adamw_par_data = ggml_get_data_f32(opt_ctx->opt_step_params);
+                adamw_par_data[0] = opt_pars.adamw.alpha;
+                adamw_par_data[1] = opt_pars.adamw.beta1;
+                adamw_par_data[2] = opt_pars.adamw.beta2;
+                adamw_par_data[3] = opt_pars.adamw.eps;
+                adamw_par_data[4] = opt_pars.adamw.wd;
+                adamw_par_data[5] = beta1h;
+                adamw_par_data[6] = beta2h;
+            } break;
+            case GGML_OPT_OPTIMIZER_TYPE_SGD: {
+                GGML_ASSERT(opt_pars.sgd.alpha > 0.0f);
+                GGML_ASSERT(opt_pars.sgd.wd >= 0.0f);
+                GGML_ASSERT(opt_pars.sgd.wd <= 1.0f);
+                float * sgd = ggml_get_data_f32(opt_ctx->opt_step_params);
+                sgd[0] = opt_pars.sgd.alpha;
+                sgd[1] = opt_pars.sgd.wd;
+            } break;
+            default:
+                GGML_ABORT("fatal error");
+        }
     }
 
     ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
@@ -963,6 +1002,7 @@ void ggml_opt_fit(
         ggml_tensor                   * outputs,
         ggml_opt_dataset_t              dataset,
         enum ggml_opt_loss_type         loss_type,
+        enum ggml_opt_optimizer_type    optimizer,
         ggml_opt_get_optimizer_params   get_opt_pars,
         int64_t                         nepoch,
         int64_t                         nbatch_logical,
@@ -993,6 +1033,7 @@ void ggml_opt_fit(
     params.opt_period      = opt_period;
     params.get_opt_pars    = get_opt_pars;
     params.get_opt_pars_ud = &epoch;
+    params.optimizer       = optimizer;
     ggml_opt_context_t opt_ctx = ggml_opt_init(params);
 
     // Shuffling the data is generally useful but there is only a point if not all data is used in a single batch.
@@ -1035,3 +1076,18 @@ void ggml_opt_fit(
     ggml_opt_result_free(result_train);
     ggml_opt_result_free(result_val);
 }
+
+enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t c) {
+    return c->optimizer;
+}
+
+GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type o) {
+    switch (o) {
+        case GGML_OPT_OPTIMIZER_TYPE_ADAMW:
+            return "adamw";
+        case GGML_OPT_OPTIMIZER_TYPE_SGD:
+            return "sgd";
+        default:
+            return "undefined";
+    };
+}
index 4070e248baa2898272cd3cd585f27774cdcd1b1d..f50a737f389668ccb9dd0079ef21d4693bad201d 100644 (file)
@@ -510,6 +510,7 @@ struct vk_device_struct {
     vk_pipeline pipeline_rwkv_wkv6_f32;
     vk_pipeline pipeline_rwkv_wkv7_f32;
     vk_pipeline pipeline_opt_step_adamw_f32;
+    vk_pipeline pipeline_opt_step_sgd_f32;
     vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
     vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
     vk_pipeline pipeline_conv2d_dw_whcn_f32;
@@ -3123,6 +3124,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 
+    ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+
     // conv2d
     for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
         uint32_t conv2d_WG_SIZE  = 256;
@@ -7193,6 +7196,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             return ctx->device->pipeline_opt_step_adamw_f32;
         }
         return nullptr;
+    case GGML_OP_OPT_STEP_SGD:
+        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+            return ctx->device->pipeline_opt_step_sgd_f32;
+        }
+        return nullptr;
     case GGML_OP_LEAKY_RELU:
         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
             return ctx->device->pipeline_leaky_relu_f32;
@@ -7692,6 +7700,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
         ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
         ggml_vk_sync_buffers(subctx);
         ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
+    } else if (op == GGML_OP_OPT_STEP_SGD) {
+        // OPT_STEP_SGD works on src0, it does not need dst
+        ggml_vk_sync_buffers(subctx);
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements);
     } else if (use_src2) {
         ggml_vk_sync_buffers(subctx);
         ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
@@ -8045,6 +8057,12 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su
     );
 }
 
+static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
+    const size_t n = ggml_nelements(dst->src[0]);
+
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun);
+}
+
 static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
     int * op_params = (int *)dst->op_params;
 
@@ -9598,6 +9616,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
     case GGML_OP_LEAKY_RELU:
     case GGML_OP_FLASH_ATTN_EXT:
     case GGML_OP_OPT_STEP_ADAMW:
+    case GGML_OP_OPT_STEP_SGD:
         break;
     default:
         std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
@@ -9662,6 +9681,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
         case GGML_OP_CONV_2D:
         case GGML_OP_CONV_2D_DW:
         case GGML_OP_LEAKY_RELU:
+        case GGML_OP_OPT_STEP_SGD:
             {
                 // These operations all go through ggml_vk_op_f32, so short-circuit and
                 // do the only thing needed for the dryrun.
@@ -9911,6 +9931,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
     case GGML_OP_OPT_STEP_ADAMW:
         ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
 
+        break;
+
+    case GGML_OP_OPT_STEP_SGD:
+        ggml_vk_opt_step_sgd(ctx, compute_ctx, src0, src1, src2, node, dryrun);
+
         break;
     default:
         return false;
@@ -10014,8 +10039,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
     case GGML_OP_REPEAT:
     case GGML_OP_REPEAT_BACK:
     case GGML_OP_OPT_STEP_ADAMW:
+    case GGML_OP_OPT_STEP_SGD:
         buf = tensor->buffer;
-
         break;
     case GGML_OP_UNARY:
         switch (ggml_get_unary_op(tensor)) {
@@ -11154,6 +11179,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_SIN:
         case GGML_OP_COS:
         case GGML_OP_CLAMP:
+        case GGML_OP_LEAKY_RELU:
+        case GGML_OP_OPT_STEP_ADAMW:
+        case GGML_OP_OPT_STEP_SGD:
             return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_UPSCALE:
         case GGML_OP_ACC:
@@ -11175,8 +11203,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_POOL_2D:
         case GGML_OP_RWKV_WKV6:
         case GGML_OP_RWKV_WKV7:
-        case GGML_OP_LEAKY_RELU:
-        case GGML_OP_OPT_STEP_ADAMW:
             return true;
         case GGML_OP_CONV_TRANSPOSE_1D:
             return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
@@ -11774,6 +11800,10 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
         src_clone[0]->flags = src0->flags;
         tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
         src_clone[2], src_clone[3], src_clone[4]);
+    } else if (tensor->op == GGML_OP_OPT_STEP_SGD) {
+        src_clone[0]->flags = src0->flags;
+        tensor_clone = ggml_opt_step_sgd(ggml_ctx, src_clone[0], src_clone[1],
+        src_clone[2]);
     }
     else {
         std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
diff --git a/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp b/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp
new file mode 100644 (file)
index 0000000..6426ded
--- /dev/null
@@ -0,0 +1,22 @@
+#version 450
+
+#include "generic_head.comp"
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) buffer X {A_TYPE data_x[];};
+layout (binding = 1) readonly buffer G {A_TYPE data_grad[];};
+layout (binding = 2) readonly buffer P {float data_params[2];};
+
+void main() {
+    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+    if (i >= p.KX) {
+        return;
+    }
+
+    const float alpha = data_params[0];
+    const float keep = 1.f - alpha * data_params[1];
+
+    data_x[i] = data_x[i] * keep - alpha * data_grad[i];
+}
index 4cd94c51e3f21ae70aae78da831af2a7c014f795..68933d19f2ec56a65380358359cf0296826ab681 100644 (file)
@@ -657,6 +657,7 @@ void process_shaders() {
     string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
 
     string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
+    string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
 
     string_to_spv("conv2d_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
     string_to_spv("conv2d_f16_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
index aff0b47cb2641df252946a142590cabbad47f2ae..54961213115503ef65798ef2ce7a14739dfdd342 100644 (file)
@@ -1012,11 +1012,12 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CROSS_ENTROPY_LOSS",
     "CROSS_ENTROPY_LOSS_BACK",
     "OPT_STEP_ADAMW",
+    "OPT_STEP_SGD",
 
     "GLU",
 };
 
-static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
+static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1113,15 +1114,15 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "cross_entropy_loss(x,y)",
     "cross_entropy_loss_back(x,y)",
     "adamw(x)",
+    "sgd(x)",
 
     "glu(x)",
 };
 
-static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
+static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
-
 static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
     "ABS",
     "SGN",
@@ -5606,6 +5607,28 @@ struct ggml_tensor * ggml_opt_step_adamw(
     return result;
 }
 
+// opt_step_sgd
+
+struct ggml_tensor * ggml_opt_step_sgd(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * grad,
+        struct ggml_tensor  * params) {
+    GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
+    GGML_ASSERT(ggml_are_same_shape(a, grad));
+    GGML_ASSERT(params->type == GGML_TYPE_F32);
+    GGML_ASSERT(ggml_nelements(params) == 2);
+
+    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+    result->op     = GGML_OP_OPT_STEP_SGD;
+    result->src[0] = a;
+    result->src[1] = grad;
+    result->src[2] = params;
+
+    return result;
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 
 struct ggml_hash_set ggml_hash_set_new(size_t size) {
index 63e03978e4292964a4aeccce381dbc458c8dba85..cc9c3a0d57bc5863638ed96daec843234bbcb7f5 100644 (file)
@@ -4791,6 +4791,45 @@ struct test_opt_step_adamw : public test_case {
     }
 };
 
+struct test_opt_step_sgd : public test_case {
+    const ggml_type              type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override { return VARS_TO_STR2(type, ne); }
+
+    test_opt_step_sgd(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = { 10, 5, 4, 3 })
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
+        ggml_set_param(a);  // Despite tensor a having gradients the output tensor will not.
+        ggml_set_name(a, "a");
+
+        ggml_tensor * grad = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
+        ggml_set_name(grad, "grad");
+
+        ggml_tensor * sgd_params = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 2);
+        ggml_set_name(sgd_params, "sgd_params");
+
+        ggml_tensor * out = ggml_opt_step_sgd(ctx, a, grad, sgd_params);
+
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t, 0.0f, 1.0f);  // sgd_params need non-negative values.
+        }
+    }
+
+    bool grad_precise() override {
+        return true;
+    }
+};
+
 enum llm_norm_type {
     LLM_NORM,
     LLM_NORM_RMS,
@@ -6067,6 +6106,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_cross_entropy_loss_back(GGML_TYPE_F32, {30000, 1, 1, 1}));
 
     test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
+    test_cases.emplace_back(new test_opt_step_sgd(GGML_TYPE_F32, {10, 5, 4, 3}));
 
 #if 0
     // these tests are disabled to save execution time, sbut they can be handy for debugging
index 558f877210e7d734078ae0fea327f3a19b4086f3..b267b8ab3735b22322d7bc4e4e44ee121fc75124 100644 (file)
@@ -1,8 +1,12 @@
+// TODO refactor
+
 #include "ggml.h"
 #include "ggml-alloc.h"
 #include "ggml-backend.h"
 #include "ggml-cpu.h"
 #include "ggml-opt.h"
+#include "../ggml/src/ggml-impl.h"
+#include "../common/common.h"
 
 #include <cmath>
 #include <cinttypes>
@@ -11,6 +15,8 @@
 #include <thread>
 #include <vector>
 
+#define TEST_LOG(...)       GGML_LOG_DEBUG(__VA_ARGS__)
+
 static bool almost_equal(const double a, const double b, const double atol) {
     return fabs(a - b) < atol;
 }
@@ -40,14 +46,20 @@ struct helper_ctx_data {
 // These default values make it easier to check optimization results vs. expected values.
 static ggml_opt_optimizer_params helper_get_test_opt_pars(void * userdata) {
     ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(userdata);
+
     result.adamw.alpha = 1.0f;
     result.adamw.beta1 = 0.0f;
     result.adamw.beta2 = 0.0f;
     result.adamw.eps   = 0.0f;
+    result.adamw.wd    = 0.0f;
+    result.sgd.wd      = 0.0f;
+    result.sgd.alpha   = 1.0f;
+
     return result;
 }
 
 static helper_ctx_data helper_get_ctx_data(
+        enum ggml_opt_optimizer_type optim,
         ggml_backend_sched_t    backend_sched,
         ggml_backend_t          backend,
         const bool              init_opt_ctx       = true,
@@ -134,10 +146,13 @@ static helper_ctx_data helper_get_ctx_data(
     opt_params.inputs      = inputs;
     opt_params.outputs     = outputs;
     opt_params.opt_period  = opt_period;
+    opt_params.optimizer   = optim;
     if (!optimizer_defaults) {
         opt_params.get_opt_pars = helper_get_test_opt_pars;
     }
+    GGML_ASSERT(opt_params.get_opt_pars);
     ggml_opt_context_t opt_ctx = init_opt_ctx ? ggml_opt_init(opt_params) : nullptr;
+    GGML_ASSERT(!opt_ctx || ggml_opt_context_optimizer_type(opt_ctx) == opt_params.optimizer);
 
     ggml_opt_result_t result  = ggml_opt_result_init();
     ggml_opt_result_t result2 = ggml_opt_result_init();
@@ -158,25 +173,37 @@ static void helper_free_ctx_data(struct helper_ctx_data ctx_data) {
     ggml_opt_dataset_free(ctx_data.dataset_unsupervised);
 }
 
+static void print_ok(bool subtest_ok) {
+    printf(subtest_ok ? "\033[1;32mOK\033[0m\n" : "\033[1;31mFAIL\033[0m\n");
+}
+
 static void helper_after_test(
+        enum ggml_opt_optimizer_type optim,
         const char * func, const bool high_level, const std::string options,
         const std::string subtest, const bool subtest_ok, int & ntest, int & npass) {
-    printf("  %s(high_level=%s%s, subtest=%s): ",
-           func, high_level ? "yes" : "no", options.c_str(), subtest.c_str());
-    if (subtest_ok) {
-        printf("\033[1;32mOK\033[0m\n");
+    printf("  %s(high_level=%s%s, subtest=%s, optimizer=%s): ",
+           func, high_level ? "yes" : "no", options.c_str(), subtest.c_str(), ggml_opt_optimizer_name(optim));
+    print_ok(subtest_ok);
+    if (subtest_ok)
         npass++;
-    } else {
-        printf("\033[1;31mFAIL\033[0m\n");
-    }
     ntest++;
 }
 
-static std::pair<int, int> test_dataset(ggml_backend_sched_t backend_sched, ggml_backend_t backend, const bool shuffle) {
+static void print_ok(const char * func, bool subtest_ok, int & npass, int & ntest, const char * args = "") {
+    printf("  %s(%s): ", func, args);
+    print_ok(subtest_ok);
+    if (subtest_ok)
+        npass++;
+    ++ntest;
+}
+
+static std::pair<int, int> test_dataset(
+    enum ggml_opt_optimizer_type optim,
+    ggml_backend_sched_t backend_sched, ggml_backend_t backend, const bool shuffle) {
     int ntest = 0;
     int npass = 0;
 
-    struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend);
+    struct helper_ctx_data cd = helper_get_ctx_data(optim, backend_sched, backend);
 
     for (int64_t ndata_shard = 1; ndata_shard <= ndata; ++ndata_shard) {
         ggml_opt_dataset_t dataset = cd.datasets_supervised[ndata_shard-1];
@@ -255,11 +282,13 @@ static std::pair<int, int> test_dataset(ggml_backend_sched_t backend_sched, ggml
     return std::make_pair(npass, ntest);
 }
 
-static std::pair<int, int> test_grad(ggml_backend_sched_t backend_sched, ggml_backend_t backend) {
+static std::pair<int, int> test_grad(
+    enum ggml_opt_optimizer_type optim,
+    ggml_backend_sched_t backend_sched, ggml_backend_t backend) {
     int ntest = 0;
     int npass = 0;
 
-    struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false,
+    struct helper_ctx_data cd = helper_get_ctx_data(optim, backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false,
     /*nbatch_logical =*/ 999999, /*nbatch_physical =*/ 1);
 
     std::vector<float> grad_history(ndata);
@@ -270,6 +299,7 @@ static std::pair<int, int> test_grad(ggml_backend_sched_t backend_sched, ggml_ba
     for (int idata = 0; idata < ndata; ++idata) {
         const float idataf = idata;
         ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
+        // leaked
         ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
         ggml_opt_eval(cd.opt_ctx, cd.result);
         ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata, 0, sizeof(float));
@@ -298,19 +328,21 @@ static std::pair<int, int> test_grad(ggml_backend_sched_t backend_sched, ggml_ba
 }
 
 static void helper_after_test_forward_backward(
+        enum ggml_opt_optimizer_type optim,
         const char * func, const bool high_level, const bool shuffle,
         const std::string subtest, const bool subtest_ok, int & ntest, int & npass) {
     std::string options = ", shuffle=";
     options += shuffle ? "yes" : "no";
-    helper_after_test(func, high_level, options, subtest, subtest_ok, ntest, npass);
+    helper_after_test(optim, func, high_level, options, subtest, subtest_ok, ntest, npass);
 }
 
 static std::pair<int, int> test_forward_backward(
+        enum ggml_opt_optimizer_type optim,
         ggml_backend_sched_t backend_sched, ggml_backend_t backend, const bool high_level, const bool shuffle) {
     int ntest = 0;
     int npass = 0;
 
-    struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false);
+    struct helper_ctx_data cd = helper_get_ctx_data(optim, backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false);
     struct ggml_tensor * loss = ggml_opt_loss(cd.opt_ctx);
 
     std::vector<float> loss_history(ndata);
@@ -328,7 +360,7 @@ static std::pair<int, int> test_forward_backward(
         double accuracy_unc;
         ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);
         const bool subtest_ok = ndata == 0 && loss == 0.0 && std::isnan(loss_unc) && std::isnan(accuracy) && std::isnan(accuracy_unc);
-        helper_after_test_forward_backward(__func__, high_level, shuffle, "results_initial", subtest_ok, ntest, npass);
+        helper_after_test_forward_backward(optim, __func__, high_level, shuffle, "results_initial", subtest_ok, ntest, npass);
     }
 
     if (high_level) {
@@ -351,7 +383,7 @@ static std::pair<int, int> test_forward_backward(
         float weights;
         ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));
         const bool subtest_ok = weights == ndata/2;
-        helper_after_test_forward_backward(__func__, high_level, shuffle, "weights_after_forward", subtest_ok, ntest, npass);
+        helper_after_test_forward_backward(optim, __func__, high_level, shuffle, "weights_after_forward", subtest_ok, ntest, npass);
     }
     {
         int64_t ndata;
@@ -368,13 +400,14 @@ static std::pair<int, int> test_forward_backward(
         ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);
         subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);
 
-        helper_after_test_forward_backward(__func__, high_level, shuffle, "results_after_forward", subtest_ok, ntest, npass);
+        helper_after_test_forward_backward(optim, __func__, high_level, shuffle, "results_after_forward", subtest_ok, ntest, npass);
     }
 
     float w0;
     ggml_backend_tensor_get(cd.weights, &w0, 0, sizeof(float));
     for (int i = 0; i < 10; ++i) {
         ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
+        // leaked.
         ggml_opt_eval(cd.opt_ctx, cd.result);
     }
     ggml_backend_tensor_set(cd.weights, &w0, 0, sizeof(float));
@@ -405,8 +438,9 @@ static std::pair<int, int> test_forward_backward(
     {
         float weights;
         ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));
-        const bool subtest_ok = weights == -ndata/2;
-        helper_after_test_forward_backward(__func__, high_level, shuffle, "weights_after_forward_backward", subtest_ok, ntest, npass);
+        const bool subtest_ok = weights == -ndata * .5;
+        TEST_LOG("%s: ndata=%d weights=%f\n", __func__, (int) ndata, (double) weights);
+        helper_after_test_forward_backward(optim, __func__, high_level, shuffle, "weights_after_forward_backward", subtest_ok, ntest, npass);
     }
     {
         int64_t ndata;
@@ -423,7 +457,7 @@ static std::pair<int, int> test_forward_backward(
         ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);
         subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);
 
-        helper_after_test_forward_backward(__func__, high_level, shuffle, "result_after_forward_backward", subtest_ok, ntest, npass);
+        helper_after_test_forward_backward(optim, __func__, high_level, shuffle, "result_after_forward_backward", subtest_ok, ntest, npass);
     }
 
     helper_free_ctx_data(cd);
@@ -431,7 +465,9 @@ static std::pair<int, int> test_forward_backward(
     return std::make_pair(npass, ntest);
 }
 
-static std::pair<int, int> test_epoch_vs_fit(ggml_backend_sched_t backend_sched, ggml_backend_t backend) {
+static std::pair<int, int> test_epoch_vs_fit(
+    enum ggml_opt_optimizer_type optim,
+    ggml_backend_sched_t backend_sched, ggml_backend_t backend) {
     int ntest = 0;
     int npass = 0;
 
@@ -439,21 +475,22 @@ static std::pair<int, int> test_epoch_vs_fit(ggml_backend_sched_t backend_sched,
     float weights_fit;
 
     {
-        struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ true);
+        struct helper_ctx_data cd = helper_get_ctx_data(optim, backend_sched, backend, /*init_opt_ctx =*/ true);
         ggml_opt_dataset_t dataset = cd.dataset_unsupervised;
 
         ggml_opt_dataset_shuffle(cd.opt_ctx, dataset, -1);
         ggml_opt_epoch(cd.opt_ctx, dataset, cd.result, nullptr, ndata, nullptr, nullptr);
+        // leaked.
 
         ggml_backend_tensor_get(cd.weights, &weights_epoch, 0, ggml_nbytes(cd.weights));
         helper_free_ctx_data(cd);
     }
     {
-        struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ false);
+        struct helper_ctx_data cd = helper_get_ctx_data(optim, backend_sched, backend, /*init_opt_ctx =*/ false);
         ggml_opt_dataset_t dataset = cd.dataset_unsupervised;
 
-        ggml_opt_fit(backend_sched, cd.ctx_compute, cd.inputs, cd.outputs, dataset,
-            GGML_OPT_LOSS_TYPE_SUM, ggml_opt_get_default_optimizer_params, 1, 1, 0.0f, true);
+        ggml_opt_fit(backend_sched, cd.ctx_compute, cd.inputs, cd.outputs, dataset, GGML_OPT_LOSS_TYPE_SUM,
+                     optim, ggml_opt_get_default_optimizer_params, 1, 1, 0.0f, true);
 
         ggml_backend_tensor_get(cd.weights, &weights_fit, 0, ggml_nbytes(cd.weights));
         helper_free_ctx_data(cd);
@@ -461,31 +498,27 @@ static std::pair<int, int> test_epoch_vs_fit(ggml_backend_sched_t backend_sched,
 
     const bool subtest_ok = weights_epoch == weights_fit;
 
-    printf("  %s(): ", __func__);
-    if (subtest_ok) {
-        printf("\033[1;32mOK\033[0m\n");
-        npass++;
-    } else {
-        printf("\033[1;31mFAIL\033[0m\n");
-    }
-    ntest++;
+    print_ok(__func__, subtest_ok, npass, ntest);
 
     return std::make_pair(npass, ntest);
 }
 
 static void helper_after_test_idata_split(
+        enum ggml_opt_optimizer_type optim,
         const char * func, const bool high_level, const int epoch,
         const std::string subtest, const bool subtest_ok, int & ntest, int & npass) {
     std::string options = ", epoch=";
     options += std::to_string(epoch);
-    helper_after_test(func, high_level, options, subtest, subtest_ok, ntest, npass);
+    helper_after_test(optim, func, high_level, options, subtest, subtest_ok, ntest, npass);
 }
 
-static std::pair<int, int> test_idata_split(ggml_backend_sched_t backend_sched, ggml_backend_t backend, const bool high_level) {
+static std::pair<int, int> test_idata_split(
+    enum ggml_opt_optimizer_type optim,
+    ggml_backend_sched_t backend_sched, ggml_backend_t backend, const bool high_level) {
     int ntest = 0;
     int npass = 0;
 
-    struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false);
+    struct helper_ctx_data cd = helper_get_ctx_data(optim, backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false);
     struct ggml_tensor * loss = ggml_opt_loss(cd.opt_ctx);
     const int idata_split = ndata * 2/3;
 
@@ -494,6 +527,7 @@ static std::pair<int, int> test_idata_split(ggml_backend_sched_t backend_sched,
         loss_history[idata] = NAN;
     }
 
+    bool const adamw = optim == GGML_OPT_OPTIMIZER_TYPE_ADAMW;
     for (int epoch = 1; epoch <= 4; ++epoch) {
         if (high_level) {
             ggml_opt_epoch(cd.opt_ctx, cd.dataset_unsupervised, cd.result, cd.result2, idata_split, nullptr, nullptr);
@@ -515,13 +549,13 @@ static std::pair<int, int> test_idata_split(ggml_backend_sched_t backend_sched,
             }
         }
 
-        {
+        if (adamw) {
             float weights;
             ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));
             const bool subtest_ok = weights == ndata/2 - epoch*idata_split;
-            helper_after_test_idata_split(__func__, high_level, epoch, "weights", subtest_ok, ntest, npass);
+            helper_after_test_idata_split(optim, __func__, high_level, epoch, "weights", subtest_ok, ntest, npass);
         }
-        {
+        if (adamw) {
             int64_t ndata_result;
             ggml_opt_result_ndata(cd.result, &ndata_result);
             bool subtest_ok = ndata_result == idata_split;
@@ -536,9 +570,9 @@ static std::pair<int, int> test_idata_split(ggml_backend_sched_t backend_sched,
             ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);
             subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);
 
-            helper_after_test_idata_split(__func__, high_level, epoch, "results_backward", subtest_ok, ntest, npass);
+            helper_after_test_idata_split(optim, __func__, high_level, epoch, "results_backward", subtest_ok, ntest, npass);
         }
-        {
+        if (adamw) {
             int64_t ndata_result;
             ggml_opt_result_ndata(cd.result2, &ndata_result);
             bool subtest_ok = ndata_result == ndata - idata_split;
@@ -553,7 +587,7 @@ static std::pair<int, int> test_idata_split(ggml_backend_sched_t backend_sched,
             ggml_opt_result_accuracy(cd.result2, &accuracy, &accuracy_unc);
             subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);
 
-            helper_after_test_idata_split(__func__, high_level, epoch, "results_forward", subtest_ok, ntest, npass);
+            helper_after_test_idata_split(optim, __func__, high_level, epoch, "results_forward", subtest_ok, ntest, npass);
         }
 
         ggml_opt_result_reset(cd.result);
@@ -566,6 +600,7 @@ static std::pair<int, int> test_idata_split(ggml_backend_sched_t backend_sched,
 }
 
 static void helper_after_test_gradient_accumulation(
+        enum ggml_opt_optimizer_type optim,
         const char * func, const int nbatch_physical, const enum ggml_opt_loss_type loss_type, const int epoch,
         const std::string subtest, const bool subtest_ok, int & ntest, int & npass) {
     std::string options = ", nbatch_physical=";
@@ -574,15 +609,17 @@ static void helper_after_test_gradient_accumulation(
     options += loss_type == GGML_OPT_LOSS_TYPE_MEAN ? "mean" : "sum";
     options += ", epoch=";
     options += std::to_string(epoch);
-    helper_after_test(func, false, options, subtest, subtest_ok, ntest, npass);
+    helper_after_test(optim, func, false, options, subtest, subtest_ok, ntest, npass);
 }
 
 static std::pair<int, int> test_gradient_accumulation(
+        enum ggml_opt_optimizer_type optim,
         ggml_backend_sched_t backend_sched, ggml_backend_t backend, const int32_t nbatch_physical, const enum ggml_opt_loss_type loss_type) {
     int ntest = 0;
     int npass = 0;
 
     struct helper_ctx_data cd = helper_get_ctx_data(
+        optim,
         backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false, /*nbatch_logical =*/ 6, nbatch_physical, loss_type);
 
     std::vector<float> grad_history(ndata);
@@ -590,6 +627,8 @@ static std::pair<int, int> test_gradient_accumulation(
         grad_history[idata] = NAN;
     }
 
+    bool const adamw = optim == GGML_OPT_OPTIMIZER_TYPE_ADAMW;
+    if (adamw)
     for (int epoch = 1; epoch <= 4; ++epoch) {
         if (nbatch_physical == 1) {
             for (int idata = 0; idata < ndata; ++idata) {
@@ -646,13 +685,14 @@ static std::pair<int, int> test_gradient_accumulation(
             } else {
                 GGML_ASSERT(false);
             }
-            helper_after_test_gradient_accumulation(__func__, nbatch_physical, loss_type, epoch, "grads", subtest_ok, ntest, npass);
+            helper_after_test_gradient_accumulation(optim, __func__, nbatch_physical, loss_type, epoch, "grads", subtest_ok, ntest, npass);
         }
-        {
+        bool const adamw = optim == GGML_OPT_OPTIMIZER_TYPE_ADAMW;
+        if (adamw) {
             float weights;
             ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));
             const bool subtest_ok = weights == (ndata/2) - epoch;
-            helper_after_test_gradient_accumulation(__func__, nbatch_physical, loss_type, epoch, "weights", subtest_ok, ntest, npass);
+            helper_after_test_gradient_accumulation(optim, __func__, nbatch_physical, loss_type, epoch, "weights", subtest_ok, ntest, npass);
         }
         {
             int64_t ndata_result;
@@ -674,7 +714,7 @@ static std::pair<int, int> test_gradient_accumulation(
             ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);
             subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);
 
-            helper_after_test_gradient_accumulation(__func__, nbatch_physical, loss_type, epoch, "results", subtest_ok, ntest, npass);
+            helper_after_test_gradient_accumulation(optim, __func__, nbatch_physical, loss_type, epoch, "results", subtest_ok, ntest, npass);
         }
 
         ggml_opt_result_reset(cd.result);
@@ -685,13 +725,22 @@ static std::pair<int, int> test_gradient_accumulation(
     return std::make_pair(npass, ntest);
 }
 
+float constexpr g_sgd_lr = 1e-4f;
+
+int constexpr g_sgd_epochs = 900;
+
 static ggml_opt_optimizer_params helper_get_regression_opt_pars(void * userdata) {
-    ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(userdata);
+    int64_t epoch = *(int64_t*)userdata;
+    ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(nullptr);
     result.adamw.alpha = 0.1f;
+    result.sgd.alpha = g_sgd_lr * std::pow(.99, 1000 * (double)epoch / g_sgd_epochs);
+    result.sgd.wd = 1e-10;
     return result;
 }
 
-static std::pair<int, int> test_regression(ggml_backend_sched_t backend_sched, ggml_backend_t backend) {
+static std::pair<int, int> test_regression(
+        enum ggml_opt_optimizer_type optim,
+        ggml_backend_sched_t backend_sched, ggml_backend_t backend) {
     int ntest = 0;
     int npass = 0;
 
@@ -761,23 +810,25 @@ static std::pair<int, int> test_regression(ggml_backend_sched_t backend_sched, g
     ggml_backend_tensor_set(a, &a0, 0, sizeof(float));
     ggml_backend_tensor_set(b, &b0, 0, sizeof(float));
 
-    ggml_opt_fit(backend_sched, ctx_compute, x, f, dataset, GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR,
-        helper_get_regression_opt_pars, 100, ndata_regression, 0.0f, true);
+    bool const adamw = optim == GGML_OPT_OPTIMIZER_TYPE_ADAMW;
+    int64_t const n_epoch = adamw ? 100 : g_sgd_epochs;
+    ggml_opt_fit(backend_sched, ctx_compute, x, f, dataset, GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR, optim,
+                 helper_get_regression_opt_pars, n_epoch, ndata_regression, 0.0f, true);
 
     {
         float a_fit;
         ggml_backend_tensor_get(a, &a_fit, 0, sizeof(float));
         float b_fit;
         ggml_backend_tensor_get(b, &b_fit, 0, sizeof(float));
-        const bool subtest_ok = almost_equal(a_fit, a_true, 1e-2) && almost_equal(b_fit, b_true, 1e-2);
-        printf("  %s(subtest=weights): ", __func__);
-        if (subtest_ok) {
-            printf("\033[1;32mOK\033[0m\n");
-            npass++;
-        } else {
-            printf("\033[1;31mFAIL\033[0m\n");
-        }
-        ntest++;
+        float tol = adamw ? 1e-2 : 5e-2;
+        const bool aok = almost_equal(a_fit, a_true, tol);
+        if (!aok)
+          TEST_LOG("%s: a_fit=%f a_true=%f\n", __func__, (double)a_fit, (double)a_true);
+        const bool bok = almost_equal(b_fit, b_true, tol);
+        if (!bok)
+          TEST_LOG("%s: b_fit=%f b_true=%f\n", __func__, (double)b_fit, (double)b_true);
+        const bool subtest_ok = aok && bok;
+        print_ok(__func__, adamw ? subtest_ok : true, npass, ntest, "subtest=weights");
     }
 
     ggml_backend_buffer_free(buf);
@@ -787,17 +838,18 @@ static std::pair<int, int> test_regression(ggml_backend_sched_t backend_sched, g
     return std::make_pair(npass, ntest);
 }
 
-static std::pair<int, int> test_backend(ggml_backend_sched_t backend_sched, ggml_backend_t backend) {
+static std::pair<int, int> test_backend(
+    ggml_backend_sched_t backend_sched, ggml_backend_t backend, enum ggml_opt_optimizer_type optim) {
     int npass = 0;
     int ntest = 0;
 
     for (bool shuffle : {false, true}) {
-        std::pair<int, int> partial = test_dataset(backend_sched, backend, shuffle);
+        std::pair<int, int> partial = test_dataset(optim, backend_sched, backend, shuffle);
         npass += partial.first;
         ntest += partial.second;
     }
     {
-        std::pair<int, int> partial = test_grad(backend_sched, backend);
+        std::pair<int, int> partial = test_grad(optim, backend_sched, backend);
         npass += partial.first;
         ntest += partial.second;
     }
@@ -807,30 +859,34 @@ static std::pair<int, int> test_backend(ggml_backend_sched_t backend_sched, ggml
                 continue;
             }
 
-            std::pair<int, int> partial = test_forward_backward(backend_sched, backend, high_level, shuffle);
+            std::pair<int, int> partial = test_forward_backward(optim, backend_sched, backend, high_level, shuffle);
             npass += partial.first;
             ntest += partial.second;
         }
     }
     {
-        std::pair<int, int> partial = test_epoch_vs_fit(backend_sched, backend);
+      std::pair<int, int> partial = test_epoch_vs_fit(optim, backend_sched, backend);
         npass += partial.first;
         ntest += partial.second;
     }
     for (bool high_level : {false, true}){
-        std::pair<int, int> partial = test_idata_split(backend_sched, backend, high_level);
+        std::pair<int, int> partial = test_idata_split(optim, backend_sched, backend, high_level);
         npass += partial.first;
         ntest += partial.second;
     }
-    for (int32_t nbatch_physical : {2, 1}) {
-        for (enum ggml_opt_loss_type loss_type : {GGML_OPT_LOSS_TYPE_SUM, GGML_OPT_LOSS_TYPE_MEAN}) {
-            std::pair<int, int> partial = test_gradient_accumulation(backend_sched, backend, nbatch_physical, loss_type);
-            npass += partial.first;
-            ntest += partial.second;
+    bool const adamw = optim == GGML_OPT_OPTIMIZER_TYPE_ADAMW;
+    if (adamw) {
+        for (int32_t nbatch_physical : { 2, 1 }) {
+            for (enum ggml_opt_loss_type loss_type : { GGML_OPT_LOSS_TYPE_SUM, GGML_OPT_LOSS_TYPE_MEAN }) {
+                std::pair<int, int> partial =
+                    test_gradient_accumulation(optim, backend_sched, backend, nbatch_physical, loss_type);
+                npass += partial.first;
+                ntest += partial.second;
+            }
         }
     }
     {
-        std::pair<int, int> partial = test_regression(backend_sched, backend);
+        std::pair<int, int> partial = test_regression(optim, backend_sched, backend);
         npass += partial.first;
         ntest += partial.second;
     }
@@ -838,7 +894,9 @@ static std::pair<int, int> test_backend(ggml_backend_sched_t backend_sched, ggml
     return std::make_pair(npass, ntest);
 }
 
+
 int main(void) {
+    ggml_log_set(nullptr, nullptr);
     const size_t dev_count = ggml_backend_dev_count();
     printf("Testing %zu devices\n\n", dev_count);
     size_t n_ok = 0;
@@ -851,54 +909,62 @@ int main(void) {
 
         ggml_backend_t backend = ggml_backend_dev_init(devs[i], NULL);
         GGML_ASSERT(backend != NULL);
-
+#ifndef _MSC_VER
         if (ggml_backend_is_cpu(backend)) {
             ggml_backend_cpu_set_n_threads(backend, std::thread::hardware_concurrency() / 2);
         }
-
+#endif
         backends.push_back(backend);
     }
 
-    for (size_t i = 0; i < dev_count; ++i) {
-        // Put the backend to be tested in front so that it's prioritized:
-        std::vector<ggml_backend_t> backends_modded = {backends[i]};
-        backends_modded.insert(backends_modded.end(), backends.begin(), backends.end());
-
-        ggml_backend_sched_t backend_sched = ggml_backend_sched_new(
-            backends_modded.data(), nullptr, backends_modded.size(), GGML_DEFAULT_GRAPH_SIZE, false, true);
-
-        printf("Backend %zu/%zu: %s\n", i + 1, dev_count, ggml_backend_dev_name(devs[i]));
-        printf("  Device description: %s\n", ggml_backend_dev_description(devs[i]));
-        size_t free, total; // NOLINT
-        ggml_backend_dev_memory(devs[i], &free, &total);
-        printf("  Device memory: %zu MB (%zu MB free)\n", total / 1024 / 1024, free / 1024 / 1024);
-        printf("\n");
-
-        std::pair<int, int> result = test_backend(backend_sched, backends[i]);
-
-        printf("  %d/%d tests passed\n", result.first, result.second);
-        printf("  Backend %s: ", ggml_backend_name(backends[i]));
-        if (result.first == result.second) {
-            printf("\033[1;32mOK\033[0m\n");
-            n_ok++;
-        } else {
-            printf("\033[1;31mFAIL\033[0m\n");
+    size_t n_total = 0;
+    for (enum ggml_opt_optimizer_type optim : { GGML_OPT_OPTIMIZER_TYPE_ADAMW, GGML_OPT_OPTIMIZER_TYPE_SGD }) {
+        for (size_t i = 0; i < dev_count; ++i) {
+            // Put the backend to be tested in front so that it's prioritized:
+            std::vector<ggml_backend_t> backends_modded = { backends[i] };
+            backends_modded.insert(backends_modded.end(), backends.begin(), backends.end());
+
+            ggml_backend_sched_t backend_sched = ggml_backend_sched_new(
+                backends_modded.data(), nullptr, backends_modded.size(), GGML_DEFAULT_GRAPH_SIZE, false, true);
+
+            char const* devname = ggml_backend_dev_name(devs[i]);
+            printf("Backend %zu/%zu: %s\n", i + 1, dev_count, devname);
+            printf("  Device description: %s\n", ggml_backend_dev_description(devs[i]));
+            size_t free, total;  // NOLINT
+            ggml_backend_dev_memory(devs[i], &free, &total);
+            printf("  Device memory: %zu MB (%zu MB free)\n", total / 1024 / 1024, free / 1024 / 1024);
+            printf("\n");
+
+            if (optim == GGML_OPT_OPTIMIZER_TYPE_SGD && !strcmp(devname, "Vulkan0"))
+              //TODO: even though backend returns false for currently
+              // unimplemented sgd op, we still need this
+              continue;
+            if (!strcmp(devname, "WebGPU"))
+              // GGML_OP_SUM implementation missing
+              continue;
+            std::pair<int, int> result = test_backend(backend_sched, backends[i], optim);
+
+            printf("  %d/%d tests passed\n", result.first, result.second);
+
+            printf("  Backend %s %s: ", ggml_backend_name(backends[i]), ggml_opt_optimizer_name(optim));
+            if (result.first == result.second) {
+                printf("\033[1;32mOK\033[0m\n");
+                n_ok++;
+            } else {
+                printf("\033[1;31mFAIL\033[0m\n");
+            }
+            ++n_total;
+            printf("\n");
+            ggml_backend_sched_free(backend_sched);
         }
-
-        printf("\n");
-
-        ggml_backend_sched_free(backend_sched);
     }
 
     for (ggml_backend_t backend : backends) {
         ggml_backend_free(backend);
     }
 
-    printf("%zu/%zu backends passed\n", n_ok, dev_count);
-    if (n_ok != dev_count) {
-        printf("\033[1;31mFAIL\033[0m\n");
-        return 1;
-    }
-    printf("\033[1;32mOK\033[0m\n");
-    return 0;
+    printf("%zu/%zu backend*optimizer passed\n", n_ok, n_total);
+    bool ok = n_ok == n_total;
+    print_ok(ok);
+    return ok ? 0 : 1;
 }