]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml/examples: add backend support for numerical optimization (ggml/949)
authorJohannes Gäßler <redacted>
Fri, 20 Sep 2024 16:04:44 +0000 (19:04 +0300)
committerGeorgi Gerganov <redacted>
Fri, 20 Sep 2024 18:15:05 +0000 (21:15 +0300)
* CUDA eval works

* stochastic gradient descent op

* Adam except decay

* CUDA CROSS_ENTROPY_LOSS_BACK

* CUDA mnist-fc training works

* backend CLI arg

* refactor gguf load

* remove sched from opt_step_adam

* implement l1 regularization (weight decay)

* extra call to add optimizer

* initialize gradients with ggml_graph_reset

* gradient accumulation

* increment iter per eval instead of epoch

* adjust backend interfaces

* fix ggml_graph_reset without backend

* fix ggml graph export/import

* fixup

* rename

* revert ggml_opt changes

* more general CUDA repeat_back

* update documentation, fix CNN

* validation split

* add clarifying comment

* optimize PyTorch training

* adjust buffer size, thread count

* fix 0.0f validation split

* Update examples/mnist/mnist-common.cpp

Co-authored-by: Georgi Gerganov <redacted>
* fix gradient accumulation

* tensor flag for accumulators -> tensor hash set

* Update include/ggml.h

Co-authored-by: slaren <redacted>
* Update tests/test-backend-ops.cpp

Co-authored-by: slaren <redacted>
* Update tests/test-backend-ops.cpp

Co-authored-by: slaren <redacted>
* fix test prints

* Update src/ggml-backend.c

Co-authored-by: Georgi Gerganov <redacted>
* better CUDA support for noncontiguous out_prod

* add comment

---------

Co-authored-by: Georgi Gerganov <redacted>
Co-authored-by: slaren <redacted>
24 files changed:
ggml/include/ggml-backend.h
ggml/include/ggml.h
ggml/src/ggml-backend-impl.h
ggml/src/ggml-backend.c
ggml/src/ggml-cann.cpp
ggml/src/ggml-cuda.cu
ggml/src/ggml-cuda/binbcast.cu
ggml/src/ggml-cuda/binbcast.cuh
ggml/src/ggml-cuda/cross-entropy-loss.cu
ggml/src/ggml-cuda/cross-entropy-loss.cuh
ggml/src/ggml-cuda/opt-step-adamw.cu [new file with mode: 0644]
ggml/src/ggml-cuda/opt-step-adamw.cuh [new file with mode: 0644]
ggml/src/ggml-cuda/out-prod.cu [new file with mode: 0644]
ggml/src/ggml-cuda/out-prod.cuh [new file with mode: 0644]
ggml/src/ggml-cuda/unary.cu
ggml/src/ggml-cuda/unary.cuh
ggml/src/ggml-kompute.cpp
ggml/src/ggml-metal.m
ggml/src/ggml-rpc.cpp
ggml/src/ggml-sycl.cpp
ggml/src/ggml-vulkan.cpp
ggml/src/ggml.c
tests/test-backend-ops.cpp
tests/test-grad0.cpp

