]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml-opt: fix data corruption (ggml/1022)
authorJohannes Gäßler <redacted>
Wed, 20 Nov 2024 13:56:04 +0000 (14:56 +0100)
committerGeorgi Gerganov <redacted>
Thu, 21 Nov 2024 07:22:02 +0000 (09:22 +0200)
ggml/src/ggml-backend.cpp
ggml/src/ggml-impl.h
ggml/src/ggml-opt.cpp
ggml/src/ggml.c
tests/test-backend-ops.cpp

index 9dcde8d11952ae4fd6c0d395081f75760c843912..3433d082e775b0a8a391a0b83898c23f7a87bdc6 100644 (file)
@@ -252,6 +252,7 @@ void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_ten
 }
 
 void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    GGML_ASSERT(tensor);
     ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
 
     if (size == 0) {
@@ -266,6 +267,7 @@ void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, siz
 }
 
 void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    GGML_ASSERT(tensor);
     ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
 
     if (size == 0) {
index 92a64fe5a59303b2fae76f70b4503e8b842bf446..3965be78755a1a3d24aa59c36df67c230f5cd43e 100644 (file)
@@ -295,6 +295,9 @@ struct ggml_cgraph {
     enum ggml_cgraph_eval_order order;
 };
 
+// returns a slice of cgraph with nodes [i0, i1)
+// the slice does not have leafs or gradients
+// if you need the gradients, get them from the original graph
 struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph, int i0, int i1);
 
 // Memory allocation
