]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
tests: add gradient tests for all backends (#932)
authorJohannes Gäßler <redacted>
Tue, 3 Sep 2024 15:21:46 +0000 (17:21 +0200)
committerGitHub <redacted>
Tue, 3 Sep 2024 15:21:46 +0000 (17:21 +0200)
* tests: add gradient checking to test-backend-ops

* remove old comment

* reorder includes

* adjust SIN/COS parameters

* add documentation, use supports_op if possible

include/ggml.h
src/ggml-backend.c
src/ggml-cuda.cu
src/ggml-cuda/cross-entropy-loss.cu
src/ggml-cuda/sum.cu [new file with mode: 0644]
src/ggml-cuda/sum.cuh [new file with mode: 0644]
src/ggml-cuda/unary.cu
src/ggml-cuda/unary.cuh
src/ggml.c
tests/test-backend-ops.cpp

index 2d381f91c889bcdaaef05120b554e05d5616ab86..59fa80edb60bb8d8177f0404aa87cb28544e1347 100644 (file)
@@ -1234,7 +1234,7 @@ extern "C" {
             size_t                nb1,
             size_t                nb2,
             size_t                nb3,
-            size_t                offset);
+            size_t                offset); // in bytes
 
     // b -> view(a,offset,nb1,nb2,3), return view(a)
     GGML_API struct ggml_tensor * ggml_set_inplace(
@@ -1244,19 +1244,19 @@ extern "C" {
             size_t                nb1,
             size_t                nb2,
             size_t                nb3,
-            size_t                offset);
+            size_t                offset); // in bytes
 
     GGML_API struct ggml_tensor * ggml_set_1d(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b,
-            size_t                offset);
+            size_t                offset); // in bytes
 
     GGML_API struct ggml_tensor * ggml_set_1d_inplace(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b,
-            size_t                offset);
+            size_t                offset); // in bytes
 
     // b -> view(a,offset,nb1,nb2,3), return modified a
     GGML_API struct ggml_tensor * ggml_set_2d(
@@ -1264,7 +1264,7 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b,
             size_t                nb1,
-            size_t                offset);
+            size_t                offset); // in bytes
 
     // b -> view(a,offset,nb1,nb2,3), return view(a)
     GGML_API struct ggml_tensor * ggml_set_2d_inplace(
@@ -1272,7 +1272,7 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b,
             size_t                nb1,
-            size_t                offset);
+            size_t                offset); // in bytes
 
     // a -> b, return view(b)
     GGML_API struct ggml_tensor * ggml_cpy(
index 8856967c911042428d4aac20c8c73eae0450c454..6ba5c0889f083c9ff5d3bf792fa808ff67d3402c 100644 (file)
@@ -825,6 +825,10 @@ GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const
                 op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
         case GGML_OP_MUL_MAT:
             return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type;
+        case GGML_OP_ROPE_BACK:
+            return op->src[2] == NULL && (op->op_params[2] & 4) == 0;
+        case GGML_OP_IM2COL_BACK:
+            return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
         default:
             return true;
     }
index 0bc08d3e3cdbf5c23d933b045fe5ee67d252c827..dcb53224a00c7d5e72c775b125910c8540cf1e8a 100644 (file)
@@ -27,6 +27,7 @@
 #include "ggml-cuda/rope.cuh"
 #include "ggml-cuda/scale.cuh"
 #include "ggml-cuda/softmax.cuh"
+#include "ggml-cuda/sum.cuh"
 #include "ggml-cuda/sumrows.cuh"
 #include "ggml-cuda/tsembd.cuh"
 #include "ggml-cuda/unary.cuh"
@@ -2180,6 +2181,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
             ggml_cuda_dup(ctx, dst);
             break;
         case GGML_OP_ADD:
+        case GGML_OP_ADD1: // TODO: more efficient implementation
             ggml_cuda_op_add(ctx, dst);
             break;
         case GGML_OP_SUB:
@@ -2196,6 +2198,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
             break;
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(dst)) {
+                case GGML_UNARY_OP_NEG:
+                    ggml_cuda_op_neg(ctx, dst);
+                    break;
                 case GGML_UNARY_OP_GELU:
                     ggml_cuda_op_gelu(ctx, dst);
                     break;
@@ -2304,6 +2309,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_POOL_2D:
             ggml_cuda_op_pool2d(ctx, dst);
             break;
+        case GGML_OP_SUM:
+            ggml_cuda_op_sum(ctx, dst);
+            break;
         case GGML_OP_SUM_ROWS:
             ggml_cuda_op_sum_rows(ctx, dst);
             break;
@@ -2741,6 +2749,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
     switch (op->op) {
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(op)) {
+                case GGML_UNARY_OP_NEG:
                 case GGML_UNARY_OP_GELU:
                 case GGML_UNARY_OP_SILU:
                 case GGML_UNARY_OP_RELU:
@@ -2867,6 +2876,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_TRANSPOSE:
         case GGML_OP_NORM:
         case GGML_OP_ADD:
+        case GGML_OP_ADD1:
         case GGML_OP_SUB:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
@@ -2886,7 +2896,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_ROPE:
             return ggml_is_contiguous(op->src[0]);
         case GGML_OP_IM2COL:
+            return op->src[0]->type == GGML_TYPE_F16;
         case GGML_OP_POOL_2D:
+        case GGML_OP_SUM:
         case GGML_OP_SUM_ROWS:
         case GGML_OP_ARGSORT:
         case GGML_OP_ACC:
index a14043e70451a0595ece1e34495f5bd7eef152d3..5575a90f643266bf5eb233596b731e2a938cc1a0 100644 (file)
@@ -1,6 +1,6 @@
 #include "common.cuh"
 #include "cross-entropy-loss.cuh"
-#include "sumrows.cuh"
+#include "sum.cuh"
 
 #include <cmath>
 #include <cstdint>
@@ -102,5 +102,5 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
     cross_entropy_loss_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
 
     // Combine results from individual blocks:
-    sum_rows_f32_cuda(dst_tmp.ptr, dst_d, blocks_num.x, 1, stream);
+    sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream);
 }
diff --git a/src/ggml-cuda/sum.cu b/src/ggml-cuda/sum.cu
new file mode 100644 (file)
index 0000000..0d5e953
--- /dev/null
@@ -0,0 +1,41 @@
+#include "sumrows.cuh"
+#include "sum.cuh"
+
+#include <cstdint>
+
+#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
+#include <cub/cub.cuh>
+using namespace cub;
+#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
+
+void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream) {
+#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
+    size_t tmp_size = 0;
+    DeviceReduce::Sum(nullptr,       tmp_size, x, dst, ne, stream);
+    ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
+    DeviceReduce::Sum(tmp_alloc.ptr, tmp_size, x, dst, ne, stream);
+#else
+    // Use (inefficient) sum_rows implementation as a fallback.
+    // For AMD there is rocPRIM which could be used as a drop-in replacement via hipcub but this would require C++11 -> C++14.
+    sum_rows_f32_cuda(x, dst, ne, 1, stream);
+    GGML_UNUSED(pool);
+#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
+}
+
+void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    const float * src0_d = (const float *) src0->data;
+    float * dst_d = (float *) dst->data;
+
+    const int64_t ne = ggml_nelements(src0);
+
+    ggml_cuda_pool & pool = ctx.pool();
+    cudaStream_t stream = ctx.stream();
+
+    sum_f32_cuda(pool, src0_d, dst_d, ne, stream);
+}
diff --git a/src/ggml-cuda/sum.cuh b/src/ggml-cuda/sum.cuh
new file mode 100644 (file)
index 0000000..8cadc37
--- /dev/null
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream);
+
+void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 89abfc21d8a56c85bb191fd75270cf713a4a7b64..8ac669f94e2de23ed18a999041cab0861f8c5480 100644 (file)
@@ -1,5 +1,15 @@
 #include "unary.cuh"
 
