]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
sync : ggml (SD ops, tests, kernels) (#4444)
authorGeorgi Gerganov <redacted>
Wed, 13 Dec 2023 19:54:54 +0000 (21:54 +0200)
committerGitHub <redacted>
Wed, 13 Dec 2023 19:54:54 +0000 (21:54 +0200)
* sync : ggml (SD ops, tests, kernels)

ggml-ci

* cuda : restore im2col

ggml-ci

* metal : fix accuracy of dequantization kernels

ggml-ci

* cuda : restore correct im2col

ggml-ci

* metal : try to fix moe test by reducing expert size

ggml-ci

* cuda : fix bin bcast when src1 and dst have different types

ggml-ci

---------

Co-authored-by: slaren <redacted>
ggml-cuda.cu
ggml-metal.m
ggml-metal.metal
ggml.c
ggml.h
tests/test-backend-ops.cpp

index 9e1acd3f19e5f6d08e2ea1ac8335186ca980fff2..019648bddc4d9e23d363d289b432f10dbc481853 100644 (file)
@@ -439,6 +439,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 
 #define CUDA_GELU_BLOCK_SIZE 256
 #define CUDA_SILU_BLOCK_SIZE 256
+#define CUDA_TANH_BLOCK_SIZE 256
 #define CUDA_RELU_BLOCK_SIZE 256
 #define CUDA_SQR_BLOCK_SIZE 256
 #define CUDA_CPY_BLOCK_SIZE 32
@@ -451,6 +452,11 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 #define CUDA_QUANTIZE_BLOCK_SIZE 256
 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
 #define CUDA_GET_ROWS_BLOCK_SIZE 256
+#define CUDA_UPSCALE_BLOCK_SIZE 256
+#define CUDA_CONCAT_BLOCK_SIZE 256
+#define CUDA_PAD_BLOCK_SIZE 256
+#define CUDA_ACC_BLOCK_SIZE 256
+#define CUDA_IM2COL_BLOCK_SIZE 256
 
 // dmmv = dequantize_mul_mat_vec
 #ifndef GGML_CUDA_DMMV_X
@@ -612,6 +618,24 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
     dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
 }
 
+static __global__ void acc_f32(const float * x, const float * y, float * dst, const int ne,
+    const int ne10, const int ne11, const int ne12,
+    const int nb1, const int nb2, int offset) {
+    const int i = blockDim.x * blockIdx.x + threadIdx.x;
+    if (i >= ne) {
+        return;
+    }
+    int src1_idx = i - offset;
+    int oz = src1_idx / nb2;
+    int oy = (src1_idx - (oz * nb2)) / nb1;
+    int ox = src1_idx % nb1;
+    if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
+        dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
+    } else {
+        dst[i] = x[i];
+    }
+}
+
 static __global__ void gelu_f32(const float * x, float * dst, const int k) {
     const float GELU_COEF_A    = 0.044715f;
     const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
@@ -634,6 +658,23 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
     dst[i] = x[i] / (1.0f + expf(-x[i]));
 }
 
+static __global__ void gelu_quick_f32(const float *x, float *dst, int k) {
+    const float GELU_QUICK_COEF = -1.702f;
+    const int i  = blockDim.x*blockIdx.x + threadIdx.x;
+    if (i >= k) {
+        return;
+    }
+    dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i])));
+}
+
+static __global__ void tanh_f32(const float *x, float *dst, int k) {
+    const int i  = blockDim.x*blockIdx.x + threadIdx.x;
+    if (i >= k) {
+        return;
+    }
+    dst[i] = tanhf(x[i]);
+}
+
 static __global__ void relu_f32(const float * x, float * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
@@ -643,6 +684,14 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) {
     dst[i] = fmaxf(x[i], 0);
 }
 
+static __global__ void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope) {
+    const int i  = blockDim.x*blockIdx.x + threadIdx.x;
+    if (i >= k) {
+        return;
+    }
+    dst[i] = fmaxf(x[i], 0) + fminf(x[i], 0.0f) * negative_slope;
+}
+
 static __global__ void sqr_f32(const float * x, float * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
@@ -688,6 +737,132 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c
     }
 }
 
+static __global__ void concat_f32(const float  *x,const float  *y, float *dst, const int ne0, const int ne02) {
+    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (nidx >= ne0) {
+        return;
+    }
+    // operation
+    int offset_dst =
+        nidx +
+        blockIdx.y * ne0 +
+        blockIdx.z * ne0 * gridDim.y;
+    if (blockIdx.z < ne02) { // src0
+        int offset_src =
+            nidx +
+            blockIdx.y * ne0 +
+            blockIdx.z * ne0 * gridDim.y;
+            dst[offset_dst] = x[offset_src];
+    } else {
+        int offset_src =
+            nidx +
+            blockIdx.y * ne0 +
+            (blockIdx.z - ne02) * ne0 *  gridDim.y;
+            dst[offset_dst] = y[offset_src];
+    }
+}
+
+static __global__ void upscale_f32(const float  *x, float *dst, const int ne00, const int nb02, const int scale_factor) {
+    int ne0 = ne00 * scale_factor;
+    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (nidx >= ne0) {
+        return;
+    }
+    // operation
+    int i00 = nidx / scale_factor;
+    int i01 = blockIdx.y / scale_factor;
+    int offset_src =
+        i00 +
+        i01 * ne00 +
+        blockIdx.z * nb02;
+    int offset_dst =
+        nidx +
+        blockIdx.y * ne0 +
+        blockIdx.z * ne0 * gridDim.y;
+    dst[offset_dst] = x[offset_src];
+}
+
+static __global__ void pad_f32(const float  *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02) {
+    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (nidx >= ne0) {
+        return;
+    }
+
+    // operation
+    int offset_dst =
+        nidx +
+        blockIdx.y * ne0 +
+        blockIdx.z * ne0 * gridDim.y;
+    if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02) {
+        int offset_src =
+            nidx +
+            blockIdx.y * ne00 +
+            blockIdx.z * ne00 * ne01;
+            dst[offset_dst] = x[offset_src];
+    } else {
+        dst[offset_dst] = 0.0f;
+    }
+}
+
+template <int block_size>
+static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
+    int start = blockIdx.x * group_size;
+    int end = start + group_size;
+
+    start += threadIdx.x;
+
+    if (end >= ne_elements) {
+        end = ne_elements;
+    }
+
+    float tmp = 0.0f; // partial sum for thread in warp
+
+    for (int j = start; j < end; j += block_size) {
+        tmp += x[j];
+    }
+
+    tmp = warp_reduce_sum(tmp);
+    if (block_size > WARP_SIZE) {
+        __shared__ float s_sum[32];
+        int warp_id = threadIdx.x / WARP_SIZE;
+        int lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            s_sum[warp_id] = tmp;
+        }
+        __syncthreads();
+        tmp = s_sum[lane_id];
+        tmp = warp_reduce_sum(tmp);
+    }
+
+    float mean = tmp / group_size;
+    tmp = 0.0f;
+
+    for (int j = start; j < end; j += block_size) {
+        float xi = x[j] - mean;
+        dst[j] = xi;
+        tmp += xi * xi;
+    }
+
+    tmp = warp_reduce_sum(tmp);
+    if (block_size > WARP_SIZE) {
+        __shared__ float s_sum[32];
+        int warp_id = threadIdx.x / WARP_SIZE;
+        int lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            s_sum[warp_id] = tmp;
+        }
+        __syncthreads();
+        tmp = s_sum[lane_id];
+        tmp = warp_reduce_sum(tmp);
+    }
+
+    float variance = tmp / group_size;
+    float scale = rsqrtf(variance + eps);
+    for (int j = start; j < end; j += block_size) {
+        dst[j] *= scale;
+    }
+}
+
 template <int block_size>
 static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
     const int row = blockIdx.x*blockDim.y + threadIdx.y;
