]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml: new gpu kernels + extends ggml_leaky_relu + ggml_pad (#621)
authorSteward Garcia <redacted>
Wed, 13 Dec 2023 14:08:48 +0000 (09:08 -0500)
committerGitHub <redacted>
Wed, 13 Dec 2023 14:08:48 +0000 (09:08 -0500)
* add new cuda kernels and new op ggml_pad

* add ggml_tanh cuda kernel

* remove old broadcast impl

* restore some changes

* cuda: optimized im2col + group_norm kernels

* extent ggml_leaky -> ggml_leaky_relu

* fix some code issues

* cuda: concat support 4 dims

* cuda: fix ggml_acc + add backends ops test

* restore ggml_pad + add backend op test

* metal : implement GGML_OP_ACC

* ggml : fix bug in ggml_upscale

* metal : add ggml_upscale

* metal : add ggml_tanh

* metal : add ggml_gelu_quick

* ggml : make ggml_pad more general purpose

* metal : add ggml_pad

* ggml_leaky_relu as regular op + fix identation

* cuda: ggml_acc admit all op_parms

* negative_slope better pass param

* metal : add ggml_leaky_relu

* metal : add ggml_group_norm

* cuda : minor

* ggml : add GGML_OP_LEAKY_RELU to ggml_compute_backward

* metal : soft max, tanh, supports_op fixes

* test-backend-ops : add sentinels between tensors to detect overflows

---------

Co-authored-by: Georgi Gerganov <redacted>
Co-authored-by: slaren <redacted>
.gitignore
examples/yolo/yolov3-tiny.cpp
include/ggml/ggml.h
src/CMakeLists.txt
src/ggml-cuda.cu
src/ggml-metal.m
src/ggml-metal.metal
src/ggml.c
tests/test-backend-ops.cpp
tests/test-conv1d.cpp
tests/test-conv2d.cpp

index ff756904b5c57b17fd344b8edff0b0055c7a3218..fbf3c634b1a0fda1915f0912a001f711911d21ea 100644 (file)
@@ -38,3 +38,4 @@ __pycache__/
 
 # Model files
 ggml-model-f16.bin
+*.bat
index ef848f3bc0754dae8ab2c10956ee91f2ed6c1288..2467bd89f33fa3d254701d30bceb43d30001da0f 100644 (file)
@@ -140,7 +140,7 @@ static ggml_tensor * apply_conv2d(ggml_context * ctx, ggml_tensor * input, const
     }
     result = ggml_add(ctx, result, ggml_repeat(ctx, layer.biases, result));
     if (layer.activate) {
-        result = ggml_leaky(ctx, result);
+        result = ggml_leaky_relu(ctx, result, 0.1f, true);
     }
     return result;
 }
index a8f10cbd5c1d8ead7963644753cd3d8bdd776f05..c0fd2230f86e31bcab150b56485b187b6d725a2c 100644 (file)
@@ -423,7 +423,9 @@ extern "C" {
         GGML_OP_POOL_1D,
         GGML_OP_POOL_2D,
         GGML_OP_UPSCALE, // nearest interpolate
+        GGML_OP_PAD,
         GGML_OP_ARGSORT,
+        GGML_OP_LEAKY_RELU,
 
         GGML_OP_FLASH_ATTN,
         GGML_OP_FLASH_FF,
@@ -463,7 +465,6 @@ extern "C" {
         GGML_UNARY_OP_GELU,
         GGML_UNARY_OP_GELU_QUICK,
         GGML_UNARY_OP_SILU,
-        GGML_UNARY_OP_LEAKY,
 
         GGML_UNARY_OP_COUNT,
     };
@@ -793,6 +794,9 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    // dst = a
+    // view(dst, nb1, nb2, nb3, offset) += b
+    // return dst
     GGML_API struct ggml_tensor * ggml_acc(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -957,15 +961,14 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
-    GGML_API struct ggml_tensor * ggml_leaky(
+    GGML_API struct ggml_tensor * ggml_leaky_relu(
             struct ggml_context * ctx,
-            struct ggml_tensor  * a);
+            struct ggml_tensor  * a, float negative_slope, bool inplace);
 
     GGML_API struct ggml_tensor * ggml_relu_inplace(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
-    // TODO: double-check this computation is correct
     GGML_API struct ggml_tensor * ggml_gelu(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
@@ -1549,6 +1552,15 @@ extern "C" {
             struct ggml_tensor  * a,
             int                   scale_factor);
 
+    // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
+    GGML_API struct ggml_tensor * ggml_pad(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                  p0,
+            int                  p1,
+            int                  p2,
+            int                  p3);
+
     // sort rows
     enum ggml_sort_order {
         GGML_SORT_ASC,
index 94ee0ce1480d080edff1233ad76982c31d5afb8a..de7c2434c17605b6c67541af064ba7e65901d53e 100644 (file)
@@ -223,7 +223,7 @@ if (GGML_CUBLAS)
         endif()
 
         # required for dynamic parallelism
-        set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)
+        set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)
 
         if (GGML_STATIC)
             set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
index 85f7a293783bed95758dd729fb49fb6c97d27017..e8a8783d5c43703d38f0b2531b62423c27c11517 100644 (file)
@@ -437,6 +437,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 
 #define CUDA_GELU_BLOCK_SIZE 256
 #define CUDA_SILU_BLOCK_SIZE 256
+#define CUDA_TANH_BLOCK_SIZE 256
 #define CUDA_RELU_BLOCK_SIZE 256
 #define CUDA_SQR_BLOCK_SIZE 256
 #define CUDA_CPY_BLOCK_SIZE 32
@@ -449,6 +450,11 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 #define CUDA_QUANTIZE_BLOCK_SIZE 256
 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
 #define CUDA_GET_ROWS_BLOCK_SIZE 256
+#define CUDA_UPSCALE_BLOCK_SIZE 256
+#define CUDA_CONCAT_BLOCK_SIZE 256
+#define CUDA_PAD_BLOCK_SIZE 256
+#define CUDA_ACC_BLOCK_SIZE 256
+#define CUDA_IM2COL_BLOCK_SIZE 256
 
 // dmmv = dequantize_mul_mat_vec
 #ifndef GGML_CUDA_DMMV_X
@@ -610,6 +616,24 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
     dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
 }
 
+static __global__ void acc_f32(const float * x, const float * y, float * dst, const int ne,
+    const int ne10, const int ne11, const int ne12,
+    const int nb1, const int nb2, int offset) {
+    const int i = blockDim.x * blockIdx.x + threadIdx.x;
+    if (i >= ne) {
+        return;
+    }
+    int src1_idx = i - offset;
+    int oz = src1_idx / nb2;
+    int oy = (src1_idx - (oz * nb2)) / nb1;
+    int ox = src1_idx % nb1;
+    if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
+        dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
+    } else {
+        dst[i] = x[i];
+    }
+}
+
 static __global__ void gelu_f32(const float * x, float * dst, const int k) {
     const float GELU_COEF_A    = 0.044715f;
     const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
@@ -632,6 +656,23 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
     dst[i] = x[i] / (1.0f + expf(-x[i]));
 }
 
+static __global__ void gelu_quick_f32(const float *x, float *dst, int k) {
+    const float GELU_QUICK_COEF = -1.702f;
+    const int i  = blockDim.x*blockIdx.x + threadIdx.x;
+    if (i >= k) {
+        return;
+    }
+    dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i])));
+}
+
+static __global__ void tanh_f32(const float *x, float *dst, int k) {
+    const int i  = blockDim.x*blockIdx.x + threadIdx.x;
+    if (i >= k) {
+        return;
+    }
+    dst[i] = tanhf(x[i]);
+}
+
 static __global__ void relu_f32(const float * x, float * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
@@ -641,6 +682,14 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) {
     dst[i] = fmaxf(x[i], 0);
 }
 
+static __global__ void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope) {
+    const int i  = blockDim.x*blockIdx.x + threadIdx.x;
+    if (i >= k) {
+        return;
+    }
+    dst[i] = fmaxf(x[i], 0) + fminf(x[i], 0.0f) * negative_slope;
+}
+
 static __global__ void sqr_f32(const float * x, float * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
@@ -686,6 +735,132 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c
     }
 }
 
