]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
backend : offload large batches to GPU (llama/6083)
authorslaren <redacted>
Mon, 18 Mar 2024 10:03:04 +0000 (11:03 +0100)
committerGeorgi Gerganov <redacted>
Wed, 27 Mar 2024 11:20:00 +0000 (13:20 +0200)
* backend : offload large batches to GPU

* fix hip

* code cleanup

* fix CUDA split buffers

* Update ggml-backend-impl.h

Co-authored-by: Johannes Gäßler <redacted>
* cuda : fix memset without set_device

* imatrix : remove sched affix from weight names

* sched : add a new split if the current one has too many inputs
reduce max inputs per split
more cleanup

* update backends

ggml-ci

---------

Co-authored-by: Johannes Gäßler <redacted>
include/ggml/ggml-backend.h
src/ggml-alloc.c
src/ggml-backend-impl.h
src/ggml-backend.c
src/ggml-cuda.cu
src/ggml-cuda.h
src/ggml-kompute.cpp
src/ggml-metal.m
src/ggml-sycl.cpp
src/ggml-vulkan.cpp
src/ggml.c

index 099d9c258794ed0b3ad460549ca446a10c875130..422457ab6fc50a245640ec06ea635745125c0ecc 100644 (file)
@@ -70,11 +70,11 @@ extern "C" {
     GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph);
     GGML_API void                      ggml_backend_graph_plan_free  (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
 
-    GGML_API enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
-    GGML_API enum ggml_status ggml_backend_graph_compute     (ggml_backend_t backend, struct ggml_cgraph * cgraph);
-
-    GGML_API bool ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph);
+    GGML_API enum ggml_status ggml_backend_graph_plan_compute (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
+    GGML_API enum ggml_status ggml_backend_graph_compute      (ggml_backend_t backend, struct ggml_cgraph * cgraph);
+    GGML_API enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph);
     GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op);
+    GGML_API bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op);
 
     // tensor copy between different backends
     GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
index 60b86c27581621b508ba468415d3aa82a117810e..7ceafec309d0c5517d0491ae1598139b68eceaff 100644 (file)
@@ -548,7 +548,11 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
     for (int i = 0; i < graph->n_nodes; i++) {
         struct ggml_tensor * node = graph->nodes[i];
 
-        if (ggml_is_view(node)) {
+        // TODO: better way to add external dependencies
+        // GGML_OP_NONE does not appear normally in the graph nodes, but is used by ggml-backend to add dependencies to
+        // control when some tensors are allocated and freed. in this case, the dependencies are in `src`, but the node
+        // itself is never used and should not be considered a dependency
+        if (ggml_is_view(node) && node->op != GGML_OP_NONE) {
             struct ggml_tensor * view_src = node->view_src;
             ggml_gallocr_hash_get(galloc, view_src)->n_views += 1;
         }
@@ -565,8 +569,8 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
 
             ggml_gallocr_hash_get(galloc, src)->n_children += 1;
 
-            // allocate explicit inputs and leafs
-            if (src->flags & GGML_TENSOR_FLAG_INPUT || src->op == GGML_OP_NONE) {
+            // allocate explicit inputs
+            if (src->flags & GGML_TENSOR_FLAG_INPUT) {
                 ggml_gallocr_allocate_node(galloc, src, get_node_buffer_id(node_buffer_ids, i));
             }
         }
index e475e20e5f46a55d6ca5f3b23ffb2d3549ce9c2f..f121e1de420facd4b5a9a8487480f60b4cc14e4a 100644 (file)
@@ -103,6 +103,11 @@ extern "C" {
         // check if the backend supports an operation
         bool (*GGML_CALL supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
 
+        // check if the backend wants to run an operation, even if the weights are allocated in a CPU buffer
+        // these should be expensive operations with large batch sizes that may benefit from running on this backend
+        // even if the weight has to be copied from the CPU temporarily
+        bool (*GGML_CALL offload_op)(ggml_backend_t backend, const struct ggml_tensor * op);
+
         // (optional) event synchronization
         ggml_backend_event_t (*GGML_CALL event_new)         (ggml_backend_t backend);
         void                 (*GGML_CALL event_free)        (ggml_backend_event_t event);
index 31f8d5a6dd30befb6a4bea8676c457e8259c982f..9f0084df78942fffc5117dd195205fa56fc33144 100644 (file)
@@ -278,7 +278,7 @@ enum ggml_status ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_
     return err;
 }
 
-bool ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
     return backend->iface.graph_compute(backend, cgraph);
 }
 
@@ -286,6 +286,13 @@ bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor *
     return backend->iface.supports_op(backend, op);
 }
 
+bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op) {
+    if (backend->iface.offload_op != NULL) {
+        return backend->iface.offload_op(backend, op);
+    }
+    return false;
+}
+
 // backend copy
 
 static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
@@ -761,6 +768,10 @@ GGML_CALL static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(gg
 
     if (cpu_plan->cplan.work_size > 0) {
         cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size);
+        if (cpu_plan->cplan.work_data == NULL) {
+            free(cpu_plan);
+            return NULL;
+        }
     }
 
     cpu_plan->cplan.abort_callback      = cpu_ctx->abort_callback;
