]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : add ggml_cont() + optimize ggml_cpy() for contiguous dst
authorGeorgi Gerganov <redacted>
Mon, 10 Apr 2023 19:40:28 +0000 (22:40 +0300)
committerGeorgi Gerganov <redacted>
Mon, 10 Apr 2023 19:42:28 +0000 (22:42 +0300)
ggml.c
ggml.h

diff --git a/ggml.c b/ggml.c
index 6db6fde8deb5d9d1863d4d6102e79e262c9dfd8f..4f6420678403d53f65beb56ca0ee468ee3e2c0e7 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -2609,6 +2609,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
 
     "SCALE",
     "CPY",
+    "CONT",
     "RESHAPE",
     "VIEW",
     "PERMUTE",
@@ -2624,7 +2625,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "FLASH_FF",
 };
 
-static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
+static_assert(GGML_OP_COUNT == 36, "GGML_OP_COUNT != 36");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -2653,6 +2654,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
 
     "x*v",
     "x-\\>y",
+    "cont(x)",
     "reshape(x)",
     "view(x)",
     "permute(x)",
@@ -2668,7 +2670,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "flash_ff(x)",
 };
 
-static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
+static_assert(GGML_OP_COUNT == 36, "GGML_OP_COUNT != 36");
 
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
 static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -4301,6 +4303,41 @@ struct ggml_tensor * ggml_cpy_inplace(
     return ggml_cpy_impl(ctx, a, b, true);
 }
 
+// ggml_cont
+
+struct ggml_tensor * ggml_cont_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && a->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_CONT;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_cont(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a) {
+    return ggml_cont_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_cont_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a) {
+    return ggml_cont_impl(ctx, a, true);
+}
+
 // ggml_reshape
 
 struct ggml_tensor * ggml_reshape(
@@ -4843,6 +4880,85 @@ static void ggml_compute_forward_dup_f16(
 
     // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
 
+    if (ggml_is_contiguous(dst)) {
+        if (src0->nb[0] == sizeof(ggml_fp16_t)) {
+            if (dst->type == GGML_TYPE_F16) {
+                size_t id = 0;
+                const size_t rs = ne00*nb00;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        for (int i01 = 0; i01 < ne01; i01++) {
+                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+                            char * dst_ptr = (char *) dst->data + id*rs;
+
+                            memcpy(dst_ptr, src0_ptr, rs);
+
+                            id++;
+                        }
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                float * dst_ptr = (float *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        for (int i01 = 0; i01 < ne01; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
+                                id++;
+                            }
+                        }
+                    }
+                }
+            } else {
+                GGML_ASSERT(false); // TODO: implement
+            }
+        } else {
+            //printf("%s: this is not optimal - fix me\n", __func__);
+
+            if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                float * dst_ptr = (float *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        for (int i01 = 0; i01 < ne01; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
+                                id++;
+                            }
+                        }
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_F16) {
+                size_t id = 0;
+                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        for (int i01 = 0; i01 < ne01; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = *src0_ptr;
+                                id++;
+                            }
+                        }
+                    }
+                }
+            } else {
+                GGML_ASSERT(false); // TODO: implement
+            }
+        }
+        return;
+    }
+
     // dst counters
     int64_t i10 = 0;
     int64_t i11 = 0;
@@ -4937,6 +5053,105 @@ static void ggml_compute_forward_dup_f32(
         return;
     }
 
+    if (src0->type == dst->type &&
+        src0->ne[0] == dst->ne[0] &&
+        src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
+        // copy by rows
+        const size_t rs = ne00*nb00;
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    memcpy(
+                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
+                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
+                        rs);
+                }
+            }
+        }
+        return;
+    }
+
+    if (ggml_is_contiguous(dst)) {
+        // TODO: simplify
+        if (src0->nb[0] == sizeof(float)) {
+            if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                const size_t rs = ne00*nb00;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        for (int i01 = 0; i01 < ne01; i01++) {
+                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+                            char * dst_ptr = (char *) dst->data + id*rs;
+
+                            memcpy(dst_ptr, src0_ptr, rs);
+
+                            id++;
+                        }
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_F16) {
+                size_t id = 0;
+                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        for (int i01 = 0; i01 < ne01; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
+                                id++;
+                            }
+                        }
+                    }
+                }
+            } else {
+                GGML_ASSERT(false); // TODO: implement
+            }
+        } else {
+            //printf("%s: this is not optimal - fix me\n", __func__);
+
+            if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                float * dst_ptr = (float *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        for (int i01 = 0; i01 < ne01; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = *src0_ptr;
+                                id++;
+                            }
+                        }
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_F16) {
+                size_t id = 0;
+                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        for (int i01 = 0; i01 < ne01; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
+                                id++;
+                            }
+                        }
+                    }
+                }
+            } else {
+                GGML_ASSERT(false); // TODO: implement
+            }
+        }
+
+        return;
+    }
+
     // dst counters
     int64_t i10 = 0;
     int64_t i11 = 0;
