}
}
-static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
+static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
}
-static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
+static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(!ggml_is_transposed(src0));
GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(!ggml_is_permuted(src0));
GGML_UNUSED(backend);
}
+static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
+ graph_node_properties->node_address = node->data;
+ graph_node_properties->node_op = node->op;
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
+ graph_node_properties->ne[i] = node->ne[i];
+ graph_node_properties->nb[i] = node->nb[i];
+ }
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
+ graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
+ }
+}
+
+static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
+ if (node->data != graph_node_properties->node_address &&
+ node->op != GGML_OP_CPY &&
+ node->op != GGML_OP_VIEW) {
+ return false;
+ }
+
+ if (node->op != graph_node_properties->node_op) {
+ return false;
+ }
+
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
+ if (node->ne[i] != graph_node_properties->ne[i]) {
+ return false;
+ }
+ if (node->nb[i] != graph_node_properties->nb[i]) {
+ return false;
+ }
+ }
+
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
+ if (node->src[i] &&
+ node->src[i]->data != graph_node_properties->src_address[i] &&
+ node->op != GGML_OP_CPY &&
+ node->op != GGML_OP_VIEW
+ ) {
+ return false;
+ }
+ }
+ return true;
+}
+
GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
ggml_cuda_set_device(cuda_ctx->device);
- for (int i = 0; i < cgraph->n_nodes; i++) {
- ggml_tensor * node = cgraph->nodes[i];
+#ifdef USE_CUDA_GRAPH
+ static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
- if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
- continue;
+ // Objects required for CUDA Graph
+ if (cuda_ctx->cuda_graph == nullptr) {
+ cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
+ }
+
+ bool use_cuda_graph = true;
+ bool cuda_graph_update_required = false;
+ // pointer to CUDA cpy kernel, which is required to identify
+ // kernel parameters which need updated in the graph for each token
+ void * ggml_cuda_cpy_fn_ptr = nullptr;
+
+ if (cuda_ctx->cuda_graph->graph == nullptr) {
+ if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
+ cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
+#ifndef NDEBUG
+ fprintf(stderr, "%s: disabling CUDA graphs due to GPU architecture\n", __func__);
+#endif
+ }
+ }
+
+ // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
+ // or previous graph capture failure.
+ // Also disable for multi-gpu for now. TO DO investigate
+ if (disable_cuda_graphs_due_to_env
+ || cuda_ctx->cuda_graph->disable_due_to_gpu_arch
+ || cuda_ctx->cuda_graph->disable_due_to_too_many_updates
+ || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
+ use_cuda_graph = false;
+ }
+
+ if (use_cuda_graph) {
+ if (cuda_ctx->cuda_graph->instance == nullptr) {
+ cuda_graph_update_required = true;
+ }
+
+ // Check if the graph size has changed
+ if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
+ cuda_graph_update_required = true;
+ cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
+ }
+
+ // Loop over nodes in GGML graph to determine if CUDA graph update is required
+ // and store properties to allow this comparison for the next token
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ bool has_matching_properties = true;
+ if (!cuda_graph_update_required) {
+ has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
+ }
+ if (!has_matching_properties) {
+ cuda_graph_update_required = true;
+ }
+ set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
+ }
+
+ // Loop over nodes in GGML graph to obtain info needed for CUDA graph
+ cuda_ctx->cuda_graph->updated_kernel_arg.clear();
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ ggml_tensor * node = cgraph->nodes[i];
+
+ if (node->src[0] && ggml_backend_buffer_is_cuda_split(node->src[0]->buffer)) {
+ use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
+#ifndef NDEBUG
+ fprintf(stderr, "%s: disabling CUDA graphs due to split buffer\n", __func__);
+#endif
+ }
+
+ if (node->op == GGML_OP_MUL_MAT_ID) {
+ use_cuda_graph = false; // This node type is not supported by CUDA graph capture
+#ifndef NDEBUG
+ fprintf(stderr, "%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
+#endif
+ }
+
+ if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
+ // disable CUDA graphs for batch size > 1 for now.
+ // Changes in batch size or context size can cause changes to the grid size of some kernels.
+ use_cuda_graph = false;
+#ifndef NDEBUG
+ fprintf(stderr, "%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
+#endif
+ }
+
+ if (node->op == GGML_OP_CPY) {
+ // store the copy op parameter which changes with each token.
+ cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
+ if (ggml_cuda_cpy_fn_ptr == nullptr) {
+ // store a pointer to the copy op CUDA kernel to identify it later
+ ggml_cuda_cpy_fn_ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
+ }
+ }
+
+ if (!use_cuda_graph) {
+ break;
+ }
+ }
+
+ // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
+ if (cuda_graph_update_required) {
+ cuda_ctx->cuda_graph->number_consecutive_updates++;
+ } else {
+ cuda_ctx->cuda_graph->number_consecutive_updates = 0;
}
+ if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
+ cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
+#ifndef NDEBUG
+ fprintf(stderr, "%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
+#endif
+ }
+ }
+
+ if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
+ CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
+ }
+
+#else
+ bool use_cuda_graph = false;
+ bool cuda_graph_update_required = false;
+#endif // USE_CUDA_GRAPH
+
+ bool graph_evaluated_or_captured = false;
+
+ while (!graph_evaluated_or_captured) {
+ // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
+ // With the use of CUDA graphs, the execution will be performed by the graph launch.
+ if (!use_cuda_graph || cuda_graph_update_required) {
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ ggml_tensor * node = cgraph->nodes[i];
+
+ if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
+ continue;
+ }
+
#ifndef NDEBUG
- assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
- for (int j = 0; j < GGML_MAX_SRC; j++) {
- if (node->src[j] != nullptr) {
- assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer));
+ assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ if (node->src[j] != nullptr) {
+ assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer));
+ }
+ }
+#endif
+
+ bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
+ if (!ok) {
+ fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
+ }
+ GGML_ASSERT(ok);
}
}
+
+#ifdef USE_CUDA_GRAPH
+ if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
+ if (cuda_ctx->cuda_graph->graph != nullptr) {
+ CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
+ cuda_ctx->cuda_graph->graph = nullptr;
+ }
+ CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
+
+#if 0
+ if (disable_cuda_graphs_due_to_failed_capture) {
+ use_cuda_graph = false;
+ cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
+#ifndef NDEBUG
+ fprintf(stderr, "%s: disabling CUDA graphs due to failed graph capture\n", __func__);
#endif
+ } else {
+ graph_evaluated_or_captured = true; // CUDA graph has been captured
+ }
+#endif
+ graph_evaluated_or_captured = true; // CUDA graph has been captured
+ } else {
+ graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
+ }
+ }
- bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
- if (!ok) {
- fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
+ if (use_cuda_graph) {
+ if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
+ CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
}
- GGML_ASSERT(ok);
+
+ // Perform update to graph (if required for this token), and change copy parameter (required for every token)
+
+ if (cuda_graph_update_required) {
+ // Extract nodes from graph
+ if (cuda_ctx->cuda_graph->num_nodes == 0) {
+ // First call with null argument gets number of nodes in graph
+ CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
+ }
+ // Subsequent call with non-null argument gets nodes
+ cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
+ cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
+ if (cuda_ctx->cuda_graph->num_nodes > 0) {
+ CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
+
+ // Loop over nodes, and extract kernel parameters from each node
+ for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
+ cudaGraphNodeType node_type;
+ CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
+ if (node_type == cudaGraphNodeTypeKernel) {
+ cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
+ if (stat == cudaErrorInvalidDeviceFunction) {
+ // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
+ // We don't need to update blas nodes, so clear error and move on.
+ cudaGetLastError();
+ } else {
+ GGML_ASSERT(stat == cudaSuccess);
+ }
+ }
+ }
+ }
+ }
+
+ // One of the arguments to the copy kernel is updated for each token, hence we need to
+ // replace that argument with the updated value in the CUDA graph
+ if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
+ int k = 0;
+ for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
+ if (cuda_ctx->cuda_graph->params[i].func == ggml_cuda_cpy_fn_ptr) {
+ char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
+ cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
+ CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
+ }
+ }
+ }
+
+ // Update graph executable
+ cudaGraphExecUpdateResultInfo result_info;
+ cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
+ if (stat == cudaErrorGraphExecUpdateFailure) {
+#ifndef NDEBUG
+ fprintf(stderr, "%s: CUDA graph update failed\n", __func__);
+#endif
+ // The pre-existing graph exec cannot be updated due to violated constraints
+ // so instead clear error and re-instantiate
+ cudaGetLastError();
+ CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
+ cuda_ctx->cuda_graph->instance = nullptr;
+ CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
+ } else {
+ GGML_ASSERT(stat == cudaSuccess);
+ }
+ // Launch graph
+ CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
+#else
+ graph_evaluated_or_captured = true;
+#endif // USE_CUDA_GRAPH
}
return GGML_STATUS_SUCCESS;
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
+ int id = ggml_cuda_get_device();
const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps;
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
+ int id = ggml_cuda_get_device();
const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps;
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
+ int id = ggml_cuda_get_device();
const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps;
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
+ int id = ggml_cuda_get_device();
const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps;
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
+ int id = ggml_cuda_get_device();
const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps;
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
+ int id = ggml_cuda_get_device();
const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps;
#if QK_K == 256
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
+ int id = ggml_cuda_get_device();
const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps;
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
+ int id = ggml_cuda_get_device();
const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps;
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
+ int id = ggml_cuda_get_device();
const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps;
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
+ int id = ggml_cuda_get_device();
const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps;