]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
sync : ggml
authorGeorgi Gerganov <redacted>
Tue, 27 Aug 2024 19:01:45 +0000 (22:01 +0300)
committerGeorgi Gerganov <redacted>
Tue, 27 Aug 2024 19:41:27 +0000 (22:41 +0300)
21 files changed:
ggml/include/ggml-backend.h
ggml/include/ggml.h
ggml/src/ggml-cuda.cu
ggml/src/ggml-cuda/binbcast.cu
ggml/src/ggml-cuda/binbcast.cuh
ggml/src/ggml-cuda/cross-entropy-loss.cu [new file with mode: 0644]
ggml/src/ggml-cuda/cross-entropy-loss.cuh [new file with mode: 0644]
ggml/src/ggml-cuda/sumrows.cu
ggml/src/ggml-cuda/sumrows.cuh
ggml/src/ggml-cuda/unary.cu
ggml/src/ggml-cuda/unary.cuh
ggml/src/ggml-metal.m
ggml/src/ggml-metal.metal
ggml/src/ggml-quants.c
ggml/src/ggml-vulkan.cpp
ggml/src/ggml.c
ggml/src/vulkan-shaders/cos.comp [new file with mode: 0644]
ggml/src/vulkan-shaders/sin.comp [new file with mode: 0644]
scripts/sync-ggml.last
tests/test-backend-ops.cpp
tests/test-grad0.cpp

index 5f3f1e286990e478f96fcf1cbc21b58887d9b22a..e73b9a7452feda5b8a2e0c041e5f2244f81bb8fd 100644 (file)
@@ -63,6 +63,7 @@ extern "C" {
     GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
     GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
 
+    // "offset" refers to the offset of the tensor data for setting/getting data
     GGML_API GGML_CALL void ggml_backend_tensor_set(      struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
     GGML_API GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
 
index a7e9dc9b2ff6346e9b3a65c926d7ebc9e8c6fb07..b11d047aeda7d057f74f9b606e8d2f837e44a16c 100644 (file)
 #include <stdio.h>
 
 #define GGML_FILE_MAGIC   0x67676d6c // "ggml"
-#define GGML_FILE_VERSION 1
+#define GGML_FILE_VERSION 2
 
 #define GGML_QNT_VERSION        2    // bump this on quantization format changes
 #define GGML_QNT_VERSION_FACTOR 1000 // do not change this
@@ -453,6 +453,8 @@ extern "C" {
         GGML_OP_SQR,
         GGML_OP_SQRT,
         GGML_OP_LOG,
+        GGML_OP_SIN,
+        GGML_OP_COS,
         GGML_OP_SUM,
         GGML_OP_SUM_ROWS,
         GGML_OP_MEAN,
@@ -490,9 +492,11 @@ extern "C" {
         GGML_OP_CLAMP,
         GGML_OP_CONV_TRANSPOSE_1D,
         GGML_OP_IM2COL,
+        GGML_OP_IM2COL_BACK,
         GGML_OP_CONV_TRANSPOSE_2D,
         GGML_OP_POOL_1D,
         GGML_OP_POOL_2D,
+        GGML_OP_POOL_2D_BACK,
         GGML_OP_UPSCALE, // nearest interpolate
         GGML_OP_PAD,
         GGML_OP_ARANGE,
@@ -969,6 +973,22 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    GGML_API struct ggml_tensor * ggml_sin(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_sin_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_cos(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_cos_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     // return scalar
     GGML_API struct ggml_tensor * ggml_sum(
             struct ggml_context * ctx,
@@ -1566,34 +1586,49 @@ extern "C" {
             float                 min,
             float                 max);
 
+    // im2col
+    // converts data into a format that effectively results in a convolution when combined with matrix multiplication
     GGML_API struct ggml_tensor * ggml_im2col(
             struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            int                  s0,
-            int                  s1,
-            int                  p0,
-            int                  p1,
-            int                  d0,
-            int                  d1,
-            bool                 is_2D,
-            enum ggml_type       dst_type);
+            struct ggml_tensor  * a,  // convolution kernel
+            struct ggml_tensor  * b,  // data
+            int                   s0, // stride dimension 0
+            int                   s1, // stride dimension 1
+            int                   p0, // padding dimension 0
+            int                   p1, // padding dimension 1
+            int                   d0, // dilation dimension 0
+            int                   d1, // dilation dimension 1
+            bool                  is_2D,
+            enum ggml_type        dst_type);
+
+    GGML_API struct ggml_tensor * ggml_im2col_back(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,  // convolution kernel
+        struct ggml_tensor  * b,  // gradient of im2col output
+        int64_t             * ne, // shape of im2col input
+        int                   s0, // stride dimension 0
+        int                   s1, // stride dimension 1
+        int                   p0, // padding dimension 0
+        int                   p1, // padding dimension 1
+        int                   d0, // dilation dimension 0
+        int                   d1, // dilation dimension 1
+        bool                  is_2D);
 
     GGML_API struct ggml_tensor * ggml_conv_depthwise_2d(
             struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            int                  s0,
-            int                  s1,
-            int                  p0,
-            int                  p1,
-            int                  d0,
-            int                  d1);
+            struct ggml_tensor  * a,  // convolution kernel
+            struct ggml_tensor  * b,  // data
+            int                  s0,  // stride dimension 0
+            int                  s1,  // stride dimension 1
+            int                  p0,  // padding dimension 0
+            int                  p1,  // padding dimension 1
+            int                  d0,  // dilation dimension 0
+            int                  d1); // dilation dimension 1
 
     GGML_API struct ggml_tensor * ggml_conv_1d(
             struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
+            struct ggml_tensor  * a,   // convolution kernel
+            struct ggml_tensor  * b,   // data
             int                   s0,  // stride
             int                   p0,  // padding
             int                   d0); // dilation
@@ -1602,29 +1637,29 @@ extern "C" {
     // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)
     GGML_API struct ggml_tensor* ggml_conv_1d_ph(
             struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            int                   s,
-            int                   d);
+            struct ggml_tensor  * a,  // convolution kernel
+            struct ggml_tensor  * b,  // data
+            int                   s,  // stride
+            int                   d); // dilation
 
     GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
             struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            int                   s0,
-            int                   p0,
-            int                   d0);
+            struct ggml_tensor  * a,   // convolution kernel
+            struct ggml_tensor  * b,   // data
+            int                   s0,  // stride
+            int                   p0,  // padding
+            int                   d0); // dilation
 
     GGML_API struct ggml_tensor * ggml_conv_2d(
             struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            struct ggml_tensor  * b,
-            int                   s0,
-            int                   s1,
-            int                   p0,
-            int                   p1,
-            int                   d0,
-            int                   d1);
+            struct ggml_tensor  * a,   // convolution kernel
+            struct ggml_tensor  * b,   // data
+            int                   s0,  // stride dimension 0
+            int                   s1,  // stride dimension 1
+            int                   p0,  // padding dimension 0
+            int                   p1,  // padding dimension 1
+            int                   d0,  // dilation dimension 0
+            int                   d1); // dilation dimension 1
 
 
     // kernel size is a->ne[0] x a->ne[1]
@@ -1686,6 +1721,18 @@ extern "C" {
             float                 p0,
             float                 p1);
 
+    GGML_API struct ggml_tensor * ggml_pool_2d_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * af, // "a"/input used in forward pass
+            enum ggml_op_pool     op,
+            int                   k0,
+            int                   k1,
+            int                   s0,
+            int                   s1,
+            float                 p0,
+            float                 p1);
+
     // nearest interpolate
     // multiplies ne0 and ne1 by scale factor
     // used in stable-diffusion
index 682c30d45bcf4323a4e2506ed4b2ecacf7d63e99..8a844b02a27a5a451fe694cc324864b568c93c7b 100644 (file)
@@ -9,8 +9,10 @@
 #include "ggml-cuda/binbcast.cuh"
 #include "ggml-cuda/clamp.cuh"
 #include "ggml-cuda/concat.cuh"
+#include "ggml-cuda/conv-transpose-1d.cuh"
 #include "ggml-cuda/convert.cuh"
 #include "ggml-cuda/cpy.cuh"
+#include "ggml-cuda/cross-entropy-loss.cuh"
 #include "ggml-cuda/diagmask.cuh"
 #include "ggml-cuda/dmmv.cuh"
 #include "ggml-cuda/fattn.cuh"
@@ -29,7 +31,6 @@
 #include "ggml-cuda/tsembd.cuh"
 #include "ggml-cuda/unary.cuh"
 #include "ggml-cuda/upscale.cuh"
-#include "ggml-cuda/conv-transpose-1d.cuh"
 
 #include <algorithm>
 #include <array>
@@ -2181,6 +2182,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_ADD:
             ggml_cuda_op_add(ctx, dst);
             break;
+        case GGML_OP_SUB:
+            ggml_cuda_op_sub(ctx, dst);
+            break;
         case GGML_OP_ACC:
             ggml_cuda_op_acc(ctx, dst);
             break;
@@ -2267,6 +2271,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_SQRT:
             ggml_cuda_op_sqrt(ctx, dst);
             break;
+        case GGML_OP_SIN:
+            ggml_cuda_op_sin(ctx, dst);
+            break;
+        case GGML_OP_COS:
+            ggml_cuda_op_cos(ctx, dst);
+            break;
         case GGML_OP_CLAMP:
             ggml_cuda_op_clamp(ctx, dst);
             break;
@@ -2303,6 +2313,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_FLASH_ATTN_EXT:
             ggml_cuda_flash_attn_ext(ctx, dst);
             break;
+        case GGML_OP_CROSS_ENTROPY_LOSS:
+            ggml_cuda_cross_entropy_loss(ctx, dst);
+            break;
         default:
             return false;
     }
@@ -2610,6 +2623,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
                 assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
                 for (int j = 0; j < GGML_MAX_SRC; j++) {
                     if (node->src[j] != nullptr) {
+                        assert(node->src[j]->buffer);
                         assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer));
                     }
                 }
@@ -2853,12 +2867,15 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_TRANSPOSE:
         case GGML_OP_NORM:
         case GGML_OP_ADD:
+        case GGML_OP_SUB:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
         case GGML_OP_RMS_NORM:
         case GGML_OP_SCALE:
         case GGML_OP_SQR:
         case GGML_OP_SQRT:
+        case GGML_OP_SIN:
+        case GGML_OP_COS:
         case GGML_OP_CLAMP:
         case GGML_OP_CONT:
         case GGML_OP_DIAG_MASK_INF:
@@ -2890,6 +2907,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
             }
             return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
                 op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