@@ -834,6 +845,7 @@ static struct ggml_backend_i cpu_backend_i = {
     /* .graph_plan_compute      = */ ggml_backend_cpu_graph_plan_compute,
     /* .graph_compute           = */ ggml_backend_cpu_graph_compute,
     /* .supports_op             = */ ggml_backend_cpu_supports_op,
+    /* .offload_op              = */ NULL,
     /* .event_new               = */ NULL,
     /* .event_free              = */ NULL,
     /* .event_record            = */ NULL,
@@ -999,11 +1011,11 @@ static bool ggml_is_view_op(enum ggml_op op) {
 #endif
 
 #ifndef GGML_SCHED_MAX_SPLITS
-#define GGML_SCHED_MAX_SPLITS 256
+#define GGML_SCHED_MAX_SPLITS 2048
 #endif
 
 #ifndef GGML_SCHED_MAX_SPLIT_INPUTS
-#define GGML_SCHED_MAX_SPLIT_INPUTS 16
+#define GGML_SCHED_MAX_SPLIT_INPUTS 4
 #endif
 
 #ifndef GGML_SCHED_MAX_COPIES
@@ -1043,8 +1055,9 @@ struct ggml_backend_sched {
     struct ggml_cgraph * graph;
 
     // graph splits
-    struct ggml_backend_sched_split splits[GGML_SCHED_MAX_SPLITS];
+    struct ggml_backend_sched_split * splits;
     int n_splits;
+    int splits_capacity;
 
     // pipeline parallelism support
     int n_copies;
@@ -1114,40 +1127,48 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
     // TODO: use supports_op to check if the backend supports the op
 
     // assign pre-allocated nodes to their backend
-    // dst
-    int cur_backend = ggml_backend_sched_backend_from_buffer(sched, tensor);
-    if (cur_backend != -1) {
+    int cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor);
+    if (cur_backend_id != -1) {
         SET_CAUSE(tensor, "1.dst");
-        return cur_backend;
+        return cur_backend_id;
     }
 
     // view_src
     if (tensor->view_src != NULL) {
-        cur_backend = ggml_backend_sched_backend_from_buffer(sched, tensor->view_src);
-        if (cur_backend != -1) {
+        cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor->view_src);
+        if (cur_backend_id != -1) {
             SET_CAUSE(tensor, "1.vsrc");
-            return cur_backend;
+            return cur_backend_id;
         }
     }
 
-    // input
+    // graph input
     if (tensor->flags & GGML_TENSOR_FLAG_INPUT) {
-        cur_backend = sched->n_backends - 1; // last backend (assumed CPU)
+        cur_backend_id = sched->n_backends - 1; // last backend (assumed CPU)
         SET_CAUSE(tensor, "1.inp");
-        return cur_backend;
+        return cur_backend_id;
     }
 
     // assign nodes that use weights to the backend of the weights
+    // operations with weights are preferably run on the same backend as the weights
     for (int i = 0; i < GGML_MAX_SRC; i++) {
         const struct ggml_tensor * src = tensor->src[i];
         if (src == NULL) {
             continue;
         }
         if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
-            int src_backend = ggml_backend_sched_backend_from_buffer(sched, src);
-            // operations with weights are always run on the same backend as the weights
+            int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src);
+            // check if a backend with higher prio wants to offload the op
+            if (src_backend_id == sched->n_backends - 1) {
+                for (int b = 0; b < src_backend_id; b++) {
+                    if (ggml_backend_offload_op(sched->backends[b], tensor)) {
+                        SET_CAUSE(tensor, "1.off");
+                        return b;
+                    }
+                }
+            }
             SET_CAUSE(tensor, "1.wgt%d", i);
-            return src_backend;
+            return src_backend_id;
         }
     }
 
@@ -1227,28 +1248,31 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
     // pass 1: assign backends to ops with pre-allocated inputs
     for (int i = 0; i < graph->n_leafs; i++) {
         struct ggml_tensor * leaf = graph->leafs[i];
-        if (tensor_backend_id(leaf) != -1) {
+        int * leaf_backend_id = &tensor_backend_id(leaf);
+        if (*leaf_backend_id != -1) {
             // do not overwrite user assignments
             continue;
         }
-        tensor_backend_id(leaf) = ggml_backend_sched_backend_id_from_cur(sched, leaf);
+        *leaf_backend_id = ggml_backend_sched_backend_id_from_cur(sched, leaf);
     }
 
     for (int i = 0; i < graph->n_nodes; i++) {
         struct ggml_tensor * node = graph->nodes[i];
-        if (tensor_backend_id(node) != -1) {
+        int * node_backend_id = &tensor_backend_id(node);
+        if (*node_backend_id != -1) {
             // do not overwrite user assignments
             continue;
         }
-        tensor_backend_id(node) = ggml_backend_sched_backend_id_from_cur(sched, node);
+        *node_backend_id = ggml_backend_sched_backend_id_from_cur(sched, node);
         // src
         for (int j = 0; j < GGML_MAX_SRC; j++) {
             struct ggml_tensor * src = node->src[j];
             if (src == NULL) {
                 continue;
             }
-            if (tensor_backend_id(src) == -1) {
-                tensor_backend_id(src) = ggml_backend_sched_backend_id_from_cur(sched, src);
+            int * src_backend_id = &tensor_backend_id(src);
+            if (*src_backend_id == -1) {
+                *src_backend_id = ggml_backend_sched_backend_id_from_cur(sched, src);
             }
         }
     }
@@ -1270,21 +1294,20 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
             if (ggml_is_view_op(node->op)) {
                 continue;
             }
-            int tensor_backend_id = tensor_backend_id(node);
-            if (tensor_backend_id != -1) {
-                if (tensor_backend_id == sched->n_backends - 1) {
+            int * node_backend_id = &tensor_backend_id(node);
+            if (*node_backend_id != -1) {
+                if (*node_backend_id == sched->n_backends - 1) {
                     // skip cpu (lowest prio backend)
                     cur_backend_id = -1;
                 } else {
-                    cur_backend_id = tensor_backend_id;
+                    cur_backend_id = *node_backend_id;
                 }
             } else {
-                tensor_backend_id(node) = cur_backend_id;
+                *node_backend_id = cur_backend_id;
                 SET_CAUSE(node, "2.2");
             }
         }
     }
-
     // pass 2.1 expand gpu up
     {
         int cur_backend_id = -1;
@@ -1293,22 +1316,20 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
             if (ggml_is_view_op(node->op)) {
                 continue;
             }
-            int tensor_backend_id = tensor_backend_id(node);
-            if (tensor_backend_id != -1) {
-                if (tensor_backend_id == sched->n_backends - 1) {
+            int * node_backend_id = &tensor_backend_id(node);
+            if (*node_backend_id != -1) {
+                if (*node_backend_id == sched->n_backends - 1) {
                     // skip cpu (lowest prio backend)
                     cur_backend_id = -1;
                 } else {
-                    cur_backend_id = tensor_backend_id;
+                    cur_backend_id = *node_backend_id;
                 }
             } else {
-                tensor_backend_id(node) = cur_backend_id;
+                *node_backend_id = cur_backend_id;
                 SET_CAUSE(node, "2.1");
             }
         }
     }
-
-
     // pass 2.4 expand rest down
     {
         int cur_backend_id = -1;
@@ -1317,16 +1338,16 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
             if (ggml_is_view_op(node->op)) {
                 continue;
             }
-            int tensor_backend_id = tensor_backend_id(node);
-            if (tensor_backend_id != -1) {
-                cur_backend_id = tensor_backend_id;
+            int * node_backend_id = &tensor_backend_id(node);
+            if (*node_backend_id != -1) {
+                cur_backend_id = *node_backend_id;
             } else {
-                tensor_backend_id(node) = cur_backend_id;
+                *node_backend_id = cur_backend_id;
                 SET_CAUSE(node, "2.4");
             }
         }
     }
-        // pass 2.3 expand rest up
+    // pass 2.3 expand rest up
     {
         int cur_backend_id = -1;
         for (int i = graph->n_nodes - 1; i >= 0; i--) {
@@ -1334,11 +1355,11 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
             if (ggml_is_view_op(node->op)) {
                 continue;
             }
-            int tensor_backend_id = tensor_backend_id(node);
-            if (tensor_backend_id != -1) {
-                cur_backend_id = tensor_backend_id;
+            int * node_backend_id = &tensor_backend_id(node);
+            if (*node_backend_id != -1) {
+                cur_backend_id = *node_backend_id;
             } else {
-                tensor_backend_id(node) = cur_backend_id;
+                *node_backend_id = cur_backend_id;
                 SET_CAUSE(node, "2.3");
             }
         }
@@ -1351,9 +1372,9 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
     // pass 3: assign backends to remaining src from dst and view_src
     for (int i = 0; i < graph->n_nodes; i++) {
         struct ggml_tensor * node = graph->nodes[i];
-        int cur_backend_id = tensor_backend_id(node);
-        if (node->view_src != NULL && cur_backend_id == -1) {
-            cur_backend_id = tensor_backend_id(node) = tensor_backend_id(node->view_src);
+        int * cur_backend_id = &tensor_backend_id(node);
+        if (node->view_src != NULL && *cur_backend_id == -1) {
+            *cur_backend_id = tensor_backend_id(node->view_src);
             SET_CAUSE(node, "3.vsrc");
         }
         for (int j = 0; j < GGML_MAX_SRC; j++) {
@@ -1361,14 +1382,14 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
             if (src == NULL) {
                 continue;
             }
-            int src_backend_id = tensor_backend_id(src);
-            if (src_backend_id == -1) {
+            int * src_backend_id = &tensor_backend_id(src);
+            if (*src_backend_id == -1) {
                 if (src->view_src != NULL) {
                     // views are always on the same backend as the source
-                    tensor_backend_id(src) = tensor_backend_id(src->view_src);
+                    *src_backend_id = tensor_backend_id(src->view_src);
                     SET_CAUSE(src, "3.vsrc");
                 } else {
-                    tensor_backend_id(src) = cur_backend_id;
+                    *src_backend_id = *cur_backend_id;
                     SET_CAUSE(src, "3.cur");
                 }
             }
@@ -1380,19 +1401,20 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
 
     // pass 4: split graph, find tensors that need to be copied
     {
-        int cur_split = 0;
+        int i_split = 0;
+        struct ggml_backend_sched_split * split = &sched->splits[0];
         // find the backend of the first split, skipping view ops
         for (int i = 0; i < graph->n_nodes; i++) {
             struct ggml_tensor * node = graph->nodes[i];
             if (!ggml_is_view_op(node->op)) {
-                sched->splits[0].backend_id = tensor_backend_id(node);
+                split->backend_id = tensor_backend_id(node);
                 break;
             }
         }
-        sched->splits[0].i_start = 0;
-        sched->splits[0].n_inputs = 0;
-        memset(sched->splits[0].inputs, 0, sizeof(sched->splits[0].inputs)); //HACK
-        int cur_backend_id = sched->splits[0].backend_id;
+        split->i_start = 0;
+        split->n_inputs = 0;
+        memset(split->inputs, 0, sizeof(split->inputs)); //HACK
+        int cur_backend_id = split->backend_id;
         for (int i = 0; i < graph->n_nodes; i++) {
             struct ggml_tensor * node = graph->nodes[i];
 
@@ -1400,18 +1422,54 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
                 continue;
             }
 
-            int tensor_backend_id = tensor_backend_id(node);
+            const int node_backend_id = tensor_backend_id(node);
 
-            GGML_ASSERT(tensor_backend_id != -1); // all nodes should be assigned by now
+            GGML_ASSERT(node_backend_id != -1); // all nodes should be assigned by now
 
-            if (tensor_backend_id != cur_backend_id) {
-                sched->splits[cur_split].i_end = i;
-                cur_split++;
-                GGML_ASSERT(cur_split < GGML_SCHED_MAX_SPLITS);
-                sched->splits[cur_split].backend_id = tensor_backend_id;
-                sched->splits[cur_split].i_start = i;
-                sched->splits[cur_split].n_inputs = 0;
-                cur_backend_id = tensor_backend_id;
+            // check if we should start a new split based on the sources of the current node
+            bool need_new_split = false;
+            if (node_backend_id == cur_backend_id && split->n_inputs > 0) {
+                for (int j = 0; j < GGML_MAX_SRC; j++) {
+                    struct ggml_tensor * src = node->src[j];
+                    if (src == NULL) {
+                        continue;
+                    }
+                    // check if a weight is on a different backend
+                    // by starting a new split, the memory of the previously offloaded weights can be reused
+                    if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
+                        int src_backend_id = tensor_backend_id(src);
+                        if (src_backend_id != -1 && src_backend_id != cur_backend_id) {
+                            need_new_split = true;
+                            break;
+                        }
+                    }
+                    // check if the split has too many inputs
+                    if (split->n_inputs == GGML_SCHED_MAX_SPLIT_INPUTS) {
+                        const size_t id = hash_id(src);
+                        int src_backend_id = sched->tensor_backend_id[id];
+                        if (src_backend_id != cur_backend_id && sched->tensor_copies[hash_id(src)][cur_backend_id][0] == NULL) {
+                            //printf("starting new split because of too many inputs: node %s, input %s\n", node->name, src->name);
+                            need_new_split = true;
+                            break;
+                        }
+                    }
+                }
+            }
+
+            if (node_backend_id != cur_backend_id || need_new_split) {
+                split->i_end = i;
+                i_split++;
+                if (i_split >= sched->splits_capacity) {
+                    sched->splits_capacity *= 2;
+                    sched->splits = realloc(sched->splits, sched->splits_capacity * sizeof(struct ggml_backend_sched_split));
+                    GGML_ASSERT(sched->splits != NULL);
+                }
+                GGML_ASSERT(i_split < GGML_SCHED_MAX_SPLITS);
+                split = &sched->splits[i_split];
+                split->backend_id = node_backend_id;
+                split->i_start = i;
+                split->n_inputs = 0;
+                cur_backend_id = node_backend_id;
             }
 
             // find inputs that are not on the same backend
@@ -1421,10 +1479,10 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
                     continue;
                 }
 
-                int src_backend_id = tensor_backend_id(src);
+                const int src_backend_id = tensor_backend_id(src);
                 assert(src_backend_id != -1); // all inputs should be assigned by now
 
-                if (src->flags & GGML_TENSOR_FLAG_INPUT)  {
+                if (src->flags & GGML_TENSOR_FLAG_INPUT && sched->n_copies > 1)  {
                     size_t id = hash_id(src);
                     if (sched->tensor_copies[id][src_backend_id][0] == NULL) {
                         ggml_backend_t backend = sched->backends[src_backend_id];
@@ -1441,7 +1499,6 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
                                 ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor
                             }
                             sched->tensor_copies[id][src_backend_id][c] = tensor_copy;
-                            tensor_backend_id(tensor_copy) = src_backend_id;
                             SET_CAUSE(tensor_copy, "4.cpy");
                         }
                         int n_graph_inputs = sched->n_graph_inputs++;
@@ -1450,9 +1507,9 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
                     }
                 }
 
-                if (src_backend_id != tensor_backend_id) {
+                if (src_backend_id != node_backend_id) {
                     // create a copy of the input in the split's backend
-                    size_t id = hash_id(src);
+                    const size_t id = hash_id(src);
                     if (sched->tensor_copies[id][cur_backend_id][0] == NULL) {
                         ggml_backend_t backend = sched->backends[cur_backend_id];
                         for (int c = 0; c < sched->n_copies; c++) {
@@ -1463,76 +1520,42 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
                                 ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor
                             }
                             sched->tensor_copies[id][cur_backend_id][c] = tensor_copy;
-                            tensor_backend_id(tensor_copy) = cur_backend_id;
                             SET_CAUSE(tensor_copy, "4.cpy");
                         }
-                        int n_inputs = sched->splits[cur_split].n_inputs++;
+                        int n_inputs = split->n_inputs++;
                         GGML_ASSERT(n_inputs < GGML_SCHED_MAX_SPLIT_INPUTS);
-                        sched->splits[cur_split].inputs[n_inputs] = src;
+                        split->inputs[n_inputs] = src;
                     }
                     node->src[j] = sched->tensor_copies[id][cur_backend_id][sched->cur_copy];
                 }
             }
         }
-        sched->splits[cur_split].i_end = graph->n_nodes;
-        sched->n_splits = cur_split + 1;
+        split->i_end = graph->n_nodes;
+        sched->n_splits = i_split + 1;
     }
 #ifdef DEBUG_PASS4
     fprintf(stderr, "PASS 4 ASSIGNMENTS\n"); ggml_backend_sched_print_assignments(sched, graph);
 #endif
 
-#ifndef NDEBUG
-    // sanity check: all sources should have the same backend as the node
-    for (int i = 0; i < graph->n_nodes; i++) {
-        struct ggml_tensor * node = graph->nodes[i];
-        ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
-        if (tensor_backend == NULL) {
-            fprintf(stderr, "!!!!!!! %s has no backend\n", node->name);
-        }
-        if (node->view_src != NULL && tensor_backend != ggml_backend_sched_get_tensor_backend(sched, node->view_src)) {
-            fprintf(stderr, "!!!!!!! %s has backend %s, view_src %s has backend %s\n",
-                node->name, tensor_backend ? ggml_backend_name(tensor_backend) : "NULL",
-                node->view_src->name, ggml_backend_sched_get_tensor_backend(sched, node->view_src) ?
-                    ggml_backend_name(ggml_backend_sched_get_tensor_backend(sched, node->view_src)) : "NULL");
-        }
-        for (int j = 0; j < GGML_MAX_SRC; j++) {
-            struct ggml_tensor * src = node->src[j];
-            if (src == NULL) {
-                continue;
-            }
-            ggml_backend_t src_backend = ggml_backend_sched_get_tensor_backend(sched, src);
-            if (src_backend != tensor_backend /* && src_backend != NULL */) {
-                fprintf(stderr, "!!!! %s has backend %s, src %d (%s) has backend %s\n",
-                    node->name, tensor_backend ? ggml_backend_name(tensor_backend) : "NULL",
-                    j, src->name, src_backend ? ggml_backend_name(src_backend) : "NULL");
-            }
-            if (src->view_src != NULL && src_backend != ggml_backend_sched_get_tensor_backend(sched, src->view_src)) {
-                fprintf(stderr, "!!!!!!! [src] %s has backend %s, view_src %s has backend %s\n",
-                    src->name, src_backend ? ggml_backend_name(src_backend) : "NULL",
-                    src->view_src->name, ggml_backend_sched_get_tensor_backend(sched, src->view_src) ?
-                        ggml_backend_name(ggml_backend_sched_get_tensor_backend(sched, src->view_src)) : "NULL");
-            }
-        }
-    }
-    fflush(stderr);
-#endif
-
     // create copies of the graph for each split
     // TODO: avoid this copy
-    struct ggml_cgraph * graph_copy = ggml_new_graph_custom(sched->ctx, graph->n_nodes + sched->n_splits*GGML_SCHED_MAX_SPLIT_INPUTS, false);
+    struct ggml_cgraph * graph_copy = ggml_new_graph_custom(sched->ctx, graph->n_nodes + sched->n_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2, false);
     for (int i = 0; i < sched->n_splits; i++) {
         struct ggml_backend_sched_split * split = &sched->splits[i];
         split->graph = ggml_graph_view(graph, split->i_start, split->i_end);
 
         // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split
         for (int j = 0; j < split->n_inputs; j++) {
+            assert(graph_copy->size > (graph_copy->n_nodes + 1));
+
             struct ggml_tensor * input = split->inputs[j];
-            struct ggml_tensor * input_cpy = sched->tensor_copies[hash_id(input)][split->backend_id][sched->cur_copy];
+            const size_t input_id = hash_id(input);
+            struct ggml_tensor * input_cpy = sched->tensor_copies[input_id][split->backend_id][sched->cur_copy];
 
             // add a dependency to the input source so that it is not freed before the copy is done
             struct ggml_tensor * input_dep = ggml_view_tensor(sched->ctx, input);
             input_dep->src[0] = input;
-            sched->node_backend_ids[graph_copy->n_nodes] = tensor_backend_id(input);
+            sched->node_backend_ids[graph_copy->n_nodes] = sched->tensor_backend_id[input_id];
             graph_copy->nodes[graph_copy->n_nodes++] = input_dep;
 
             // add a dependency to the input copy so that it is allocated at the start of the split
@@ -1541,6 +1564,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
         }
 
         for (int j = split->i_start; j < split->i_end; j++) {
+            assert(graph_copy->size > graph_copy->n_nodes);
             sched->node_backend_ids[graph_copy->n_nodes] = tensor_backend_id(graph->nodes[j]);
             graph_copy->nodes[graph_copy->n_nodes++] = graph->nodes[j];
         }
@@ -1625,13 +1649,12 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
                 }
                 ggml_backend_tensor_copy(input, input_cpy);
             } else {
+                // wait for the split backend to finish using the input before overwriting it
                 if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
                     ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]);
                 } else {
                     ggml_backend_synchronize(split_backend);
-                    ggml_backend_synchronize(input_backend);
                 }
-
                 ggml_backend_tensor_copy_async(input_backend, split_backend, input, input_cpy);
             }
         }
@@ -1701,17 +1724,21 @@ ggml_backend_sched_t ggml_backend_sched_new(
     struct ggml_backend_sched * sched = calloc(sizeof(struct ggml_backend_sched), 1);
 
     // initialize hash table
-    sched->hash_set          = ggml_hash_set_new(graph_size + GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS);
+    sched->hash_set          = ggml_hash_set_new(graph_size);
     sched->tensor_backend_id = calloc(sizeof(sched->tensor_backend_id[0]), sched->hash_set.size);
     sched->tensor_copies     = calloc(sizeof(sched->tensor_copies[0]), sched->hash_set.size);
-    sched->node_backend_ids  = calloc(sizeof(sched->node_backend_ids[0]), graph_size);
-    sched->leaf_backend_ids  = calloc(sizeof(sched->leaf_backend_ids[0]), graph_size);
+
+    const size_t nodes_size = graph_size + GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2;
+    sched->node_backend_ids  = calloc(sizeof(sched->node_backend_ids[0]), nodes_size);
+    sched->leaf_backend_ids  = calloc(sizeof(sched->leaf_backend_ids[0]), nodes_size);
 
     sched->n_backends = n_backends;
 
     sched->n_copies = parallel ? GGML_SCHED_MAX_COPIES : 1;
 
-    GGML_ASSERT(sched->n_copies <= GGML_SCHED_MAX_COPIES);
+    const int initial_splits_capacity = 16;
+    sched->splits = calloc(sizeof(sched->splits[0]), initial_splits_capacity);
+    sched->splits_capacity = initial_splits_capacity;
 
     for (int b = 0; b < n_backends; b++) {
         sched->backends[b] = backends[b];
@@ -1742,6 +1769,7 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
     }
     ggml_gallocr_free(sched->galloc);
     ggml_free(sched->ctx);
+    free(sched->splits);
     free(sched->hash_set.keys);
     free(sched->tensor_backend_id);
     free(sched->tensor_copies);
@@ -1762,6 +1790,8 @@ void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
 }
 
 bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) {
+    GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes);
+
     ggml_backend_sched_split_graph(sched, measure_graph);
 
     // TODO: extract this to a separate function
@@ -1776,7 +1806,7 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
 }
 
 bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
