]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : optimize ggml_cpy() for contiguous dst
authorGeorgi Gerganov <redacted>
Mon, 10 Apr 2023 19:39:07 +0000 (22:39 +0300)
committerGeorgi Gerganov <redacted>
Mon, 10 Apr 2023 19:39:07 +0000 (22:39 +0300)
src/ggml.c

index a8f36e64e36661cae98f78f6951604cb19788b4a..3942379ecdd2685436a9a84526fed7d72dd44fa7 100644 (file)
@@ -4885,6 +4885,85 @@ static void ggml_compute_forward_dup_f16(
 
     // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
 
+    if (ggml_is_contiguous(dst)) {
+        if (src0->nb[0] == sizeof(ggml_fp16_t)) {
+            if (dst->type == GGML_TYPE_F16) {
+                size_t id = 0;
+                const size_t rs = ne00*nb00;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        for (int i01 = 0; i01 < ne01; i01++) {
+                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+                            char * dst_ptr = (char *) dst->data + id*rs;
+
+                            memcpy(dst_ptr, src0_ptr, rs);
+
+                            id++;
+                        }
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                float * dst_ptr = (float *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        for (int i01 = 0; i01 < ne01; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
+                                id++;
+                            }
+                        }
+                    }
+                }
+            } else {
+                GGML_ASSERT(false); // TODO: implement
+            }
+        } else {
+            //printf("%s: this is not optimal - fix me\n", __func__);
+
+            if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                float * dst_ptr = (float *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        for (int i01 = 0; i01 < ne01; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
+                                id++;
+                            }
+                        }
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_F16) {
+                size_t id = 0;
+                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        for (int i01 = 0; i01 < ne01; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = *src0_ptr;
+                                id++;
+                            }
+                        }
+                    }
+                }
+            } else {
+                GGML_ASSERT(false); // TODO: implement
+            }
+        }
+        return;
+    }
+
     // dst counters
     int64_t i10 = 0;
     int64_t i11 = 0;
@@ -4979,6 +5058,105 @@ static void ggml_compute_forward_dup_f32(
         return;
     }
 
+    if (src0->type == dst->type &&
+        src0->ne[0] == dst->ne[0] &&
+        src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
+        // copy by rows
+        const size_t rs = ne00*nb00;
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    memcpy(
+                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
+                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
+                        rs);
+                }
+            }
+        }
+        return;
+    }
+
+    if (ggml_is_contiguous(dst)) {
+        // TODO: simplify
+        if (src0->nb[0] == sizeof(float)) {
+            if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                const size_t rs = ne00*nb00;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        for (int i01 = 0; i01 < ne01; i01++) {
+                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+                            char * dst_ptr = (char *) dst->data + id*rs;
+
+                            memcpy(dst_ptr, src0_ptr, rs);
+
+                            id++;
+                        }
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_F16) {
+                size_t id = 0;
+                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        for (int i01 = 0; i01 < ne01; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
+                                id++;
+                            }
+                        }
+                    }
+                }
+            } else {
+                GGML_ASSERT(false); // TODO: implement
+            }
+        } else {
+            //printf("%s: this is not optimal - fix me\n", __func__);
+
+            if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                float * dst_ptr = (float *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        for (int i01 = 0; i01 < ne01; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = *src0_ptr;
+                                id++;
+                            }
+                        }
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_F16) {
+                size_t id = 0;
+                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        for (int i01 = 0; i01 < ne01; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
+                                id++;
+                            }
+                        }
+                    }
+                }
+            } else {
+                GGML_ASSERT(false); // TODO: implement
+            }
+        }
+
+        return;
+    }
+
     // dst counters
     int64_t i10 = 0;
     int64_t i11 = 0;
@@ -5099,14 +5277,18 @@ static void ggml_compute_forward_add_f32(
     GGML_ASSERT(nb00 == sizeof(float));
 
     if (nb10 == sizeof(float)) {
-        const int j0 = (n/nth)*ith;
-        const int j1 = ith == nth - 1 ? n : (n/nth)*(ith + 1);
-
-        for (int j = j0; j < j1; j++) {
+        for (int j = ith; j < n; j += nth) {
+#ifdef GGML_USE_ACCELERATE
+            vDSP_vadd(
+                    (float *) ((char *) src0->data + j*nb01), 1,
+                    (float *) ((char *) src1->data + j*nb11), 1,
+                    (float *) ((char *) dst->data  + j*nb1),  1, nc);
+#else
             ggml_vec_add_f32(nc,
                     (float *) ((char *) dst->data  + j*nb1),
                     (float *) ((char *) src0->data + j*nb01),
                     (float *) ((char *) src1->data + j*nb11));
+#endif
         }
     } else {
         // src1 is not contiguous