]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : implement ggml_compute_forward_dup_f16() special cases
authorGeorgi Gerganov <redacted>
Fri, 16 Dec 2022 19:50:41 +0000 (21:50 +0200)
committerGeorgi Gerganov <redacted>
Fri, 16 Dec 2022 19:50:41 +0000 (21:50 +0200)
ggml.c

diff --git a/ggml.c b/ggml.c
index bab5cc71d0193b16afe6da291d607501451158be..c5780ed2579a59363314985d1ea03af4ec486dde 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -3178,22 +3178,96 @@ void ggml_compute_forward_dup_f16(
         return;
     }
 
-    //const int ne00 = src0->ne[0];
-    //const int ne01 = src0->ne[1];
-    //const int ne02 = src0->ne[2];
-    //const int ne03 = src0->ne[3];
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
 
-    //const size_t nb00 = src0->nb[0];
-    //const size_t nb01 = src0->nb[1];
-    //const size_t nb02 = src0->nb[2];
-    //const size_t nb03 = src0->nb[3];
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
 
     if (ggml_is_contiguous(src0) && src0->type == dst->type) {
         memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
         return;
     }
 
-    GGML_ASSERT(false); // TODO: implement
+    if (src0->nb[0] == sizeof(ggml_fp16_t)) {
+        if (dst->type == GGML_TYPE_F16) {
+            int id = 0;
+            const size_t rs = ne00*nb00;
+
+            for (int i03 = 0; i03 < ne03; i03++) {
+                for (int i02 = 0; i02 < ne02; i02++) {
+                    for (int i01 = 0; i01 < ne01; i01++) {
+                        const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+                        char * dst_ptr = (char *) dst->data + id*rs;
+
+                        memcpy(dst_ptr, src0_ptr, rs);
+
+                        id++;
+                    }
+                }
+            }
+        } else if (dst->type == GGML_TYPE_F32) {
+            int id = 0;
+            float * dst_ptr = (float *) dst->data;
+
+            for (int i03 = 0; i03 < ne03; i03++) {
+                for (int i02 = 0; i02 < ne02; i02++) {
+                    for (int i01 = 0; i01 < ne01; i01++) {
+                        for (int i00 = 0; i00 < ne00; i00++) {
+                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                            dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
+                            id++;
+                        }
+                    }
+                }
+            }
+        } else {
+            GGML_ASSERT(false); // TODO: implement
+        }
+    } else {
+        //printf("%s: this is not optimal - fix me\n", __func__);
+
+        if (dst->type == GGML_TYPE_F32) {
+            int id = 0;
+            float * dst_ptr = (float *) dst->data;
+
+            for (int i03 = 0; i03 < ne03; i03++) {
+                for (int i02 = 0; i02 < ne02; i02++) {
+                    for (int i01 = 0; i01 < ne01; i01++) {
+                        for (int i00 = 0; i00 < ne00; i00++) {
+                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                            dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
+                            id++;
+                        }
+                    }
+                }
+            }
+        } else if (dst->type == GGML_TYPE_F16) {
+            int id = 0;
+            ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+            for (int i03 = 0; i03 < ne03; i03++) {
+                for (int i02 = 0; i02 < ne02; i02++) {
+                    for (int i01 = 0; i01 < ne01; i01++) {
+                        for (int i00 = 0; i00 < ne00; i00++) {
+                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                            dst_ptr[id] = *src0_ptr;
+                            id++;
+                        }
+                    }
+                }
+            }
+        } else {
+            GGML_ASSERT(false); // TODO: implement
+        }
+    }
 }
 
 void ggml_compute_forward_dup_f32(