+        case GGML_OP_CROSS_ENTROPY_LOSS:
+            return true;
 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
         default:
             return false;
index 34bc67acdd890c077ae15de763435bab09ff0f2c..e1390a0414559fca9bc0a7ae05fd394de3c65615 100644 (file)
@@ -9,6 +9,10 @@ static __device__ __forceinline__ float op_add(const float a, const float b) {
     return a + b;
 }
 
+static __device__ __forceinline__ float op_sub(const float a, const float b) {
+    return a - b;
+}
+
 static __device__ __forceinline__ float op_mul(const float a, const float b) {
     return a * b;
 }
@@ -271,6 +275,10 @@ void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
 }
 
+void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_sub>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
+}
+
 void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
 }
index 4f63d6372eb50e717f47f7fc4844d90719119f82..198c9ef6fd8ea73c3e9e85f5ef0e60676365ca91 100644 (file)
@@ -2,5 +2,6 @@
 
 void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/cross-entropy-loss.cu b/ggml/src/ggml-cuda/cross-entropy-loss.cu
new file mode 100644 (file)
index 0000000..a14043e
--- /dev/null
@@ -0,0 +1,106 @@
+#include "common.cuh"
+#include "cross-entropy-loss.cuh"
+#include "sumrows.cuh"
+
+#include <cmath>
+#include <cstdint>
+
+static __global__ void cross_entropy_loss_f32(const float * logits, const float * labels, float * dst, const int nclasses, const int k) {
+    const int warp_id = threadIdx.x / WARP_SIZE;
+    const int lane_id = threadIdx.x % WARP_SIZE;
+    const int i0 = blockDim.x*blockIdx.x + warp_id*WARP_SIZE;
+
+    const int ne_tmp = WARP_SIZE*nclasses;
+
+    extern __shared__ float tmp_all[];
+    float * tmp_logits = tmp_all + (2*warp_id + 0)*ne_tmp;
+    float * tmp_labels = tmp_all + (2*warp_id + 1)*ne_tmp;
+
+    // Each warp first loads ne_tmp logits/labels into shared memory:
+    for (int i = lane_id; i < ne_tmp; i += WARP_SIZE) {
+        const int ig = i0*nclasses + i; // ig == i global
+
+        tmp_logits[i] = ig < k*nclasses ? logits[ig] : 0.0f;
+        tmp_labels[i] = ig < k*nclasses ? labels[ig] : 0.0f;
+    }
+
+    // Each thread in the warp then calculates the cross entropy loss for a single row.
+    // TODO: pad in order to avoid shared memory bank conflicts.
+
+    // Find maximum for softmax:
+    float max = -INFINITY;
+    for (int i = 0; i < nclasses; ++i) {
+        max = fmaxf(max, tmp_logits[lane_id*nclasses + i]);
+    }
+
+    // Calculate log(softmax(logits)) which is just logits - max:
+    float sum = 0.0f;
+    for (int i = 0; i < nclasses; ++i) {
+        float val = tmp_logits[lane_id*nclasses + i] - max;
+        sum += expf(val);
+        tmp_logits[lane_id*nclasses + i] = val;
+    }
+    sum = logf(sum);
+
+    // log(exp(logits - max) / sum) = (logits - max) - log(sum)
+    float loss = 0.0f;
+    for (int i = 0; i < nclasses; ++i) {
+        loss += (tmp_logits[lane_id*nclasses + i] - sum) * tmp_labels[lane_id*nclasses + i];
+    }
+    loss = -warp_reduce_sum(loss) / (float)k;
+
+    __syncthreads();
+
+    if (lane_id == 0) {
+        tmp_all[warp_id] = loss;
+    }
+
+    __syncthreads();
+
+    if (warp_id != 0) {
+        return;
+    }
+
+    loss = lane_id < CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE/WARP_SIZE ? tmp_all[lane_id] : 0.0f;
+    loss = warp_reduce_sum(loss);
+
+    if (lane_id != 0) {
+        return;
+    }
+
+    dst[blockIdx.x] = loss;
+}
+
+void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+
+    const int64_t ne00  = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    const float * src0_d = (const float *) src0->data;
+    const float * src1_d = (const float *) src1->data;
+    float       * dst_d  = (float       *) dst->data;
+
+    ggml_cuda_pool & pool = ctx.pool();
+    cudaStream_t stream = ctx.stream();
+
+    const dim3 blocks_dim(CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
+    const dim3 blocks_num((nrows + CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE - 1) / CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
+    const int shmem = 2*CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE*ne00*sizeof(float);
+
+    ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
+
+    cross_entropy_loss_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
+
+    // Combine results from individual blocks:
+    sum_rows_f32_cuda(dst_tmp.ptr, dst_d, blocks_num.x, 1, stream);
+}
diff --git a/ggml/src/ggml-cuda/cross-entropy-loss.cuh b/ggml/src/ggml-cuda/cross-entropy-loss.cuh
new file mode 100644 (file)
index 0000000..9d7b8b0
--- /dev/null
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE 256
+
+void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 82e8e875f9be3b58cb8f25cd54a513b61148c0d1..38dbf1b5e1fa9d9a484f55848ac9b21343d029e1 100644 (file)
@@ -16,7 +16,7 @@ static __global__ void k_sum_rows_f32(const float * x, float * dst, const int nc
     }
 }
 
-static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     const dim3 block_dims(WARP_SIZE, 1, 1);
     const dim3 block_nums(nrows, 1, 1);
     k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
@@ -32,7 +32,6 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
     GGML_ASSERT(ggml_is_contiguous(src0));
 
-
     const int64_t ncols = src0->ne[0];
     const int64_t nrows = ggml_nrows(src0);
 
index e7545f83c496bbfbebf1a5bed3ed697c775f2e95..191db1c13167e4d6e9ac3acfff5b919577c0957b 100644 (file)
@@ -1,3 +1,5 @@
 #include "common.cuh"
 
+void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
+
 void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index f9e208011e2a8f2b1387d2c2173e8056bf0ad068..89abfc21d8a56c85bb191fd75270cf713a4a7b64 100644 (file)
@@ -101,6 +101,24 @@ static __global__ void sqrt_f32(const float * x, float * dst, const int k) {
     dst[i] = sqrtf(x[i]);
 }
 
+static __global__ void sin_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = sinf(x[i]);
+}
+
+static __global__ void cos_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = cosf(x[i]);
+}
+
 static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
     gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -156,6 +174,16 @@ static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_
     sqrt_f32<<<num_blocks, CUDA_SQRT_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
+static void sin_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_SIN_BLOCK_SIZE - 1) / CUDA_SIN_BLOCK_SIZE;
+    sin_f32<<<num_blocks, CUDA_SIN_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+static void cos_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_COS_BLOCK_SIZE - 1) / CUDA_COS_BLOCK_SIZE;
+    cos_f32<<<num_blocks, CUDA_COS_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
 void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const float * src0_d = (const float *)src0->data;
@@ -312,3 +340,31 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
     sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
 }
+
+void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    sin_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    cos_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
index 4cfb0479e7169a83a8b433b21b9b381aeb9158d9..c610e996abeb62a2fdcc2ea424774513f7e5c443 100644 (file)
@@ -9,6 +9,8 @@
 #define CUDA_HARDSWISH_BLOCK_SIZE 256
 #define CUDA_SQR_BLOCK_SIZE 256
 #define CUDA_SQRT_BLOCK_SIZE 256
+#define CUDA_SIN_BLOCK_SIZE 256
+#define CUDA_COS_BLOCK_SIZE 256
 
 void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
@@ -31,3 +33,7 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
 void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 936751800518b279bed0c3072cf152f887a248c9..91b5e61b23eadfed8d349a5cd0114c39f07f5285 100644 (file)
@@ -31,6 +31,8 @@ struct ggml_metal_kernel {
 enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_ADD,
     GGML_METAL_KERNEL_TYPE_ADD_ROW,
+    GGML_METAL_KERNEL_TYPE_SUB,
+    GGML_METAL_KERNEL_TYPE_SUB_ROW,
     GGML_METAL_KERNEL_TYPE_MUL,
     GGML_METAL_KERNEL_TYPE_MUL_ROW,
     GGML_METAL_KERNEL_TYPE_DIV,
@@ -207,6 +209,9 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
     GGML_METAL_KERNEL_TYPE_CONCAT,
     GGML_METAL_KERNEL_TYPE_SQR,
+    GGML_METAL_KERNEL_TYPE_SQRT,
+    GGML_METAL_KERNEL_TYPE_SIN,
+    GGML_METAL_KERNEL_TYPE_COS,
     GGML_METAL_KERNEL_TYPE_SUM_ROWS,
 
     GGML_METAL_KERNEL_TYPE_COUNT
@@ -493,6 +498,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
 
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD,                           add,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW,                       add_row,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB,                           sub,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW,                       sub_row,                        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL,                           mul,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW,                       mul_row,                        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV,                           div,                            true);
@@ -669,6 +676,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,                cpy_f32_iq4_nl,                 true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT,                        concat,                         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR,                           sqr,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT,                          sqrt,                           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN,                           sin,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS,                           cos,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                      sum_rows,                       true);
     }
 
@@ -769,15 +779,20 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
         case GGML_OP_PERMUTE:
         case GGML_OP_CONCAT:
         case GGML_OP_ADD:
+        case GGML_OP_SUB:
         case GGML_OP_ACC:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
         case GGML_OP_REPEAT:
         case GGML_OP_SCALE:
         case GGML_OP_CLAMP:
+            return true;
         case GGML_OP_SQR:
+        case GGML_OP_SQRT:
+        case GGML_OP_SIN:
+        case GGML_OP_COS:
+            return ggml_is_contiguous(op->src[0]);
         case GGML_OP_SUM_ROWS:
-            return true;
         case GGML_OP_SOFT_MAX:
         case GGML_OP_RMS_NORM:
         case GGML_OP_GROUP_NORM:
@@ -1057,6 +1072,7 @@ static enum ggml_status ggml_metal_graph_compute(
                         [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                     } break;
                 case GGML_OP_ADD:
+                case GGML_OP_SUB:
                 case GGML_OP_MUL:
                 case GGML_OP_DIV:
                     {
@@ -1080,6 +1096,7 @@ static enum ggml_status ggml_metal_graph_compute(
                             nb = ne00 / 4;
                             switch (dst->op) {
                                 case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
+                                case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
                                 case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
                                 case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
                                 default: GGML_ABORT("fatal error");
@@ -1089,6 +1106,7 @@ static enum ggml_status ggml_metal_graph_compute(
                         } else {
                             switch (dst->op) {
                                 case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
+                                case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
                                 case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
                                 case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
                                 default: GGML_ABORT("fatal error");
@@ -1416,6 +1434,48 @@ static enum ggml_status ggml_metal_graph_compute(
 
                         const int64_t n = ggml_nelements(dst);
 
+                        [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                    } break;
+                case GGML_OP_SQRT:
+                    {
+                        GGML_ASSERT(ggml_is_contiguous(src0));
+
+                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQRT].pipeline;
+
+                        [encoder setComputePipelineState:pipeline];
+                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                        [encoder setBuffer:id_dst  offset:offs_dst atIndex:1];
+
+                        const int64_t n = ggml_nelements(dst);
+
+                        [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                    } break;
+                case GGML_OP_SIN:
+                    {
+                        GGML_ASSERT(ggml_is_contiguous(src0));
+
+                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIN].pipeline;
+
+                        [encoder setComputePipelineState:pipeline];
+                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                        [encoder setBuffer:id_dst  offset:offs_dst atIndex:1];
+
+                        const int64_t n = ggml_nelements(dst);
+
+                        [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                    } break;
+                case GGML_OP_COS:
+                    {
+                        GGML_ASSERT(ggml_is_contiguous(src0));
+
+                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_COS].pipeline;
+
+                        [encoder setComputePipelineState:pipeline];
+                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                        [encoder setBuffer:id_dst  offset:offs_dst atIndex:1];
+
+                        const int64_t n = ggml_nelements(dst);
+
                         [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                     } break;
                 case GGML_OP_SUM_ROWS:
index 755970f31ce2969f16ad10eed71d02e747ebe9b7..f323ab5f447d5497259405f9e3eb5cb827f4aedd 100644 (file)
@@ -17,7 +17,7 @@ enum ggml_sort_order {
     GGML_SORT_ORDER_DESC,
 };
 
-// general-purpose kernel for addition, multiplication and division of two tensors
+// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
 // pros: works for non-contiguous tensors, supports broadcast across all dims
 // cons: not very efficient
 kernel void kernel_add(
@@ -70,6 +70,56 @@ kernel void kernel_add(
     }
 }
 
+kernel void kernel_sub(
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        constant  int64_t & ne00,
+        constant  int64_t & ne01,
+        constant  int64_t & ne02,
+        constant  int64_t & ne03,
+        constant uint64_t & nb00,
+        constant uint64_t & nb01,
+        constant uint64_t & nb02,
+        constant uint64_t & nb03,
+        constant  int64_t & ne10,
+        constant  int64_t & ne11,
+        constant  int64_t & ne12,
+        constant  int64_t & ne13,
+        constant uint64_t & nb10,
+        constant uint64_t & nb11,
+        constant uint64_t & nb12,
+        constant uint64_t & nb13,
+        constant  int64_t & ne0,
+        constant  int64_t & ne1,
+        constant  int64_t & ne2,
+        constant  int64_t & ne3,
+        constant uint64_t & nb0,
+        constant uint64_t & nb1,
+        constant uint64_t & nb2,
+        constant uint64_t & nb3,
+        constant  int64_t & offs,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = tgpig.z;
+    const int64_t i02 = tgpig.y;
+    const int64_t i01 = tgpig.x;
+
+    const int64_t i13 = i03 % ne13;
+    const int64_t i12 = i02 % ne12;
+    const int64_t i11 = i01 % ne11;
+
+    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
+    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1  + offs;
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        const int i10 = i0 % ne10;
+        *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10));
+    }
+}
+
 kernel void kernel_mul(
         device const char * src0,
         device const char * src1,
@@ -226,6 +276,15 @@ kernel void kernel_add_row(
     dst[tpig] = src0[tpig] + src1[tpig % nb];
 }
 
+kernel void kernel_sub_row(
+        device const float4 * src0,
+        device const float4 * src1,
+        device       float4 * dst,
+        constant   uint64_t & nb [[buffer(28)]],
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] - src1[tpig % nb];
+}
+
 kernel void kernel_mul_row(
         device const float4 * src0,
         device const float4 * src1,
@@ -358,6 +417,27 @@ kernel void kernel_sqr(
     dst[tpig] = src0[tpig] * src0[tpig];
 }
 
+kernel void kernel_sqrt(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = sqrt(src0[tpig]);
+}
+
+kernel void kernel_sin(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = sin(src0[tpig]);
+}
+
+kernel void kernel_cos(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = cos(src0[tpig]);
+}
+
 kernel void kernel_sum_rows(
         device const float * src0,
         device       float * dst,
index d5b91c2dbc0c17ce3a6691d7274c2f5dd794c26b..48b90f01b5a0a5533a0d203629d0764db8a17dac 100644 (file)
@@ -3644,7 +3644,7 @@ void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
     quantize_row_q8_K_ref(x, y, k);
 }
 
-//===================================== Dot ptoducts =================================
+//===================================== Dot products =================================
 
 //
 // Helper functions
index 32fda32a879ba9d45eb29bbe7e7f8000634e6b2c..ca4f44cf75615afdd710c6c077fdd89b758df9e2 100644 (file)
@@ -188,6 +188,8 @@ struct vk_device_struct {
     vk_pipeline pipeline_upscale_f32;
     vk_pipeline pipeline_scale_f32;
     vk_pipeline pipeline_sqr_f32;
+    vk_pipeline pipeline_sin_f32;
+    vk_pipeline pipeline_cos_f32;
     vk_pipeline pipeline_clamp_f32;
     vk_pipeline pipeline_pad_f32;
     vk_pipeline pipeline_repeat_f32;
@@ -1702,6 +1704,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
@@ -4023,6 +4027,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             return ctx->device->pipeline_sqr_f32;
         }
         return nullptr;
+    case GGML_OP_SIN:
+        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+            return ctx->device->pipeline_sin_f32;
+        }
+        return nullptr;
+    case GGML_OP_COS:
+        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+            return ctx->device->pipeline_cos_f32;
+        }
+        return nullptr;
     case GGML_OP_CLAMP:
         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
             return ctx->device->pipeline_clamp_f32;
@@ -4171,6 +4185,8 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
     case GGML_OP_UPSCALE:
     case GGML_OP_SCALE:
     case GGML_OP_SQR:
+    case GGML_OP_SIN:
+    case GGML_OP_COS:
     case GGML_OP_CLAMP:
     case GGML_OP_PAD:
     case GGML_OP_REPEAT:
@@ -4381,6 +4397,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
     case GGML_OP_MUL:
     case GGML_OP_SCALE:
     case GGML_OP_SQR:
+    case GGML_OP_SIN:
+    case GGML_OP_COS:
     case GGML_OP_CLAMP:
     case GGML_OP_PAD:
     case GGML_OP_REPEAT:
@@ -4598,6 +4616,32 @@ static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const
     }, dryrun);
 }
 
+static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+    const uint32_t src0_type_size = ggml_type_size(src0->type);
+    const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+    ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, {
+        (uint32_t)ggml_nelements(src0),
+        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
+        0,
+        0.0f, 0.0f,
+    });
+}
+
+static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+    const uint32_t src0_type_size = ggml_type_size(src0->type);
+    const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+    ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, {
+        (uint32_t)ggml_nelements(src0),
+        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
+        0,
+        0.0f, 0.0f,
+    });
+}
+
 static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
     float * op_params = (float *)dst->op_params;
     const uint32_t src0_type_size = ggml_type_size(src0->type);
@@ -5658,6 +5702,8 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
     case GGML_OP_UPSCALE:
     case GGML_OP_SCALE:
     case GGML_OP_SQR:
+    case GGML_OP_SIN:
+    case GGML_OP_COS:
     case GGML_OP_CLAMP:
     case GGML_OP_PAD:
     case GGML_OP_CPY:
@@ -5735,6 +5781,14 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
     case GGML_OP_SQR:
         ggml_vk_sqr(ctx, compute_ctx, src0, node, dryrun);
 
+        break;
+    case GGML_OP_SIN:
+        ggml_vk_sin(ctx, compute_ctx, src0, node);
+
+        break;
+    case GGML_OP_COS:
+        ggml_vk_cos(ctx, compute_ctx, src0, node);
+
         break;
     case GGML_OP_CLAMP:
         ggml_vk_clamp(ctx, compute_ctx, src0, node, dryrun);
@@ -5851,6 +5905,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
     case GGML_OP_UPSCALE:
     case GGML_OP_SCALE:
     case GGML_OP_SQR:
+    case GGML_OP_SIN:
+    case GGML_OP_COS:
     case GGML_OP_CLAMP:
     case GGML_OP_PAD:
     case GGML_OP_CPY:
@@ -6582,6 +6638,8 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
         case GGML_OP_UPSCALE:
         case GGML_OP_SCALE:
         case GGML_OP_SQR:
+        case GGML_OP_SIN:
+        case GGML_OP_COS:
         case GGML_OP_CLAMP:
         case GGML_OP_PAD:
         case GGML_OP_CONT:
@@ -7024,6 +7082,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
         tensor_clone = ggml_scale(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0]);
     } else if (tensor->op == GGML_OP_SQR) {
         tensor_clone = ggml_sqr(ggml_ctx, src0_clone);
+    } else if (tensor->op == GGML_OP_SIN) {
+        tensor_clone = ggml_sin(ggml_ctx, src0_clone);
+    } else if (tensor->op == GGML_OP_COS) {
+        tensor_clone = ggml_cos(ggml_ctx, src0_clone);
     } else if (tensor->op == GGML_OP_CLAMP) {
         tensor_clone = ggml_clamp(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
     } else if (tensor->op == GGML_OP_PAD) {
index e52471ce3f861d8c6c986f8c8b4ac47a5d829342..9c105fd353de4c12ab189fc7f2601d7e3ae8238d 100644 (file)
@@ -2310,7 +2310,9 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
 inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s);   }
 inline static void ggml_vec_sqr_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i];   }
 inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
-inline static void ggml_vec_log_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]);   }
+inline static void ggml_vec_log_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]);  }
+inline static void ggml_vec_sin_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]);  }
+inline static void ggml_vec_cos_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]);  }
 inline static void ggml_vec_abs_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
 inline static void ggml_vec_sgn_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
 inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
@@ -2669,6 +2671,19 @@ static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x,
     return sum;
 }
 
