]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : refactor forward_dup for cpu backend (llama/16062)
authorXuan-Son Nguyen <redacted>
Fri, 19 Sep 2025 04:31:56 +0000 (11:31 +0700)
committerGeorgi Gerganov <redacted>
Sat, 20 Sep 2025 10:33:50 +0000 (13:33 +0300)
* ggml : refactor forward_dup for cpu backend

* clean up a bit

* add quant/dequant perf test

src/ggml-cpu/common.h
src/ggml-cpu/ops.cpp
tests/test-backend-ops.cpp

index 353563dc35c5d120aef3bedf9d41ee91a2617a57..6adca5437f8654cf1cb4735ecdf1c96c22a0232d 100644 (file)
@@ -28,6 +28,14 @@ static inline float bf16_to_f32(ggml_bf16_t x) {
     return GGML_BF16_TO_FP32(x);
 }
 
+static inline float i32_to_f32(int32_t x) {
+    return x;
+}
+
+static inline int32_t f32_to_i32(float x) {
+    return x;
+}
+
 static inline float f32_to_f32(float x) {
     return x;
 }
@@ -54,6 +62,12 @@ struct type_conversion_table<ggml_bf16_t> {
     static constexpr ggml_bf16_t (*from_f32)(float) = f32_to_bf16;
 };
 
+template <>
+struct type_conversion_table<int32_t> {
+    static constexpr float (*to_f32)(int32_t) = i32_to_f32;
+    static constexpr int32_t (*from_f32)(float) = f32_to_i32;
+};
+
 static std::pair<int64_t, int64_t> get_thread_range(const struct ggml_compute_params * params, const struct ggml_tensor * src0) {
     const int64_t ith = params->ith;
     const int64_t nth = params->nth;
index c4824d145a54d5fb9466b774421f200f9b6bf36d..763ab099e31a656b2735fa41ad68434b956a98e6 100644 (file)
@@ -41,13 +41,15 @@ static void ggml_compute_forward_dup_same_cont(
     }
 }
 
