]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
Introduction of CUDA Graphs to LLama.cpp (llama/6766)
authoragray3 <redacted>
Wed, 8 May 2024 20:55:49 +0000 (21:55 +0100)
committerGeorgi Gerganov <redacted>
Mon, 13 May 2024 08:02:26 +0000 (11:02 +0300)
* DRAFT: Introduction of CUDA Graphs to LLama.cpp

* FIx issues raised in comments

* Tidied to now only use CUDA runtime (not mixed with driver calls)

* disable for multi-gpu and batch size > 1

* Disable CUDA graphs for old GPU arch and with env var

* added missing CUDA_CHECKs

* Addressed comments

* further addressed comments

* limit to GGML_ALLOW_CUDA_GRAPHS defined in llama.cpp cmake

* Added more comprehensive graph node checking

* With mechanism to fall back if graph capture fails

* Revert "With mechanism to fall back if graph capture fails"

This reverts commit eb9f15fb6fcb81384f732c4601a5b25c016a5143.

* Fall back if graph capture fails and address other comments

* - renamed GGML_ALLOW_CUDA_GRAPHS to GGML_CUDA_USE_GRAPHS

- rename env variable to disable CUDA graphs to GGML_CUDA_DISABLE_GRAPHS

- updated Makefile build to enable CUDA graphs

- removed graph capture failure checking in ggml_cuda_error
  using a global variable to track this is not thread safe, but I am also not safistied with checking an error by string
  if this is necessary to workaround some issues with graph capture with eg. cuBLAS, we can pass the ggml_backend_cuda_context to the error checking macro and store the result in the context

- fixed several resource leaks

- fixed issue with zero node graphs

- changed fixed size arrays to vectors

- removed the count of number of evaluations before start capturing, and instead changed the capture mode to relaxed

- removed the check for multiple devices so that it is still possible to use a single device, instead checks for split buffers to disable cuda graphs with -sm row

- changed the op for checking batch size to GGML_OP_ADD, should be more reliable than GGML_OP_SOFT_MAX

- code style fixes

- things to look into
  - VRAM usage of the cudaGraphExec_t, if it is significant we may need to make it optional
  - possibility of using cudaStreamBeginCaptureToGraph to keep track of which ggml graph nodes correspond to which cuda graph nodes

* fix build without cuda graphs

* remove outdated comment

* replace minimum cc value with a constant

---------

Co-authored-by: slaren <redacted>
ggml-cuda.cu
ggml-cuda/clamp.cu
ggml-cuda/common.cuh
ggml-cuda/convert.cu
ggml-cuda/cpy.cu
ggml-cuda/cpy.cuh
ggml-cuda/mmq.cu
ggml-cuda/mmvq.cu
ggml-cuda/scale.cu

index 8739baa2a791404eb6355c87d5372ba8f4e17fc4..ceb66170edd69fc106055d4f4689cff11a3f3f27 100644 (file)
@@ -1647,7 +1647,7 @@ static void ggml_cuda_op_mul_mat(
     }
 }
 
-static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
+static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
     GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
     GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
@@ -1670,7 +1670,7 @@ static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const gg
     ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
 }
 
-static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
+static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(!ggml_is_transposed(src0));
     GGML_ASSERT(!ggml_is_transposed(src1));
     GGML_ASSERT(!ggml_is_permuted(src0));
@@ -2413,32 +2413,304 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
     GGML_UNUSED(backend);
 }
 
+static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
+    graph_node_properties->node_address = node->data;
+    graph_node_properties->node_op = node->op;
+    for (int i = 0; i < GGML_MAX_DIMS; i++) {
+        graph_node_properties->ne[i] = node->ne[i];
+        graph_node_properties->nb[i] = node->nb[i];
+    }
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
+    }
+}
+
+static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
+    if (node->data != graph_node_properties->node_address &&
+          node->op != GGML_OP_CPY &&
+          node->op != GGML_OP_VIEW) {
+        return false;
+    }
+
+    if (node->op != graph_node_properties->node_op) {
+        return false;
+    }
+
+    for (int i = 0; i < GGML_MAX_DIMS; i++) {
+        if (node->ne[i] != graph_node_properties->ne[i]) {
+            return false;
+        }
+        if (node->nb[i] != graph_node_properties->nb[i]) {
+            return false;
+        }
+    }
+
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        if (node->src[i] &&
+            node->src[i]->data != graph_node_properties->src_address[i] &&
+            node->op != GGML_OP_CPY &&
+            node->op != GGML_OP_VIEW
+        ) {
+            return false;
+        }
+    }
+    return true;
+}
+
 GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
     ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
 
     ggml_cuda_set_device(cuda_ctx->device);
 