@@ -5071,19 +5246,30 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
 
 static  __global__ void im2col_f32_f16(
         const float * x, half * dst,
-        int ofs0, int ofs1, int IW, int IH, int CHW,
+        int offset_delta, int IW, int IH, int OW, int KW, int KH, int pelements, int CHW,
         int s0, int s1, int p0, int p1, int d0, int d1) {
-    const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
-    const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
+    const int i = threadIdx.x + blockIdx.x * blockDim.x;
+    if (i >= pelements) {
+        return;
+    }
+
+    const int ksize = OW * (KH > 1 ? KW : 1);
+    const int kx = i / ksize;
+    const int kd = kx * ksize;
+    const int ky = (i - kd) / OW;
+    const int ix = i % OW;
+
+    const int iiw = ix * s0 + kx * d0 - p0;
+    const int iih = blockIdx.y * s1 + ky * d1 - p1;
 
     const int offset_dst =
-        (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
-        (blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
+        (blockIdx.y * OW + ix) * CHW +
+        (blockIdx.z * (KW * KH) + ky * KW + kx);
 
     if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
         dst[offset_dst] = __float2half(0.0f);
     } else {
-        const int offset_src =  threadIdx.x * ofs0 + blockIdx.x * ofs1;
+        const int offset_src = blockIdx.z * offset_delta;
         dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
     }
 }
@@ -5220,10 +5406,10 @@ struct bin_bcast_cuda {
             size_t nb12 = cnb1[2];
             size_t nb13 = cnb1[3];
 
-            size_t s0 = nb0 / sizeof(src1_t);
-            size_t s1 = nb1 / sizeof(src1_t);
-            size_t s2 = nb2 / sizeof(src1_t);
-            size_t s3 = nb3 / sizeof(src1_t);
+            size_t s0 = nb0 / sizeof(dst_t);
+            size_t s1 = nb1 / sizeof(dst_t);
+            size_t s2 = nb2 / sizeof(dst_t);
+            size_t s3 = nb3 / sizeof(dst_t);
 
             size_t s10 = nb10 / sizeof(src1_t);
             size_t s11 = nb11 / sizeof(src1_t);
@@ -5269,6 +5455,13 @@ struct bin_bcast_cuda {
     }
 };
 
+static void acc_f32_cuda(const float * x, const float * y, float * dst, const int n_elements,
+    const int ne10, const int ne11, const int ne12,
+    const int nb1, const int nb2, const int offset, cudaStream_t stream) {
+    int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE;
+    acc_f32<<<num_blocks, CUDA_ACC_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset);
+}
+
 static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
     gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -5279,11 +5472,26 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
     silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
+static void gelu_quick_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
+    gelu_quick_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
+    tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
 static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
     relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
+static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
+    leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
+}
+
 static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
     sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -5300,6 +5508,38 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
     }
 }
 
+static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const int group_size, const int ne_elements, cudaStream_t stream) {
+    static const float eps = 1e-6f;
+    if (group_size < 1024) {
+        const dim3 block_dims(WARP_SIZE, 1, 1);
+        group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
+    } else {
+        const dim3 block_dims(1024, 1, 1);
+        group_norm_f32<1024><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
+    }
+}
+
+static void concat_f32_cuda(const float * x, const float * y, float * dst, const int ne0, int ne1, int ne2, int ne02, cudaStream_t stream) {
+    int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
+    dim3 gridDim(num_blocks, ne1, ne2);
+    concat_f32<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
+}
+
+static void upscale_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int scale_factor, cudaStream_t stream) {
+    int ne0 = (ne00 * scale_factor);
+    int num_blocks = (ne0 + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
+    dim3 gridDim(num_blocks, (ne01 * scale_factor), ne02);
+    upscale_f32<<<gridDim, CUDA_UPSCALE_BLOCK_SIZE, 0, stream>>>(x, dst, ne00, ne00 * ne01, scale_factor);
+}
+
+static void pad_f32_cuda(const float * x, float * dst,
+    const int ne00, const int ne01, const int ne02,
+    const int ne0, const int ne1, const int ne2, cudaStream_t stream) {
+    int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
+    dim3 gridDim(num_blocks, ne1, ne2);
+    pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02);
+}
+
 static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
     GGML_ASSERT(ncols % WARP_SIZE == 0);
     if (ncols < 1024) {
@@ -6262,13 +6502,14 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con
     soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
 }
 
-static void im2col_f32_f16_cuda(const float * x, half * dst,
-    int OH, int IW, int IH, int OW, int IC,
-    int KH, int KW, int N,  int ofs0, int ofs1,
-    int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) {
-    dim3 block_nums(IC, OH, OW);
-    dim3 block_dims(N,  KH, KW);
-    im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
+static void im2col_f32_f16_cuda(const float* x, half* dst,
+    int IW, int IH, int OW, int OH, int KW, int KH, int IC,
+    int offset_delta,
+    int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
+    const int parallel_elements = OW * KW * KH;
+    const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
+    dim3 block_nums(num_blocks, OH, IC);
+    im2col_f32_f16<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, offset_delta, IW, IH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
 }
 
 // buffer pool for cuda
@@ -6615,6 +6856,25 @@ inline void ggml_cuda_op_add(
     ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
 }
 
+inline void ggml_cuda_op_acc(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
+
+    int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
+    int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
+    // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
+    int offset = dst->op_params[3] / 4; // offset in bytes
+
+    acc_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, main_stream);
+
+    (void) dst;
+}
+
 inline void ggml_cuda_op_mul(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -6657,6 +6917,34 @@ inline void ggml_cuda_op_silu(
     (void) src1_dd;
 }
 
+inline void ggml_cuda_op_gelu_quick(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    gelu_quick_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_cuda_op_tanh(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    tanh_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
 inline void ggml_cuda_op_relu(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -6671,6 +6959,23 @@ inline void ggml_cuda_op_relu(
     (void) src1_dd;
 }
 
+inline void ggml_cuda_op_leaky_relu(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    float negative_slope;
+    memcpy(&negative_slope, dst->op_params, sizeof(float));
+
+    leaky_relu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), negative_slope, main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
 inline void ggml_cuda_op_sqr(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -6705,6 +7010,71 @@ inline void ggml_cuda_op_norm(
     (void) src1_dd;
 }
 
+
+inline void ggml_cuda_op_group_norm(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    int num_groups = dst->op_params[0];
+    int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
+    group_norm_f32_cuda(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_cuda_op_concat(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    for (int i3 = 0; i3 < dst->ne[3]; i3++) {
+        concat_f32_cuda(src0_dd + i3 * (src0->nb[3] / 4), src1_dd + i3 * (src1->nb[3] / 4), dst_dd + i3 * (dst->nb[3] / 4), dst->ne[0], dst->ne[1], dst->ne[2], src0->ne[2], main_stream);
+    }
+
+    (void) src1;
+    (void) dst;
+}
+
+inline void ggml_cuda_op_upscale(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
+
+    const int scale_factor = dst->op_params[0];
+
+    upscale_f32_cuda(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], scale_factor, main_stream);
+
+    (void) src1;
+    (void) dst;
+}
+
+inline void ggml_cuda_op_pad(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
+
+    pad_f32_cuda(src0_dd, dst_dd,
+        src0->ne[0], src0->ne[1], src0->ne[2],
+        dst->ne[0], dst->ne[1], dst->ne[2], main_stream);
+
+    (void) src1;
+    (void) dst;
+}
+
 inline void ggml_cuda_op_rms_norm(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -7219,7 +7589,6 @@ inline void ggml_cuda_op_im2col(
 
     const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
 
-    const int64_t N  = src1->ne[is_2D ? 3 : 2];
     const int64_t IC = src1->ne[is_2D ? 2 : 1];
     const int64_t IH = is_2D ? src1->ne[1] : 1;
     const int64_t IW =         src1->ne[0];
@@ -7230,17 +7599,15 @@ inline void ggml_cuda_op_im2col(
     const int64_t OH = is_2D ? dst->ne[2] : 1;
     const int64_t OW =         dst->ne[1];
 
-    const size_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
-    const size_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
+    const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
 
-    im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
-        OH, IW, IH, OW, IC, KH, KW, N,
-        ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream);
+    im2col_f32_f16_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
 
     (void) src0;
     (void) src0_dd;
 }
 
+
 inline void ggml_cuda_op_sum_rows(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -7789,6 +8156,10 @@ static void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, gg
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add);
 }
 
+static void ggml_cuda_acc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_acc);
+}
+
 static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul);
 }
@@ -7805,10 +8176,22 @@ static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, g
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
 }
 
+static void ggml_cuda_gelu_quick(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu_quick);
+}
+
+static void ggml_cuda_tanh(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_tanh);
+}
+
 static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
 }
 
+static void ggml_cuda_leaky_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_leaky_relu);
+}
+
 static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr);
 }
@@ -7817,6 +8200,22 @@ static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, g
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
 }
 
+static void ggml_cuda_group_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_group_norm);
+}
+
+static void ggml_cuda_concat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_concat);
+}
+
+static void ggml_cuda_upscale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_upscale);
+}
+
+static void ggml_cuda_pad(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_pad);
+}
+
 static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm);
 }
