]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
add some new ops, fix some operators and add batch operations to certain operators...
authorleejet <redacted>
Sun, 3 Mar 2024 12:23:52 +0000 (20:23 +0800)
committerGeorgi Gerganov <redacted>
Fri, 8 Mar 2024 09:38:31 +0000 (11:38 +0200)
* cuda: fix group_norm

* cuda: add batch inference support for ggml_pad/ggml_upscale

* add ggml_arrange

* add ggml_timestep_embedding

* update ggml_arange/ggml_timestep_embedding tests

* cuda: fix im2col

* add ggml_arange/ggml_timestep_embbeding support for metal backend

* fix some bugs

* fix some bugs

* Update ggml.h

Co-authored-by: Georgi Gerganov <redacted>
* Update ggml-cuda.cu

Co-authored-by: Georgi Gerganov <redacted>
* Update ggml-metal.m

Co-authored-by: Georgi Gerganov <redacted>
* Update ggml-metal.m

Co-authored-by: Georgi Gerganov <redacted>
* Update ggml-metal.metal

Co-authored-by: Georgi Gerganov <redacted>
* modify according to the review comments

* ggml : fix compile warnings + code style

* ggml : normalize compute_forward calls + fix seg fault in debug

* minor

---------

Co-authored-by: Georgi Gerganov <redacted>
Co-authored-by: slaren <redacted>
ggml-cuda.cu
ggml-metal.m
ggml-metal.metal
ggml.c
ggml.h

index 0c6501e98a2a630e93dfd4cdece0e67e91d4b4c4..66d0577fcbe3254f36477d46cf754b433e391f98 100644 (file)
@@ -616,6 +616,8 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + Q
 #define CUDA_UPSCALE_BLOCK_SIZE 256
 #define CUDA_CONCAT_BLOCK_SIZE 256
 #define CUDA_PAD_BLOCK_SIZE 256
+#define CUDA_ARANGE_BLOCK_SIZE 256
+#define CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
 #define CUDA_ACC_BLOCK_SIZE 256
 #define CUDA_IM2COL_BLOCK_SIZE 256
 #define CUDA_POOL2D_BLOCK_SIZE 256
@@ -990,17 +992,21 @@ static __global__ void concat_f32(const float * x,const float * y, float * dst,
             nidx +
             blockIdx.y * ne0 +
             blockIdx.z * ne0 * gridDim.y;
-            dst[offset_dst] = x[offset_src];
+        dst[offset_dst] = x[offset_src];
     } else {
         int offset_src =
             nidx +
             blockIdx.y * ne0 +
             (blockIdx.z - ne02) * ne0 *  gridDim.y;
-            dst[offset_dst] = y[offset_src];
+        dst[offset_dst] = y[offset_src];
     }
 }
 