-    for (int i = 0; i < cgraph->n_nodes; i++) {
-        ggml_tensor * node = cgraph->nodes[i];
+#ifdef USE_CUDA_GRAPH
+    static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
 
-        if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
-            continue;
+    // Objects required for CUDA Graph
+    if (cuda_ctx->cuda_graph == nullptr) {
+        cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
+    }
+
+    bool use_cuda_graph = true;
+    bool cuda_graph_update_required = false;
+    // pointer to CUDA cpy kernel, which is required to identify
+    // kernel parameters which need updated in the graph for each token
+    void * ggml_cuda_cpy_fn_ptr = nullptr;
+
+    if (cuda_ctx->cuda_graph->graph == nullptr) {
+        if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
+            cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
+#ifndef NDEBUG
+            fprintf(stderr, "%s: disabling CUDA graphs due to GPU architecture\n", __func__);
+#endif
+        }
+    }
+
+    // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
+    // or previous graph capture failure.
+    // Also disable for multi-gpu for now. TO DO investigate
+    if (disable_cuda_graphs_due_to_env
+        || cuda_ctx->cuda_graph->disable_due_to_gpu_arch
+        || cuda_ctx->cuda_graph->disable_due_to_too_many_updates
+        || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
+        use_cuda_graph = false;
+    }
+
+    if (use_cuda_graph) {
+        if (cuda_ctx->cuda_graph->instance == nullptr) {
+            cuda_graph_update_required = true;
+        }
+
+        // Check if the graph size has changed
+        if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
+            cuda_graph_update_required = true;
+            cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
+        }
+
+        // Loop over nodes in GGML graph to determine if CUDA graph update is required
+        // and store properties to allow this comparison for the next token
+        for (int i = 0; i < cgraph->n_nodes; i++) {
+            bool has_matching_properties = true;
+            if (!cuda_graph_update_required) {
+                has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
+            }
+            if (!has_matching_properties) {
+                cuda_graph_update_required = true;
+            }
+            set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
+        }
+
+        // Loop over nodes in GGML graph to obtain info needed for CUDA graph
+        cuda_ctx->cuda_graph->updated_kernel_arg.clear();
+        for (int i = 0; i < cgraph->n_nodes; i++) {
+            ggml_tensor * node = cgraph->nodes[i];
+
+            if (node->src[0] && ggml_backend_buffer_is_cuda_split(node->src[0]->buffer)) {
+                use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
+#ifndef NDEBUG
+                fprintf(stderr, "%s: disabling CUDA graphs due to split buffer\n", __func__);
+#endif
+            }
+
+            if (node->op == GGML_OP_MUL_MAT_ID) {
+                use_cuda_graph = false; // This node type is not supported by CUDA graph capture
+#ifndef NDEBUG
+                fprintf(stderr, "%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
+#endif
+            }
+
+            if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
+                // disable CUDA graphs for batch size > 1 for now.
+                // Changes in batch size or context size can cause changes to the grid size of some kernels.
+                use_cuda_graph = false;
+#ifndef NDEBUG
+                fprintf(stderr, "%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
+#endif
+            }
+
+            if (node->op == GGML_OP_CPY) {
+                // store the copy op parameter which changes with each token.
+                cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
+                if (ggml_cuda_cpy_fn_ptr == nullptr) {
+                    // store a pointer to the copy op CUDA kernel to identify it later
+                    ggml_cuda_cpy_fn_ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
+                }
+            }
+
+            if (!use_cuda_graph) {
+                break;
+            }
+        }
+
+        // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
+        if (cuda_graph_update_required) {
+            cuda_ctx->cuda_graph->number_consecutive_updates++;
+        } else {
+            cuda_ctx->cuda_graph->number_consecutive_updates = 0;
         }
 
+        if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
+            cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
+#ifndef NDEBUG
+            fprintf(stderr, "%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
+#endif
+        }
+    }
+
+    if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
+        CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
+    }
+
+#else
+    bool use_cuda_graph = false;
+    bool cuda_graph_update_required = false;
+#endif // USE_CUDA_GRAPH
+
+    bool graph_evaluated_or_captured = false;
+
+    while (!graph_evaluated_or_captured) {
+        // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
+        // With the use of CUDA graphs, the execution will be performed by the graph launch.
+        if (!use_cuda_graph || cuda_graph_update_required) {
+            for (int i = 0; i < cgraph->n_nodes; i++) {
+                ggml_tensor * node = cgraph->nodes[i];
+
+                if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
+                    continue;
+                }
+
 #ifndef NDEBUG
-        assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
-        for (int j = 0; j < GGML_MAX_SRC; j++) {
-            if (node->src[j] != nullptr) {
-                assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer));
+                assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
+                for (int j = 0; j < GGML_MAX_SRC; j++) {
+                    if (node->src[j] != nullptr) {
+                        assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer));
+                    }
+                }
+#endif
+
+                bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
+                if (!ok) {
+                    fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
+                }
+                GGML_ASSERT(ok);
             }
         }