@@ -8809,6 +9208,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
         case GGML_OP_ADD:
             func = ggml_cuda_add;
             break;
+        case GGML_OP_ACC:
+            func = ggml_cuda_acc;
+            break;
         case GGML_OP_MUL:
             func = ggml_cuda_mul;
             break;
@@ -8823,6 +9225,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
                 case GGML_UNARY_OP_SILU:
                     func = ggml_cuda_silu;
                     break;
+                case GGML_UNARY_OP_GELU_QUICK:
+                    func = ggml_cuda_gelu_quick;
+                    break;
+                case GGML_UNARY_OP_TANH:
+                    func = ggml_cuda_tanh;
+                    break;
                 case GGML_UNARY_OP_RELU:
                     func = ggml_cuda_relu;
                     break;
@@ -8833,6 +9241,21 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
         case GGML_OP_NORM:
             func = ggml_cuda_norm;
             break;
+        case GGML_OP_GROUP_NORM:
+            func = ggml_cuda_group_norm;
+            break;
+        case GGML_OP_CONCAT:
+            func = ggml_cuda_concat;
+            break;
+        case GGML_OP_UPSCALE:
+            func = ggml_cuda_upscale;
+            break;
+        case GGML_OP_PAD:
+            func = ggml_cuda_pad;
+            break;
+        case GGML_OP_LEAKY_RELU:
+            func = ggml_cuda_leaky_relu;
+            break;
         case GGML_OP_RMS_NORM:
             func = ggml_cuda_rms_norm;
             break;
@@ -8855,9 +9278,6 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
             func = ggml_cuda_sqr;
             break;
         case GGML_OP_CLAMP:
-            if (!any_on_device) {
-                return false;
-            }
             func = ggml_cuda_clamp;
             break;
         case GGML_OP_CPY:
@@ -8866,6 +9286,7 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
         case GGML_OP_CONT:
             func = ggml_cuda_dup;
             break;
+        case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
         case GGML_OP_VIEW:
         case GGML_OP_PERMUTE:
@@ -9285,6 +9706,8 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
                 case GGML_UNARY_OP_GELU:
                 case GGML_UNARY_OP_SILU:
                 case GGML_UNARY_OP_RELU:
+                case GGML_UNARY_OP_GELU_QUICK:
+                case GGML_UNARY_OP_TANH:
                     return true;
                 default:
                     return false;
@@ -9369,6 +9792,12 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
         case GGML_OP_IM2COL:
         case GGML_OP_SUM_ROWS:
         case GGML_OP_ARGSORT:
+        case GGML_OP_ACC:
+        case GGML_OP_CONCAT:
+        case GGML_OP_GROUP_NORM:
+        case GGML_OP_UPSCALE:
+        case GGML_OP_PAD:
+        case GGML_OP_LEAKY_RELU:
             return true;
         default:
             return false;
index 1dcfa6eddbfa570dbbec056c6370ef0140f20681..465679a6bb0c8702f4f46a1ccf81161cad38b626 100644 (file)
@@ -66,9 +66,11 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(div_row);
     GGML_METAL_DECL_KERNEL(scale);
     GGML_METAL_DECL_KERNEL(scale_4);
-    GGML_METAL_DECL_KERNEL(silu);
+    GGML_METAL_DECL_KERNEL(tanh);
     GGML_METAL_DECL_KERNEL(relu);
     GGML_METAL_DECL_KERNEL(gelu);
+    GGML_METAL_DECL_KERNEL(gelu_quick);
+    GGML_METAL_DECL_KERNEL(silu);
     GGML_METAL_DECL_KERNEL(soft_max);
     GGML_METAL_DECL_KERNEL(soft_max_4);
     GGML_METAL_DECL_KERNEL(diag_mask_inf);
@@ -86,6 +88,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(get_rows_q5_K);
     GGML_METAL_DECL_KERNEL(get_rows_q6_K);
     GGML_METAL_DECL_KERNEL(rms_norm);
+    GGML_METAL_DECL_KERNEL(group_norm);
     GGML_METAL_DECL_KERNEL(norm);
     GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
     GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
@@ -145,8 +148,11 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(rope_f16);
     GGML_METAL_DECL_KERNEL(alibi_f32);
     GGML_METAL_DECL_KERNEL(im2col_f16);
+    GGML_METAL_DECL_KERNEL(upscale_f32);
+    GGML_METAL_DECL_KERNEL(pad_f32);
     GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
     GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
+    GGML_METAL_DECL_KERNEL(leaky_relu_f32);
     GGML_METAL_DECL_KERNEL(cpy_f32_f16);
     GGML_METAL_DECL_KERNEL(cpy_f32_f32);
     GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
@@ -334,9 +340,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(div_row);
         GGML_METAL_ADD_KERNEL(scale);
         GGML_METAL_ADD_KERNEL(scale_4);
-        GGML_METAL_ADD_KERNEL(silu);
+        GGML_METAL_ADD_KERNEL(tanh);
         GGML_METAL_ADD_KERNEL(relu);
         GGML_METAL_ADD_KERNEL(gelu);
+        GGML_METAL_ADD_KERNEL(gelu_quick);
+        GGML_METAL_ADD_KERNEL(silu);
         GGML_METAL_ADD_KERNEL(soft_max);
         GGML_METAL_ADD_KERNEL(soft_max_4);
         GGML_METAL_ADD_KERNEL(diag_mask_inf);
@@ -354,6 +362,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(get_rows_q5_K);
         GGML_METAL_ADD_KERNEL(get_rows_q6_K);
         GGML_METAL_ADD_KERNEL(rms_norm);
+        GGML_METAL_ADD_KERNEL(group_norm);
         GGML_METAL_ADD_KERNEL(norm);
         GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
         GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
@@ -415,8 +424,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(rope_f16);
         GGML_METAL_ADD_KERNEL(alibi_f32);
         GGML_METAL_ADD_KERNEL(im2col_f16);
+        GGML_METAL_ADD_KERNEL(upscale_f32);
+        GGML_METAL_ADD_KERNEL(pad_f32);
         GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
         GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
+        GGML_METAL_ADD_KERNEL(leaky_relu_f32);
         GGML_METAL_ADD_KERNEL(cpy_f32_f16);
         GGML_METAL_ADD_KERNEL(cpy_f32_f32);
         GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
@@ -450,9 +462,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(div_row);
     GGML_METAL_DEL_KERNEL(scale);
     GGML_METAL_DEL_KERNEL(scale_4);
-    GGML_METAL_DEL_KERNEL(silu);
+    GGML_METAL_DEL_KERNEL(tanh);
     GGML_METAL_DEL_KERNEL(relu);
     GGML_METAL_DEL_KERNEL(gelu);
+    GGML_METAL_DEL_KERNEL(gelu_quick);
+    GGML_METAL_DEL_KERNEL(silu);
     GGML_METAL_DEL_KERNEL(soft_max);
     GGML_METAL_DEL_KERNEL(soft_max_4);
     GGML_METAL_DEL_KERNEL(diag_mask_inf);
@@ -470,6 +484,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(get_rows_q5_K);
     GGML_METAL_DEL_KERNEL(get_rows_q6_K);
     GGML_METAL_DEL_KERNEL(rms_norm);
+    GGML_METAL_DEL_KERNEL(group_norm);
     GGML_METAL_DEL_KERNEL(norm);
     GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
     GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
@@ -531,8 +546,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(rope_f16);
     GGML_METAL_DEL_KERNEL(alibi_f32);
     GGML_METAL_DEL_KERNEL(im2col_f16);
+    GGML_METAL_DEL_KERNEL(upscale_f32);
+    GGML_METAL_DEL_KERNEL(pad_f32);
     GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
     GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
+    GGML_METAL_DEL_KERNEL(leaky_relu_f32);
     GGML_METAL_DEL_KERNEL(cpy_f32_f16);
     GGML_METAL_DEL_KERNEL(cpy_f32_f32);
     GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