index 040205a31bab0d628b8e9c940a851f48a5c14dba..7c3e24103a250ae47fd2cfee653eed1f85d338be 100644 (file)
 #include <vector>
 
 struct ggml_opt_dataset {
-    struct ggml_context   * ctx;
-    ggml_backend_buffer_t   buf;
-    struct ggml_tensor    * data;
-    struct ggml_tensor    * labels;
+    struct ggml_context   * ctx    = nullptr;
+    ggml_backend_buffer_t   buf    = nullptr;
+    struct ggml_tensor    * data   = nullptr;
+    struct ggml_tensor    * labels = nullptr;
 
-    int64_t ndata;
-    int64_t ndata_shard;
-    size_t  nbs_data;
-    size_t  nbs_labels;
+    int64_t ndata       = -1;
+    int64_t ndata_shard = -1;
+    size_t  nbs_data    = -1;
+    size_t  nbs_labels  = -1;
 
     std::vector<int64_t> permutation;
 };
 
 struct ggml_opt_context {
-    ggml_backend_sched_t    backend_sched;
-    ggml_cgraph           * allocated_graph;
-    ggml_cgraph           * allocated_graph_copy;
-    struct ggml_context   * ctx_static;
-    struct ggml_context   * ctx_static_cpu;
-    struct ggml_context   * ctx_compute;
-    struct ggml_context   * ctx_copy;
-    ggml_backend_buffer_t   buf_static;
-    ggml_backend_buffer_t   buf_static_cpu;
+    ggml_backend_sched_t    backend_sched        = nullptr;
+    ggml_cgraph           * allocated_graph      = nullptr;
+    ggml_cgraph           * allocated_graph_copy = nullptr;
+    struct ggml_context   * ctx_static           = nullptr;
+    struct ggml_context   * ctx_static_cpu       = nullptr;
+    struct ggml_context   * ctx_compute          = nullptr;
+    struct ggml_context   * ctx_copy             = nullptr;
+    ggml_backend_buffer_t   buf_static           = nullptr;
+    ggml_backend_buffer_t   buf_static_cpu       = nullptr;
     std::mt19937            rng;
 
-    struct ggml_tensor * inputs;
-    struct ggml_tensor * outputs;
-    struct ggml_tensor * labels;
+    struct ggml_tensor * inputs  = nullptr;
+    struct ggml_tensor * outputs = nullptr;
+    struct ggml_tensor * labels  = nullptr;
 
-    struct ggml_tensor * loss;
-    struct ggml_tensor * pred;
-    struct ggml_tensor * ncorrect;
+    struct ggml_tensor * loss     = nullptr;
+    struct ggml_tensor * pred     = nullptr;
+    struct ggml_tensor * ncorrect = nullptr;
 
-    struct ggml_cgraph * gf;
-    struct ggml_cgraph * gb_grad;
-    struct ggml_cgraph * gb_opt;
+    struct ggml_cgraph * gf      = nullptr;
+    struct ggml_cgraph * gb_grad = nullptr;
+    struct ggml_cgraph * gb_opt  = nullptr;
 
-    int64_t iter;
-    int32_t opt_period;
-    int32_t opt_i;
-    bool    loss_per_datapoint;
+    int64_t iter               = 1;
+    int32_t opt_period         = 1;
+    int32_t opt_i              = 0;
+    bool    loss_per_datapoint = false;
 
-    ggml_opt_get_optimizer_params get_opt_pars;
-    void * get_opt_pars_ud;
-    struct ggml_tensor * adamw_params;
+    ggml_opt_get_optimizer_params get_opt_pars = nullptr;
+    void * get_opt_pars_ud                     = nullptr;
+    struct ggml_tensor * adamw_params          = nullptr;
 };
 
 struct ggml_opt_result {
@@ -67,8 +67,8 @@ struct ggml_opt_result {
     std::vector<int32_t> pred;
     int64_t              ncorrect = 0;
 
-    bool loss_per_datapoint = false;
-    int64_t opt_period = -1;
+    int64_t opt_period         = -1;
+    bool    loss_per_datapoint = false;
 };
 
 // ====== Dataset ======
@@ -188,11 +188,11 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us
 }
 
 struct ggml_opt_params ggml_opt_default_params(
-        ggml_backend_sched_t backend_sched,
-        struct ggml_context * ctx_compute,
-        struct ggml_tensor * inputs,
-        struct ggml_tensor * outputs,
-        enum ggml_opt_loss_type loss_type) {
+        ggml_backend_sched_t      backend_sched,
+        struct ggml_context     * ctx_compute,
+        struct ggml_tensor      * inputs,
+        struct ggml_tensor      * outputs,
+        enum ggml_opt_loss_type   loss_type) {
     return {
         /*backend_sched   =*/ backend_sched,
         /*ctx_compute     =*/ ctx_compute,
@@ -237,25 +237,33 @@ static ggml_tensor * map_tensor(std::map<ggml_tensor *, ggml_tensor *> & tensor_
     return new_tensor;
 }
 
-static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * graph) {
+static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * src) {
     std::map<ggml_tensor *, ggml_tensor *> tensor_map;
 
-    ggml_cgraph * new_graph = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true);
+    ggml_cgraph * dst = ggml_new_graph_custom(ctx, src->size, /*grads =*/ true);
 
-    for (int i = 0; i < graph->n_leafs; i++) {
-        ggml_build_forward_expand(new_graph, map_tensor(tensor_map, ctx, graph->leafs[i]));
+    for (int i = 0; i < src->n_leafs; i++) {
+        ggml_build_forward_expand(dst, map_tensor(tensor_map, ctx, src->leafs[i]));
     }
-    for (int i = 0; i < graph->n_nodes; i++) {
-        ggml_build_forward_expand(new_graph, map_tensor(tensor_map, ctx, graph->nodes[i]));
+    GGML_ASSERT(dst->n_leafs == src->n_leafs);
+    for (int i = 0; i < src->n_nodes; i++) {
+        ggml_build_forward_expand(dst, map_tensor(tensor_map, ctx, src->nodes[i]));
     }
-    for (int i = 0; i < graph->n_nodes; ++i) {
-        const size_t igrad_src = ggml_hash_find(&graph->visited_hash_set, graph->nodes[i]);
-        const size_t igrad_dst = ggml_hash_find(&new_graph->visited_hash_set, new_graph->nodes[i]);
-        graph->grads[igrad_dst]     = new_graph->grads[igrad_src];
-        graph->grad_accs[igrad_dst] = new_graph->grad_accs[igrad_src];
+    GGML_ASSERT(dst->n_nodes == src->n_nodes);
+    for (int i = 0; i < src->n_nodes; ++i) {
+        const size_t igrad_src = ggml_hash_find(&src->visited_hash_set, src->nodes[i]);
+        const size_t igrad_dst = ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]);
+
+        GGML_ASSERT(igrad_src != GGML_HASHSET_FULL);
+        GGML_ASSERT(ggml_bitset_get(src->visited_hash_set.used, igrad_src));
+        GGML_ASSERT(igrad_dst != GGML_HASHSET_FULL);
+        GGML_ASSERT(ggml_bitset_get(dst->visited_hash_set.used, igrad_dst));
+
+        dst->grads[igrad_dst]     = src->grads[igrad_src];
+        dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src];
     }
 
-    return new_graph;
+    return dst;
 }
 
 static void ggml_opt_alloc_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph) {
@@ -284,18 +292,13 @@ static void ggml_opt_alloc_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph
 
 ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
     ggml_opt_context_t result = new struct ggml_opt_context;
-    result->backend_sched        = params.backend_sched;
-    result->allocated_graph      = nullptr;
-    result->allocated_graph_copy = nullptr;
-    result->ctx_compute          = params.ctx_compute;
-    result->ctx_copy             = nullptr;
-    result->inputs               = params.inputs;
-    result->outputs              = params.outputs;
-    result->iter                 = 1;
-    result->opt_period           = params.opt_period;
-    result->opt_i                = 0;
-    result->get_opt_pars         = params.get_opt_pars;
-    result->get_opt_pars_ud      = params.get_opt_pars_ud;
+    result->backend_sched   = params.backend_sched;
+    result->ctx_compute     = params.ctx_compute;
+    result->inputs          = params.inputs;
+    result->outputs         = params.outputs;
+    result->opt_period      = params.opt_period;
+    result->get_opt_pars    = params.get_opt_pars;
+    result->get_opt_pars_ud = params.get_opt_pars_ud;
 
     GGML_ASSERT(result->inputs->data && "the inputs must be allocated statically");
     GGML_ASSERT(result->opt_period >= 1);
@@ -348,7 +351,6 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
 
     switch (params.loss_type) {
         case GGML_OPT_LOSS_TYPE_MEAN: {
-            result->labels = nullptr;
             result->loss = ggml_sum(result->ctx_static, result->outputs);
             ggml_set_name(result->loss, "loss_sum");
             const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs));
@@ -358,7 +360,6 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
             break;
         }
         case GGML_OPT_LOSS_TYPE_SUM: {
-            result->labels = nullptr;
             result->loss = ggml_sum(result->ctx_static, result->outputs);
             ggml_set_name(result->loss, "loss_sum");
             result->loss_per_datapoint = false;
@@ -413,14 +414,7 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
     }
 
     if (params.build_type == GGML_OPT_BUILD_TYPE_FORWARD) {
-        result->gb_grad = nullptr;
-        result->gb_opt  = nullptr;
-
         result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
-        result->buf_static_cpu = nullptr;
-
-        ggml_opt_alloc_graph(result, result->gf);
-
         return result;
     }
 
