]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : sync (unary ops refactor, static-correctness) (#2370)
authorGeorgi Gerganov <redacted>
Mon, 24 Jul 2023 11:46:21 +0000 (14:46 +0300)
committerGitHub <redacted>
Mon, 24 Jul 2023 11:46:21 +0000 (14:46 +0300)
* ggml : sync (unary ops, tests)

ggml-ci

* tests : remove unnecessary funcs

ggml-cuda.cu
ggml-metal.m
ggml.c
ggml.h
tests/test-grad0.c
tests/test-opt.c

index 6823adc6cc958172e16e728d5d75a6ab15dfb06f..b8c98354da192ae5a476995e90b122a0c904d582 100644 (file)
@@ -3962,18 +3962,23 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
             }
             func = ggml_cuda_mul;
             break;
-        case GGML_OP_GELU:
-            if (!any_on_device) {
-                return false;
-            }
-            func = ggml_cuda_gelu;
-            break;
-        case GGML_OP_SILU:
-            if (!any_on_device) {
-                return false;
-            }
-            func = ggml_cuda_silu;
-            break;
+        case GGML_OP_UNARY:
+            switch (ggml_get_unary_op(tensor)) {
+                case GGML_UNARY_OP_GELU:
+                    if (!any_on_device) {
+                        return false;
+                    }
+                    func = ggml_cuda_gelu;
+                    break;
+                case GGML_UNARY_OP_SILU:
+                    if (!any_on_device) {
+                        return false;
+                    }
+                    func = ggml_cuda_silu;
+                    break;
+                default:
+                    return false;
+            } break;
         case GGML_OP_NORM:
             if (!any_on_device) {
                 return false;
index bf3f68fe45726ca93a10b9293200d11c717e7789..1fd6e857ffe6138dc2a5be702bebcd57eb5c6fc5 100644 (file)
@@ -519,48 +519,56 @@ void ggml_metal_graph_compute(
 
                             [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                         } break;
-                    case GGML_OP_SILU:
-                        {
-                            if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
-                            }
-
-                            [encoder setComputePipelineState:ctx->pipeline_silu];
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-
-                            const int64_t n = ggml_nelements(dst);
-
-                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                        } break;
-                    case GGML_OP_RELU:
-                        {
-                            if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
-                            }
-
-                            [encoder setComputePipelineState:ctx->pipeline_relu];
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-
-                            const int64_t n = ggml_nelements(dst);
-
-                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                    case GGML_OP_UNARY:
+                        switch (ggml_get_unary_op(gf->nodes[i])) {
+                            case GGML_UNARY_OP_SILU:
+                                {
+                                    if (encoder == nil) {
+                                        encoder = [command_buffer computeCommandEncoder];
+                                    }
+
+                                    [encoder setComputePipelineState:ctx->pipeline_silu];
+                                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                                    const int64_t n = ggml_nelements(dst);
+
+                                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                } break;
+                            case GGML_UNARY_OP_RELU:
+                                {
+                                    if (encoder == nil) {
+                                        encoder = [command_buffer computeCommandEncoder];
+                                    }
+
+                                    [encoder setComputePipelineState:ctx->pipeline_relu];
+                                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                                    const int64_t n = ggml_nelements(dst);
+
+                                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                } break;
+                            case GGML_UNARY_OP_GELU:
+                                {
+                                    if (encoder == nil) {
+                                        encoder = [command_buffer computeCommandEncoder];
+                                    }
+
+                                    [encoder setComputePipelineState:ctx->pipeline_gelu];
+                                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                                    const int64_t n = ggml_nelements(dst);
+
+                                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                } break;
+                            default:
+                                {
+                                    fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
+                                    GGML_ASSERT(false);
+                                }
                         } break;
-                    case GGML_OP_GELU:
-                    {
-                            if (encoder == nil) {
-                                encoder = [command_buffer computeCommandEncoder];
-                            }
-
-                            [encoder setComputePipelineState:ctx->pipeline_gelu];
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-
-                            const int64_t n = ggml_nelements(dst);
-
-                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
-                    } break;
                     case GGML_OP_SOFT_MAX:
                         {
                             if (encoder == nil) {
@@ -979,8 +987,10 @@ void ggml_metal_graph_compute(
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
                     default:
-                        fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
-                        GGML_ASSERT(false);
+                        {
+                            fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
+                            GGML_ASSERT(false);
+                        }
                 }
             }
 
diff --git a/ggml.c b/ggml.c
index 9ee4a8d7f687bae9c572e58f724636bfa31cbf28..960b8057709a987e8aa83395fbe51a5c597fe21f 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -3440,7 +3440,9 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
 
 //inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) { for (int i = 0; i < n; ++i) y[i] *= v;          }
 inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) {
-#if defined(GGML_SIMD)
+#if defined(GGML_USE_ACCELERATE)
+    vDSP_vsmul(y, 1, &v, y, 1, n);
+#elif defined(GGML_SIMD)
     const int np = (n & ~(GGML_F32_STEP - 1));
 
     GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
@@ -3603,7 +3605,7 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
 #endif
 }
 
-inline static void ggml_vec_sum_ggf(const int n, ggml_float * s, const float * x) {
+inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) {
     ggml_float sum = 0.0;
     for (int i = 0; i < n; ++i) {
         sum += (ggml_float)x[i];
@@ -3611,6 +3613,14 @@ inline static void ggml_vec_sum_ggf(const int n, ggml_float * s, const float * x
     *s = sum;
 }
 
+inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_t * x) {
+    float sum = 0.0f;
+    for (int i = 0; i < n; ++i) {
+        sum += GGML_FP16_TO_FP32(x[i]);
+    }
+    *s = sum;
+}
+
 inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
 #ifndef GGML_USE_ACCELERATE
     float max = -INFINITY;
@@ -3750,16 +3760,6 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "ARGMAX",
     "REPEAT",
     "REPEAT_BACK",
-    "ABS",
-    "SGN",
-    "NEG",
-    "STEP",
-    "TANH",
-    "ELU",
-    "RELU",
-    "GELU",
-    "GELU_QUICK",
-    "SILU",
     "SILU_BACK",
     "NORM",
     "RMS_NORM",
@@ -3798,6 +3798,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "WIN_PART",
     "WIN_UNPART",
 
+    "UNARY",
+
     "MAP_UNARY",
     "MAP_BINARY",
 
@@ -3809,7 +3811,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CROSS_ENTROPY_LOSS_BACK",
 };
 
-static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
+static_assert(GGML_OP_COUNT == 59, "GGML_OP_COUNT != 59");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -3830,16 +3832,6 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "argmax(x)",
     "repeat(x)",
     "repeat_back(x)",
-    "abs(x)",
-    "sgn(x)",
-    "-x",
-    "step(x)",
-    "tanh(x)",
-    "elu(x)",
-    "relu(x)",
-    "gelu(x)",
-    "gelu_quick(x)",
-    "silu(x)",
     "silu_back(x)",
     "norm(x)",
     "rms_norm(x)",
@@ -3878,6 +3870,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "win_part(x)",
     "win_unpart(x)",
 
+    "unary(x)",
+
     "f(x)",
     "f(x,y)",
 
@@ -3889,7 +3883,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "cross_entropy_loss_back(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
+static_assert(GGML_OP_COUNT == 59, "GGML_OP_COUNT != 59");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -4145,6 +4139,10 @@ const char * ggml_op_name(enum ggml_op op) {
     return GGML_OP_NAME[op];
 }
 
+const char * ggml_op_symbol(enum ggml_op op) {
+    return GGML_OP_SYMBOL[op];
+}
+
 size_t ggml_element_size(const struct ggml_tensor * tensor) {
     return GGML_TYPE_SIZE[tensor->type];
 }
@@ -4443,6 +4441,10 @@ size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch)
     return result;
 }
 
+bool ggml_get_no_alloc(struct ggml_context * ctx) {
+    return ctx->no_alloc;
+}
+
 void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc) {
     ctx->no_alloc = no_alloc;
 }