+static ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) {
+    // log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i)
+
+    int i = 0;
+    ggml_float sum = 0;
+    for (; i < n; ++i) {
+        float val = x[i] - max;
+        y[i] = val;
+        sum += (ggml_float)expf(val);
+    }
+    return sum = (ggml_float)logf(sum);
+}
+
 inline static float ggml_silu_backward_f32(float x, float dy) {
     const float s = 1.0f/(1.0f + expf(-x));
     return dy*s*(1.0f + x*(1.0f - s));
@@ -2760,6 +2775,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "SQR",
     "SQRT",
     "LOG",
+    "SIN",
+    "COS",
     "SUM",
     "SUM_ROWS",
     "MEAN",
@@ -2797,9 +2814,11 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CLAMP",
     "CONV_TRANSPOSE_1D",
     "IM2COL",
+    "IM2COL_BACK",
     "CONV_TRANSPOSE_2D",
     "POOL_1D",
     "POOL_2D",
+    "POOL_2D_BACK",
     "UPSCALE",
     "PAD",
     "ARANGE",
@@ -2833,7 +2852,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CROSS_ENTROPY_LOSS_BACK",
 };
 
-static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
+static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -2848,6 +2867,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "x^2",
     "√x",
     "log(x)",
+    "sin(x)",
+    "cos(x)",
     "Σx",
     "Σx_k",
     "Σx/n",
@@ -2885,9 +2906,11 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "clamp(x)",
     "conv_transpose_1d(x)",
     "im2col(x)",
+    "im2col_back(x)",
     "conv_transpose_2d(x)",
     "pool_1d(x)",
     "pool_2d(x)",
+    "pool_2d_back(x)",
     "upscale(x)",
     "pad(x)",
     "arange(start, stop, step)",
@@ -2921,7 +2944,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "cross_entropy_loss_back(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
+static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -3767,6 +3790,7 @@ static struct ggml_tensor * ggml_new_tensor_impl(
     }
 
     struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TYPE_TENSOR, GGML_TENSOR_SIZE + obj_alloc_size);
+    GGML_ASSERT(obj_new);
 
     // TODO: for recoverable errors, we would need to free the data allocated from the scratch buffer here
 
