]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml: add support for float16 input tensors in pooling operations (ggml/895)
authorIvan Filipov <redacted>
Mon, 22 Jul 2024 11:32:02 +0000 (14:32 +0300)
committerGeorgi Gerganov <redacted>
Sat, 27 Jul 2024 14:43:44 +0000 (17:43 +0300)
* Add support for float16 tensors in 1d pooling operations

* Add support for float16 input tensors in 2d pooling operations

* code cleanup

remove unnecessary casting during srow ptr initialization

---------

Co-authored-by: vanaka11 <redacted>
ggml/src/ggml.c

index a14d0d1dbdf0eb613ded5aed8320ac7adce6feb3..c76d00a39ed00a3882cd7474f52414f41406b1f2 100644 (file)
@@ -14746,7 +14746,7 @@ static void ggml_compute_forward_pool_1d_sk_p0(
 
     const struct ggml_tensor * src = dst->src[0];
 
-    assert(src->type == GGML_TYPE_F32);
+    assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
 
     if (params->ith != 0) {
         return;
@@ -14759,10 +14759,8 @@ static void ggml_compute_forward_pool_1d_sk_p0(
     const int64_t rs = dst->ne[0];
 
     while (cdata < data_end) {
-        const float * const srow = (const float *)cdata;
-
+        const void * srow = (const void *)cdata;
         int j = 0;
-
         for (int64_t i = 0; i < rs; ++i) {
             switch (op) {
                 case GGML_OP_POOL_AVG:   drow[i] = 0;        break;
@@ -14770,10 +14768,11 @@ static void ggml_compute_forward_pool_1d_sk_p0(
                 case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
             }
             for (int ki = 0; ki < k; ++ki) {
+                const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
                 switch (op) {
-                    case GGML_OP_POOL_AVG:                          drow[i] += srow[j]; break;
-                    case GGML_OP_POOL_MAX:   if (srow[j] > drow[i]) drow[i]  = srow[j]; break;
-                    case GGML_OP_POOL_COUNT:                        GGML_ABORT("fatal error");
+                    case GGML_OP_POOL_AVG:                         drow[i] += srow_j; break;
+                    case GGML_OP_POOL_MAX:   if (srow_j > drow[i]) drow[i]  = srow_j; break;
+                    case GGML_OP_POOL_COUNT:                       GGML_ABORT("fatal error");
                 }
                 ++j;
             }
@@ -14814,7 +14813,7 @@ static void ggml_compute_forward_pool_2d(
 
     const struct ggml_tensor * src = dst->src[0];
 
-    GGML_ASSERT(src->type == GGML_TYPE_F32);
+    assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
 
     if (params->ith != 0) {
         return;
@@ -14857,14 +14856,15 @@ static void ggml_compute_forward_pool_2d(
 
                 for (int ky = 0; ky < k1; ++ky) {
                     if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
-                    const float * const srow = (const float *)(cdata + src->nb[1] * (iy + ky));
+                    const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
                     for (int kx = 0; kx < k0; ++kx) {
                         int j = ix + kx;
                         if (j < 0 || j >= src->ne[0]) continue;
+                        const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
                         switch (op) {
-                            case GGML_OP_POOL_AVG:                     *out += srow[j]; break;
-                            case GGML_OP_POOL_MAX: if (srow[j] > *out) *out  = srow[j]; break;
-                            case GGML_OP_POOL_COUNT:                GGML_ABORT("fatal error");
+                            case GGML_OP_POOL_AVG:                     *out += srow_j; break;
+                            case GGML_OP_POOL_MAX: if (srow_j > *out)  *out  = srow_j; break;
+                            case GGML_OP_POOL_COUNT:               GGML_ABORT("fatal error");
                         }
                     }
                 }