+
+#ifdef USE_CUDA_GRAPH
+        if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
+            if (cuda_ctx->cuda_graph->graph != nullptr) {
+                CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
+                cuda_ctx->cuda_graph->graph = nullptr;
+            }
+            CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
+
+#if 0
+            if (disable_cuda_graphs_due_to_failed_capture) {
+                use_cuda_graph = false;
+                cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
+#ifndef NDEBUG
+                fprintf(stderr, "%s: disabling CUDA graphs due to failed graph capture\n", __func__);
 #endif
+            } else {
+                graph_evaluated_or_captured = true; // CUDA graph has been captured
+            }
+#endif
+            graph_evaluated_or_captured = true; // CUDA graph has been captured
+        } else {
+            graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
+        }
+    }
 
-        bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
-        if (!ok) {
-            fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
+    if (use_cuda_graph) {
+        if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
+            CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
         }
-        GGML_ASSERT(ok);
+
+        // Perform update to graph (if required for this token), and change copy parameter (required for every token)
+
+        if (cuda_graph_update_required) {
+            // Extract nodes from graph
+            if (cuda_ctx->cuda_graph->num_nodes == 0) {
+                // First call with null argument gets number of nodes in graph
+                CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
+            }
+            // Subsequent call with non-null argument gets nodes
+            cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
+            cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
+            if (cuda_ctx->cuda_graph->num_nodes > 0) {
+                CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
+
+                // Loop over nodes, and extract kernel parameters from each node
+                for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
+                    cudaGraphNodeType node_type;
+                    CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
+                    if (node_type == cudaGraphNodeTypeKernel) {
+                        cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
+                        if (stat == cudaErrorInvalidDeviceFunction) {
+                            // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
+                            // We don't need to update blas nodes, so clear error and move on.
+                            cudaGetLastError();
+                        } else {
+                            GGML_ASSERT(stat == cudaSuccess);
+                        }
+                    }
+                }
+            }
+        }
+
+        // One of the arguments to the copy kernel is updated for each token, hence we need to
+        // replace that argument with the updated value in the CUDA graph
+        if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
+            int k = 0;
+            for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
+                if (cuda_ctx->cuda_graph->params[i].func == ggml_cuda_cpy_fn_ptr) {
+                    char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
+                    cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
+                    CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
+                }
+            }
+        }
+
+        // Update graph executable
+        cudaGraphExecUpdateResultInfo result_info;
+        cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
+        if (stat == cudaErrorGraphExecUpdateFailure) {
+#ifndef NDEBUG
+            fprintf(stderr, "%s: CUDA graph update failed\n", __func__);
+#endif
+            // The pre-existing graph exec cannot be updated due to violated constraints
+            // so instead clear error and re-instantiate
+            cudaGetLastError();
+            CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
+            cuda_ctx->cuda_graph->instance = nullptr;
+            CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
+        } else {
+            GGML_ASSERT(stat == cudaSuccess);
+        }
+        // Launch graph
+        CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
+#else
+        graph_evaluated_or_captured = true;
+#endif // USE_CUDA_GRAPH
     }
 
     return GGML_STATUS_SUCCESS;
index 379ded042d897ac09932740c4d13b33a3099eb1f..8009a3e3d8607cfc8b682ac0affee9a84c11e71d 100644 (file)
@@ -31,5 +31,4 @@ void ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
 
     clamp_f32_cuda(src0_d, dst_d, min, max, ggml_nelements(src0), stream);
-    CUDA_CHECK(cudaGetLastError());
 }