@@ -4486,8 +4510,6 @@ static struct ggml_tensor * ggml_add_impl(
     bool is_node = false;
 
     if (!inplace && (a->grad || b->grad)) {
-        // TODO: support backward pass for broadcasting
-        GGML_ASSERT(ggml_are_same_shape(a, b));
         is_node = true;
     }
 
@@ -4661,11 +4683,13 @@ static struct ggml_tensor * ggml_sub_impl(
         struct ggml_tensor * a,
         struct ggml_tensor * b,
         bool inplace) {
-    GGML_ASSERT(ggml_are_same_shape(a, b));
+    GGML_ASSERT(ggml_can_repeat(b, a));
 
     bool is_node = false;
 
     if (!inplace && (a->grad || b->grad)) {
+        // TODO: support backward pass for broadcasting
+        GGML_ASSERT(ggml_are_same_shape(a, b));
         is_node = true;
     }
 
@@ -4880,6 +4904,72 @@ struct ggml_tensor * ggml_log_inplace(
     return ggml_log_impl(ctx, a, true);
 }
 
+// ggml_sin
+
+static struct ggml_tensor * ggml_sin_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_SIN;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_sin(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_sin_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_sin_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_sin_impl(ctx, a, true);
+}
+
+// ggml_cos
+
+static struct ggml_tensor * ggml_cos_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_COS;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_cos(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_cos_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_cos_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_cos_impl(ctx, a, true);
+}
+
 // ggml_sum
 
 struct ggml_tensor * ggml_sum(
@@ -6727,17 +6817,20 @@ struct ggml_tensor * ggml_im2col(
         GGML_ASSERT(a->ne[2] == b->ne[2]);
     } else {
         GGML_ASSERT(a->ne[1] == b->ne[1]);
+        GGML_ASSERT(b->ne[3] == 1);
     }
     bool is_node = false;
 
-    if (a->grad || b->grad) {
-        GGML_ABORT("fatal error"); // TODO: implement backward
+    if (/*a->grad ||*/ b->grad) { // a is only used for its shape, not its data
         is_node = true;
     }
 
     const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0;
     const int64_t OW =         ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
 
+    GGML_ASSERT((!is_2D || OH > 0) && "b too small compared to a");
+    GGML_ASSERT((OW > 0)           && "b too small compared to a");
+
     const int64_t ne[4] = {
         is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0],
         OW,
@@ -6757,6 +6850,37 @@ struct ggml_tensor * ggml_im2col(
     return result;
 }
 
+struct ggml_tensor * ggml_im2col_back(
+    struct ggml_context * ctx,
+    struct ggml_tensor  * a,
+    struct ggml_tensor  * b,
+    int64_t             * ne,
+    int                   s0,
+    int                   s1,
+    int                   p0,
+    int                   p1,
+    int                   d0,
+    int                   d1,
+    bool                  is_2D) {
+
+    bool is_node = false;
+
+    if (/*a->grad ||*/ b->grad) { // a is only used for its shape, not its data
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+    int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
+    ggml_set_op_params(result, params, sizeof(params));
+
+    result->op = GGML_OP_IM2COL_BACK;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
 // a: [OC,IC, KH, KW]
 // b: [N, IC, IH, IW]
 // result: [N, OC, OH, OW]
@@ -6770,7 +6894,7 @@ struct ggml_tensor * ggml_conv_2d(
         int                  p1,
         int                  d0,
         int                  d1) {
-    struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N, OH, OW, IC * KH * KW]
+    struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, a->type); // [N, OH, OW, IC * KH * KW]
 
     struct ggml_tensor * result =
         ggml_mul_mat(ctx,
@@ -6896,17 +7020,17 @@ struct ggml_tensor * ggml_pool_2d(
     bool is_node = false;
 
     if (a->grad) {
-        GGML_ABORT("fatal error"); // TODO: implement backward
         is_node = true;
     }
 
     struct ggml_tensor * result;
-    const int64_t ne[3] = {
+    const int64_t ne[4] = {
         ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
         ggml_calc_pool_output_size(a->ne[1], k1, s1, p1),
         a->ne[2],
+        a->ne[3],
     };
-    result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
+    result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
     int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
     ggml_set_op_params(result, params, sizeof(params));
@@ -6917,6 +7041,37 @@ struct ggml_tensor * ggml_pool_2d(
     return result;
 }
 
+struct ggml_tensor * ggml_pool_2d_back(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * af,
+        enum ggml_op_pool     op,
+        int                   k0,
+        int                   k1,
+        int                   s0,
+        int                   s1,
+        float                 p0,
+        float                 p1) {
+
+    bool is_node = false;
+
+    if (a->grad) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result;
+    result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, af->ne);
+
+    int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
+    ggml_set_op_params(result, params, sizeof(params));
+
+    result->op = GGML_OP_POOL_2D_BACK;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+    result->src[1] = af;
+    return result;
+}
+
 // ggml_upscale
 
 static struct ggml_tensor * ggml_upscale_impl(
@@ -10098,11 +10253,10 @@ static void ggml_compute_forward_sub_f32(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    if (params->ith != 0) {
-        return;
-    }
+    assert(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
 
-    assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+    const int ith = params->ith;
+    const int nth = params->nth;
 
     const int nr  = ggml_nrows(src0);
 
@@ -10111,40 +10265,55 @@ static void ggml_compute_forward_sub_f32(
     GGML_ASSERT( nb0 == sizeof(float));
     GGML_ASSERT(nb00 == sizeof(float));
 
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
     if (nb10 == sizeof(float)) {
-        for (int ir = 0; ir < nr; ++ir) {
-            // src0, src1 and dst are same shape => same indices
-            const int i3 = ir/(ne2*ne1);
-            const int i2 = (ir - i3*ne2*ne1)/ne1;
-            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+        for (int ir = ir0; ir < ir1; ++ir) {
+            // src1 is broadcastable across src0 and dst in i1, i2, i3
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+            const int64_t nr0 = ne00 / ne10;
 
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+            float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
+
+            for (int64_t r = 0; r < nr0; ++r) {
 #ifdef GGML_USE_ACCELERATE
-            vDSP_vsub(
-                    (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
-                    (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
-                    (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ), 1,
-                    ne0);
+                vDSP_vsub(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
 #else
-            ggml_vec_sub_f32(ne0,
-                    (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
-                    (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
-                    (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
+                ggml_vec_sub_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
 #endif
-                // }
-            // }
+            }
         }
     } else {
         // src1 is not contiguous
-        for (int ir = 0; ir < nr; ++ir) {
-            // src0, src1 and dst are same shape => same indices
-            const int i3 = ir/(ne2*ne1);
-            const int i2 = (ir - i3*ne2*ne1)/ne1;
-            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+        for (int ir = ir0; ir < ir1; ++ir) {
+            // src1 is broadcastable across src0 and dst in i1, i2, i3
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
 
-            float * dst_ptr  = (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
-            float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-            for (int i0 = 0; i0 < ne0; i0++) {
-                float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+
+            for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                const int64_t i10 = i0 % ne10;
+                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
 
                 dst_ptr[i0] = src0_ptr[i0] - *src1_ptr;
             }
@@ -10490,6 +10659,96 @@ static void ggml_compute_forward_log(
     }
 }
 
+// ggml_compute_forward_sin
+
+static void ggml_compute_forward_sin_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    GGML_ASSERT( dst->nb[0] == sizeof(float));
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_sin_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_sin(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sin_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_cos
+
+static void ggml_compute_forward_cos_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    GGML_ASSERT( dst->nb[0] == sizeof(float));
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_cos_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_cos(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_cos_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 // ggml_compute_forward_sum
 
 static void ggml_compute_forward_sum_f32(
@@ -14525,6 +14784,7 @@ static void ggml_compute_forward_conv_transpose_1d(
     }
 }
 
+// ggml_compute_forward_im2col_f32
 // src0: kernel [OC, IC, KH, KW]
 // src1: image [N, IC, IH, IW]
 // dst:  result [N, OH, OW, IC*KH*KW]
@@ -14535,7 +14795,6 @@ static void ggml_compute_forward_im2col_f32(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
@@ -14566,7 +14825,6 @@ static void ggml_compute_forward_im2col_f32(
     int ofs0 = is_2D ? nb13 : nb12;
     int ofs1 = is_2D ? nb12 : nb11;
 
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
     GGML_ASSERT(nb10 == sizeof(float));
 
     // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
@@ -14602,6 +14860,7 @@ static void ggml_compute_forward_im2col_f32(
 }
 
 
+// ggml_compute_forward_im2col_f16
 // src0: kernel [OC, IC, KH, KW]
 // src1: image [N, IC, IH, IW]
 // dst:  result [N, OH, OW, IC*KH*KW]
@@ -14697,6 +14956,99 @@ static void ggml_compute_forward_im2col(
     }
 }
 
+// ggml_compute_forward_im2col_back_f32
+
+static void ggml_compute_forward_im2col_back_f32(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
+    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
+    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
+    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
+    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t N  = is_2D ? ne3 : ne2;
+    const int64_t IC = is_2D ? ne2 : ne1;
+    const int64_t IH = is_2D ? ne1 : 1;
+    const int64_t IW = ne0;
+
+    const int64_t KH = is_2D ? ne01 : 1;
+    const int64_t KW = ne00;
+
+    const int64_t OH = is_2D ? ne12 : 1;
+    const int64_t OW = ne11;
+
+    int ofs0 = is_2D ? nb3 : nb2;
+    int ofs1 = is_2D ? nb2 : nb1;
+
+    GGML_ASSERT(nb0  == sizeof(float));
+
+    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
+    {
+        float * const wdata = (float *) dst->data;
+
+        for (int64_t in = 0; in < N; in++) {
+            for (int64_t iic = ith; iic < IC; iic += nth) {
+                for (int64_t iih = 0; iih < IH; iih++) {
+                    for (int64_t iiw = 0; iiw < IW; iiw++) {
+
+                        // micro kernel
+                        float grad = 0.0f;
+                        for (int64_t ikh = 0; ikh < KH; ikh++) {
+                            for (int64_t ikw = 0; ikw < KW; ikw++) {
+                                // For s0 > 1 some values were skipped over in the forward pass.
+                                // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.
+                                const int64_t tmpw = (iiw + p0 - ikw*d0);
+                                if (tmpw % s0 != 0) {
+                                    continue;
+                                }
+                                const int64_t iow = tmpw / s0;
+
+                                // Equivalent logic as above except for s1.
+                                int64_t ioh;
+                                if (is_2D) {
+                                    const int64_t tmph = iih + p1 - ikh*d1;
+
+                                    if (tmph % s1 != 0) {
+                                        continue;
+                                    }
+
+                                    ioh = tmph / s1;
+                                } else {
+                                    ioh = 0;
+                                }
+
+                                if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
+                                    continue;
+                                }
+
+                                const float * const src_data = (const float *) src1->data
+                                    + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
+                                grad += src_data[iic*(KH*KW) + ikh*KW + ikw];
+                            }
+                        }
+                        float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
+                        dst_data[iih*IW + iiw] = grad;
+                    }
+                }
+            }
+        }
+    }
+}
 
 // ggml_compute_forward_conv_transpose_2d
 
@@ -14939,6 +15291,128 @@ static void ggml_compute_forward_pool_2d(
     }
 }
 
+// ggml_compute_forward_pool_2d_back
+
+static void ggml_compute_forward_pool_2d_back(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src  = dst->src[0];
+    const struct ggml_tensor * dstf = dst->src[1]; // forward tensor of dst
+
+    assert(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    const int32_t * opts = (const int32_t *)dst->op_params;
+    enum ggml_op_pool op = opts[0];
+    const int k0 = opts[1];
+    const int k1 = opts[2];
+    const int s0 = opts[3];
+    const int s1 = opts[4];
+    const int p0 = opts[5];
+    const int p1 = opts[6];
+
+    char       * cdata  = (char       *) dst->data;
+    const char * cdataf = (const char *) dstf->data;
+    const char * const data_end = cdata + ggml_nbytes(dst);
+
+    GGML_ASSERT(params->ith == 0);
+    memset(cdata, 0, ggml_nbytes(dst));
+
+    const int64_t px = src->ne[0];
+    const int64_t py = src->ne[1];
+    const int64_t pa = px * py;
+
+    const float * splane = (const float *) src->data;
+
+    const int ka = k0 * k1;
+    const int offset0 = -p0;
+    const int offset1 = -p1;
+
+    while (cdata < data_end) {
+        for (int oy = 0; oy < py; ++oy) {
+            const float * const srow = splane + oy * px;
+            for (int ox = 0; ox < px; ++ox) {
+                const float grad0 = srow[ox];
+
+                const int ix = offset0 + ox * s0;
+                const int iy = offset1 + oy * s1;
+
+                if (op == GGML_OP_POOL_MAX) {
+                    float maxval = -FLT_MAX;
+                    int kxmax = -1;
+                    int kymax = -1;
+
+                    for (int ky = 0; ky < k1; ++ky) {
+                        if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
+                            continue;
+                        }
+                        const void * drowf = (const void *)(cdataf + dst->nb[1] * (iy + ky));
+                        for (int kx = 0; kx < k0; ++kx) {
+                            int j = ix + kx;
+                            if (j < 0 || j >= dst->ne[0]) {
+                                continue;
+                            }
+
+                            const float val = dst->type == GGML_TYPE_F32 ?
+                                ((const float *) drowf)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t *) drowf)[j]);
+                            if (val <= maxval) {
+                                continue;
+                            }
+
+                            maxval = val;
+                            kxmax = kx;
+                            kymax = ky;
+                        }
+                    }
+
+                    if (kxmax == -1 || kymax == -1) {
+                        continue;
+                    }
+
+                    void * drow = (void *)(cdata + dst->nb[1] * (iy + kymax));
+                    const int j = ix + kxmax;
+                    if (dst->type == GGML_TYPE_F32) {
+                        ((float *) drow)[j] += grad0;
+                    } else {
+                        ((ggml_fp16_t *) drow)[j] = GGML_FP32_TO_FP16(grad0 + GGML_FP16_TO_FP32(((const ggml_fp16_t *) drow)[j]));
+                    }
+                } else if (op == GGML_OP_POOL_AVG) {
+                    const float grad = grad0 / ka;
+
+                    for (int ky = 0; ky < k1; ++ky) {
+                        if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
+                            continue;
+                        }
+                        void * drow = (void *)(cdata + dst->nb[1] * (iy + ky));
+                        for (int kx = 0; kx < k0; ++kx) {
+                            int j = ix + kx;
+                            if (j < 0 || j >= dst->ne[0]) {
+                                continue;
+                            }
+
+                            if (dst->type == GGML_TYPE_F32) {
+                                ((float *) drow)[j] += grad;
+                            } else {
+                                ((ggml_fp16_t *) drow)[j] += GGML_FP32_TO_FP16(grad);
+                            }
+                        }
+                    }
+                } else {
+                    GGML_ASSERT(false);
+                }
+            }
+        }
+
+        cdata  += dst->nb[2];
+        cdataf += dst->nb[2];
+        splane += pa;
+    }
+}
+
 // ggml_compute_forward_upscale
 
 static void ggml_compute_forward_upscale_f32(
@@ -16481,8 +16955,6 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
     }
     ggml_barrier(params->shared);
 
-    const double eps = 1e-9;
-
     // rows per thread
     const int dr = (nr + nth - 1)/nth;
 
@@ -16503,20 +16975,15 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
         }
 #endif
 
-        // soft_max
         float max = -INFINITY;
         ggml_vec_max_f32(nc, &max, s0);
-        ggml_float sum = ggml_vec_soft_max_f32(nc, st, s0, max);
-        assert(sum > 0.0);
-        sum = (1.0 - eps) / sum;
+        ggml_float sum = ggml_vec_log_soft_max_f32(nc, st, s0, max);
+        assert(sum >= 0.0);
 
-        // avoid log(0) by rescaling from [0..1] to [eps..1]
-        ggml_vec_scale_f32(nc, st, sum);
-        ggml_vec_add1_f32(nc, st, st, eps);
-        ggml_vec_log_f32(nc, st, st);
+        ggml_vec_add1_f32(nc, st, st, -sum);
         ggml_vec_mul_f32(nc, st, st, s1);
 
-        float st_sum = 0;
+        float st_sum = 0.0f;
         ggml_vec_sum_f32(nc, &st_sum, st);
         sums[ith] += st_sum;
 
@@ -16573,8 +17040,6 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
     const int64_t ith = params->ith;
     const int64_t nth = params->nth;
 
-    const double eps = 1e-9;
-
     // TODO: handle transposed/permuted matrices
     const int64_t nc = src0->ne[0];
     const int64_t nr = ggml_nrows(src0);
@@ -16606,11 +17071,9 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
         ggml_vec_max_f32(nc, &max, s0);
         ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
         assert(sum > 0.0);
-        sum = (1.0 - eps) / sum;
+        ggml_vec_scale_f32(nc, ds0, 1.0/sum);
 
         // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
-        ggml_vec_scale_f32(nc, ds0, sum);
-        ggml_vec_add1_f32(nc, ds0, ds0, eps);
         ggml_vec_sub_f32(nc, ds0, ds0, s1);
         ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr);
 
@@ -16691,6 +17154,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_log(params, tensor);
             } break;
+        case GGML_OP_SIN:
+            {
+                ggml_compute_forward_sin(params, tensor);
+            } break;
+        case GGML_OP_COS:
+            {
+                ggml_compute_forward_cos(params, tensor);
+            } break;
         case GGML_OP_SUM:
             {
                 ggml_compute_forward_sum(params, tensor);
@@ -16831,6 +17302,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_im2col(params, tensor);
             } break;
+        case GGML_OP_IM2COL_BACK:
+            {
+                ggml_compute_forward_im2col_back_f32(params, tensor);
+            } break;
         case GGML_OP_CONV_TRANSPOSE_2D:
             {
                 ggml_compute_forward_conv_transpose_2d(params, tensor);
@@ -16843,6 +17318,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_pool_2d(params, tensor);
             } break;
+        case GGML_OP_POOL_2D_BACK:
+            {
+                ggml_compute_forward_pool_2d_back(params, tensor);
+            } break;
         case GGML_OP_UPSCALE:
             {
                 ggml_compute_forward_upscale(params, tensor);
@@ -17211,7 +17690,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
                 }
                 if (src1->grad) {
-                    src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
+                    if (ggml_are_same_shape(src0, src1)) {
+                        src1->grad = ggml_add_or_set(ctx, src1->grad,                       tensor->grad,        zero_table);
+                    } else {
+                        src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_repeat_back(ctx, tensor->grad, src1), zero_table);
+                    }
                 }
             } break;
         case GGML_OP_ADD1:
@@ -17337,6 +17820,30 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 zero_table);
                 }
             } break;
+        case GGML_OP_SIN:
+            {
+                if (src0->grad) {
+                    src0->grad =
+                        ggml_add_or_set(ctx,
+                                src0->grad,
+                                ggml_mul(ctx,
+                                    tensor->grad,
+                                    ggml_cos(ctx, src0)),
+                                zero_table);
+                }
+            } break;
+        case GGML_OP_COS:
+            {
+                if (src0->grad) {
+                    src0->grad =
+                        ggml_sub_or_set(ctx,
+                                src0->grad,
+                                ggml_mul(ctx,
+                                    tensor->grad,
+                                    ggml_sin(ctx, src0)),
+                                zero_table);
+                }
+            } break;
         case GGML_OP_SUM:
             {
                 if (src0->grad) {
@@ -17784,6 +18291,23 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 GGML_ABORT("fatal error"); // TODO: not implemented
             }
         case GGML_OP_IM2COL:
+            {
+                if (src1->grad) {
+                    const int32_t s0    = ggml_get_op_params_i32(tensor, 0);
+                    const int32_t s1    = ggml_get_op_params_i32(tensor, 1);
+                    const int32_t p0    = ggml_get_op_params_i32(tensor, 2);
+                    const int32_t p1    = ggml_get_op_params_i32(tensor, 3);
+                    const int32_t d0    = ggml_get_op_params_i32(tensor, 4);
+                    const int32_t d1    = ggml_get_op_params_i32(tensor, 5);
+                    const bool    is_2D = ggml_get_op_params_i32(tensor, 6) == 1;
+
+                    src1->grad = ggml_add_or_set(ctx,
+                            src1->grad,
+                            ggml_im2col_back(ctx, src0, tensor->grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D),
+                            zero_table);
+                }
+            } break;
+        case GGML_OP_IM2COL_BACK:
             {
                 GGML_ABORT("fatal error"); // TODO: not implemented
             }
@@ -17796,6 +18320,23 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 GGML_ABORT("fatal error"); // TODO: not implemented
             }
         case GGML_OP_POOL_2D:
+            {
+                if (src0->grad) {
+                    const enum ggml_op_pool op = ggml_get_op_params_i32(tensor, 0);
+                    const      int32_t      k0 = ggml_get_op_params_i32(tensor, 1);
+                    const      int32_t      k1 = ggml_get_op_params_i32(tensor, 2);
+                    const      int32_t      s0 = ggml_get_op_params_i32(tensor, 3);
+                    const      int32_t      s1 = ggml_get_op_params_i32(tensor, 4);
+                    const      int32_t      p0 = ggml_get_op_params_i32(tensor, 5);
+                    const      int32_t      p1 = ggml_get_op_params_i32(tensor, 6);
+
+                    src0->grad = ggml_add_or_set(ctx,
+                            src0->grad,
+                            ggml_pool_2d_back(ctx, tensor->grad, src0, op, k0, k1, s0, s1, p0, p1),
+                            zero_table);
+                }
+            } break;
+        case GGML_OP_POOL_2D_BACK:
             {
                 GGML_ABORT("fatal error"); // TODO: not implemented
             }
@@ -18085,6 +18626,7 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor *
 
 void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep) {
     GGML_ASSERT(gf->n_nodes > 0);
+    GGML_ASSERT(gf->grads);
 
     // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph
     if (keep) {
@@ -18424,6 +18966,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_SQR:
         case GGML_OP_SQRT:
         case GGML_OP_LOG:
+        case GGML_OP_SIN:
+        case GGML_OP_COS:
         case GGML_OP_SUM:
         case GGML_OP_SUM_ROWS:
         case GGML_OP_MEAN:
@@ -18510,6 +19054,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
                 n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));
             } break;
         case GGML_OP_IM2COL:
+        case GGML_OP_IM2COL_BACK:
         case GGML_OP_CONV_TRANSPOSE_1D:
         case GGML_OP_CONV_TRANSPOSE_2D:
             {
@@ -18517,6 +19062,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             } break;
         case GGML_OP_POOL_1D:
         case GGML_OP_POOL_2D:
+        case GGML_OP_POOL_2D_BACK:
             {
                 n_tasks = 1;
             } break;
@@ -19030,9 +19576,11 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
 
                 const uint32_t type   = tensor->type;
                 const uint32_t op     = tensor->op;
+                const int32_t  flags  = tensor->flags;
 
                 fwrite(&type,   sizeof(uint32_t), 1, fout);
                 fwrite(&op,     sizeof(uint32_t), 1, fout);
+                fwrite(&flags,  sizeof(int32_t),  1, fout);
 
                 for (int j = 0; j < GGML_MAX_DIMS; ++j) {
                     const uint64_t ne = tensor->ne[j];
@@ -19062,9 +19610,11 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
 
                 const uint32_t type   = tensor->type;
                 const uint32_t op     = tensor->op;
+                const int32_t  flags  = tensor->flags;
 
                 fwrite(&type,   sizeof(uint32_t), 1, fout);
                 fwrite(&op,     sizeof(uint32_t), 1, fout);
+                fwrite(&flags,  sizeof(int32_t),  1, fout);
 
                 for (int j = 0; j < GGML_MAX_DIMS; ++j) {
                     const uint64_t ne = tensor->ne[j];
@@ -19123,6 +19673,14 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
                         }
                     }
                 }
+
+                // dump the data
+                // TODO: pad this to 32 byte boundary
+                if ((flags & GGML_TENSOR_FLAG_PARAM)) {
+                    const size_t size = ggml_nbytes(tensor);
+
+                    fwrite(tensor->data, sizeof(char), size, fout);
+                }
             }
         }
 
@@ -19236,10 +19794,12 @@ struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context *
         {
             uint32_t type;
             uint32_t op;
+            int32_t  flags;
 
             for (uint32_t i = 0; i < n_leafs; ++i) {
                 type   = *(const uint32_t *) ptr; ptr += sizeof(type);
                 op     = *(const uint32_t *) ptr; ptr += sizeof(op);
+                flags  = *(const int32_t  *) ptr; ptr += sizeof(flags);
 
                 int64_t ne[GGML_MAX_DIMS];
                 size_t  nb[GGML_MAX_DIMS];
@@ -19257,20 +19817,19 @@ struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context *
 
                 struct ggml_tensor * tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, GGML_MAX_DIMS, ne);
 
-                tensor->op = (enum ggml_op) op;
+                tensor->op    = (enum ggml_op) op;
+                tensor->flags = flags;
 
                 memcpy(tensor->name,      ptr, GGML_MAX_NAME);      ptr += GGML_MAX_NAME;
                 memcpy(tensor->op_params, ptr, GGML_MAX_OP_PARAMS); ptr += GGML_MAX_OP_PARAMS;
 
-                tensor->data = (void *) ptr;
-
                 for (int j = 0; j < GGML_MAX_DIMS; ++j) {
                     tensor->nb[j] = nb[j];
                 }
 
-                result->leafs[i] = tensor;
+                tensor->data = (void *) ptr; ptr += ggml_nbytes(tensor);
 
-                ptr += ggml_nbytes(tensor);
+                result->leafs[i] = tensor;
 
                 fprintf(stderr, "%s: loaded leaf %u: '%16s', %9zu bytes\n", __func__, i, tensor->name, ggml_nbytes(tensor));
             }
@@ -19282,10 +19841,12 @@ struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context *
         {
             uint32_t type;
             uint32_t op;
+            int32_t  flags;
 
             for (uint32_t i = 0; i < n_nodes; ++i) {
                 type   = *(const uint32_t *) ptr; ptr += sizeof(type);
                 op     = *(const uint32_t *) ptr; ptr += sizeof(op);
+                flags  = *(const int32_t  *) ptr; ptr += sizeof(flags);
 
                 enum ggml_op eop = (enum ggml_op) op;
 
@@ -19375,6 +19936,11 @@ struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context *
 
                 result->nodes[i] = tensor;
 
+                // TODO tensor data is be duplicated due to ggml_new_tensor call above
+                if (flags & GGML_TENSOR_FLAG_PARAM) {
+                    tensor->data = (void *) ptr; ptr += ggml_nbytes(tensor);
+                }
+
                 fprintf(stderr, "%s: loaded node %u: '%16s', %9zu bytes\n", __func__, i, tensor->name, ggml_nbytes(tensor));
             }
         }
@@ -19643,6 +20209,7 @@ static enum ggml_opt_result ggml_opt_adam(
         ggml_opt_callback callback,
         void * callback_data) {
     GGML_ASSERT(ggml_is_scalar(f));
+    GGML_ASSERT(f->type == GGML_TYPE_F32);
 
     // these will store the parameters we want to optimize
     struct ggml_tensor * ps[GGML_MAX_PARAMS];
@@ -20409,6 +20976,8 @@ enum ggml_opt_result ggml_opt(
         struct ggml_context * ctx,
         struct ggml_opt_params params,
         struct ggml_tensor * f) {
+    GGML_ASSERT(f->grad && "ggml_set_param called for at least one parent tensor.");
+
     bool free_ctx = false;
     if (ctx == NULL) {
         struct ggml_init_params params_ctx = {
@@ -20463,6 +21032,8 @@ enum ggml_opt_result ggml_opt_resume_g(
         ggml_opt_callback callback,
         void * callback_data) {
 
+    GGML_ASSERT(f->grad && "ggml_set_param must be called for at least one ancestor");
+
     // build forward + backward compute graphs
     enum ggml_opt_result result = GGML_OPT_RESULT_OK;
 
@@ -21550,6 +22121,7 @@ void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) {
 void gguf_add_tensor(
              struct gguf_context * ctx,
         const struct ggml_tensor * tensor) {
+    GGML_ASSERT(tensor);
     if (gguf_find_tensor(ctx, tensor->name) != -1) {
         GGML_ABORT("duplicated tensor name");
     }
diff --git a/ggml/src/vulkan-shaders/cos.comp b/ggml/src/vulkan-shaders/cos.comp
new file mode 100644 (file)
index 0000000..f9a858c
--- /dev/null
@@ -0,0 +1,15 @@
+#version 450
+
+#include "types.comp"
+#include "generic_unary_head.comp"
+
+void main() {
+    const uint idx = get_idx();
+
+    if (idx >= p.ne) {
+        return;
+    }
+
+    const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
+    data_d[p.d_offset + dst_idx(idx)] = D_TYPE(cos(val));
+}
diff --git a/ggml/src/vulkan-shaders/sin.comp b/ggml/src/vulkan-shaders/sin.comp
new file mode 100644 (file)
index 0000000..7faf9be
--- /dev/null
@@ -0,0 +1,15 @@
+#version 450
+
+#include "types.comp"
+#include "generic_unary_head.comp"
+
+void main() {
+    const uint idx = get_idx();
+
+    if (idx >= p.ne) {
+        return;
+    }
+
+    const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
+    data_d[p.d_offset + dst_idx(idx)] = D_TYPE(sin(val));
+}
index eef6768b149db44000d577bfdb3080b67ee606ff..1e6db754fe68a7f8a63f12d84af386c39797ba7e 100644 (file)
@@ -1 +1 @@
-797faa25af14126eb30134d4033139ae3c5428ed
+28b7633d733bbeef0026570fbc61c79c5e9aa5ae
index 3955ef3323f5eff688243690bd3cf2f009bf9d5c..c832bc9569bbf160f29bb7cadf27da4d22b3a062 100644 (file)
@@ -1160,6 +1160,58 @@ struct test_sqrt : public test_case {
     }
 };
 
+// GGML_OP_SIN
+struct test_sin : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_sin(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 10})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_sin(ctx, a);
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t, -100.0f, 100.0f);
+        }
+    }
+};
+
+// GGML_OP_COS
+struct test_cos : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_cos(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 10})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_cos(ctx, a);
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t, -100.0f, 100.0f);
+        }
+    }
+};
+
 // GGML_OP_CLAMP
 struct test_clamp : public test_case {
     const ggml_type type;
@@ -1731,6 +1783,27 @@ struct test_flash_attn_ext : public test_case {
     }
 };
 
+// GGML_OP_CROSS_ENTROPY_LOSS
+struct test_cross_entropy_loss : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_cross_entropy_loss(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 10})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * logits = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * labels = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_cross_entropy_loss(ctx, logits, labels);
+        return out;
+    }
+};
+
 enum llm_norm_type {
     LLM_NORM,
     LLM_NORM_RMS,
@@ -2393,6 +2466,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
 
     test_cases.emplace_back(new test_sqr());
     test_cases.emplace_back(new test_sqrt());
+    test_cases.emplace_back(new test_sin());
+    test_cases.emplace_back(new test_cos());
     test_cases.emplace_back(new test_clamp());
 
     test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10,  1,  1}, 5));