+static __global__ void neg_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    dst[i] = -x[i];
+}
+
 static __global__ void gelu_f32(const float * x, float * dst, const int k) {
     const float GELU_COEF_A    = 0.044715f;
     const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
@@ -119,6 +129,11 @@ static __global__ void cos_f32(const float * x, float * dst, const int k) {
     dst[i] = cosf(x[i]);
 }
 
+static void neg_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE;
+    neg_f32<<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
 static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
     gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -184,6 +199,20 @@ static void cos_f32_cuda(const float * x, float * dst, const int k, cudaStream_t
     cos_f32<<<num_blocks, CUDA_COS_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
+void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    neg_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
 void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const float * src0_d = (const float *)src0->data;
index c610e996abeb62a2fdcc2ea424774513f7e5c443..ed2ffc461e8102caafcf1c2d48fa82892a8eafc4 100644 (file)
@@ -1,5 +1,6 @@
 #include "common.cuh"
 
+#define CUDA_NEG_BLOCK_SIZE 256
 #define CUDA_GELU_BLOCK_SIZE 256
 #define CUDA_SILU_BLOCK_SIZE 256
 #define CUDA_TANH_BLOCK_SIZE 256
@@ -12,6 +13,8 @@
 #define CUDA_SIN_BLOCK_SIZE 256
 #define CUDA_COS_BLOCK_SIZE 256
 
+void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
 void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 1e0dde6e8bb64db933640e930e13f05720bfc4e8..0c820015fed9f34d40713f4caef2a04e06966aec 100644 (file)
@@ -5131,6 +5131,7 @@ struct ggml_tensor * ggml_concat(
     bool is_node = false;
 
     if (a->grad || b->grad) {
+        GGML_ABORT("fatal error"); // TODO: implement
         is_node = true;
     }
 
@@ -5252,6 +5253,7 @@ struct ggml_tensor * ggml_leaky_relu(
     bool is_node = false;
 
     if (!inplace && (a->grad)) {
+        GGML_ABORT("fatal error"); // TODO: not implemented
         is_node = true;
     }
 
@@ -5677,6 +5679,7 @@ static struct ggml_tensor * ggml_set_impl(
     // make a view of the destination
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
+    GGML_ASSERT(offset < (size_t)(1 << 30));
     int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };
     ggml_set_op_params(result, params, sizeof(params));
 
@@ -6634,14 +6637,12 @@ struct ggml_tensor * ggml_rope_back(
     GGML_ASSERT(ggml_is_vector(b));
     GGML_ASSERT(b->type == GGML_TYPE_I32);
     GGML_ASSERT(a->ne[2] == b->ne[0]);
-    GGML_ASSERT(c == NULL && "freq factors not implemented yet");
-
-    GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
 
     bool is_node = false;
 
     if (a->grad) {
-        is_node = false; // TODO: implement backward
+        GGML_ASSERT(false && "backwards pass not implemented");
+        is_node = false;
     }
 
     struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
@@ -6659,6 +6660,7 @@ struct ggml_tensor * ggml_rope_back(
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
     result->src[1] = b;
+    result->src[2] = c;
 
     return result;
 }
@@ -7212,6 +7214,11 @@ struct ggml_tensor * ggml_argsort(
         enum ggml_sort_order  order) {
     bool is_node = false;
 
+    if (a->grad) {
+        GGML_ABORT("fatal error"); // TODO: not implemented
+        is_node = true;
+    }
+
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
 
     ggml_set_op_params_i32(result, 0, (int32_t) order);
@@ -10745,9 +10752,6 @@ static void ggml_compute_forward_sum_f32(
         return;
     }
 
-    assert(ggml_is_scalar(dst));
-
-
     assert(ggml_is_scalar(dst));
     assert(src0->nb[0] == sizeof(float));
 
@@ -18000,14 +18004,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 if (src0->grad || src1->grad) {
                     GGML_ASSERT(src0->type == tensor->type);
                     GGML_ASSERT(tensor->grad->type == tensor->type);
-                    GGML_ASSERT(tensor->grad->type == src1->grad->type);
+                    GGML_ASSERT(!src1->grad || src1->grad->type == tensor->grad->type);
 
                     tensor_grad_view = ggml_view_4d(ctx,
-                        tensor->grad,
-                        src1->grad->ne[0],
-                        src1->grad->ne[1],
-                        src1->grad->ne[2],
-                        src1->grad->ne[3],
+                        tensor->grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
                         nb1, nb2, nb3, offset);
                 }
 
@@ -18076,9 +18076,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
 
                     memcpy(&offset, tensor->op_params, sizeof(offset));
 
-                    size_t nb1     = tensor->nb[1];
-                    size_t nb2     = tensor->nb[2];
-                    size_t nb3     = tensor->nb[3];
+                    size_t nb1 = tensor->nb[1];
+                    size_t nb2 = tensor->nb[2];
+                    size_t nb3 = tensor->nb[3];
 
                     if (src0->type != src0->grad->type) {
                         // gradient is typically F32, but src0 could be other type
index 318848962f59f4aaf3e1ea23d32954979686dfd0..0a88781bade9b87662cb9f14511cd61f0bbc37d8 100644 (file)
@@ -1,3 +1,20 @@
+// This file defines tests for various GGML ops and backends.
+// For the forward pass it asserts that the results of multiple backends computing the same GGML ops are consistent.
+// For the backwards pass it asserts that the gradients from backpropagation are consistent
+// with the gradients obtained via the method of finite differences ("grad" mode, this is optional).
+// It is also possible to check the performance ("perf" mode).
+//
+// this file has three sections: Section 1 does general setup, section 2 defines the GGML ops to be tested,
+// and section 3 defines which tests to run.
+// Quick start for adding a new GGML op: Go to section 2 and create a struct that inherits from test_case,
+// then go to section 3 and add an instantiation of your struct.
+
+
+// ##############################
+// ## Section 1: General Setup ##
+// ##############################
+
+
 #include <ggml.h>
 #include <ggml-alloc.h>
 #include <ggml-backend.h>
@@ -5,6 +22,7 @@
 #include <algorithm>
 #include <array>
 #include <cfloat>
+#include <cstdint>
 #include <cstring>
 #include <functional>
 #include <memory>
@@ -212,6 +230,39 @@ static double nmse(const float * a, const float * b, size_t n) {
     return mse_a_b / mse_a_0;
 }
 
+// maximum absolute asymmetry between a and b
+// asymmetry: (a - b) / (a + b)
+// This is more stable than relative error if one of the values fluctuates towards zero.
+// n: number of values to compare.
+// expected_vals: optional vector of expected values for a. If expected_vals is not empty, filter out all comparisons where
+//     a does not match any of the expected values. Needed for noncontinuous gradients where the numerical calculation can fail.
+static double mean_abs_asymm(const float * a, const float * b, const size_t n, const std::vector<float> & expected_vals) {
+    double sum = 0.0f;
+
+    size_t nvalid = 0;
+    for (size_t i = 0; i < n; i++) {
+        if (!expected_vals.empty()) {
+            bool matches_any = false;
+            for (const float & ev : expected_vals) {
+                if (fabsf(a[i] - ev) < 1e-3f) {
+                    matches_any = true;
+                    break;
+                }
+            }
+            if (!matches_any) {
+                continue;
+            }
+        }
+
+        const float asymm = (a[i] - b[i]) / (a[i] + b[i]);
+
+        sum += fabsf(asymm);
+        nvalid++;
+    }
+
+    return sum/nvalid;
+}
+
 // utils for printing the variables of the test cases
 #define VAR_TO_STR(x) (#x "=" + var_to_str(x))
 
@@ -295,6 +346,7 @@ static bool ggml_is_view_op(enum ggml_op op) {
 enum test_mode {
     MODE_TEST,
     MODE_PERF,
+    MODE_GRAD,
 };
 
 struct test_case {
@@ -314,6 +366,32 @@ struct test_case {
         return 1e-7;
     }
 
+    virtual double max_maa_err() {
+        return 1e-4;
+    }
+
+    virtual float grad_eps(){
+        return 1e-1f;
+    }
+
+    // If false, estimate gradient with 2 points, neglects 3rd order derivative and higher.
+    // If true,  estimate gradient with 4 points, neglects 5th order derivative and higher.
+    virtual bool grad_precise(){
+        return false;
+    }
+
+    // Skip gradient checks if total number of gradients to be checked is larger than this (to speed up the tests).
+    virtual int64_t grad_nmax() {
+        return 10000;
+    }
+
+    // No effect if empty.
+    // If not empty, skip all gradient checks where the numerical result does not match any of the values.
+    // Needed for dealing with noncontinuous gradients (e.g. ReLU) where estimation using finite differences is unreliable.
+    virtual std::vector<float> grad_expect() {
+        return {};
+    }
+
     virtual void initialize_tensors(ggml_context * ctx) {
         for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
             init_tensor_uniform(t);
@@ -332,6 +410,7 @@ struct test_case {
     }
 
     ggml_cgraph * gf = nullptr;
+    ggml_cgraph * gb = nullptr;
 
     static const int sentinel_size = 1024;
 
@@ -340,7 +419,7 @@ struct test_case {
     std::vector<ggml_tensor *> sentinels;
 
     void add_sentinel(ggml_context * ctx) {
-        if (mode == MODE_PERF) {
+        if (mode == MODE_PERF || mode == MODE_GRAD) {
             return;
         }
         ggml_tensor * sentinel = ::ggml_new_tensor_1d(ctx, GGML_TYPE_F32, sentinel_size);
@@ -389,6 +468,7 @@ struct test_case {
             /* .no_alloc = */ true,
         };
         ggml_context * ctx = ggml_init(params);
+        GGML_ASSERT(ctx);
 
         gf = ggml_new_graph(ctx);
 
@@ -550,6 +630,7 @@ struct test_case {
             /* .no_alloc = */ true,
         };
         ggml_context * ctx = ggml_init(params);
+        GGML_ASSERT(ctx);
 
         ggml_tensor * out = build_graph(ctx);
 
@@ -643,8 +724,282 @@ struct test_case {
 
         return true;
     }
+
+    bool eval_grad(ggml_backend_t backend, const char * op_name) {
+        mode = MODE_GRAD;
+        const std::vector<float> expect = grad_expect();
+
+        ggml_init_params params = {
+            /* .mem_size = */ ggml_tensor_overhead()*128 + 2*ggml_graph_overhead_custom(GGML_DEFAULT_GRAPH_SIZE, true),
+            /* .mem_base = */ NULL,
+            /* .no_alloc = */ true,
+        };
+        ggml_context * ctx = ggml_init(params);
+        GGML_ASSERT(ctx);
+
+        gf = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, true);
+        gb = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, true);
+
+        ggml_tensor * out = build_graph(ctx);
+
+        if (op_name != nullptr && op_desc(out) != op_name) {
+            //printf("  %s: skipping\n", op_desc(out).c_str());
+            ggml_free(ctx);
+            return true;
+        }
+
+        printf("  %s(%s): ", op_desc(out).c_str(), vars().c_str());
+        fflush(stdout);
+
+        if (out->grad == nullptr) {
+            printf("backwards pass not supported \n");
+            ggml_free(ctx);
+            return true;
+        }
+        if (out->type != GGML_TYPE_F32) {
+            ggml_free(ctx);
+            printf("not supported [%s->type != FP32]\n", out->name);
+            return true;
+        }
+
+        // check if the backend supports the ops
+        bool supported = true;
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (!ggml_backend_supports_op(backend, t)) {
+                printf("not supported [%s] ", ggml_backend_name(backend));
+                supported = false;
+                break;
+            }
+            if ((t->flags & GGML_TENSOR_FLAG_PARAM) && t->type != GGML_TYPE_F32) {
+                printf("not supported [%s->type != FP32] ", t->name);
+                supported = false;
+                break;
+            }
+        }
+        if (!supported) {
+            printf("\n");
+            ggml_free(ctx);
+            return true;
+        }
+
+        int64_t ngrads = 0;
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (t->flags & GGML_TENSOR_FLAG_PARAM) {
+                ngrads += ggml_nelements(t);
+            }
+        }
+        if (ngrads > grad_nmax()) {
+            printf("skipping large tensors for speed \n");
+            ggml_free(ctx);
+            return true;
+        }
+
+
+        if (!ggml_is_scalar(out)) {
+            out = ggml_sum(ctx, out);
+            ggml_set_name(out, "sum_of_out");
+        }
+
+        ggml_build_forward_expand(gf, out);
+        ggml_graph_cpy(gf, gb);
+        ggml_build_backward_expand(ctx, gf, gb, false);
+        if (expect.size() != 1 || expect[0] != 0.0f) {
+            GGML_ASSERT(gb->n_nodes > gf->n_nodes);
+            for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+                GGML_ASSERT(!(t->flags & GGML_TENSOR_FLAG_PARAM) || t->grad->op != GGML_OP_NONE);
+            }
+        }
+
+        // TODO: refactor so that this check is only needed once
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (!ggml_backend_supports_op(backend, t)) {
+                printf("not supported [%s] ", ggml_backend_name(backend));
+                supported = false;
+                break;
+            }
+            if ((t->flags & GGML_TENSOR_FLAG_PARAM) && t->type != GGML_TYPE_F32) {
+                printf("not supported [%s->type != FP32] ", t->name);
+                supported = false;
+                break;
+            }
+        }
+        if (!supported) {
+            printf("\n");
+            ggml_free(ctx);
+            return true;
+        }
+
+        // allocate
+        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend);
+        if (buf == NULL) {
+            printf("failed to allocate tensors [%s] ", ggml_backend_name(backend));
+            ggml_free(ctx);
+            return false;
+        }
+
+        // randomize tensors
+        initialize_tensors(ctx);
+
+        for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
+            if (!t->grad) {
+                continue;
+            }
+
+            std::vector<float> tmp(ggml_nelements(t->grad));
+            ggml_backend_tensor_set(t->grad, tmp.data(), 0, ggml_nbytes(t->grad));
+        }
+
+        // build graphs
+        const float onef = 1.0f;
+        ggml_backend_graph_compute(backend, gf);
+        ggml_backend_tensor_set(out->grad, &onef, 0, ggml_nbytes(out->grad));
+        ggml_backend_graph_compute(backend, gb);
+
+        bool ok = true;
+        for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
+            if (!(t->flags & GGML_TENSOR_FLAG_PARAM)) {
+                continue;
+            }
+
+            const char * bn = ggml_backend_name(backend);
+            const int64_t ne = ggml_nelements(t);
+
+            std::vector<float> ga = tensor_to_float(t->grad);
+
+            for (int64_t i = 0; i < ne; ++i) { // gradient algebraic
+                // check for nans
+                if (!std::isfinite(ga[i])) {
+                    printf("[%s] nonfinite gradient at index %zu (%s=%f) ", ggml_op_desc(t), i, bn, ga[i]);
+                    ok = false;
+                    break;
+                }
+            }
+            if (!ok) {
+                break;
+            }
+
+            std::vector<float> gn(ne); // gradient numeric
+            GGML_ASSERT(ga.size() == gn.size());
+
+            std::vector<float> x0 = tensor_to_float(t); // original t data
+            GGML_ASSERT(ggml_is_scalar(out));
+            GGML_ASSERT(out->type == GGML_TYPE_F32);
+
+            const float eps = grad_eps();
+            for (int64_t i = 0; i < ne; ++i) {
+                const float xiu  = x0[i] + 1.0f*eps; // x, index i, up
+                const float xiuh = x0[i] + 0.5f*eps; // x, index i, up half
+                const float xidh = x0[i] - 0.5f*eps; // x, index i, down half
+                const float xid  = x0[i] - 1.0f*eps; // x, index i, down
+
+                float fu, fuh, fdh, fd; // output values for xiu, xiuh, xid, xidh
+
+                ggml_backend_tensor_set(t, &xiu, i*sizeof(float), sizeof(float));
+                ggml_backend_graph_compute(backend, gf);
+                ggml_backend_tensor_get(out, &fu, 0, ggml_nbytes(out));
+
+                ggml_backend_tensor_set(t, &xid, i*sizeof(float), sizeof(float));
+                ggml_backend_graph_compute(backend, gf);
+                ggml_backend_tensor_get(out, &fd, 0, ggml_nbytes(out));
+
+                if (grad_precise()) {
+                    ggml_backend_tensor_set(t, &xiuh, i*sizeof(float), sizeof(float));
+                    ggml_backend_graph_compute(backend, gf);
+                    ggml_backend_tensor_get(out, &fuh, 0, ggml_nbytes(out));
+
+                    ggml_backend_tensor_set(t, &xidh, i*sizeof(float), sizeof(float));
+                    ggml_backend_graph_compute(backend, gf);
+                    ggml_backend_tensor_get(out, &fdh, 0, ggml_nbytes(out));
+
+                    gn[i] = (8.0*(double)fuh + (double)fd - (8.0*(double)fdh + (double)fu)) / (6.0*(double)eps);
+                } else {
+                    gn[i] = (fu - fd) / (2.0f*eps);
+                }
+
+                ggml_backend_tensor_set(t, x0.data(), 0, ggml_nbytes(t));
+            }
+
+            const double err = mean_abs_asymm(gn.data(), ga.data(), gn.size(), expect);
+            if (err > max_maa_err()) {
+                printf("[%s] MAA = %.9f > %.9f ", ggml_op_desc(t), err, max_maa_err());
+                ok = false;
+                break;
+            }
+            if (!ok) {
+                break;
+            }
+        }
+
+        if (!ok) {
+            printf("compare failed ");
+        }
+
+        ggml_backend_buffer_free(buf);
+
+        ggml_free(ctx);
+
+        if (ok) {
+            printf("\033[1;32mOK\033[0m\n");
+            return true;
+        }
+
+        printf("\033[1;31mFAIL\033[0m\n");
+        return false;
+    }
+};
+
+
+// ###################################
+// ## Section 2: GGML Op Defintions ##
+// ###################################
+
+
+// The following is an example showing the bare minimum for creating a test for a GGML op.
+
+// GGML_OP_EXAMPLE
+struct test_example : public test_case {
+    // Always define these 2 or variants thereof:
+    const ggml_type type; // The type of the input tensors.
+    const std::array<int64_t, 4> ne; // The shape of the input tensors.
+    // For some ops it's necessary to define multiple types or shapes for the inputs.
+    // Or they may need additional parameters.
+
+    // Put all parameters needed to fully define the test into one of the VARS_TO_STR macros.
+    // In most cases these are just the properties of the struct that you defined above.
+    // This is needed for info prints.
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    // Define a constructor for the struct.
+    // In most cases it will be sufficient to have the same arguments as the struct has properties
+    // and just use initializer lists.
+    test_example(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 5, 4, 3})
+        : type(type), ne(ne) {}
+
+    // Define how a simple GGML compute graph can be constructed for the new GGML op.
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        // Step 1: create input tensors that don't depend on any other tensors:
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a"); // Setting names is optional but it's useful for debugging.
+
+        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(b, "b");
+
+        // Step 2: use the op that you want to test in the GGML compute graph.
+        ggml_tensor * out = ggml_add(ctx, a, b); // For this example we're just doing a simple addition.
+        ggml_set_name(out, "out");
+
+        // Step 3: return the output tensor.
+        return out;
+    }
+    // In order to also check the gradients for your op, add calls like ggml_set_param(ctx, a)
+    // immediately after you create the tensors.
+    // This is optional and only makes sense if a backwards pass has actually been implemented for the new op.
 };
 
+
 // GGML_OP_UNARY
 struct test_unary : public test_case {
     const ggml_unary_op op;
@@ -658,20 +1013,36 @@ struct test_unary : public test_case {
 
     test_unary(ggml_unary_op op,
             ggml_type type = GGML_TYPE_F32,
-            std::array<int64_t, 4> ne_a = {128, 10, 10, 10},
+            std::array<int64_t, 4> ne_a = {128, 2, 2, 2},
             int v = 0)
         : op(op), type(type), ne_a(ne_a), v(v) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
+        const bool grad_supported = op == GGML_UNARY_OP_ABS || op == GGML_UNARY_OP_SGN || op == GGML_UNARY_OP_NEG ||
+            op == GGML_UNARY_OP_STEP || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU;
+
         ggml_tensor * a;
         if (v & 1) {
             auto ne = ne_a; ne[0] *= 3;
             a = ggml_new_tensor(ctx, type, 4, ne.data());
+            if (grad_supported) {
+                ggml_set_param(ctx, a);
+            }
+            ggml_set_name(a, "a");
+
             a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
+            ggml_set_name(a, "view_of_a");
         } else {
             a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+            if (grad_supported) {
+                ggml_set_param(ctx, a);
+            }
+            ggml_set_name(a, "a");
         }
+
         ggml_tensor * out = ggml_unary(ctx, a, op);
+        ggml_set_name(out, "out");
+
         return out;
     }
 
@@ -681,6 +1052,24 @@ struct test_unary : public test_case {
             init_tensor_uniform(t, -150.f, 150.f);
         }
     }
+
+    float grad_eps() override {
+        return 15.0f;
+    }
+
+    std::vector<float> grad_expect() override {
+        if (op == GGML_UNARY_OP_ABS) {
+            return {-1.0f, 1.0f};
+        }
+        if (op == GGML_UNARY_OP_SGN || op == GGML_UNARY_OP_STEP) {
+            return {0.0f};
+        }
+        if (op == GGML_UNARY_OP_RELU) {
+            return {0.0f, 1.0f};
+        }
+        return {};
+    }
+
 };
 
 // GGML_OP_GET_ROWS