+static __global__ void concat_f32(const float  *x,const float  *y, float *dst, const int ne0, const int ne02) {
+    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (nidx >= ne0) {
+        return;
+    }
+    // operation
+    int offset_dst =
+        nidx +
+        blockIdx.y * ne0 +
+        blockIdx.z * ne0 * gridDim.y;
+    if (blockIdx.z < ne02) { // src0
+        int offset_src =
+            nidx +
+            blockIdx.y * ne0 +
+            blockIdx.z * ne0 * gridDim.y;
+            dst[offset_dst] = x[offset_src];
+    } else {
+        int offset_src =
+            nidx +
+            blockIdx.y * ne0 +
+            (blockIdx.z - ne02) * ne0 *  gridDim.y;
+            dst[offset_dst] = y[offset_src];
+    }
+}
+
+static __global__ void upscale_f32(const float  *x, float *dst, const int ne00, const int nb02, const int scale_factor) {
+    int ne0 = ne00 * scale_factor;
+    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (nidx >= ne0) {
+        return;
+    }
+    // operation
+    int i00 = nidx / scale_factor;
+    int i01 = blockIdx.y / scale_factor;
+    int offset_src =
+        i00 +
+        i01 * ne00 +
+        blockIdx.z * nb02;
+    int offset_dst =
+        nidx +
+        blockIdx.y * ne0 +
+        blockIdx.z * ne0 * gridDim.y;
+    dst[offset_dst] = x[offset_src];
+}
+
+static __global__ void pad_f32(const float  *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02) {
+    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (nidx >= ne0) {
+        return;
+    }
+
+    // operation
+    int offset_dst =
+        nidx +
+        blockIdx.y * ne0 +
+        blockIdx.z * ne0 * gridDim.y;
+    if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02) {
+        int offset_src =
+            nidx +
+            blockIdx.y * ne00 +
+            blockIdx.z * ne00 * ne01;
+            dst[offset_dst] = x[offset_src];
+    } else {
+        dst[offset_dst] = 0.0f;
+    }
+}
+
+template <int block_size>
+static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
+    int start = blockIdx.x * group_size;
+    int end = start + group_size;
+
+    start += threadIdx.x;
+
+    if (end >= ne_elements) {
+        end = ne_elements;
+    }
+
+    float tmp = 0.0f; // partial sum for thread in warp
+
+    for (int j = start; j < end; j += block_size) {
+        tmp += x[j];
+    }
+
+    tmp = warp_reduce_sum(tmp);
+    if (block_size > WARP_SIZE) {
+        __shared__ float s_sum[32];
+        int warp_id = threadIdx.x / WARP_SIZE;
+        int lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            s_sum[warp_id] = tmp;
+        }
+        __syncthreads();
+        tmp = s_sum[lane_id];
+        tmp = warp_reduce_sum(tmp);
+    }
+
+    float mean = tmp / group_size;
+    tmp = 0.0f;
+
+    for (int j = start; j < end; j += block_size) {
+        float xi = x[j] - mean;
+        dst[j] = xi;
+        tmp += xi * xi;
+    }
+
+    tmp = warp_reduce_sum(tmp);
+    if (block_size > WARP_SIZE) {
+        __shared__ float s_sum[32];
+        int warp_id = threadIdx.x / WARP_SIZE;
+        int lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            s_sum[warp_id] = tmp;
+        }
+        __syncthreads();
+        tmp = s_sum[lane_id];
+        tmp = warp_reduce_sum(tmp);
+    }
+
+    float variance = tmp / group_size;
+    float scale = rsqrtf(variance + eps);
+    for (int j = start; j < end; j += block_size) {
+        dst[j] *= scale;
+    }
+}
+
 template <int block_size>
 static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
     const int row = blockIdx.x*blockDim.y + threadIdx.y;
@@ -5035,19 +5210,30 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
 
 static  __global__ void im2col_f32_f16(
         const float * x, half * dst,
-        int ofs0, int ofs1, int IW, int IH, int CHW,
+        int offset_delta, int IW, int IH, int OW, int KW, int KH, int pelements, int CHW,
         int s0, int s1, int p0, int p1, int d0, int d1) {
-    const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
-    const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
+    const int i = threadIdx.x + blockIdx.x * blockDim.x;
+    if (i >= pelements) {
+        return;
+    }
+
+    const int ksize = OW * (KH > 1 ? KW : 1);
+    const int kx = i / ksize;
+    const int kd = kx * ksize;
+    const int ky = (i - kd) / OW;
+    const int ix = i % OW;
+
+    const int iiw = ix * s0 + kx * d0 - p0;
+    const int iih = blockIdx.y * s1 + ky * d1 - p1;
 
     const int offset_dst =
-        (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
-        (blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
+        (blockIdx.y * OW + ix) * CHW +
+        (blockIdx.z * (KW * KH) + ky * KW + kx);
 
     if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
         dst[offset_dst] = __float2half(0.0f);
     } else {
-        const int offset_src =  threadIdx.x * ofs0 + blockIdx.x * ofs1;
+        const int offset_src = blockIdx.z * offset_delta;
         dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
     }
 }
@@ -5174,6 +5360,13 @@ struct bin_bcast_cuda {
     }
 };
 
+static void acc_f32_cuda(const float * x, const float * y, float * dst, const int n_elements,
+    const int ne10, const int ne11, const int ne12,
+    const int nb1, const int nb2, const int offset, cudaStream_t stream) {
+    int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE;
+    acc_f32<<<num_blocks, CUDA_ACC_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset);
+}
+
 static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
     gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -5184,11 +5377,26 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
     silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
+static void gelu_quick_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
+    gelu_quick_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
+    tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
 static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
     relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
+static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
+    leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
+}
+
 static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
     sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -5205,6 +5413,38 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
     }
 }
 
+static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const int group_size, const int ne_elements, cudaStream_t stream) {
+    static const float eps = 1e-6f;
+    if (group_size < 1024) {
+        const dim3 block_dims(WARP_SIZE, 1, 1);
+        group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
+    } else {
+        const dim3 block_dims(1024, 1, 1);
+        group_norm_f32<1024><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
+    }
+}
+
+static void concat_f32_cuda(const float * x, const float * y, float * dst, const int ne0, int ne1, int ne2, int ne02, cudaStream_t stream) {
+    int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
+    dim3 gridDim(num_blocks, ne1, ne2);
+    concat_f32<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
+}
+
+static void upscale_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int scale_factor, cudaStream_t stream) {
+    int ne0 = (ne00 * scale_factor);
+    int num_blocks = (ne0 + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
+    dim3 gridDim(num_blocks, (ne01 * scale_factor), ne02);
+    upscale_f32<<<gridDim, CUDA_UPSCALE_BLOCK_SIZE, 0, stream>>>(x, dst, ne00, ne00 * ne01, scale_factor);
+}
+
+static void pad_f32_cuda(const float * x, float * dst,
+    const int ne00, const int ne01, const int ne02,
+    const int ne0, const int ne1, const int ne2, cudaStream_t stream) {
+    int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
+    dim3 gridDim(num_blocks, ne1, ne2);
+    pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02);
+}
+
 static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
     GGML_ASSERT(ncols % WARP_SIZE == 0);
     if (ncols < 1024) {
@@ -6167,13 +6407,14 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con
     soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
 }
 