@@ -2512,6 +2587,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         }
     }
 
+    test_cases.emplace_back(new test_cross_entropy_loss());
+
     // these tests are disabled to save execution time, but they can be handy for debugging
 #if 0
     test_cases.emplace_back(new test_llama(1));
index a353276459b2d144abd609b318d7e60558fbf4c7..1834c11d894b4ce4279715598fe35c0c8c1c83a2 100644 (file)
@@ -1,10 +1,14 @@
 #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
 #include "ggml.h"
 
+#include <cfloat>
 #include <cmath>
+#include <cstdint>
 #include <cstdio>
 #include <cstdlib>
 #include <cassert>
+#include <initializer_list>
+#include <vector>
 
 #if defined(_MSC_VER)
 #pragma warning(disable: 4244 4267) // possible loss of data
@@ -217,7 +221,8 @@ static bool check_gradient(
         int nargs,
         float eps,
         float max_error_abs,
-        float max_error_rel) {
+        float max_error_rel,
+        std::vector<double> expected_vals) {
 
     static int n_threads = -1;
     if (n_threads < 0) {
@@ -248,9 +253,10 @@ static bool check_gradient(
     // ggml_graph_dump_dot(gb, gf,  "test-grad0-backward.dot");
 
     for (int i = 0; i < nargs; ++i) {
+        bool all_g0_bad = true;
         const int nelements = ggml_nelements(x[i]);
         for (int k = 0; k < nelements; ++k) {
-            // compute gradient using finite differences
+            // Calculate gradient numerically:
             const float x0 = ggml_get_f32_1d(x[i], k);
             const float xm = x0 - eps;
             const float xp = x0 + eps;
@@ -267,6 +273,28 @@ static bool check_gradient(
             const double f1 = ggml_get_f32_1d(f, 0);
             const double g0 = (f0 - f1)/(2.0*(double) eps);
 
+            // The numerical calculation of the gradient fails around noncontinuities (e.g. 0 for ReLU).
+            // In such cases, provide a vector of expected values and skip the comparison for failed calculations.
+            if (!expected_vals.empty()) {
+                bool matches_any = false;
+                for (const double & ev : expected_vals) {
+                    const double error_abs = std::fabs(g0 - ev);
+                    if (error_abs > max_error_abs) {
+                        continue;
+                    }
+                    const double error_rel = g0 != 0.0 ? fabs(g0 - ev)/fabs(g0) : 0.0;
+                    if (error_rel > max_error_rel) {
+                        continue;
+                    }
+                    matches_any = true;
+                    break;
+                }
+                if (!matches_any) {
+                    continue;
+                }
+            }
+            all_g0_bad = false;
+
             ggml_set_f32_1d(x[i], k, x0);
 
             // compute gradient using backward graph
@@ -278,7 +306,7 @@ static bool check_gradient(
             const double g1 = ggml_get_f32_1d(x[i]->grad, k);
 
             const double error_abs = fabs(g0 - g1);
-            const double error_rel = g0 != 0 ? fabs(g0 - g1)/fabs(g0) : 0;
+            const double error_rel = g0 != 0.0 ? fabs(g0 - g1)/fabs(g0) : 0.0;
 
             if (error_abs > max_error_abs || error_rel > max_error_rel) {
                 printf("%s: ndims=%d, i=%d, k=%d, x0=%f, xm=%f, xp=%f, f0=%f, f1=%f, g0=%f, g1=%f, eps=%f, error_abs=%f, error_rel=%f\n",
@@ -287,6 +315,10 @@ static bool check_gradient(
                 return false;
             }
         }
+        if (all_g0_bad) {
+            printf("%s: numerical calculation of the gradient failed for all values\n", op_name);
+            return false;
+        }
     }
 
     return true;
@@ -404,7 +436,7 @@ int main(int argc, const char ** argv) {
         seed_iter = rand();
         unsigned seed = rand();
 
-        printf("test-grad0: iter:%d/%d\n", iter, niter);
+        printf("test-grad0: iter:%d/%d\n", (iter+1), niter);
         struct ggml_context * ctx0 = ggml_init(params);
 
         get_random_dims(ne, 4);
@@ -424,7 +456,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1]));
 
-                check_gradient("add f32", ctx0, x, f, ndims, nargs, 1e-3f, 2e-3f, 2e-3f);
+                check_gradient("add f32", ctx0, x, f, ndims, nargs, 1e-3f, 2e-3f, 2e-3f, {});
             }
         }
 
@@ -441,7 +473,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1]));
 