index e497b6d02388a14ed28aa15fbbb7bdf8b9f85d62..71c0bef8ee7ee69a7d59eb881d9a9c67551630a9 100644 (file)
@@ -66,6 +66,7 @@ extern "C" {
     // "offset" refers to the offset of the tensor data for setting/getting data
     GGML_API GGML_CALL void ggml_backend_tensor_set(      struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
     GGML_API GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
+    GGML_API GGML_CALL void ggml_backend_tensor_memset(   struct ggml_tensor * tensor,     uint8_t value, size_t offset, size_t size);
 
     GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
 
@@ -122,7 +123,7 @@ extern "C" {
     // The backend registry is a registry of all the available backends, and allows initializing backends in a generic way
 
     GGML_API size_t                     ggml_backend_reg_get_count(void);
-    GGML_API size_t                     ggml_backend_reg_find_by_name(const char * name);
+    GGML_API size_t                     ggml_backend_reg_find_by_name(const char * name); // returns index of backend with name, or SIZE_MAX if not found
     GGML_API ggml_backend_t             ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is backend_name:params (params is optional)
     GGML_API const char *               ggml_backend_reg_get_name(size_t i);
     GGML_API ggml_backend_t             ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific
index a413df35750b1e88dd9ede535b96043176df482e..2035001e97d7e1c3825ed63ab028be2cb582e45d 100644 (file)
@@ -534,6 +534,7 @@ extern "C" {
 
         GGML_OP_CROSS_ENTROPY_LOSS,
         GGML_OP_CROSS_ENTROPY_LOSS_BACK,
+        GGML_OP_OPT_STEP_ADAMW,
 
         GGML_OP_COUNT,
     };
@@ -571,10 +572,12 @@ extern "C" {
         GGML_LOG_LEVEL_DEBUG = 4,
     };
 
+    // this tensor...
     enum ggml_tensor_flag {
-        GGML_TENSOR_FLAG_INPUT  = 1,
-        GGML_TENSOR_FLAG_OUTPUT = 2,
-        GGML_TENSOR_FLAG_PARAM  = 4,
+        GGML_TENSOR_FLAG_INPUT    = 1, // ...is an input for the GGML compute graph
+        GGML_TENSOR_FLAG_OUTPUT   = 2, // ...is an output for the GGML compute graph
+        GGML_TENSOR_FLAG_PARAM    = 4, // ...contains trainable parameters
+        GGML_TENSOR_FLAG_LOSS     = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
     };
 
     // n-dimensional tensor
@@ -2037,23 +2040,44 @@ extern "C" {
             struct ggml_tensor          * b,
             struct ggml_tensor          * c);
 
+    // AdamW optimizer step
+    // Paper: https://arxiv.org/pdf/1711.05101v3.pdf
+    // PyTorch: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html
+    GGML_API struct ggml_tensor * ggml_opt_step_adamw(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            float                 alpha,
+            float                 beta1,
+            float                 beta2,
+            float                 eps,
+            float                 wd); // weight decay
+
     //
     // automatic differentiation
     //
 
-    GGML_API void ggml_set_param(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * tensor);
+    GGML_API void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor);
+    GGML_API void ggml_set_loss(struct ggml_tensor * tensor);
 
     GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
-    GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep);
+    GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate, bool keep);
+
+    GGML_API void ggml_build_opt_adamw(
+            struct ggml_context * ctx,
+            struct ggml_cgraph  * gf,
+            struct ggml_cgraph  * gb,
+            float                 alpha,
+            float                 beta1,
+            float                 beta2,
+            float                 eps,
+            float                 wd); // weight decay
 
     // graph allocation in a context
     GGML_API struct ggml_cgraph * ggml_new_graph       (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
     GGML_API struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads);
     GGML_API struct ggml_cgraph * ggml_graph_dup       (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
     GGML_API void                 ggml_graph_cpy       (struct ggml_cgraph * src, struct ggml_cgraph * dst);
-    GGML_API void                 ggml_graph_reset     (struct ggml_cgraph * cgraph);  // zero grads
+    GGML_API void                 ggml_graph_reset     (struct ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1
     GGML_API void                 ggml_graph_clear     (struct ggml_cgraph * cgraph);
 
     GGML_API int                   ggml_graph_size   (struct ggml_cgraph * cgraph);
index 36ca370867c9e7bcccbd58a86febd5bc82704086..b0d4141cc4363dae92235c8915b3c2df796f5d8c 100644 (file)
@@ -38,15 +38,16 @@ extern "C" {
     typedef void * ggml_backend_buffer_context_t;
 
     struct ggml_backend_buffer_i {
-        const char * (*GGML_CALL get_name)   (ggml_backend_buffer_t buffer);
-        void         (*GGML_CALL free_buffer)(ggml_backend_buffer_t buffer);
-        void *       (*GGML_CALL get_base)   (ggml_backend_buffer_t buffer);
-        void         (*GGML_CALL init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
-        void         (*GGML_CALL set_tensor) (ggml_backend_buffer_t buffer,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
-        void         (*GGML_CALL get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
-        bool         (*GGML_CALL cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); // dst is in the buffer, src may be in any buffer
-        void         (*GGML_CALL clear)      (ggml_backend_buffer_t buffer, uint8_t value);
-        void         (*GGML_CALL reset)      (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
+        const char * (*GGML_CALL get_name)      (ggml_backend_buffer_t buffer);
+        void         (*GGML_CALL free_buffer)   (ggml_backend_buffer_t buffer);
+        void *       (*GGML_CALL get_base)      (ggml_backend_buffer_t buffer);
+        void         (*GGML_CALL init_tensor)   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+        void         (*GGML_CALL memset_tensor) (ggml_backend_buffer_t buffer,       struct ggml_tensor * tensor,     uint8_t value, size_t offset, size_t size);
+        void         (*GGML_CALL set_tensor)    (ggml_backend_buffer_t buffer,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+        void         (*GGML_CALL get_tensor)    (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
+        bool         (*GGML_CALL cpy_tensor)    (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); // dst is in the buffer, src may be in any buffer
+        void         (*GGML_CALL clear)         (ggml_backend_buffer_t buffer, uint8_t value);
+        void         (*GGML_CALL reset)         (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
     };
 
     struct ggml_backend_buffer {
index b5d9301a787629de5260cd9e8e8e591c430c9c88..97ca5a1f32f744e9d7f0afe98aa361d1c6a6d722 100644 (file)
@@ -246,6 +246,22 @@ GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void *
     buf->iface.get_tensor(buf, tensor, data, offset, size);
 }
 
+GGML_API GGML_CALL void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+    ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
+
+    GGML_ASSERT(buf != NULL && "tensor buffer not set");
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
+
+    if (!size) {
+        return;
+    }
+    
+    GGML_ASSERT(buf->iface.memset_tensor != NULL && "memset not supported by backend buffer");
+
+    buf->iface.memset_tensor(buf, tensor, value, offset, size);
+}
+
 void ggml_backend_synchronize(ggml_backend_t backend) {
     if (backend->iface.synchronize == NULL) {
         return;
@@ -569,6 +585,12 @@ GGML_CALL static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t
     free(buffer->context);
 }
 
+GGML_CALL static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+    memset((char *)tensor->data + offset, value, size);
+
+    GGML_UNUSED(buffer);
+}
+
 GGML_CALL static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
     memcpy((char *)tensor->data + offset, data, size);
 
@@ -600,6 +622,7 @@ static struct ggml_backend_buffer_i cpu_backend_buffer_i = {
     /* .free_buffer     = */ ggml_backend_cpu_buffer_free_buffer,
     /* .get_base        = */ ggml_backend_cpu_buffer_get_base,
     /* .init_tensor     = */ NULL, // no initialization required
+    /* .memset_tensor   = */ ggml_backend_cpu_buffer_memset_tensor,
     /* .set_tensor      = */ ggml_backend_cpu_buffer_set_tensor,
     /* .get_tensor      = */ ggml_backend_cpu_buffer_get_tensor,
     /* .cpy_tensor      = */ ggml_backend_cpu_buffer_cpy_tensor,
@@ -613,6 +636,7 @@ static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = {
     /* .free_buffer     = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
     /* .get_base        = */ ggml_backend_cpu_buffer_get_base,
     /* .init_tensor     = */ NULL, // no initialization required
+    /* .memset_tensor   = */ ggml_backend_cpu_buffer_memset_tensor,
     /* .set_tensor      = */ ggml_backend_cpu_buffer_set_tensor,
     /* .get_tensor      = */ ggml_backend_cpu_buffer_get_tensor,
     /* .cpy_tensor      = */ ggml_backend_cpu_buffer_cpy_tensor,
@@ -980,6 +1004,7 @@ static struct ggml_backend_buffer_i ggml_backend_multi_buffer_context_interface(
         /* .free_buffer     = */ ggml_backend_multi_buffer_free_buffer,
         /* .get_base        = */ NULL,
         /* .init_tensor     = */ NULL,
+        /* .memset_tensor   = */ NULL,
         /* .set_tensor      = */ NULL,
         /* .get_tensor      = */ NULL,
         /* .cpy_tensor      = */ NULL,
index aa315b83f77aa9f8373810c7a3702f15d00ac953..d3ab78006ee23f1858d3189a34b627e5288b3648 100644 (file)
@@ -1037,6 +1037,7 @@ static ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
     /* .free_buffer     = */ ggml_backend_cann_buffer_free_buffer,
     /* .get_base        = */ ggml_backend_cann_buffer_get_base,
     /* .init_tensor     = */ ggml_backend_cann_buffer_init_tensor,
+    /* .memset_tensor   = */ NULL,
     /* .set_tensor      = */ ggml_backend_cann_buffer_set_tensor,
     /* .get_tensor      = */ ggml_backend_cann_buffer_get_tensor,
     /* .cpy_tensor      = */ ggml_backend_cann_buffer_cpy_tensor,
index 54f1a7c2d3075a9e41aa3327360c1378cc91f1f7..b0843dc621cf54989ab7d01821ecd05683e376b4 100644 (file)
@@ -21,6 +21,8 @@
 #include "ggml-cuda/mmq.cuh"
 #include "ggml-cuda/mmvq.cuh"
 #include "ggml-cuda/norm.cuh"
+#include "ggml-cuda/opt-step-adamw.cuh"
+#include "ggml-cuda/out-prod.cuh"
 #include "ggml-cuda/pad.cuh"
 #include "ggml-cuda/pool2d.cuh"
 #include "ggml-cuda/quantize.cuh"
@@ -493,6 +495,14 @@ GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t
     }
 }
 
+GGML_CALL static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+
+    ggml_cuda_set_device(ctx->device);
+    CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + offset, value, size, cudaStreamPerThread));
+    CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
+}
+
 GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
     ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
 
@@ -544,6 +554,7 @@ static ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {
     /* .free_buffer     = */ ggml_backend_cuda_buffer_free_buffer,
     /* .get_base        = */ ggml_backend_cuda_buffer_get_base,
     /* .init_tensor     = */ ggml_backend_cuda_buffer_init_tensor,
+    /* .memset_tensor   = */ ggml_backend_cuda_buffer_memset_tensor,
     /* .set_tensor      = */ ggml_backend_cuda_buffer_set_tensor,
     /* .get_tensor      = */ ggml_backend_cuda_buffer_get_tensor,
     /* .cpy_tensor      = */ ggml_backend_cuda_buffer_cpy_tensor,
@@ -860,6 +871,7 @@ static struct ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = {
     /* .free_buffer     = */ ggml_backend_cuda_split_buffer_free_buffer,
     /* .get_base        = */ ggml_backend_cuda_split_buffer_get_base,
     /* .init_tensor     = */ ggml_backend_cuda_split_buffer_init_tensor,
+    /* .memset_tensor   = */ NULL,
     /* .set_tensor      = */ ggml_backend_cuda_split_buffer_set_tensor,
     /* .get_tensor      = */ ggml_backend_cuda_split_buffer_get_tensor,
     /* .cpy_tensor      = */ NULL,
@@ -2168,6 +2180,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_REPEAT:
             ggml_cuda_op_repeat(ctx, dst);
             break;
+        case GGML_OP_REPEAT_BACK:
+            ggml_cuda_op_repeat_back(ctx, dst);
+            break;
         case GGML_OP_GET_ROWS:
             ggml_cuda_op_get_rows(ctx, dst);
             break;
@@ -2201,6 +2216,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
                 case GGML_UNARY_OP_NEG:
                     ggml_cuda_op_neg(ctx, dst);
                     break;
+                case GGML_UNARY_OP_STEP:
+                    ggml_cuda_op_step(ctx, dst);
+                    break;
                 case GGML_UNARY_OP_GELU:
                     ggml_cuda_op_gelu(ctx, dst);
                     break;
@@ -2267,6 +2285,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_MUL_MAT_ID:
             ggml_cuda_mul_mat_id(ctx, dst);
             break;
+        case GGML_OP_OUT_PROD:
+            ggml_cuda_out_prod(ctx, dst);
+            break;
         case GGML_OP_SCALE:
             ggml_cuda_op_scale(ctx, dst);
             break;
@@ -2324,6 +2345,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_CROSS_ENTROPY_LOSS:
             ggml_cuda_cross_entropy_loss(ctx, dst);
             break;
+        case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+            ggml_cuda_cross_entropy_loss_back(ctx, dst);
+            break;
+        case GGML_OP_OPT_STEP_ADAMW:
+            ggml_cuda_opt_step_adamw(ctx, dst);
+            break;
         default:
             return false;
     }
@@ -2761,6 +2788,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(op)) {
                 case GGML_UNARY_OP_NEG:
+                case GGML_UNARY_OP_STEP:
                 case GGML_UNARY_OP_GELU:
                 case GGML_UNARY_OP_SILU:
                 case GGML_UNARY_OP_RELU:
@@ -2813,6 +2841,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
                         return false;
                 }
             } break;
+        case GGML_OP_OUT_PROD:
+            return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
         case GGML_OP_GET_ROWS:
             {
                 switch (op->src[0]->type) {
@@ -2869,6 +2899,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
             } break;
         case GGML_OP_DUP:
         case GGML_OP_REPEAT:
+            {
+                ggml_type src0_type = op->src[0]->type;
+                return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
+            } break;
+        case GGML_OP_REPEAT_BACK:
+                return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1;
         case GGML_OP_CONCAT:
             {
                 ggml_type src0_type = op->src[0]->type;
@@ -2935,9 +2971,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
             }
             return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
                 op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
+#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
         case GGML_OP_CROSS_ENTROPY_LOSS:
+        case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+        case GGML_OP_OPT_STEP_ADAMW:
             return true;
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
         default:
             return false;
     }
index e1390a0414559fca9bc0a7ae05fd394de3c65615..c7b6be4e2905c1d927f0bc3860a0617b1ab16ce4 100644 (file)
@@ -1,4 +1,5 @@
 #include "binbcast.cuh"
+#include <cstdint>
 
 static __device__ __forceinline__ float op_repeat(const float a, const float b) {
     return b;
@@ -90,6 +91,30 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
     dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
 }
 
+template <typename T>
+static __global__ void k_repeat_back(
+    const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
+    const int64_t ne0, const int64_t ne1, const int64_t ne2) {
+
+    const int64_t tid0 = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
+    const int64_t tid1 = (int64_t) blockIdx.y*blockDim.y + threadIdx.y;
+    const int64_t tid2 = (int64_t) blockIdx.z*blockDim.z + threadIdx.z;
+
+    if (tid0 >= ne0) {
+        return;
+    }
+
+    T sum = 0;
+    for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
+        for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
+            for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
+                sum += src[i2*ne01*ne00 + i1*ne00 + i0];
+            }
+        }
+    }
+    dst[tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
+}
+
 template<float (*bin_op)(const float, const float)>
 struct bin_bcast_cuda {
     template<typename src0_t, typename src1_t, typename dst_t>
@@ -247,6 +272,16 @@ struct bin_bcast_cuda {
     }
 };
 
+template <typename T>
+static void repeat_back_cuda(
+    const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
+    const int64_t ne0, const int64_t ne1, const int64_t ne2, cudaStream_t stream) {
+
+    const dim3 block_dims(WARP_SIZE, 1, 1);
+    const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2);
+    k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2);
+}
+
 template<class op>
 static void ggml_cuda_op_bin_bcast(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
@@ -286,3 +321,35 @@ void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
 }
+
+void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0->type == dst->type);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_can_repeat(dst, src0));
+
+    cudaStream_t stream = ctx.stream();
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    GGML_ASSERT(src0->ne[3] == 1);
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+    GGML_ASSERT(dst->ne[3] == 1);
+
+    switch (dst->type) {
+        case GGML_TYPE_F32: {
+            const float * src0_d = (const float *) src0->data;
+            float       * dst_d  = (float       *) dst->data;
+            repeat_back_cuda<float>(src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, stream);
+        } break;
+        default: {
+            GGML_ASSERT(false);
+        } break;
+    }
+}
index 198c9ef6fd8ea73c3e9e85f5ef0e60676365ca91..3ac1c9b03fcea7255a124e4a0d9e1a18b3bd8b64 100644 (file)
@@ -5,3 +5,5 @@ void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 5575a90f643266bf5eb233596b731e2a938cc1a0..ed09406a88bacb119441e457ffa6c01c60b7f5ce 100644 (file)
@@ -71,6 +71,32 @@ static __global__ void cross_entropy_loss_f32(const float * logits, const float
     dst[blockIdx.x] = loss;
 }
 