-static void im2col_f32_f16_cuda(const float * x, half * dst,
-    int OH, int IW, int IH, int OW, int IC,
-    int KH, int KW, int N,  int ofs0, int ofs1,
-    int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) {
-    dim3 block_nums(IC, OH, OW);
-    dim3 block_dims(N,  KH, KW);
-    im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
+static void im2col_f32_f16_cuda(const float* x, half* dst,
+    int IW, int IH, int OW, int OH, int KW, int KH, int IC,
+    int offset_delta,
+    int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
+    const int parallel_elements = OW * KW * KH;
+    const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
+    dim3 block_nums(num_blocks, OH, IC);
+    im2col_f32_f16<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, offset_delta, IW, IH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
 }
 
 // buffer pool for cuda
@@ -6522,6 +6763,25 @@ inline void ggml_cuda_op_add(
     ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
 }
 
+inline void ggml_cuda_op_acc(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
+
+    int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
+    int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
+    // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
+    int offset = dst->op_params[3] / 4; // offset in bytes
+
+    acc_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, main_stream);
+
+    (void) dst;
+}
+
 inline void ggml_cuda_op_mul(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -6564,6 +6824,34 @@ inline void ggml_cuda_op_silu(
     (void) src1_dd;
 }
 
+inline void ggml_cuda_op_gelu_quick(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    gelu_quick_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_cuda_op_tanh(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    tanh_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
 inline void ggml_cuda_op_relu(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -6578,6 +6866,23 @@ inline void ggml_cuda_op_relu(
     (void) src1_dd;
 }
 
+inline void ggml_cuda_op_leaky_relu(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    float negative_slope;
+    memcpy(&negative_slope, dst->op_params, sizeof(float));
+
+    leaky_relu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), negative_slope, main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
 inline void ggml_cuda_op_sqr(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -6612,6 +6917,71 @@ inline void ggml_cuda_op_norm(
     (void) src1_dd;
 }
 
+
+inline void ggml_cuda_op_group_norm(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    int num_groups = dst->op_params[0];
+    int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
+    group_norm_f32_cuda(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_cuda_op_concat(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    for (int i3 = 0; i3 < dst->ne[3]; i3++) {
+        concat_f32_cuda(src0_dd + i3 * (src0->nb[3] / 4), src1_dd + i3 * (src1->nb[3] / 4), dst_dd + i3 * (dst->nb[3] / 4), dst->ne[0], dst->ne[1], dst->ne[2], src0->ne[2], main_stream);
+    }
+
+    (void) src1;
+    (void) dst;
+}
+
+inline void ggml_cuda_op_upscale(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
+
+    const int scale_factor = dst->op_params[0];
+
+    upscale_f32_cuda(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], scale_factor, main_stream);
+
+    (void) src1;
+    (void) dst;
+}
+
+inline void ggml_cuda_op_pad(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
+
+    pad_f32_cuda(src0_dd, dst_dd,
+        src0->ne[0], src0->ne[1], src0->ne[2],
+        dst->ne[0], dst->ne[1], dst->ne[2], main_stream);
+
+    (void) src1;
+    (void) dst;
+}
+
 inline void ggml_cuda_op_rms_norm(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -7126,7 +7496,6 @@ inline void ggml_cuda_op_im2col(
 
     const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
 
-    const int64_t N  = src1->ne[is_2D ? 3 : 2];
     const int64_t IC = src1->ne[is_2D ? 2 : 1];
     const int64_t IH = is_2D ? src1->ne[1] : 1;
     const int64_t IW =         src1->ne[0];
@@ -7137,12 +7506,9 @@ inline void ggml_cuda_op_im2col(
     const int64_t OH = is_2D ? dst->ne[2] : 1;
     const int64_t OW =         dst->ne[1];
 
-    const size_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
-    const size_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
+    const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
 
-    im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
-        OH, IW, IH, OW, IC, KH, KW, N,
-        ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream);
+    im2col_f32_f16_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
 
     (void) src0;
     (void) src0_dd;
@@ -7696,6 +8062,10 @@ static void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, gg
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add);
 }
 
+static void ggml_cuda_acc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_acc);
+}
+
 static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul);
 }
@@ -7712,10 +8082,22 @@ static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, g
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
 }
 
+static void ggml_cuda_gelu_quick(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu_quick);
+}
+
+static void ggml_cuda_tanh(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_tanh);
+}
+
 static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
 }
 
+static void ggml_cuda_leaky_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_leaky_relu);
+}
+
 static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr);
 }
@@ -7724,6 +8106,22 @@ static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, g
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
 }
 
+static void ggml_cuda_group_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_group_norm);
+}
+
+static void ggml_cuda_concat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_concat);
+}
+
+static void ggml_cuda_upscale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_upscale);
+}
+
+static void ggml_cuda_pad(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_pad);
+}
+
 static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm);
 }
@@ -8683,6 +9081,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
         case GGML_OP_ADD:
             func = ggml_cuda_add;
             break;
+        case GGML_OP_ACC:
+            func = ggml_cuda_acc;
+            break;
         case GGML_OP_MUL:
             func = ggml_cuda_mul;
             break;
@@ -8697,6 +9098,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
                 case GGML_UNARY_OP_SILU:
                     func = ggml_cuda_silu;
                     break;
+                case GGML_UNARY_OP_GELU_QUICK:
+                    func = ggml_cuda_gelu_quick;
+                    break;
+                case GGML_UNARY_OP_TANH:
+                    func = ggml_cuda_tanh;
+                    break;
                 case GGML_UNARY_OP_RELU:
                     func = ggml_cuda_relu;
                     break;
@@ -8707,6 +9114,21 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
         case GGML_OP_NORM:
             func = ggml_cuda_norm;
             break;
+        case GGML_OP_GROUP_NORM:
+            func = ggml_cuda_group_norm;
+            break;
+        case GGML_OP_CONCAT:
+            func = ggml_cuda_concat;
+            break;
+        case GGML_OP_UPSCALE:
+            func = ggml_cuda_upscale;
+            break;
+        case GGML_OP_PAD:
+            func = ggml_cuda_pad;
+            break;
+        case GGML_OP_LEAKY_RELU:
+            func = ggml_cuda_leaky_relu;
+            break;
         case GGML_OP_RMS_NORM:
             func = ggml_cuda_rms_norm;
             break;
@@ -8729,9 +9151,6 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
             func = ggml_cuda_sqr;
             break;
         case GGML_OP_CLAMP:
-            if (!any_on_device) {
-                return false;
-            }
             func = ggml_cuda_clamp;
             break;
         case GGML_OP_CPY:
@@ -8740,6 +9159,7 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
         case GGML_OP_CONT:
             func = ggml_cuda_dup;
             break;
+        case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
         case GGML_OP_VIEW:
         case GGML_OP_PERMUTE:
@@ -9159,6 +9579,8 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
                 case GGML_UNARY_OP_GELU:
                 case GGML_UNARY_OP_SILU:
                 case GGML_UNARY_OP_RELU:
+                case GGML_UNARY_OP_GELU_QUICK:
+                case GGML_UNARY_OP_TANH:
                     return true;
                 default:
                     return false;
@@ -9206,6 +9628,12 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
         case GGML_OP_IM2COL:
         case GGML_OP_SUM_ROWS:
         case GGML_OP_ARGSORT:
+        case GGML_OP_ACC:
+        case GGML_OP_CONCAT:
+        case GGML_OP_GROUP_NORM:
+        case GGML_OP_UPSCALE:
+        case GGML_OP_PAD:
+        case GGML_OP_LEAKY_RELU:
             return true;
         default:
             return false;
index d7146ea61cfde65f450fc4d0c578efdecd29549c..6b538fd39cf973d26e30bf934a93ae86d1fd456f 100644 (file)
@@ -66,9 +66,11 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(div_row);
     GGML_METAL_DECL_KERNEL(scale);
     GGML_METAL_DECL_KERNEL(scale_4);
-    GGML_METAL_DECL_KERNEL(silu);
+    GGML_METAL_DECL_KERNEL(tanh);
     GGML_METAL_DECL_KERNEL(relu);
     GGML_METAL_DECL_KERNEL(gelu);
+    GGML_METAL_DECL_KERNEL(gelu_quick);
+    GGML_METAL_DECL_KERNEL(silu);
     GGML_METAL_DECL_KERNEL(soft_max);
     GGML_METAL_DECL_KERNEL(soft_max_4);
     GGML_METAL_DECL_KERNEL(diag_mask_inf);
@@ -86,6 +88,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(get_rows_q5_K);
     GGML_METAL_DECL_KERNEL(get_rows_q6_K);
     GGML_METAL_DECL_KERNEL(rms_norm);
+    GGML_METAL_DECL_KERNEL(group_norm);
     GGML_METAL_DECL_KERNEL(norm);
     GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
     GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
@@ -130,8 +133,11 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(rope_f16);
     GGML_METAL_DECL_KERNEL(alibi_f32);
     GGML_METAL_DECL_KERNEL(im2col_f16);
+    GGML_METAL_DECL_KERNEL(upscale_f32);
+    GGML_METAL_DECL_KERNEL(pad_f32);
     GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
     GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
+    GGML_METAL_DECL_KERNEL(leaky_relu_f32);
     GGML_METAL_DECL_KERNEL(cpy_f32_f16);
     GGML_METAL_DECL_KERNEL(cpy_f32_f32);
     GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
@@ -318,9 +324,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(div_row);
         GGML_METAL_ADD_KERNEL(scale);
         GGML_METAL_ADD_KERNEL(scale_4);
-        GGML_METAL_ADD_KERNEL(silu);
+        GGML_METAL_ADD_KERNEL(tanh);
         GGML_METAL_ADD_KERNEL(relu);
         GGML_METAL_ADD_KERNEL(gelu);
+        GGML_METAL_ADD_KERNEL(gelu_quick);
+        GGML_METAL_ADD_KERNEL(silu);
         GGML_METAL_ADD_KERNEL(soft_max);
         GGML_METAL_ADD_KERNEL(soft_max_4);
         GGML_METAL_ADD_KERNEL(diag_mask_inf);
@@ -338,6 +346,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(get_rows_q5_K);
         GGML_METAL_ADD_KERNEL(get_rows_q6_K);
         GGML_METAL_ADD_KERNEL(rms_norm);