-static void ggml_compute_forward_dup_f16(
+template<typename src_t, typename dst_t>
+static void ggml_compute_forward_dup_flt(
         const ggml_compute_params * params,
         ggml_tensor * dst) {
 
     const ggml_tensor * src0 = dst->src[0];
 
     GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+    GGML_ASSERT(!ggml_is_quantized(src0->type) && !ggml_is_quantized(dst->type));
 
     GGML_TENSOR_UNARY_OP_LOCALS
 
@@ -62,6 +64,7 @@ static void ggml_compute_forward_dup_f16(
     const int ir0 = dr * ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
+    // case: type & row size equal
     if (src0->type == dst->type &&
         ne00 == ne0 &&
         nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
@@ -80,11 +83,11 @@ static void ggml_compute_forward_dup_f16(
         return;
     }
 
-    // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
-
+    // case: dst tensor is contiguous
     if (ggml_is_contiguous(dst)) {
-        if (nb00 == sizeof(ggml_fp16_t)) {
-            if (dst->type == GGML_TYPE_F16) {
+        if (nb00 == sizeof(src_t)) {
+            if constexpr (std::is_same_v<dst_t, src_t>) {
+                // same type
                 size_t id = 0;
                 const size_t rs = ne00 * nb00;
                 char * dst_ptr = (char *) dst->data;
@@ -100,91 +103,46 @@ static void ggml_compute_forward_dup_f16(
                         id += rs * (ne01 - ir1);
                     }
                 }
-            } else if (dst->type == GGML_TYPE_F32) {
+            } else {
+                // casting between non-quantized types
                 size_t id = 0;
-                float * dst_ptr = (float *) dst->data;
+                dst_t * dst_ptr = (dst_t *) dst->data;
 
                 for (int i03 = 0; i03 < ne03; i03++) {
                     for (int i02 = 0; i02 < ne02; i02++) {
                         id += ne00 * ir0;
                         for (int i01 = ir0; i01 < ir1; i01++) {
-                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+                            const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
                             for (int i00 = 0; i00 < ne00; i00++) {
-                                dst_ptr[id] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
+                                float tmp = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
+                                dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
                                 id++;
                             }
                         }
                         id += ne00 * (ne01 - ir1);
                     }
                 }
-            } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
-                ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
-                float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
-
-                size_t id = 0;
-                size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
-                char * dst_ptr = (char *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += rs * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                src0_f32[i00] = GGML_CPU_FP16_TO_FP32(src0_ptr[i00]);
-                            }
-
-                            quantize_row_q(src0_f32, dst_ptr + id, ne00);
-                            id += rs;
-                        }
-                        id += rs * (ne01 - ir1);
-                    }
-                }
-            } else {
-                GGML_ABORT("fatal error"); // TODO: implement
             }
         } else {
             //printf("%s: this is not optimal - fix me\n", __func__);
 
-            if (dst->type == GGML_TYPE_F32) {
-                size_t id = 0;
-                float * dst_ptr = (float *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                                dst_ptr[id] = GGML_CPU_FP16_TO_FP32(*src0_ptr);
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else if (dst->type == GGML_TYPE_F16) {
-                size_t id = 0;
-                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+            size_t id = 0;
+            dst_t * dst_ptr = (dst_t *) dst->data;
 
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+            for (int i03 = 0; i03 < ne03; i03++) {
+                for (int i02 = 0; i02 < ne02; i02++) {
+                    id += ne00 * ir0;
+                    for (int i01 = ir0; i01 < ir1; i01++) {
+                        for (int i00 = 0; i00 < ne00; i00++) {
+                            const src_t * src0_ptr = (src_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
 
-                                dst_ptr[id] = *src0_ptr;
-                                id++;
-                            }
+                            float tmp = type_conversion_table<src_t>::to_f32(*src0_ptr);
+                            dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
+                            id++;
                         }
-                        id += ne00 * (ne01 - ir1);
                     }
+                    id += ne00 * (ne01 - ir1);
                 }
-            } else {
-                GGML_ABORT("fatal error"); // TODO: implement
             }
         }
         return;
@@ -196,7 +154,7 @@ static void ggml_compute_forward_dup_f16(
     int64_t i12 = 0;
     int64_t i13 = 0;
 
-    if (dst->type == GGML_TYPE_F16) {
+    if constexpr (std::is_same_v<dst_t, src_t>) {
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
                 i10 += ne00 * ir0;
@@ -217,7 +175,7 @@ static void ggml_compute_forward_dup_f16(
                         const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
                               char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
 
-                        memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t));
+                        memcpy(dst_ptr, src0_ptr, sizeof(dst_t));
 
                         if (++i10 == ne00) {
                             i10 = 0;
@@ -248,7 +206,8 @@ static void ggml_compute_forward_dup_f16(
                 }
             }
         }
-    } else if (dst->type == GGML_TYPE_F32) {
+
+    } else {
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
                 i10 += ne00 * ir0;
@@ -269,7 +228,8 @@ static void ggml_compute_forward_dup_f16(
                         const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
                               char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
 
-                        *(float *) dst_ptr = GGML_CPU_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
+                        float tmp = type_conversion_table<src_t>::to_f32(*(const src_t *) src0_ptr);
+                        *(dst_t *) dst_ptr = type_conversion_table<dst_t>::from_f32(tmp);
 
                         if (++i10 == ne0) {
                             i10 = 0;
@@ -300,18 +260,19 @@ static void ggml_compute_forward_dup_f16(
                 }
             }
         }
-    } else {
-        GGML_ABORT("fatal error"); // TODO: implement
     }
 }
 
-static void ggml_compute_forward_dup_bf16(
+
+template<typename src_t>
+static void ggml_compute_forward_dup_to_q(
         const ggml_compute_params * params,
         ggml_tensor * dst) {
 
     const ggml_tensor * src0 = dst->src[0];
 
     GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+    GGML_ASSERT(!ggml_is_quantized(src0->type));
 
     GGML_TENSOR_UNARY_OP_LOCALS
 
@@ -326,785 +287,36 @@ static void ggml_compute_forward_dup_bf16(
     const int ir0 = dr * ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    if (src0->type == dst->type &&
-        ne00 == ne0 &&
-        nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
-        // copy by rows
-        const size_t rs = ne00*nb00;
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    memcpy(
-                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
-                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
-                        rs);
-                }
-            }
-        }
-        return;
-    }
-
-    // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
-
-    if (ggml_is_contiguous(dst)) {
-        if (nb00 == sizeof(ggml_bf16_t)) {
-            if (dst->type == GGML_TYPE_BF16) {
-                size_t id = 0;
-                const size_t rs = ne00 * nb00;
-                char * dst_ptr = (char *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += rs * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
-                            memcpy(dst_ptr + id, src0_ptr, rs);
-                            id += rs;
-                        }
-                        id += rs * (ne01 - ir1);
-                    }
-                }
-            } else if (dst->type == GGML_TYPE_F16) {
-                size_t id = 0;
-                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else if (dst->type == GGML_TYPE_F32) {
-                size_t id = 0;
-                float * dst_ptr = (float *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
-                ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
-                float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
-
-                size_t id = 0;
-                size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
-                char * dst_ptr = (char *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += rs * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]);
-                            }
-
-                            quantize_row_q(src0_f32, dst_ptr + id, ne00);
-                            id += rs;
-                        }
-                        id += rs * (ne01 - ir1);
-                    }
-                }
-            } else {
-                GGML_ABORT("fatal error"); // TODO: implement
-            }
-        } else {
-            //printf("%s: this is not optimal - fix me\n", __func__);
-
-            if (dst->type == GGML_TYPE_F32) {
-                size_t id = 0;
-                float * dst_ptr = (float *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                                dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr);
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else if (dst->type == GGML_TYPE_BF16) {
-                size_t id = 0;
-                ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                                dst_ptr[id] = *src0_ptr;
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else if (dst->type == GGML_TYPE_F16) {
-                size_t id = 0;
-                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                                dst_ptr[id] = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else {
-                GGML_ABORT("fatal error"); // TODO: implement
-            }
-        }
-        return;
-    }
-
-    // dst counters
-    int64_t i10 = 0;
-    int64_t i11 = 0;
-    int64_t i12 = 0;
-    int64_t i13 = 0;
+    if (ggml_is_contiguous(dst) &&
+            nb00 == sizeof(src_t) &&
+            ggml_get_type_traits_cpu(dst->type)->from_float) {
+        // casting non-quantized types --> intermediate f32 --> quantized
+        ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
+        float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
 
-    if (dst->type == GGML_TYPE_BF16) {
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                i10 += ne00 * ir0;
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+        size_t id = 0;
+        size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
+        char * dst_ptr = (char *) dst->data;
 
-                        memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t));
+        for (int i03 = 0; i03 < ne03; i03++) {
+            for (int i02 = 0; i02 < ne02; i02++) {
+                id += rs * ir0;
+                for (int i01 = ir0; i01 < ir1; i01++) {
+                    const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
 
-                        if (++i10 == ne00) {
-                            i10 = 0;
-                            if (++i11 == ne01) {
-                                i11 = 0;
-                                if (++i12 == ne02) {
-                                    i12 = 0;
-                                    if (++i13 == ne03) {
-                                        i13 = 0;
-                                    }
-                                }
-                            }
-                        }
-                    }
-                }
-                i10 += ne00 * (ne01 - ir1);
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    } else if (dst->type == GGML_TYPE_F16) {
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                i10 += ne00 * ir0;
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
+                    for (int i00 = 0; i00 < ne00; i00++) {
+                        src0_f32[i00] = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
                     }
-                }
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
-
-                        *(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
 
-                        if (++i10 == ne0) {
-                            i10 = 0;
-                            if (++i11 == ne1) {
-                                i11 = 0;
-                                if (++i12 == ne2) {
-                                    i12 = 0;
-                                    if (++i13 == ne3) {
-                                        i13 = 0;
-                                    }
-                                }
-                            }
-                        }
-                    }
-                }
-                i10 += ne00 * (ne01 - ir1);
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    } else if (dst->type == GGML_TYPE_F32) {
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                i10 += ne00 * ir0;
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
-
-                        *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr);
-
-                        if (++i10 == ne0) {
-                            i10 = 0;
-                            if (++i11 == ne1) {
-                                i11 = 0;
-                                if (++i12 == ne2) {
-                                    i12 = 0;
-                                    if (++i13 == ne3) {
-                                        i13 = 0;
-                                    }
-                                }
-                            }
-                        }
-                    }
-                }
-                i10 += ne00 * (ne01 - ir1);
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    } else {
-        GGML_ABORT("fatal error"); // TODO: implement
-    }
-}
-
-static void ggml_compute_forward_dup_f32(
-        const ggml_compute_params * params,
-        ggml_tensor * dst) {
-
-    const ggml_tensor * src0 = dst->src[0];
-
-    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    const int ith = params->ith; // thread index
-    const int nth = params->nth; // number of threads
-
-    // parallelize by rows
-    const int nr = ne01;
-    // number of rows per thread
-    const int dr = (nr + nth - 1) / nth;
-    // row range for this thread
-    const int ir0 = dr * ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    if (src0->type == dst->type &&
-        ne00 == ne0 &&
-        nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
-        // copy by rows
-        const size_t rs = ne00*nb00;
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    memcpy(
-                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
-                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
-                        rs);
-                }
-            }
-        }
-        return;
-    }
-
-    if (ggml_is_contiguous(dst)) {
-        // TODO: simplify
-        if (nb00 == sizeof(float)) {
-            if (ggml_get_type_traits_cpu(dst->type)->from_float) {
-                ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
-
-                size_t id = 0;
-                size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
-                char * dst_ptr = (char *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += rs * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-                            from_float(src0_ptr, dst_ptr + id, ne00);
-                            id += rs;
-                        }
-                        id += rs * (ne01 - ir1);
-                    }
-                }
-            } else {
-                GGML_ABORT("fatal error"); // TODO: implement
-            }
-        } else {
-            //printf("%s: this is not optimal - fix me\n", __func__);
-
-            if (dst->type == GGML_TYPE_F32) {
-                size_t id = 0;
-                float * dst_ptr = (float *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                                dst_ptr[id] = *src0_ptr;
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else if (dst->type == GGML_TYPE_F16) {
-                size_t id = 0;
-                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                                dst_ptr[id] = GGML_CPU_FP32_TO_FP16(*src0_ptr);
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else if (dst->type == GGML_TYPE_BF16) {
-                size_t id = 0;
-                ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                                dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr);
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else if (dst->type == GGML_TYPE_I32) {
-                size_t id = 0;
-                int32_t * dst_ptr = (int32_t *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                                dst_ptr[id] = *src0_ptr;
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else {
-                GGML_ABORT("fatal error"); // TODO: implement
-            }
-        }
-
-        return;
-    }
-
-    // dst counters
-
-    int64_t i10 = 0;
-    int64_t i11 = 0;
-    int64_t i12 = 0;
-    int64_t i13 = 0;
-
-    if (dst->type == GGML_TYPE_F32) {
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                i10 += ne00 * ir0;
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
-
-                        memcpy(dst_ptr, src0_ptr, sizeof(float));
-
-                        if (++i10 == ne0) {
-                            i10 = 0;
-                            if (++i11 == ne1) {
-                                i11 = 0;
-                                if (++i12 == ne2) {
-                                    i12 = 0;
-                                    if (++i13 == ne3) {
-                                        i13 = 0;
-                                    }
-                                }
-                            }
-                        }
-                    }
-                }
-                i10 += ne00 * (ne01 - ir1);
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    } else if (dst->type == GGML_TYPE_F16) {
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                i10 += ne00 * ir0;
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
-
-                        *(ggml_fp16_t *) dst_ptr = GGML_CPU_FP32_TO_FP16(*(const float *) src0_ptr);
-
-                        if (++i10 == ne0) {
-                            i10 = 0;
-                            if (++i11 == ne1) {
-                                i11 = 0;
-                                if (++i12 == ne2) {
-                                    i12 = 0;
-                                    if (++i13 == ne3) {
-                                        i13 = 0;
-                                    }
-                                }
-                            }
-                        }
-                    }
-                }
-                i10 += ne00 * (ne01 - ir1);
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    } else if (dst->type == GGML_TYPE_BF16) {
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                i10 += ne00 * ir0;
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
-
-                        *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr);
-
-                        if (++i10 == ne0) {
-                            i10 = 0;
-                            if (++i11 == ne1) {
-                                i11 = 0;
-                                if (++i12 == ne2) {
-                                    i12 = 0;
-                                    if (++i13 == ne3) {
-                                        i13 = 0;
-                                    }
-                                }
-                            }
-                        }
-                    }
-                }
-                i10 += ne00 * (ne01 - ir1);
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    } else if (dst->type == GGML_TYPE_I32) {
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                i10 += ne00 * ir0;
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
-
-                        *(int32_t *) dst_ptr = *(const float *) src0_ptr;
-
-                        if (++i10 == ne0) {
-                            i10 = 0;
-                            if (++i11 == ne1) {
-                                i11 = 0;
-                                if (++i12 == ne2) {
-                                    i12 = 0;
-                                    if (++i13 == ne3) {
-                                        i13 = 0;
-                                    }
-                                }
-                            }
-                        }
-                    }
-                }
-                i10 += ne00 * (ne01 - ir1);
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    } else {
-        GGML_ABORT("fatal error"); // TODO: implement
-    }
-}
-
-static void ggml_compute_forward_dup_i32(
-        const ggml_compute_params * params,
-        ggml_tensor * dst) {
-
-    const ggml_tensor * src0 = dst->src[0];
-
-    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    const int ith = params->ith; // thread index
-    const int nth = params->nth; // number of threads
-
-    // parallelize by rows
-    const int nr = ne01;
-    // number of rows per thread
-    const int dr = (nr + nth - 1) / nth;
-    // row range for this thread
-    const int ir0 = dr * ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    // dst counters
-
-    int64_t i10 = 0;
-    int64_t i11 = 0;
-    int64_t i12 = 0;
-    int64_t i13 = 0;
-
-    // TODO: not optimal, but works
-    if (dst->type == GGML_TYPE_F32) {
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                i10 += ne00 * ir0;
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
-                }
-                for (int64_t i01 = ir0; i01 < ir1; i01++) {
-                    for (int64_t i00 = 0; i00 < ne00; i00++) {
-                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
-
-                        *(float *) dst_ptr = *(const int32_t *) src0_ptr;
-
-                        if (++i10 == ne0) {
-                            i10 = 0;
-                            if (++i11 == ne1) {
-                                i11 = 0;
-                                if (++i12 == ne2) {
-                                    i12 = 0;
-                                    if (++i13 == ne3) {
-                                        i13 = 0;
-                                    }
-                                }
-                            }
-                        }
-                    }
-                }
-                i10 += ne00 * (ne01 - ir1);
-                while (i10 >= ne0) {
-                    i10 -= ne0;
-                    if (++i11 == ne1) {
-                        i11 = 0;
-                        if (++i12 == ne2) {
-                            i12 = 0;
-                            if (++i13 == ne3) {
-                                i13 = 0;
-                            }
-                        }
-                    }
+                    quantize_row_q(src0_f32, dst_ptr + id, ne00);
+                    id += rs;
                 }
+                id += rs * (ne01 - ir1);
             }
         }
     } else {
-        GGML_ABORT("fatal error"); // TODO: implement
+        // printf("%s %s\n", ggml_type_name(src0->type), ggml_type_name(dst->type));
+        GGML_ABORT("not implemented");
     }
 }
 