@@ -701,11 +1090,24 @@ struct test_get_rows : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * in = ggml_new_tensor_3d(ctx, type, n, m, b);
+        ggml_set_name(in, "in");
+
         ggml_tensor * rows = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, r, b);
+        ggml_set_name(rows, "rows");
         if (v) {
             rows = ggml_view_2d(ctx, rows, r/2, b, rows->nb[1], 0);
+            ggml_set_name(rows, "view_of_rows");
         }
+
+        const bool grad_supported = ggml_is_matrix(in) && ggml_is_vector(rows);
+        if (grad_supported) {
+            ggml_set_param(ctx, in);
+            // rows is a constant input -> no gradients
+        }
+
         ggml_tensor * out = ggml_get_rows(ctx, in, rows);
+        ggml_set_name(out, "out");
+
         return out;
     }
 
@@ -741,14 +1143,21 @@ struct test_repeat : public test_case {
     }
 
     test_repeat(ggml_type type = GGML_TYPE_F32,
-            std::array<int64_t, 4> ne = {10, 10, 10, 10},
+            std::array<int64_t, 4> ne = {10, 5, 4, 3},
             std::array<int, 4> nr = {2, 2, 2, 2})
         : type(type), ne(ne), nr(nr) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * target = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);
+        ggml_set_name(target, "target");
+
         ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(ctx, src);