index b2627b7b4b77ff65a17de8199b3163a6a3c9c6be..a4197f11ba779f52040e6932cb68e2689cc68bda 100644 (file)
@@ -19,6 +19,7 @@
 #include <cassert>
 #include <cfloat>
 #include <string>
+#include <vector>
 
 #if defined(GGML_USE_HIPBLAS)
 #include <hip/hip_runtime.h>
@@ -526,6 +527,43 @@ struct ggml_tensor_extra_gpu {
     cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs
 };
 
+
+#if (CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)
+#define USE_CUDA_GRAPH
+#endif
+
+struct ggml_graph_node_properties {
+    void * node_address;
+    ggml_op node_op;
+    int64_t ne[GGML_MAX_DIMS];
+    size_t nb[GGML_MAX_DIMS];
+    void * src_address[GGML_MAX_SRC];
+};
+
+struct ggml_cuda_graph {
+#ifdef USE_CUDA_GRAPH
+    ~ggml_cuda_graph() {
+        if (instance != nullptr) {
+            CUDA_CHECK(cudaGraphExecDestroy(instance));
+        }
+        if (graph != nullptr) {
+            CUDA_CHECK(cudaGraphDestroy(graph));
+        }
+    }
+    cudaGraph_t graph = nullptr;
+    cudaGraphExec_t instance = nullptr;
+    size_t num_nodes = 0;
+    std::vector<cudaGraphNode_t> nodes;
+    std::vector<cudaKernelNodeParams> params;
+    bool disable_due_to_gpu_arch = false;
+    bool disable_due_to_too_many_updates = false;
+    bool disable_due_to_failed_graph_capture = false;
+    int number_consecutive_updates = 0;
+    std::vector<ggml_graph_node_properties> ggml_graph_properties;
+    std::vector<char **> updated_kernel_arg;
+#endif
+};
+
 struct ggml_backend_cuda_context {
     int device;
     std::string name;
@@ -534,6 +572,8 @@ struct ggml_backend_cuda_context {
     cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
     cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
 
+    std::unique_ptr<ggml_cuda_graph> cuda_graph;
+
     explicit ggml_backend_cuda_context(int device) :
         device(device),
         name(GGML_CUDA_NAME + std::to_string(device)) {
index 75e50c98561235c572365217281afeaf5a7ae247..830e2d75661625fdc440ca877b87328b083a41b8 100644 (file)
@@ -727,7 +727,6 @@ static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict_
 }
 
 to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
-    int id;
     switch (type) {
         case GGML_TYPE_Q4_0:
             return dequantize_row_q4_0_cuda;
@@ -738,8 +737,7 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
         case GGML_TYPE_Q5_1:
             return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
         case GGML_TYPE_Q8_0:
-            CUDA_CHECK(cudaGetDevice(&id));
-            if (ggml_cuda_info().devices[id].cc >= CC_PASCAL) {
+            if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= CC_PASCAL) {
                 return dequantize_block_q8_0_f16_cuda;
             }
             return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
index 16d9c8fffb4b39df4a955d7d382f88b0975192d4..12d741f017d3b4bdebae6d5b719d30c737ac2aea 100644 (file)
@@ -459,3 +459,32 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     ggml_cuda_cpy(ctx, src0, dst);
 }
+
+void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
+    if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
+            return (void*) cpy_f32_f16<cpy_1_f32_f32>;
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
+            return (void*) cpy_f32_f16<cpy_1_f32_f16>;
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
+            return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
+            return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
+            return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
+            return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
+            return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
+            return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
+    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
+            return (void*) cpy_f32_f16<cpy_1_f32_f16>;
+    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
+            return (void*) cpy_f32_f16<cpy_1_f16_f32>;
+    } else {
+        fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
+                ggml_type_name(src0->type), ggml_type_name(src1->type));
+        GGML_ASSERT(false);
+    }
+}
+
index f0b2c453bfe6ad9dcc5123853051c3935902b662..7961674266ee1c67d69ab4077b84646d5e61009d 100644 (file)
@@ -5,3 +5,5 @@
 void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
 
 void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
index 60d6616a860f70d7530028a621740347e0ecf72f..7948f1b1237fa81e04a1ecaaaebd5b1c2c31b47e 100644 (file)
@@ -1735,8 +1735,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
     const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
     const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
 