@@ -429,14 +423,8 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
     ggml_build_backward_expand(result->ctx_static, result->ctx_compute, result->gb_grad, accumulate);
 
     if (params.build_type == GGML_OPT_BUILD_TYPE_GRAD) {
-        result->gb_opt  = nullptr;
-
         result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
-        result->buf_static_cpu = nullptr;
-
-        ggml_opt_alloc_graph(result, result->gb_grad);
         ggml_graph_reset(result->gb_grad);
-
         return result;
     }
 
@@ -466,7 +454,6 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
 
     result->buf_static_cpu = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx_static_cpu, ggml_backend_cpu_buffer_type());
 
-    ggml_opt_alloc_graph(result, result->gb_opt);
     ggml_graph_reset(result->gb_opt);
 
     return result;
index ee72a173e07fd1e38ca3c95a0f0e3db326d1ba5a..719d75c70c75d29e9d3f00e17df8ed9e423e5337 100644 (file)
@@ -5019,8 +5019,10 @@ static void ggml_hash_map_free(struct hash_map * map) {
 }
 
 // utility functions to change gradients
-// if a is in acc_table, modify gradients in-place and mark result as gradient accumulator
-// else if a is in zero_table, replace a
+// isrc is the index of tensor in cgraph->visited_has_set.keys
+// the corresponding gradient (accumulators) are also at position isrc
+// if tensor has a gradient accumulator, modify that accumulator in-place
+// else if there is no gradient for tensor, set the corresponding value
 // else, just add/subtract/etc. the gradients
 
 static void ggml_add_or_set(
@@ -5028,11 +5030,14 @@ static void ggml_add_or_set(
         struct ggml_cgraph  * cgraph,
         size_t                isrc,
         struct ggml_tensor  * tensor) {
+    struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
+    GGML_ASSERT(src);
     if (cgraph->grads[isrc]) {
-        cgraph->grads[isrc] = ggml_add_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
+        cgraph->grads[isrc] = ggml_add_impl(ctx, cgraph->grads[isrc], tensor, /*inplace =*/ cgraph->grad_accs[isrc]);
     } else {
         cgraph->grads[isrc] = tensor;
     }
+    ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
     ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
 }
 
@@ -5040,18 +5045,20 @@ static void ggml_acc_or_set(
         struct ggml_context * ctx,
         struct ggml_cgraph  * cgraph,
         size_t                isrc,
-        struct ggml_tensor  * src,
         struct ggml_tensor  * tensor,
         const  size_t         nb1,
         const  size_t         nb2,
         const  size_t         nb3,
         const  size_t         offset) {
+    struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
+    GGML_ASSERT(src);
     if (cgraph->grads[isrc]) {
         cgraph->grads[isrc] = ggml_acc_impl(ctx, cgraph->grads[isrc], tensor, nb1, nb2, nb3, offset, cgraph->grad_accs[isrc]);
     } else {
         struct ggml_tensor * a_zero = ggml_scale(ctx, src, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
         cgraph->grads[isrc] = ggml_acc_impl(ctx, a_zero, tensor, nb1, nb2, nb3, offset, false);
     }
+    ggml_format_name(cgraph->grads[isrc], "grad for %s", cgraph->visited_hash_set.keys[isrc]->name);
     ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
 }
 
@@ -5059,13 +5066,15 @@ static void ggml_add1_or_set(
         struct ggml_context * ctx,
         struct ggml_cgraph  * cgraph,
         size_t                isrc,
-        struct ggml_tensor  * src,
         struct ggml_tensor  * tensor) {
+    struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
+    GGML_ASSERT(src);
     if (cgraph->grads[isrc]) {
         cgraph->grads[isrc] = ggml_add1_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
     } else {
         cgraph->grads[isrc] = ggml_repeat(ctx, tensor, src);
     }
+    ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
     ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
 }
 
@@ -5074,11 +5083,14 @@ static void ggml_sub_or_set(
         struct ggml_cgraph  * cgraph,
         size_t                isrc,
         struct ggml_tensor  * tensor) {
+    struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
+    GGML_ASSERT(src);
     if (cgraph->grads[isrc]) {
         cgraph->grads[isrc] = ggml_sub_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
     } else {
         cgraph->grads[isrc] = ggml_neg(ctx, tensor);
     }
+    ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
     ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
 }
 
@@ -5095,12 +5107,12 @@ static void ggml_compute_backward(
     struct ggml_tensor * src1 = tensor->src[1];
     struct ggml_tensor * src2 = tensor->src[2];
     struct ggml_hash_set * hash_set = &cgraph->visited_hash_set;
-    const size_t isrc0 = ggml_hash_find(hash_set, src0);
-    const size_t isrc1 = ggml_hash_find(hash_set, src1);
-    const size_t isrc2 = ggml_hash_find(hash_set, src2);
-    const bool src0_needs_grads = isrc0 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc0) && grads_needed[isrc0];
-    const bool src1_needs_grads = isrc1 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc1) && grads_needed[isrc1];
-    const bool src2_needs_grads = isrc2 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc2) && grads_needed[isrc2];
+    const size_t isrc0 = src0 ? ggml_hash_find(hash_set, src0) : (size_t) -1;
+    const size_t isrc1 = src1 ? ggml_hash_find(hash_set, src1) : (size_t) -1;
+    const size_t isrc2 = src2 ? ggml_hash_find(hash_set, src2) : (size_t) -1;
+    const bool src0_needs_grads = src0 && isrc0 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc0) && grads_needed[isrc0];
+    const bool src1_needs_grads = src1 && isrc1 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc1) && grads_needed[isrc1];
+    const bool src2_needs_grads = src2 && isrc2 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc2) && grads_needed[isrc2];
 
     switch (tensor->op) {
         case GGML_OP_DUP: {
@@ -5200,7 +5212,7 @@ static void ggml_compute_backward(
         } break;
         case GGML_OP_SUM: {
             if (src0_needs_grads) {
-                ggml_add1_or_set(ctx, cgraph, isrc0, src0, grad);
+                ggml_add1_or_set(ctx, cgraph, isrc0, grad);
             }
         } break;
         case GGML_OP_SUM_ROWS: {
@@ -5210,7 +5222,7 @@ static void ggml_compute_backward(
         } break;
         case GGML_OP_MEAN: {
             if (src0_needs_grads) {
-                ggml_add1_or_set(ctx, cgraph, isrc0, src0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false));
+                ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false));
             }
         } break;
         case GGML_OP_REPEAT: {
@@ -5363,7 +5375,7 @@ static void ggml_compute_backward(
                     nb3 = (nb3 / n0) * ng;
                 }
 
-                ggml_acc_or_set(ctx, cgraph, isrc0, src0, grad, nb1, nb2, nb3, offset);
+                ggml_acc_or_set(ctx, cgraph, isrc0, grad, nb1, nb2, nb3, offset);
             }
         } break;
         case GGML_OP_PERMUTE: {
@@ -5597,10 +5609,9 @@ void ggml_build_backward_expand(
 
     const int n_nodes_f = cgraph->n_nodes;
 
-    const size_t hash_size = ggml_hash_size(2*cgraph->size);
-    memset(cgraph->grads,     0, hash_size*sizeof(struct ggml_tensor *));
-    memset(cgraph->grad_accs, 0, hash_size*sizeof(struct ggml_tensor *));
-    bool * grads_needed = calloc(hash_size, sizeof(bool));
+    memset(cgraph->grads,     0, cgraph->visited_hash_set.size*sizeof(struct ggml_tensor *));
+    memset(cgraph->grad_accs, 0, cgraph->visited_hash_set.size*sizeof(struct ggml_tensor *));
+    bool * grads_needed = calloc(cgraph->visited_hash_set.size, sizeof(bool));
 
     {
         bool any_params = false;
@@ -5621,7 +5632,7 @@ void ggml_build_backward_expand(
             continue;
         }
 
-        bool node_needs_grad = node->flags & GGML_TENSOR_FLAG_PARAM;
+        bool node_needs_grad = (node->flags & GGML_TENSOR_FLAG_PARAM) || (node->flags & GGML_TENSOR_FLAG_LOSS);
         bool ignore_src[GGML_MAX_SRC] = {false};
         switch (node->op) {
             // gradients in node->src[0] for one reason or another have no effect on output gradients
@@ -5638,7 +5649,7 @@ void ggml_build_backward_expand(
             } break;
 
             // gradients in node->src[1] for one reason or another have no effect on output gradients
-            case GGML_OP_CPY:           // gradients in CPY target  are irrelevant
+            case GGML_OP_CPY:           // gradients in CPY target are irrelevant
             case GGML_OP_GET_ROWS:      // row indices not differentiable
             case GGML_OP_GET_ROWS_BACK: // same as for GET_ROWS
             case GGML_OP_ROPE:          // positions not differentiable
@@ -5665,9 +5676,12 @@ void ggml_build_backward_expand(
             node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE);
 
         const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
+        GGML_ASSERT(igrad != GGML_HASHSET_FULL);
+        GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, igrad));
         if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {
-            cgraph->grads[igrad]     = ggml_dup_tensor(ctx_static, node);
-            cgraph->grad_accs[igrad] = cgraph->grads[igrad];
+            cgraph->grad_accs[igrad] = ggml_dup_tensor(ctx_static, node);
+            cgraph->grads[igrad]     = cgraph->grad_accs[igrad];
+            ggml_format_name(cgraph->grad_accs[igrad], "grad acc for %s", node->name);
         }
         grads_needed[igrad] = true;
     }
@@ -5761,15 +5775,15 @@ struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) {
 
 struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1) {
     struct ggml_cgraph cgraph = {
-        /*.size         =*/ 0,
-        /*.n_nodes      =*/ i1 - i0,
-        /*.n_leafs      =*/ 0,
-        /*.nodes        =*/ cgraph0->nodes + i0,
-        /*.grads        =*/ cgraph0->grads ? cgraph0->grads + i0 : NULL,
-        /*.grad_accs    =*/ cgraph0->grad_accs ? cgraph0->grad_accs + i0 : NULL,
-        /*.leafs        =*/ NULL,
-        /*.hash_table   =*/ { 0, NULL, NULL },
-        /*.order        =*/ cgraph0->order,
+        /*.size             =*/ 0,
+        /*.n_nodes          =*/ i1 - i0,
+        /*.n_leafs          =*/ 0,
+        /*.nodes            =*/ cgraph0->nodes + i0,
+        /*.grads            =*/ NULL, // gradients would need visited_hash_set
+        /*.grad_accs        =*/ NULL,
+        /*.leafs            =*/ NULL,
+        /*.visited_hash_set =*/ { 0, NULL, NULL },
+        /*.order            =*/ cgraph0->order,
     };
 
     return cgraph;
@@ -5799,12 +5813,22 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
         }
     }
 