+        ggml_set_name(src, "src");
+
         ggml_tensor * out = ggml_repeat(ctx, src, target);
+        ggml_set_name(out, "out");
+
         return out;
     }
 };
@@ -774,10 +1183,62 @@ struct test_dup : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(ctx, src);
+        ggml_set_name(src, "src");
+
         if (_use_permute) {
             src = ggml_permute(ctx, src, permute[0], permute[1], permute[2], permute[3]);
+            ggml_set_name(src, "src_permuted");
         }
+
         ggml_tensor * out = ggml_dup(ctx, src);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_SET
+struct test_set : public test_case {
+    const ggml_type type_src;
+    const ggml_type type_dst;
+    const std::array<int64_t, 4> ne;
+    const int dim;
+
+    std::string vars() override {
+        return VARS_TO_STR4(type_src, type_dst, ne, dim);
+    }
+
+    size_t op_size(ggml_tensor * t) override {
+        return ggml_nbytes(t) + ggml_nbytes(t->src[0]);
+    }
+
+    test_set(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {6, 5, 4, 3}, int dim = 1)
+        : type_src(type_src), type_dst(type_dst), ne(ne), dim(dim) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
+        ggml_set_param(ctx, src);
+        ggml_set_name(src, "src");
+
+        auto ne_dst = ne;
+        for (int i = 0; i < dim; ++i) {
+            ne_dst[i] *= 2;
+        }
+        ggml_tensor* dst = ggml_new_tensor(ctx, type_dst, 4, ne_dst.data());
+        ggml_set_param(ctx, dst);
+        ggml_set_name(dst, "dst");
+
+        size_t offset = 0;
+        for (int i = 0; i < dim; ++i) {
+            offset += ((ne_dst[i] - ne[i])/2)*dst->nb[i];
+        }
+        ggml_tensor * out = ggml_set(ctx, dst, src,
+            // The backwards pass requires setting a contiguous region:
+            src->nb[1], src->nb[2], src->nb[3], offset);
+        ggml_set_name(out, "out");
+
         return out;
     }
 };