-    GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS);
+    GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes);
 
     ggml_backend_sched_split_graph(sched, graph);
 
index db595409aa735abd1e0600c4f497b03288a20d5a..139025588ed6119381b689d1d290534489f2b458 100644 (file)
 #define cudaGetDeviceProperties hipGetDeviceProperties
 #define cudaGetErrorString hipGetErrorString
 #define cudaGetLastError hipGetLastError
+#define cudaHostRegister hipHostRegister
+#define cudaHostRegisterPortable hipHostRegisterPortable
+#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
+#define cudaHostUnregister hipHostUnregister
 #define cudaLaunchHostFunc hipLaunchHostFunc
 #ifdef GGML_HIP_UMA
 #define cudaMalloc hipMallocManaged
@@ -7787,11 +7791,7 @@ struct cuda_pool_alloc {
 
 static bool g_cublas_loaded = false;
 
-GGML_CALL bool ggml_cublas_loaded(void) {
-    return g_cublas_loaded;
-}
-
-GGML_CALL void ggml_init_cublas() {
+static void ggml_init_cublas() {
     static bool initialized = false;
 
     if (!initialized) {
@@ -7880,7 +7880,7 @@ GGML_CALL void ggml_init_cublas() {
     }
 }
 
-GGML_CALL void * ggml_cuda_host_malloc(size_t size) {
+static void * ggml_cuda_host_malloc(size_t size) {
     if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
         return nullptr;
     }
@@ -7890,7 +7890,7 @@ GGML_CALL void * ggml_cuda_host_malloc(size_t size) {
     if (err != cudaSuccess) {
         // clear the error
         cudaGetLastError();
-        fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
+        fprintf(stderr, "%s: warning: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
             size/1024.0/1024.0, cudaGetErrorString(err));
         return nullptr;
     }
@@ -7898,7 +7898,7 @@ GGML_CALL void * ggml_cuda_host_malloc(size_t size) {
     return ptr;
 }
 
-GGML_CALL void ggml_cuda_host_free(void * ptr) {
+static void ggml_cuda_host_free(void * ptr) {
     CUDA_CHECK(cudaFreeHost(ptr));
 }
 
@@ -9036,21 +9036,13 @@ static void ggml_cuda_op_soft_max(
 
     // positions tensor
     float * src2_dd = nullptr;
-    cuda_pool_alloc<float> src2_f;
 
     ggml_tensor * src2 = dst->src[2];
     const bool use_src2 = src2 != nullptr;
 
     if (use_src2) {
-        const bool src2_on_device = src2->backend == GGML_BACKEND_TYPE_GPU;
-
-        if (src2_on_device) {
-            ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra;
-            src2_dd = (float *) src2_extra->data_device[g_main_device];
-        } else {
-            src2_dd = src2_f.alloc(ggml_nelements(src2));
-            CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream));
-        }
+        ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra;
+        src2_dd = (float *) src2_extra->data_device[g_main_device];
     }
 
     soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream);
@@ -9107,55 +9099,24 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
     ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
     ggml_tensor_extra_gpu * dst_extra  =            (ggml_tensor_extra_gpu *)  dst->extra;
 
-    const bool src0_on_device =             src0->backend == GGML_BACKEND_TYPE_GPU || src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
-    const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_TYPE_GPU;
-    const bool  dst_on_device =              dst->backend == GGML_BACKEND_TYPE_GPU;
-
     // dd = data device
     float * src0_ddf = nullptr;
     float * src1_ddf = nullptr;
     float *  dst_ddf = nullptr;
 
-    cuda_pool_alloc<float> src0_f;
-    cuda_pool_alloc<float> src1_f;
-    cuda_pool_alloc<float>  dst_f;
-
     ggml_cuda_set_device(g_main_device);
     cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
 
-    if (src0_on_device) {
-        src0_ddf = (float *) src0_extra->data_device[g_main_device];
-    } else {
-        src0_ddf = src0_f.alloc(ggml_nelements(src0));
-        CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf, src0, 0, 0, 0, nrows0, main_stream));
-    }
+    src0_ddf = (float *) src0_extra->data_device[g_main_device];
 
     if (use_src1) {
-        if (src1_on_device) {
-            src1_ddf = (float *) src1_extra->data_device[g_main_device];
-        } else {
-            src1_ddf = src1_f.alloc(ggml_nelements(src1));
-            CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf, src1, 0, 0, 0, nrows1, main_stream));
-        }
-    }
-    if (dst_on_device) {
-        dst_ddf = (float *) dst_extra->data_device[g_main_device];
-    } else {
-        dst_ddf = dst_f.alloc(ggml_nelements(dst));
+        src1_ddf = (float *) src1_extra->data_device[g_main_device];
     }
