]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
tests : remove unnecessary funcs
authorGeorgi Gerganov <redacted>
Mon, 24 Jul 2023 11:24:53 +0000 (14:24 +0300)
committerGeorgi Gerganov <redacted>
Mon, 24 Jul 2023 11:24:53 +0000 (14:24 +0300)
tests/test-grad0.c

index 7e03b5426d57ca1a6ca441601e1b41610c7dfd15..ef20bce516662e645395475e3dc6fdf01756b4ab 100644 (file)
@@ -208,42 +208,6 @@ struct ggml_tensor * get_random_tensor_i32(
     return result;
 }
 
-float get_element(const struct ggml_tensor * t, int idx) {
-    switch (t->type) {
-        case GGML_TYPE_F32:
-            return ((float *)t->data)[idx];
-        case GGML_TYPE_I32:
-            return ((int32_t *)t->data)[idx];
-        case GGML_TYPE_F16:
-            return ggml_fp16_to_fp32(((ggml_fp16_t *)t->data)[idx]);
-        case GGML_TYPE_I16:
-            return ((int16_t *)t->data)[idx];
-        default:
-            assert(false);
-    }
-    return INFINITY;
-}
-
-void set_element(struct ggml_tensor * t, int idx, float value) {
-    switch (t->type) {
-        case GGML_TYPE_F32:
-            ((float *)t->data)[idx] = value;
-            break;
-        case GGML_TYPE_I32:
-            ((int32_t *)t->data)[idx] = value;
-            break;
-        case GGML_TYPE_F16:
-            ((ggml_fp16_t*)t->data)[idx] = ggml_fp32_to_fp16(value);
-            break;
-        case GGML_TYPE_I16:
-            ((int16_t *)t->data)[idx] = value;
-            break;
-        default:
-            assert(false);
-    }
-    ;
-}
-
 void print_elements(const char* label, const struct ggml_tensor * t) {
     if (!t) {
         printf("%s: %s = null\n", __func__, label);
@@ -253,7 +217,7 @@ void print_elements(const char* label, const struct ggml_tensor * t) {
     printf("%s: %s = [", __func__, label);
     for (int k = 0; k < nelements; ++k) {
         if (k > 0) { printf(", "); }
-        printf("%.5f", get_element(t, k));
+        printf("%.5f", ggml_get_f32_1d(t, k));
     }
     printf("] shape: [");
     for (int k = 0; k < t->n_dims; ++k) {
@@ -304,23 +268,23 @@ bool check_gradient(
         const int nelements = ggml_nelements(x[i]);
         for (int k = 0; k < nelements; ++k) {
             // compute gradient using finite differences
-            const float x0 = get_element(x[i], k);
+            const float x0 = ggml_get_f32_1d(x[i], k);
             const float xm = x0 - eps;
             const float xp = x0 + eps;
-            set_element(x[i], k, xp);
+            ggml_set_f32_1d(x[i], k, xp);
 
             ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
 
             const float f0 = ggml_get_f32_1d(f, 0);
 
-            set_element(x[i], k, xm);
+            ggml_set_f32_1d(x[i], k, xm);
 
             ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
 
             const float f1 = ggml_get_f32_1d(f, 0);
             const float g0 = (f0 - f1)/(2.0f*eps);
 
-            set_element(x[i], k, x0);
+            ggml_set_f32_1d(x[i], k, x0);
 
             // compute gradient using backward graph
             ggml_graph_reset  (&gf);
@@ -328,7 +292,7 @@ bool check_gradient(
 
             ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
 
-            const float g1 = get_element(x[i]->grad, k);
+            const float g1 = ggml_get_f32_1d(x[i]->grad, k);
 
             const float error_abs = fabsf(g0 - g1);
             const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabsf(g0) : 0;