+static __global__ void cross_entropy_loss_back_f32(const float * logits, const float * labels, const float * loss, float * dst, const int nclasses) {
+    extern __shared__ float tmp[];
+
+    float maxval = -INFINITY;
+    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+        const float val = logits[blockIdx.x*nclasses + i];
+        maxval = fmaxf(maxval, val);
+        tmp[i] = val;
+    }
+    maxval = warp_reduce_max(maxval);
+
+    float sum = 0.0f;
+    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+        const float val = expf(tmp[i] - maxval);
+        sum += val;
+        tmp[i] = val;
+    }
+    sum = warp_reduce_sum(sum);
+    const float sm_scale = 1.0f/sum;
+
+    const float d_by_nrows = *loss/gridDim.x;
+    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+        dst[blockIdx.x*nclasses + i] = (tmp[i]*sm_scale - labels[blockIdx.x*nclasses + i])*d_by_nrows;
+    }
+}
+
 void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
@@ -104,3 +130,37 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
     // Combine results from individual blocks:
     sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream);
 }
+
+void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    const ggml_tensor * opt0 = dst->src[2];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(opt0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(ggml_is_contiguous(opt0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_are_same_shape(src0, src1));
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    const int64_t ne00  = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    const float * src0_d = (const float *) src0->data;
+    const float * src1_d = (const float *) src1->data;
+    const float * opt0_d = (const float *) opt0->data;
+    float       * dst_d  = (float       *) dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    const dim3 blocks_dim(WARP_SIZE, 1, 1);
+    const dim3 blocks_num(nrows, 1, 1);
+    const int shmem = ne00*sizeof(float);
+
+    cross_entropy_loss_back_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, opt0_d, dst_d, ne00);
+}
index 9d7b8b0f0082ba52fe90da130d58a0b13597a967..9ec7152ff4518607a01f61ab2f4e209ec8cef6dd 100644 (file)
@@ -3,3 +3,5 @@
 #define CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE 256
 
 void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/opt-step-adamw.cu b/ggml/src/ggml-cuda/opt-step-adamw.cu
new file mode 100644 (file)
index 0000000..d6f13a9
--- /dev/null
@@ -0,0 +1,80 @@
+#include "opt-step-adamw.cuh"
+
+#include <cstdint>
+
+static __global__ void opt_step_adamw_f32(
+    float * __restrict__ x, const float * __restrict__ g, float * __restrict__ g_m, float * __restrict__ g_v, const int64_t k,
+    const float alpha, const float beta1, const float beta2, const float eps, const float wd,
+    const float beta1h, const float beta2h) {
+
+    const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    const float gi = g[i];
+    const float gmi = g_m[i]*beta1 +    gi*(1.0f - beta1);
+    const float gvi = g_v[i]*beta2 + gi*gi*(1.0f - beta2);
+
+    g_m[i] = gmi;
+    g_v[i] = gvi;
+
+    const float mh =       gmi*beta1h;
+    const float vh = sqrtf(gvi*beta2h) + eps;
+
+    x[i] = x[i]*(1.0f - alpha*wd) - mh/vh;
+}
+
+static void opt_step_adamw_f32_cuda(
+    float * x, const float * g, float * g_m, float * g_v, const int64_t k,
+    const float alpha, const float beta1, const float beta2, const float eps, const float wd,
+    const float beta1h, const float beta2h, cudaStream_t stream) {
+
+    const dim3 block_dims(CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1);
+    const dim3 block_nums((k + CUDA_OPT_STEP_ADAMW_BLOCK_SIZE - 1) / CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1);
+    opt_step_adamw_f32<<<block_nums, block_dims, 0, stream>>>(x, g, g_m, g_v, k, alpha, beta1, beta2, eps, wd, beta1h, beta2h);
+}
+
+void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0        = dst->src[0];
+    const ggml_tensor * src0_grad   = dst->src[1];
+    const ggml_tensor * src0_grad_m = dst->src[2];
+    const ggml_tensor * src0_grad_v = dst->src[3];
+
+    GGML_ASSERT(src0->type        == GGML_TYPE_F32);
+    GGML_ASSERT(src0_grad->type   == GGML_TYPE_F32);
+    GGML_ASSERT(src0_grad_m->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0_grad_v->type == GGML_TYPE_F32);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src0_grad));
+    GGML_ASSERT(ggml_is_contiguous(src0_grad_m));
+    GGML_ASSERT(ggml_is_contiguous(src0_grad_v));
+    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
+    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
+    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
+
+    float       * src0_d        = (float       *) src0->data;
+    const float * src0_grad_d   = (const float *) src0_grad->data;
+    float       * src0_grad_m_d = (float       *) src0_grad_m->data;
+    float       * src0_grad_v_d = (float       *) src0_grad_v->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    const int64_t ne = ggml_nelements(src0);
+
+    int64_t iter;  memcpy(&iter,  &dst->op_params[0], sizeof(int64_t));
+    float   alpha; memcpy(&alpha, &dst->op_params[2], sizeof(float));
+    float   beta1; memcpy(&beta1, &dst->op_params[3], sizeof(float));
+    float   beta2; memcpy(&beta2, &dst->op_params[4], sizeof(float));
+    float   eps;   memcpy(&eps,   &dst->op_params[5], sizeof(float));
+    float   wd;    memcpy(&wd,    &dst->op_params[6], sizeof(float));
+
+    const float beta1h  = alpha/(1.0f - powf(beta1, iter));
+    const float beta2h  =  1.0f/(1.0f - powf(beta2, iter));
+
+    opt_step_adamw_f32_cuda(src0_d, src0_grad_d, src0_grad_m_d, src0_grad_v_d, ne, alpha, beta1, beta2, eps, wd, beta1h, beta2h, stream);
+
+    iter++;
+    memcpy(&dst->op_params[0], &iter, sizeof(int64_t));
+}
diff --git a/ggml/src/ggml-cuda/opt-step-adamw.cuh b/ggml/src/ggml-cuda/opt-step-adamw.cuh
new file mode 100644 (file)
index 0000000..58d6f6e
--- /dev/null
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_OPT_STEP_ADAMW_BLOCK_SIZE 256
+
+void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/out-prod.cu b/ggml/src/ggml-cuda/out-prod.cu
new file mode 100644 (file)
index 0000000..657d50e
--- /dev/null
@@ -0,0 +1,52 @@
+#include "out-prod.cuh"
+#include "vendors/cuda.h"
+
+#include <cstdint>
+
+void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+
+    GGML_ASSERT(ne01 == ne11);
+    GGML_ASSERT(ne0 == ne00);
+    GGML_ASSERT(ne1 == ne10);
+
+    GGML_ASSERT(ne2 == src0->ne[2]);
+    GGML_ASSERT(ne2 == src1->ne[2]);
+    GGML_ASSERT(ne3 == src0->ne[3]);
+    GGML_ASSERT(ne3 == src1->ne[3]);
+
+    const float * src0_d = (const float *) src0->data;
+    const float * src1_d = (const float *) src1->data;
+    float       *  dst_d = (float       *)  dst->data;
+
+    cudaStream_t   stream = ctx.stream();
+    cublasHandle_t handle = ctx.cublas_handle();
+
+    const float alpha = 1.0f;
+    const float beta = 0.0f;
+
+    GGML_ASSERT(ne2 == 1);
+    GGML_ASSERT(ne3 == 1);
+    CUBLAS_CHECK(cublasSetStream(handle, stream));
+
+    const bool src1_T = ggml_is_transposed(src1);
+    const cublasOperation_t src1_cublas_op =  src1_T ? CUBLAS_OP_N : CUBLAS_OP_T;
+    const int64_t           ldb            = (src1_T ?        nb10 :        nb11) /  sizeof(float);
+    GGML_ASSERT(                             (src1_T ?        nb11 :        nb10) == sizeof(float));
+
+    CUBLAS_CHECK(
+        cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
+                ne0, ne1, ne01,
+                &alpha, src0_d, ne00,
+                        src1_d, ldb,
+                &beta,  dst_d,  ne0));
+}
diff --git a/ggml/src/ggml-cuda/out-prod.cuh b/ggml/src/ggml-cuda/out-prod.cuh
new file mode 100644 (file)
index 0000000..a0046f5
--- /dev/null
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 8ac669f94e2de23ed18a999041cab0861f8c5480..163b5a8ffec6b72f1ffac66a6b4c31f94e139808 100644 (file)
@@ -10,6 +10,16 @@ static __global__ void neg_f32(const float * x, float * dst, const int k) {
     dst[i] = -x[i];
 }
 