+    dst_ddf = (float *) dst_extra->data_device[g_main_device];
 
     // do the computation
     op(src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
     CUDA_CHECK(cudaGetLastError());
-
-    // copy dst to host if necessary
-    if (!dst_on_device) {
-        CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream));
-    }
-
-    if (dst->backend == GGML_BACKEND_TYPE_CPU) {
-        CUDA_CHECK(cudaDeviceSynchronize());
-    }
 }
 
 static void ggml_cuda_set_peer_access(const int n_tokens) {
@@ -9251,7 +9212,6 @@ static void ggml_cuda_op_mul_mat(
     ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
     ggml_tensor_extra_gpu *  dst_extra = (ggml_tensor_extra_gpu *)  dst->extra;
 
-    const bool src0_on_device = src0->backend == GGML_BACKEND_TYPE_GPU || src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
     const bool src0_is_contiguous = ggml_is_contiguous(src0);
     const bool src1_is_contiguous = ggml_is_contiguous(src1);
 
@@ -9322,13 +9282,13 @@ static void ggml_cuda_op_mul_mat(
 
         used_devices++;
 
-        const bool src1_on_device = src1->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device;
-        const bool  dst_on_device =  dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device;
+        const bool src1_on_device = id == g_main_device; // TODO: check from buffer
+        const bool  dst_on_device = id == g_main_device;
 
         ggml_cuda_set_device(id);
         cudaStream_t stream = g_cudaStreams[id][0];
 
-        if (src0_on_device && src0_is_contiguous) {
+        if (src0_is_contiguous) {
             dev[id].src0_dd = (char *) src0_extra->data_device[id];
         } else {
             dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ggml_nbytes(src0));
@@ -9374,8 +9334,8 @@ static void ggml_cuda_op_mul_mat(
                 continue;
             }
 
-            const bool src1_on_device = src1->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device;
-            const bool  dst_on_device =  dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device;
+            const bool src1_on_device = id == g_main_device; // TODO: check from buffer
+            const bool  dst_on_device = id == g_main_device;
             const int64_t row_diff = dev[id].row_high - dev[id].row_low;
 
             ggml_cuda_set_device(id);
@@ -9400,12 +9360,12 @@ static void ggml_cuda_op_mul_mat(
 
                 // the main device memory buffer can be on VRAM scratch, with space for all partial results
                 // in that case an offset on dst_ddf_i is needed
-                if (dst->backend == GGML_BACKEND_TYPE_GPU && id == g_main_device) {
+                if (id == g_main_device) {
                     dst_dd_i += dev[id].row_low; // offset is 0 if no tensor split
                 }
 
                 // copy src0, src1 to device if necessary
-                if (src1->backend == GGML_BACKEND_TYPE_GPU && src1_is_contiguous) {
+                if (src1_is_contiguous) {
                     if (id != g_main_device) {
                         if (convert_src1_to_q8_1) {
                             char * src1_ddq_i_source = dev[g_main_device].src1_ddq + src1_ddq_i_offset;
@@ -9418,19 +9378,19 @@ static void ggml_cuda_op_mul_mat(
                                                             src1_ncols*ne10*sizeof(float), stream));
                         }
                     }
-                } else if (src1->backend == GGML_BACKEND_TYPE_CPU || (src1_on_device && !src1_is_contiguous)) {
+                } else if (src1_on_device && !src1_is_contiguous) {
                     CUDA_CHECK(ggml_cuda_cpy_tensor_2d(
                                 src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
                 } else {
                     GGML_ASSERT(false);
                 }
 
-                if (convert_src1_to_q8_1 && (src1->backend == GGML_BACKEND_TYPE_CPU || !src1_is_contiguous)) {
+                if (convert_src1_to_q8_1 && !src1_is_contiguous) {
                     quantize_row_q8_1_cuda(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream);
                     CUDA_CHECK(cudaGetLastError());
                 }
 
-                if (src1_col_0 == 0 && (!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) {
+                if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
                     CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream));
                 }
 
@@ -9441,17 +9401,7 @@ static void ggml_cuda_op_mul_mat(
 
                 // copy dst to host or other device if necessary
                 if (!dst_on_device) {
-                    void * dst_off_device;
-                    cudaMemcpyKind kind;
-                    if (dst->backend == GGML_BACKEND_TYPE_CPU) {
-                        dst_off_device = dst->data;
-                        kind = cudaMemcpyDeviceToHost;
-                    } else if (dst->backend == GGML_BACKEND_TYPE_GPU) {
-                        dst_off_device = dst_extra->data_device[g_main_device];
-                        kind = cudaMemcpyDeviceToDevice;
-                    } else {
-                        GGML_ASSERT(false);
-                    }
+                    void * dst_off_device = dst_extra->data_device[g_main_device];
                     if (split) {
                         // src0 = weight matrix is saved as a transposed matrix for better memory layout.
                         // dst is NOT transposed.
@@ -9462,28 +9412,26 @@ static void ggml_cuda_op_mul_mat(
                         GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
                         dhf_dst_i += src1_col_0*ne0 + dev[id].row_low;
 #if !defined(GGML_USE_HIPBLAS)
-                        if (kind == cudaMemcpyDeviceToDevice) {
-                            // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
-                            cudaMemcpy3DPeerParms p = {};
-                            p.dstDevice = g_main_device;
-                            p.dstPtr = make_cudaPitchedPtr(dhf_dst_i, ne0*sizeof(float), row_diff, src1_ncols);
-                            p.srcDevice = id;
-                            p.srcPtr = make_cudaPitchedPtr(dst_dd_i, row_diff*sizeof(float), row_diff, src1_ncols);
-                            p.extent = make_cudaExtent(row_diff*sizeof(float), src1_ncols, 1);
-                            CUDA_CHECK(cudaMemcpy3DPeerAsync(&p, stream));
-                        } else
+                        // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
+                        cudaMemcpy3DPeerParms p = {};
+                        p.dstDevice = g_main_device;
+                        p.dstPtr = make_cudaPitchedPtr(dhf_dst_i, ne0*sizeof(float), row_diff, src1_ncols);
+                        p.srcDevice = id;
+                        p.srcPtr = make_cudaPitchedPtr(dst_dd_i, row_diff*sizeof(float), row_diff, src1_ncols);
+                        p.extent = make_cudaExtent(row_diff*sizeof(float), src1_ncols, 1);
+                        CUDA_CHECK(cudaMemcpy3DPeerAsync(&p, stream));
+#else
+                        // HIP does not support cudaMemcpy3DPeerAsync or vmm pools
+                        CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float),
+                                                        dst_dd_i, row_diff*sizeof(float),
+                                                        row_diff*sizeof(float), src1_ncols,
+                                                        cudaMemcpyDeviceToDevice, stream));
 #endif
-                        {
-                            CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float),
-                                                            dst_dd_i, row_diff*sizeof(float),
-                                                            row_diff*sizeof(float), src1_ncols,
-                                                            kind, stream));
-                        }
                     } else {
                         float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
                         GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
                         dhf_dst_i += src1_col_0*ne0;
-                        CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_dd_i, src1_ncols*ne0*sizeof(float), kind, stream));
+                        CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_dd_i, src1_ncols*ne0*sizeof(float), cudaMemcpyDeviceToDevice, stream));
                     }
                 }
 