+    if (dst->grads) {
+        memset(dst->grads,     0, dst->visited_hash_set.size*sizeof(struct ggml_tensor *));
+        memset(dst->grad_accs, 0, dst->visited_hash_set.size*sizeof(struct ggml_tensor *));
+    }
     if (src->grads) {
         GGML_ASSERT(dst->grads     != NULL);
         GGML_ASSERT(dst->grad_accs != NULL);
         for (int i = 0; i < src->n_nodes; ++i) {
             const size_t igrad_src = ggml_hash_find(&src->visited_hash_set, src->nodes[i]);
             const size_t igrad_dst = ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]);
+
+            GGML_ASSERT(igrad_src != GGML_HASHSET_FULL);
+            GGML_ASSERT(ggml_bitset_get(src->visited_hash_set.used, igrad_src));
+            GGML_ASSERT(igrad_dst != GGML_HASHSET_FULL);
+            GGML_ASSERT(ggml_bitset_get(dst->visited_hash_set.used, igrad_dst));
+
             dst->grads[igrad_dst]     = src->grads[igrad_src];
             dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src];
         }
@@ -5839,12 +5863,8 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
 
         if (node->op == GGML_OP_OPT_STEP_ADAMW) {
             // clear momenta
-            if (node->src[2]->data) {
-                ggml_set_zero(node->src[2]);
-            }
-            if (node->src[3]->data) {
-                ggml_set_zero(node->src[3]);
-            }
+            ggml_set_zero(node->src[2]);
+            ggml_set_zero(node->src[3]);
         }
 
         // initial gradients of loss should be 1, 0 otherwise
index 01ac7166e0e224ae6318ff7ef5ec4630e4b7ac29..37342c156d7c13fe93fcb84f4c92b0c6d4801693 100644 (file)
@@ -819,7 +819,6 @@ struct test_case {
             }
         }
 
-        // TODO: refactor so that this check is only needed once
         for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
             if (!ggml_backend_supports_op(backend, t)) {
                 printf("not supported [%s] ", ggml_backend_name(backend));