]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda/cpu: Increase support for fp16 unary operations (ggml/1125)
authorcmdr2 <redacted>
Fri, 28 Feb 2025 07:04:39 +0000 (12:34 +0530)
committerGeorgi Gerganov <redacted>
Mon, 3 Mar 2025 16:18:11 +0000 (18:18 +0200)
* Support fp16 unary operations in the CUDA backend

* cpu: increase fp16 support for unary operators in the CPU backend

* cuda: increase fp16 support for unary operators in the CUDA backend

* Add test cases for fp16 unary operators

* metal: update supports_op for unary operators that don't support fp16, to prevent test-backend-ops from failing

* metal: fix PR comments for unary op support after fp16 unary tests

ggml/src/ggml-cpu/ggml-cpu.c
ggml/src/ggml-cuda/clamp.cu
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/unary.cu
ggml/src/ggml-cuda/unary.cuh
ggml/src/ggml-metal/ggml-metal.m
tests/test-backend-ops.cpp

index 33ab5e9c6e766c0c2589eda46ed726e5b8903b63..9ab24522ca769a74f330ebe9f5ca0faeeb0b8bc0 100644 (file)
@@ -1432,6 +1432,12 @@ inline static void ggml_vec_sub_f16 (const int n, ggml_fp16_t * z, const ggml_fp
 inline static void ggml_vec_set_f32 (const int n, float * x, const float   v)                  { for (int i = 0; i < n; ++i) x[i]  = v;           }
 inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = x[i];        }
 inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = -x[i];       }
+inline static void ggml_vec_neg_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(-GGML_FP16_TO_FP32(x[i]));
+    }
+}
+
 inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]*y[i];   }
 inline static void ggml_vec_mul_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
     for (int i = 0; i < n; ++i) {
@@ -1830,22 +1836,107 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
 
 inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s);   }
 inline static void ggml_vec_sqr_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i];   }
+inline static void ggml_vec_sqr_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        float v = GGML_FP16_TO_FP32(x[i]);
+        y[i] = GGML_FP32_TO_FP16(v*v);
+    }
+}
 inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
+inline static void ggml_vec_sqrt_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(sqrtf(GGML_FP16_TO_FP32(x[i])));
+    }
+}
 inline static void ggml_vec_log_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]);  }
+inline static void ggml_vec_log_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(logf(GGML_FP16_TO_FP32(x[i])));
+    }
+}
 inline static void ggml_vec_sin_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]);  }
+inline static void ggml_vec_sin_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(sinf(GGML_FP16_TO_FP32(x[i])));
+    }
+}
 inline static void ggml_vec_cos_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]);  }
+inline static void ggml_vec_cos_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(cosf(GGML_FP16_TO_FP32(x[i])));
+    }
+}
 inline static void ggml_vec_abs_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
+inline static void ggml_vec_abs_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(fabsf(GGML_FP16_TO_FP32(x[i])));
+    }
+}
 inline static void ggml_vec_sgn_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
+inline static void ggml_vec_sgn_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        float v = GGML_FP16_TO_FP32(x[i]);
+        y[i] = GGML_FP32_TO_FP16((v > 0.f) ? 1.f : ((v < 0.f) ? -1.f : 0.f));
+    }
+}
 inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
+inline static void ggml_vec_step_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16((GGML_FP16_TO_FP32(x[i]) > 0.f) ? 1.f : 0.f);
+    }
+}
 inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]);  }
+inline static void ggml_vec_tanh_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(tanhf(GGML_FP16_TO_FP32(x[i])));
+    }
+}
 inline static void ggml_vec_elu_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); }
+inline static void ggml_vec_elu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(expm1f(GGML_FP16_TO_FP32(x[i])));
+    }
+}
 inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
+inline static void ggml_vec_relu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        float v = GGML_FP16_TO_FP32(x[i]);
+        y[i] = GGML_FP32_TO_FP16((v > 0.f) ? v : 0.f);
+    }
+}
 inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
+inline static void ggml_vec_leaky_relu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const float ns) {
+    for (int i = 0; i < n; ++i) {
+        float v = GGML_FP16_TO_FP32(x[i]);
+        y[i] = GGML_FP32_TO_FP16(((v > 0.f) ? v : 0.f) + ns * ((v < 0.0f) ? v : 0.f));
+    }
+}
 inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); }
+inline static void ggml_vec_sigmoid_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(1.f / (1.f + expf(-GGML_FP16_TO_FP32(x[i]))));
+    }
+}
 // TODO: optimize performance
 inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
+inline static void ggml_vec_hardswish_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        float v = GGML_FP16_TO_FP32(x[i]);
+        y[i] = GGML_FP32_TO_FP16(v * fminf(1.0f, fmaxf(0.0f, (v + 3.0f) / 6.0f)));
+    }
+}
 inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
+inline static void ggml_vec_hardsigmoid_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(fminf(1.0f, fmaxf(0.0f, (GGML_FP16_TO_FP32(x[i]) + 3.0f) / 6.0f)));
+    }
+}
 inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); }
+inline static void ggml_vec_exp_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(expf(GGML_FP16_TO_FP32(x[i])));
+    }
+}
 
 static const float GELU_COEF_A     = 0.044715f;
 static const float GELU_QUICK_COEF = -1.702f;
@@ -1913,10 +2004,21 @@ inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float *
 }
 #endif
 
+inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        float v = GGML_FP16_TO_FP32(x[i]);
+        y[i] = GGML_FP32_TO_FP16(v*(1.0f/(1.0f+expf(GELU_QUICK_COEF*v))));
+    }
+}
+
 // Sigmoid Linear Unit (SiLU) function
 inline static float ggml_silu_f32(float x) {
     return x/(1.0f + expf(-x));
 }
+inline static ggml_fp16_t ggml_silu_f16(ggml_fp16_t x) {
+    float v = GGML_FP16_TO_FP32(x);
+    return GGML_FP32_TO_FP16(v/(1.0f + expf(-v)));
+}
 
 #if __FINITE_MATH_ONLY__
 #error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix"
@@ -2140,6 +2242,12 @@ static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
     }
 }
 
+inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = ggml_silu_f16(x[i]);
+    }
+}
+
 static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
     int i = 0;
     ggml_float sum = 0;
@@ -2211,12 +2319,24 @@ inline static float ggml_silu_backward_f32(float x, float dy) {
     return dy*s*(1.0f + x*(1.0f - s));
 }
 
+inline static ggml_fp16_t ggml_silu_backward_f16(ggml_fp16_t x, ggml_fp16_t dy) {
+    const float v = GGML_FP16_TO_FP32(x);
+    const float s = 1.0f/(1.0f + expf(-v));
+    return GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(dy)*s*(1.0f + v*(1.0f - s)));
+}
+
 inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
     for (int i = 0; i < n; ++i) {
         dx[i] = ggml_silu_backward_f32(x[i], dy[i]);
     }
 }
 
+inline static void ggml_vec_silu_backward_f16(const int n, ggml_fp16_t * dx, const ggml_fp16_t * x, const ggml_fp16_t * dy) {
+    for (int i = 0; i < n; ++i) {
+        dx[i] = ggml_silu_backward_f16(x[i], dy[i]);
+    }
+}
+
 inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
 #ifndef GGML_USE_ACCELERATE
     ggml_float sum = 0.0;
@@ -5623,6 +5743,31 @@ static void ggml_compute_forward_sqr_f32(
     }
 }
 