+static __global__ void step_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    dst[i] = x[i] > 0.0f;
+}
+
 static __global__ void gelu_f32(const float * x, float * dst, const int k) {
     const float GELU_COEF_A    = 0.044715f;
     const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
@@ -134,6 +144,11 @@ static void neg_f32_cuda(const float * x, float * dst, const int k, cudaStream_t
     neg_f32<<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
+static void step_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_STEP_BLOCK_SIZE - 1) / CUDA_STEP_BLOCK_SIZE;
+    step_f32<<<num_blocks, CUDA_STEP_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
 static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
     gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -213,6 +228,20 @@ void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     neg_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
 }
 
+void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    step_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
 void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const float * src0_d = (const float *)src0->data;
index ed2ffc461e8102caafcf1c2d48fa82892a8eafc4..fe519f6a232dfa053bd704d6f3f7c0942eda5510 100644 (file)
@@ -1,6 +1,7 @@
 #include "common.cuh"
 
 #define CUDA_NEG_BLOCK_SIZE 256
+#define CUDA_STEP_BLOCK_SIZE 256
 #define CUDA_GELU_BLOCK_SIZE 256
 #define CUDA_SILU_BLOCK_SIZE 256
 #define CUDA_TANH_BLOCK_SIZE 256
@@ -15,6 +16,8 @@
 
 void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
+void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
 void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 7f0bd82d5de92db1651fd720200c7f2e6a50559f..9cbc57a647de5eb21b590680f24ead462c20229f 100644 (file)
@@ -1872,6 +1872,7 @@ static ggml_backend_buffer_i ggml_backend_kompute_buffer_i = {
     /* .free_buffer     = */ ggml_backend_kompute_buffer_free_buffer,
     /* .get_base        = */ ggml_backend_kompute_buffer_get_base,
     /* .init_tensor     = */ NULL,
+    /* .memset_tensor   = */ NULL,
     /* .set_tensor      = */ ggml_backend_kompute_buffer_set_tensor,
     /* .get_tensor      = */ ggml_backend_kompute_buffer_get_tensor,
     /* .cpy_tensor      = */ NULL,
index f87181d19332ea368cd81d2fa31cb091a3220dcc..ef3b7f0e824a913e3e934d93f31886155bc80357 100644 (file)
@@ -3167,6 +3167,7 @@ static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
     /* .free_buffer     = */ ggml_backend_metal_buffer_free_buffer,
     /* .get_base        = */ ggml_backend_metal_buffer_get_base,
     /* .init_tensor     = */ NULL,
+    /* .memset_tensor   = */ NULL,
     /* .set_tensor      = */ ggml_backend_metal_buffer_set_tensor,
     /* .get_tensor      = */ ggml_backend_metal_buffer_get_tensor,
     /* .cpy_tensor      = */ ggml_backend_metal_buffer_cpy_tensor,
index a8a2eb85adc23bd8bae243adcfb863712907f13f..49b3fa91174e2209c07be1a3d1247b8d84c57e25 100644 (file)
@@ -469,6 +469,7 @@ static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
     /* .free_buffer     = */ ggml_backend_rpc_buffer_free_buffer,
     /* .get_base        = */ ggml_backend_rpc_buffer_get_base,
     /* .init_tensor     = */ ggml_backend_rpc_buffer_init_tensor,
+    /* .memset_tensor   = */ NULL,
     /* .set_tensor      = */ ggml_backend_rpc_buffer_set_tensor,
     /* .get_tensor      = */ ggml_backend_rpc_buffer_get_tensor,
     /* .cpy_tensor      = */ ggml_backend_rpc_buffer_cpy_tensor,
index acef7c6d4e1eaea92fe6fffaa9ff44668a8af68d..2cf4450473f9f9d8176c8792f6e38e386d5f7f2b 100644 (file)
@@ -4323,6 +4323,7 @@ static struct ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
     /* .free_buffer     = */ ggml_backend_sycl_buffer_free_buffer,
     /* .get_base        = */ ggml_backend_sycl_buffer_get_base,
     /* .init_tensor     = */ ggml_backend_sycl_buffer_init_tensor,
+    /* .memset_tensor   = */ NULL,
     /* .set_tensor      = */ ggml_backend_sycl_buffer_set_tensor,
     /* .get_tensor      = */ ggml_backend_sycl_buffer_get_tensor,
     /* .cpy_tensor      = */ ggml_backend_sycl_buffer_cpy_tensor,
index bad960510850ef139f79576f73e5d21fe96ecd21..f9da45881e9df4747b65059fa77973db756d8581 100644 (file)
@@ -6246,6 +6246,7 @@ static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = {
     /* .free_buffer     = */ ggml_backend_vk_buffer_free_buffer,
     /* .get_base        = */ ggml_backend_vk_buffer_get_base,
     /* .init_tensor     = */ ggml_backend_vk_buffer_init_tensor,
+    /* .memset_tensor   = */ NULL,
     /* .set_tensor      = */ ggml_backend_vk_buffer_set_tensor,
     /* .get_tensor      = */ ggml_backend_vk_buffer_get_tensor,
     /* .cpy_tensor      = */ ggml_backend_vk_buffer_cpy_tensor,
index 51532c5eaaad5f3bcd3432f49d867baa04115508..201d5466a0e4b7e6cff37cb090796d4f1534e397 100644 (file)
@@ -1,6 +1,7 @@
 #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
 #define _USE_MATH_DEFINES // For M_PI on MSVC
 
+#include "ggml-backend.h"
 #include "ggml-impl.h"
 #include "ggml-cpu-impl.h"
 #include "ggml-quants.h"
@@ -2997,9 +2998,10 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
 
     "CROSS_ENTROPY_LOSS",
     "CROSS_ENTROPY_LOSS_BACK",
+    "OPT_STEP_ADAMW",
 };
 
-static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
+static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -3090,9 +3092,10 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
 
     "cross_entropy_loss(x,y)",
     "cross_entropy_loss_back(x,y)",
+    "adamw(x)",
 };
 