@@ -9510,11 +9458,6 @@ static void ggml_cuda_op_mul_mat(
             }
         }
     }
-
-    if (dst->backend == GGML_BACKEND_TYPE_CPU) {
-        ggml_cuda_set_device(g_main_device);
-        CUDA_CHECK(cudaDeviceSynchronize());
-    }
 }
 
 static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -9599,36 +9542,19 @@ static void ggml_cuda_pad(const ggml_tensor * src0, const ggml_tensor * src1, gg
 static void ggml_cuda_arange(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *)  dst->extra;
 
-    const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU;
-
     // dd = data device
     float * src0_ddf = nullptr;
     float * src1_ddf = nullptr;
     float *  dst_ddf = nullptr;
 
-    cuda_pool_alloc<float>  dst_f;
-
     ggml_cuda_set_device(g_main_device);
     cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
 
-    if (dst_on_device) {
-        dst_ddf = (float *) dst_extra->data_device[g_main_device];
-    } else {
-        dst_ddf = dst_f.alloc(ggml_nelements(dst));
-    }
+    dst_ddf = (float *) dst_extra->data_device[g_main_device];
 
     // do the computation
     ggml_cuda_op_arange(src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
     CUDA_CHECK(cudaGetLastError());
-
-    // copy dst to host if necessary
-    if (!dst_on_device) {
-        CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream));
-    }
-
-    if (dst->backend == GGML_BACKEND_TYPE_CPU) {
-        CUDA_CHECK(cudaDeviceSynchronize());
-    }
 }
 
 static void ggml_cuda_timestep_embedding(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -9639,21 +9565,6 @@ static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm);
 }
 