-                check_gradient("add f16", ctx0, x, f, ndims, nargs, 1e-1f, 2e-1f, 2e-1f);
+                check_gradient("add f16", ctx0, x, f, ndims, nargs, 1e-1f, 2e-1f, 2e-1f, {});
             }
         }
 
@@ -458,7 +490,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_sub(ctx0, x[0], x[1]));
 
-                check_gradient("sub", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+                check_gradient("sub", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
             }
         }
 
@@ -475,7 +507,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_mul(ctx0, x[0], x[1]));
 
-                check_gradient("mul", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+                check_gradient("mul", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
             }
         }
 
@@ -492,7 +524,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_div(ctx0, x[0], x[1]));
 
-                check_gradient("div", ctx0, x, f, ndims, nargs, 1e-3f, 1e-1f, 1e-1f);
+                check_gradient("div", ctx0, x, f, ndims, nargs, 1e-3f, 1e-1f, 1e-1f, {});
             }
         }
 
@@ -509,7 +541,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, x[0]));
 
-                check_gradient("sqr", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+                check_gradient("sqr", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
             }
         }
 
@@ -526,7 +558,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqrt(ctx0, x[0]));
 
-                check_gradient("sqrt", ctx0, x, f, ndims, nargs, 1e-3f, 2e-2f, 1e-1f);
+                check_gradient("sqrt", ctx0, x, f, ndims, nargs, 1e-3f, 2e-2f, 1e-1f, {});
             }
         }
 
@@ -543,7 +575,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_log(ctx0, x[0]));
 
-                check_gradient("log", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-1f);
+                check_gradient("log", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-1f, {});
             }
         }
 
@@ -560,7 +592,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor * f = ggml_sum(ctx0, x[0]);
 
-                check_gradient("sum", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+                check_gradient("sum", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
             }
         }
 
@@ -578,7 +610,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sum_rows(ctx0, x[0])));
 
-                check_gradient("sum_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY);
+                check_gradient("sum_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY, {});
             }
         }
 
@@ -596,7 +628,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_mean(ctx0, x[0]));
 
-                check_gradient("mean", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+                check_gradient("mean", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
             }
         }
 
@@ -614,7 +646,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_argmax(ctx0, x[0]));
 
-                check_gradient("argmax", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+                check_gradient("argmax", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
             }
         }
 
@@ -637,7 +669,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[1], ggml_repeat(ctx0, x[0], x[1]))));
 
-                check_gradient("repeat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY);
+                check_gradient("repeat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY, {});
             }
         }
 
@@ -660,25 +692,25 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[0], ggml_repeat_back(ctx0, x[1], x[0]))));
 
-                check_gradient("repeat back", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY);
+                check_gradient("repeat back", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY, {});
             }
         }
 
-        // abs (finite differences do not work)
-        //{
-        //    const int nargs = 1;
+        // abs
+        {
+           const int nargs = 1;
 
-        //    for (int ndims = 1; ndims <= 2; ++ndims) {
-        //        for (int i = 0; i < nargs; ++i) {
-        //            x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-        //            ggml_set_param(ctx0, x[i]);
-        //        }
+           for (int ndims = 1; ndims <= 4; ++ndims) {
+               for (int i = 0; i < nargs; ++i) {
+                   x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                   ggml_set_param(ctx0, x[i]);
+               }
 
-        //        struct ggml_tensor * f = ggml_sum(ctx0, ggml_abs(ctx0, x[0]));
+               struct ggml_tensor * f = ggml_sum(ctx0, ggml_abs(ctx0, x[0]));
 
-        //        check_gradient("abs", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-3f);
-        //    }
-        //}
+               check_gradient("abs", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-3f, {-1.0, 1.0});
+           }
+        }
 
         // sgn
         {
@@ -693,7 +725,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor* f = ggml_sum(ctx0, ggml_sgn(ctx0, x[0]));
 
-                check_gradient("sgn", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+                check_gradient("sgn", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {0.0});
             }
         }
 
@@ -710,7 +742,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor* f = ggml_sum(ctx0, ggml_neg(ctx0, x[0]));
 
-                check_gradient("neg", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+                check_gradient("neg", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
             }
         }
 
@@ -727,7 +759,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor* f = ggml_sum(ctx0, ggml_step(ctx0, x[0]));
 
-                check_gradient("step", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+                check_gradient("step", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {0.0});
             }
         }
 
@@ -745,7 +777,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor* f = ggml_sum(ctx0, ggml_tanh(ctx0, x[0]));
 
-                check_gradient("tanh", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+                check_gradient("tanh", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
             }
         }
 