-static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
+static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -4094,7 +4097,11 @@ static void ggml_set_op_params_f32(struct ggml_tensor * tensor, uint32_t i, floa
 }
 
 struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
-    memset(tensor->data, 0, ggml_nbytes(tensor));
+    if (tensor->buffer) {
+        ggml_backend_tensor_memset(tensor, 0, 0, ggml_nbytes(tensor));
+    } else {
+        memset(tensor->data, 0, ggml_nbytes(tensor));
+    }
     return tensor;
 }
 
@@ -8320,11 +8327,46 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
     return result;
 }
 
-////////////////////////////////////////////////////////////////////////////////
+// opt_step_adamw
 
-void ggml_set_param(
+struct ggml_tensor * ggml_opt_step_adamw(
         struct ggml_context * ctx,
-        struct ggml_tensor * tensor) {
+        struct ggml_tensor  * a,
+        float                 alpha,
+        float                 beta1,
+        float                 beta2,
+        float                 eps,
+        float                 wd) {
+    GGML_ASSERT(a->grad);
+    GGML_ASSERT(alpha >  0.0f);
+    GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f);
+    GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f);
+    GGML_ASSERT(eps   >= 0.0f);
+    GGML_ASSERT(wd    >= 0.0f && wd    <= 1.0f);
+
+    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+    result->op   = GGML_OP_OPT_STEP_ADAMW;
+    result->grad = NULL;
+    result->src[0] = a;
+    result->src[1] = a->grad;
+    result->src[2] = ggml_dup_tensor(ctx, a->grad);
+    result->src[3] = ggml_dup_tensor(ctx, a->grad);
+
+    const int64_t iter = 1;
+    memcpy(&result->op_params[0], &iter, sizeof(int64_t));
+    ggml_set_op_params_f32(result, 2, alpha);
+    ggml_set_op_params_f32(result, 3, beta1);
+    ggml_set_op_params_f32(result, 4, beta2);
+    ggml_set_op_params_f32(result, 5, eps);
+    ggml_set_op_params_f32(result, 6, wd);
+
+    return result;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor) {
     tensor->flags |= GGML_TENSOR_FLAG_PARAM;
 
     GGML_ASSERT(tensor->grad == NULL);
@@ -8332,6 +8374,13 @@ void ggml_set_param(
     ggml_format_name(tensor->grad, "%s (grad)", tensor->name);
 }
 
+void ggml_set_loss(struct ggml_tensor * tensor) {
+    GGML_ASSERT(ggml_is_scalar(tensor));
+    GGML_ASSERT(tensor->type == GGML_TYPE_F32);
+    GGML_ASSERT(tensor->grad);
+    tensor->flags |= GGML_TENSOR_FLAG_LOSS;
+}
+
 // ggml_compute_forward_dup
 
 static void ggml_compute_forward_dup_same_cont(
@@ -17406,7 +17455,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
     const int64_t ir0 = dr*ith;
     const int64_t ir1 = MIN(ir0 + dr, nr);
 
-    float * d   = (float *) opt0->data;
+    const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr;
 
     for (int64_t i1 = ir0; i1 < ir1; i1++) {
         float * ds0 = (float *)((char *) dst->data  + i1*dst->nb[1]);
@@ -17430,7 +17479,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
 
         // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
         ggml_vec_sub_f32(nc, ds0, ds0, s1);
-        ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr);
+        ggml_vec_scale_f32(nc, ds0, d_by_nr);
 
 #ifndef NDEBUG
         for (int i = 0; i < nc; ++i) {
@@ -17459,6 +17508,94 @@ static void ggml_compute_forward_cross_entropy_loss_back(
     }
 }
 
+static void ggml_compute_forward_opt_step_adamw_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0        = dst->src[0];
+    const struct ggml_tensor * src0_grad   = dst->src[1];
+    const struct ggml_tensor * src0_grad_m = dst->src[2];
+    const struct ggml_tensor * src0_grad_v = dst->src[3];
+    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    /* const float   gnorm = 1.0f; */
+    int64_t       iter;   memcpy(&iter, &dst->op_params[0], sizeof(int64_t));
+    const float   alpha = ggml_get_op_params_f32(dst, 2);
+    const float   beta1 = ggml_get_op_params_f32(dst, 3);
+    const float   beta2 = ggml_get_op_params_f32(dst, 4);
+    const float   eps   = ggml_get_op_params_f32(dst, 5);
+    const float   wd    = ggml_get_op_params_f32(dst, 6);
+
+    const float beta1h  = alpha/(1.0f - powf(beta1, iter));
+    const float beta2h  =  1.0f/(1.0f - powf(beta2, iter));
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        const int64_t i03 = ir/(ne02*ne01);
+        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+        const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;
+
+        float       * w = (float       *) ((char       *) src0->data        + offset); // weight
+        const float * g = (const float *) ((const char *) src0_grad->data   + offset); // grad
+        float       * m = (float       *) ((char       *) src0_grad_m->data + offset);
+        float       * v = (float       *) ((char       *) src0_grad_v->data + offset);
+
+        for (int i00 = 0; i00 < ne00; ++i00) {
+            m[i00] = m[i00]*beta1 +        g[i00]*(1.0f - beta1);
+            v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2);
+
+            const float mh =       m[i00]*beta1h;
+            const float vh = sqrtf(v[i00]*beta2h) + eps;
+
+            // The weight decay is applied independently of the Adam momenta m and v.
+            // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
+            // See: https://arxiv.org/pdf/1711.05101v3.pdf
+            w[i00] = w[i00]*(1.0f - alpha*wd) - mh/vh;
+        }
+    }
+
+    ggml_barrier(params->threadpool);
+    if (ith != 0) {
+        return;
+    }
+
+    iter++;
+    memcpy(&dst->op_params[0], &iter, sizeof(int64_t));
+}
+
+static void ggml_compute_forward_opt_step_adamw(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_opt_step_adamw_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
 /////////////////////////////////
 
 static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
@@ -17804,6 +17941,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
                 ggml_compute_forward_cross_entropy_loss_back(params, tensor);
             }
             break;