-GGML_CALL bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
-    if (!g_cublas_loaded) return false;
-
-    const int64_t ne10 = src1->ne[0];
-
-    const int64_t ne0 = dst->ne[0];
-    const int64_t ne1 = dst->ne[1];
-
-    // TODO: find the optimal values for these
-    return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
-            src1->type == GGML_TYPE_F32 &&
-             dst->type == GGML_TYPE_F32 &&
-            (ne0 >= 32 && ne1 >= 32 && ne10 >= 32);
-}
-
 static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
     GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
     GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
@@ -9891,11 +9802,6 @@ static void ggml_cuda_mul_mat_batched_cublas(const ggml_tensor * src0, const ggm
 }
 
 static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    const bool all_on_device =
-        (src0->backend == GGML_BACKEND_TYPE_GPU || src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT) &&
-        (src1->backend == GGML_BACKEND_TYPE_GPU) &&
-        ( dst->backend == GGML_BACKEND_TYPE_GPU);
-
     const bool split = src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
 
     int64_t min_compute_capability = INT_MAX;
@@ -9972,13 +9878,13 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
     //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
     //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
 
-    if (!split && all_on_device && !fp16_performance_good && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
+    if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
         // KQ single-batch
         ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
-    } else if (!split && all_on_device && !fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
+    } else if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
         // KQV single-batch
         ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
-    } else if (!split && all_on_device && fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
+    } else if (!split && fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
         // KQ + KQV multi-batch
         ggml_cuda_mul_mat_batched_cublas(src0, src1, dst);
     } else if (use_dequantize_mul_mat_vec) {
@@ -10178,6 +10084,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
     ggml_cuda_mul_mat_id_cublas(dst);
     // TODO: mmq/mmv support
 #endif
+    cudaStream_t stream = g_cudaStreams[g_main_device][0];
 
     const size_t nb11 = src1->nb[1];
     const size_t nb1  =  dst->nb[1];
@@ -10187,16 +10094,9 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
     const int32_t n_as = ((int32_t *) dst->op_params)[1];
 
     std::vector<char> ids_host(ggml_nbytes(ids));
-
-    cudaStream_t stream = g_cudaStreams[g_main_device][0];
-
-    if (ids->backend == GGML_BACKEND_TYPE_GPU) {
-        const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
-        CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
-        CUDA_CHECK(cudaStreamSynchronize(stream));
-    } else {
-        memcpy(ids_host.data(), ids->data, ggml_nbytes(ids));
-    }
+    const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
+    CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
+    CUDA_CHECK(cudaStreamSynchronize(stream));
 
     const ggml_tensor_extra_gpu * src1_extra = (const ggml_tensor_extra_gpu *) src1->extra;
     const ggml_tensor_extra_gpu * dst_extra = (const ggml_tensor_extra_gpu *) dst->extra;
@@ -10213,20 +10113,11 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
     src1_row.extra = &src1_row_extra;
     dst_row.extra = &dst_row_extra;
 
-    char * src1_original = src1->backend == GGML_BACKEND_TYPE_CPU ?
-        (char *) src1->data : (char *) src1_extra->data_device[g_main_device];
-    char * dst_original  =  dst->backend == GGML_BACKEND_TYPE_CPU ?
-        (char *)  dst->data : (char *)  dst_extra->data_device[g_main_device];
+    char * src1_original = (char *) src1_extra->data_device[g_main_device];
+    char * dst_original  = (char *)  dst_extra->data_device[g_main_device];
 
     if (src1->ne[1] == 1) {
-        GGML_ASSERT(src1->backend == GGML_BACKEND_TYPE_GPU);
-        GGML_ASSERT(dst->backend  == GGML_BACKEND_TYPE_GPU);
-
         for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
-            //int32_t row_id;
-            //CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
-            //CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
-
             const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
 
             GGML_ASSERT(row_id >= 0 && row_id < n_as);
@@ -10248,11 +10139,6 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
         src1_row_extra.data_device[g_main_device] = src1_contiguous.get();
         dst_row_extra.data_device[g_main_device]  =  dst_contiguous.get();
 
-        const cudaMemcpyKind src1_kind = src1->backend == GGML_BACKEND_TYPE_CPU ?
-            cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
-        const cudaMemcpyKind dst_kind  =  dst->backend == GGML_BACKEND_TYPE_CPU ?
-            cudaMemcpyDeviceToHost : cudaMemcpyDeviceToDevice;
-
         for (int32_t row_id = 0; row_id < n_as; ++row_id) {
             const struct ggml_tensor * src0_row = dst->src[row_id + 2];
 
@@ -10267,7 +10153,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
                 GGML_ASSERT(row_id >= 0 && row_id < n_as);
 
                 CUDA_CHECK(cudaMemcpyAsync(src1_contiguous.get() + num_src1_rows*nb11, src1_original + i01*nb11,
-                                        nb11, src1_kind, stream));
+                                        nb11, cudaMemcpyDeviceToDevice, stream));
                 num_src1_rows++;
             }
 