+        GGML_METAL_ADD_KERNEL(group_norm);
         GGML_METAL_ADD_KERNEL(norm);
         GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
         GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
@@ -384,8 +393,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(rope_f16);
         GGML_METAL_ADD_KERNEL(alibi_f32);
         GGML_METAL_ADD_KERNEL(im2col_f16);
+        GGML_METAL_ADD_KERNEL(upscale_f32);
+        GGML_METAL_ADD_KERNEL(pad_f32);
         GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
         GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
+        GGML_METAL_ADD_KERNEL(leaky_relu_f32);
         GGML_METAL_ADD_KERNEL(cpy_f32_f16);
         GGML_METAL_ADD_KERNEL(cpy_f32_f32);
         GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
@@ -418,9 +430,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(div_row);
     GGML_METAL_DEL_KERNEL(scale);
     GGML_METAL_DEL_KERNEL(scale_4);
-    GGML_METAL_DEL_KERNEL(silu);
+    GGML_METAL_DEL_KERNEL(tanh);
     GGML_METAL_DEL_KERNEL(relu);
     GGML_METAL_DEL_KERNEL(gelu);
+    GGML_METAL_DEL_KERNEL(gelu_quick);
+    GGML_METAL_DEL_KERNEL(silu);
     GGML_METAL_DEL_KERNEL(soft_max);
     GGML_METAL_DEL_KERNEL(soft_max_4);
     GGML_METAL_DEL_KERNEL(diag_mask_inf);
@@ -438,6 +452,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(get_rows_q5_K);
     GGML_METAL_DEL_KERNEL(get_rows_q6_K);
     GGML_METAL_DEL_KERNEL(rms_norm);
+    GGML_METAL_DEL_KERNEL(group_norm);
     GGML_METAL_DEL_KERNEL(norm);
     GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
     GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
@@ -484,8 +499,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(rope_f16);
     GGML_METAL_DEL_KERNEL(alibi_f32);
     GGML_METAL_DEL_KERNEL(im2col_f16);
+    GGML_METAL_DEL_KERNEL(upscale_f32);
+    GGML_METAL_DEL_KERNEL(pad_f32);
     GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
     GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
+    GGML_METAL_DEL_KERNEL(leaky_relu_f32);
     GGML_METAL_DEL_KERNEL(cpy_f32_f16);
     GGML_METAL_DEL_KERNEL(cpy_f32_f32);
     GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
@@ -795,9 +813,11 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
     switch (op->op) {
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(op)) {
-                case GGML_UNARY_OP_SILU:
+                case GGML_UNARY_OP_TANH:
                 case GGML_UNARY_OP_RELU:
                 case GGML_UNARY_OP_GELU:
+                case GGML_UNARY_OP_GELU_QUICK:
+                case GGML_UNARY_OP_SILU:
                     return true;
                 default:
                     return false;
@@ -809,6 +829,7 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
         case GGML_OP_PERMUTE:
         case GGML_OP_CONCAT:
         case GGML_OP_ADD:
+        case GGML_OP_ACC:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
         case GGML_OP_SCALE:
@@ -816,11 +837,15 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
         case GGML_OP_SUM_ROWS:
         case GGML_OP_SOFT_MAX:
         case GGML_OP_RMS_NORM:
+        case GGML_OP_GROUP_NORM:
         case GGML_OP_NORM:
         case GGML_OP_ALIBI:
         case GGML_OP_ROPE:
         case GGML_OP_IM2COL:
+        case GGML_OP_UPSCALE:
+        case GGML_OP_PAD:
         case GGML_OP_ARGSORT:
+        case GGML_OP_LEAKY_RELU:
         case GGML_OP_DUP:
         case GGML_OP_CPY:
         case GGML_OP_CONT:
@@ -830,8 +855,8 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_GET_ROWS:
             {
-                return op->ne[0] % 4 == 0;
-            }
+                return op->ne[3] == 1;
+            } break;
         default:
             return false;
     }
@@ -906,7 +931,10 @@ void ggml_metal_graph_compute(
                         } break;
                 }
 
-                GGML_ASSERT(ggml_metal_supports_op(dst));
+                if (!ggml_metal_supports_op(dst)) {
+                    GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
+                    GGML_ASSERT(!"unsupported op");
+                }
 
                 const int64_t  ne00 = src0 ? src0->ne[0] : 0;
                 const int64_t  ne01 = src0 ? src0->ne[1] : 0;
@@ -1006,6 +1034,8 @@ void ggml_metal_graph_compute(
                             GGML_ASSERT(ggml_is_contiguous(src0));
                             GGML_ASSERT(ggml_is_contiguous(src1));
 
+                            const size_t offs = 0;
+
                             bool bcast_row = false;
 
                             int64_t nb = ne00;
@@ -1058,7 +1088,8 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:24];
                             [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:25];
                             [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:26];