@@ -776,7 +808,7 @@ int main(int argc, const char ** argv) {
 
                         GGML_PRINT_DEBUG("testing: mul_mat, [%lld, %lld] (%d) * [%lld, %lld] (%d)\n", x[1]->ne[0], x[1]->ne[1], x[1]->n_dims, x[0]->ne[0], x[0]->ne[1], x[0]->n_dims);
 
-                        check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+                        check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
                         if (ndims == 2) {
                             // check_mat_mul does not support ndims > 2
                             check_mat_mul(m, x[1], x[0]);
@@ -800,7 +832,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor* f = ggml_sum(ctx0, ggml_elu(ctx0, x[0]));
 
-                check_gradient("elu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+                check_gradient("elu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
             }
         }
 
@@ -817,7 +849,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor* f = ggml_sum(ctx0, ggml_relu(ctx0, x[0]));
 
-                check_gradient("relu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+                check_gradient("relu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {0.0, 1.0});
             }
         }
 
@@ -835,7 +867,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor* f = ggml_sum(ctx0, ggml_gelu(ctx0, x[0]));
 
-                check_gradient("gelu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+                check_gradient("gelu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
             }
         }
 
@@ -854,9 +886,9 @@ int main(int argc, const char ** argv) {
 
 #ifdef GGML_SILU_FP16
                 // due to GGML_SILU_FP16 the finite difference method will be slightly wrong -> increase error bounds.
-                check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 0.5, INFINITY);
+                check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 0.5, INFINITY, {});
 #else
-                check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+                check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
 #endif
             }
         }
@@ -874,7 +906,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0], 1e-6f));
 
-                check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY);
+                check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY, {});
             }
         }
 
@@ -892,7 +924,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_scale(ctx0, x[0], s));
 
-                check_gradient("scale", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+                check_gradient("scale", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
             }
         }
 
@@ -910,7 +942,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1]));
 
-                check_gradient("cpy f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+                check_gradient("cpy f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
             }
         }
 
@@ -928,7 +960,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1]));
 
-                check_gradient("cpy f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY);
+                check_gradient("cpy f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY, {});
             }
         }
 
@@ -952,7 +984,7 @@ int main(int argc, const char ** argv) {
 
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_reshape(ctx0, x[0], x[1]));
-                check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+                check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
             }
         }
 
@@ -976,7 +1008,7 @@ int main(int argc, const char ** argv) {
 
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_reshape(ctx0, x[0], x[1]));
-                check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+                check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
             }
         }
 
@@ -1004,7 +1036,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
 
-                check_gradient("acc 1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+                check_gradient("acc 1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
             }
         }
 
@@ -1037,7 +1069,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
 
-                check_gradient("acc 2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+                check_gradient("acc 2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
             }
         }
 
@@ -1072,7 +1104,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
 
-                check_gradient("acc 3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+                check_gradient("acc 3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
             }
         }
 
@@ -1109,7 +1141,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
 
-                check_gradient("acc 4d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+                check_gradient("acc 4d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
             }
         }
 
@@ -1137,7 +1169,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_set_1d(ctx0, x[0], x[1], offset));
 
-                check_gradient("set_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+                check_gradient("set_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
             }
         }
 
@@ -1170,7 +1202,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_set_2d(ctx0, x[0], x[1], x[1]->nb[1], offset));
 
-                check_gradient("set_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+                check_gradient("set_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
             }
         }
 
@@ -1194,7 +1226,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_1d(ctx0, x[0], nelem, offset));
 
-                check_gradient("view_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+                check_gradient("view_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
             }
         }
 
@@ -1225,7 +1257,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_2d(ctx0, x[0], ne2[0], ne2[1], nb2[1], offset));
 
-                check_gradient("view_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+                check_gradient("view_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
             }
         }
 
@@ -1257,7 +1289,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_3d(ctx0, x[0], ne2[0], ne2[1], ne2[2], nb2[1], nb2[2], offset));
 
-                check_gradient("view_3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+                check_gradient("view_3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
             }
         }
 
@@ -1291,7 +1323,7 @@ int main(int argc, const char ** argv) {
                 // sum requires contiguous tensor rows
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, x[0], ax0, ax1, ax2, ax3)));
 
-                check_gradient("permute", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+                check_gradient("permute", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
             }
         }
 
@@ -1319,7 +1351,7 @@ int main(int argc, const char ** argv) {
                 // sum requires contiguous tensor rows
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, x[0])));
 
-                check_gradient("transpose", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+                check_gradient("transpose", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
             }
         }
 
@@ -1337,7 +1369,7 @@ int main(int argc, const char ** argv) {
 
             struct ggml_tensor * f = ggml_sum(ctx0, ggml_get_rows(ctx0, x[0], x[1]));
 
-            check_gradient("get_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            check_gradient("get_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
         }
 
         // diag_mask_inf
@@ -1353,7 +1385,7 @@ int main(int argc, const char ** argv) {
 
             struct ggml_tensor * f = ggml_sum(ctx0, ggml_diag_mask_inf(ctx0, x[0], n_past));
 
-            check_gradient("diag_mask_inf", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            check_gradient("diag_mask_inf", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
         }
 
         // diag_mask_zero
@@ -1369,7 +1401,7 @@ int main(int argc, const char ** argv) {
 
             struct ggml_tensor * f = ggml_sum(ctx0, ggml_diag_mask_zero(ctx0, x[0], n_past));
 
-            check_gradient("diag_mask_zero", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            check_gradient("diag_mask_zero", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
         }
 
         // softmax
@@ -1395,7 +1427,7 @@ int main(int argc, const char ** argv) {
                                                         1.0f - eps),
                                                     ggml_new_f32(ctx0, eps))));
 
-                check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 2e-1f, INFINITY);
+                check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 2e-1f, INFINITY, {});
                 // NOTE: softmax forward is computed using f16 table lookup instead of using actual expf, but backward assumes actual expf.
                 // this may result in different gradients too finite differences.
                 // when this test reports errors, first try to replace the table lookup with actual expf and test again to see if just that was the cause.
@@ -1412,7 +1444,7 @@ int main(int argc, const char ** argv) {
             get_random_dims(ne2, 4);
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -0.1f, 0.1f);
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
                 x[1] = get_random_tensor_f32(ctx0, ndims, ne2, 0.0f, 1.0f);
                 // the second argument to cross_entropy_loss must sum up to 1 for each row
                 int nr = ggml_nrows(x[1]);
@@ -1430,7 +1462,7 @@ int main(int argc, const char ** argv) {
 
                 struct ggml_tensor * f = ggml_cross_entropy_loss(ctx0, x[0], x[1]);
 
-                check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-4f, 1e-3f, INFINITY);
+                check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
             }
         }
 
@@ -1468,7 +1500,7 @@ int main(int argc, const char ** argv) {
                         struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode));
 
                         GGML_PRINT_DEBUG("rope f32: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
-                        check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY);
+                        check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY, {});
                     }
                 }
             }
@@ -1508,12 +1540,93 @@ int main(int argc, const char ** argv) {
                         struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode));
 
                         GGML_PRINT_DEBUG("rope f16: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
-                        check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY);
+                        check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY, {});
                     }
                 }
             }
         }
 
+        // im2col f32
+        {
+            srand(seed);
+            const int nargs = 1;
+            const int ndims = 4;
+
+            for (const bool is_2D : {false, true}) {
+                int64_t ne0[ndims];
+                int64_t ne1[ndims];
+                get_random_dims(ne0, ndims);
+                get_random_dims(ne1, ndims);
+
+                // // Ensure that the output is not zero-sized:
+                ne1[0] += 8;
+                ne1[1] += 8;
+
+                if (is_2D) {
+                    ne1[2] = ne0[2];
+                } else {
+                    ne1[1] = ne0[1];
+                    ne0[3] = 1;
+                    ne1[3] = 1;
+                }
+
+                // The order of arguments is swapped because the first tensor is only used for its shape.
+                x[1] = get_random_tensor_f16(ctx0, ndims, ne0, -1.0f, 1.0f);
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne1, -1.0f, 1.0f);
+
+                ggml_set_param(ctx0, x[0]);
+
+                const int s0 =         1 + irand(2);
+                const int s1 = is_2D ? 1 + irand(2) : 0;
+                const int p0 =         0 + irand(2);
+                const int p1 = is_2D ? 0 + irand(2) : 0;
+                const int d0 =         1 + irand(2);
+                const int d1 = is_2D ? 1 + irand(2) : 0;
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_im2col(ctx0, x[1], x[0], s0, s1, p0, p1, d0, d1, is_2D, GGML_TYPE_F32));
+
+                GGML_PRINT_DEBUG("im2col f32: is_2D=%s, s0=%d, s1=%d, p0=%d, p1=%d, d0=%d, d1=%d\n", is_2D ? "yes" : "no", s0, s1, p0, p1, d0, d1);
+                check_gradient("im2col f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY, {});
+            }
+        }
+
+        // pool_2d f32
+        {
+            srand(seed);
+            const int nargs = 1;
+            const int ndims = 4;
+
+            for (const enum ggml_op_pool op : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
+                int64_t ne0[ndims];
+                get_random_dims(ne0, ndims);
+
+                ne0[0] += 8;
+                ne0[1] += 8;
+
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne0, -1.0f, 1.0f);
+
+                ggml_set_param(ctx0, x[0]);
+
+                const int k0 = 2 + irand(2);
+                const int k1 = 2 + irand(2);
+                const int s0 = 2 + irand(2);
+                const int s1 = 2 + irand(2);
+                const int p0 = 0 + irand(2);
+                const int p1 = 0 + irand(2);
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_pool_2d(ctx0, x[0], op, k0, k1, s0, s1, p0, p1));
+
+                GGML_PRINT_DEBUG("ggml_pool_2d f32: op=%s k0=%d, k1=%d, s0=%d, s1=%d, p0=%d, p1=%d\n",
+                                 op == GGML_OP_POOL_MAX ? "max" : "avg", k0, k1, s0, s1, p0, p1);
+                std::vector<double> expected_vals;
+                if (op == GGML_OP_POOL_MAX) {
+                    expected_vals.push_back(0.0);
+                    expected_vals.push_back(1.0);
+                }
+                check_gradient("ggml_pool_2d f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, expected_vals);
+            }
+        }
+
         // flash_attn f32
         // TODO: adapt to ggml_flash_attn_ext() changes
         //{
@@ -1553,7 +1666,7 @@ int main(int argc, const char ** argv) {
 
         //                struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
 
-        //                check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
+        //                check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY, {});
         //            }
         //        }
         //    }