@@ -10299,15 +10185,11 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
                 GGML_ASSERT(row_id >= 0 && row_id < n_as);
 
                 CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous.get() + num_src1_rows*nb1,
-                                        nb1, dst_kind, stream));
+                                        nb1, cudaMemcpyDeviceToDevice, stream));
                 num_src1_rows++;
             }
         }
     }
-
-    if (dst->backend == GGML_BACKEND_TYPE_CPU) {
-        CUDA_CHECK(cudaStreamSynchronize(stream));
-    }
 }
 
 static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -10435,7 +10317,7 @@ static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_spl
     return nrows_split*ggml_row_size(tensor->type, tensor->ne[0]);
 }
 
-GGML_CALL static void ggml_cuda_set_main_device(const int main_device) {
+static void ggml_cuda_set_main_device(const int main_device) {
     if (main_device >= g_device_count) {
         fprintf(stderr, "warning: cannot set main_device=%d because there are only %d devices. Using device %d instead.\n",
                 main_device, g_device_count, g_main_device);
@@ -10450,18 +10332,9 @@ GGML_CALL static void ggml_cuda_set_main_device(const int main_device) {
     }
 }
 
-GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
+static bool ggml_cuda_compute_forward(struct ggml_tensor * tensor) {
     if (!g_cublas_loaded) return false;
 
-    ggml_cuda_func_t func;
-    const bool any_on_device = tensor->backend == GGML_BACKEND_TYPE_GPU
-        || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU || tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT))
-        || (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_TYPE_GPU);
-
-    if (!any_on_device && tensor->op != GGML_OP_MUL_MAT && tensor->op != GGML_OP_MUL_MAT_ID) {
-        return false;
-    }
-
     if (tensor->op == GGML_OP_MUL_MAT) {
         if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
 #ifndef NDEBUG
@@ -10471,6 +10344,8 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
         }
     }
 
+    ggml_cuda_func_t func;
+
     switch (tensor->op) {
         case GGML_OP_REPEAT:
             func = ggml_cuda_repeat;
@@ -10548,15 +10423,9 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
             func = ggml_cuda_rms_norm;
             break;
         case GGML_OP_MUL_MAT:
-            if (!any_on_device && !ggml_cuda_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) {
-                return false;
-            }
             func = ggml_cuda_mul_mat;
             break;
         case GGML_OP_MUL_MAT_ID:
-            if (!any_on_device && !ggml_cuda_can_mul_mat(tensor->src[2], tensor->src[1], tensor)) {
-                return false;
-            }
             func = ggml_cuda_mul_mat_id;
             break;
         case GGML_OP_SCALE:
@@ -10613,17 +10482,11 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
         ggml_cuda_set_peer_access(tensor->src[1]->ne[1]);
     }
 
-    if (params->ith != 0) {
-        return true;
-    }
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return true;
-    }
     func(tensor->src[0], tensor->src[1], tensor);
     return true;
 }
 
-GGML_CALL int ggml_cuda_get_device_count() {
+static int ggml_cuda_get_device_count() {
     int device_count;
     if (cudaGetDeviceCount(&device_count) != cudaSuccess) {
         return 0;
@@ -10631,7 +10494,7 @@ GGML_CALL int ggml_cuda_get_device_count() {
     return device_count;
 }
 
-GGML_CALL void ggml_cuda_get_device_description(int device, char * description, size_t description_size) {
+static void ggml_cuda_get_device_description(int device, char * description, size_t description_size) {
     cudaDeviceProp prop;
     CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
     snprintf(description, description_size, "%s", prop.name);
@@ -10736,6 +10599,7 @@ GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t
         size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
 
         if (padded_size > original_size && tensor->view_src == nullptr) {
+            ggml_cuda_set_device(ctx->device);
             CUDA_CHECK(cudaMemset((char *)tensor->data + original_size, 0, padded_size - original_size));
         }
     }
@@ -10873,6 +10737,8 @@ static ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = {
 };
 
 GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
+    ggml_init_cublas();
+
     // FIXME: this is not thread safe
     if (device >= ggml_backend_cuda_get_device_count()) {
         return nullptr;
@@ -11157,6 +11023,8 @@ static ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_interface
 };
 
 GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * tensor_split) {
+    ggml_init_cublas();
+
     // FIXME: this is not thread safe
     static std::map<std::array<float, GGML_CUDA_MAX_DEVICES>, struct ggml_backend_buffer_type> buft_map;
 
@@ -11348,9 +11216,6 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
 
     ggml_cuda_set_main_device(cuda_ctx->device);
 
-    ggml_compute_params params = {};
-    params.type = GGML_TASK_TYPE_COMPUTE;
-    params.ith = 0;
     for (int i = 0; i < cgraph->n_nodes; i++) {
         ggml_tensor * node = cgraph->nodes[i];
 
@@ -11372,7 +11237,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
         }
 #endif
 
-        bool ok = ggml_cuda_compute_forward(&params, node);
+        bool ok = ggml_cuda_compute_forward(node);
         if (!ok) {
             fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
         }
@@ -11509,6 +11374,14 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
     UNUSED(backend);
 }
 
+GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
+    const int min_batch_size = 32;
+
+    return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
+
+    UNUSED(backend);
+}
+
 static ggml_backend_event_t ggml_backend_cuda_event_new(ggml_backend_t backend) {
     ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
 
@@ -11571,6 +11444,7 @@ static ggml_backend_i ggml_backend_cuda_interface = {
     /* .graph_plan_compute      = */ NULL,
     /* .graph_compute           = */ ggml_backend_cuda_graph_compute,
     /* .supports_op             = */ ggml_backend_cuda_supports_op,
+    /* .offload_op              = */ ggml_backend_cuda_offload_op,
     /* .event_new               = */ ggml_backend_cuda_event_new,
     /* .event_free              = */ ggml_backend_cuda_event_free,
     /* .event_record            = */ ggml_backend_cuda_event_record,
@@ -11584,7 +11458,7 @@ static ggml_guid_t ggml_backend_cuda_guid() {
 }
 
 GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device) {
-    ggml_init_cublas(); // TODO: remove from ggml.c
+    ggml_init_cublas();
 
     if (device < 0 || device >= ggml_cuda_get_device_count()) {
         fprintf(stderr, "%s: error: invalid device %d\n", __func__, device);
@@ -11627,6 +11501,31 @@ GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, si
     CUDA_CHECK(cudaMemGetInfo(free, total));
 }
 
+GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
+    if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
+        return false;
+    }
+
+    cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
+    if (err != cudaSuccess) {
+        // clear the error
+        cudaGetLastError();
+
+        fprintf(stderr, "%s: warning: failed to register %.2f MiB of pinned memory: %s\n", __func__,
+                size/1024.0/1024.0, cudaGetErrorString(err));
+        return false;
+    }
+    return true;
+}
+
+GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
+    cudaError_t err = cudaHostUnregister(buffer);
+    if (err != cudaSuccess) {
+        // clear the error
+        cudaGetLastError();
+    }
+}
+
 // backend registry
 GGML_CALL static ggml_backend_t ggml_backend_reg_cuda_init(const char * params, void * user_data) {
     ggml_backend_t cuda_backend = ggml_backend_cuda_init((int) (intptr_t) user_data);
index b1ebd61d7fb665c3e659685a0480f7ba1428fb3d..5eb4af40f4d1fef5847c37bed27dfc99d3a07a4a 100644 (file)
@@ -17,29 +17,17 @@ extern "C" {
 
 #define GGML_CUDA_MAX_DEVICES       16
 
-// Always success. To check if CUDA is actually loaded, use `ggml_cublas_loaded`.
-GGML_API GGML_CALL void   ggml_init_cublas(void);
-
-// Returns `true` if there are available CUDA devices and cublas loads successfully; otherwise, it returns `false`.
-GGML_API GGML_CALL bool   ggml_cublas_loaded(void);
-
-GGML_API GGML_CALL void * ggml_cuda_host_malloc(size_t size);
-GGML_API GGML_CALL void   ggml_cuda_host_free(void * ptr);
-
-GGML_API GGML_CALL bool   ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
-GGML_API GGML_CALL bool   ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
-
-GGML_API GGML_CALL int    ggml_cuda_get_device_count(void);
-GGML_API GGML_CALL void   ggml_cuda_get_device_description(int device, char * description, size_t description_size);
-
 // backend API
 GGML_API GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device);
 
 GGML_API GGML_CALL bool ggml_backend_is_cuda(ggml_backend_t backend);
 
+// device buffer
 GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
+
 // split tensor buffer that splits matrices by rows across multiple devices
 GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * tensor_split);
+
 // pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
 GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
 
@@ -47,6 +35,9 @@ GGML_API GGML_CALL int  ggml_backend_cuda_get_device_count(void);
 GGML_API GGML_CALL void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
 GGML_API GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
 
+GGML_API GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);
+GGML_API GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer);
+
 #ifdef  __cplusplus
 }
 #endif