+static void ggml_compute_forward_sqr_f16(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n     = ggml_nrows(src0);
+    const int nc    = src0->ne[0];
+
+    assert( dst->nb[0] == sizeof(ggml_fp16_t));
+    assert(src0->nb[0] == sizeof(ggml_fp16_t));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_sqr_f16(nc,
+                (ggml_fp16_t *) ((char *) dst->data  + i*( dst->nb[1])),
+                (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
 static void ggml_compute_forward_sqr(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
@@ -5634,6 +5779,10 @@ static void ggml_compute_forward_sqr(
             {
                 ggml_compute_forward_sqr_f32(params, dst);
             } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_sqr_f16(params, dst);
+            } break;
         default:
             {
                 GGML_ABORT("fatal error");
@@ -5668,6 +5817,31 @@ static void ggml_compute_forward_sqrt_f32(
     }
 }
 
+static void ggml_compute_forward_sqrt_f16(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert( dst->nb[0] == sizeof(ggml_fp16_t));
+    assert(src0->nb[0] == sizeof(ggml_fp16_t));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_sqrt_f16(nc,
+                (ggml_fp16_t *) ((char *) dst->data  + i*( dst->nb[1])),
+                (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
 static void ggml_compute_forward_sqrt(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
@@ -5679,6 +5853,10 @@ static void ggml_compute_forward_sqrt(
             {
                 ggml_compute_forward_sqrt_f32(params, dst);
             } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_sqrt_f16(params, dst);
+            } break;
         default:
             {
                 GGML_ABORT("fatal error");
@@ -5713,6 +5891,31 @@ static void ggml_compute_forward_log_f32(
     }
 }
 
+static void ggml_compute_forward_log_f16(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    GGML_ASSERT( dst->nb[0] == sizeof(ggml_fp16_t));
+    GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_log_f16(nc,
+                (ggml_fp16_t *) ((char *) dst->data  + i*( dst->nb[1])),
+                (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
 static void ggml_compute_forward_log(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
@@ -5724,6 +5927,10 @@ static void ggml_compute_forward_log(
             {
                 ggml_compute_forward_log_f32(params, dst);
             } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_log_f16(params, dst);
+            } break;
         default:
             {
                 GGML_ABORT("fatal error");
@@ -5758,6 +5965,31 @@ static void ggml_compute_forward_sin_f32(
     }
 }
 
+static void ggml_compute_forward_sin_f16(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    GGML_ASSERT( dst->nb[0] == sizeof(ggml_fp16_t));
+    GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_sin_f16(nc,
+                (ggml_fp16_t *) ((char *) dst->data  + i*( dst->nb[1])),
+                (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
 static void ggml_compute_forward_sin(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
@@ -5769,6 +6001,10 @@ static void ggml_compute_forward_sin(
             {
                 ggml_compute_forward_sin_f32(params, dst);
             } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_sin_f16(params, dst);
+            } break;
         default:
             {
                 GGML_ABORT("fatal error");
@@ -5803,6 +6039,31 @@ static void ggml_compute_forward_cos_f32(
     }
 }
 
+static void ggml_compute_forward_cos_f16(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    GGML_ASSERT( dst->nb[0] == sizeof(ggml_fp16_t));
+    GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_cos_f16(nc,
+                (ggml_fp16_t *) ((char *) dst->data  + i*( dst->nb[1])),
+                (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
 static void ggml_compute_forward_cos(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
@@ -5814,6 +6075,10 @@ static void ggml_compute_forward_cos(
             {
                 ggml_compute_forward_cos_f32(params, dst);
             } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_cos_f16(params, dst);
+            } break;
         default:
             {
                 GGML_ABORT("fatal error");
@@ -6471,6 +6736,30 @@ static void ggml_compute_forward_abs_f32(
     }
 }
 
+static void ggml_compute_forward_abs_f16(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_abs_f16(nc,
+                (ggml_fp16_t *) ((char *) dst->data  + i*( dst->nb[1])),
+                (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
 static void ggml_compute_forward_abs(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
@@ -6482,6 +6771,10 @@ static void ggml_compute_forward_abs(
             {
                 ggml_compute_forward_abs_f32(params, dst);
             } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_abs_f16(params, dst);
+            } break;
         default:
             {
                 GGML_ABORT("fatal error");
@@ -6515,6 +6808,30 @@ static void ggml_compute_forward_sgn_f32(
     }
 }
 
+static void ggml_compute_forward_sgn_f16(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_sgn_f16(nc,
+                (ggml_fp16_t *) ((char *) dst->data  + i*( dst->nb[1])),
+                (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
 static void ggml_compute_forward_sgn(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
@@ -6526,6 +6843,10 @@ static void ggml_compute_forward_sgn(
             {
                 ggml_compute_forward_sgn_f32(params, dst);
             } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_sgn_f16(params, dst);
+            } break;
         default:
             {
                 GGML_ABORT("fatal error");
@@ -6559,6 +6880,30 @@ static void ggml_compute_forward_neg_f32(
     }
 }
 
+static void ggml_compute_forward_neg_f16(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_neg_f16(nc,
+                (ggml_fp16_t *) ((char *) dst->data  + i*( dst->nb[1])),
+                (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
 static void ggml_compute_forward_neg(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
@@ -6570,6 +6915,10 @@ static void ggml_compute_forward_neg(
             {
                 ggml_compute_forward_neg_f32(params, dst);
             } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_neg_f16(params, dst);
+            } break;
         default:
             {
                 GGML_ABORT("fatal error");
@@ -6603,9 +6952,33 @@ static void ggml_compute_forward_step_f32(
     }
 }
 
-static void ggml_compute_forward_step(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
+static void ggml_compute_forward_step_f16(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_step_f16(nc,
+                (ggml_fp16_t *) ((char *) dst->data  + i*( dst->nb[1])),
+                (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_step(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
 
     const struct ggml_tensor * src0 = dst->src[0];
 
@@ -6614,6 +6987,10 @@ static void ggml_compute_forward_step(
             {
                 ggml_compute_forward_step_f32(params, dst);
             } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_step_f16(params, dst);
+            } break;
         default:
             {
                 GGML_ABORT("fatal error");
@@ -6647,6 +7024,30 @@ static void ggml_compute_forward_tanh_f32(
     }
 }
 
+static void ggml_compute_forward_tanh_f16(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_tanh_f16(nc,
+                (ggml_fp16_t *) ((char *) dst->data  + i*( dst->nb[1])),
+                (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
 static void ggml_compute_forward_tanh(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
@@ -6658,6 +7059,10 @@ static void ggml_compute_forward_tanh(
             {
                 ggml_compute_forward_tanh_f32(params, dst);
             } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_tanh_f16(params, dst);
+            } break;
         default:
             {
                 GGML_ABORT("fatal error");
@@ -6691,6 +7096,30 @@ static void ggml_compute_forward_elu_f32(
     }
 }
 
+static void ggml_compute_forward_elu_f16(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_elu_f16(nc,
+                (ggml_fp16_t *) ((char *) dst->data  + i*( dst->nb[1])),
+                (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
 static void ggml_compute_forward_elu(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
@@ -6702,6 +7131,10 @@ static void ggml_compute_forward_elu(
             {
                 ggml_compute_forward_elu_f32(params, dst);
             } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_elu_f16(params, dst);
+            } break;
         default:
             {
                 GGML_ABORT("fatal error");
@@ -6735,6 +7168,30 @@ static void ggml_compute_forward_relu_f32(
     }
 }
 
+static void ggml_compute_forward_relu_f16(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_relu_f16(nc,
+                (ggml_fp16_t *) ((char *) dst->data  + i*( dst->nb[1])),
+                (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
 static void ggml_compute_forward_relu(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
@@ -6746,6 +7203,10 @@ static void ggml_compute_forward_relu(
             {
                 ggml_compute_forward_relu_f32(params, dst);
             } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_relu_f16(params, dst);
+            } break;
         default:
             {
                 GGML_ABORT("fatal error");
@@ -6779,6 +7240,30 @@ static void ggml_compute_forward_sigmoid_f32(
     }
 }
 
+static void ggml_compute_forward_sigmoid_f16(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_sigmoid_f16(nc,
+                (ggml_fp16_t *) ((char *) dst->data  + i*( dst->nb[1])),
+                (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
 static void ggml_compute_forward_sigmoid(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
@@ -6790,6 +7275,10 @@ static void ggml_compute_forward_sigmoid(
             {
                 ggml_compute_forward_sigmoid_f32(params, dst);
             } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_sigmoid_f16(params, dst);
+            } break;
         default:
             {
                 GGML_ABORT("fatal error");
@@ -6838,6 +7327,46 @@ static void ggml_compute_forward_gelu_f32(
     }
 }
 
+static void ggml_compute_forward_gelu_f16(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_gelu_f16(nc,
+                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            const float v = GGML_FP16_TO_FP32(x);
+            UNUSED(v);
+            assert(!isnan(v));
+            assert(!isinf(v));
+        }
+#endif
+    }
+}
+
 static void ggml_compute_forward_gelu(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
@@ -6849,6 +7378,10 @@ static void ggml_compute_forward_gelu(
             {
                 ggml_compute_forward_gelu_f32(params, dst);
             } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_gelu_f16(params, dst);
+            } break;
         default:
             {
                 GGML_ABORT("fatal error");
@@ -6897,6 +7430,46 @@ static void ggml_compute_forward_gelu_quick_f32(
     }
 }
 
+static void ggml_compute_forward_gelu_quick_f16(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_gelu_quick_f16(nc,
+                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            const float v = GGML_FP16_TO_FP32(x);
+            UNUSED(v);
+            assert(!isnan(v));
+            assert(!isinf(v));
+        }
+#endif
+    }
+}
+
 static void ggml_compute_forward_gelu_quick(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
@@ -6908,6 +7481,10 @@ static void ggml_compute_forward_gelu_quick(
             {
                 ggml_compute_forward_gelu_quick_f32(params, dst);
             } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_gelu_quick_f16(params, dst);
+            } break;
         default:
             {
                 GGML_ABORT("fatal error");
@@ -6956,6 +7533,46 @@ static void ggml_compute_forward_silu_f32(
     }
 }
 
+static void ggml_compute_forward_silu_f16(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_silu_f16(nc,
+                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])))[k];
+            const float v = GGML_FP16_TO_FP32(x);
+            UNUSED(v);
+            assert(!isnan(v));
+            assert(!isinf(v));
+        }
+#endif
+    }
+}
+
 static void ggml_compute_forward_silu(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
@@ -6967,6 +7584,10 @@ static void ggml_compute_forward_silu(
             {
                 ggml_compute_forward_silu_f32(params, dst);
             } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_silu_f16(params, dst);
+            } break;
         default:
             {
                 GGML_ABORT("fatal error");
@@ -7005,6 +7626,36 @@ static void ggml_compute_forward_leaky_relu_f32(
     }
 }
 
+static void ggml_compute_forward_leaky_relu_f16(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    float negative_slope;
+    memcpy(&negative_slope, dst->op_params, sizeof(float));
+
+    assert(dst->nb[0]  == sizeof(ggml_fp16_t));
+    assert(src0->nb[0] == sizeof(ggml_fp16_t));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_leaky_relu_f16(nc,
+                (ggml_fp16_t *) ((char *) dst->data  + i*( dst->nb[1])),
+                (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
+    }
+}
+
 static void ggml_compute_forward_leaky_relu(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
@@ -7016,6 +7667,10 @@ static void ggml_compute_forward_leaky_relu(
             {
                 ggml_compute_forward_leaky_relu_f32(params, dst);
             } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_leaky_relu_f16(params, dst);
+            } break;
         default:
             {
                 GGML_ABORT("fatal error");
@@ -7068,6 +7723,50 @@ static void ggml_compute_forward_silu_back_f32(
     }
 }
 
+static void ggml_compute_forward_silu_back_f16(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * grad = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    assert(ggml_is_contiguous_1(grad));
+    assert(ggml_is_contiguous_1(src1));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src1, dst));
+    assert(ggml_are_same_shape(src1, grad));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src1->ne[0];
+    const int nr = ggml_nrows(src1);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_silu_backward_f16(nc,
+                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
+                (ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
+
+    #ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            const float v = GGML_FP16_TO_FP32(x);
+            UNUSED(v);
+            assert(!isnan(v));
+            assert(!isinf(v));
+        }
+    #endif
+    }
+}
+
 static void ggml_compute_forward_silu_back(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
@@ -7079,6 +7778,10 @@ static void ggml_compute_forward_silu_back(
             {
                 ggml_compute_forward_silu_back_f32(params, dst);
             } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_silu_back_f16(params, dst);
+            } break;
         default:
             {
                 GGML_ABORT("fatal error");
@@ -7086,7 +7789,6 @@ static void ggml_compute_forward_silu_back(
     }
 }
 
-
 static void ggml_compute_forward_hardswish_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
@@ -7110,6 +7812,31 @@ static void ggml_compute_forward_hardswish_f32(
                 (float *) ((char *) src0->data + i*(src0->nb[1])));
     }
 }
+
+static void ggml_compute_forward_hardswish_f16(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_hardswish_f16(nc,
+                (ggml_fp16_t *) ((char *) dst->data  + i*( dst->nb[1])),
+                (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
 static void ggml_compute_forward_hardswish(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
@@ -7121,6 +7848,10 @@ static void ggml_compute_forward_hardswish(
             {
                 ggml_compute_forward_hardswish_f32(params, dst);
             } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_hardswish_f16(params, dst);
+            } break;
         default:
             {
                 GGML_ABORT("fatal error");
@@ -7152,6 +7883,30 @@ static void ggml_compute_forward_hardsigmoid_f32(
     }
 }
 
+static void ggml_compute_forward_hardsigmoid_f16(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_hardsigmoid_f16(nc,
+                (ggml_fp16_t *) ((char *) dst->data  + i*( dst->nb[1])),
+                (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
 static void ggml_compute_forward_hardsigmoid(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
@@ -7163,6 +7918,10 @@ static void ggml_compute_forward_hardsigmoid(
             {
                 ggml_compute_forward_hardsigmoid_f32(params, dst);
             } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_hardsigmoid_f16(params, dst);
+            } break;
         default:
             {
                 GGML_ABORT("fatal error");
@@ -7194,6 +7953,30 @@ static void ggml_compute_forward_exp_f32(
     }
 }
 
+static void ggml_compute_forward_exp_f16(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_exp_f16(nc,
+                (ggml_fp16_t *) ((char *) dst->data  + i*( dst->nb[1])),
+                (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
 static void ggml_compute_forward_exp(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
@@ -7205,6 +7988,10 @@ static void ggml_compute_forward_exp(
             {
                 ggml_compute_forward_exp_f32(params, dst);
             } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_exp_f16(params, dst);
+            } break;
         default:
             {
                 GGML_ABORT("fatal error");
@@ -9489,6 +10276,43 @@ static void ggml_compute_forward_clamp_f32(
     }
 }
 
+static void ggml_compute_forward_clamp_f16(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    float min;
+    float max;
+    memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+
+    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+    for (int j = ith; j < n; j += nth) {
+        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *)  dst->data + j*nb1);
+        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
+
+        for (int i = 0; i < nc; i++) {
+            float v = GGML_FP16_TO_FP32(src0_ptr[i]);
+            dst_ptr[i] = GGML_FP32_TO_FP16(MAX(MIN(v, max), min));
+        }
+    }
+}
+
 static void ggml_compute_forward_clamp(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
@@ -9501,6 +10325,9 @@ static void ggml_compute_forward_clamp(
                 ggml_compute_forward_clamp_f32(params, dst);
             } break;
         case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_clamp_f16(params, dst);
+            } break;
         case GGML_TYPE_BF16:
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
index 8009a3e3d8607cfc8b682ac0affee9a84c11e71d..611db88644e443b6b574f7e31b1d0d346eb7c6f6 100644 (file)
@@ -1,6 +1,7 @@
 #include "clamp.cuh"
 
-static __global__ void clamp_f32(const float * x, float * dst, const float min, const float max, const int k) {
+template <class T>
+static __global__ void op_clamp(const T * x, T * dst, const T min, const T max, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= k) {
@@ -10,25 +11,31 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
     dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
 }
 
-static void clamp_f32_cuda(const float * x, float * dst, const float min, const float max, const int k, cudaStream_t stream) {
+template <class T>
+static void clamp_cuda(const T * x, T * dst, const T min, const T max, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE;
-    clamp_f32<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k);
+    op_clamp<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k);
 }
 
 
 void ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
+    const void * src0_d = src0->data;
+    void * dst_d = dst->data;
     cudaStream_t stream = ctx.stream();
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
 
     float min;
     float max;
     memcpy(&min, dst->op_params, sizeof(float));
     memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
 
-    clamp_f32_cuda(src0_d, dst_d, min, max, ggml_nelements(src0), stream);
+    if (src0->type == GGML_TYPE_F16) {
+        clamp_cuda((const half *)src0_d, (half *)dst_d, (half)min, (half)max, ggml_nelements(src0), stream);
+    } else {
+        clamp_cuda((const float *)src0_d, (float *)dst_d, (float)min, (float)max, ggml_nelements(src0), stream);
+    }
 }
index d23686d161773ef765b2d1746a39eaf2972d5177..a28b7037bd3a05348db9a5fca67475b254e13166 100644 (file)
@@ -2147,6 +2147,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
             break;
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(dst)) {
+                case GGML_UNARY_OP_ABS:
+                    ggml_cuda_op_abs(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_SGN:
+                    ggml_cuda_op_sgn(ctx, dst);
+                    break;
                 case GGML_UNARY_OP_NEG:
                     ggml_cuda_op_neg(ctx, dst);
                     break;
@@ -2244,6 +2250,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_CLAMP:
             ggml_cuda_op_clamp(ctx, dst);
             break;
+        case GGML_OP_LOG:
+            ggml_cuda_op_log(ctx, dst);
+            break;
         case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
         case GGML_OP_VIEW:
@@ -2962,6 +2971,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
     switch (op->op) {
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(op)) {
+                case GGML_UNARY_OP_ABS:
+                case GGML_UNARY_OP_SGN:
                 case GGML_UNARY_OP_NEG:
                 case GGML_UNARY_OP_STEP:
                 case GGML_UNARY_OP_GELU:
@@ -3168,6 +3179,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_SIN:
         case GGML_OP_COS:
         case GGML_OP_CLAMP:
+        case GGML_OP_LOG:
             return true;
         case GGML_OP_CONT:
             return op->src[0]->type != GGML_TYPE_BF16;
index 6b21f407d80490086d08aa50c68299e4fd13316a..9b0eaaccd05202158ce2f2463e9544673b985437 100644 (file)
@@ -1,6 +1,29 @@
 #include "unary.cuh"
 
-static __global__ void neg_f32(const float * x, float * dst, const int k) {
+template <class T>
+static __global__ void op_abs(const T * x, T * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    dst[i] = fabsf(x[i]);
+}
+
+template <class T>
+static __global__ void op_sgn(const T * x, T * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    dst[i] = (T)(x[i] > (T)0.f ? 1.f : ((x[i] < (T)0.f ? -1.f : 0.f)));
+}
+
+template <class T>
+static __global__ void op_neg(const T * x, T * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= k) {
@@ -10,61 +33,67 @@ static __global__ void neg_f32(const float * x, float * dst, const int k) {
     dst[i] = -x[i];
 }
 
-static __global__ void step_f32(const float * x, float * dst, const int k) {
+template <class T>
+static __global__ void op_step(const T * x, T * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= k) {
         return;
     }
 
-    dst[i] = x[i] > 0.0f;
+    dst[i] = x[i] > (T)0.0f;
 }
 
-static __global__ void gelu_f32(const float * x, float * dst, const int k) {
-    const float GELU_COEF_A    = 0.044715f;
-    const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+template <class T>
+static __global__ void op_gelu(const T * x, T * dst, const int k) {
+    const T GELU_COEF_A    = 0.044715f;
+    const T SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= k) {
         return;
     }
 
-    float xi = x[i];
-    dst[i] = 0.5f*xi*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi)));
+    T xi = x[i];
+    dst[i] = (T)0.5f*xi*((T)1.0f + (T)tanhf(SQRT_2_OVER_PI*xi*((T)1.0f + GELU_COEF_A*xi*xi)));
 }
 
-static __global__ void gelu_quick_f32(const float * x, float * dst, int k) {
-    const float GELU_QUICK_COEF = -1.702f;
+template <class T>
+static __global__ void op_gelu_quick(const T * x, T * dst, int k) {
+    const T GELU_QUICK_COEF = -1.702f;
     const int i  = blockDim.x*blockIdx.x + threadIdx.x;
     if (i >= k) {
         return;
     }
-    dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i])));
+    dst[i] = x[i] * ((T)1.0f / ((T)1.0f + (T)expf(GELU_QUICK_COEF * x[i])));
 }
 
-static __global__ void silu_f32(const float * x, float * dst, const int k) {
+template <class T>
+static __global__ void op_silu(const T * x, T * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= k) {
         return;
     }
-    dst[i] = x[i] / (1.0f + expf(-x[i]));
+    dst[i] = x[i] / ((T)1.0f + (T)expf(-x[i]));
 }
 
-static __global__ void silu_back_f32(
-        const float * grad, const float * xf, float * dst, const int k) {
+template <class T>
+static __global__ void op_silu_back(
+        const T * grad, const T * xf, T * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= k) {
         return;
     }
 
-    const float xfi = xf[i];
-    const float s = 1.0f / (1.0f + expf(-xfi));
-    dst[i] = grad[i] * s * (1.0f + xfi * (1.0f - s));
+    const T xfi = xf[i];
+    const T s = (T)1.0f / ((T)1.0f + (T)expf(-xfi));
+    dst[i] = grad[i] * s * ((T)1.0f + xfi * ((T)1.0f - s));
 }
 
-static __global__ void tanh_f32(const float * x, float * dst, int k) {
+template <class T>
+static __global__ void op_tanh(const T * x, T * dst, int k) {
     const int i  = blockDim.x*blockIdx.x + threadIdx.x;
     if (i >= k) {
         return;
@@ -72,7 +101,8 @@ static __global__ void tanh_f32(const float * x, float * dst, int k) {
     dst[i] = tanhf(x[i]);
 }
 
-static __global__ void relu_f32(const float * x, float * dst, const int k) {
+template <class T>
+static __global__ void op_relu(const T * x, T * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= k) {
@@ -81,34 +111,38 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) {
     dst[i] = fmaxf(x[i], 0);
 }
 
-static __global__ void sigmoid_f32(const float * x, float * dst, const int k) {
+template <class T>
+static __global__ void op_sigmoid(const T * x, T * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= k) {
         return;
     }
-    dst[i] = 1.0f / (1.0f + expf(-x[i]));
+    dst[i] = (T)1.0f / ((T)1.0f + (T)expf(-x[i]));
 }
 
-static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) {
+template <class T>
+static __global__ void op_hardsigmoid(const T * x, T * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= k) {
         return;
     }
-    dst[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
+    dst[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + (T)3.0f) / (T)6.0f));
 }
 
-static __global__ void hardswish_f32(const float * x, float * dst, const int k) {
+template <class T>
+static __global__ void op_hardswish(const T * x, T * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= k) {
         return;
     }
-    dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
+    dst[i] = x[i] * (T)fminf(1.0f, fmaxf(0.0f, (x[i] + (T)3.0f) / (T)6.0f));
 }
 
-static __global__ void exp_f32(const float * x, float * dst, const int k) {
+template <class T>
+static __global__ void op_exp(const T * x, T * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= k) {
@@ -117,15 +151,17 @@ static __global__ void exp_f32(const float * x, float * dst, const int k) {
     dst[i] = expf(x[i]);
 }
 
-static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
+template <class T>
+static __global__ void op_leaky_relu(const T * x, T * dst, const int k, const float negative_slope) {
     const int i  = blockDim.x*blockIdx.x + threadIdx.x;
     if (i >= k) {
         return;
     }
-    dst[i] = fmaxf(x[i], 0) + fminf(x[i], 0.0f) * negative_slope;
+    dst[i] = (T)fmaxf(x[i], 0) + (T)fminf(x[i], 0.0f) * (T)negative_slope;
 }
 
-static __global__ void sqr_f32(const float * x, float * dst, const int k) {
+template <class T>
+static __global__ void op_sqr(const T * x, T * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= k) {
@@ -134,7 +170,8 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
     dst[i] = x[i] * x[i];
 }
 
-static __global__ void sqrt_f32(const float * x, float * dst, const int k) {
+template <class T>
+static __global__ void op_sqrt(const T * x, T * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= k) {
@@ -143,7 +180,8 @@ static __global__ void sqrt_f32(const float * x, float * dst, const int k) {
     dst[i] = sqrtf(x[i]);
 }
 
-static __global__ void sin_f32(const float * x, float * dst, const int k) {
+template <class T>
+static __global__ void op_sin(const T * x, T * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= k) {
@@ -152,7 +190,8 @@ static __global__ void sin_f32(const float * x, float * dst, const int k) {
     dst[i] = sinf(x[i]);
 }
 
-static __global__ void cos_f32(const float * x, float * dst, const int k) {
+template <class T>
+static __global__ void op_cos(const T * x, T * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= k) {
@@ -161,145 +200,248 @@ static __global__ void cos_f32(const float * x, float * dst, const int k) {
     dst[i] = cosf(x[i]);
 }
 
-static void neg_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+template <class T>
+static __global__ void op_log(const T * x, T * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = logf(x[i]);
+}
+
+template <class T>
+static void abs_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE;
+    op_abs<<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+template <class T>
+static void sgn_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE;
+    op_sgn<<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+template <class T>
+static void neg_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE;
-    neg_f32<<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+    op_neg<<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
-static void step_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+template <class T>
+static void step_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_STEP_BLOCK_SIZE - 1) / CUDA_STEP_BLOCK_SIZE;
-    step_f32<<<num_blocks, CUDA_STEP_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+    op_step<<<num_blocks, CUDA_STEP_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
-static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+template <class T>
+static void gelu_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
-    gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+    op_gelu<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
-static void gelu_quick_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+template <class T>
+static void gelu_quick_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
-    gelu_quick_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+    op_gelu_quick<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
-static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+template <class T>
+static void silu_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
-    silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+    op_silu<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
-static void silu_back_f32_cuda(const float * grad, const float * x, float * dst, const int k, cudaStream_t stream) {
+template <class T>
+static void silu_back_cuda(const T * grad, const T * x, T * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
-    silu_back_f32<<<num_blocks, CUDA_SILU_BACK_BLOCK_SIZE, 0, stream>>>(grad, x, dst, k);
+    op_silu_back<<<num_blocks, CUDA_SILU_BACK_BLOCK_SIZE, 0, stream>>>(grad, x, dst, k);
 }
 
-static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+template <class T>
+static void tanh_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
-    tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+    op_tanh<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
-static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+template <class T>
+static void relu_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
-    relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+    op_relu<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
-static void sigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+template <class T>
+static void sigmoid_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_SIGMOID_BLOCK_SIZE - 1) / CUDA_SIGMOID_BLOCK_SIZE;
-    sigmoid_f32<<<num_blocks, CUDA_SIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+    op_sigmoid<<<num_blocks, CUDA_SIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
-static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+template <class T>
+static void hardsigmoid_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE;
-    hardsigmoid_f32<<<num_blocks, CUDA_HARDSIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+    op_hardsigmoid<<<num_blocks, CUDA_HARDSIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
-static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+template <class T>
+static void hardswish_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_HARDSWISH_BLOCK_SIZE - 1) / CUDA_HARDSWISH_BLOCK_SIZE;
-    hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+    op_hardswish<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
-static void exp_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+template <class T>
+static void exp_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_EXP_BLOCK_SIZE - 1) / CUDA_EXP_BLOCK_SIZE;
-    exp_f32<<<num_blocks, CUDA_EXP_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+    op_exp<<<num_blocks, CUDA_EXP_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
-static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
+template <class T>
+static void leaky_relu_cuda(const T * x, T * dst, const int k, const float negative_slope, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
-    leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
+    op_leaky_relu<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
 }
 
-static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+template <class T>
+static void sqr_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
-    sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+    op_sqr<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
-static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+template <class T>
+static void sqrt_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_SQRT_BLOCK_SIZE - 1) / CUDA_SQRT_BLOCK_SIZE;
-    sqrt_f32<<<num_blocks, CUDA_SQRT_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+    op_sqrt<<<num_blocks, CUDA_SQRT_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
-static void sin_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+template <class T>
+static void sin_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_SIN_BLOCK_SIZE - 1) / CUDA_SIN_BLOCK_SIZE;
-    sin_f32<<<num_blocks, CUDA_SIN_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+    op_sin<<<num_blocks, CUDA_SIN_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
-static void cos_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+template <class T>
+static void cos_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_COS_BLOCK_SIZE - 1) / CUDA_COS_BLOCK_SIZE;
-    cos_f32<<<num_blocks, CUDA_COS_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+    op_cos<<<num_blocks, CUDA_COS_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+template <class T>
+static void log_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_COS_BLOCK_SIZE - 1) / CUDA_COS_BLOCK_SIZE;
+    op_log<<<num_blocks, CUDA_COS_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const void * src0_d = src0->data;
+    void * dst_d = dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
+
+    if (src0->type == GGML_TYPE_F16) {
+        abs_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
+    } else {
+        abs_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
+    }
+}
+
+void ggml_cuda_op_sgn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const void * src0_d = src0->data;
+    void * dst_d = dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
+
+    if (src0->type == GGML_TYPE_F16) {
+        sgn_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
+    } else {
+        sgn_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
+    }
 }
 
 void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
+    const void * src0_d = src0->data;
+    void * dst_d = dst->data;
     cudaStream_t stream = ctx.stream();
 
     GGML_ASSERT(ggml_is_contiguous(src0));
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
 
-    neg_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+    if (src0->type == GGML_TYPE_F16) {
+        neg_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
+    } else {
+        neg_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
+    }
 }
 
 void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
+    const void * src0_d = src0->data;
+    void * dst_d = dst->data;
     cudaStream_t stream = ctx.stream();
 
     GGML_ASSERT(ggml_is_contiguous(src0));
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
 
-    step_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+    if (src0->type == GGML_TYPE_F16) {
+        step_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
+    } else {
+        step_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
+    }
 }
 
 void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
+    const void * src0_d = src0->data;
+    void * dst_d = dst->data;
     cudaStream_t stream = ctx.stream();
 
     GGML_ASSERT(ggml_is_contiguous(src0));
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
 
-    gelu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+    if (src0->type == GGML_TYPE_F16) {
+        gelu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
+    } else {
+        gelu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
+    }
 }
 
 void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
+    const void * src0_d = src0->data;
+    void * dst_d = dst->data;
     cudaStream_t stream = ctx.stream();
 
     GGML_ASSERT(ggml_is_contiguous(src0));
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
 
-    silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+    if (src0->type == GGML_TYPE_F16) {
+        silu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
+    } else {
+        silu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
+    }
 }
 
 void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -314,179 +456,263 @@ void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
 
     GGML_ASSERT(ggml_is_contiguous(src0));
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
 
-    silu_back_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(src0), stream);
+    if (src0->type == GGML_TYPE_F16) {
+        silu_back_cuda((const half *)src0_d, (const half *)src1_d, (half *)dst_d, ggml_nelements(src0), stream);
+    } else {
+        silu_back_cuda((const float*)src0_d, (const float*)src1_d, (float *)dst_d, ggml_nelements(src0), stream);
+    }
 }
 
 void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
+    const void * src0_d = src0->data;
+    void * dst_d = dst->data;
     cudaStream_t stream = ctx.stream();
 
     GGML_ASSERT(ggml_is_contiguous(src0));
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
 
-    gelu_quick_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+    if (src0->type == GGML_TYPE_F16) {
+        gelu_quick_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
+    } else {
+        gelu_quick_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
+    }
 }
 
 void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
+    const void * src0_d = src0->data;
+    void * dst_d = dst->data;
     cudaStream_t stream = ctx.stream();
 
     GGML_ASSERT(ggml_is_contiguous(src0));
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
 
-    tanh_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+    if (src0->type == GGML_TYPE_F16) {
+        tanh_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
+    } else {
+        tanh_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
+    }
 }
 
 void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
+    const void * src0_d = src0->data;
+    void * dst_d = dst->data;
     cudaStream_t stream = ctx.stream();
 
     GGML_ASSERT(ggml_is_contiguous(src0));
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
 
-    relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+    if (src0->type == GGML_TYPE_F16) {
+        relu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
+    } else {
+        relu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
+    }
 }
 
 void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
+    const void * src0_d = src0->data;
+    void * dst_d = dst->data;
     cudaStream_t stream = ctx.stream();
 
     GGML_ASSERT(ggml_is_contiguous(src0));
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
 
-    sigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+    if (src0->type == GGML_TYPE_F16) {
+        sigmoid_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
+    } else {
+        sigmoid_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
+    }
 }
 
 void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
+    const void * src0_d = src0->data;
+    void * dst_d = dst->data;
     cudaStream_t stream = ctx.stream();
 
     GGML_ASSERT(ggml_is_contiguous(src0));
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
 
-    hardsigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+    if (src0->type == GGML_TYPE_F16) {
+        hardsigmoid_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
+    } else {
+        hardsigmoid_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
+    }
 }
 
 void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
+    const void * src0_d = src0->data;
+    void * dst_d = dst->data;
     cudaStream_t stream = ctx.stream();
 
     GGML_ASSERT(ggml_is_contiguous(src0));
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
 
-    hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+    if (src0->type == GGML_TYPE_F16) {
+        hardswish_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
+    } else {
+        hardswish_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
+    }
 }
 
 void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
+    const void * src0_d = src0->data;
+    void * dst_d = dst->data;
     cudaStream_t stream = ctx.stream();
 
     GGML_ASSERT(ggml_is_contiguous(src0));
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
 
-    exp_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+    if (src0->type == GGML_TYPE_F16) {
+        exp_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
+    } else {
+        exp_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
+    }
 }
 
 void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
+    const void * src0_d = src0->data;
+    void * dst_d = dst->data;
     cudaStream_t stream = ctx.stream();
 
     GGML_ASSERT(ggml_is_contiguous(src0));
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
 
     float negative_slope;
     memcpy(&negative_slope, dst->op_params, sizeof(float));
 
-    leaky_relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), negative_slope, stream);
+    if (src0->type == GGML_TYPE_F16) {
+        leaky_relu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), negative_slope, stream);
+    } else {
+        leaky_relu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), negative_slope, stream);
+    }
 }
 
 void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
+    const void * src0_d = src0->data;
+    void * dst_d = dst->data;
     cudaStream_t stream = ctx.stream();
 
     GGML_ASSERT(ggml_is_contiguous(src0));
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
 
-    sqr_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+    if (src0->type == GGML_TYPE_F16) {
+        sqr_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
+    } else {
+        sqr_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
+    }
 }
 
 void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
+    const void * src0_d = src0->data;
+    void * dst_d = dst->data;
     cudaStream_t stream = ctx.stream();
 
     GGML_ASSERT(ggml_is_contiguous(src0));
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
 
-    sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+    if (src0->type == GGML_TYPE_F16) {
+        sqrt_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
+    } else {
+        sqrt_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
+    }
 }
 
 void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
+    const void * src0_d = src0->data;
+    void * dst_d = dst->data;
     cudaStream_t stream = ctx.stream();
 
     GGML_ASSERT(ggml_is_contiguous(src0));
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
 
-    sin_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+    if (src0->type == GGML_TYPE_F16) {
+        sin_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
+    } else {
+        sin_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
+    }
 }
 
 void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
+    const void * src0_d = src0->data;
+    void * dst_d = dst->data;
     cudaStream_t stream = ctx.stream();
 
     GGML_ASSERT(ggml_is_contiguous(src0));
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
 
-    cos_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+    if (src0->type == GGML_TYPE_F16) {
+        cos_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
+    } else {
+        cos_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
+    }
+}
+
+void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const void * src0_d = src0->data;
+    void * dst_d = dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
+
+    if (src0->type == GGML_TYPE_F16) {
+        log_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
+    } else {
+        log_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
+    }
 }
index e7f62643a2afd65441b014732690761f3c6f9eeb..940a1feed9a9c411e2a2cedecf6e419f663890e5 100644 (file)
 #define CUDA_SIN_BLOCK_SIZE 256
 #define CUDA_COS_BLOCK_SIZE 256
 
+void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_sgn(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
 void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
@@ -49,3 +53,5 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index c550142a7d07c635c3e048973826f3899a2c1fb2..9eb9a0db68f82b09ed4ee76399cde6f6c7605bba 100644 (file)
@@ -1200,7 +1200,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
                 case GGML_UNARY_OP_GELU_QUICK:
                 case GGML_UNARY_OP_SILU:
                 case GGML_UNARY_OP_ELU:
-                    return ggml_is_contiguous(op->src[0]);
+                    return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
                 default:
                     return false;
             }
@@ -1210,21 +1210,26 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
         case GGML_OP_TRANSPOSE:
         case GGML_OP_PERMUTE:
         case GGML_OP_CONCAT:
+            return true;
         case GGML_OP_ADD:
         case GGML_OP_SUB:
-        case GGML_OP_ACC:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
+            return op->src[0]->type == GGML_TYPE_F32;
+        case GGML_OP_ACC:
         case GGML_OP_REPEAT:
         case GGML_OP_SCALE:
-        case GGML_OP_CLAMP:
         case GGML_OP_CONV_TRANSPOSE_1D:
             return true;
+        case GGML_OP_CLAMP:
+            return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_SQR:
         case GGML_OP_SQRT:
         case GGML_OP_SIN:
         case GGML_OP_COS:
-            return ggml_is_contiguous(op->src[0]);
+            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
+        case GGML_OP_LOG:
+            return false; // TODO: implement
         case GGML_OP_SUM_ROWS:
         case GGML_OP_SOFT_MAX:
         case GGML_OP_GROUP_NORM:
@@ -1254,10 +1259,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
         case GGML_OP_UPSCALE:
         case GGML_OP_PAD:
         case GGML_OP_PAD_REFLECT_1D:
-        case GGML_OP_ARANGE:
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_ARGSORT:
         case GGML_OP_LEAKY_RELU:
+            return op->src[0]->type == GGML_TYPE_F32;
+        case GGML_OP_ARANGE:
             return true;
         case GGML_OP_FLASH_ATTN_EXT:
             if (op->src[1]->type != op->src[2]->type) {
index e2412bf31165a4f46ce06b6fac63cbf1e55fb27d..c393bf16b8a1f6e29a479b2fcf050d6a9c392ca8 100644 (file)
@@ -3771,10 +3771,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     std::default_random_engine rng(0);
 
     // unary ops
-    for (int v : {0, 1}) {
-        for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
-            test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 128, 2, 2, 2 }, v));
-            test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 5, 7, 11, 13 }, v));
+    for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
+        for (int v : {0, 1}) {
+            for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
+                test_cases.emplace_back(new test_unary((ggml_unary_op) op, type, { 128, 2, 2, 2 }, v));
+                test_cases.emplace_back(new test_unary((ggml_unary_op) op, type, { 5, 7, 11, 13 }, v));
+            }
         }
     }
 
@@ -3996,7 +3998,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
 
     test_cases.emplace_back(new test_add1());
     test_cases.emplace_back(new test_scale());
-    test_cases.emplace_back(new test_silu_back());
+
+    for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
+        test_cases.emplace_back(new test_silu_back());
+    }
 
     for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {
         for (bool v : {false, true}) {
@@ -4156,12 +4161,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         }
     }
 
-    test_cases.emplace_back(new test_sqr());
-    test_cases.emplace_back(new test_sqrt());
-    test_cases.emplace_back(new test_log());
-    test_cases.emplace_back(new test_sin());
-    test_cases.emplace_back(new test_cos());
-    test_cases.emplace_back(new test_clamp());
+    for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
+        test_cases.emplace_back(new test_sqr(type));
+        test_cases.emplace_back(new test_sqrt(type));
+        test_cases.emplace_back(new test_log(type));
+        test_cases.emplace_back(new test_sin(type));
+        test_cases.emplace_back(new test_cos(type));
+        test_cases.emplace_back(new test_clamp(type));
+    }
 
     test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));
     test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 3, 1}, 5));