+        case GGML_OP_OPT_STEP_ADAMW:
+            {
+                ggml_compute_forward_opt_step_adamw(params, tensor);
+            }
+            break;
         case GGML_OP_NONE:
             {
                 // nop
@@ -17958,7 +18100,7 @@ void ggml_build_backward_gradient_checkpointing(
         struct ggml_tensor  * * checkpoints,
         int                     n_checkpoints) {
     ggml_graph_cpy(gf, gb_tmp);
-    ggml_build_backward_expand(ctx, gf, gb_tmp, true);
+    ggml_build_backward_expand(ctx, gf, gb_tmp, false, true);
 
     if (n_checkpoints <= 0) {
         ggml_graph_cpy(gb_tmp, gb);
@@ -17996,42 +18138,93 @@ void ggml_build_backward_gradient_checkpointing(
     ggml_hash_map_free(replacements);
 }
 
-// functions to change gradients considering the case that input a might be initial gradient with zero value
-
-static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) {
+// utility functions to change gradients
+// if a is in acc_table, modify gradients in-place and mark result as gradient accumulator
+// else if a is in zero_table, replace a
+// else, just add/subtract/etc. the gradients
+
+static struct ggml_tensor * ggml_add_or_set(
+        struct ggml_context  * ctx,
+        struct ggml_tensor   * a,
+        struct ggml_tensor   * b,
+        struct ggml_hash_set * zero_table,
+        struct ggml_hash_set * acc_table) {
+    if (ggml_hash_contains(acc_table, a)) {
+        struct ggml_tensor * ret = ggml_add_impl(ctx, a, b, true);
+        const size_t insert_result = ggml_hash_insert(acc_table, ret);
+        GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
+        GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
+        return ret;
+    }
     if (ggml_hash_contains(zero_table, a)) {
         return b;
-    } else {
-        return ggml_add_impl(ctx, a, b, false);
     }
+    return ggml_add_impl(ctx, a, b, false);
 }
 
-static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct ggml_hash_set * zero_table) {
+static struct ggml_tensor * ggml_acc_or_set(
+        struct ggml_context  * ctx,
+        struct ggml_tensor   * a,
+        struct ggml_tensor   * b,
+        const  size_t          nb1,
+        const  size_t          nb2,
+        const  size_t          nb3,
+        const  size_t          offset,
+        struct ggml_hash_set * zero_table,
+        struct ggml_hash_set * acc_table) {
+    if (ggml_hash_contains(acc_table, a)) {
+        struct ggml_tensor * ret = ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
+        const size_t insert_result = ggml_hash_insert(acc_table, ret);
+        GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
+        GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
+        return ret;
+    }
     if (ggml_hash_contains(zero_table, a)) {
-        struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f);
+        struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
         return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
-    } else {
-        return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
     }
+    return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
 }
 
-static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) {
+static struct ggml_tensor * ggml_add1_or_set(
+        struct ggml_context  * ctx,
+        struct ggml_tensor   * a,
+        struct ggml_tensor   * b,
+        struct ggml_hash_set * zero_table,
+        struct ggml_hash_set * acc_table) {
+    if (ggml_hash_contains(acc_table, a)) {
+        struct ggml_tensor * ret = ggml_add1_impl(ctx, a, b, true);
+        const size_t insert_result = ggml_hash_insert(acc_table, ret);
+        GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
+        GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
+        return ret;
+    }
     if (ggml_hash_contains(zero_table, a)) {
         return ggml_repeat(ctx, b, a);
-    } else {
-        return ggml_add1_impl(ctx, a, b, false);
     }
+    return ggml_add1_impl(ctx, a, b, false);
 }
 
-static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) {
+static struct ggml_tensor * ggml_sub_or_set(
+        struct ggml_context  * ctx,
+        struct ggml_tensor   * a,
+        struct ggml_tensor   * b,
+        struct ggml_hash_set * zero_table,
+        struct ggml_hash_set * acc_table) {
+    if (ggml_hash_contains(acc_table, a)) {
+        struct ggml_tensor * ret = ggml_sub_impl(ctx, a, b, true);
+        const size_t insert_result = ggml_hash_insert(acc_table, ret);
+        GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
+        GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
+        return ret;
+    }
     if (ggml_hash_contains(zero_table, a)) {
         return ggml_neg(ctx, b);
-    } else {
-        return ggml_sub_impl(ctx, a, b, false);
     }
+    return ggml_sub_impl(ctx, a, b, false);
 }
 
-static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set * zero_table) {
+static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set * zero_table, struct ggml_hash_set * acc_table) {
     struct ggml_tensor * src0 = tensor->src[0];
     struct ggml_tensor * src1 = tensor->src[1];
     struct ggml_tensor * src2 = tensor->src[2];
@@ -18040,38 +18233,38 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
         case GGML_OP_DUP:
             {
                 if (src0->grad) {
-                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
+                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
                 }
             } break;
         case GGML_OP_ADD:
             {
                 if (src0->grad) {
-                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
+                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
                 }
                 if (src1->grad) {
                     if (ggml_are_same_shape(src0, src1)) {
-                        src1->grad = ggml_add_or_set(ctx, src1->grad,                       tensor->grad,        zero_table);
+                        src1->grad = ggml_add_or_set(ctx, src1->grad,                       tensor->grad,        zero_table, acc_table);
                     } else {
-                        src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_repeat_back(ctx, tensor->grad, src1), zero_table);
+                        src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_repeat_back(ctx, tensor->grad, src1), zero_table, acc_table);
                     }
                 }
             } break;
         case GGML_OP_ADD1:
             {
                 if (src0->grad) {
-                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
+                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
                 }
                 if (src1->grad) {
                     src1->grad = ggml_add_or_set(ctx,
                         src1->grad,
                         ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean
-                        zero_table);
+                        zero_table, acc_table);
                 }
             } break;
         case GGML_OP_ACC:
             {
                 if (src0->grad) {
-                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
+                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
                 }
                 if (src1->grad) {
                     const size_t nb1     = ((int32_t *) tensor->op_params)[0];
@@ -18093,16 +18286,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                             ggml_reshape(ctx,
                                 ggml_cont(ctx, tensor_grad_view),
                                 src1->grad),
-                            zero_table);
+                            zero_table, acc_table);
                 }
             } break;
         case GGML_OP_SUB:
             {
                 if (src0->grad) {
-                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
+                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
                 }
                 if (src1->grad) {
-                    src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table);
+                    src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table);
                 }
             } break;
         case GGML_OP_MUL:
@@ -18112,14 +18305,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                         ggml_add_or_set(ctx,
                                 src0->grad,
                                 ggml_mul(ctx, src1, tensor->grad),
-                                zero_table);
+                                zero_table, acc_table);
                 }
                 if (src1->grad) {
                     src1->grad =
                         ggml_add_or_set(ctx,
                                 src1->grad,
                                 ggml_mul(ctx, src0, tensor->grad),
-                                zero_table);
+                                zero_table, acc_table);
                 }
             } break;
         case GGML_OP_DIV:
@@ -18129,7 +18322,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                         ggml_add_or_set(ctx,
                                 src0->grad,
                                 ggml_div(ctx, tensor->grad, src1),
-                                zero_table);
+                                zero_table, acc_table);
                 }
                 if (src1->grad) {
                     src1->grad =
@@ -18138,7 +18331,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 ggml_mul(ctx,
                                     tensor->grad,
                                     ggml_div(ctx, tensor, src1)),
-                                zero_table);
+                                zero_table, acc_table);
                 }
             } break;
         case GGML_OP_SQR:
@@ -18150,7 +18343,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 ggml_scale(ctx,
                                     ggml_mul(ctx, src0, tensor->grad),
                                     2.0f),
-                                zero_table);
+                                zero_table, acc_table);
                 }
             } break;
         case GGML_OP_SQRT:
@@ -18164,7 +18357,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                         tensor->grad,
                                         tensor),
                                     0.5f),
-                                zero_table);
+                                zero_table, acc_table);
                 }
             } break;
         case GGML_OP_LOG:
@@ -18176,7 +18369,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 ggml_div(ctx,
                                     tensor->grad,
                                     src0),
-                                zero_table);
+                                zero_table, acc_table);
                 }
             } break;
         case GGML_OP_SIN:
@@ -18188,7 +18381,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 ggml_mul(ctx,
                                     tensor->grad,
                                     ggml_cos(ctx, src0)),
-                                zero_table);
+                                zero_table, acc_table);
                 }
             } break;
         case GGML_OP_COS:
@@ -18200,7 +18393,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 ggml_mul(ctx,
                                     tensor->grad,
                                     ggml_sin(ctx, src0)),
-                                zero_table);
+                                zero_table, acc_table);
                 }
             } break;
         case GGML_OP_SUM:
@@ -18210,7 +18403,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                         ggml_add1_or_set(ctx,
                                 src0->grad,
                                 tensor->grad,
-                                zero_table);
+                                zero_table, acc_table);
                 }
             } break;
         case GGML_OP_SUM_ROWS:
@@ -18222,7 +18415,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 ggml_repeat(ctx,
                                     tensor->grad,
                                     src0->grad),
-                                zero_table);
+                                zero_table, acc_table);
                 }
             } break;
         case GGML_OP_MEAN:
@@ -18237,7 +18430,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src0->grad = ggml_add_or_set(ctx,
                             src0->grad,
                             ggml_repeat_back(ctx, tensor->grad, src0->grad),
-                            zero_table);
+                            zero_table, acc_table);
                 }
             } break;
         case GGML_OP_REPEAT_BACK:
@@ -18247,7 +18440,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src0->grad = ggml_add_or_set(ctx,
                             src0->grad,
                             ggml_repeat(ctx, tensor->grad, src0->grad),
-                            zero_table);
+                            zero_table, acc_table);
                 }
             } break;
         case GGML_OP_CONCAT:
@@ -18272,7 +18465,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src0->grad = ggml_add_or_set(ctx,
                             src0->grad,
                             ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
-                            zero_table);
+                            zero_table, acc_table);
                 }
             } break;
         case GGML_OP_RMS_NORM_BACK:
@@ -18320,7 +18513,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                         ggml_add_or_set(ctx,
                                 src0->grad, // [n,m,q1,r1]
                                 s1_tg,      // [n,m,q1,r1]
-                                zero_table);
+                                zero_table, acc_table);
                 }
                 if (src1->grad) {
                     src1->grad =
@@ -18338,7 +18531,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                     src0,                           // [n,m,q1,r1]
                                     ggml_transpose(ctx,             // [p,m,qq,rr]
                                         tensor->grad)),             // [m,p,qq,rr]
-                                zero_table);
+                                zero_table, acc_table);
                 }
             } break;
         case GGML_OP_MUL_MAT_ID:
@@ -18360,7 +18553,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                         ggml_add_or_set(ctx,
                             src0->grad,
                             ggml_scale_impl(ctx, tensor->grad, s, false),
-                            zero_table);
+                            zero_table, acc_table);
                 }
             } break;
         case GGML_OP_SET:
@@ -18389,7 +18582,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                             tensor->grad,
                             ggml_neg(ctx, tensor_grad_view),
                             nb1, nb2, nb3, offset, false),
-                        zero_table);
+                        zero_table, acc_table);
                 }
 
                 if (src1->grad) {
@@ -18399,7 +18592,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                             ggml_reshape(ctx,
                                 ggml_cont(ctx, tensor_grad_view),
                                 src1->grad),
-                            zero_table);
+                            zero_table, acc_table);
                 }
             } break;
         case GGML_OP_CPY:
