ggml_tallocr_t src_allocr = node_allocr(src);
GGML_ASSERT(src_allocr != NULL); // all inputs should be assigned by now
if (src_allocr != node_allocr) {
+ // create a copy of the input in the split's backend
+ size_t id = hash_id(src);
+ if (sched->node_copies[id][cur_backend_id] == NULL) {
+ ggml_backend_t backend = get_allocr_backend(sched, cur_allocr);
+ struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
+ ggml_format_name(tensor_copy, "%s#%s", ggml_backend_name(backend), src->name);
+
+ sched->node_copies[id][cur_backend_id] = tensor_copy;
+ node_allocr(tensor_copy) = cur_allocr;
+ SET_CAUSE(tensor_copy, "4.cpy");
+
+ int n_inputs = sched->splits[cur_split].n_inputs++;
+ GGML_ASSERT(n_inputs < GGML_MAX_SPLIT_INPUTS);
+ sched->splits[cur_split].inputs[n_inputs] = src;
+ }
+ node->src[j] = sched->node_copies[id][cur_backend_id];
+
+#if 0
// check if the input is already in the split
bool found = false;
for (int k = 0; k < sched->splits[cur_split].n_inputs; k++) {
GGML_ASSERT(n_inputs < GGML_MAX_SPLIT_INPUTS);
sched->splits[cur_split].inputs[n_inputs] = src;
}
-
- // create a copy of the input in the split's backend
- size_t id = hash_id(src);
- if (sched->node_copies[id][cur_backend_id] == NULL) {
- ggml_backend_t backend = get_allocr_backend(sched, cur_allocr);
- struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
- ggml_format_name(tensor_copy, "%s#%s", ggml_backend_name(backend), src->name);
-
- sched->node_copies[id][cur_backend_id] = tensor_copy;
- node_allocr(tensor_copy) = cur_allocr;
- SET_CAUSE(tensor_copy, "4.cpy");
- }
- node->src[j] = sched->node_copies[id][cur_backend_id];
+#endif
}
}
}
uint64_t compute_start_us = ggml_time_us();
if (!sched->callback_eval) {
ggml_backend_graph_compute(split_backend, &split->graph);
- //ggml_backend_synchronize(split_backend); // necessary to measure compute time
+ //ggml_backend_synchronize(split_backend); // necessary to measure compute time
} else {
// similar to ggml_backend_compare_graph_backend
for (int j0 = 0; j0 < split->graph.n_nodes; j0++) {