-    int id;
-    CUDA_CHECK(cudaGetDevice(&id));
+    int id = ggml_cuda_get_device();
     const int compute_capability = ggml_cuda_info().devices[id].cc;
 
     int mmq_x, mmq_y, nwarps;
@@ -1780,8 +1779,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
     const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
     const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
 
-    int id;
-    CUDA_CHECK(cudaGetDevice(&id));
+    int id = ggml_cuda_get_device();
     const int compute_capability = ggml_cuda_info().devices[id].cc;
 
     int mmq_x, mmq_y, nwarps;
@@ -1825,8 +1823,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
     const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
     const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
 
-    int id;
-    CUDA_CHECK(cudaGetDevice(&id));
+    int id = ggml_cuda_get_device();
     const int compute_capability = ggml_cuda_info().devices[id].cc;
 
     int mmq_x, mmq_y, nwarps;
@@ -1870,8 +1867,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
     const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
     const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
 
-    int id;
-    CUDA_CHECK(cudaGetDevice(&id));
+    int id = ggml_cuda_get_device();
     const int compute_capability = ggml_cuda_info().devices[id].cc;
 
     int mmq_x, mmq_y, nwarps;
@@ -1915,8 +1911,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
     const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
     const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
 
-    int id;
-    CUDA_CHECK(cudaGetDevice(&id));
+    int id = ggml_cuda_get_device();
     const int compute_capability = ggml_cuda_info().devices[id].cc;
 
     int mmq_x, mmq_y, nwarps;
@@ -1960,8 +1955,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
     const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
     const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
 
-    int id;
-    CUDA_CHECK(cudaGetDevice(&id));
+    int id = ggml_cuda_get_device();
     const int compute_capability = ggml_cuda_info().devices[id].cc;
 
     int mmq_x, mmq_y, nwarps;
@@ -2007,8 +2001,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
 
 #if QK_K == 256
 
-    int id;
-    CUDA_CHECK(cudaGetDevice(&id));
+    int id = ggml_cuda_get_device();
     const int compute_capability = ggml_cuda_info().devices[id].cc;
 
     int mmq_x, mmq_y, nwarps;
@@ -2053,8 +2046,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
     const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
     const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
 
-    int id;
-    CUDA_CHECK(cudaGetDevice(&id));
+    int id = ggml_cuda_get_device();
     const int compute_capability = ggml_cuda_info().devices[id].cc;
 
     int mmq_x, mmq_y, nwarps;
@@ -2098,8 +2090,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
     const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
     const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
 
-    int id;
-    CUDA_CHECK(cudaGetDevice(&id));
+    int id = ggml_cuda_get_device();
     const int compute_capability = ggml_cuda_info().devices[id].cc;
 
     int mmq_x, mmq_y, nwarps;
@@ -2143,8 +2134,7 @@ static void ggml_mul_mat_q6_K_q8_1_cuda(
     const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
     const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
 
-    int id;
-    CUDA_CHECK(cudaGetDevice(&id));
+    int id = ggml_cuda_get_device();
     const int compute_capability = ggml_cuda_info().devices[id].cc;
 
     int mmq_x, mmq_y, nwarps;
index 3965590017b953a9bac360264986adab6992d840..65cc1bcaad697e93dfe1360dadeef77fc3a3008c 100644 (file)
@@ -89,8 +89,7 @@ static void mul_mat_vec_q_cuda(
     GGML_ASSERT(ncols_x % qk == 0);
     GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
 
-    int id;
-    CUDA_CHECK(cudaGetDevice(&id));
+    int id = ggml_cuda_get_device();
 
     int64_t nwarps = 1;
     int64_t rows_per_cuda_block = 1;
@@ -328,8 +327,7 @@ void ggml_cuda_op_mul_mat_vec_q(
 
     const int64_t ne0 = dst->ne[0];
 
-    int id;
-    CUDA_CHECK(cudaGetDevice(&id));
+    int id = ggml_cuda_get_device();
 
     // the main device has a larger memory buffer to hold the results from all GPUs
     // nrows_dst == nrows of the matrix that the kernel writes into
index 6e3617d1cdbd503588de0ec1c1bfcde156bd49c1..1405e066e86a29a4e56c6bda59e4f43ab3f40b17 100644 (file)
@@ -28,5 +28,4 @@ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     memcpy(&scale, dst->op_params, sizeof(float));
 
     scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream);
-    CUDA_CHECK(cudaGetLastError());
 }