@@ -843,9 +861,11 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
     switch (op->op) {
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(op)) {
-                case GGML_UNARY_OP_SILU:
+                case GGML_UNARY_OP_TANH:
                 case GGML_UNARY_OP_RELU:
                 case GGML_UNARY_OP_GELU:
+                case GGML_UNARY_OP_GELU_QUICK:
+                case GGML_UNARY_OP_SILU:
                     return true;
                 default:
                     return false;
@@ -853,11 +873,11 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
         case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
         case GGML_OP_VIEW:
-        case GGML_OP_PERMUTE:
         case GGML_OP_TRANSPOSE:
-        case GGML_OP_GET_ROWS:
+        case GGML_OP_PERMUTE:
         case GGML_OP_CONCAT:
         case GGML_OP_ADD:
+        case GGML_OP_ACC:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
         case GGML_OP_SCALE:
@@ -865,11 +885,15 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
         case GGML_OP_SUM_ROWS:
         case GGML_OP_SOFT_MAX:
         case GGML_OP_RMS_NORM:
+        case GGML_OP_GROUP_NORM:
         case GGML_OP_NORM:
         case GGML_OP_ALIBI:
         case GGML_OP_ROPE:
         case GGML_OP_IM2COL:
+        case GGML_OP_UPSCALE:
+        case GGML_OP_PAD:
         case GGML_OP_ARGSORT:
+        case GGML_OP_LEAKY_RELU:
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
             return true;
@@ -902,8 +926,9 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
                 };
             }
         case GGML_OP_DIAG_MASK_INF:
+        case GGML_OP_GET_ROWS:
             {
-                return op->ne[0] % 4 == 0;
+                return op->ne[3] == 1;
             }
         default:
             return false;
@@ -979,7 +1004,10 @@ void ggml_metal_graph_compute(
                         } break;
                 }
 
-                GGML_ASSERT(ggml_metal_supports_op(dst));
+                if (!ggml_metal_supports_op(dst)) {
+                    GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
+                    GGML_ASSERT(!"unsupported op");
+                }
 
                 const int64_t  ne00 = src0 ? src0->ne[0] : 0;
                 const int64_t  ne01 = src0 ? src0->ne[1] : 0;