@@ -18410,7 +18603,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 // tensor = src0 * 1 + src1 * 0
                 if (src0->grad) {
                     // dsrc0 = dtensor * 1
-                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
+                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
                 }
                 if (src1->grad) {
                     // dsrc1 = dtensor * 0 -> noop
@@ -18422,7 +18615,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 if (src0->grad) {
                     GGML_ASSERT(ggml_is_contiguous(src0->grad));
                     GGML_ASSERT(ggml_is_contiguous(tensor->grad));
-                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
+                    src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
                 }
             } break;
         case GGML_OP_RESHAPE:
@@ -18436,7 +18629,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                     ? tensor->grad
                                     : ggml_cont(ctx, tensor->grad),
                                 src0->grad),
-                        zero_table);
+                        zero_table, acc_table);
                 }
             } break;
         case GGML_OP_VIEW:
@@ -18465,7 +18658,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                         nb3 = (nb3 / n0) * ng;
                     }
 
-                    src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table);
+                    src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table, acc_table);
                 }
             } break;
         case GGML_OP_PERMUTE:
@@ -18490,7 +18683,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 axes_backward[1],
                                 axes_backward[2],
                                 axes_backward[3]),
-                            zero_table);
+                            zero_table, acc_table);
                 }
             } break;
         case GGML_OP_TRANSPOSE:
@@ -18500,7 +18693,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src0->grad =
                         ggml_add_or_set(ctx, src0->grad,
                             ggml_transpose(ctx, tensor->grad),
-                        zero_table);
+                        zero_table, acc_table);
                 }
             } break;
         case GGML_OP_GET_ROWS:
@@ -18512,7 +18705,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                             // last ggml_get_rows_back argument src0->grad is only
                             // necessary to setup correct output shape
                             ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
-                        zero_table);
+                        zero_table, acc_table);
                 }
                 if (src1->grad) {
                     // noop
@@ -18536,7 +18729,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                             /* ggml_diag_mask_inf_impl() shouldn't be here */
                             /* ref:  https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */
                             ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
-                        zero_table);
+                        zero_table, acc_table);
                 }
             } break;
         case GGML_OP_DIAG_MASK_ZERO:
@@ -18547,7 +18740,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src0->grad =
                         ggml_add_or_set(ctx, src0->grad,
                             ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
-                        zero_table);
+                        zero_table, acc_table);
                 }
             } break;
         case GGML_OP_SOFT_MAX:
@@ -18557,7 +18750,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src0->grad =
                         ggml_add_or_set(ctx, src0->grad,
                             ggml_soft_max_back(ctx, tensor->grad, tensor),
-                        zero_table);
+                        zero_table, acc_table);
                 }
 
             } break;
@@ -18598,7 +18791,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 attn_factor,
                                 beta_fast,
                                 beta_slow),
-                            zero_table);
+                            zero_table, acc_table);
                 }
             } break;
         case GGML_OP_ROPE_BACK:
@@ -18634,7 +18827,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 beta_fast,
                                 beta_slow,
                                 false),
-                            zero_table);
+                            zero_table, acc_table);
                 }
             } break;
         case GGML_OP_CLAMP:
@@ -18659,7 +18852,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src1->grad = ggml_add_or_set(ctx,
                             src1->grad,
                             ggml_im2col_back(ctx, src0, tensor->grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D),
-                            zero_table);
+                            zero_table, acc_table);
                 }
             } break;
         case GGML_OP_IM2COL_BACK:
@@ -18688,7 +18881,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src0->grad = ggml_add_or_set(ctx,
                             src0->grad,
                             ggml_pool_2d_back(ctx, tensor->grad, src0, op, k0, k1, s0, s1, p0, p1),
-                            zero_table);
+                            zero_table, acc_table);
                 }
             } break;
         case GGML_OP_POOL_2D_BACK:
@@ -18753,7 +18946,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src0->grad = ggml_add_or_set(ctx,
                             src0->grad,
                             grad_q,
-                            zero_table);
+                            zero_table, acc_table);
                 }
                 if (src1->grad) {
                     struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k);
@@ -18761,7 +18954,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src1->grad = ggml_add_or_set(ctx,
                             src1->grad,
                             grad_k,
-                            zero_table);
+                            zero_table, acc_table);
                 }
                 if (src2->grad) {
                     struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v);
@@ -18769,7 +18962,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src2->grad = ggml_add_or_set(ctx,
                             src2->grad,
                             grad_v,
-                            zero_table);
+                            zero_table, acc_table);
                 }
             } break;
         case GGML_OP_FLASH_ATTN_BACK:
@@ -18795,7 +18988,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                             ggml_mul(ctx,
                                                 ggml_sgn(ctx, src0),
                                                 tensor->grad),
-                                            zero_table);
+                                            zero_table, acc_table);
                             }
                         } break;
                     case GGML_UNARY_OP_SGN:
@@ -18807,7 +19000,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     case GGML_UNARY_OP_NEG:
                         {
                             if (src0->grad) {
-                                src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table);
+                                src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
                             }
                         } break;
                     case GGML_UNARY_OP_STEP:
@@ -18832,7 +19025,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                         ggml_mul(ctx,
                                             ggml_step(ctx, src0),
                                             tensor->grad),
-                                        zero_table);
+                                        zero_table, acc_table);
                             }
                         } break;
                     case GGML_UNARY_OP_SIGMOID:
@@ -18854,7 +19047,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 src0->grad = ggml_add_or_set(ctx,
                                         src0->grad,
                                         ggml_silu_back(ctx, src0, tensor->grad),
-                                        zero_table);
+                                        zero_table, acc_table);
                             }
                         } break;
                     case GGML_UNARY_OP_EXP:
@@ -18863,7 +19056,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 src0->grad = ggml_add_or_set(ctx,
                                         src0->grad,
                                         ggml_mul(ctx, tensor, tensor->grad),
-                                        zero_table);
+                                        zero_table, acc_table);
                             }
                         } break;
                     default:
@@ -18893,13 +19086,17 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                     src0,
                                     src1,
                                     tensor->grad),
-                                zero_table);
+                                zero_table, acc_table);
                 }
             } break;
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
             {
                 GGML_ABORT("fatal error"); // not supported
             }
+        case GGML_OP_OPT_STEP_ADAMW:
+            {
+                GGML_ABORT("fatal error"); // not supported
+            }
         case GGML_OP_NONE:
             {
                 // nop
@@ -18989,7 +19186,7 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor *
     ggml_build_forward_impl(cgraph, tensor, true);
 }
 
-void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep) {
+void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate, bool keep) {
     GGML_ASSERT(gf->n_nodes > 0);
     GGML_ASSERT(gf->grads);
 
@@ -19005,21 +19202,35 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
         }
     }
 
-    // remember original gradients which start with zero values
+    // keep tables of original gradients for replacement/accumulation logic
     struct ggml_hash_set zero_table = ggml_hash_set_new(gf->size);
+    struct ggml_hash_set acc_table  = ggml_hash_set_new(gf->size);
     for (int i = 0; i < gf->n_nodes; i++) {
-        if (gf->grads[i]) {
-            ggml_hash_insert(&zero_table, gf->grads[i]);
+        struct ggml_tensor * node = gf->nodes[i];
+
+        if (node->grad) {
+            {
+                const size_t insert_result = ggml_hash_insert(&zero_table, node->grad);
+                GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
+                GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
+            }
+
+            // only gradients of trainable parameters should be accumulated
+            if (accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
+                const size_t insert_result = ggml_hash_insert(&acc_table, node->grad);
+                GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
+                GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
+            }
         }
     }
 
     for (int i = gf->n_nodes - 1; i >= 0; i--) {
         struct ggml_tensor * node = gf->nodes[i];
 
-        // inplace operations to add gradients are not created by ggml_compute_backward
+        // inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
         // use allocator to automatically make inplace operations
         if (node->grad) {
-            ggml_compute_backward(ctx, node, &zero_table);
+            ggml_compute_backward(ctx, node, &zero_table, &acc_table);
         }
     }
 
@@ -19033,8 +19244,30 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
     }
 
     ggml_hash_set_free(&zero_table);
+    ggml_hash_set_free(&acc_table);
+}
+
+void ggml_build_opt_adamw(
+        struct ggml_context * ctx,
+        struct ggml_cgraph  * gf,
+        struct ggml_cgraph  * gb,
+        float                 alpha,
+        float                 beta1,
+        float                 beta2,
+        float                 eps,
+        float                 wd) {
+    for (int i = 0; i < gf->n_nodes; i++) {
+        struct ggml_tensor * node = gf->nodes[i];
+
+        if (node->flags & GGML_TENSOR_FLAG_PARAM) {
+            GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
+            struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, alpha, beta1, beta2, eps, wd);
+            ggml_build_forward_expand(gb, opt_step);
+        }
+    }
 }
 