@@ -1258,7 +470,7 @@ static void ggml_compute_forward_dup_bytes(
     }
 }
 
-static void ggml_compute_forward_dup_q(
+static void ggml_compute_forward_dup_from_q(
         const ggml_compute_params * params,
               ggml_tensor * dst) {
 
@@ -1323,24 +535,35 @@ void ggml_compute_forward_dup(
     switch (src0->type) {
         case GGML_TYPE_F16:
             {
-                ggml_compute_forward_dup_f16(params, dst);
+                /**/ if (dst->type == GGML_TYPE_F16)  ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_fp16_t>(params, dst);
+                else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_bf16_t>(params, dst);
+                else if (dst->type == GGML_TYPE_F32)  ggml_compute_forward_dup_flt<ggml_fp16_t, float      >(params, dst);
+                else ggml_compute_forward_dup_to_q<ggml_fp16_t>(params, dst);
             } break;
         case GGML_TYPE_BF16:
             {
-                ggml_compute_forward_dup_bf16(params, dst);
+                /**/ if (dst->type == GGML_TYPE_F16)  ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_fp16_t>(params, dst);
+                else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_bf16_t>(params, dst);
+                else if (dst->type == GGML_TYPE_F32)  ggml_compute_forward_dup_flt<ggml_bf16_t, float      >(params, dst);
+                else ggml_compute_forward_dup_to_q<ggml_bf16_t>(params, dst);
             } break;
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_dup_f32(params, dst);
+                /**/ if (dst->type == GGML_TYPE_F16)  ggml_compute_forward_dup_flt<float, ggml_fp16_t>(params, dst);
+                else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<float, ggml_bf16_t>(params, dst);
+                else if (dst->type == GGML_TYPE_F32)  ggml_compute_forward_dup_flt<float, float      >(params, dst);
+                else if (dst->type == GGML_TYPE_I32)  ggml_compute_forward_dup_flt<float, int32_t    >(params, dst);
+                else ggml_compute_forward_dup_to_q<float>(params, dst);
             } break;
         case GGML_TYPE_I32:
             {
-                ggml_compute_forward_dup_i32(params, dst);
+                if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<int32_t, float>(params, dst);
+                else GGML_ABORT("not implemented");
             } break;
         default:
             {
                 if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
-                    ggml_compute_forward_dup_q(params, dst);
+                    ggml_compute_forward_dup_from_q(params, dst);
                     break;
                 }
                 GGML_ABORT("fatal error");
index 01cbf8753303ad0182d707f1556c5bfdcc13e6e1..507b691dc96e204a4975e47a9f7a029724f77675 100644 (file)
@@ -6629,9 +6629,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
     test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1,   1, 1, 1}));
     test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));
 
-    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F16, {512, 3072, 1, 1}));
-    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
-    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32,  GGML_TYPE_F16,  {512, 3072, 1, 1}));
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32,  GGML_TYPE_F32,  {8192, 512, 2, 1}, {0, 2, 1, 3}));
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32,  GGML_TYPE_F32,  {3072, 512, 2, 1}, {0, 2, 1, 3}));
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32,  GGML_TYPE_Q4_0, {8192, 512, 2, 1}));
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_Q4_0, GGML_TYPE_F32,  {8192, 512, 2, 1}));
 
     test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
     test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));