]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : sync new operators from stable-diffusion.cpp (#461)
authorleejet <redacted>
Tue, 22 Aug 2023 10:39:31 +0000 (18:39 +0800)
committerGitHub <redacted>
Tue, 22 Aug 2023 10:39:31 +0000 (13:39 +0300)
* ggml : add ggml_group_norm

* ggml : add ggml_upscale

* ggml : add ggml_concat

* ggml : match code style

---------

Co-authored-by: Georgi Gerganov <redacted>
include/ggml/ggml.h
src/ggml.c

index 1a34bd9be3bd78b8deb5046b6053cb55e503b2b4..3c48fd27fab39d2986f76f3fc4a3b1543a6cb915 100644 (file)
@@ -346,10 +346,12 @@ extern "C" {
         GGML_OP_ARGMAX,
         GGML_OP_REPEAT,
         GGML_OP_REPEAT_BACK,
+        GGML_OP_CONCAT,
         GGML_OP_SILU_BACK,
         GGML_OP_NORM, // normalize
         GGML_OP_RMS_NORM,
         GGML_OP_RMS_NORM_BACK,
+        GGML_OP_GROUP_NORM,
 
         GGML_OP_MUL_MAT,
         GGML_OP_OUT_PROD,
@@ -379,6 +381,8 @@ extern "C" {
         GGML_OP_POOL_1D,
         GGML_OP_POOL_2D,
 
+        GGML_OP_UPSCALE, // nearest interpolate
+
         GGML_OP_FLASH_ATTN,
         GGML_OP_FLASH_FF,
         GGML_OP_FLASH_ATTN_BACK,
@@ -809,6 +813,13 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    // concat a and b on dim 2
+    // used in stable-diffusion
+    GGML_API struct ggml_tensor * ggml_concat(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
     GGML_API struct ggml_tensor * ggml_abs(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
@@ -917,6 +928,19 @@ extern "C" {
             struct ggml_tensor  * a,
             float                 eps);
 
+    // group normalize along ne0*ne1*n_groups
+    // used in stable-diffusion
+    // TODO: eps is hardcoded to 1e-6 for now
+    GGML_API struct ggml_tensor * ggml_group_norm(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_groups);
+
+    GGML_API struct ggml_tensor * ggml_group_norm_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_groups);
+
     // a - x
     // b - dy
     // TODO: update with configurable eps
@@ -1343,6 +1367,13 @@ extern "C" {
             int                   p0,
             int                   p1);
 
+    // nearest interpolate
+    // used in stable-diffusion
+    GGML_API struct ggml_tensor * ggml_upscale(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   scale_factor);
+
     GGML_API struct ggml_tensor * ggml_flash_attn(
             struct ggml_context * ctx,
             struct ggml_tensor  * q,
index b7831919efc4a8cbf3f5fb816fa9fbd758411641..af031cc2363755d49a01e2cb862bea82d50a3057 100644 (file)
@@ -3744,10 +3744,12 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "ARGMAX",
     "REPEAT",
     "REPEAT_BACK",
+    "CONCAT",
     "SILU_BACK",
     "NORM",
     "RMS_NORM",
     "RMS_NORM_BACK",
+    "GROUP_NORM",
 
     "MUL_MAT",
     "OUT_PROD",
@@ -3776,6 +3778,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CONV_TRANSPOSE_2D",
     "POOL_1D",
     "POOL_2D",
+    "UPSCALE",
 
     "FLASH_ATTN",
     "FLASH_FF",
@@ -3802,7 +3805,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CROSS_ENTROPY_LOSS_BACK",
 };
 
-static_assert(GGML_OP_COUNT == 65, "GGML_OP_COUNT != 65");
+static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -3823,10 +3826,12 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "argmax(x)",
     "repeat(x)",
     "repeat_back(x)",
+    "concat(x, y)",
     "silu_back(x)",
     "norm(x)",
     "rms_norm(x)",
     "rms_norm_back(x)",
+    "group_norm(x)",
 
     "X*Y",
     "X*Y",
@@ -3855,6 +3860,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "conv_transpose_2d(x)",
     "pool_1d(x)",
     "pool_2d(x)",
+    "upscale(x)",
 
     "flash_attn(x)",
     "flash_ff(x)",
@@ -3881,7 +3887,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "cross_entropy_loss_back(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 65, "GGML_OP_COUNT != 65");
+static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -5591,6 +5597,30 @@ struct ggml_tensor * ggml_repeat_back(
     return result;
 }
 
+// ggml_concat
+
+struct ggml_tensor* ggml_concat(
+    struct ggml_context* ctx,
+    struct ggml_tensor* a,
+    struct ggml_tensor* b) {
+    GGML_ASSERT(a->ne[0] == b->ne[0] && a->ne[1] == b->ne[1] && a->ne[3] == b->ne[3]);
+
+    bool is_node = false;
+
+    if (a->grad || b->grad) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, a->ne[0], a->ne[1], a->ne[2] + b->ne[2], a->ne[3]);
+
+    result->op = GGML_OP_CONCAT;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
 // ggml_abs
 
 struct ggml_tensor * ggml_abs(
@@ -5790,6 +5820,8 @@ struct ggml_tensor * ggml_norm_inplace(
     return ggml_norm_impl(ctx, a, true);
 }
 
+// ggml_rms_norm
+
 static struct ggml_tensor * ggml_rms_norm_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
@@ -5826,6 +5858,8 @@ struct ggml_tensor * ggml_rms_norm_inplace(
     return ggml_rms_norm_impl(ctx, a, eps, true);
 }
 
+// ggml_rms_norm_back
+
 struct ggml_tensor * ggml_rms_norm_back(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
@@ -5847,6 +5881,44 @@ struct ggml_tensor * ggml_rms_norm_back(
     return result;
 }
 
+// ggml_group_norm
+
+static struct ggml_tensor * ggml_group_norm_impl(
+    struct ggml_context * ctx,
+    struct ggml_tensor * a,
+    int n_groups,
+    bool inplace) {
+
+    bool is_node = false;
+    if (!inplace && (a->grad)) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op = GGML_OP_GROUP_NORM;
+    result->op_params[0] = n_groups;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+    result->src[1] = NULL; // TODO: maybe store epsilon here?
+
+    return result;
+}
+
+struct ggml_tensor * ggml_group_norm(
+    struct ggml_context * ctx,
+    struct ggml_tensor * a,
+    int n_groups) {
+    return ggml_group_norm_impl(ctx, a, n_groups, false);
+}
+
+struct ggml_tensor * ggml_group_norm_inplace(
+    struct ggml_context * ctx,
+    struct ggml_tensor * a,
+    int n_groups) {
+    return ggml_group_norm_impl(ctx, a, n_groups, true);
+}
 
 // ggml_mul_mat
 
@@ -7111,6 +7183,40 @@ struct ggml_tensor * ggml_pool_2d(
     return result;
 }
 
+// ggml_upscale
+
+static struct ggml_tensor * ggml_upscale_impl(
+    struct ggml_context * ctx,
+    struct ggml_tensor * a,
+    int scale_factor) {
+    bool is_node = false;
+
+    if (a->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
+            a->ne[0] * scale_factor,
+            a->ne[1] * scale_factor,
+            a->ne[2], a->ne[3]);
+
+    result->op = GGML_OP_UPSCALE;
+    result->op_params[0] = scale_factor;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+    result->src[1] = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_upscale(
+    struct ggml_context * ctx,
+    struct ggml_tensor * a,
+    int scale_factor) {
+    return ggml_upscale_impl(ctx, a, scale_factor);
+}
+
 // ggml_flash_attn
 
 struct ggml_tensor * ggml_flash_attn(
@@ -9893,6 +9999,72 @@ static void ggml_compute_forward_repeat_back(
     }
 }
 
+// ggml_compute_forward_concat
+
+static void ggml_compute_forward_concat_f32(
+    const struct ggml_compute_params * params,
+    const struct ggml_tensor * src0,
+    const struct ggml_tensor * src1,
+    struct ggml_tensor * dst) {
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    // TODO: support for transposed / permuted tensors
+    GGML_ASSERT(nb0  == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    for (int i3 = 0; i3 < ne3; i3++) {
+        for (int i2 = ith; i2 < ne2; i2++) {
+            if (i2 < ne02) { // src0
+                for (int i1 = 0; i1 < ne1; i1++) {
+                    for (int i0 = 0; i0 < ne0; i0++) {
+                        const float * x = (float *)((char *) src0->data + i0 * nb00 + i1 * nb01 + i2 * nb02 + i3 * nb03);
+
+                        float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
+                        *y = *x;
+                    }
+                }
+            } // src1
+            else {
+                for (int i1 = 0; i1 < ne1; i1++) {
+                    for (int i0 = 0; i0 < ne0; i0++) {
+                        const float * x = (float *)((char *) src1->data + i0 * nb10 + i1 * nb11 + (i2 - ne02) * nb12 + i3 * nb13);
+
+                        float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
+                        *y = *x;
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_concat(
+    const struct ggml_compute_params* params,
+    const struct ggml_tensor* src0,
+    const struct ggml_tensor* src1,
+    struct ggml_tensor* dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_concat_f32(params, src0, src1, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_abs
 
 static void ggml_compute_forward_abs_f32(
@@ -10496,6 +10668,8 @@ static void ggml_compute_forward_norm(
     }
 }
 
+// ggml_compute_forward_group_rms_norm
+
 static void ggml_compute_forward_rms_norm_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
@@ -10560,7 +10734,6 @@ static void ggml_compute_forward_rms_norm(
     }
 }
 
-
 static void ggml_compute_forward_rms_norm_back_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
@@ -10734,6 +10907,96 @@ static void ggml_compute_forward_rms_norm_back(
     }
 }
 
+// ggml_compute_forward_group_norm
+
+static void ggml_compute_forward_group_norm_f32(
+    const struct ggml_compute_params * params,
+    const struct ggml_tensor * src0,
+    struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS;
+
+    const float eps = 1e-6f; // TODO: make this a parameter
+
+    // TODO: optimize
+
+    int n_channels = src0->ne[2];
+    int n_groups = dst->op_params[0];
+    int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
+    for (int i = ith; i < n_groups; i+=nth) {
+        int start = i * n_channels_per_group;
+        int end = start + n_channels_per_group;
+        if (end > n_channels) {
+            end = n_channels;
+        }
+        int step = end - start;
+
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            ggml_float sum = 0.0;
+            for (int64_t i02 = start; i02 < end; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
+
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        sum += (ggml_float)x[i00];
+                    }
+                }
+            }
+            float mean = sum / (ne00 * ne01 * step);
+            ggml_float sum2 = 0.0;
+
+            for (int64_t i02 = start; i02 < end; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
+
+                    float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
+
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        float v = x[i00] - mean;
+                        y[i00] = v;
+                        sum2 += (ggml_float)(v * v);
+                    }
+                }
+            }
+            float variance = sum2 / (ne00 * ne01 * step);
+            const float scale = 1.0f / sqrtf(variance + eps);
+
+            for (int64_t i02 = start; i02 < end; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
+                    ggml_vec_scale_f32(ne00, y, scale);
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_group_norm(
+    const struct ggml_compute_params * params,
+    const struct ggml_tensor * src0,
+    struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_group_norm_f32(params, src0, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_mul_mat
 
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
@@ -13488,6 +13751,61 @@ static void ggml_compute_forward_pool_2d(
     ggml_compute_forward_pool_2d_sk_p0(params, op, src0, k0, k1, dst);
 }
 
+// ggml_compute_forward_upscale
+
+static void ggml_compute_forward_upscale_f32(
+    const struct ggml_compute_params * params,
+    const struct ggml_tensor * src0,
+    struct ggml_tensor * dst) {
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+
+    GGML_TENSOR_UNARY_OP_LOCALS;
+
+    const int scale_factor = dst->op_params[0];
+
+    // TODO: optimize
+
+    for (int i03 = 0; i03 < ne03; i03++) {
+        for (int i02 = ith; i02 < ne02; i02++) {
+            for (int m = 0; m < dst->ne[1]; m++) {
+                int i01 = m / scale_factor;
+                for (int n = 0; n < dst->ne[0]; n++) {
+                    int i00 = n / scale_factor;
+
+                    const float * x = (float *)((char *) src0->data + i00 * nb00 +i01 * nb01 + i02 * nb02 + i03 * nb03);
+
+                    float * y = (float *)((char *) dst->data + n * dst->nb[0] + m * dst->nb[1] + i02 * dst->nb[2] + i03 * dst->nb[3]);
+
+                    *y = *x;
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_upscale(
+    const struct ggml_compute_params * params,
+    const struct ggml_tensor * src0,
+    struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_upscale_f32(params, src0, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_flash_attn
 
 static void ggml_compute_forward_flash_attn_f32(
@@ -15295,6 +15613,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_repeat_back(params, tensor->src[0], tensor);
             } break;
+        case GGML_OP_CONCAT:
+            {
+                ggml_compute_forward_concat(params, tensor->src[0], tensor->src[1], tensor);
+            } break;
         case GGML_OP_SILU_BACK:
             {
                 ggml_compute_forward_silu_back(params, tensor->src[0], tensor->src[1], tensor);
@@ -15311,6 +15633,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_rms_norm_back(params, tensor->src[0], tensor->src[1], tensor);
             } break;
+        case GGML_OP_GROUP_NORM:
+            {
+                ggml_compute_forward_group_norm(params, tensor->src[0], tensor);
+            } break;
         case GGML_OP_MUL_MAT:
             {
                 ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
@@ -15415,6 +15741,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_pool_2d(params, tensor->src[0], tensor);
             } break;
+        case GGML_OP_UPSCALE:
+            {
+                ggml_compute_forward_upscale(params, tensor->src[0], tensor);
+            } break;
         case GGML_OP_FLASH_ATTN:
             {
                 const int32_t t = ggml_get_op_params_i32(tensor, 0);
@@ -15716,6 +16046,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                             inplace);
                 }
             } break;
+        case GGML_OP_CONCAT:
+            {
+                GGML_ASSERT(false); // TODO: implement
+            } break;
         case GGML_OP_SILU_BACK:
             {
                 GGML_ASSERT(false); // TODO: not implemented
@@ -15738,6 +16072,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
+        case GGML_OP_GROUP_NORM:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_MUL_MAT:
             {
                 // https://cs231n.github.io/optimization-2/#staged
@@ -16096,6 +16434,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
+        case GGML_OP_UPSCALE:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_FLASH_ATTN:
             {
                 struct ggml_tensor * flash_grad = NULL;
@@ -16915,9 +17257,11 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
             case GGML_OP_NORM:
             case GGML_OP_RMS_NORM:
             case GGML_OP_RMS_NORM_BACK:
+            case GGML_OP_GROUP_NORM:
                 {
                     n_tasks = n_threads;
                 } break;
+            case GGML_OP_CONCAT:
             case GGML_OP_MUL_MAT:
             case GGML_OP_OUT_PROD:
                 {
@@ -17086,6 +17430,10 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                 {
                     n_tasks = 1;
                 } break;
+            case GGML_OP_UPSCALE:
+                {
+                    n_tasks = n_threads;
+                } break;
             case GGML_OP_FLASH_ATTN:
                 {
                     n_tasks = n_threads;