+
 static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
     void * ptr = *p;
     ptr = (void *) GGML_PAD((uintptr_t) ptr, align);
@@ -19162,10 +19395,28 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
     GGML_ASSERT(cgraph->grads != NULL);
 
     for (int i = 0; i < cgraph->n_nodes; i++) {
-        struct ggml_tensor * grad = cgraph->grads[i];
+        struct ggml_tensor * node = cgraph->nodes[i];
+
+        // initial gradients of loss should be 1, 0 otherwise
+        if (node->grad) {
+            if (node->flags & GGML_TENSOR_FLAG_LOSS) {
+                GGML_ASSERT(node->grad->buffer);
+                GGML_ASSERT(node->type == GGML_TYPE_F32);
+                GGML_ASSERT(ggml_is_scalar(node));
+
+                const float onef = 1.0f;
+                ggml_backend_tensor_set(node->grad, &onef, 0, ggml_nbytes(node->grad));
+            } else {
+                ggml_set_zero(node->grad);
+            }
+        }
 
-        if (grad) {
-            ggml_set_zero(grad);
+        GGML_ASSERT(node);
+        if (node->op == GGML_OP_OPT_STEP_ADAMW) {
+            // set iteration to 1 and clear momenta
+            ggml_set_op_params_i32(node, 0, 1);
+            ggml_set_zero(node->src[2]);
+            ggml_set_zero(node->src[3]);
         }
     }
 }
@@ -19458,6 +19709,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             } break;
         case GGML_OP_CROSS_ENTROPY_LOSS:
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+        case GGML_OP_OPT_STEP_ADAMW:
             {
                 n_tasks = n_threads;
             } break;
@@ -21851,7 +22103,7 @@ enum ggml_opt_result ggml_opt_resume(
     ggml_build_forward_expand(gf, f);
 
     struct ggml_cgraph * gb = ggml_graph_dup(ctx, gf);
-    ggml_build_backward_expand(ctx, gf, gb, true);
+    ggml_build_backward_expand(ctx, gf, gb, false, true);
 
     return ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL);
 }
index aa7896defdad0721752e09f45d6cccb761ac5d54..889a199448a6d1ce77cdd3f58fd7e376a6a58e36 100644 (file)
@@ -799,10 +799,11 @@ struct test_case {
             out = ggml_sum(ctx, out);
             ggml_set_name(out, "sum_of_out");
         }
+        ggml_set_loss(out);
 
         ggml_build_forward_expand(gf, out);
         ggml_graph_cpy(gf, gb);
-        ggml_build_backward_expand(ctx, gf, gb, false);
+        ggml_build_backward_expand(ctx, gf, gb, false, false);
         if (expect.size() != 1 || expect[0] != 0.0f) {
             GGML_ASSERT(ggml_graph_n_nodes(gb) > ggml_graph_n_nodes(gf));
             for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
@@ -837,22 +838,11 @@ struct test_case {
             return false;
         }
 
-        // randomize tensors
-        initialize_tensors(ctx);
-
-        for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
-            if (!t->grad) {
-                continue;
-            }
 
-            std::vector<float> tmp(ggml_nelements(t->grad));
-            ggml_backend_tensor_set(t->grad, tmp.data(), 0, ggml_nbytes(t->grad));
-        }
+        initialize_tensors(ctx); // Randomizes all tensors (including gradients).
+        ggml_graph_reset(gb);    // Sets gradients to 1 if loss, 0 otherwise.
 
-        // build graphs
-        const float onef = 1.0f;
         ggml_backend_graph_compute(backend, gf);
-        ggml_backend_tensor_set(out->grad, &onef, 0, ggml_nbytes(out->grad));
         ggml_backend_graph_compute(backend, gb);
 
         bool ok = true;
@@ -1681,6 +1671,50 @@ struct test_mul_mat_id : public test_case {
     }
 };
 
+// GGML_OP_OUT_PROD
+struct test_out_prod : public test_case {
+    const ggml_type type_a;
+    const ggml_type type_b;
+    const int64_t m;
+    const int64_t n;
+    const int64_t k;
+    const std::array<int64_t, 2> bs; // dims 3 and 4
+    const bool trans_b;
+
+    std::string vars() override {
+        return VARS_TO_STR7(type_a, type_b, m, n, k, bs, trans_b);
+    }
+
+    double max_nmse_err() override {
+        return 5e-4;
+    }
+
+    test_out_prod(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
+            int64_t m = 32, int64_t n = 32, int64_t k = 32,
+            std::array<int64_t, 2> bs = {10, 10},
+            bool trans_b = false)
+        : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), trans_b(trans_b) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, m, k, bs[0], bs[1]);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * b;
+        if (trans_b) {
+            b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0], bs[1]);
+            b = ggml_transpose(ctx, b);
+        } else {
+            b = ggml_new_tensor_4d(ctx, type_b, n, k, bs[0], bs[1]);
+        }
+        ggml_set_name(b, "b");
+
+        ggml_tensor * out = ggml_out_prod(ctx, a, b);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
 // GGML_OP_SQR
 struct test_sqr : public test_case {
     const ggml_type type;
@@ -2666,6 +2700,51 @@ struct test_cross_entropy_loss : public test_case {
     }
 };
 
+// GGML_OP_OPT_STEP_ADAMW
+struct test_opt_step_adamw : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    const float alpha;
+    const float beta1;
+    const float beta2;
+    const float eps;
+    const float wd;
+
+    std::string vars() override {
+        return VARS_TO_STR7(type, ne, alpha, beta1, beta2, eps, wd);
+    }
+
+    test_opt_step_adamw(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 5, 4, 3},
+            float alpha = 1e-3f,
+            float beta1 = 0.9f,
+            float beta2 = 0.999f,
+            float eps = 1e-8f,
+            float wd = 0.0f)
+        : type(type), ne(ne), alpha(alpha), beta1(beta1), beta2(beta2), eps(eps), wd(wd) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
+        ggml_set_param(ctx, a); // Despite tensor a having gradients the output tensor will not.
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_opt_step_adamw(ctx, a, alpha, beta1, beta2, eps, wd);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t, 0.0f, 1.0f); // grad_v needs non-negative values.
+        }
+    }
+
+    bool grad_precise() override {
+        return true;
+    }
+};
+
 enum llm_norm_type {
     LLM_NORM,
     LLM_NORM_RMS,
@@ -3159,14 +3238,15 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
     test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
 
-
-    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 1}));
-    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1}));
-    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 2, 1, 1}));
-    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 1}));
-    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 2}));
-    test_cases.emplace_back(new test_repeat(GGML_TYPE_I32, {10, 5, 4, 3}, {2, 1, 1, 1}));
-    test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 5, 4, 3}, {1, 1, 1, 2}));
+    for (int ne3 : {1, 3}) { // CUDA backwards pass only supports ne3 == 1
+        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 1}));
+        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {2, 1, 1, 1}));
+        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 2, 1, 1}));
+        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 2, 1}));
+        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 2}));
+        test_cases.emplace_back(new test_repeat(GGML_TYPE_I32, {10, 5, 4, ne3}, {2, 1, 1, 1}));
+        test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 5, 4, ne3}, {1, 1, 1, 2}));
+    }
 
     test_cases.emplace_back(new test_dup(GGML_TYPE_F32));
     test_cases.emplace_back(new test_dup(GGML_TYPE_F16));
@@ -3350,6 +3430,27 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         }
     }
 
+    for (ggml_type type_a : base_types) {
+        for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, { 1,  1}));
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10,  1}));
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10,  1}));
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10}));
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10}));
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10}));
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10}));
+
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, { 1,  1}));
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, { 1,  1}, true));
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10,  1}));
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10,  1}));
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10}));
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10}));
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10}));
+            test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10}));
+        }
+    }
+
     test_cases.emplace_back(new test_sqr());
     test_cases.emplace_back(new test_sqrt());
     test_cases.emplace_back(new test_log());
@@ -3476,6 +3577,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     }
 
     test_cases.emplace_back(new test_cross_entropy_loss());
+    for (float wd : {0.0f, 1e-2f}) {
+        test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}, 1.0f, 1e-3f, 0.9f, 0.999f, wd));
+    }
 
     // these tests are disabled to save execution time, but they can be handy for debugging
 #if 0
index 1834c11d894b4ce4279715598fe35c0c8c1c83a2..2ef606d2c3591cbdd85d3ee120430e740875b2fc 100644 (file)
@@ -240,7 +240,7 @@ static bool check_gradient(
     struct ggml_cgraph * gb = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
     ggml_build_forward_expand(gf, f);
     ggml_graph_cpy(gf, gb);
-    ggml_build_backward_expand(ctx0, gf, gb, false);
+    ggml_build_backward_expand(ctx0, gf, gb, false, false);
 
     ggml_graph_compute_with_ctx(ctx0, gf, n_threads);