@@ -810,11 +1271,20 @@ struct test_cpy : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
+        ggml_set_param(ctx, src);
+        ggml_set_name(src, "src");
+
         if (_src_use_permute) {
             src = ggml_permute(ctx, src, permute[0], permute[1], permute[2], permute[3]);
+            ggml_set_name(src, "src_permuted");
         }
+
         ggml_tensor* dst = ggml_new_tensor(ctx, type_dst, 4, src->ne);
+        ggml_set_name(dst, "dst");
+
         ggml_tensor * out = ggml_cpy(ctx, src, dst);
+        ggml_set_name(out, "out");
+
         return out;
     }
 };
@@ -834,8 +1304,14 @@ struct test_cont : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(ctx, src);
+        ggml_set_name(src, "src");
+
         src = ggml_transpose(ctx, src);
+        ggml_set_name(src, "src_transposed");
+
         ggml_tensor * out = ggml_cont(ctx, src);
+        ggml_set_name(out, "out");
 
         return out;
     }
@@ -866,21 +1342,79 @@ struct test_bin_bcast : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);
+        ggml_set_name(a, "a");
+
         ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(b, "b");
+
+        // The backwards pass supports broadcasting only for GGML_ADD:
+        const bool grad_supported = op == ggml_add || ggml_are_same_shape(a, b);
+        if (grad_supported) {
+            ggml_set_param(ctx, a);
+            ggml_set_param(ctx, b);
+        }
+
         ggml_tensor * out = op(ctx, a, b);
+        ggml_set_name(out, "out");
+
         return out;
     }
 
     void initialize_tensors(ggml_context * ctx) override {
         for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
-            if (op == ggml_div) {
-                // avoid division by zero
-                init_tensor_uniform(t, 1.0f, 2.0f);
+            if (op == ggml_mul || op == ggml_div) {
+                // MUL and DIV have numerical issues around zero:
+                init_tensor_uniform(t, 0.9f, 1.1f);
             } else {
                 init_tensor_uniform(t);
             }
         }
     }
+
+    float grad_eps() override {
+        return 0.1f * (op == ggml_mul ? ne[0]*ne[1]*ne[2]*ne[3] : 1);
+    }
+
+    bool grad_precise() override {
+        return op == ggml_div;
+    }
+
+    double max_maa_err() override {
+        return op == ggml_add ? 1e-4 : 1e-3;
+    }
+};
+
+// GGML_OP_ADD1
+struct test_add1 : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_add1(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 5, 4, 3})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(ctx, a);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * b = ggml_new_tensor_1d(ctx, type, 1);
+        // ggml_set_param(ctx, b); // TODO: implement
+        ggml_set_name(b, "b");
+
+        ggml_tensor * out = ggml_add1(ctx, a, b);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    float grad_eps() override {
+        return 0.1f * ne[0]*ne[1]*ne[2]*ne[3];
+    }
 };
 
 // GGML_OP_SCALE
@@ -900,7 +1434,12 @@ struct test_scale : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(ctx, a);
+        ggml_set_name(a, "a");
+
         ggml_tensor * out = ggml_scale(ctx, a, scale);
+        ggml_set_name(out, "out");
+
         return out;
     }
 };
@@ -916,13 +1455,17 @@ struct test_norm : public test_case {
     }
 
     test_norm(ggml_type type = GGML_TYPE_F32,
-            std::array<int64_t, 4> ne = {64, 10, 10, 10},
+            std::array<int64_t, 4> ne = {64, 5, 4, 3},
             float eps = 1e-6f)
         : type(type), ne(ne), eps(eps) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a");
+
         ggml_tensor * out = ggml_norm(ctx, a, eps);
+        ggml_set_name(out, "out");
+
         return out;
     }
 };
@@ -938,15 +1481,24 @@ struct test_rms_norm : public test_case {
     }
 
     test_rms_norm(ggml_type type = GGML_TYPE_F32,
-            std::array<int64_t, 4> ne = {64, 10, 10, 10},
+            std::array<int64_t, 4> ne = {64, 5, 4, 3},
             float eps = 1e-6f)
         : type(type), ne(ne), eps(eps) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(ctx, a);
+        ggml_set_name(a, "a");
+
         ggml_tensor * out = ggml_rms_norm(ctx, a, eps);
+        ggml_set_name(out, "out");
+
         return out;
     }
+
+    bool grad_precise() override {
+        return true;
+    }
 };
 
 // GGML_OP_SSM_CONV
@@ -1038,7 +1590,14 @@ struct test_mul_mat : public test_case {
         // C^T = A * B^T: (k, m) * (k, n) => (m, n)
         ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0]      , bs[1]);
         ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
+        ggml_set_param(ctx, a);
+        ggml_set_param(ctx, b);
+        ggml_set_name(a, "a");
+        ggml_set_name(b, "b");
+
         ggml_tensor * out = ggml_mul_mat(ctx, a, b);
+        ggml_set_name(out, "out");
+
         return out;
     }
 };
@@ -1082,12 +1641,21 @@ struct test_mul_mat_id : public test_case {
     ggml_tensor * build_graph(ggml_context * ctx) override {
         // C^T = A * B^T: (k, m) * (k, n) => (m, n)
         ggml_tensor * as = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats);
+        ggml_set_name(as, "as");
+
         ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n);
+        ggml_set_name(ids, "ids");
         if (n_used != n_mats) {
             ids = ggml_view_2d(ctx, ids, n_used, n, ids->nb[1], 0);
+            ggml_set_name(ids, "view_of_ids");
         }
+
         ggml_tensor * b = ggml_new_tensor_3d(ctx, type_b, k, this->b ? 1 : n_used, n);
+        ggml_set_name(b, "b");
+
         ggml_tensor * out = ggml_mul_mat_id(ctx, as, b, ids);
+        ggml_set_name(out, "out");
+
         return out;
     }
 
@@ -1123,14 +1691,23 @@ struct test_sqr : public test_case {
     }
 
     test_sqr(ggml_type type = GGML_TYPE_F32,
-            std::array<int64_t, 4> ne = {10, 10, 10, 10})
+            std::array<int64_t, 4> ne = {10, 5, 4, 3})
         : type(type), ne(ne) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(ctx, a);
+        ggml_set_name(a, "a");
+
         ggml_tensor * out = ggml_sqr(ctx, a);
+        ggml_set_name(out, "out");
+
         return out;
     }
+
+    float grad_eps() override {
+        return 0.1f * 0.25f*ne[0]*ne[1]*ne[2]*ne[3]; // 10% of expected value of sum.
+    }
 };
 
 // GGML_OP_SQRT
@@ -1143,21 +1720,70 @@ struct test_sqrt : public test_case {
     }
 
     test_sqrt(ggml_type type = GGML_TYPE_F32,
-            std::array<int64_t, 4> ne = {10, 10, 10, 10})
+            std::array<int64_t, 4> ne = {10, 3, 3, 2})
         : type(type), ne(ne) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(ctx, a);
+        ggml_set_name(a, "a");
+
         ggml_tensor * out = ggml_sqrt(ctx, a);
+        ggml_set_name(out, "out");
+
         return out;
     }
 
     void initialize_tensors(ggml_context * ctx) override {
         // fill with positive values
         for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
-            init_tensor_uniform(t, 0.0f, 100.0f);
+            init_tensor_uniform(t, 50.0f, 100.0f);
+        }
+    }
+
+    float grad_eps() override {
+        return 20.0f;
+    }
+
+    bool grad_precise() override {
+        return true;
+    }
+};
+
+// GGML_OP_LOG
+struct test_log : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_log(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 5, 4, 3})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(ctx, a);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_log(ctx, a);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            // log(1) == 0, cluster values there to keep the sum low for better precision in the backwards pass:
+            init_tensor_uniform(t, 0.9f, 1.1f);
         }
     }
