]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml: refactor cross entropy loss CPU impl. (#976)
authorJohannes Gäßler <redacted>
Wed, 2 Oct 2024 13:32:39 +0000 (15:32 +0200)
committerGitHub <redacted>
Wed, 2 Oct 2024 13:32:39 +0000 (15:32 +0200)
include/ggml-backend.h
src/ggml.c

index 71c0bef8ee7ee69a7d59eb881d9a9c67551630a9..f94af936c6ba99fd66c5be500e9af1455e14cf0f 100644 (file)
@@ -185,7 +185,7 @@ extern "C" {
     GGML_API void                 ggml_backend_sched_free(ggml_backend_sched_t sched);
 
     // Initialize backend buffers from a measure graph
-    GGML_API bool                 ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
+    GGML_API bool                 ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); // returns success
 
     GGML_API int                  ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched);
     GGML_API ggml_backend_t       ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i);
@@ -200,7 +200,7 @@ extern "C" {
     GGML_API ggml_backend_t       ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
 
     // Allocate and compute graph on the backend scheduler
-    GGML_API bool                 ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
+    GGML_API bool                 ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success
     GGML_API enum ggml_status     ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
     GGML_API enum ggml_status     ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
     GGML_API void                 ggml_backend_sched_synchronize(ggml_backend_sched_t sched);
index bcbc32d913dec87393c8db7b667802637cb11b8c..fd411a64b1868b412b61f98783b415fa80884477 100644 (file)
@@ -4195,9 +4195,13 @@ static void ggml_set_op_params_f32(struct ggml_tensor * tensor, uint32_t i, floa
 }
 
 struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
+    if (ggml_is_empty(tensor)) {
+        return tensor;
+    }
     if (tensor->buffer) {
         ggml_backend_tensor_memset(tensor, 0, 0, ggml_nbytes(tensor));
     } else {
+        GGML_ASSERT(tensor->data);
         memset(tensor->data, 0, ggml_nbytes(tensor));
     }
     return tensor;
@@ -16810,41 +16814,40 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    GGML_ASSERT(ggml_is_contiguous(src1));
-    GGML_ASSERT(ggml_is_scalar(dst));
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+    GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
     GGML_ASSERT(ggml_are_same_shape(src0, src1));
+    GGML_ASSERT(ggml_is_scalar(dst));
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    // TODO: handle transposed/permuted matrices
+    const int64_t nc = src0->ne[0];
+    const int64_t nr = ggml_nrows(src0);
 
     const int ith = params->ith;
     const int nth = params->nth;
 
-    float * sums = (float *) params->wdata;
-
-    // TODO: handle transposed/permuted matrices
-    const int nc = src0->ne[0];
-    const int nr = ggml_nrows(src0);
+    float * sums =  (float *) params->wdata;
+    float * st   = ((float *) params->wdata) + nth + ith*nc;
+    float sum_thread = 0.0f;
 
     GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));
 
-    if (ith == 0) {
-        memset(sums, 0, sizeof(float) * (nth + nth * nc));
-    }
-    ggml_barrier(params->threadpool);
-
     // rows per thread
-    const int dr = (nr + nth - 1)/nth;
+    const int64_t dr = (nr + nth - 1)/nth;
 
     // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
+    const int64_t ir0 = dr*ith;
+    const int64_t ir1 = MIN(ir0 + dr, nr);
 
-    for (int i1 = ir0; i1 < ir1; i1++) {
-        float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
-        float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
-        float * st = ((float *) params->wdata) + nth + ith*nc;
+    for (int64_t i1 = ir0; i1 < ir1; ++i1) {
+        const float * s0 = (const float *)((const char *) src0->data + i1*src0->nb[1]);
+        const float * s1 = (const float *)((const char *) src1->data + i1*src1->nb[1]);
 
 #ifndef NDEBUG
-        for (int i = 0; i < nc; ++i) {
+        for (int64_t i = 0; i < nc; ++i) {
             //printf("p[%d] = %f\n", i, p[i]);
             assert(!isnan(s0[i]));
             assert(!isnan(s1[i]));
@@ -16853,23 +16856,24 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
 
         float max = -INFINITY;
         ggml_vec_max_f32(nc, &max, s0);
-        ggml_float sum = ggml_vec_log_soft_max_f32(nc, st, s0, max);
-        assert(sum >= 0.0);
+        const ggml_float sum_softmax = ggml_vec_log_soft_max_f32(nc, st, s0, max);
+        assert(sum_softmax >= 0.0);
 
-        ggml_vec_add1_f32(nc, st, st, -sum);
+        ggml_vec_add1_f32(nc, st, st, -sum_softmax);
         ggml_vec_mul_f32(nc, st, st, s1);
 
-        float st_sum = 0.0f;
-        ggml_vec_sum_f32(nc, &st_sum, st);
-        sums[ith] += st_sum;
+        float sum_st = 0.0f;
+        ggml_vec_sum_f32(nc, &sum_st, st);
+        sum_thread += sum_st;
 
 #ifndef NDEBUG
-        for (int i = 0; i < nc; ++i) {
+        for (int64_t i = 0; i < nc; ++i) {
             assert(!isnan(st[i]));
             assert(!isinf(st[i]));
         }
 #endif
     }
+    sums[ith] = sum_thread;
     ggml_barrier(params->threadpool);
 
     if (ith == 0) {
@@ -16935,7 +16939,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
         float * s1  = (float *)((char *) src1->data + i1*src1->nb[1]);
 
 #ifndef NDEBUG
-        for (int i = 0; i < nc; ++i) {
+        for (int64_t i = 0; i < nc; ++i) {
             //printf("p[%d] = %f\n", i, p[i]);
             assert(!isnan(s0[i]));
             assert(!isnan(s1[i]));
@@ -16954,7 +16958,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
         ggml_vec_scale_f32(nc, ds0, d_by_nr);
 
 #ifndef NDEBUG
-        for (int i = 0; i < nc; ++i) {
+        for (int64_t i = 0; i < nc; ++i) {
             assert(!isnan(ds0[i]));
             assert(!isinf(ds0[i]));
         }