}
}
+#ifdef GGML_SYCL_GRAPH
+static bool check_graph_compatibility(ggml_cgraph * cgraph) {
+ if (ggml_sycl_info().device_count > 1) {
+ // A sycl_ex::command_graph object can only be created for a single device
+ GGML_LOG_INFO("%s: disabling SYCL graphs due to multiple devices\n", __func__);
+ return false;
+ }
+
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ const ggml_op node_op = cgraph->nodes[i]->op;
+ switch (node_op) {
+ default:
+ break;
+ case GGML_OP_CONCAT:
+ // ggml_sycl_op_concat() does a blocking host wait after memcpy operations,
+ // but wait() can't be called on the events returned by a queue recording
+ // to a graph.
+ [[fallthrough]];
+ case GGML_OP_MUL_MAT_ID:
+ // ggml_sycl_mul_mat_id() does a blocking host wait on the sycl queue after
+ // submitting a memcpy operation, but wait() can't be called on a queue that
+ // is recording to a graph.
+ GGML_LOG_INFO("%s: disabling SYCL graphs due to unsupported node type %s\n", __func__,
+ ggml_op_name(node_op));
+ return false;
+ }
+ }
+ return true;
+}
+#endif
+
static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
auto * sycl_ctx = static_cast<ggml_backend_sycl_context *>(backend->context);
#ifdef GGML_SYCL_GRAPH
- if (!g_ggml_sycl_disable_graph) {
+ bool use_sycl_graph = !g_ggml_sycl_disable_graph && check_graph_compatibility(cgraph);
+ if (use_sycl_graph) {
const bool graph_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_limited_graph);
if (!graph_support) {
GGML_SYCL_DEBUG("[SYCL-GRAPH] can not use graphs on device:%d\n", sycl_ctx->device);