+
+    bool grad_precise() override {
+        return true;
+    }
 };
 
 // GGML_OP_SIN
@@ -1170,20 +1796,37 @@ struct test_sin : public test_case {
     }
 
     test_sin(ggml_type type = GGML_TYPE_F32,
-            std::array<int64_t, 4> ne = {10, 10, 10, 10})
+            std::array<int64_t, 4> ne = {10, 2, 2, 2})
         : type(type), ne(ne) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(ctx, a);
+        ggml_set_name(a, "a");
+
         ggml_tensor * out = ggml_sin(ctx, a);
+        ggml_set_name(out, "out");
+
         return out;
     }
 
     void initialize_tensors(ggml_context * ctx) override {
         for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
-            init_tensor_uniform(t, -100.0f, 100.0f);
+            init_tensor_uniform(t, -6.5f, 6.5f); // Covers interval [-2*pi, 2*pi].
         }
     }
+
+    double max_maa_err() override {
+        return 1e-3;
+    }
+
+    float grad_eps() override {
+        return 0.2f;
+    }
+
+    bool grad_precise() override {
+        return true;
+    }
 };
 
 // GGML_OP_COS
@@ -1196,20 +1839,37 @@ struct test_cos : public test_case {
     }
 
     test_cos(ggml_type type = GGML_TYPE_F32,
-            std::array<int64_t, 4> ne = {10, 10, 10, 10})
+            std::array<int64_t, 4> ne = {10, 2, 2, 2})
         : type(type), ne(ne) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(ctx, a);
+        ggml_set_name(a, "a");
+
         ggml_tensor * out = ggml_cos(ctx, a);
+        ggml_set_name(out, "out");
+
         return out;
     }
 
     void initialize_tensors(ggml_context * ctx) override {
         for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
-            init_tensor_uniform(t, -100.0f, 100.0f);
+            init_tensor_uniform(t, -6.5f, 6.5f); // Covers interval [-2*pi, 2*pi].
         }
     }
+
+    double max_maa_err() override {
+        return 1e-3;
+    }
+
+    float grad_eps() override {
+        return 0.2f;
+    }
+
+    bool grad_precise() override {
+        return true;
+    }
 };
 
 // GGML_OP_CLAMP
@@ -1224,15 +1884,27 @@ struct test_clamp : public test_case {
     }
 
     test_clamp(ggml_type type = GGML_TYPE_F32,
-            std::array<int64_t, 4> ne = {10, 10, 10, 10},
+            std::array<int64_t, 4> ne = {10, 5, 4, 3},
             float min = -0.5f, float max = 0.5f)
         : type(type), ne(ne), min(min), max(max) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a");
+
         ggml_tensor * out = ggml_clamp(ctx, a, min, max);
+        ggml_set_name(out, "out");
+
         return out;
     }
+
+    float grad_eps() override {
+        return 1e-2f;
+    }
+
+    std::vector<float> grad_expect() override {
+        return {0.0f, 1.0f};
+    }
 };
 
 // GGML_OP_DIAG_MASK_INF
@@ -1246,13 +1918,18 @@ struct test_diag_mask_inf : public test_case {
     }
 
     test_diag_mask_inf(ggml_type type = GGML_TYPE_F32,
-            std::array<int64_t, 4> ne = {10, 10, 10, 10},
+            std::array<int64_t, 4> ne = {10, 10, 3, 2},
             int n_past = 5)
         : type(type), ne(ne), n_past(n_past) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(ctx, a);
+        ggml_set_name(a, "a");
+
         ggml_tensor * out = ggml_diag_mask_inf(ctx, a, n_past);
+        ggml_set_name(out, "out");
+
         return out;
     }
 };
@@ -1276,7 +1953,7 @@ struct test_soft_max : public test_case {
     }
 
     test_soft_max(ggml_type type = GGML_TYPE_F32,
-            std::array<int64_t, 4> ne = {10, 10, 10, 10},
+            std::array<int64_t, 4> ne = {10, 5, 4, 3},
             bool mask = false,
             float scale = 1.0f,
             float max_bias = 0.0f)
@@ -1284,13 +1961,24 @@ struct test_soft_max : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(ctx, a);
+        ggml_set_name(a, "a");
+
         ggml_tensor * mask = nullptr;
         if (this->mask) {
             mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]);
+            ggml_set_name(mask, "mask");
         }
+
         ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias);
+        ggml_set_name(out, "out");
+
         return out;
     }
+
+    bool grad_precise() override {
+        return true;
+    }
 };
 
 
@@ -1312,7 +2000,7 @@ struct test_rope : public test_case {
     }
 
     test_rope(ggml_type type = GGML_TYPE_F32,
-            std::array<int64_t, 4> ne_a = {10, 10, 10, 1},
+            std::array<int64_t, 4> ne_a = {10, 5, 3, 1},
             int n_dims = 10, int mode = 0, int n_ctx = 512, float fs = 1.0f, float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0)
         : type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v) {}
 
@@ -1321,13 +2009,29 @@ struct test_rope : public test_case {
         if (v & 1) {
             auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3;
             a = ggml_new_tensor(ctx, type, 4, ne.data());
+            ggml_set_param(ctx, a);
+            ggml_set_name(a, "a");
+
             a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
+            ggml_set_name(a, "view_of_a");
         } else {
             a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+            ggml_set_param(ctx, a);
+            ggml_set_name(a, "a");
         }
+
         ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);
-        ggml_tensor * freq = ff ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2) : nullptr;
+        ggml_set_name(pos, "pos");
+
+        ggml_tensor * freq = nullptr;
+        if (ff) {
+            freq = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2);
+            ggml_set_name(freq, "freq");
+        }
+
         ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
+        ggml_set_name(out, "out");
+
         return out;
     }
 
@@ -1350,6 +2054,14 @@ struct test_rope : public test_case {
             }
         }
     }
+
+    double max_maa_err() override {
+        return 1e-3;
+    }
+
+    bool grad_precise() override {
+        return true;
+    }
 };
 
 // GGML_OP_POOL2D
@@ -1381,7 +2093,12 @@ struct test_pool2d : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
+        ggml_set_param(ctx, input);
+        ggml_set_name(input, "input");
+
         ggml_tensor * out = ggml_pool_2d(ctx, input, pool_type, k0, k1, s0, s1, p0, p1);
+        ggml_set_name(out, "out");
+
         return out;
     }
 };
@@ -1406,8 +2123,14 @@ struct test_conv_transpose_1d : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
+        ggml_set_name(input, "input");
+
         ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_kernel.data());
+        ggml_set_name(kernel, "kernel");
+
         ggml_tensor * out = ggml_conv_transpose_1d(ctx, kernel, input, s0, p0, d0);
+        ggml_set_name(out, "out");
+
         return out;
     }
 };
@@ -1446,8 +2169,15 @@ struct test_im2col : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
+        ggml_set_param(ctx, input);
+        ggml_set_name(input, "input");
+
         ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data());
+        ggml_set_name(kernel, "kernel");
+
         ggml_tensor * out = ggml_im2col(ctx, kernel, input, s0, s1, p0, p1, d0, d1, is_2D, dst_type);
+        ggml_set_name(out, "out");
+
         return out;
     }
 };
@@ -1465,8 +2195,8 @@ struct test_concat : public test_case {
     }
 
     test_concat(ggml_type type = GGML_TYPE_F32,
-            std::array<int64_t, 4> ne_a = {10, 10, 10, 10},
-            int64_t ne_b_d = 10,
+            std::array<int64_t, 4> ne_a = {10, 5, 5, 5},
+            int64_t ne_b_d = 5,
             int dim = 2, int v = 0)
         : type(type), ne_a(ne_a), ne_b_d(ne_b_d), dim(dim), v(v) {}
 