-                            [encoder setBytes:&nb   length:sizeof(nb)   atIndex:27];
+                            [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
+                            [encoder setBytes:&nb   length:sizeof(nb)   atIndex:28];
 
                             if (bcast_row) {
                                 const int64_t n = ggml_nelements(dst)/4;
@@ -1070,6 +1101,86 @@ void ggml_metal_graph_compute(
                                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                             }
                         } break;
+                    case GGML_OP_ACC:
+                        {
+                            GGML_ASSERT(src0t == GGML_TYPE_F32);
+                            GGML_ASSERT(src1t == GGML_TYPE_F32);
+                            GGML_ASSERT(dstt  == GGML_TYPE_F32);
+
+                            GGML_ASSERT(ggml_is_contiguous(src0));
+                            GGML_ASSERT(ggml_is_contiguous(src1));
+
+                            const size_t pnb1 = ((int32_t *) dst->op_params)[0];
+                            const size_t pnb2 = ((int32_t *) dst->op_params)[1];
+                            const size_t pnb3 = ((int32_t *) dst->op_params)[2];
+                            const size_t offs = ((int32_t *) dst->op_params)[3];
+
+                            const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
+
+                            if (!inplace) {
+                                // run a separete kernel to cpy src->dst
+                                // not sure how to avoid this
+                                // TODO: make a simpler cpy_bytes kernel
+
+                                const int nth = MIN(1024, ne00);
+
+                                [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
+                                [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
+                                [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
+                                [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
+                                [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:3];
+                                [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:4];
+                                [encoder setBytes:&ne03    length:sizeof( int64_t) atIndex:5];
+                                [encoder setBytes:&nb00    length:sizeof(uint64_t) atIndex:6];
+                                [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:7];
+                                [encoder setBytes:&nb02    length:sizeof(uint64_t) atIndex:8];
+                                [encoder setBytes:&nb03    length:sizeof(uint64_t) atIndex:9];
+                                [encoder setBytes:&ne0     length:sizeof( int64_t) atIndex:10];
+                                [encoder setBytes:&ne1     length:sizeof( int64_t) atIndex:11];
+                                [encoder setBytes:&ne2     length:sizeof( int64_t) atIndex:12];
+                                [encoder setBytes:&ne3     length:sizeof( int64_t) atIndex:13];
+                                [encoder setBytes:&nb0     length:sizeof(uint64_t) atIndex:14];
+                                [encoder setBytes:&nb1     length:sizeof(uint64_t) atIndex:15];
+                                [encoder setBytes:&nb2     length:sizeof(uint64_t) atIndex:16];
+                                [encoder setBytes:&nb3     length:sizeof(uint64_t) atIndex:17];
+
+                                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                            }
+
+                            [encoder setComputePipelineState:ctx->pipeline_add];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+                            [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
+                            [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
+                            [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
+                            [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
+                            [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
+                            [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
+                            [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
+                            [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
+                            [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
+                            [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
+                            [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
+                            [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
+                            [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
+                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:19];
+                            [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:20];
+                            [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:21];
+                            [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:22];
+                            [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:23];
+                            [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
+                            [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
+                            [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
+                            [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
+
+                            const int nth = MIN(1024, ne0);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        } break;
                     case GGML_OP_SCALE:
                         {
                             GGML_ASSERT(ggml_is_contiguous(src0));
@@ -1093,16 +1204,15 @@ void ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_UNARY:
                         switch (ggml_get_unary_op(gf->nodes[i])) {
-                            case GGML_UNARY_OP_SILU:
+                            case GGML_UNARY_OP_TANH:
                                 {
-                                    [encoder setComputePipelineState:ctx->pipeline_silu];
+                                    [encoder setComputePipelineState:ctx->pipeline_tanh];
                                     [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                     [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 
                                     const int64_t n = ggml_nelements(dst);
-                                    GGML_ASSERT(n % 4 == 0);
 
-                                    [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                                 } break;
                             case GGML_UNARY_OP_RELU:
                                 {
@@ -1123,6 +1233,28 @@ void ggml_metal_graph_compute(
                                     const int64_t n = ggml_nelements(dst);
                                     GGML_ASSERT(n % 4 == 0);
 
+                                    [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                } break;
+                            case GGML_UNARY_OP_GELU_QUICK:
+                                {
+                                    [encoder setComputePipelineState:ctx->pipeline_gelu_quick];
+                                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                                    const int64_t n = ggml_nelements(dst);
+                                    GGML_ASSERT(n % 4 == 0);
+
+                                    [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                } break;
+                            case GGML_UNARY_OP_SILU:
+                                {
+                                    [encoder setComputePipelineState:ctx->pipeline_silu];
+                                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                                    const int64_t n = ggml_nelements(dst);
+                                    GGML_ASSERT(n % 4 == 0);
+
                                     [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                                 } break;
                             default:
@@ -1197,6 +1329,8 @@ void ggml_metal_graph_compute(
                             [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
                             if (id_src1) {
                                 [encoder setBuffer:id_src1 offset:offs_src1   atIndex:1];
+                            } else {
+                                [encoder setBuffer:id_src0 offset:offs_src0   atIndex:1];
                             }
                             [encoder setBuffer:id_dst  offset:offs_dst    atIndex:2];
                             [encoder setBytes:&ne00  length:sizeof(ne00)  atIndex:3];
@@ -1599,6 +1733,38 @@ void ggml_metal_graph_compute(
 
                             [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
+                    case GGML_OP_GROUP_NORM:
+                        {
+                            GGML_ASSERT(ne00 % 4 == 0);
+
+                            //float eps;
+                            //memcpy(&eps, dst->op_params, sizeof(float));
+
+                            const float eps = 1e-6f; // TODO: temporarily hardcoded
+
+                            const int32_t n_groups = ((int32_t *) dst->op_params)[0];
+
+                            int nth = 32; // SIMD width
+
+                            //while (nth < ne00/4 && nth < 1024) {
+                            //    nth *= 2;
+                            //}
+
+                            [encoder setComputePipelineState:ctx->pipeline_group_norm];
+                            [encoder setBuffer:id_src0  offset:offs_src0        atIndex:0];
+                            [encoder setBuffer:id_dst   offset:offs_dst         atIndex:1];
+                            [encoder setBytes:&ne00     length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&ne01     length:sizeof( int64_t) atIndex:3];
+                            [encoder setBytes:&ne02     length:sizeof( int64_t) atIndex:4];
+                            [encoder setBytes:&nb00     length:sizeof(uint64_t) atIndex:5];
+                            [encoder setBytes:&nb01     length:sizeof(uint64_t) atIndex:6];
+                            [encoder setBytes:&nb02     length:sizeof(uint64_t) atIndex:7];
+                            [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
+                            [encoder setBytes:&eps      length:sizeof(   float) atIndex:9];
+                            [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        } break;
                     case GGML_OP_NORM:
                         {
                             float eps;
@@ -1768,6 +1934,65 @@ void ggml_metal_graph_compute(
 
                             [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
                         } break;
+                    case GGML_OP_UPSCALE:
+                        {
+                            GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+                            const int sf = dst->op_params[0];
+
+                            [encoder setComputePipelineState:ctx->pipeline_upscale_f32];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                            [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+                            [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+                            [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+                            [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+                            [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
+                            [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
+                            [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
+                            [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
+                            [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
+                            [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
+                            [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
+                            [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
+                            [encoder setBytes:&sf   length:sizeof(sf)   atIndex:18];
+
+                            const int nth = MIN(1024, ne0);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        } break;
+                    case GGML_OP_PAD:
+                        {
+                            GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+                            [encoder setComputePipelineState:ctx->pipeline_pad_f32];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                            [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+                            [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+                            [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+                            [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+                            [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
+                            [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
+                            [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
+                            [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
+                            [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
+                            [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
+                            [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
+                            [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
+
+                            const int nth = MIN(1024, ne0);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        } break;
                     case GGML_OP_ARGSORT:
                         {
                             GGML_ASSERT(src0->type == GGML_TYPE_F32);
@@ -1789,6 +2014,22 @@ void ggml_metal_graph_compute(
 
                             [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
                         } break;
+                    case GGML_OP_LEAKY_RELU:
+                        {
+                            GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+                            float slope;
+                            memcpy(&slope, dst->op_params, sizeof(float));
+
+                            [encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32];
+                            [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst    atIndex:1];
+                            [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
+
+                            const int64_t n = ggml_nelements(dst);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
                     case GGML_OP_DUP:
                     case GGML_OP_CPY:
                     case GGML_OP_CONT:
index 2f8ea22d66226040f9b9b1de4b70466994f6ff94..54b3b8a16200c0f30768987b8a59b8e833cabdde 100644 (file)
@@ -79,6 +79,7 @@ kernel void kernel_add(
         constant  int64_t & nb1,
         constant  int64_t & nb2,
         constant  int64_t & nb3,
+        constant  int64_t & offs,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]],
         uint3   ntg[[threads_per_threadgroup]]) {
@@ -90,9 +91,9 @@ kernel void kernel_add(
     const int64_t i12 = i02 % ne12;
     const int64_t i11 = i01 % ne11;
 
-    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
     device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
-    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
+    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1  + offs;
 
     for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
         const int i10 = i0 % ne10;
@@ -204,7 +205,7 @@ kernel void kernel_add_row(
         device const float4 * src0,
         device const float4 * src1,
         device       float4 * dst,
-        constant    int64_t & nb [[buffer(27)]],
+        constant    int64_t & nb [[buffer(28)]],
         uint tpig[[thread_position_in_grid]]) {
     dst[tpig] = src0[tpig] + src1[tpig % nb];
 }
@@ -213,7 +214,7 @@ kernel void kernel_mul_row(
         device const float4 * src0,
         device const float4 * src1,
         device       float4 * dst,
-        constant    int64_t & nb  [[buffer(27)]],
+        constant    int64_t & nb  [[buffer(28)]],
         uint tpig[[thread_position_in_grid]]) {
     dst[tpig] = src0[tpig] * src1[tpig % nb];
 }
@@ -222,7 +223,7 @@ kernel void kernel_div_row(
         device const float4 * src0,
         device const float4 * src1,
         device       float4 * dst,
-        constant    int64_t & nb  [[buffer(27)]],
+        constant    int64_t & nb  [[buffer(28)]],
         uint tpig[[thread_position_in_grid]]) {
     dst[tpig] = src0[tpig] / src1[tpig % nb];
 }
@@ -243,19 +244,53 @@ kernel void kernel_scale_4(
     dst[tpig] = src0[tpig] * scale;
 }
 
-kernel void kernel_silu(
-        device const float4 * src0,
-        device       float4 * dst,
+kernel void kernel_relu(
+        device const float * src0,
+        device       float * dst,
         uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-    dst[tpig] = x / (1.0f + exp(-x));
+    dst[tpig] = max(0.0f, src0[tpig]);
 }
 
-kernel void kernel_relu(
+kernel void kernel_tanh(
         device const float * src0,
         device       float * dst,
         uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = max(0.0f, src0[tpig]);
+    device const float & x = src0[tpig];
+    dst[tpig] = precise::tanh(x);
+}
+
+constant float GELU_COEF_A     = 0.044715f;
+constant float GELU_QUICK_COEF = -1.702f;
+constant float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
+
+kernel void kernel_gelu(
+    device const float4 * src0,
+    device       float4 * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    device const float4 & x = src0[tpig];
+
+    // BEWARE !!!
+    // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
+    // This was observed with Falcon 7B and 40B models
+    //
+    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+kernel void kernel_gelu_quick(
+    device const float4 * src0,
+    device       float4 * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    device const float4 & x = src0[tpig];
+
+    dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
+}
+
+kernel void kernel_silu(
+        device const float4 * src0,
+        device       float4 * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    device const float4 & x = src0[tpig];
+    dst[tpig] = x / (1.0f + exp(-x));
 }
 
 kernel void kernel_sqr(
@@ -313,22 +348,6 @@ kernel void kernel_sum_rows(
     dst_row[0] = row_sum;
 }
 
-constant float GELU_COEF_A    = 0.044715f;
-constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
-
-kernel void kernel_gelu(
-    device const float4 * src0,
-    device       float4 * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-
-    // BEWARE !!!
-    // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
-    // This was observed with Falcon 7B and 40B models
-    //
-    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
-}
-
 kernel void kernel_soft_max(
         device const float * src0,
         device const float * src1,
@@ -348,7 +367,7 @@ kernel void kernel_soft_max(
     const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
 
     device const float * psrc0 =        src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-    device const float * pmask = src1 ? src1                                      + i01*ne00 : nullptr;
+    device const float * pmask = src1 != src0 ? src1                              + i01*ne00 : nullptr;
     device       float * pdst  =        dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
 
     // parallel max
@@ -385,6 +404,7 @@ kernel void kernel_soft_max(
         pdst[i00] = exp_psrc0;
     }
 
+    threadgroup_barrier(mem_flags::mem_threadgroup);
     float sum = simd_sum(lsum);
     if (ntg > N_SIMDWIDTH) {
         if (sgitg == 0) {
@@ -428,9 +448,9 @@ kernel void kernel_soft_max_4(
     const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
     const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
 
-    device const float4 * psrc4 =        (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
-    device const float4 * pmask = src1 ? (device const float4 *)(src1 +                                      i01*ne00) : nullptr;
-    device       float4 * pdst4 =        (device       float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    device const float4 * psrc4 =                (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 +                                      i01*ne00) : nullptr;
+    device       float4 * pdst4 =                (device       float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
 
     // parallel max
     float4 lmax4 = -INFINITY;
@@ -468,6 +488,7 @@ kernel void kernel_soft_max_4(
     }
 
     const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
+    threadgroup_barrier(mem_flags::mem_threadgroup);
     float sum = simd_sum(lsum);
     if (ntg > N_SIMDWIDTH) {
         if (sgitg == 0) {
@@ -639,6 +660,94 @@ kernel void kernel_rms_norm(
     }
 }
 
+kernel void kernel_group_norm(
+        device const float * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int32_t & n_groups,
+        constant     float & eps,
+        threadgroup float  * buf [[threadgroup(0)]],
+        uint tgpig[[threadgroup_position_in_grid]],
+        uint tpitg[[thread_position_in_threadgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]],
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint   ntg[[threads_per_threadgroup]]) {
+    const int64_t ne = ne00*ne01*ne02;
+    const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
+
+    int start = tgpig * gs;
+    int end   = start + gs;
+
+    start += tpitg;
+
+    if (end >= ne) {
+        end = ne;
+    }
+
+    float tmp = 0.0f; // partial sum for thread in warp
+
+    for (int j = start; j < end; j += ntg) {
+        tmp += src0[j];
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    tmp = simd_sum(tmp);
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            buf[sgitg] = tmp;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        tmp = buf[tiisg];
+        tmp = simd_sum(tmp);
+    }
+
+    const float mean = tmp / gs;
+    tmp = 0.0f;
+
+    for (int j = start; j < end; j += ntg) {
+        float xi = src0[j] - mean;
+        dst[j] = xi;
+        tmp += xi * xi;
+    }
+
+    tmp = simd_sum(tmp);
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            buf[sgitg] = tmp;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        tmp = buf[tiisg];
+        tmp = simd_sum(tmp);
+    }
+
+    const float variance = tmp / gs;
+    const float scale = 1.0f/sqrt(variance + eps);
+    for (int j = start; j < end; j += ntg) {
+        dst[j] *= scale;
+    }
+}
+
 // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
 // il indicates where the q4 quants begin (0 or QK4_0/4)
 // we assume that the yl's have been multiplied with the appropriate scale factor
@@ -1548,6 +1657,97 @@ kernel void kernel_im2col_f16(
     }
 }
 
+kernel void kernel_upscale_f32(
+    device  const char * src0,
+    device        char * dst,
+    constant   int64_t & ne00,
+    constant   int64_t & ne01,
+    constant   int64_t & ne02,
+    constant   int64_t & ne03,
+    constant  uint64_t & nb00,
+    constant  uint64_t & nb01,
+    constant  uint64_t & nb02,
+    constant  uint64_t & nb03,
+    constant   int64_t & ne0,
+    constant   int64_t & ne1,
+    constant   int64_t & ne2,
+    constant   int64_t & ne3,
+    constant  uint64_t & nb0,
+    constant  uint64_t & nb1,
+    constant  uint64_t & nb2,
+    constant  uint64_t & nb3,
+    constant   int32_t & sf,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    const int64_t i3 = tgpig.z;
+    const int64_t i2 = tgpig.y;
+    const int64_t i1 = tgpig.x;
+
+    const int64_t i03 = i3;
+    const int64_t i02 = i2;
+    const int64_t i01 = i1/sf;
+
+    device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
+    device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1);
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        dst_ptr[i0] = src0_ptr[i0/sf];
+    }
+}
+
+kernel void kernel_pad_f32(
+    device  const char * src0,
+    device        char * dst,
+    constant   int64_t & ne00,
+    constant   int64_t & ne01,
+    constant   int64_t & ne02,
+    constant   int64_t & ne03,
+    constant  uint64_t & nb00,
+    constant  uint64_t & nb01,
+    constant  uint64_t & nb02,
+    constant  uint64_t & nb03,
+    constant   int64_t & ne0,
+    constant   int64_t & ne1,
+    constant   int64_t & ne2,
+    constant   int64_t & ne3,
+    constant  uint64_t & nb0,
+    constant  uint64_t & nb1,
+    constant  uint64_t & nb2,
+    constant  uint64_t & nb3,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    const int64_t i3 = tgpig.z;
+    const int64_t i2 = tgpig.y;
+    const int64_t i1 = tgpig.x;
+
+    const int64_t i03 = i3;
+    const int64_t i02 = i2;
+    const int64_t i01 = i1;
+
+    device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
+    device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1);
+
+    if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
+        for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+            if (i0 < ne00) {
+                dst_ptr[i0] = src0_ptr[i0];
+            } else {
+                dst_ptr[i0] = 0.0f;
+            }
+        }
+
+        return;
+    }
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        dst_ptr[i0] = 0.0f;
+    }
+}
+
 // bitonic sort implementation following the CUDA kernels as reference
 typedef void (argsort_t)(
         device const float * x,
@@ -1600,6 +1800,14 @@ kernel void kernel_argsort_f32_i32(
 template [[host_name("kernel_argsort_f32_i32_asc")]]  kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
 template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
 
+kernel void kernel_leaky_relu_f32(
+        device const float * src0,
+        device       float * dst,
+        constant     float & slope,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
+}
+
 kernel void kernel_cpy_f16_f16(
         device const half * src0,
         device       half * dst,
@@ -1917,9 +2125,9 @@ kernel void kernel_cpy_f32_q4_1(
 }
 
 kernel void kernel_concat(
-    device const char * src0,
-    device const char * src1,
-    device       char * dst,
+    device  const char * src0,
+    device  const char * src1,
+    device        char * dst,
     constant   int64_t & ne00,
     constant   int64_t & ne01,
     constant   int64_t & ne02,
@@ -1956,7 +2164,7 @@ kernel void kernel_concat(
     const int64_t i12 = i02 % ne12;
     const int64_t i11 = i01 % ne11;
 
-    device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
+    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
     device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
     device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1  + tpitg.x*nb0;
 
index ca56f063c3a87440353e3efce428c18e003517fa..8652446c739352c8d2c373f0e0ff328c65a706a2 100644 (file)
@@ -1395,7 +1395,7 @@ inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) {
 inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]);  }
 inline static void ggml_vec_elu_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
 inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
-inline static void ggml_vec_leaky_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.1f*x[i]; }
+inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
 
 static const float GELU_COEF_A     = 0.044715f;
 static const float GELU_QUICK_COEF = -1.702f;
@@ -1623,7 +1623,9 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "POOL_1D",
     "POOL_2D",
     "UPSCALE",
+    "PAD",
     "ARGSORT",
+    "LEAKY_RELU",
 
     "FLASH_ATTN",
     "FLASH_FF",
@@ -1650,7 +1652,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CROSS_ENTROPY_LOSS_BACK",
 };
 
-static_assert(GGML_OP_COUNT == 70, "GGML_OP_COUNT != 70");
+static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1707,7 +1709,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "pool_1d(x)",
     "pool_2d(x)",
     "upscale(x)",
+    "pad(x)",
     "argsort(x)",
+    "leaky_relu(x)",
 
     "flash_attn(x)",
     "flash_ff(x)",
@@ -1734,7 +1738,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "cross_entropy_loss_back(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 70, "GGML_OP_COUNT != 70");
+static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -1750,10 +1754,9 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
     "GELU",
     "GELU_QUICK",
     "SILU",
-    "LEAKY",
 };
 
-static_assert(GGML_UNARY_OP_COUNT == 11, "GGML_UNARY_OP_COUNT != 11");
+static_assert(GGML_UNARY_OP_COUNT == 10, "GGML_UNARY_OP_COUNT != 10");
 
 
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -3830,12 +3833,25 @@ struct ggml_tensor * ggml_relu_inplace(
     return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU);
 }
 
-// ggml_leaky
+// ggml_leaky_relu
 
-struct ggml_tensor * ggml_leaky(
+struct ggml_tensor * ggml_leaky_relu(
         struct ggml_context * ctx,
-        struct ggml_tensor  * a) {
-    return ggml_unary(ctx, a, GGML_UNARY_OP_LEAKY);
+        struct ggml_tensor  * a, float negative_slope, bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+    ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
+
+    result->op   = GGML_OP_LEAKY_RELU;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+
+    return result;
 }
 
 // ggml_gelu
@@ -4022,8 +4038,9 @@ static struct ggml_tensor * ggml_group_norm_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    result->op = GGML_OP_GROUP_NORM;
     result->op_params[0] = n_groups;
+
+    result->op = GGML_OP_GROUP_NORM;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
     result->src[1] = NULL; // TODO: maybe store epsilon here?
@@ -5519,6 +5536,30 @@ static struct ggml_tensor * ggml_upscale_impl(
     return result;
 }
 
+struct ggml_tensor * ggml_pad(
+    struct ggml_context * ctx,
+    struct ggml_tensor  * a,
+    int p0, int p1, int p2, int p3) {
+    bool is_node = false;
+
+    if (a->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
+            a->ne[0] + p0,
+            a->ne[1] + p1,
+            a->ne[2] + p2,
+            a->ne[3] + p3);
+
+    result->op = GGML_OP_PAD;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+
+    return result;
+}
+
 struct ggml_tensor * ggml_upscale(
     struct ggml_context * ctx,
     struct ggml_tensor * a,
@@ -7714,8 +7755,10 @@ static void ggml_compute_forward_mul_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
+// TODO: OpenCL kernel support broadcast
 #ifdef GGML_USE_CLBLAST
     if (src1->backend == GGML_BACKEND_GPU) {
+        GGML_ASSERT(ggml_are_same_shape(src0, src1));
         if (ith == 0) {
             ggml_cl_mul(src0, src1, dst);
         }
@@ -8981,10 +9024,9 @@ static void ggml_compute_forward_silu(
             } break;
     }
 }
+// ggml_compute_forward_leaky_relu
 
-// ggml_compute_forward_leaky
-
-static void ggml_compute_forward_leaky_f32(
+static void ggml_compute_forward_leaky_relu_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -8998,24 +9040,27 @@ static void ggml_compute_forward_leaky_f32(
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
+    float negative_slope;
+    memcpy(&negative_slope, dst->op_params, sizeof(float));
+
     assert(dst->nb[0]  == sizeof(float));
     assert(src0->nb[0] == sizeof(float));
 
     for (int i = 0; i < n; i++) {
-        ggml_vec_leaky_f32(nc,
+        ggml_vec_leaky_relu_f32(nc,
                 (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
+                (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
     }
 }
 
-static void ggml_compute_forward_leaky(
+static void ggml_compute_forward_leaky_relu(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_leaky_f32(params, src0, dst);
+                ggml_compute_forward_leaky_relu_f32(params, src0, dst);
             } break;
         default:
             {
@@ -12114,6 +12159,7 @@ static void ggml_compute_forward_upscale_f32(
     GGML_ASSERT(src0->nb[0] == sizeof(float));
 
     const int ith = params->ith;
+    const int nth = params->nth;
 
     GGML_TENSOR_UNARY_OP_LOCALS
 
@@ -12121,16 +12167,17 @@ static void ggml_compute_forward_upscale_f32(
 
     // TODO: optimize
 
-    for (int i03 = 0; i03 < ne03; i03++) {
-        for (int i02 = ith; i02 < ne02; i02++) {
-            for (int m = 0; m < dst->ne[1]; m++) {
-                int i01 = m / scale_factor;
-                for (int n = 0; n < dst->ne[0]; n++) {
-                    int i00 = n / scale_factor;
-
-                    const float * x = (float *)((char *) src0->data + i00 * nb00 +i01 * nb01 + i02 * nb02 + i03 * nb03);
+    for (int64_t i3 = 0; i3 < ne3; i3++) {
+        const int64_t i03 = i3;
+        for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
+            const int64_t i02 = i2;
+            for (int64_t i1 = 0; i1 < ne1; i1++) {
+                const int64_t i01 = i1 / scale_factor;
+                for (int64_t i0 = 0; i0 < ne0; i0++) {
+                    const int64_t i00 = i0 / scale_factor;
 
-                    float * y = (float *)((char *) dst->data + n * dst->nb[0] + m * dst->nb[1] + i02 * dst->nb[2] + i03 * dst->nb[3]);
+                    const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                          float * y = (float *)((char *)  dst->data +  i0*nb0  +  i1*nb1  +  i2*nb2  +  i3*nb3);
 
                     *y = *x;
                 }
@@ -12155,6 +12202,64 @@ static void ggml_compute_forward_upscale(
     }
 }
 
+// ggml_compute_forward_pad
+
+static void ggml_compute_forward_pad_f32(
+    const struct ggml_compute_params * params,
+    const struct ggml_tensor * src0,
+          struct ggml_tensor * dst) {
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT( dst->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    float * dst_ptr = (float *) dst->data;
+
+    // TODO: optimize
+
+    for (int64_t i2 = 0; i2 < ne2; ++i2) {
+        for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
+            for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                for (int64_t i3 = 0; i3 < ne3; ++i3) {
+                    const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
+
+                    const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+
+                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+                        dst_ptr[dst_idx] = *src_ptr;
+                    } else {
+                        dst_ptr[dst_idx] = 0;
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_pad(
+    const struct ggml_compute_params * params,
+    const struct ggml_tensor * src0,
+    struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_pad_f32(params, src0, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_argsort
 
 static void ggml_compute_forward_argsort_f32(
@@ -13362,10 +13467,6 @@ static void ggml_compute_forward_unary(
             {
                 ggml_compute_forward_silu(params, src0, dst);
             } break;
-        case GGML_UNARY_OP_LEAKY:
-            {
-                ggml_compute_forward_leaky(params, src0, dst);
-            } break;
         default:
             {
                 GGML_ASSERT(false);
@@ -14147,10 +14248,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_upscale(params, tensor->src[0], tensor);
             } break;
+        case GGML_OP_PAD:
+            {
+                ggml_compute_forward_pad(params, tensor->src[0], tensor);
+            } break;
         case GGML_OP_ARGSORT:
             {
                 ggml_compute_forward_argsort(params, tensor->src[0], tensor);
             } break;
+        case GGML_OP_LEAKY_RELU:
+            {
+                ggml_compute_forward_leaky_relu(params, tensor->src[0], tensor);
+            } break;
         case GGML_OP_FLASH_ATTN:
             {
                 const int32_t t = ggml_get_op_params_i32(tensor, 0);
@@ -15143,10 +15252,18 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
+        case GGML_OP_PAD:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_ARGSORT:
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
+        case GGML_OP_LEAKY_RELU:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_FLASH_ATTN:
             {
                 struct ggml_tensor * flash_grad = NULL;
@@ -15752,6 +15869,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_ARGMAX:
         case GGML_OP_REPEAT:
         case GGML_OP_REPEAT_BACK:
+        case GGML_OP_LEAKY_RELU:
             {
                 n_tasks = 1;
             } break;
@@ -15764,7 +15882,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
                 case GGML_UNARY_OP_TANH:
                 case GGML_UNARY_OP_ELU:
                 case GGML_UNARY_OP_RELU:
-                case GGML_UNARY_OP_LEAKY:
                     {
                         n_tasks = 1;
                     } break;
@@ -15883,6 +16000,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             {
                 n_tasks = n_threads;
             } break;
+        case GGML_OP_PAD:
+            {
+                n_tasks = n_threads;
+            } break;
         case GGML_OP_ARGSORT:
             {
                 n_tasks = n_threads;
index e0155ac1c89139367a6a09e77b2f8a1cc3501ab0..937521ef9b2caf44812ffd2d9000939646307a83 100644 (file)
@@ -230,6 +230,11 @@ static bool ggml_is_view_op(enum ggml_op op) {
     return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE;
 }
 
+enum test_mode {
+    MODE_TEST,
+    MODE_PERF,
+};
+
 struct test_case {
     virtual ~test_case() {}
 
@@ -260,7 +265,58 @@ struct test_case {
         return size;
     }
 
+    ggml_cgraph * gf = nullptr;
+
+    static const int sentinel_size = 1024;
+
+    test_mode mode;
+
+    std::vector<ggml_tensor *> sentinels;
+
+    void add_sentinel(ggml_context * ctx) {
+        if (mode == MODE_PERF) {
+            return;
+        }
+        ggml_tensor * sentinel = ::ggml_new_tensor_1d(ctx, GGML_TYPE_F32, sentinel_size);
+        ggml_format_name(sentinel, "sent_%zu", sentinels.size());
+        sentinels.push_back(sentinel);
+    }
+
+    // hijack ggml_new_tensor to add sentinels after each tensor to check for overflows in the backend
+
+    ggml_tensor * ggml_new_tensor(ggml_context * ctx, ggml_type type, int n_dims, const int64_t * ne) {
+        ggml_tensor * t = ::ggml_new_tensor(ctx, type, n_dims, ne);
+        add_sentinel(ctx);
+        return t;
+    }
+
+    ggml_tensor * ggml_new_tensor_1d(ggml_context * ctx, ggml_type type, int64_t ne0) {
+        ggml_tensor * t = ::ggml_new_tensor_1d(ctx, type, ne0);
+        add_sentinel(ctx);
+        return t;
+    }
+
+    ggml_tensor * ggml_new_tensor_2d(ggml_context * ctx, ggml_type type, int64_t ne0, int64_t ne1) {
+        ggml_tensor * t = ::ggml_new_tensor_2d(ctx, type, ne0, ne1);
+        add_sentinel(ctx);
+        return t;
+    }
+
+    ggml_tensor * ggml_new_tensor_3d(ggml_context * ctx, ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2) {
+        ggml_tensor * t = ::ggml_new_tensor_3d(ctx, type, ne0, ne1, ne2);
+        add_sentinel(ctx);
+        return t;
+    }
+
+    ggml_tensor * ggml_new_tensor_4d(ggml_context * ctx, ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
+        ggml_tensor * t = ::ggml_new_tensor_4d(ctx, type, ne0, ne1, ne2, ne3);
+        add_sentinel(ctx);
+        return t;
+    }
+
     bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name) {
+        mode = MODE_TEST;
+
         ggml_init_params params = {
             /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead(),
             /* .mem_base = */ NULL,
@@ -268,6 +324,11 @@ struct test_case {
         };
         ggml_context * ctx = ggml_init(params);
 
+        gf = ggml_new_graph(ctx);
+
+        // pre-graph sentinel
+        add_sentinel(ctx);
+
         ggml_tensor * out = build_graph(ctx);
 
         if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) {
@@ -288,13 +349,20 @@ struct test_case {
             }
         }
 
+        // post-graph sentinel
+        add_sentinel(ctx);
+
         // allocate
         ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend1);
 
         // build graph
-        ggml_cgraph * gf = ggml_new_graph(ctx);
         ggml_build_forward_expand(gf, out);
 
+        // add sentinels as graph nodes so that they are checked in the callback
+        for (ggml_tensor * sentinel : sentinels) {
+            gf->nodes[gf->n_nodes++] = sentinel;
+        }
+
         // randomize tensors
         initialize_tensors(ctx);
 
@@ -310,9 +378,24 @@ struct test_case {
         };
 
         auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool {
+            callback_userdata * ud = (callback_userdata *) user_data;
+
+            if (t1->op == GGML_OP_NONE) {
+                // sentinels must be unchanged
+                std::vector<uint8_t> t1_data(ggml_nbytes(t1));
+                std::vector<uint8_t> t2_data(ggml_nbytes(t2));
+                ggml_backend_tensor_get(t1, t1_data.data(), 0, ggml_nbytes(t1));
+                ggml_backend_tensor_get(t2, t2_data.data(), 0, ggml_nbytes(t2));
+
+                if (memcmp(t1_data.data(), t2_data.data(), ggml_nbytes(t1)) != 0) {
+                    printf("sentinel mismatch: %s ", t1->name);
+                    ud->ok = false;
+                    return true;
+                }
+            }
+
             std::vector<float> f1 = tensor_to_float(t1);
             std::vector<float> f2 = tensor_to_float(t2);
-            callback_userdata * ud = (callback_userdata *) user_data;
 
             for (size_t i = 0; i < f1.size(); i++) {
                 // check for nans
@@ -361,6 +444,8 @@ struct test_case {
     }
 
     bool eval_perf(ggml_backend_t backend, const char * op_name) {
+        mode = MODE_PERF;
+
         static const size_t graph_nodes = 8192;
 
         ggml_init_params params = {
@@ -1109,9 +1194,116 @@ struct test_sum_rows : public test_case {
     }
 };
 
-enum test_mode {
-    MODE_TEST,
-    MODE_PERF,
+// GGML_OP_UPSCALE
+struct test_upscale : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    const int32_t scale_factor;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, scale_factor);
+    }
+
+    test_upscale(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {512, 512, 3, 1},
+            int32_t scale_factor = 2)
+        : type(type), ne(ne), scale_factor(scale_factor) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_upscale(ctx, a, scale_factor);
+        return out;
+    }
+};
+
+// GGML_OP_GROUP_NORM
+struct test_group_norm : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    const int32_t num_groups;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, num_groups);
+    }
+
+    test_group_norm(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {64, 64, 320, 1},
+            int32_t num_groups = 32)
+        : type(type), ne(ne), num_groups(num_groups) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_group_norm(ctx, a, num_groups);
+        return out;
+    }
+};
+
+// GGML_OP_ACC
+struct test_acc : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne_a;
+    const std::array<int64_t, 4> ne_b;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne_a, ne_b);
+    }
+
+    test_acc(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne_a = {1024, 577, 1, 1},
+            std::array<int64_t, 4> ne_b = {1024, 576, 1, 1})
+        : type(type), ne_a(ne_a), ne_b(ne_b) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
+        ggml_tensor * out = ggml_acc(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], b->nb[1]);
+        return out;
+    }
+};
+
+// GGML_OP_PAD
+struct test_pad : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne_a;
+    const int pad_0;
+    const int pad_1;
+
+    std::string vars() override {
+        return VARS_TO_STR4(type, ne_a, pad_0, pad_1);
+    }
+
+    test_pad(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne_a = {512, 512, 1, 1},
+            int pad_0 = 1, int pad_1 = 1)
+        : type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1)  {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+        ggml_tensor * out = ggml_pad(ctx, a, pad_0, pad_1, 0, 0);
+        return out;
+    }
+};
+
+// GGML_OP_LEAKY_RELU
+struct test_leaky_relu : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne_a;
+    const float negative_slope;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne_a, negative_slope);
+    }
+
+    test_leaky_relu(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne_a = {10, 10, 10, 10},
+            float negative_slope = 0.1f)
+        : type(type), ne_a(ne_a), negative_slope(negative_slope)  {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+        ggml_tensor * out = ggml_leaky_relu(ctx, a, negative_slope, true);
+        return out;
+    }
 };
 
 static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
@@ -1251,6 +1443,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     }
 
     test_cases.emplace_back(new test_sum_rows());
+    test_cases.emplace_back(new test_upscale());
+    test_cases.emplace_back(new test_group_norm());
+    test_cases.emplace_back(new test_acc());
+    test_cases.emplace_back(new test_pad());
+    test_cases.emplace_back(new test_leaky_relu());
 
     // run tests
     if (mode == MODE_TEST) {
index 5b74156853ed58566f9c850a1f2ab4b7077b45e6..37522fc0b3f72cb281741991fcd13373c9e9d2c1 100644 (file)
@@ -196,7 +196,6 @@ struct ggml_cgraph* compute_graph(const test_model & model, struct ggml_allocr *
 
     //ggml_graph_print(gf);
 
-    // in this case, the output tensor is the last one in the graph
     return gf;
 }
 
index f50a53afa244f3c55828fd0a714cee054f1d35ac..be318c43c4c08ddadb4670f2b0c808b8bb699d1b 100644 (file)
@@ -171,7 +171,6 @@ struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * a
     ggml_set_name(conv2d_res, "conv2d_res");
     ggml_build_forward_expand(gf, conv2d_res);
 
-    // delete the temporally context used to build the graph
     ggml_free(ctx0);
     return gf;
 }
@@ -200,7 +199,6 @@ struct ggml_cgraph * compute_graph(const test_model & model, struct ggml_allocr
 
     //ggml_graph_print(gf);
 
-    // in this case, the output tensor is the last one in the graph
     return gf;
 }