@@ -4480,7 +4482,7 @@ size_t ggml_get_max_tensor_size(const struct ggml_context * ctx) {
 // this is an error prone process, but it is necessary to support inplace
 // operators when using scratch buffers
 // TODO: implement a better way
-void ggml_scratch_save(struct ggml_context * ctx) {
+static void ggml_scratch_save(struct ggml_context * ctx) {
     // this is needed to allow opt tensors to store their data
     // TODO: again, need to find a better way
     ctx->no_alloc_save = ctx->no_alloc;
@@ -4490,7 +4492,7 @@ void ggml_scratch_save(struct ggml_context * ctx) {
     ctx->scratch.data = NULL;
 }
 
-void ggml_scratch_load(struct ggml_context * ctx) {
+static void ggml_scratch_load(struct ggml_context * ctx) {
     ctx->no_alloc = ctx->no_alloc_save;
 
     ctx->scratch = ctx->scratch_save;
@@ -4498,7 +4500,7 @@ void ggml_scratch_load(struct ggml_context * ctx) {
 
 ////////////////////////////////////////////////////////////////////////////////
 
-struct ggml_tensor * ggml_new_tensor_impl(
+static struct ggml_tensor * ggml_new_tensor_impl(
         struct ggml_context * ctx,
         enum   ggml_type type,
         int    n_dims,
@@ -4621,6 +4623,21 @@ struct ggml_tensor * ggml_new_tensor_impl(
     return result;
 }
 
+static void ggml_set_op_params(struct ggml_tensor * tensor, const void * params, size_t params_size) {
+    assert(params_size <= GGML_MAX_OP_PARAMS);
+    memcpy(tensor->op_params, params, params_size);
+}
+
+static int32_t ggml_get_op_params_i32(const struct ggml_tensor * tensor, uint32_t i) {
+    assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t));
+    return ((const int32_t *)(tensor->op_params))[i];
+}
+
+static void ggml_set_op_params_i32(struct ggml_tensor * tensor, uint32_t i, int32_t value) {
+    assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t));
+    ((int32_t *)(tensor->op_params))[i] = value;
+}
+
 struct ggml_tensor * ggml_new_tensor(
         struct ggml_context * ctx,
         enum   ggml_type type,
@@ -4952,6 +4969,16 @@ float * ggml_get_data_f32(const struct ggml_tensor * tensor) {
     return (float *)(tensor->data);
 }
 
+enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) {
+    GGML_ASSERT(tensor->op == GGML_OP_UNARY);
+    return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0);
+}
+
+static void ggml_set_unary_op(struct ggml_tensor * tensor, enum ggml_unary_op op) {
+    GGML_ASSERT(tensor->op = GGML_OP_UNARY);
+    ggml_set_op_params_i32(tensor, 0, (int32_t) op);
+}
+
 const char * ggml_get_name(const struct ggml_tensor * tensor) {
     return tensor->name;
 }
@@ -4970,11 +4997,6 @@ struct ggml_tensor * ggml_format_name(struct ggml_tensor * tensor, const char *
     return tensor;
 }
 
-static void ggml_set_op_params(struct ggml_tensor * tensor, const void * params, size_t params_size) {
-    assert(params_size <= GGML_MAX_OP_PARAMS);
-    memcpy(tensor->op_params, params, params_size);
-}
-
 struct ggml_tensor * ggml_view_tensor(
         struct ggml_context * ctx,
         const struct ggml_tensor * src) {
@@ -5010,7 +5032,7 @@ struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * nam
 
 // ggml_dup
 
-struct ggml_tensor * ggml_dup_impl(
+static struct ggml_tensor * ggml_dup_impl(
         struct ggml_context * ctx,
         struct ggml_tensor * a,
         bool inplace) {
@@ -5043,7 +5065,7 @@ struct ggml_tensor * ggml_dup_inplace(
 
 // ggml_add
 
-struct ggml_tensor * ggml_add_impl(
+static struct ggml_tensor * ggml_add_impl(
         struct ggml_context * ctx,
         struct ggml_tensor * a,
         struct ggml_tensor * b,
@@ -5086,7 +5108,7 @@ struct ggml_tensor * ggml_add_inplace(
 
 // ggml_add1
 
-struct ggml_tensor * ggml_add1_impl(
+static struct ggml_tensor * ggml_add1_impl(
         struct ggml_context * ctx,
         struct ggml_tensor * a,
         struct ggml_tensor * b,
@@ -5126,7 +5148,7 @@ struct ggml_tensor * ggml_add1_inplace(
 
 // ggml_acc
 
-struct ggml_tensor * ggml_acc_impl(
+static struct ggml_tensor * ggml_acc_impl(
         struct ggml_context * ctx,
         struct ggml_tensor * a,
         struct ggml_tensor * b,
@@ -5183,7 +5205,7 @@ struct ggml_tensor * ggml_acc_inplace(
 
 // ggml_sub
 
-struct ggml_tensor * ggml_sub_impl(
+static struct ggml_tensor * ggml_sub_impl(
         struct ggml_context * ctx,
         struct ggml_tensor * a,
         struct ggml_tensor * b,
@@ -5222,7 +5244,7 @@ struct ggml_tensor * ggml_sub_inplace(
 
 // ggml_mul
 
-struct ggml_tensor * ggml_mul_impl(
+static struct ggml_tensor * ggml_mul_impl(
         struct ggml_context * ctx,
         struct ggml_tensor * a,
         struct ggml_tensor * b,
@@ -5269,7 +5291,7 @@ struct ggml_tensor * ggml_mul_inplace(
 
 // ggml_div
 
-struct ggml_tensor * ggml_div_impl(
+static struct ggml_tensor * ggml_div_impl(
         struct ggml_context * ctx,
         struct ggml_tensor * a,
         struct ggml_tensor * b,
@@ -5312,7 +5334,7 @@ struct ggml_tensor * ggml_div_inplace(
 
 // ggml_sqr
 
-struct ggml_tensor * ggml_sqr_impl(
+static struct ggml_tensor * ggml_sqr_impl(
         struct ggml_context * ctx,
         struct ggml_tensor * a,
         bool inplace) {
@@ -5345,7 +5367,7 @@ struct ggml_tensor * ggml_sqr_inplace(
 
 // ggml_sqrt
 
-struct ggml_tensor * ggml_sqrt_impl(
+static struct ggml_tensor * ggml_sqrt_impl(
         struct ggml_context * ctx,
         struct ggml_tensor * a,
         bool inplace) {
@@ -5379,7 +5401,7 @@ struct ggml_tensor * ggml_sqrt_inplace(
 
 // ggml_log
 
-struct ggml_tensor * ggml_log_impl(
+static struct ggml_tensor * ggml_log_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         bool inplace) {
@@ -5559,333 +5581,142 @@ struct ggml_tensor * ggml_repeat_back(
 
 // ggml_abs
 
-struct ggml_tensor * ggml_abs_impl(
-        struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        is_node = true;
-    }
-
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    result->op   = GGML_OP_ABS;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-
-    return result;
-}
-
 struct ggml_tensor * ggml_abs(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_abs_impl(ctx, a, false);
+    return ggml_unary(ctx, a, GGML_UNARY_OP_ABS);
 }
 
 struct ggml_tensor * ggml_abs_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_abs_impl(ctx, a, true);
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ABS);
 }
 
-
 // ggml_sgn
 
-struct ggml_tensor * ggml_sgn_impl(
-        struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        is_node = true;
-    }
-
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    result->op   = GGML_OP_SGN;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-
-    return result;
-}
-
 struct ggml_tensor * ggml_sgn(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_sgn_impl(ctx, a, false);
+    return ggml_unary(ctx, a, GGML_UNARY_OP_SGN);
 }
 
 struct ggml_tensor * ggml_sgn_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_sgn_impl(ctx, a, true);
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SGN);
 }
 
 // ggml_neg
 
-struct ggml_tensor * ggml_neg_impl(
-        struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        is_node = true;
-    }
-
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    result->op   = GGML_OP_NEG;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-
-    return result;
-}
-
 struct ggml_tensor * ggml_neg(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_neg_impl(ctx, a, false);
+    return ggml_unary(ctx, a, GGML_UNARY_OP_NEG);
 }
 
 struct ggml_tensor * ggml_neg_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_neg_impl(ctx, a, true);
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_NEG);
 }
 
 // ggml_step
 
-struct ggml_tensor * ggml_step_impl(
-        struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        is_node = true;
-    }
-
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    result->op   = GGML_OP_STEP;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-
-    return result;
-}
-
 struct ggml_tensor * ggml_step(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_step_impl(ctx, a, false);
+    return ggml_unary(ctx, a, GGML_UNARY_OP_STEP);
 }
 
 struct ggml_tensor * ggml_step_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_step_impl(ctx, a, true);
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_STEP);
 }
 
 // ggml_tanh
 
-struct ggml_tensor * ggml_tanh_impl(
-        struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        is_node = true;
-    }
-
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    result->op   = GGML_OP_TANH;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-
-    return result;
-}
-
 struct ggml_tensor * ggml_tanh(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_tanh_impl(ctx, a, false);
+    return ggml_unary(ctx, a, GGML_UNARY_OP_TANH);
 }
 
 struct ggml_tensor * ggml_tanh_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_tanh_impl(ctx, a, true);
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_TANH);
 }
 
 // ggml_elu
 
-struct ggml_tensor * ggml_elu_impl(
-        struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        is_node = true;
-    }
-
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    result->op   = GGML_OP_ELU;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-
-    return result;
-}
-
 struct ggml_tensor * ggml_elu(
     struct ggml_context * ctx,
     struct ggml_tensor  * a) {
-    return ggml_elu_impl(ctx, a, false);
+    return ggml_unary(ctx, a, GGML_UNARY_OP_ELU);
 }
 
 struct ggml_tensor * ggml_elu_inplace(
     struct ggml_context * ctx,
     struct ggml_tensor  * a) {
-    return ggml_elu_impl(ctx, a, true);
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ELU);
 }
 
 // ggml_relu
 
-struct ggml_tensor * ggml_relu_impl(
-        struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        is_node = true;
-    }
-
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    result->op   = GGML_OP_RELU;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-
-    return result;
-}
-
 struct ggml_tensor * ggml_relu(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_relu_impl(ctx, a, false);
+    return ggml_unary(ctx, a, GGML_UNARY_OP_RELU);
 }
 
 struct ggml_tensor * ggml_relu_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_relu_impl(ctx, a, true);
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU);
 }
 
 // ggml_gelu
 
-struct ggml_tensor * ggml_gelu_impl(
-        struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        is_node = true;
-    }
-
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    result->op   = GGML_OP_GELU;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-
-    return result;
-}
-
 struct ggml_tensor * ggml_gelu(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_gelu_impl(ctx, a, false);
+    return ggml_unary(ctx, a, GGML_UNARY_OP_GELU);
 }
 
 struct ggml_tensor * ggml_gelu_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_gelu_impl(ctx, a, true);
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU);
 }
 
 // ggml_gelu_quick
 
-struct ggml_tensor * ggml_gelu_quick_impl(
-        struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        is_node = true;
-    }
-
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    result->op   = GGML_OP_GELU_QUICK;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-
-    return result;
-}
-
 struct ggml_tensor * ggml_gelu_quick(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_gelu_quick_impl(ctx, a, false);
+    return ggml_unary(ctx, a, GGML_UNARY_OP_GELU_QUICK);
 }
 
 struct ggml_tensor * ggml_gelu_quick_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_gelu_quick_impl(ctx, a, true);
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU_QUICK);
 }
 
 // ggml_silu
 
-struct ggml_tensor * ggml_silu_impl(
-        struct ggml_context * ctx,
-        struct ggml_tensor * a,
-        bool inplace) {
-    bool is_node = false;
-
-    if (!inplace && (a->grad)) {
-        is_node = true;
-    }
-
-    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-
-    result->op   = GGML_OP_SILU;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-
-    return result;
-}
-
 struct ggml_tensor * ggml_silu(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_silu_impl(ctx, a, false);
+    return ggml_unary(ctx, a, GGML_UNARY_OP_SILU);
 }
 
 struct ggml_tensor * ggml_silu_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_silu_impl(ctx, a, true);
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU);
 }
 
 // ggml_silu_back
@@ -5913,7 +5744,7 @@ struct ggml_tensor * ggml_silu_back(
 
 // ggml_norm
 
-struct ggml_tensor * ggml_norm_impl(
+static struct ggml_tensor * ggml_norm_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         bool inplace) {
@@ -5947,7 +5778,7 @@ struct ggml_tensor * ggml_norm_inplace(
     return ggml_norm_impl(ctx, a, true);
 }
 
-struct ggml_tensor * ggml_rms_norm_impl(
+static struct ggml_tensor * ggml_rms_norm_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         bool inplace) {
@@ -6056,7 +5887,7 @@ struct ggml_tensor * ggml_out_prod(
 
 // ggml_scale
 
-struct ggml_tensor * ggml_scale_impl(
+static struct ggml_tensor * ggml_scale_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
@@ -6096,7 +5927,7 @@ struct ggml_tensor * ggml_scale_inplace(
 
 // ggml_set
 
-struct ggml_tensor * ggml_set_impl(
+static struct ggml_tensor * ggml_set_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
@@ -6186,7 +6017,7 @@ struct ggml_tensor * ggml_set_2d_inplace(
 
 // ggml_cpy
 
-struct ggml_tensor * ggml_cpy_impl(
+static struct ggml_tensor * ggml_cpy_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
@@ -6231,7 +6062,7 @@ struct ggml_tensor * ggml_cpy_inplace(
 
 // ggml_cont
 
-struct ggml_tensor * ggml_cont_impl(
+static struct ggml_tensor * ggml_cont_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         bool inplace) {
@@ -6701,7 +6532,7 @@ struct ggml_tensor * ggml_diag(
 
 // ggml_diag_mask_inf
 
-struct ggml_tensor * ggml_diag_mask_inf_impl(
+static struct ggml_tensor * ggml_diag_mask_inf_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         int                   n_past,
@@ -6741,7 +6572,7 @@ struct ggml_tensor * ggml_diag_mask_inf_inplace(
 
 // ggml_diag_mask_zero
 
-struct ggml_tensor * ggml_diag_mask_zero_impl(
+static struct ggml_tensor * ggml_diag_mask_zero_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         int                   n_past,
@@ -6780,7 +6611,7 @@ struct ggml_tensor * ggml_diag_mask_zero_inplace(
 
 // ggml_soft_max
 
-struct ggml_tensor * ggml_soft_max_impl(
+static struct ggml_tensor * ggml_soft_max_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         bool                  inplace) {
@@ -6814,7 +6645,7 @@ struct ggml_tensor * ggml_soft_max_inplace(
 
 // ggml_soft_max_back
 
-struct ggml_tensor * ggml_soft_max_back_impl(
+static struct ggml_tensor * ggml_soft_max_back_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
@@ -6851,7 +6682,7 @@ struct ggml_tensor * ggml_soft_max_back_inplace(
 
 // ggml_rope
 
-struct ggml_tensor * ggml_rope_impl(
+static struct ggml_tensor * ggml_rope_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         int                   n_past,
@@ -7363,9 +7194,47 @@ struct ggml_tensor * ggml_win_unpart(
     return result;
 }
 
+// gmml_unary
+
+static struct ggml_tensor * ggml_unary_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        enum ggml_unary_op op,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_set_unary_op(result, op);
+
+    result->op   = GGML_OP_UNARY;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_unary(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        enum ggml_unary_op op) {
+    return ggml_unary_impl(ctx, a, op, false);
+}
+
+struct ggml_tensor * ggml_unary_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        enum ggml_unary_op op) {
+    return ggml_unary_impl(ctx, a, op, true);
+}
+
 // ggml_map_unary
 
-struct ggml_tensor * ggml_map_unary_impl_f32(
+static struct ggml_tensor * ggml_map_unary_impl_f32(
         struct ggml_context        * ctx,
         struct ggml_tensor         * a,
         const  ggml_unary_op_f32_t fun,
@@ -7403,7 +7272,7 @@ struct ggml_tensor * ggml_map_unary_inplace_f32(
 
 // ggml_map_binary
 
-struct ggml_tensor * ggml_map_binary_impl_f32(
+static struct ggml_tensor * ggml_map_binary_impl_f32(
         struct ggml_context         * ctx,
         struct ggml_tensor          * a,
         struct ggml_tensor          * b,
@@ -7447,7 +7316,7 @@ struct ggml_tensor * ggml_map_binary_inplace_f32(
 
 // ggml_map_custom1
 
-struct ggml_tensor * ggml_map_custom1_impl_f32(
+static struct ggml_tensor * ggml_map_custom1_impl_f32(
         struct ggml_context          * ctx,
         struct ggml_tensor           * a,
         const  ggml_custom1_op_f32_t   fun,
@@ -7485,7 +7354,7 @@ struct ggml_tensor * ggml_map_custom1_inplace_f32(
 
 // ggml_map_custom2
 
-struct ggml_tensor * ggml_map_custom2_impl_f32(
+static struct ggml_tensor * ggml_map_custom2_impl_f32(
         struct ggml_context          * ctx,
         struct ggml_tensor           * a,
         struct ggml_tensor           * b,
@@ -7527,7 +7396,7 @@ struct ggml_tensor * ggml_map_custom2_inplace_f32(
 
 // ggml_map_custom3
 
-struct ggml_tensor * ggml_map_custom3_impl_f32(
+static struct ggml_tensor * ggml_map_custom3_impl_f32(
         struct ggml_context          * ctx,
         struct ggml_tensor           * a,
         struct ggml_tensor           * b,
@@ -9292,7 +9161,7 @@ static void ggml_compute_forward_sum_f32(
     for (int64_t i03 = 0; i03 < ne03; i03++) {
         for (int64_t i02 = 0; i02 < ne02; i02++) {
             for (int64_t i01 = 0; i01 < ne01; i01++) {
-                ggml_vec_sum_ggf(ne00,
+                ggml_vec_sum_f32_ggf(ne00,
                         &row_sum,
                         (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
                 sum += row_sum;
@@ -9302,6 +9171,38 @@ static void ggml_compute_forward_sum_f32(
     ((float *) dst->data)[0] = sum;
 }
 
+static void ggml_compute_forward_sum_f16(
+    const struct ggml_compute_params * params,
+    const struct ggml_tensor * src0,
+          struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(ggml_is_scalar(dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    assert(src0->nb[0] == sizeof(ggml_fp16_t));
+
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb);
+
+    float sum = 0;
+    float row_sum = 0;
+
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = 0; i01 < ne01; i01++) {
+                ggml_vec_sum_f16_ggf(ne00,
+                    &row_sum,
+                    (ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
+                sum += row_sum;
+            }
+        }
+    }
+    ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum);
+}
+
 static void ggml_compute_forward_sum(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
@@ -9311,6 +9212,10 @@ static void ggml_compute_forward_sum(
             {
                 ggml_compute_forward_sum_f32(params, src0, dst);
             } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_sum_f16(params, src0, dst);
+            } break;
         default:
             {
                 GGML_ASSERT(false);
@@ -10077,7 +9982,6 @@ static void ggml_compute_forward_silu(
     }
 }
 
-
 // ggml_compute_forward_silu_back
 
 static void ggml_compute_forward_silu_back_f32(
@@ -14122,6 +14026,62 @@ static void ggml_compute_forward_win_unpart(
     }
 }
 
+//gmml_compute_forward_unary
+
+static void ggml_compute_forward_unary(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    const enum ggml_unary_op op = ggml_get_unary_op(dst);
+
+    switch (op) {
+        case GGML_UNARY_OP_ABS:
+            {
+                ggml_compute_forward_abs(params, src0, dst);
+            } break;
+        case GGML_UNARY_OP_SGN:
+            {
+                ggml_compute_forward_sgn(params, src0, dst);
+            } break;
+        case GGML_UNARY_OP_NEG:
+            {
+                ggml_compute_forward_neg(params, src0, dst);
+            } break;
+        case GGML_UNARY_OP_STEP:
+            {
+                ggml_compute_forward_step(params, src0, dst);
+            } break;
+        case GGML_UNARY_OP_TANH:
+            {
+                ggml_compute_forward_tanh(params, src0, dst);
+            } break;
+        case GGML_UNARY_OP_ELU:
+            {
+                ggml_compute_forward_elu(params, src0, dst);
+            } break;
+        case GGML_UNARY_OP_RELU:
+            {
+                ggml_compute_forward_relu(params, src0, dst);
+            } break;
+        case GGML_UNARY_OP_GELU:
+            {
+                ggml_compute_forward_gelu(params, src0, dst);
+            } break;
+        case GGML_UNARY_OP_GELU_QUICK:
+            {
+                ggml_compute_forward_gelu_quick(params, src0, dst);
+            } break;
+        case GGML_UNARY_OP_SILU:
+            {
+                ggml_compute_forward_silu(params, src0, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_map_unary
 
 static void ggml_compute_forward_map_unary_f32(
@@ -14682,46 +14642,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_repeat_back(params, tensor->src[0], tensor);
             } break;
-        case GGML_OP_ABS:
-            {
-                ggml_compute_forward_abs(params, tensor->src[0], tensor);
-            } break;
-        case GGML_OP_SGN:
-            {
-                ggml_compute_forward_sgn(params, tensor->src[0], tensor);
-            } break;
-        case GGML_OP_NEG:
-            {
-                ggml_compute_forward_neg(params, tensor->src[0], tensor);
-            } break;
-        case GGML_OP_STEP:
-            {
-                ggml_compute_forward_step(params, tensor->src[0], tensor);
-            } break;
-        case GGML_OP_TANH:
-            {
-                ggml_compute_forward_tanh(params, tensor->src[0], tensor);
-            } break;
-        case GGML_OP_ELU:
-            {
-                ggml_compute_forward_elu(params, tensor->src[0], tensor);
-            } break;
-        case GGML_OP_RELU:
-            {
-                ggml_compute_forward_relu(params, tensor->src[0], tensor);
-            } break;
-        case GGML_OP_GELU:
-            {
-                ggml_compute_forward_gelu(params, tensor->src[0], tensor);
-            } break;
-        case GGML_OP_GELU_QUICK:
-            {
-                ggml_compute_forward_gelu_quick(params, tensor->src[0], tensor);
-            } break;
-        case GGML_OP_SILU:
-            {
-                ggml_compute_forward_silu(params, tensor->src[0], tensor);
-            } break;
         case GGML_OP_SILU_BACK:
             {
                 ggml_compute_forward_silu_back(params, tensor->src[0], tensor->src[1], tensor);
@@ -14864,6 +14784,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_win_unpart(params, tensor->src[0], tensor);
             } break;
+        case GGML_OP_UNARY:
+            {
+                ggml_compute_forward_unary(params, tensor->src[0], tensor);
+            } break;
         case GGML_OP_MAP_UNARY:
             {
                 ggml_unary_op_f32_t fun;
@@ -15112,73 +15036,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                             inplace);
                 }
             } break;
-        case GGML_OP_ABS:
-            {
-                if (src0->grad) {
-                    src0->grad =
-                        ggml_add_impl(ctx,
-                                src0->grad,
-                                ggml_mul(ctx,
-                                    ggml_sgn(ctx, src0),
-                                    tensor->grad),
-                                inplace);
-                }
-            } break;
-        case GGML_OP_SGN:
-            {
-                if (src0->grad) {
-                    // noop
-                }
-            } break;
-        case GGML_OP_NEG:
-            {
-                if (src0->grad) {
-                    src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace);
-                }
-            } break;
-        case GGML_OP_STEP:
-            {
-                if (src0->grad) {
-                    // noop
-                }
-            } break;
-        case GGML_OP_TANH:
-            {
-                GGML_ASSERT(false); // TODO: not implemented
-            } break;
-        case GGML_OP_ELU:
-            {
-                GGML_ASSERT(false); // TODO: not implemented
-            } break;
-        case GGML_OP_RELU:
-            {
-                if (src0->grad) {
-                    src0->grad = ggml_sub_impl(ctx,
-                            src0->grad,
-                            ggml_mul(ctx,
-                                ggml_step(ctx, src0),
-                                tensor->grad),
-                            inplace);
-                }
-            } break;
-        case GGML_OP_GELU:
-            {
-                GGML_ASSERT(false); // TODO: not implemented
-            } break;
-        case GGML_OP_GELU_QUICK:
-            {
-                GGML_ASSERT(false); // TODO: not implemented
-            } break;
-        case GGML_OP_SILU:
-            {
-                // necessary for llama
-                if (src0->grad) {
-                    src0->grad = ggml_add_impl(ctx,
-                            src0->grad,
-                            ggml_silu_back(ctx, src0, tensor->grad),
-                            inplace);
-                }
-            } break;
         case GGML_OP_SILU_BACK:
             {
                 GGML_ASSERT(false); // TODO: not implemented
@@ -15440,9 +15297,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                             ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
                         inplace);
                 }
-                if (src1->grad) {
-                    // noop
-                }
             } break;
         case GGML_OP_DIAG_MASK_ZERO:
             {
@@ -15454,9 +15308,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                             ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
                         inplace);
                 }
-                if (src1->grad) {
-                    // noop
-                }
             } break;
         case GGML_OP_SOFT_MAX:
             {
@@ -15491,9 +15342,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 n_ctx),
                             inplace);
                 }
-                if (src1->grad) {
-                    // noop
-                }
             } break;
         case GGML_OP_ROPE_BACK:
             {
@@ -15512,9 +15360,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 n_ctx),
                             inplace);
                 }
-                if (src1->grad) {
-                    // noop
-                }
             } break;
         case GGML_OP_ALIBI:
             {
@@ -15707,6 +15552,80 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             } break;
         case GGML_OP_WIN_PART:
         case GGML_OP_WIN_UNPART:
+        case GGML_OP_UNARY:
+            {
+                switch (ggml_get_unary_op(tensor)) {
+                    case GGML_UNARY_OP_ABS:
+                        {
+                            if (src0->grad) {
+                                src0->grad =
+                                    ggml_add_impl(ctx,
+                                            src0->grad,
+                                            ggml_mul(ctx,
+                                                ggml_sgn(ctx, src0),
+                                                tensor->grad),
+                                            inplace);
+                            }
+                        } break;
+                    case GGML_UNARY_OP_SGN:
+                        {
+                            if (src0->grad) {
+                                // noop
+                            }
+                        } break;
+                    case GGML_UNARY_OP_NEG:
+                        {
+                            if (src0->grad) {
+                                src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace);
+                            }
+                        } break;
+                    case GGML_UNARY_OP_STEP:
+                        {
+                            if (src0->grad) {
+                                // noop
+                            }
+                        } break;
+                    case GGML_UNARY_OP_TANH:
+                        {
+                            GGML_ASSERT(false); // TODO: not implemented
+                        } break;
+                    case GGML_UNARY_OP_ELU:
+                        {
+                            GGML_ASSERT(false); // TODO: not implemented
+                        } break;
+                    case GGML_UNARY_OP_RELU:
+                        {
+                            if (src0->grad) {
+                                src0->grad = ggml_add_impl(ctx,
+                                        src0->grad,
+                                        ggml_mul(ctx,
+                                            ggml_step(ctx, src0),
+                                            tensor->grad),
+                                        inplace);
+                            }
+                        } break;
+                    case GGML_UNARY_OP_GELU:
+                        {
+                            GGML_ASSERT(false); // TODO: not implemented
+                        } break;
+                    case GGML_UNARY_OP_GELU_QUICK:
+                        {
+                            GGML_ASSERT(false); // TODO: not implemented
+                        } break;
+                    case GGML_UNARY_OP_SILU:
+                        {
+                            // necessary for llama
+                            if (src0->grad) {
+                                src0->grad = ggml_add_impl(ctx,
+                                        src0->grad,
+                                        ggml_silu_back(ctx, src0, tensor->grad),
+                                        inplace);
+                            }
+                        } break;
+                    default:
+                        GGML_ASSERT(false);
+                }
+            } break;
         case GGML_OP_MAP_UNARY:
         case GGML_OP_MAP_BINARY:
         case GGML_OP_MAP_CUSTOM1:
@@ -15937,7 +15856,7 @@ typedef pthread_t ggml_thread_t;
 
 // Android's libc implementation "bionic" does not support setting affinity
 #if defined(__linux__) && !defined(__BIONIC__)
-void set_numa_thread_affinity(int thread_n, int n_threads) {
+static void set_numa_thread_affinity(int thread_n, int n_threads) {
     if (!ggml_is_numa()) {
         return;
     }
@@ -15962,7 +15881,7 @@ void set_numa_thread_affinity(int thread_n, int n_threads) {
     CPU_FREE(cpus);
 }
 
-void clear_numa_thread_affinity(void) {
+static void clear_numa_thread_affinity(void) {
     if (!ggml_is_numa()) {
         return;
     }
@@ -15986,8 +15905,8 @@ void clear_numa_thread_affinity(void) {
 #else
 // TODO: Windows etc.
 // (the linux implementation may also work on BSD, someone should test)
-void set_numa_thread_affinity(int thread_n, int n_threads) { UNUSED(thread_n); UNUSED(n_threads);  }
-void clear_numa_thread_affinity(void) {}
+static void set_numa_thread_affinity(int thread_n, int n_threads) { UNUSED(thread_n); UNUSED(n_threads);  }
+static void clear_numa_thread_affinity(void) {}
 #endif
 
 struct ggml_compute_state_shared {
@@ -16199,21 +16118,34 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
             case GGML_OP_ARGMAX:
             case GGML_OP_REPEAT:
             case GGML_OP_REPEAT_BACK:
-            case GGML_OP_ABS:
-            case GGML_OP_SGN:
-            case GGML_OP_NEG:
-            case GGML_OP_STEP:
-            case GGML_OP_TANH:
-            case GGML_OP_ELU:
-            case GGML_OP_RELU:
-                {
+            {
                     n_tasks = 1;
                 } break;
-            case GGML_OP_MUL:
-            case GGML_OP_GELU:
-            case GGML_OP_GELU_QUICK:
-            case GGML_OP_SILU:
+
+            case GGML_OP_UNARY:
+                {
+                    switch (ggml_get_unary_op(node)) {
+                        case GGML_UNARY_OP_ABS:
+                        case GGML_UNARY_OP_SGN:
+                        case GGML_UNARY_OP_NEG:
+                        case GGML_UNARY_OP_STEP:
+                        case GGML_UNARY_OP_TANH:
+                        case GGML_UNARY_OP_ELU:
+                        case GGML_UNARY_OP_RELU:
+                            {
+                                n_tasks = 1;
+                            } break;
+
+                        case GGML_UNARY_OP_GELU:
+                        case GGML_UNARY_OP_GELU_QUICK:
+                        case GGML_UNARY_OP_SILU:
+                            {
+                                n_tasks = n_threads;
+                            } break;
+                    }
+                } break;
             case GGML_OP_SILU_BACK:
+            case GGML_OP_MUL:
             case GGML_OP_NORM:
             case GGML_OP_RMS_NORM:
             case GGML_OP_RMS_NORM_BACK:
@@ -16728,7 +16660,8 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
                     fwrite(&nb, sizeof(uint64_t), 1, fout);
                 }
 
-                fwrite(tensor->name, sizeof(char), GGML_MAX_NAME, fout);
+                fwrite(tensor->name,      sizeof(char), GGML_MAX_NAME,      fout);
+                fwrite(tensor->op_params, sizeof(char), GGML_MAX_OP_PARAMS, fout);
 
                 // dump the data
                 // TODO: pad this to 32 byte boundary
@@ -16761,7 +16694,8 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
                     fwrite(&nb, sizeof(uint64_t), 1, fout);
                 }
 
-                fwrite(tensor->name, sizeof(char), GGML_MAX_NAME, fout);
+                fwrite(tensor->name,      sizeof(char), GGML_MAX_NAME,      fout);
+                fwrite(tensor->op_params, sizeof(char), GGML_MAX_OP_PARAMS, fout);
 
                 // output the op arguments
                 {
@@ -16942,7 +16876,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
 
                 tensor->op = (enum ggml_op) op;
 
-                memcpy(tensor->name, ptr, GGML_MAX_NAME); ptr += GGML_MAX_NAME;
+                memcpy(tensor->name,      ptr, GGML_MAX_NAME);      ptr += GGML_MAX_NAME;
+                memcpy(tensor->op_params, ptr, GGML_MAX_OP_PARAMS); ptr += GGML_MAX_OP_PARAMS;
 
                 tensor->data = (void *) ptr;
 
@@ -16987,7 +16922,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
                     nb[j] = nb_cur;
                 }
 
-                const char * ptr_name = ptr; ptr += GGML_MAX_NAME;
+                const char * ptr_name      = ptr; ptr += GGML_MAX_NAME;
+                const char * ptr_op_params = ptr; ptr += GGML_MAX_OP_PARAMS;
 
                 const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += GGML_MAX_SRC*sizeof(int32_t);
 
@@ -17024,8 +16960,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
                         {
                             tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0);
 
-                            uint64_t offs;
-                            memcpy(&offs, tensor->op_params, sizeof(offs));
+                            size_t offs;
+                            memcpy(&offs, ptr_op_params, sizeof(offs));
 
                             tensor->data = ((char *) tensor->data) + offs;
                         } break;
@@ -17045,7 +16981,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
                         } break;
                 }
 
-                memcpy(tensor->name, ptr_name, GGML_MAX_NAME);
+                memcpy(tensor->name,      ptr_name,      GGML_MAX_NAME);
+                memcpy(tensor->op_params, ptr_op_params, GGML_MAX_OP_PARAMS);
 
                 for (int j = 0; j < GGML_MAX_DIMS; ++j) {
                     tensor->nb[j] = nb[j];
@@ -17079,7 +17016,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
         GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
                 i,
                 node->ne[0], node->ne[1], node->ne[2],
-                GGML_OP_NAME[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
+                ggml_op_name(node->op), node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
                 (double) node->perf_cycles  / (double) ggml_cycles_per_ms(),
                 (double) node->perf_cycles  / (double) ggml_cycles_per_ms() / (double) node->perf_runs,
                 (double) node->perf_time_us / 1000.0,
@@ -17093,7 +17030,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
         GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
                 i,
                 node->ne[0], node->ne[1],
-                GGML_OP_NAME[node->op]);
+                ggml_op_name(node->op));
     }
 
     for (int i = 0; i < GGML_OP_COUNT; i++) {
@@ -17101,7 +17038,7 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
             continue;
         }
 
-        GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_OP_NAME[i], (double) perf_total_per_op_us[i] / 1000.0);
+        GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", ggml_op_name(i), (double) perf_total_per_op_us[i] / 1000.0);
     }
 
     GGML_PRINT("========================================\n");
@@ -17195,13 +17132,13 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
         }
 
         if (node->n_dims == 2) {
-            fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], GGML_OP_SYMBOL[node->op]);
+            fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], ggml_op_symbol(node->op));
         } else {
-            fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], node->ne[2], GGML_OP_SYMBOL[node->op]);
+            fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], node->ne[2], ggml_op_symbol(node->op));
         }
 
         if (node->grad) {
-            fprintf(fp, " | <g>%s\"; ]\n", GGML_OP_SYMBOL[node->grad->op]);
+            fprintf(fp, " | <g>%s\"; ]\n", ggml_op_symbol(node->grad->op));
         } else {
             fprintf(fp, "\"; ]\n");
         }
diff --git a/ggml.h b/ggml.h
index 871c85a89aae796f841fce9bd9bee0284929f2c6..de44fba9e0961886ba2d35046e60dc7934f05961 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -330,16 +330,6 @@ extern "C" {
         GGML_OP_ARGMAX,
         GGML_OP_REPEAT,
         GGML_OP_REPEAT_BACK,
-        GGML_OP_ABS,
-        GGML_OP_SGN,
-        GGML_OP_NEG,
-        GGML_OP_STEP,
-        GGML_OP_TANH,
-        GGML_OP_ELU,
-        GGML_OP_RELU,
-        GGML_OP_GELU,
-        GGML_OP_GELU_QUICK,
-        GGML_OP_SILU,
         GGML_OP_SILU_BACK,
         GGML_OP_NORM, // normalize
         GGML_OP_RMS_NORM,
@@ -378,6 +368,8 @@ extern "C" {
         GGML_OP_WIN_PART,
         GGML_OP_WIN_UNPART,
 
+        GGML_OP_UNARY,
+
         GGML_OP_MAP_UNARY,
         GGML_OP_MAP_BINARY,
 
@@ -391,6 +383,18 @@ extern "C" {
         GGML_OP_COUNT,
     };
 
+    enum ggml_unary_op {
+        GGML_UNARY_OP_ABS,
+        GGML_UNARY_OP_SGN,
+        GGML_UNARY_OP_NEG,
+        GGML_UNARY_OP_STEP,
+        GGML_UNARY_OP_TANH,
+        GGML_UNARY_OP_ELU,
+        GGML_UNARY_OP_RELU,
+        GGML_UNARY_OP_GELU,
+        GGML_UNARY_OP_GELU_QUICK,
+        GGML_UNARY_OP_SILU,
+    };
 
     // ggml object
     struct ggml_object {
@@ -535,6 +539,7 @@ extern "C" {
 
     GGML_API const char * ggml_type_name(enum ggml_type type);
     GGML_API const char * ggml_op_name  (enum ggml_op   op);
+    GGML_API const char * ggml_op_symbol(enum ggml_op   op);
 
     GGML_API size_t  ggml_element_size(const struct ggml_tensor * tensor);
 
@@ -558,6 +563,7 @@ extern "C" {
     GGML_API size_t  ggml_used_mem(const struct ggml_context * ctx);
 
     GGML_API size_t  ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch);
+    GGML_API bool    ggml_get_no_alloc(struct ggml_context * ctx);
     GGML_API void    ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);
 
     GGML_API void *  ggml_get_mem_buffer     (const struct ggml_context * ctx);
@@ -617,9 +623,11 @@ extern "C" {
     GGML_API void *  ggml_get_data    (const struct ggml_tensor * tensor);
     GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
 
-    GGML_API const char *         ggml_get_name(const struct ggml_tensor * tensor);
-    GGML_API struct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * name);
-    GGML_API struct ggml_tensor * ggml_format_name(struct ggml_tensor * tensor, const char * fmt, ...);
+    GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
+
+    GGML_API const char *         ggml_get_name   (const struct ggml_tensor * tensor);
+    GGML_API struct ggml_tensor * ggml_set_name   (      struct ggml_tensor * tensor, const char * name);
+    GGML_API struct ggml_tensor * ggml_format_name(      struct ggml_tensor * tensor, const char * fmt, ...);
 
     //
     // operations on tensors with backpropagation
@@ -629,6 +637,11 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_dup_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     GGML_API struct ggml_tensor * ggml_add(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -952,11 +965,22 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    // a -> b, in-place, return view(b)
+    GGML_API struct ggml_tensor * ggml_cpy_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
     // make contiguous
     GGML_API struct ggml_tensor * ggml_cont(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    // make contiguous, in-place
+    GGML_API struct ggml_tensor * ggml_cont_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     // return view(a), b specifies the new shape
     // TODO: when we start computing gradient, make a copy instead of view
     GGML_API struct ggml_tensor * ggml_reshape(
@@ -1268,6 +1292,16 @@ extern "C" {
     typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
     typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
 
+    GGML_API struct ggml_tensor * ggml_unary(
+            struct ggml_context * ctx,
+             struct ggml_tensor * a,
+             enum ggml_unary_op op);
+
+    GGML_API struct ggml_tensor * ggml_unary_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        enum ggml_unary_op op);
+
     GGML_API struct ggml_tensor * ggml_map_unary_f32(
             struct ggml_context        * ctx,
             struct ggml_tensor         * a,
index 01467bc184372e5c36c069c22a6791d4ec6f2163..ef20bce516662e645395475e3dc6fdf01756b4ab 100644 (file)
@@ -64,7 +64,7 @@ void get_random_dims(int64_t * dims, int ndims) {
     }
 }
 
-struct ggml_tensor * get_random_tensor(
+struct ggml_tensor * get_random_tensor_f32(
         struct ggml_context * ctx0,
         int ndims,
         int64_t ne[],
@@ -112,7 +112,55 @@ struct ggml_tensor * get_random_tensor(
     return result;
 }
 
-struct ggml_tensor * get_random_tensor_int(
+struct ggml_tensor * get_random_tensor_f16(
+        struct ggml_context * ctx0,
+        int ndims,
+        int64_t ne[],
+        float fmin,
+        float fmax) {
+    struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F16, ndims, ne);
+
+    switch (ndims) {
+        case 1:
+            for (int i0 = 0; i0 < ne[0]; i0++) {
+                ((ggml_fp16_t *)result->data)[i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
+            }
+            break;
+        case 2:
+            for (int i1 = 0; i1 < ne[1]; i1++) {
+                for (int i0 = 0; i0 < ne[0]; i0++) {
+                    ((ggml_fp16_t *)result->data)[i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
+                }
+            }
+            break;
+        case 3:
+            for (int i2 = 0; i2 < ne[2]; i2++) {
+                for (int i1 = 0; i1 < ne[1]; i1++) {
+                    for (int i0 = 0; i0 < ne[0]; i0++) {
+                        ((ggml_fp16_t *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
+                    }
+                }
+            }
+            break;
+        case 4:
+            for (int i3 = 0; i3 < ne[3]; i3++) {
+                for (int i2 = 0; i2 < ne[2]; i2++) {
+                    for (int i1 = 0; i1 < ne[1]; i1++) {
+                        for (int i0 = 0; i0 < ne[0]; i0++) {
+                            ((ggml_fp16_t *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
+                        }
+                    }
+                }
+            }
+            break;
+        default:
+            assert(false);
+    };
+
+    return result;
+}
+
+struct ggml_tensor * get_random_tensor_i32(
         struct ggml_context * ctx0,
         int ndims,
         int64_t ne[],
@@ -160,23 +208,6 @@ struct ggml_tensor * get_random_tensor_int(
     return result;
 }
 
-float get_element(const struct ggml_tensor * t, int idx) {
-    if (t->type == GGML_TYPE_F32) {
-        return ((float *)t->data)[idx];
-    }
-
-    if (t->type == GGML_TYPE_I32) {
-        return ((int32_t *)t->data)[idx];
-    }
-
-    assert(false);
-    return INFINITY;
-}
-
-void set_element(struct ggml_tensor * t, int idx, float value) {
-    ((float *)t->data)[idx] = value;
-}
-
 void print_elements(const char* label, const struct ggml_tensor * t) {
     if (!t) {
         printf("%s: %s = null\n", __func__, label);
@@ -186,7 +217,7 @@ void print_elements(const char* label, const struct ggml_tensor * t) {
     printf("%s: %s = [", __func__, label);
     for (int k = 0; k < nelements; ++k) {
         if (k > 0) { printf(", "); }
-        printf("%.5f", get_element(t, k));
+        printf("%.5f", ggml_get_f32_1d(t, k));
     }
     printf("] shape: [");
     for (int k = 0; k < t->n_dims; ++k) {
@@ -237,23 +268,23 @@ bool check_gradient(
         const int nelements = ggml_nelements(x[i]);
         for (int k = 0; k < nelements; ++k) {
             // compute gradient using finite differences
-            const float x0 = get_element(x[i], k);
+            const float x0 = ggml_get_f32_1d(x[i], k);
             const float xm = x0 - eps;
             const float xp = x0 + eps;
-            set_element(x[i], k, xp);
+            ggml_set_f32_1d(x[i], k, xp);
 
             ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
 
             const float f0 = ggml_get_f32_1d(f, 0);
 
-            set_element(x[i], k, xm);
+            ggml_set_f32_1d(x[i], k, xm);
 
             ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
 
             const float f1 = ggml_get_f32_1d(f, 0);
             const float g0 = (f0 - f1)/(2.0f*eps);
 
-            set_element(x[i], k, x0);
+            ggml_set_f32_1d(x[i], k, x0);
 
             // compute gradient using backward graph
             ggml_graph_reset  (&gf);
@@ -261,7 +292,7 @@ bool check_gradient(
 
             ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
 
-            const float g1 = get_element(x[i]->grad, k);
+            const float g1 = ggml_get_f32_1d(x[i]->grad, k);
 
             const float error_abs = fabsf(g0 - g1);
             const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabsf(g0) : 0;
@@ -392,19 +423,35 @@ int main(int argc, const char ** argv) {
 
         struct ggml_tensor * x[MAX_NARGS];
 
-        // add
+        // add f32
         {
             const int nargs = 2;
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
                 for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
                     ggml_set_param(ctx0, x[i]);
                 }
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1]));
 
-                check_gradient("add", ctx0, x, f, ndims, nargs, 1e-3f, 2e-3f, 2e-3f);
+                check_gradient("add f32", ctx0, x, f, ndims, nargs, 1e-3f, 2e-3f, 2e-3f);
+            }
+        }
+
+        // add f16
+        {
+            const int nargs = 2;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f16(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1]));
+
+                check_gradient("add f16", ctx0, x, f, ndims, nargs, 1e-1f, 2e-1f, 2e-1f);
             }
         }
 
@@ -414,7 +461,7 @@ int main(int argc, const char ** argv) {
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
                 for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
                     ggml_set_param(ctx0, x[i]);
                 }
 
@@ -430,7 +477,7 @@ int main(int argc, const char ** argv) {
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
                 for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
                     ggml_set_param(ctx0, x[i]);
                 }
 
@@ -446,7 +493,7 @@ int main(int argc, const char ** argv) {
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
                 for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor(ctx0, ndims, ne, 0.5f, 1.0f);
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, 0.5f, 1.0f);
                     ggml_set_param(ctx0, x[i]);
                 }
 
@@ -462,7 +509,7 @@ int main(int argc, const char ** argv) {
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
                 for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
                     ggml_set_param(ctx0, x[i]);
                 }
 
@@ -478,7 +525,7 @@ int main(int argc, const char ** argv) {
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
                 for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f);
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f);
                     ggml_set_param(ctx0, x[i]);
                 }
 
@@ -494,7 +541,7 @@ int main(int argc, const char ** argv) {
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
                 for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f);
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f);
                     ggml_set_param(ctx0, x[i]);
                 }
 
@@ -510,7 +557,7 @@ int main(int argc, const char ** argv) {
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
                 for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
                     ggml_set_param(ctx0, x[i]);
                 }
 
@@ -527,7 +574,7 @@ int main(int argc, const char ** argv) {
 
             for (int ndims = 1; ndims <= 4; ++ndims) {
                 for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
                     ggml_set_param(ctx0, x[i]);
                 }
 
@@ -537,6 +584,40 @@ int main(int argc, const char ** argv) {
             }
         }
 
+        // mean, not yet fully implemented
+        if(0)
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_mean(ctx0, x[0]));
+
+                check_gradient("mean", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+            }
+        }
+
+        // argmax
+        if (0)
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_argmax(ctx0, x[0]));
+
+                check_gradient("argmax", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+            }
+        }
+
         // repeat
         {
             int64_t ne2[4];
@@ -549,15 +630,36 @@ int main(int argc, const char ** argv) {
 
             const int nargs = 1;
             for (int ndims = 1; ndims <= 2; ++ndims) {
-                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
-                x[1] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
                 ggml_set_param(ctx0, x[0]);
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[1], ggml_repeat(ctx0, x[0], x[1]))));
 
                 check_gradient("repeat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY);
             }
+        }
+
+        // repeat back
+        {
+            int64_t ne2[4];
+            get_random_dims(ne2, 4);
+
+            ne2[0] = ne[0] * ne2[0];
+            ne2[1] = ne[1] * ne2[1];
+            ne2[2] = 1;
+            ne2[3] = 1;
+
+            const int nargs = 1;
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
+                ggml_set_param(ctx0, x[0]);
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[0], ggml_repeat_back(ctx0, x[1], x[0]))));
 
+                check_gradient("repeat back", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY);
+            }
         }
 
         // abs (finite differences do not work)
@@ -566,7 +668,7 @@ int main(int argc, const char ** argv) {
 
         //    for (int ndims = 1; ndims <= 2; ++ndims) {
         //        for (int i = 0; i < nargs; ++i) {
-        //            x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+        //            x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
         //            ggml_set_param(ctx0, x[i]);
         //        }
 
@@ -576,17 +678,82 @@ int main(int argc, const char ** argv) {
         //    }
         //}
 
+        // sgn
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor* f = ggml_sum(ctx0, ggml_sgn(ctx0, x[0]));
+
+                check_gradient("sgn", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+            }
+        }
+
+        // neg
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor* f = ggml_sum(ctx0, ggml_neg(ctx0, x[0]));
+
+                check_gradient("neg", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+            }
+        }
+
+        // step
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor* f = ggml_sum(ctx0, ggml_step(ctx0, x[0]));
+
+                check_gradient("step", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+            }
+        }
+
+        // tanh, not yet fully implemented
+        if(0)
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor* f = ggml_sum(ctx0, ggml_tanh(ctx0, x[0]));
+
+                check_gradient("tanh", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+            }
+        }
+
         // mul_mat
         {
             const int nargs = 2;
 
             for (int ndims = 2; ndims <= 2; ++ndims) {
-                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
                 {
                     int64_t ne2[4];
                     get_random_dims(ne2, 4);
                     ne2[0] = ne[0];
-                    x[1] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
+                    x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
                 }
 
                 ggml_set_param(ctx0, x[0]);
@@ -602,13 +769,63 @@ int main(int argc, const char ** argv) {
             }
         }
 
+        // elu, not yet fully implemented
+        if(0)
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor* f = ggml_sum(ctx0, ggml_elu(ctx0, x[0]));
+
+                check_gradient("elu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+            }
+        }
+
+        // relu
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor* f = ggml_sum(ctx0, ggml_relu(ctx0, x[0]));
+
+                check_gradient("relu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // gelu, not yet fully implemented
+        if(0)
+        {
+            const int nargs = 1;
+
+            for (int ndims = 1; ndims <= 4; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+
+                struct ggml_tensor* f = ggml_sum(ctx0, ggml_gelu(ctx0, x[0]));
+
+                check_gradient("gelu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
+            }
+        }
+
         // silu
         {
             const int nargs = 1;
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
                 for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
                     ggml_set_param(ctx0, x[i]);
                 }
 
@@ -629,7 +846,7 @@ int main(int argc, const char ** argv) {
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
                 for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
                     ggml_set_param(ctx0, x[i]);
                 }
 
@@ -647,8 +864,8 @@ int main(int argc, const char ** argv) {
             ne2[0] = 1;
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
-                x[1] = get_random_tensor(ctx0, 1, ne2, -1.0f, 1.0f);
-                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
 
                 ggml_set_param(ctx0, x[0]);
                 ggml_set_param(ctx0, x[1]);
@@ -659,20 +876,37 @@ int main(int argc, const char ** argv) {
             }
         }
 
-        // cpy
+        // cpy f32
         {
             const int nargs = 2;
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
                 for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
                     ggml_set_param(ctx0, x[i]);
                 }
                 // x[1] is overwritten by x[0], so the gradients don't propagate to x[1]
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1]));
 
-                check_gradient("cpy", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+                check_gradient("cpy f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
+            }
+        }
+
+        // cpy f16
+        {
+            const int nargs = 2;
+
+            for (int ndims = 1; ndims <= 2; ++ndims) {
+                for (int i = 0; i < nargs; ++i) {
+                    x[i] = get_random_tensor_f16(ctx0, ndims, ne, -1.0f, 1.0f);
+                    ggml_set_param(ctx0, x[i]);
+                }
+                // x[1] is overwritten by x[0], so the gradients don't propagate to x[1]
+
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1]));
+
+                check_gradient("cpy f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY);
             }
         }
 
@@ -689,8 +923,8 @@ int main(int argc, const char ** argv) {
                 for (int i = 0; i < ndims; ++i) {
                     ne2[0] *= ne[i];
                 }
-                x[0] = get_random_tensor(ctx0, 1, ne2, -1.0f, 1.0f);
-                x[1] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                x[0] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
+                x[1] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
                 ggml_set_param(ctx0, x[0]);
 
 
@@ -712,8 +946,8 @@ int main(int argc, const char ** argv) {
                 for (int i = 0; i < ndims; ++i) {
                     ne2[0] *= ne[i];
                 }
-                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
-                x[1] = get_random_tensor(ctx0, 1, ne2, -1.0f, 1.0f);
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+                x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
                 ggml_set_param(ctx0, x[0]);
 
 
@@ -729,7 +963,7 @@ int main(int argc, const char ** argv) {
             const int nargs = 2;
             for (int ndims = 1; ndims <= 4; ++ndims) {
 
-                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
                 ggml_set_param(ctx0, x[0]);
 
                 get_random_dims(ne2, 1);
@@ -737,7 +971,7 @@ int main(int argc, const char ** argv) {
                     get_random_dims(ne2, 1);
                 }
 
-                x[1] = get_random_tensor(ctx0, 1, ne2, -1.0f, 1.0f);
+                x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
                 ggml_set_param(ctx0, x[1]);
 
                 const int max_offset = MAX(0, ggml_nelements(x[0]) - ggml_nelements(x[1]));
@@ -758,7 +992,7 @@ int main(int argc, const char ** argv) {
             const int nargs = 2;
             for (int ndims = 2; ndims <= 4; ++ndims) {
 
-                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
                 ggml_set_param(ctx0, x[0]);
 
                 get_random_dims(ne2, 2);
@@ -766,7 +1000,7 @@ int main(int argc, const char ** argv) {
                     get_random_dims(ne2, 2);
                 }
 
-                x[1] = get_random_tensor(ctx0, 2, ne2, -1.0f, 1.0f);
+                x[1] = get_random_tensor_f32(ctx0, 2, ne2, -1.0f, 1.0f);
                 ggml_set_param(ctx0, x[1]);
 
                 max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
@@ -790,7 +1024,7 @@ int main(int argc, const char ** argv) {
             const int nargs = 2;
             for (int ndims = 3; ndims <= 4; ++ndims) {
 
-                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
                 ggml_set_param(ctx0, x[0]);
 
                 get_random_dims(ne2, 3);
@@ -798,7 +1032,7 @@ int main(int argc, const char ** argv) {
                     get_random_dims(ne2, 3);
                 }
 
-                x[1] = get_random_tensor(ctx0, 3, ne2, -1.0f, 1.0f);
+                x[1] = get_random_tensor_f32(ctx0, 3, ne2, -1.0f, 1.0f);
                 ggml_set_param(ctx0, x[1]);
 
                 max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
@@ -824,7 +1058,7 @@ int main(int argc, const char ** argv) {
             const int nargs = 2;
             for (int ndims = 4; ndims <= 4; ++ndims) {
 
-                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
                 ggml_set_param(ctx0, x[0]);
 
                 get_random_dims(ne2, 4);
@@ -832,7 +1066,7 @@ int main(int argc, const char ** argv) {
                     get_random_dims(ne2, 4);
                 }
 
-                x[1] = get_random_tensor(ctx0, 4, ne2, -1.0f, 1.0f);
+                x[1] = get_random_tensor_f32(ctx0, 4, ne2, -1.0f, 1.0f);
                 ggml_set_param(ctx0, x[1]);
 
                 max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
@@ -858,7 +1092,7 @@ int main(int argc, const char ** argv) {
             const int nargs = 2;
             for (int ndims = 1; ndims <= 4; ++ndims) {
 
-                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
                 ggml_set_param(ctx0, x[0]);
 
                 get_random_dims(ne2, 1);
@@ -866,7 +1100,7 @@ int main(int argc, const char ** argv) {
                     get_random_dims(ne2, 1);
                 }
 
-                x[1] = get_random_tensor(ctx0, 1, ne2, -1.0f, 1.0f);
+                x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
                 ggml_set_param(ctx0, x[1]);
 
                 const int max_offset = MAX(0, ggml_nelements(x[0]) - ggml_nelements(x[1]));
@@ -887,7 +1121,7 @@ int main(int argc, const char ** argv) {
             const int nargs = 1;
             for (int ndims = 2; ndims <= 4; ++ndims) {
 
-                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
                 ggml_set_param(ctx0, x[0]);
 
                 get_random_dims(ne2, 2);
@@ -895,7 +1129,7 @@ int main(int argc, const char ** argv) {
                     get_random_dims(ne2, 2);
                 }
 
-                x[1] = get_random_tensor(ctx0, 2, ne2, -1.0f, 1.0f);
+                x[1] = get_random_tensor_f32(ctx0, 2, ne2, -1.0f, 1.0f);
                 ggml_set_param(ctx0, x[1]);
 
                 max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
@@ -915,7 +1149,7 @@ int main(int argc, const char ** argv) {
             const int nargs = 1;
             for (int ndims = 1; ndims <= 4; ++ndims) {
 
-                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
 
                 ggml_set_param(ctx0, x[0]);
 
@@ -941,7 +1175,7 @@ int main(int argc, const char ** argv) {
             const int nargs = 1;
             for (int ndims = 1; ndims <= 4; ++ndims) {
 
-                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
 
                 get_random_dims(ne2, 2);
                 while (ne2[0]*ne2[1] > ggml_nelements(x[0])) {
@@ -971,7 +1205,7 @@ int main(int argc, const char ** argv) {
             const int nargs = 1;
             for (int ndims = 1; ndims <= 4; ++ndims) {
 
-                x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
 
                 get_random_dims(ne2, 3);
                 while (ne2[0]*ne2[1]*ne2[2] > ggml_nelements(x[0])) {
@@ -1010,7 +1244,7 @@ int main(int argc, const char ** argv) {
                 for (int i=ndims; i<4; ++i) {
                     ne2[i] = 1;
                 }
-                x[0] = get_random_tensor(ctx0, 4, ne2, -1.0f, 1.0f);
+                x[0] = get_random_tensor_f32(ctx0, 4, ne2, -1.0f, 1.0f);
 
                 ggml_set_param(ctx0, x[0]);
 
@@ -1043,7 +1277,7 @@ int main(int argc, const char ** argv) {
                 for (int i=ndims; i<4; ++i) {
                     ne2[i] = 1;
                 }
-                x[0] = get_random_tensor(ctx0, 4, ne2, -1.0f, 1.0f);
+                x[0] = get_random_tensor_f32(ctx0, 4, ne2, -1.0f, 1.0f);
 
                 ggml_set_param(ctx0, x[0]);
 
@@ -1060,8 +1294,8 @@ int main(int argc, const char ** argv) {
             int64_t ne3[4] = {1+irand(ne[1]), 1, 1, 1};
             const int nargs = 1;
             const int ndims = 2;
-            x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
-            x[1] = get_random_tensor_int(ctx0, 1, ne3, 0, ne2[1]);
+            x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
+            x[1] = get_random_tensor_i32(ctx0, 1, ne3, 0, ne2[1]);
 
             ggml_set_param(ctx0, x[0]);
 
@@ -1075,7 +1309,7 @@ int main(int argc, const char ** argv) {
             const int nargs = 1;
             const int ndims = 2;
 
-            x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+            x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
             ggml_set_param(ctx0, x[0]);
 
             int n_past = irand(ne[0]);
@@ -1090,7 +1324,7 @@ int main(int argc, const char ** argv) {
             const int nargs = 1;
             const int ndims = 2;
 
-            x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
+            x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
             ggml_set_param(ctx0, x[0]);
 
             int n_past = irand(ne[0]);
@@ -1108,7 +1342,7 @@ int main(int argc, const char ** argv) {
             get_random_dims(ne2, 4);
 
             for (int ndims = 1; ndims <= 3; ++ndims) {
-                x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
                 ggml_set_param(ctx0, x[0]);
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_soft_max(ctx0, x[0]));
@@ -1125,8 +1359,8 @@ int main(int argc, const char ** argv) {
             get_random_dims(ne2, 4);
 
             for (int ndims = 1; ndims <= 3; ++ndims) {
-                x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
-                x[1] = get_random_tensor(ctx0, ndims, ne2, 0.0f, 1.0f);
+                x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
+                x[1] = get_random_tensor_f32(ctx0, ndims, ne2, 0.0f, 1.0f);
                 ggml_set_param(ctx0, x[0]);
 
                 struct ggml_tensor * f = ggml_sum(ctx0, ggml_cross_entropy_loss(ctx0, x[0], x[1]));
@@ -1136,7 +1370,7 @@ int main(int argc, const char ** argv) {
             }
         }
 
-        // rope
+        // rope f32
         {
             const int nargs = 1;
 
@@ -1148,7 +1382,7 @@ int main(int argc, const char ** argv) {
             for (int ndims = 3; ndims <= 4; ++ndims) {
                 for (int mode = 0; mode < 4; ++mode) {
                     for (int n_past = 1; n_past < ne2[2]; ++n_past) {
-                        x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
+                        x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
 
                         ggml_set_param(ctx0, x[0]);
 
@@ -1163,14 +1397,89 @@ int main(int argc, const char ** argv) {
 
                         struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode, 0));
 
-                        GGML_PRINT_DEBUG("rope: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
-                        check_gradient("rope", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY);
+                        GGML_PRINT_DEBUG("rope f32: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
+                        check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY);
+                    }
+                }
+            }
+        }
+
+        // rope f16
+        {
+            const int nargs = 1;
+
+            int64_t ne2[4];
+            get_random_dims(ne2, 4);
+            ne2[0] += ne2[0] % 2;
+            int n_rot = ne2[0];
+
+            for (int ndims = 3; ndims <= 4; ++ndims) {
+                for (int mode = 0; mode < 4; ++mode) {
+                    for (int n_past = 1; n_past < ne2[2]; ++n_past) {
+                        x[0] = get_random_tensor_f16(ctx0, ndims, ne2, -1.0f, 1.0f);
+
+                        ggml_set_param(ctx0, x[0]);
+
+                        const bool skip_past = (mode & 1);
+                        if (skip_past) {
+                            // we have no past, so this would have to work on uninitialized memory.
+                            // we only test the gradients here;
+                            // skip_past should have no influence on gradient computation.
+                            // so when other modes work, we assume that this does as well.
+                            continue;
+                        }
+
+                        struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode, 0));
+
+                        GGML_PRINT_DEBUG("rope f16: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
+                        check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY);
+                    }
+                }
+            }
+        }
+
+        // flash_attn f32
+        {
+            const int nargs = 3;
+
+            int64_t ne2[4];
+
+            get_random_dims(ne2, 4);
+            int64_t D = ne2[0];
+            int64_t N = ne2[1];
+            int64_t M = ne2[2] + N;
+            int64_t B = ne2[3];
+
+            for (int masked = 0; masked <= 1; ++masked) {
+                for (int ndims = 2; ndims <= 4; ++ndims) {
+                    int64_t neq[4] = { D, N, B, ne[3] };
+                    int64_t nek[4] = { D, M, B, ne[3] };
+                    int64_t nev[4] = { M, D, B, ne[3] };
+                    if (ndims == 2) {
+                        neq[2] = 1; neq[3] = 1;
+                        nek[2] = 1; nek[3] = 1;
+                        nev[2] = 1; nev[3] = 1;
+                    } else if (ndims == 3) {
+                        neq[3] = 1;
+                        nek[3] = 1;
+                        nev[3] = 1;
                     }
+                    x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f);
+                    x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f);
+                    x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f);
+                    ggml_set_param(ctx0, x[0]);
+                    ggml_set_param(ctx0, x[1]);
+                    ggml_set_param(ctx0, x[2]);
+
+                    struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
+
+                    check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f);
                 }
             }
         }
 
-        // flash_attn
+        // flash_attn f16, not yet fully implemented
+        if(0)
         {
             const int nargs = 3;
 
@@ -1196,16 +1505,16 @@ int main(int argc, const char ** argv) {
                         nek[3] = 1;
                         nev[3] = 1;
                     }
-                    x[0] = get_random_tensor(ctx0, ndims, neq, -0.1250f, 0.1250f);
-                    x[1] = get_random_tensor(ctx0, ndims, nek, -0.1250f, 0.1250f);
-                    x[2] = get_random_tensor(ctx0, ndims, nev, -0.1250f, 0.1250f);
+                    x[0] = get_random_tensor_f16(ctx0, ndims, neq, -0.1250f, 0.1250f);
+                    x[1] = get_random_tensor_f16(ctx0, ndims, nek, -0.1250f, 0.1250f);
+                    x[2] = get_random_tensor_f16(ctx0, ndims, nev, -0.1250f, 0.1250f);
                     ggml_set_param(ctx0, x[0]);
                     ggml_set_param(ctx0, x[1]);
                     ggml_set_param(ctx0, x[2]);
 
                     struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
 
-                    check_gradient("flash_attn", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f);
+                    check_gradient("flash_attn f16", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f);
                 }
             }
         }
index 5531814c48c997e1d2e0d5e563ffa9235877fbf1..4eef62bcfb96b11eec2b8cf4845d577db93525f7 100644 (file)
@@ -125,9 +125,9 @@ int main(void) {
     };
     struct ggml_context * ctx = ggml_init(params);
 
-    int64_t ne1[4] = {4, 1024, 1, 1};
-    int64_t ne2[4] = {4, 2048, 1, 1};;
-    int64_t ne3[4] = {1024, 2048, 1, 1};
+    int64_t ne1[4] = {4, 128, 1, 1};
+    int64_t ne2[4] = {4, 256, 1, 1};;
+    int64_t ne3[4] = {128, 256, 1, 1};
 
     struct ggml_tensor * a = get_random_tensor(ctx, 2, ne1, -1, +1);
     struct ggml_tensor * b = get_random_tensor(ctx, 2, ne2, -1, +1);