@@ -1477,19 +2207,30 @@ struct test_concat : public test_case {
         if (v & 1) {
             auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3;
             a = ggml_new_tensor(ctx, type, 4, ne.data());
+            ggml_set_name(a, "a");
+
             a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
+            ggml_set_name(a, "view_of_a");
         } else {
             a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+            ggml_set_name(a, "a");
         }
         ggml_tensor * b;
         if (v & 2) {
             auto ne = ne_b; ne[0] *= 3; ne[1] *= 2; ne[2] *= 4;
             b = ggml_new_tensor(ctx, type, 4, ne.data());
+            ggml_set_name(b, "b");
+
             b = ggml_view_4d(ctx, b, ne_b[0], ne_b[1], ne_b[2], ne_b[3], b->nb[1], b->nb[2], b->nb[3], 0);
+            ggml_set_name(b, "view_of_b");
         } else {
             b = ggml_new_tensor(ctx, type, 4, ne_b.data());
+            ggml_set_name(b, "b");
         }
+
         ggml_tensor * out = ggml_concat(ctx, a, b, dim);
+        ggml_set_name(out, "out");
+
         return out;
     }
 };
@@ -1511,7 +2252,11 @@ struct test_argsort : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a");
+
         ggml_tensor * out = ggml_argsort(ctx, a, order);
+        ggml_set_name(out, "out");
+
         return out;
     }
 
@@ -1544,6 +2289,35 @@ struct test_argsort : public test_case {
     }
 };
 
+// GGML_OP_SUM
+struct test_sum : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_sum(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 5, 4, 3})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(ctx, a);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_sum(ctx, a);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    float grad_eps() override {
+        return 0.1f * sqrtf(ne[0]*ne[1]*ne[2]*ne[3]);
+    }
+};
+
 // GGML_OP_SUM_ROWS
 struct test_sum_rows : public test_case {
     const ggml_type type;
@@ -1554,12 +2328,17 @@ struct test_sum_rows : public test_case {
     }
 
     test_sum_rows(ggml_type type = GGML_TYPE_F32,
-            std::array<int64_t, 4> ne = {10, 10, 10, 10})
+            std::array<int64_t, 4> ne = {10, 5, 4, 3})
         : type(type), ne(ne) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(ctx, a);
+        ggml_set_name(a, "a");
+
         ggml_tensor * out = ggml_sum_rows(ctx, a);
+        ggml_set_name(out, "out");
+
         return out;
     }
 };
@@ -1582,8 +2361,16 @@ struct test_upscale : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        if (transpose) a = ggml_transpose(ctx, a);
+        ggml_set_name(a, "a");
+
+        if (transpose) {
+            a = ggml_transpose(ctx, a);
+            ggml_set_name(a, "a_transposed");
+        }
+
         ggml_tensor * out = ggml_upscale(ctx, a, scale_factor);
+        ggml_set_name(out, "out");
+
         return out;
     }
 };
@@ -1605,7 +2392,11 @@ struct test_upscale_ext : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a");
+
         ggml_tensor * out = ggml_upscale_ext(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3]);
+        ggml_set_name(out, "out");
+
         return out;
     }
 };
@@ -1629,7 +2420,11 @@ struct test_group_norm : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a");
+
         ggml_tensor * out = ggml_group_norm(ctx, a, num_groups, eps);
+        ggml_set_name(out, "out");
+
         return out;
     }
 };
@@ -1645,14 +2440,22 @@ struct test_acc : public test_case {
     }
 
     test_acc(ggml_type type = GGML_TYPE_F32,
-            std::array<int64_t, 4> ne_a = {1024, 577, 1, 1},
-            std::array<int64_t, 4> ne_b = {1024, 576, 1, 1})
+            std::array<int64_t, 4> ne_a = {256, 17, 1, 1},
+            std::array<int64_t, 4> ne_b = {256, 16, 1, 1})
         : type(type), ne_a(ne_a), ne_b(ne_b) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+        ggml_set_param(ctx, a);
+        ggml_set_name(a, "a");
+
         ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
+        ggml_set_param(ctx, b);
+        ggml_set_name(b, "b");
+
         ggml_tensor * out = ggml_acc(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], b->nb[1]);
+        ggml_set_name(out, "out");
+
         return out;
     }
 };
@@ -1675,7 +2478,11 @@ struct test_pad : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+        ggml_set_name(a, "a");
+
         ggml_tensor * out = ggml_pad(ctx, a, pad_0, pad_1, 0, 0);
+        ggml_set_name(out, "out");
+
         return out;
     }
 };
@@ -1697,6 +2504,8 @@ struct test_arange : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * out = ggml_arange(ctx, start, stop, step);
+        ggml_set_name(out, "out");
+
         return out;
     }
 };
@@ -1719,7 +2528,11 @@ struct test_timestep_embedding : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+        ggml_set_name(a, "a");
+
         ggml_tensor * out = ggml_timestep_embedding(ctx, a, dim, max_period);
+        ggml_set_name(out, "out");
+
         return out;
     }
 };
@@ -1735,13 +2548,17 @@ struct test_leaky_relu : public test_case {
     }
 
     test_leaky_relu(ggml_type type = GGML_TYPE_F32,
-            std::array<int64_t, 4> ne_a = {10, 10, 10, 10},
+            std::array<int64_t, 4> ne_a = {10, 5, 4, 3},
             float negative_slope = 0.1f)
         : type(type), ne_a(ne_a), negative_slope(negative_slope)  {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+        ggml_set_name(a, "a");
+
         ggml_tensor * out = ggml_leaky_relu(ctx, a, negative_slope, true);
+        ggml_set_name(out, "out");
+
         return out;
     }
 };
@@ -1768,19 +2585,37 @@ struct test_flash_attn_ext : public test_case {
         return 5e-4;
     }
 
-    test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16)
+    test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8,
+                        bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16)
         : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
 
         ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs_padded, nb, nh, 1);
+        ggml_set_name(q, "q");
+
         ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV,       hs_padded, kv, nh, 1);
+        ggml_set_name(k, "k");
+
         ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV,       hs_padded, kv, nh, 1);
-        ggml_tensor * m = mask ? ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1) : nullptr;
+        ggml_set_name(v, "v");
+
+        ggml_tensor * m = nullptr;
+        if (mask) {
+            m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1);
+            ggml_set_name(m, "m");
+        }
+
         ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias, logit_softcap);
+        ggml_set_name(out, "out");
+
         return out;
     }
+
+    bool grad_precise() override {
+        return true;
+    }
 };
 
 // GGML_OP_CROSS_ENTROPY_LOSS
@@ -1793,15 +2628,42 @@ struct test_cross_entropy_loss : public test_case {
     }
 
     test_cross_entropy_loss(ggml_type type = GGML_TYPE_F32,
-            std::array<int64_t, 4> ne = {10, 10, 10, 10})
+            std::array<int64_t, 4> ne = {10, 5, 4, 3})
         : type(type), ne(ne) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * logits = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(ctx, logits);
+        ggml_set_name(logits, "logits");
+
         ggml_tensor * labels = ggml_new_tensor(ctx, type, 4, ne.data());
+        // The labels are assumed to be constant -> no gradients.
+        ggml_set_name(labels, "labels");
+
+        // Ensure labels add up to 1:
+        labels = ggml_soft_max(ctx, labels);
+        ggml_set_name(labels, "labels_normalized");
+
         ggml_tensor * out = ggml_cross_entropy_loss(ctx, logits, labels);
+        ggml_set_name(out, "out");
+
         return out;
     }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        // For larger abs. diffs between logits softmax is more linear, therefore more precise num. gradients.
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t, -100.0f, 100.0f);
+        }
+    }
+
+    float grad_eps() override {
+        return 1.0f;
+    }
+
+    bool grad_precise() override {
+        return true;
+    }
 };
 
 enum llm_norm_type {
@@ -2188,6 +3050,12 @@ struct test_falcon : public test_llm {
     }
 };
 
+
+// ###########################################
+// ## Section 3: GGML Op Test Instantiation ##
+// ###########################################
+
+
 static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
     std::vector<std::unique_ptr<test_case>> test_cases;
     std::default_random_engine rng(0);
@@ -2228,8 +3096,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     // unary ops
     for (int v : {0, 1}) {
         for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
-            test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 128, 10, 10, 10 }, v));
-            test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 7, 13, 19, 23 }, v));
+            test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 128, 2, 2, 2 }, v));
+            test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 5, 7, 11, 13 }, v));
         }
     }
 
@@ -2265,11 +3133,13 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         }
     }
 
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
     // test cases for 1D im2col