index 4caf2c9e78b0281ffe19e0caa687f31c3c5f91c5..81dd5067864ce73033a003efdcf54cb44eab4eea 100644 (file)
@@ -1951,6 +1951,7 @@ static struct ggml_backend_i kompute_backend_i = {
     /* .graph_plan_compute      = */ NULL,
     /* .graph_compute           = */ ggml_backend_kompute_graph_compute,
     /* .supports_op             = */ ggml_backend_kompute_supports_op,
+    /* .offload_op              = */ NULL,
     /* .event_new               = */ NULL,
     /* .event_free              = */ NULL,
     /* .event_record            = */ NULL,
index c3451a79b103f0b7a0b8b64c96db64f0ad747674..109e5fe6ba13dd45038a2dc7a881186c9d3b14b0 100644 (file)
@@ -2837,6 +2837,7 @@ static struct ggml_backend_i ggml_backend_metal_i = {
     /* .graph_plan_compute      = */ NULL,
     /* .graph_compute           = */ ggml_backend_metal_graph_compute,
     /* .supports_op             = */ ggml_backend_metal_supports_op,
+    /* .offload_op              = */ NULL,
     /* .event_new               = */ NULL,
     /* .event_free              = */ NULL,
     /* .event_record            = */ NULL,
index 6dc5eb20c5a63a11b70048ff22f38cc3b322631d..d51f23b410595ed7653cce928bc724aba26cec7a 100644 (file)
@@ -17390,6 +17390,7 @@ static ggml_backend_i ggml_backend_sycl_interface = {
     /* .graph_plan_compute      = */ NULL,
     /* .graph_compute           = */ ggml_backend_sycl_graph_compute,
     /* .supports_op             = */ ggml_backend_sycl_supports_op,
+    /* .offload_op              = */ NULL,
     /* .event_new               = */ NULL,
     /* .event_free              = */ NULL,
     /* .event_record            = */ NULL,
index 698b31496f417e63609be25e1d1e26551b4cde00..cbceaa19fbacde29d6b950c504edeccf92c11989 100644 (file)
@@ -5699,6 +5699,7 @@ static ggml_backend_i ggml_backend_vk_interface = {
     /* .graph_plan_compute      = */ NULL,
     /* .graph_compute           = */ ggml_backend_vk_graph_compute,
     /* .supports_op             = */ ggml_backend_vk_supports_op,
+    /* .offload_op              = */ NULL,
     /* .event_new               = */ NULL,
     /* .event_free              = */ NULL,
     /* .event_record            = */ NULL,
index fa23cb3c44f87fb549481207843551c4d95f0dc3..1d5854960a51bb5eb016d8a0aff4359a16c6a49c 100644 (file)
@@ -282,8 +282,6 @@ inline static void * ggml_calloc(size_t num, size_t size) {
 #else
 #include <cblas.h>
 #endif
-#elif defined(GGML_USE_CUBLAS)
-#include "ggml-cuda.h"
 #elif defined(GGML_USE_CLBLAST)
 #include "ggml-opencl.h"
 #elif defined(GGML_USE_VULKAN)
@@ -2640,9 +2638,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
             GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
         }
 
-#if defined(GGML_USE_CUBLAS)
-        ggml_init_cublas();
-#elif defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_CLBLAST)
         ggml_cl_init();
 #elif defined(GGML_USE_VULKAN)
         ggml_vk_init_cpu_assist();
@@ -11105,7 +11101,6 @@ static void ggml_compute_forward_out_prod_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-    // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
     // TODO: #if defined(GGML_USE_CLBLAST)
 
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
@@ -11305,7 +11300,6 @@ static void ggml_compute_forward_out_prod_q_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-    // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod
     // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
 
     if (params->type == GGML_TASK_TYPE_INIT) {
@@ -16051,14 +16045,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
         return;
     }
 
-#ifdef GGML_USE_CUBLAS
-    bool skip_cpu = ggml_cuda_compute_forward(params, tensor);
-    if (skip_cpu) {
-        return;
-    }
-    GGML_ASSERT(tensor->src[0] == NULL || tensor->src[0]->backend == GGML_BACKEND_TYPE_CPU);
-    GGML_ASSERT(tensor->src[1] == NULL || tensor->src[1]->backend == GGML_BACKEND_TYPE_CPU);
-#elif defined(GGML_USE_VULKAN)
+#if defined(GGML_USE_VULKAN)
     const bool skip_cpu = ggml_vk_compute_forward_cpu_assist(params, tensor);
 #ifdef GGML_VULKAN_CHECK_RESULTS
     if (skip_cpu) {
@@ -16070,7 +16057,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
     }
     GGML_ASSERT(tensor->src[0] == NULL || tensor->src[0]->backend == GGML_BACKEND_TYPE_CPU);
     GGML_ASSERT(tensor->src[1] == NULL || tensor->src[1]->backend == GGML_BACKEND_TYPE_CPU);
-#endif // GGML_USE_CUBLAS
+#endif // GGML_USE_VULKAN
 
 #ifdef GGML_USE_SYCL
     bool skip_cpu = ggml_sycl_compute_forward(params, tensor);