@@ -5057,14 +5272,18 @@ static void ggml_compute_forward_add_f32(
     GGML_ASSERT(nb00 == sizeof(float));
 
     if (nb10 == sizeof(float)) {
-        const int j0 = (n/nth)*ith;
-        const int j1 = ith == nth - 1 ? n : (n/nth)*(ith + 1);
-
-        for (int j = j0; j < j1; j++) {
+        for (int j = ith; j < n; j += nth) {
+#ifdef GGML_USE_ACCELERATE
+            vDSP_vadd(
+                    (float *) ((char *) src0->data + j*nb01), 1,
+                    (float *) ((char *) src1->data + j*nb11), 1,
+                    (float *) ((char *) dst->data  + j*nb1),  1, nc);
+#else
             ggml_vec_add_f32(nc,
                     (float *) ((char *) dst->data  + j*nb1),
                     (float *) ((char *) src0->data + j*nb01),
                     (float *) ((char *) src1->data + j*nb11));
+#endif
         }
     } else {
         // src1 is not contiguous
@@ -6812,6 +7031,15 @@ static void ggml_compute_forward_cpy(
     ggml_compute_forward_dup(params, src0, dst);
 }
 
+// ggml_compute_forward_cont
+
+static void ggml_compute_forward_cont(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    ggml_compute_forward_dup(params, src0, dst);
+}
+
 // ggml_compute_forward_reshape
 
 static void ggml_compute_forward_reshape(
@@ -8642,6 +8870,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_cpy(params, tensor->src0, tensor);
             } break;
+        case GGML_OP_CONT:
+            {
+                ggml_compute_forward_cont(params, tensor->src0, tensor);
+            } break;
         case GGML_OP_RESHAPE:
             {
                 ggml_compute_forward_reshape(params, tensor->src0, tensor);
@@ -8886,8 +9118,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                     src1->grad =
                         ggml_add_impl(ctx,
                                 src1->grad,
-                                // TODO: fix transpose, the node will break the graph connections
-                                ggml_mul_mat(ctx, ggml_transpose(ctx, src0), tensor->grad),
+                                ggml_mul_mat(ctx,
+                                    ggml_cont(ctx, ggml_transpose(ctx, src0)),
+                                    tensor->grad),
                                 inplace);
                 }
             } break;
@@ -8899,6 +9132,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
+        case GGML_OP_CONT:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_RESHAPE:
             {
                 GGML_ASSERT(false); // TODO: not implemented
@@ -9353,6 +9590,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                         node->n_tasks = n_threads;
                     } break;
                 case GGML_OP_CPY:
+                case GGML_OP_CONT:
                 case GGML_OP_RESHAPE:
                 case GGML_OP_VIEW:
                 case GGML_OP_PERMUTE:
diff --git a/ggml.h b/ggml.h
index af16c647c970888851d68f129a1e297508a4cc30..a5245a8ae6256c0bee449cff0d9112e24df152dc 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -236,6 +236,7 @@ enum ggml_op {
 
     GGML_OP_SCALE,
     GGML_OP_CPY,
+    GGML_OP_CONT,
     GGML_OP_RESHAPE,
     GGML_OP_VIEW,
     GGML_OP_PERMUTE,
@@ -525,6 +526,11 @@ struct ggml_tensor * ggml_cpy(
         struct ggml_tensor  * a,
         struct ggml_tensor  * b);
 
+// make contiguous
+struct ggml_tensor * ggml_cont(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a);
+
 // return view(a), b specifies the new shape
 // TODO: when we start computing gradient, make a copy instead of view
 struct ggml_tensor * ggml_reshape(