-static __global__ void upscale_f32(const float * x, float * dst, const int ne00, const int nb02, const int scale_factor) {
+static __global__ void upscale_f32(const float * x, float * dst, const int ne00, const int ne00xne01, const int scale_factor) {
+    // blockIdx.z: idx of ne02*ne03
+    // blockIdx.y: idx of ne01*scale_factor, aka ne1
+    // blockIDx.x: idx of ne00*scale_factor / BLOCK_SIZE
+    // ne00xne01: ne00 * ne01
     int ne0 = ne00 * scale_factor;
     int nidx = threadIdx.x + blockIdx.x * blockDim.x;
     if (nidx >= ne0) {
@@ -1012,7 +1018,7 @@ static __global__ void upscale_f32(const float * x, float * dst, const int ne00,
     int offset_src =
         i00 +
         i01 * ne00 +
-        blockIdx.z * nb02;
+        blockIdx.z * ne00xne01;
     int offset_dst =
         nidx +
         blockIdx.y * ne0 +
@@ -1020,7 +1026,10 @@ static __global__ void upscale_f32(const float * x, float * dst, const int ne00,
     dst[offset_dst] = x[offset_src];
 }
 
-static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02) {
+static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
+    // blockIdx.z: idx of ne2*ne3, aka ne02*ne03
+    // blockIdx.y: idx of ne1
+    // blockIDx.x: idx of ne0 / BLOCK_SIZE
     int nidx = threadIdx.x + blockIdx.x * blockDim.x;
     if (nidx >= ne0) {
         return;
@@ -1031,19 +1040,53 @@ static __global__ void pad_f32(const float * x, float * dst, const int ne0, cons
         nidx +
         blockIdx.y * ne0 +
         blockIdx.z * ne0 * gridDim.y;
-    if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02) {
+    if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
         int offset_src =
             nidx +
             blockIdx.y * ne00 +
             blockIdx.z * ne00 * ne01;
-            dst[offset_dst] = x[offset_src];
+        dst[offset_dst] = x[offset_src];
     } else {
         dst[offset_dst] = 0.0f;
     }
 }
 
+static __global__ void arange_f32(float * dst, const int ne0, const float start, const float step) {
+    // blockIDx.x: idx of ne0 / BLOCK_SIZE
+    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (nidx >= ne0) {
+        return;
+    }
+    dst[nidx] = start + step * nidx;
+}
+
+static __global__ void timestep_embedding_f32(const float * timesteps, float * dst, const int nb1, const int dim, const int max_period) {
+    // blockIDx.y: idx of timesteps->ne[0]
+    // blockIDx.x: idx of ((dim + 1) / 2) / BLOCK_SIZE
+    int i = blockIdx.y;
+    int j = threadIdx.x + blockIdx.x * blockDim.x;
+    float * embed_data = (float *)((char *)dst +  i*nb1);
+
+    if (dim % 2 != 0 && j == ((dim + 1) / 2)) {
+        embed_data[dim] = 0.f;
+    }
+
+    int half = dim / 2;
+    if (j >= half) {
+        return;
+    }
+
+    float timestep = timesteps[i];
+    float freq = (float)expf(-logf(max_period) * j / half);
+    float arg = timestep * freq;
+    embed_data[j] = cosf(arg);
+    embed_data[j + half] = sinf(arg);
+}
+
 template <int block_size>
 static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
+    // blockIdx.x: num_groups idx
+    // threadIdx.x: block_size idx
     int start = blockIdx.x * group_size;
     int end = start + group_size;
 
@@ -6449,7 +6492,7 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
                                    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
                                    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
                                    const int nb12, const int nb13) {
-    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+    const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= ne) {
         return;
@@ -6457,17 +6500,17 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
 
     // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
     // then combine those indices with the corresponding byte offsets to get the total offsets
-    const int i03 = i/(ne00 * ne01 * ne02);
-    const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
-    const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
-    const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
-    const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
-
-    const int i13 = i/(ne10 * ne11 * ne12);
-    const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
-    const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
-    const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
-    const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
+    const int64_t i03 = i/(ne00 * ne01 * ne02);
+    const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+    const int64_t i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
+    const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+    const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+
+    const int64_t i13 = i/(ne10 * ne11 * ne12);
+    const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+    const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+    const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+    const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
 
     cpy_1(cx + x_offset, cdst + dst_offset);
 }
@@ -6956,23 +6999,23 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
 
 template <typename T>
 static  __global__ void im2col_kernel(
-        const float * x, T * dst, int batch_offset,
-        int offset_delta, int IC, int IW, int IH, int OH, int OW, int KW, int KH, int pelements, int CHW,
+        const float * x, T * dst, int64_t batch_offset,
+        int64_t offset_delta, int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW,
         int s0, int s1, int p0, int p1, int d0, int d1) {
-    const int i = threadIdx.x + blockIdx.x * blockDim.x;
+    const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
     if (i >= pelements) {
         return;
     }
 
-    const int ksize = OW * (KH > 1 ? KW : 1);
-    const int kx = i / ksize;
-    const int kd = kx * ksize;
-    const int ky = (i - kd) / OW;
-    const int ix = i % OW;
+    const int64_t  ksize = OW * (KH > 1 ? KW : 1);
+    const int64_t  kx = i / ksize;
+    const int64_t  kd = kx * ksize;
+    const int64_t  ky = (i - kd) / OW;
+    const int64_t  ix = i % OW;
 
-    const int oh = blockIdx.y;
-    const int batch = blockIdx.z / IC;
-    const int ic = blockIdx.z % IC;
+    const int64_t  oh = blockIdx.y;
+    const int64_t  batch = blockIdx.z / IC;
+    const int64_t  ic = blockIdx.z % IC;
 
     const int64_t iiw = ix * s0 + kx * d0 - p0;
     const int64_t iih = oh * s1 + ky * d1 - p1;
@@ -7298,19 +7341,33 @@ static void concat_f32_cuda(const float * x, const float * y, float * dst, const
     concat_f32<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
 }
 
-static void upscale_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int scale_factor, cudaStream_t stream) {
+static void upscale_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int ne03,
+                             const int scale_factor, cudaStream_t stream) {
     int ne0 = (ne00 * scale_factor);
     int num_blocks = (ne0 + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
-    dim3 gridDim(num_blocks, (ne01 * scale_factor), ne02);
+    dim3 gridDim(num_blocks, (ne01 * scale_factor), ne02*ne03);
     upscale_f32<<<gridDim, CUDA_UPSCALE_BLOCK_SIZE, 0, stream>>>(x, dst, ne00, ne00 * ne01, scale_factor);
 }
 
 static void pad_f32_cuda(const float * x, float * dst,
-    const int ne00, const int ne01, const int ne02,
-    const int ne0, const int ne1, const int ne2, cudaStream_t stream) {
+    const int ne00, const int ne01, const int ne02, const int ne03,
+    const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
     int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
-    dim3 gridDim(num_blocks, ne1, ne2);
-    pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02);
+    dim3 gridDim(num_blocks, ne1, ne2*ne3);
+    pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
+}
+
+static void arange_f32_cuda(float * dst, const int ne0, const float start, const float step, cudaStream_t stream) {
+    int num_blocks = (ne0 + CUDA_ARANGE_BLOCK_SIZE - 1) / CUDA_ARANGE_BLOCK_SIZE;
+    arange_f32<<<num_blocks, CUDA_ARANGE_BLOCK_SIZE, 0, stream>>>(dst, ne0, start,  step);
+}
+
+static void timestep_embedding_f32_cuda(const float * x, float * dst, const int ne00, const int nb1,
+                                        const int dim, const int max_period, cudaStream_t stream) {
+    int half_ceil = (dim + 1) / 2;
+    int num_blocks = (half_ceil + CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE;
+    dim3 gridDim(num_blocks, ne00, 1);
+    timestep_embedding_f32<<<gridDim, CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE, 0, stream>>>(x, dst, nb1, dim, max_period);
 }
 
 static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
@@ -8443,8 +8500,8 @@ static void soft_max_f32_cuda(const float * x, const float * mask, const float *
 
 template <typename T>
 static void im2col_cuda(const float* x, T* dst,
-    int IW, int IH, int OW, int OH, int KW, int KH, int IC,
-    int batch, int batch_offset, int offset_delta,
+    int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
+    int64_t batch, int64_t batch_offset, int64_t offset_delta,
     int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
     const int parallel_elements = OW * KW * KH;
     const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
@@ -9123,7 +9180,7 @@ static void ggml_cuda_op_group_norm(
 
     int num_groups = dst->op_params[0];
     int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
-    group_norm_f32_cuda(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream);
+    group_norm_f32_cuda(src0_dd, dst_dd, num_groups * src0->ne[3], group_size, ggml_nelements(src0), main_stream);
 
     (void) src1;
     (void) dst;
@@ -9156,7 +9213,7 @@ static void ggml_cuda_op_upscale(
 
     const int scale_factor = dst->op_params[0];
 
-    upscale_f32_cuda(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], scale_factor, main_stream);
+    upscale_f32_cuda(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], scale_factor, main_stream);
 
     (void) src1;
     (void) dst;
@@ -9172,8 +9229,49 @@ static void ggml_cuda_op_pad(
     GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
 
     pad_f32_cuda(src0_dd, dst_dd,
-        src0->ne[0], src0->ne[1], src0->ne[2],
-        dst->ne[0], dst->ne[1], dst->ne[2], main_stream);
+        src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+        dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+static void ggml_cuda_op_arange(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
+
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    float start;
+    float stop;
+    float step;
+    memcpy(&start, (float *)dst->op_params + 0, sizeof(float));
+    memcpy(&stop,  (float *)dst->op_params + 1, sizeof(float));
+    memcpy(&step,  (float *)dst->op_params + 2, sizeof(float));
+
+    int64_t steps = (int64_t)ceil((stop - start) / step);
+    GGML_ASSERT(ggml_nelements(dst) == steps);
+
+    arange_f32_cuda(dst_dd, dst->ne[0], start, step, main_stream);
+
+    (void) src0;
+    (void) src1;
+    (void) src0_dd;
+    (void) src1_dd;
+}
+
+static void ggml_cuda_op_timestep_embedding(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    const int dim = dst->op_params[0];
+    const int max_period = dst->op_params[1];
+
+    timestep_embedding_f32_cuda(src0_dd, dst_dd, src0->ne[0], dst->nb[1], dim, max_period, main_stream);
 
     (void) src1;
     (void) dst;
@@ -10458,6 +10556,45 @@ static void ggml_cuda_pad(const ggml_tensor * src0, const ggml_tensor * src1, gg
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_pad);
 }
 
+static void ggml_cuda_arange(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *)  dst->extra;
+
+    const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU;
+
+    // dd = data device
+    float * src0_ddf = nullptr;
+    float * src1_ddf = nullptr;
+    float *  dst_ddf = nullptr;
+
+    cuda_pool_alloc<float>  dst_f;
+
+    ggml_cuda_set_device(g_main_device);
+    cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
+
+    if (dst_on_device) {
+        dst_ddf = (float *) dst_extra->data_device[g_main_device];
+    } else {
+        dst_ddf = dst_f.alloc(ggml_nelements(dst));
+    }
+
+    // do the computation
+    ggml_cuda_op_arange(src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
+    CUDA_CHECK(cudaGetLastError());
+
+    // copy dst to host if necessary
+    if (!dst_on_device) {
+        CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream));
+    }
+
+    if (dst->backend == GGML_BACKEND_TYPE_CPU) {
+        CUDA_CHECK(cudaDeviceSynchronize());
+    }
+}
+
+static void ggml_cuda_timestep_embedding(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_timestep_embedding);
+}
+
 static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm);
 }
@@ -11358,6 +11495,12 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
         case GGML_OP_PAD:
             func = ggml_cuda_pad;
             break;
+        case GGML_OP_ARANGE:
+            func = ggml_cuda_arange;
+            break;
+        case GGML_OP_TIMESTEP_EMBEDDING:
+            func = ggml_cuda_timestep_embedding;
+            break;
         case GGML_OP_LEAKY_RELU:
             func = ggml_cuda_leaky_relu;
             break;
@@ -12253,6 +12396,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_GROUP_NORM:
         case GGML_OP_UPSCALE:
         case GGML_OP_PAD:
+        case GGML_OP_ARANGE:
+        case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_LEAKY_RELU:
             return true;
         default:
index 71fcca5605914ee263996816c1084f040285e49f..6b5a8fdf541041c977f6ca56e3353e82c58dc9f3 100644 (file)
@@ -163,6 +163,8 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_IM2COL_F32,
     GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
     GGML_METAL_KERNEL_TYPE_PAD_F32,
+    GGML_METAL_KERNEL_TYPE_ARANGE_F32,
+    GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
     GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
     GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
     GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
@@ -569,6 +571,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32,                im2col_f32,             true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,               upscale_f32,            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32,                   pad_f32,                true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,    timestep_embedding_f32, true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32,                arange_f32,             true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,       argsort_f32_i32_asc,    true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,      argsort_f32_i32_desc,   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,            leaky_relu_f32,         true);
@@ -697,6 +701,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
             return false;
         case GGML_OP_UPSCALE:
         case GGML_OP_PAD:
+        case GGML_OP_ARANGE:
+        case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_ARGSORT:
         case GGML_OP_LEAKY_RELU:
             return true;
@@ -1091,7 +1097,8 @@ static bool ggml_metal_graph_compute(
                     {
                         GGML_ASSERT(ggml_is_contiguous(src0));
 
-                        const float scale = *(const float *) dst->op_params;
+                        float scale;
+                        memcpy(&scale, dst->op_params, sizeof(scale));
 
                         int64_t n = ggml_nelements(dst);
 
@@ -1250,11 +1257,15 @@ static bool ggml_metal_graph_compute(
                             pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
                         }
 
-                        const float scale    = ((float *) dst->op_params)[0];
-                        const float max_bias = ((float *) dst->op_params)[1];
+                        float scale;
+                        float max_bias;
+
+                        memcpy(&scale,    ((int32_t *) dst->op_params) + 0, sizeof(scale));
+                        memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
 
                         const int64_t nrows_x = ggml_nrows(src0);
                         const int64_t nrows_y = src0->ne[1];
+
                         const uint32_t n_head_kv   = nrows_x/nrows_y;
                         const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
 
@@ -2086,6 +2097,7 @@ static bool ggml_metal_graph_compute(
 
                         //const int n_past = ((int32_t *) dst->op_params)[0];
                         const int n_head = ((int32_t *) dst->op_params)[1];
+
                         float max_bias;
                         memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
 
@@ -2300,6 +2312,50 @@ static bool ggml_metal_graph_compute(
 
                         [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                     } break;
+                case GGML_OP_ARANGE:
+                    {
+                        GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+                        float start;
+                        float step;
+
+                        memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float));
+                        memcpy(&step,  ((int32_t *) dst->op_params) + 2, sizeof(float));
+
+                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
+
+                        [encoder setComputePipelineState:pipeline];
+                        [encoder setBuffer:id_dst  offset:offs_dst    atIndex:0];
+                        [encoder setBytes:&ne0   length:sizeof(ne0)   atIndex:1];
+                        [encoder setBytes:&start length:sizeof(start) atIndex:2];
+                        [encoder setBytes:&step  length:sizeof(step)  atIndex:3];
+
+                        const int nth = MIN(1024, ne0);
+
+                        [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                    } break;
+                case GGML_OP_TIMESTEP_EMBEDDING:
+                    {
+                        GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+                        const int dim        = dst->op_params[0];
+                        const int max_period = dst->op_params[1];
+
+                        const int half = dim / 2;
+
+                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
+
+                        [encoder setComputePipelineState:pipeline];
+                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                        [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                        [encoder setBytes:&nb1   length:sizeof(nb1) atIndex:2];
+                        [encoder setBytes:&dim   length:sizeof(dim) atIndex:3];
+                        [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
+
+                        const int nth = MIN(1024, half);
+
+                        [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                    } break;
                 case GGML_OP_ARGSORT:
                     {
                         GGML_ASSERT(src0->type == GGML_TYPE_F32);
index 74a5e0b039fc70ea4698f74b3bb89850345a1779..735f76738613a87ab111f2104310b307d67c81a2 100644 (file)
@@ -1959,6 +1959,49 @@ kernel void kernel_pad_f32(
     }
 }
 
+kernel void kernel_arange_f32(
+    device        char * dst,
+    constant   int64_t & ne0,
+    constant   float   & start,
+    constant   float   & step,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    device float * dst_ptr = (device float *) dst;
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        dst_ptr[i0] = start + step * i0;
+    }
+}
+
+kernel void kernel_timestep_embedding_f32(
+    device  const char * src0,
+    device        char * dst,
+    constant  uint64_t & nb1,
+    constant  int      & dim,
+    constant  int      & max_period,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    int i = tgpig.x;
+    device float * embed_data = (device float *)(dst +  i*nb1);
+
+    int half_ = dim / 2;
+    for (int j = tpitg.x; j < half_; j += ntg.x) {
+        float timestep = ((device float *)src0)[i];
+        float freq = (float)exp(-log((float)max_period) * j / half_);
+        float arg = timestep * freq;
+        embed_data[j        ] = cos(arg);
+        embed_data[j + half_] = sin(arg);
+    }
+
+    if (dim % 2 != 0 && tpitg.x == 0) {
+        embed_data[dim] = 0.f;
+    }
+}
+
 // bitonic sort implementation following the CUDA kernels as reference
 typedef void (argsort_t)(
         device const float * x,
diff --git a/ggml.c b/ggml.c
index f29b9f13fbcaf4afbf9a97713d1afc93f93a2347..870e41614add7e126ca0483d1455b74cc7a9bf56 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -1822,6 +1822,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "POOL_2D",
     "UPSCALE",
     "PAD",
+    "ARANGE",
+    "TIMESTEP_EMBEDDING",
     "ARGSORT",
     "LEAKY_RELU",
 
@@ -1850,7 +1852,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CROSS_ENTROPY_LOSS_BACK",
 };
 
-static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72");
+static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1908,6 +1910,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "pool_2d(x)",
     "upscale(x)",
     "pad(x)",
+    "arange(start, stop, step)",
+    "timestep_embedding(timesteps, dim, max_period)",
     "argsort(x)",
     "leaky_relu(x)",
 
@@ -1936,7 +1940,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "cross_entropy_loss_back(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72");
+static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -2895,11 +2899,21 @@ static int32_t ggml_get_op_params_i32(const struct ggml_tensor * tensor, uint32_
     return ((const int32_t *)(tensor->op_params))[i];
 }
 
+static float ggml_get_op_params_f32(const struct ggml_tensor * tensor, uint32_t i) {
+    assert(i < GGML_MAX_OP_PARAMS / sizeof(float));
+    return ((const float *)(tensor->op_params))[i];
+}
+
 static void ggml_set_op_params_i32(struct ggml_tensor * tensor, uint32_t i, int32_t value) {
     assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t));
     ((int32_t *)(tensor->op_params))[i] = value;
 }
 
+static void ggml_set_op_params_f32(struct ggml_tensor * tensor, uint32_t i, float value) {
+    assert(i < GGML_MAX_OP_PARAMS / sizeof(float));
+    ((float *)(tensor->op_params))[i] = value;
+}
+
 struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
     memset(tensor->data, 0, ggml_nbytes(tensor));
     return tensor;
@@ -5898,6 +5912,55 @@ struct ggml_tensor * ggml_upscale(
     return ggml_upscale_impl(ctx, a, scale_factor);
 }
 
+struct ggml_tensor * ggml_arange(
+    struct ggml_context * ctx,
+    float start,
+    float stop,
+    float step) {
+
+    GGML_ASSERT(stop > start);
+
+    const int64_t steps = (int64_t) ceilf((stop - start) / step);
+
+    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);
+
+    result->op = GGML_OP_ARANGE;
+    ggml_set_op_params_f32(result, 0, start);
+    ggml_set_op_params_f32(result, 1, stop);
+    ggml_set_op_params_f32(result, 2, step);
+
+    return result;
+}
+
+struct ggml_tensor * ggml_timestep_embedding(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * timesteps,
+            int                   dim,
+            int                   max_period) {
+    bool is_node = false;
+
+    if (timesteps->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    int actual_dim = dim;
+    if (dim % 2 != 0) {
+        actual_dim = dim + 1;
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, actual_dim, timesteps->ne[0]);
+
+    result->op = GGML_OP_TIMESTEP_EMBEDDING;
+    ggml_set_op_params_i32(result, 0, dim);
+    ggml_set_op_params_i32(result, 1, max_period);
+
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = timesteps;
+
+    return result;
+}
+
 // ggml_argsort
 
 struct ggml_tensor * ggml_argsort(
@@ -10231,7 +10294,7 @@ static void ggml_compute_forward_group_norm_f32(
     int n_channels = src0->ne[2];
     int n_groups = dst->op_params[0];
     int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
-    for (int i = ith; i < n_groups; i+=nth) {
+    for (int i = ith; i < n_groups; i += nth) {
         int start = i * n_channels_per_group;
         int end = start + n_channels_per_group;
         if (end > n_channels) {
@@ -10245,28 +10308,32 @@ static void ggml_compute_forward_group_norm_f32(
                 for (int64_t i01 = 0; i01 < ne01; i01++) {
                     const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
 
+                    ggml_float sumr = 0.0;
                     for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        sum += (ggml_float)x[i00];
+                        sumr += (ggml_float)x[i00];
                     }
+                    sum += sumr;
                 }
             }
-            float mean = sum / (ne00 * ne01 * step);
-            ggml_float sum2 = 0.0;
+            const float mean = sum / (ne00 * ne01 * step);
 
+            ggml_float sum2 = 0.0;
             for (int64_t i02 = start; i02 < end; i02++) {
                 for (int64_t i01 = 0; i01 < ne01; i01++) {
                     const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
 
                     float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
 
+                    ggml_float sumr = 0.0;
                     for (int64_t i00 = 0; i00 < ne00; i00++) {
                         float v = x[i00] - mean;
                         y[i00] = v;
-                        sum2 += (ggml_float)(v * v);
+                        sumr += (ggml_float)(v * v);
                     }
+                    sum2 += sumr;
                 }
             }
-            float variance = sum2 / (ne00 * ne01 * step);
+            const float variance = sum2 / (ne00 * ne01 * step);
             const float scale = 1.0f / sqrtf(variance + eps);
 
             for (int64_t i02 = start; i02 < end; i02++) {
@@ -13547,6 +13614,106 @@ static void ggml_compute_forward_pad(
     }
 }
 
+
+// ggml_compute_forward_arange
+
+static void ggml_compute_forward_arange_f32(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+        return;
+    }
+
+    GGML_ASSERT(dst->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const float start = ggml_get_op_params_f32(dst, 0);
+    const float stop  = ggml_get_op_params_f32(dst, 1);
+    const float step  = ggml_get_op_params_f32(dst, 2);
+
+    const int64_t steps = (int64_t) ceilf((stop - start) / step);
+
+    GGML_ASSERT(ggml_nelements(dst) == steps);
+
+    for (int64_t i = ith; i < steps; i+= nth) {
+        float value = start + step * i;
+        ((float *)dst->data)[i] = value;
+    }
+}
+
+static void ggml_compute_forward_arange(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+    switch (dst->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_arange_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
+static void ggml_compute_forward_timestep_embedding_f32(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+        return;
+    }
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    const int dim = ggml_get_op_params_i32(dst, 0);
+    const int max_period = ggml_get_op_params_i32(dst, 1);
+
+    int half = dim / 2;
+
+    for (int64_t i = 0; i < ne00; i++) {
+        float * embed_data = (float *)((char *)  dst->data +  i*nb1);
+        for (int64_t j = ith; j < half; j += nth) {
+            float timestep = ((float *)src0->data)[i];
+            float freq = (float)expf(-logf(max_period) * j / half);
+            float arg = timestep * freq;
+            embed_data[j] = cosf(arg);
+            embed_data[j + half] = sinf(arg);
+        }
+        if (dim % 2 != 0 && ith == 0) {
+            embed_data[dim] = 0.f;
+        }
+    }
+}
+
+static void ggml_compute_forward_timestep_embedding(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_timestep_embedding_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_argsort
 
 static void ggml_compute_forward_argsort_f32(
@@ -15615,6 +15782,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_pad(params, tensor);
             } break;
+        case GGML_OP_ARANGE:
+            {
+                ggml_compute_forward_arange(params, tensor);
+            } break;
+        case GGML_OP_TIMESTEP_EMBEDDING:
+            {
+                ggml_compute_forward_timestep_embedding(params, tensor);
+            } break;
         case GGML_OP_ARGSORT:
             {
                 ggml_compute_forward_argsort(params, tensor);
@@ -16617,6 +16792,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
+        case GGML_OP_ARANGE:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
+        case GGML_OP_TIMESTEP_EMBEDDING:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_ARGSORT:
             {
                 GGML_ASSERT(false); // TODO: not implemented
@@ -17368,6 +17551,14 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             {
                 n_tasks = n_threads;
             } break;
+        case GGML_OP_ARANGE:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_TIMESTEP_EMBEDDING:
+            {
+                n_tasks = n_threads;
+            } break;
         case GGML_OP_ARGSORT:
             {
                 n_tasks = n_threads;
diff --git a/ggml.h b/ggml.h
index 0a6d3c051fe72532d97c2b23c4f8ab749839b454..98cfc7bf8f67066e4b7e64a66315ca7202f26bb2 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -454,6 +454,8 @@ extern "C" {
         GGML_OP_POOL_2D,
         GGML_OP_UPSCALE, // nearest interpolate
         GGML_OP_PAD,
+        GGML_OP_ARANGE,
+        GGML_OP_TIMESTEP_EMBEDDING,
         GGML_OP_ARGSORT,
         GGML_OP_LEAKY_RELU,
 
@@ -1661,6 +1663,15 @@ extern "C" {
             int                  p2,
             int                  p3);
 
+    // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
+    // timesteps: [N,]
+    // return: [N, dim]
+    GGML_API struct ggml_tensor * ggml_timestep_embedding(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * timesteps,
+            int                   dim,
+            int                   max_period);
+
     // sort rows
     enum ggml_sort_order {
         GGML_SORT_ORDER_ASC,
@@ -1672,6 +1683,12 @@ extern "C" {
             struct ggml_tensor  * a,
             enum ggml_sort_order  order);
 
+    GGML_API struct ggml_tensor * ggml_arange(
+            struct ggml_context * ctx,
+            float                 start,
+            float                 stop,
+            float                 step);
+
     // top k elements per row
     GGML_API struct ggml_tensor * ggml_top_k(
             struct ggml_context * ctx,