@@ -1076,6 +1104,8 @@ void ggml_metal_graph_compute(
                     case GGML_OP_MUL:
                     case GGML_OP_DIV:
                         {
+                            const size_t offs = 0;
+
                             bool bcast_row = false;
 
                             int64_t nb = ne00;
@@ -1134,7 +1164,8 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:24];
                             [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:25];
                             [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:26];
-                            [encoder setBytes:&nb   length:sizeof(nb)   atIndex:27];
+                            [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
+                            [encoder setBytes:&nb   length:sizeof(nb)   atIndex:28];
 
                             if (bcast_row) {
                                 const int64_t n = ggml_nelements(dst)/4;
@@ -1146,6 +1177,86 @@ void ggml_metal_graph_compute(
                                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                             }
                         } break;
+                    case GGML_OP_ACC:
+                        {
+                            GGML_ASSERT(src0t == GGML_TYPE_F32);
+                            GGML_ASSERT(src1t == GGML_TYPE_F32);
+                            GGML_ASSERT(dstt  == GGML_TYPE_F32);
+
+                            GGML_ASSERT(ggml_is_contiguous(src0));
+                            GGML_ASSERT(ggml_is_contiguous(src1));
+
+                            const size_t pnb1 = ((int32_t *) dst->op_params)[0];
+                            const size_t pnb2 = ((int32_t *) dst->op_params)[1];
+                            const size_t pnb3 = ((int32_t *) dst->op_params)[2];
+                            const size_t offs = ((int32_t *) dst->op_params)[3];
+
+                            const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
+
+                            if (!inplace) {
+                                // run a separete kernel to cpy src->dst
+                                // not sure how to avoid this
+                                // TODO: make a simpler cpy_bytes kernel
+
+                                const int nth = MIN(1024, ne00);
+
+                                [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
+                                [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
+                                [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
+                                [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
+                                [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:3];
+                                [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:4];
+                                [encoder setBytes:&ne03    length:sizeof( int64_t) atIndex:5];
+                                [encoder setBytes:&nb00    length:sizeof(uint64_t) atIndex:6];
+                                [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:7];
+                                [encoder setBytes:&nb02    length:sizeof(uint64_t) atIndex:8];
+                                [encoder setBytes:&nb03    length:sizeof(uint64_t) atIndex:9];
+                                [encoder setBytes:&ne0     length:sizeof( int64_t) atIndex:10];
+                                [encoder setBytes:&ne1     length:sizeof( int64_t) atIndex:11];
+                                [encoder setBytes:&ne2     length:sizeof( int64_t) atIndex:12];
+                                [encoder setBytes:&ne3     length:sizeof( int64_t) atIndex:13];
+                                [encoder setBytes:&nb0     length:sizeof(uint64_t) atIndex:14];
+                                [encoder setBytes:&nb1     length:sizeof(uint64_t) atIndex:15];
+                                [encoder setBytes:&nb2     length:sizeof(uint64_t) atIndex:16];
+                                [encoder setBytes:&nb3     length:sizeof(uint64_t) atIndex:17];
+
+                                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                            }
+
+                            [encoder setComputePipelineState:ctx->pipeline_add];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+                            [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
+                            [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
+                            [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
+                            [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
+                            [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
+                            [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
+                            [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
+                            [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
+                            [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
+                            [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
+                            [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
+                            [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
+                            [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
+                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:19];
+                            [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:20];
+                            [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:21];
+                            [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:22];
+                            [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:23];
+                            [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
+                            [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
+                            [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
+                            [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
+
+                            const int nth = MIN(1024, ne0);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        } break;
                     case GGML_OP_SCALE:
                         {
                             GGML_ASSERT(ggml_is_contiguous(src0));
@@ -1169,16 +1280,15 @@ void ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_UNARY:
                         switch (ggml_get_unary_op(gf->nodes[i])) {
-                            case GGML_UNARY_OP_SILU:
+                            case GGML_UNARY_OP_TANH:
                                 {
-                                    [encoder setComputePipelineState:ctx->pipeline_silu];
+                                    [encoder setComputePipelineState:ctx->pipeline_tanh];
                                     [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                     [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 
                                     const int64_t n = ggml_nelements(dst);
-                                    GGML_ASSERT(n % 4 == 0);
 
-                                    [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                                 } break;
                             case GGML_UNARY_OP_RELU:
                                 {
@@ -1199,6 +1309,28 @@ void ggml_metal_graph_compute(
                                     const int64_t n = ggml_nelements(dst);
                                     GGML_ASSERT(n % 4 == 0);
 
+                                    [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                } break;
+                            case GGML_UNARY_OP_GELU_QUICK:
+                                {
+                                    [encoder setComputePipelineState:ctx->pipeline_gelu_quick];
+                                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                                    const int64_t n = ggml_nelements(dst);
+                                    GGML_ASSERT(n % 4 == 0);
+
+                                    [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                } break;
+                            case GGML_UNARY_OP_SILU:
+                                {
+                                    [encoder setComputePipelineState:ctx->pipeline_silu];
+                                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                                    const int64_t n = ggml_nelements(dst);
+                                    GGML_ASSERT(n % 4 == 0);
+
                                     [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                                 } break;
                             default:
@@ -1837,6 +1969,38 @@ void ggml_metal_graph_compute(
 
                             [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
+                    case GGML_OP_GROUP_NORM:
+                        {
+                            GGML_ASSERT(ne00 % 4 == 0);
+
+                            //float eps;
+                            //memcpy(&eps, dst->op_params, sizeof(float));
+
+                            const float eps = 1e-6f; // TODO: temporarily hardcoded
+
+                            const int32_t n_groups = ((int32_t *) dst->op_params)[0];
+
+                            int nth = 32; // SIMD width
+
+                            //while (nth < ne00/4 && nth < 1024) {
+                            //    nth *= 2;
+                            //}
+
+                            [encoder setComputePipelineState:ctx->pipeline_group_norm];
+                            [encoder setBuffer:id_src0  offset:offs_src0        atIndex:0];
+                            [encoder setBuffer:id_dst   offset:offs_dst         atIndex:1];
+                            [encoder setBytes:&ne00     length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&ne01     length:sizeof( int64_t) atIndex:3];
+                            [encoder setBytes:&ne02     length:sizeof( int64_t) atIndex:4];
+                            [encoder setBytes:&nb00     length:sizeof(uint64_t) atIndex:5];
+                            [encoder setBytes:&nb01     length:sizeof(uint64_t) atIndex:6];
+                            [encoder setBytes:&nb02     length:sizeof(uint64_t) atIndex:7];
+                            [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
+                            [encoder setBytes:&eps      length:sizeof(   float) atIndex:9];
+                            [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        } break;
                     case GGML_OP_NORM:
                         {
                             float eps;
@@ -2006,6 +2170,65 @@ void ggml_metal_graph_compute(
 
                             [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
                         } break;
+                    case GGML_OP_UPSCALE:
+                        {
+                            GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+                            const int sf = dst->op_params[0];
+
+                            [encoder setComputePipelineState:ctx->pipeline_upscale_f32];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                            [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+                            [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+                            [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+                            [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+                            [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
+                            [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
+                            [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
+                            [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
+                            [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
+                            [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
+                            [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
+                            [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
+                            [encoder setBytes:&sf   length:sizeof(sf)   atIndex:18];
+
+                            const int nth = MIN(1024, ne0);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        } break;
+                    case GGML_OP_PAD:
+                        {
+                            GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+                            [encoder setComputePipelineState:ctx->pipeline_pad_f32];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                            [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+                            [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+                            [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+                            [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+                            [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
+                            [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
+                            [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
+                            [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
+                            [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
+                            [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
+                            [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
+                            [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
+
+                            const int nth = MIN(1024, ne0);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        } break;
                     case GGML_OP_ARGSORT:
                         {
                             GGML_ASSERT(src0->type == GGML_TYPE_F32);
@@ -2027,6 +2250,22 @@ void ggml_metal_graph_compute(
 
                             [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
                         } break;
+                    case GGML_OP_LEAKY_RELU:
+                        {
+                            GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+                            float slope;
+                            memcpy(&slope, dst->op_params, sizeof(float));
+
+                            [encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32];
+                            [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst    atIndex:1];
+                            [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
+
+                            const int64_t n = ggml_nelements(dst);
+
+                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
                     case GGML_OP_DUP:
                     case GGML_OP_CPY:
                     case GGML_OP_CONT:
index 773fac124b0c486f21e328fc9f0cb2ea99289e47..fe0ada445a2d45c4f295021565a8e9247c9b2dfc 100644 (file)
@@ -79,6 +79,7 @@ kernel void kernel_add(
         constant  int64_t & nb1,
         constant  int64_t & nb2,
         constant  int64_t & nb3,
+        constant  int64_t & offs,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]],
         uint3   ntg[[threads_per_threadgroup]]) {
@@ -90,9 +91,9 @@ kernel void kernel_add(
     const int64_t i12 = i02 % ne12;
     const int64_t i11 = i01 % ne11;
 
-    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
     device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
-    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
+    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1  + offs;
 
     for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
         const int i10 = i0 % ne10;
@@ -204,7 +205,7 @@ kernel void kernel_add_row(
         device const float4 * src0,
         device const float4 * src1,
         device       float4 * dst,
-        constant    int64_t & nb [[buffer(27)]],
+        constant    int64_t & nb [[buffer(28)]],
         uint tpig[[thread_position_in_grid]]) {
     dst[tpig] = src0[tpig] + src1[tpig % nb];
 }
@@ -213,7 +214,7 @@ kernel void kernel_mul_row(
         device const float4 * src0,
         device const float4 * src1,
         device       float4 * dst,
-        constant    int64_t & nb  [[buffer(27)]],
+        constant    int64_t & nb  [[buffer(28)]],
         uint tpig[[thread_position_in_grid]]) {
     dst[tpig] = src0[tpig] * src1[tpig % nb];
 }
@@ -222,7 +223,7 @@ kernel void kernel_div_row(
         device const float4 * src0,
         device const float4 * src1,
         device       float4 * dst,
-        constant    int64_t & nb  [[buffer(27)]],
+        constant    int64_t & nb  [[buffer(28)]],
         uint tpig[[thread_position_in_grid]]) {
     dst[tpig] = src0[tpig] / src1[tpig % nb];
 }
@@ -243,19 +244,53 @@ kernel void kernel_scale_4(
     dst[tpig] = src0[tpig] * scale;
 }
 
-kernel void kernel_silu(
-        device const float4 * src0,
-        device       float4 * dst,
+kernel void kernel_relu(
+        device const float * src0,
+        device       float * dst,
         uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-    dst[tpig] = x / (1.0f + exp(-x));
+    dst[tpig] = max(0.0f, src0[tpig]);
 }
 
-kernel void kernel_relu(
+kernel void kernel_tanh(
         device const float * src0,
         device       float * dst,
         uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = max(0.0f, src0[tpig]);
+    device const float & x = src0[tpig];
+    dst[tpig] = precise::tanh(x);
+}
+
+constant float GELU_COEF_A     = 0.044715f;
+constant float GELU_QUICK_COEF = -1.702f;
+constant float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
+
+kernel void kernel_gelu(
+    device const float4 * src0,
+    device       float4 * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    device const float4 & x = src0[tpig];
+
+    // BEWARE !!!
+    // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
+    // This was observed with Falcon 7B and 40B models
+    //
+    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+kernel void kernel_gelu_quick(
+    device const float4 * src0,
+    device       float4 * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    device const float4 & x = src0[tpig];
+
+    dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
+}
+
+kernel void kernel_silu(
+        device const float4 * src0,
+        device       float4 * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    device const float4 & x = src0[tpig];
+    dst[tpig] = x / (1.0f + exp(-x));
 }
 
 kernel void kernel_sqr(
@@ -313,22 +348,6 @@ kernel void kernel_sum_rows(
     dst_row[0] = row_sum;
 }
 
-constant float GELU_COEF_A    = 0.044715f;
-constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
-
-kernel void kernel_gelu(
-    device const float4 * src0,
-    device       float4 * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-
-    // BEWARE !!!
-    // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
-    // This was observed with Falcon 7B and 40B models
-    //
-    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
-}
-
 kernel void kernel_soft_max(
         device const float * src0,
         device const float * src1,
@@ -650,6 +669,94 @@ kernel void kernel_rms_norm(
     }
 }
 
+kernel void kernel_group_norm(
+        device const float * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int32_t & n_groups,
+        constant     float & eps,
+        threadgroup float  * buf [[threadgroup(0)]],
+        uint tgpig[[threadgroup_position_in_grid]],
+        uint tpitg[[thread_position_in_threadgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]],
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint   ntg[[threads_per_threadgroup]]) {
+    const int64_t ne = ne00*ne01*ne02;
+    const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
+
+    int start = tgpig * gs;
+    int end   = start + gs;
+
+    start += tpitg;
+
+    if (end >= ne) {
+        end = ne;
+    }
+
+    float tmp = 0.0f; // partial sum for thread in warp
+
+    for (int j = start; j < end; j += ntg) {
+        tmp += src0[j];
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    tmp = simd_sum(tmp);
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            buf[sgitg] = tmp;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        tmp = buf[tiisg];
+        tmp = simd_sum(tmp);
+    }
+
+    const float mean = tmp / gs;
+    tmp = 0.0f;
+
+    for (int j = start; j < end; j += ntg) {
+        float xi = src0[j] - mean;
+        dst[j] = xi;
+        tmp += xi * xi;
+    }
+
+    tmp = simd_sum(tmp);
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            buf[sgitg] = tmp;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        tmp = buf[tiisg];
+        tmp = simd_sum(tmp);
+    }
+
+    const float variance = tmp / gs;
+    const float scale = 1.0f/sqrt(variance + eps);
+    for (int j = start; j < end; j += ntg) {
+        dst[j] *= scale;
+    }
+}
+
 // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
 // il indicates where the q4 quants begin (0 or QK4_0/4)
 // we assume that the yl's have been multiplied with the appropriate scale factor
@@ -1656,6 +1763,97 @@ kernel void kernel_im2col_f16(
     }
 }
 
+kernel void kernel_upscale_f32(
+    device  const char * src0,
+    device        char * dst,
+    constant   int64_t & ne00,
+    constant   int64_t & ne01,
+    constant   int64_t & ne02,
+    constant   int64_t & ne03,
+    constant  uint64_t & nb00,
+    constant  uint64_t & nb01,
+    constant  uint64_t & nb02,
+    constant  uint64_t & nb03,
+    constant   int64_t & ne0,
+    constant   int64_t & ne1,
+    constant   int64_t & ne2,
+    constant   int64_t & ne3,
+    constant  uint64_t & nb0,
+    constant  uint64_t & nb1,
+    constant  uint64_t & nb2,
+    constant  uint64_t & nb3,
+    constant   int32_t & sf,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    const int64_t i3 = tgpig.z;
+    const int64_t i2 = tgpig.y;
+    const int64_t i1 = tgpig.x;
+
+    const int64_t i03 = i3;
+    const int64_t i02 = i2;
+    const int64_t i01 = i1/sf;
+
+    device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
+    device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1);
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        dst_ptr[i0] = src0_ptr[i0/sf];
+    }
+}
+
+kernel void kernel_pad_f32(
+    device  const char * src0,
+    device        char * dst,
+    constant   int64_t & ne00,
+    constant   int64_t & ne01,
+    constant   int64_t & ne02,
+    constant   int64_t & ne03,
+    constant  uint64_t & nb00,
+    constant  uint64_t & nb01,
+    constant  uint64_t & nb02,
+    constant  uint64_t & nb03,
+    constant   int64_t & ne0,
+    constant   int64_t & ne1,
+    constant   int64_t & ne2,
+    constant   int64_t & ne3,
+    constant  uint64_t & nb0,
+    constant  uint64_t & nb1,
+    constant  uint64_t & nb2,
+    constant  uint64_t & nb3,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    const int64_t i3 = tgpig.z;
+    const int64_t i2 = tgpig.y;
+    const int64_t i1 = tgpig.x;
+
+    const int64_t i03 = i3;
+    const int64_t i02 = i2;
+    const int64_t i01 = i1;
+
+    device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
+    device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1);
+
+    if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
+        for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+            if (i0 < ne00) {
+                dst_ptr[i0] = src0_ptr[i0];
+            } else {
+                dst_ptr[i0] = 0.0f;
+            }
+        }
+
+        return;
+    }
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        dst_ptr[i0] = 0.0f;
+    }
+}
+
 // bitonic sort implementation following the CUDA kernels as reference
 typedef void (argsort_t)(
         device const float * x,
@@ -1708,6 +1906,14 @@ kernel void kernel_argsort_f32_i32(
 template [[host_name("kernel_argsort_f32_i32_asc")]]  kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
 template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
 
+kernel void kernel_leaky_relu_f32(
+        device const float * src0,
+        device       float * dst,
+        constant     float & slope,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
+}
+
 kernel void kernel_cpy_f16_f16(
         device  const half * src0,
         device        half * dst,
@@ -2066,9 +2272,9 @@ kernel void kernel_cpy_f32_q4_1(
 }
 
 kernel void kernel_concat(
-    device const char * src0,
-    device const char * src1,
-    device       char * dst,
+    device  const char * src0,
+    device  const char * src1,
+    device        char * dst,
     constant   int64_t & ne00,
     constant   int64_t & ne01,
     constant   int64_t & ne02,
@@ -2105,7 +2311,7 @@ kernel void kernel_concat(
     const int64_t i12 = i02 % ne12;
     const int64_t i11 = i01 % ne11;
 
-    device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
+    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
     device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
     device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1  + tpitg.x*nb0;
 
@@ -3315,10 +3521,10 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
 
 template <typename type4x4>
 void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
-    const half d = xb->d;
-    const half min = xb->dmin;
+    const float d = xb->d;
+    const float min = xb->dmin;
     device const uint8_t * q = (device const uint8_t *)xb->qs;
-    half dl, ml;
+    float dl, ml;
     uint8_t sc = xb->scales[il];
 
 #if QK_K == 256
@@ -3388,10 +3594,10 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
     q = q + (il/4) * 32 + 16 * (il&1);
     il = il & 3;
     const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
-    const half d   = il < 2 ? xb->d : xb->d / 16.h;
-    const half min = xb->dmin;
-    const half dl = d * sc[0];
-    const half ml = min * sc[1];
+    const float d   = il < 2 ? xb->d : xb->d / 16.h;
+    const float min = xb->dmin;
+    const float dl = d * sc[0];
+    const float ml = min * sc[1];
 #else
     q = q + 16 * (il&1);
     device const uint8_t * s = xb->scales;
@@ -3418,13 +3624,13 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
     uint8_t ul = 1 << (il/2);
     il = il & 3;
     const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
-    const half d = il < 2 ? xb->d : xb->d / 16.h;
-    const half min = xb->dmin;
-    const half dl = d * sc[0];
-    const half ml = min * sc[1];
+    const float d = il < 2 ? xb->d : xb->d / 16.h;
+    const float min = xb->dmin;
+    const float dl = d * sc[0];
+    const float ml = min * sc[1];
 
-    const ushort mask = il<2 ? 0x0F : 0xF0;
-    const half qh_val = il<2 ? 16.h : 256.h;
+    const ushort mask  = il<2 ? 0x0F : 0xF0;
+    const float qh_val = il<2 ? 16.f : 256.f;
     for (int i = 0; i < 16; ++i) {
         reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
     }
diff --git a/ggml.c b/ggml.c
index 66658ff4b24cef452df166f68aeac763106460ba..29e18a24c76a85f7babfcfe1652c32f0746aeadd 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -1395,7 +1395,7 @@ inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) {
 inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]);  }
 inline static void ggml_vec_elu_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
 inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
-inline static void ggml_vec_leaky_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.1f*x[i]; }
+inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
 
 static const float GELU_COEF_A     = 0.044715f;
 static const float GELU_QUICK_COEF = -1.702f;
@@ -1623,7 +1623,9 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "POOL_1D",
     "POOL_2D",
     "UPSCALE",
+    "PAD",
     "ARGSORT",
+    "LEAKY_RELU",
 
     "FLASH_ATTN",
     "FLASH_FF",
@@ -1650,7 +1652,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CROSS_ENTROPY_LOSS_BACK",
 };
 
-static_assert(GGML_OP_COUNT == 70, "GGML_OP_COUNT != 70");
+static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1707,7 +1709,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "pool_1d(x)",
     "pool_2d(x)",
     "upscale(x)",
+    "pad(x)",
     "argsort(x)",
+    "leaky_relu(x)",
 
     "flash_attn(x)",
     "flash_ff(x)",
@@ -1734,7 +1738,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "cross_entropy_loss_back(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 70, "GGML_OP_COUNT != 70");
+static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -1750,10 +1754,9 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
     "GELU",
     "GELU_QUICK",
     "SILU",
-    "LEAKY",
 };
 
-static_assert(GGML_UNARY_OP_COUNT == 11, "GGML_UNARY_OP_COUNT != 11");
+static_assert(GGML_UNARY_OP_COUNT == 10, "GGML_UNARY_OP_COUNT != 10");
 
 
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -3830,12 +3833,25 @@ struct ggml_tensor * ggml_relu_inplace(
     return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU);
 }
 
-// ggml_leaky
+// ggml_leaky_relu
 
-struct ggml_tensor * ggml_leaky(
+struct ggml_tensor * ggml_leaky_relu(
         struct ggml_context * ctx,
-        struct ggml_tensor  * a) {
-    return ggml_unary(ctx, a, GGML_UNARY_OP_LEAKY);
+        struct ggml_tensor  * a, float negative_slope, bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+    ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
+
+    result->op   = GGML_OP_LEAKY_RELU;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+
+    return result;
 }
 
 // ggml_gelu
@@ -4022,8 +4038,9 @@ static struct ggml_tensor * ggml_group_norm_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    result->op = GGML_OP_GROUP_NORM;
     result->op_params[0] = n_groups;
+
+    result->op = GGML_OP_GROUP_NORM;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
     result->src[1] = NULL; // TODO: maybe store epsilon here?
@@ -5523,6 +5540,30 @@ static struct ggml_tensor * ggml_upscale_impl(
     return result;
 }
 
+struct ggml_tensor * ggml_pad(
+    struct ggml_context * ctx,
+    struct ggml_tensor  * a,
+    int p0, int p1, int p2, int p3) {
+    bool is_node = false;
+
+    if (a->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
+            a->ne[0] + p0,
+            a->ne[1] + p1,
+            a->ne[2] + p2,
+            a->ne[3] + p3);
+
+    result->op = GGML_OP_PAD;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+
+    return result;
+}
+
 struct ggml_tensor * ggml_upscale(
     struct ggml_context * ctx,
     struct ggml_tensor * a,
@@ -7718,8 +7759,10 @@ static void ggml_compute_forward_mul_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
+// TODO: OpenCL kernel support broadcast
 #ifdef GGML_USE_CLBLAST
     if (src1->backend == GGML_BACKEND_GPU) {
+        GGML_ASSERT(ggml_are_same_shape(src0, src1));
         if (ith == 0) {
             ggml_cl_mul(src0, src1, dst);
         }
@@ -8985,10 +9028,9 @@ static void ggml_compute_forward_silu(
             } break;
     }
 }
+// ggml_compute_forward_leaky_relu
 
-// ggml_compute_forward_leaky
-
-static void ggml_compute_forward_leaky_f32(
+static void ggml_compute_forward_leaky_relu_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
@@ -9002,24 +9044,27 @@ static void ggml_compute_forward_leaky_f32(
     const int n  = ggml_nrows(src0);
     const int nc = src0->ne[0];
 
+    float negative_slope;
+    memcpy(&negative_slope, dst->op_params, sizeof(float));
+
     assert(dst->nb[0]  == sizeof(float));
     assert(src0->nb[0] == sizeof(float));
 
     for (int i = 0; i < n; i++) {
-        ggml_vec_leaky_f32(nc,
+        ggml_vec_leaky_relu_f32(nc,
                 (float *) ((char *) dst->data  + i*( dst->nb[1])),
-                (float *) ((char *) src0->data + i*(src0->nb[1])));
+                (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
     }
 }
 
-static void ggml_compute_forward_leaky(
+static void ggml_compute_forward_leaky_relu(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_leaky_f32(params, src0, dst);
+                ggml_compute_forward_leaky_relu_f32(params, src0, dst);
             } break;
         default:
             {
@@ -12158,6 +12203,7 @@ static void ggml_compute_forward_upscale_f32(
     GGML_ASSERT(src0->nb[0] == sizeof(float));
 
     const int ith = params->ith;
+    const int nth = params->nth;
 
     GGML_TENSOR_UNARY_OP_LOCALS
 
@@ -12165,16 +12211,17 @@ static void ggml_compute_forward_upscale_f32(
 
     // TODO: optimize
 
-    for (int i03 = 0; i03 < ne03; i03++) {
-        for (int i02 = ith; i02 < ne02; i02++) {
-            for (int m = 0; m < dst->ne[1]; m++) {
-                int i01 = m / scale_factor;
-                for (int n = 0; n < dst->ne[0]; n++) {
-                    int i00 = n / scale_factor;
-
-                    const float * x = (float *)((char *) src0->data + i00 * nb00 +i01 * nb01 + i02 * nb02 + i03 * nb03);
+    for (int64_t i3 = 0; i3 < ne3; i3++) {
+        const int64_t i03 = i3;
+        for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
+            const int64_t i02 = i2;
+            for (int64_t i1 = 0; i1 < ne1; i1++) {
+                const int64_t i01 = i1 / scale_factor;
+                for (int64_t i0 = 0; i0 < ne0; i0++) {
+                    const int64_t i00 = i0 / scale_factor;
 
-                    float * y = (float *)((char *) dst->data + n * dst->nb[0] + m * dst->nb[1] + i02 * dst->nb[2] + i03 * dst->nb[3]);
+                    const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                          float * y = (float *)((char *)  dst->data +  i0*nb0  +  i1*nb1  +  i2*nb2  +  i3*nb3);
 
                     *y = *x;
                 }
@@ -12199,6 +12246,64 @@ static void ggml_compute_forward_upscale(
     }
 }
 
+// ggml_compute_forward_pad
+
+static void ggml_compute_forward_pad_f32(
+    const struct ggml_compute_params * params,
+    const struct ggml_tensor * src0,
+          struct ggml_tensor * dst) {
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT( dst->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    float * dst_ptr = (float *) dst->data;
+
+    // TODO: optimize
+
+    for (int64_t i2 = 0; i2 < ne2; ++i2) {
+        for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
+            for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                for (int64_t i3 = 0; i3 < ne3; ++i3) {
+                    const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
+
+                    const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+
+                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+                        dst_ptr[dst_idx] = *src_ptr;
+                    } else {
+                        dst_ptr[dst_idx] = 0;
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_pad(
+    const struct ggml_compute_params * params,
+    const struct ggml_tensor * src0,
+    struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_pad_f32(params, src0, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_argsort
 
 static void ggml_compute_forward_argsort_f32(
@@ -13406,10 +13511,6 @@ static void ggml_compute_forward_unary(
             {
                 ggml_compute_forward_silu(params, src0, dst);
             } break;
-        case GGML_UNARY_OP_LEAKY:
-            {
-                ggml_compute_forward_leaky(params, src0, dst);
-            } break;
         default:
             {
                 GGML_ASSERT(false);
@@ -14191,10 +14292,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_upscale(params, tensor->src[0], tensor);
             } break;
+        case GGML_OP_PAD:
+            {
+                ggml_compute_forward_pad(params, tensor->src[0], tensor);
+            } break;
         case GGML_OP_ARGSORT:
             {
                 ggml_compute_forward_argsort(params, tensor->src[0], tensor);
             } break;
+        case GGML_OP_LEAKY_RELU:
+            {
+                ggml_compute_forward_leaky_relu(params, tensor->src[0], tensor);
+            } break;
         case GGML_OP_FLASH_ATTN:
             {
                 const int32_t t = ggml_get_op_params_i32(tensor, 0);
@@ -15187,10 +15296,18 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
+        case GGML_OP_PAD:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_ARGSORT:
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
+        case GGML_OP_LEAKY_RELU:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_FLASH_ATTN:
             {
                 struct ggml_tensor * flash_grad = NULL;
@@ -15796,6 +15913,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_ARGMAX:
         case GGML_OP_REPEAT:
         case GGML_OP_REPEAT_BACK:
+        case GGML_OP_LEAKY_RELU:
             {
                 n_tasks = 1;
             } break;
@@ -15808,7 +15926,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
                 case GGML_UNARY_OP_TANH:
                 case GGML_UNARY_OP_ELU:
                 case GGML_UNARY_OP_RELU:
-                case GGML_UNARY_OP_LEAKY:
                     {
                         n_tasks = 1;
                     } break;
@@ -15927,6 +16044,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             {
                 n_tasks = n_threads;
             } break;
+        case GGML_OP_PAD:
+            {
+                n_tasks = n_threads;
+            } break;
         case GGML_OP_ARGSORT:
             {
                 n_tasks = n_threads;
diff --git a/ggml.h b/ggml.h
index 32f256481615ce3703d08997bfff022a74283cff..1447646b13c439a94026cf9af1cc4b42aa617fef 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -423,7 +423,9 @@ extern "C" {
         GGML_OP_POOL_1D,
         GGML_OP_POOL_2D,
         GGML_OP_UPSCALE, // nearest interpolate
+        GGML_OP_PAD,
         GGML_OP_ARGSORT,
+        GGML_OP_LEAKY_RELU,
 
         GGML_OP_FLASH_ATTN,
         GGML_OP_FLASH_FF,
@@ -463,7 +465,6 @@ extern "C" {
         GGML_UNARY_OP_GELU,
         GGML_UNARY_OP_GELU_QUICK,
         GGML_UNARY_OP_SILU,
-        GGML_UNARY_OP_LEAKY,
 
         GGML_UNARY_OP_COUNT,
     };
@@ -793,6 +794,9 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    // dst = a
+    // view(dst, nb1, nb2, nb3, offset) += b
+    // return dst
     GGML_API struct ggml_tensor * ggml_acc(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -957,15 +961,14 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
-    GGML_API struct ggml_tensor * ggml_leaky(
+    GGML_API struct ggml_tensor * ggml_leaky_relu(
             struct ggml_context * ctx,
-            struct ggml_tensor  * a);
+            struct ggml_tensor  * a, float negative_slope, bool inplace);
 
     GGML_API struct ggml_tensor * ggml_relu_inplace(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
-    // TODO: double-check this computation is correct
     GGML_API struct ggml_tensor * ggml_gelu(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
@@ -1551,6 +1554,15 @@ extern "C" {
             struct ggml_tensor  * a,
             int                   scale_factor);
 
+    // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
+    GGML_API struct ggml_tensor * ggml_pad(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                  p0,
+            int                  p1,
+            int                  p2,
+            int                  p3);
+
     // sort rows
     enum ggml_sort_order {
         GGML_SORT_ASC,
index 44830b4d4da301a23bcf5b0fdeae89540a1a44f2..afca85143fb30e1f1f85698a815c75b99b24bf42 100644 (file)
@@ -234,6 +234,11 @@ static bool ggml_is_view_op(enum ggml_op op) {
     return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE;
 }
 
+enum test_mode {
+    MODE_TEST,
+    MODE_PERF,
+};
+
 struct test_case {
     virtual ~test_case() {}
 
@@ -268,7 +273,58 @@ struct test_case {
         return size;
     }
 
+    ggml_cgraph * gf = nullptr;
+
+    static const int sentinel_size = 1024;
+
+    test_mode mode;
+
+    std::vector<ggml_tensor *> sentinels;
+
+    void add_sentinel(ggml_context * ctx) {
+        if (mode == MODE_PERF) {
+            return;
+        }
+        ggml_tensor * sentinel = ::ggml_new_tensor_1d(ctx, GGML_TYPE_F32, sentinel_size);
+        ggml_format_name(sentinel, "sent_%zu", sentinels.size());
+        sentinels.push_back(sentinel);
+    }
+
+    // hijack ggml_new_tensor to add sentinels after each tensor to check for overflows in the backend
+
+    ggml_tensor * ggml_new_tensor(ggml_context * ctx, ggml_type type, int n_dims, const int64_t * ne) {
+        ggml_tensor * t = ::ggml_new_tensor(ctx, type, n_dims, ne);
+        add_sentinel(ctx);
+        return t;
+    }
+
+    ggml_tensor * ggml_new_tensor_1d(ggml_context * ctx, ggml_type type, int64_t ne0) {
+        ggml_tensor * t = ::ggml_new_tensor_1d(ctx, type, ne0);
+        add_sentinel(ctx);
+        return t;
+    }
+
+    ggml_tensor * ggml_new_tensor_2d(ggml_context * ctx, ggml_type type, int64_t ne0, int64_t ne1) {
+        ggml_tensor * t = ::ggml_new_tensor_2d(ctx, type, ne0, ne1);
+        add_sentinel(ctx);
+        return t;
+    }
+
+    ggml_tensor * ggml_new_tensor_3d(ggml_context * ctx, ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2) {
+        ggml_tensor * t = ::ggml_new_tensor_3d(ctx, type, ne0, ne1, ne2);
+        add_sentinel(ctx);
+        return t;
+    }
+
+    ggml_tensor * ggml_new_tensor_4d(ggml_context * ctx, ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
+        ggml_tensor * t = ::ggml_new_tensor_4d(ctx, type, ne0, ne1, ne2, ne3);
+        add_sentinel(ctx);
+        return t;
+    }
+
     bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name) {
+        mode = MODE_TEST;
+
         ggml_init_params params = {
             /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead(),
             /* .mem_base = */ NULL,
@@ -276,6 +332,11 @@ struct test_case {
         };
         ggml_context * ctx = ggml_init(params);
 
+        gf = ggml_new_graph(ctx);
+
+        // pre-graph sentinel
+        add_sentinel(ctx);
+
         ggml_tensor * out = build_graph(ctx);
 
         if (op_name != nullptr && op_desc(out) != op_name) {
@@ -296,13 +357,20 @@ struct test_case {
             }
         }
 
+        // post-graph sentinel
+        add_sentinel(ctx);
+
         // allocate
         ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend1);
 
         // build graph
-        ggml_cgraph * gf = ggml_new_graph(ctx);
         ggml_build_forward_expand(gf, out);
 
+        // add sentinels as graph nodes so that they are checked in the callback
+        for (ggml_tensor * sentinel : sentinels) {
+            gf->nodes[gf->n_nodes++] = sentinel;
+        }
+
         // randomize tensors
         initialize_tensors(ctx);
 
@@ -318,9 +386,24 @@ struct test_case {
         };
 
         auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool {
+            callback_userdata * ud = (callback_userdata *) user_data;
+
+            if (t1->op == GGML_OP_NONE) {
+                // sentinels must be unchanged
+                std::vector<uint8_t> t1_data(ggml_nbytes(t1));
+                std::vector<uint8_t> t2_data(ggml_nbytes(t2));
+                ggml_backend_tensor_get(t1, t1_data.data(), 0, ggml_nbytes(t1));
+                ggml_backend_tensor_get(t2, t2_data.data(), 0, ggml_nbytes(t2));
+
+                if (memcmp(t1_data.data(), t2_data.data(), ggml_nbytes(t1)) != 0) {
+                    printf("sentinel mismatch: %s ", t1->name);
+                    ud->ok = false;
+                    return true;
+                }
+            }
+
             std::vector<float> f1 = tensor_to_float(t1);
             std::vector<float> f2 = tensor_to_float(t2);
-            callback_userdata * ud = (callback_userdata *) user_data;
 
             for (size_t i = 0; i < f1.size(); i++) {
                 // check for nans
@@ -349,9 +432,10 @@ struct test_case {
             if (err > ud->max_err) {
                 printf("[%s] NMSE = %f ", ggml_op_desc(t1), err);
                 //for (int i = 0; i < f1.size(); i++) {
-                //    printf("(%f, %f) ", f1[i], f2[i]);
+                //    printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
                 //}
                 //printf("\n");
+                //exit(1);
                 ud->ok = false;
             }
             return true;
@@ -375,6 +459,8 @@ struct test_case {
     }
 
     bool eval_perf(ggml_backend_t backend, const char * op_name) {
+        mode = MODE_PERF;
+
         static const size_t graph_nodes = 8192;
 
         ggml_init_params params = {
@@ -1135,6 +1221,118 @@ struct test_sum_rows : public test_case {
     }
 };
 
+// GGML_OP_UPSCALE
+struct test_upscale : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    const int32_t scale_factor;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, scale_factor);
+    }
+
+    test_upscale(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {512, 512, 3, 1},
+            int32_t scale_factor = 2)
+        : type(type), ne(ne), scale_factor(scale_factor) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_upscale(ctx, a, scale_factor);
+        return out;
+    }
+};
+
+// GGML_OP_GROUP_NORM
+struct test_group_norm : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    const int32_t num_groups;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, num_groups);
+    }
+
+    test_group_norm(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {64, 64, 320, 1},
+            int32_t num_groups = 32)
+        : type(type), ne(ne), num_groups(num_groups) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_group_norm(ctx, a, num_groups);
+        return out;
+    }
+};
+
+// GGML_OP_ACC
+struct test_acc : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne_a;
+    const std::array<int64_t, 4> ne_b;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne_a, ne_b);
+    }
+
+    test_acc(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne_a = {1024, 577, 1, 1},
+            std::array<int64_t, 4> ne_b = {1024, 576, 1, 1})
+        : type(type), ne_a(ne_a), ne_b(ne_b) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
+        ggml_tensor * out = ggml_acc(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], b->nb[1]);
+        return out;
+    }
+};
+
+// GGML_OP_PAD
+struct test_pad : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne_a;
+    const int pad_0;
+    const int pad_1;
+
+    std::string vars() override {
+        return VARS_TO_STR4(type, ne_a, pad_0, pad_1);
+    }
+
+    test_pad(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne_a = {512, 512, 1, 1},
+            int pad_0 = 1, int pad_1 = 1)
+        : type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1)  {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+        ggml_tensor * out = ggml_pad(ctx, a, pad_0, pad_1, 0, 0);
+        return out;
+    }
+};
+
+// GGML_OP_LEAKY_RELU
+struct test_leaky_relu : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne_a;
+    const float negative_slope;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne_a, negative_slope);
+    }
+
+    test_leaky_relu(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne_a = {10, 10, 10, 10},
+            float negative_slope = 0.1f)
+        : type(type), ne_a(ne_a), negative_slope(negative_slope)  {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+        ggml_tensor * out = ggml_leaky_relu(ctx, a, negative_slope, true);
+        return out;
+    }
+};
+
 // Mixtral MOE
 struct test_moe : public test_case {
     const int n_experts;
@@ -1219,11 +1417,6 @@ struct test_moe : public test_case {
     }
 };
 
-enum test_mode {
-    MODE_TEST,
-    MODE_PERF,
-};
-
 static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
     std::vector<std::unique_ptr<test_case>> test_cases;
 
@@ -1372,12 +1565,16 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
     }
 
-    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, {10, 10, 10, 10}));
-    test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, {2, 1, 1, 1}));
+    test_cases.emplace_back(new test_sum_rows());
+    test_cases.emplace_back(new test_upscale());
+    test_cases.emplace_back(new test_group_norm());
+    test_cases.emplace_back(new test_acc());
+    test_cases.emplace_back(new test_pad());
+    test_cases.emplace_back(new test_leaky_relu());
 
 #if !defined(__SANITIZE_THREAD__)
     // FIXME: these tests use too much memory with thread sanitizer
-    test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 14336));
+    test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 8*1024));
     //test_cases.emplace_back(new test_moe(8, 2, 8, 4096, 14336));
 #endif