-    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
 
     // sycl backend will limit task global_range < MAX_INT
     // test cases for 2D im2col with large input W and H (occurs in stable-diffusion)
@@ -2288,13 +3158,13 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
 
 
-    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1}));
-    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {2, 1, 1, 1}));
-    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 2, 1, 1}));
-    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 2, 1}));
-    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 2}));
-    test_cases.emplace_back(new test_repeat(GGML_TYPE_I32, {10, 10, 10, 10}, {2, 1, 1, 1}));
-    test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 10, 10, 10}, {1, 1, 1, 2}));
+    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 1}));
+    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1}));
+    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 2, 1, 1}));
+    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 1}));
+    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 2}));
+    test_cases.emplace_back(new test_repeat(GGML_TYPE_I32, {10, 5, 4, 3}, {2, 1, 1, 1}));
+    test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 5, 4, 3}, {1, 1, 1, 2}));
 
     test_cases.emplace_back(new test_dup(GGML_TYPE_F32));
     test_cases.emplace_back(new test_dup(GGML_TYPE_F16));
@@ -2307,6 +3177,10 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {0, 2, 1, 3}));
     test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {1, 2, 0, 3}));
 
+    for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) {
+        test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim));
+    }
+
     for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {
         for (ggml_type type_dst : all_types) {
            test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
@@ -2339,16 +3213,16 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 8, 1}, {1, 1, 1, 1});
     add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1, 1}, {32, 1, 1, 1});
     add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 320, 320}, {1, 1, 1, 1});
-    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 1, 1}, {1, 1, 1, 1});
-    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 1}, {1, 1, 1, 1});
-    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 1, 1});
-    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {2, 1, 1, 1});
-    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 2, 1, 1});
-    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 2, 1});
-    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 1, 2});
-    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 2, 2});
-    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 2, 2, 2});
-    add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {2, 2, 2, 2});
+    add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 1, 1}, {1, 1, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 1}, {1, 1, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 2, 1, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 1});
+    add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 2});
+    add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 2});
+    add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 2, 2, 2});
+    add_test_bin_bcast(GGML_TYPE_F32, {10, 5, 4, 3}, {2, 2, 2, 2});
 
     // stable diffusion
     add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 1, 1, 1});
@@ -2367,11 +3241,12 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     //add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {1, 1, 1, 1});
     //add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {2, 1, 1, 1});
 
+    test_cases.emplace_back(new test_add1());
     test_cases.emplace_back(new test_scale());
 
     for (float eps : {1e-6f, 1e-5f, 1e-3f, 1e-1f}) {
-        test_cases.emplace_back(new test_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
-        test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
+        test_cases.emplace_back(new test_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
+        test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
     }
 
     test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 1, 1}, {4, 1536, 1, 1}));
@@ -2475,13 +3350,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
 
     test_cases.emplace_back(new test_sqr());
     test_cases.emplace_back(new test_sqrt());
+    test_cases.emplace_back(new test_log());
     test_cases.emplace_back(new test_sin());
     test_cases.emplace_back(new test_cos());
     test_cases.emplace_back(new test_clamp());
 
-    test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10,  1,  1}, 5));
-    test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10,  1}, 5));
-    test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 10}, 5));
+    test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));
+    test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 3, 1}, 5));
+    test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 3, 2}, 5));
 
 #if 0
     std::uniform_int_distribution<> dist_ne1(1, 50);
@@ -2525,23 +3401,23 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
                     for (float af : { 1.0f, 1.4245f }) {
                         for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
                             for (bool ff : {false, true}) { // freq_factors
-                                test_cases.emplace_back(new test_rope(type, {128,  32, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 7B
+                                test_cases.emplace_back(new test_rope(type, {128,  32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 7B
 
                                 if (all) {
-                                    test_cases.emplace_back(new test_rope(type, {128,  40, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 13B
-                                    test_cases.emplace_back(new test_rope(type, {128,  52, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 30B
-                                    test_cases.emplace_back(new test_rope(type, {128,  64, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 65B
+                                    test_cases.emplace_back(new test_rope(type, {128,  40, 2, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 13B
+                                    test_cases.emplace_back(new test_rope(type, {128,  52, 2, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 30B
+                                    test_cases.emplace_back(new test_rope(type, {128,  64, 2, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 65B
                                 }
 
                                 if (all) {
-                                    test_cases.emplace_back(new test_rope(type, { 64,   1, 10, 1},  64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)
-                                    test_cases.emplace_back(new test_rope(type, { 64,  71, 10, 1},  64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)
-                                    test_cases.emplace_back(new test_rope(type, { 64,   8, 10, 1},  64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B)
-                                    test_cases.emplace_back(new test_rope(type, { 80,  32, 10, 1},  20, 2, 512, fs, ef, af, ff, v)); // neox (stablelm)
-                                    test_cases.emplace_back(new test_rope(type, { 80,  32, 10, 1},  32, 2, 512, fs, ef, af, ff, v)); // neox (phi-2)
+                                    test_cases.emplace_back(new test_rope(type, { 64,   1, 2, 1},  64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)
+                                    test_cases.emplace_back(new test_rope(type, { 64,  71, 2, 1},  64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)
+                                    test_cases.emplace_back(new test_rope(type, { 64,   8, 2, 1},  64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B)
+                                    test_cases.emplace_back(new test_rope(type, { 80,  32, 2, 1},  20, 2, 512, fs, ef, af, ff, v)); // neox (stablelm)
+                                    test_cases.emplace_back(new test_rope(type, { 80,  32, 2, 1},  32, 2, 512, fs, ef, af, ff, v)); // neox (phi-2)
                                 }
 
-                                test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1},  64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B)
+                                test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1},  64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B)
                             }
                         }
 
@@ -2565,6 +3441,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
     }
 
+    test_cases.emplace_back(new test_sum());
     test_cases.emplace_back(new test_sum_rows());
     test_cases.emplace_back(new test_upscale());
     test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true));
@@ -2607,6 +3484,18 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
 #endif
 
     // run tests
+    if (mode == MODE_GRAD) {
+        size_t n_ok = 0;
+        for (auto & test : test_cases) {
+            if (test->eval_grad(backend, op_name)) {
+                n_ok++;
+            }
+        }
+        printf("  %zu/%zu tests passed\n", n_ok, test_cases.size());
+
+        return n_ok == test_cases.size();
+    }
+
     if (mode == MODE_TEST) {
         ggml_backend_t backend_cpu = ggml_backend_cpu_init();
 
@@ -2635,8 +3524,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
 
 static void usage(char ** argv) {
     printf("Usage: %s [mode] [-o op] [-b backend]\n", argv[0]);
-    printf("  valid modes are: test (compare with CPU backend for correctness) or perf (performance evaluation)\n");
-    printf("  op names are as given by ggml_op_desc()\n");
+    printf("    valid modes:\n");
+    printf("      - test (default, compare with CPU backend for correctness)\n");
+    printf("      - perf (performance evaluation)\n");
+    printf("      - grad (compare gradients from backpropagation with method of finite differences)\n");
+    printf("    op names are as given by ggml_op_desc() (e.g. GGML_ADD)\n");
 }
 
 int main(int argc, char ** argv) {
@@ -2649,6 +3541,8 @@ int main(int argc, char ** argv) {
             mode = MODE_TEST;
         } else if (strcmp(argv[i], "perf") == 0) {
             mode = MODE_PERF;
+        } else if (strcmp(argv[i], "grad") == 0) {
+            mode = MODE_GRAD;
         } else if (strcmp(argv[i], "-o") == 0) {
             if (i + 1 < argc) {
                 op_name_filter = argv[++i];
@@ -2686,7 +3580,7 @@ int main(int argc, char ** argv) {
         ggml_backend_t backend = ggml_backend_reg_init_backend(i, NULL);
         GGML_ASSERT(backend != NULL);
 
-        if (backend_filter == NULL && ggml_backend_is_cpu(backend)) {
+        if (backend_filter == NULL && ggml_backend_is_cpu(backend) && mode != MODE_GRAD) {
             printf("  Skipping CPU backend\n");
             ggml_backend_free(backend);
             n_ok++;