From: slaren Date: Mon, 18 Dec 2023 17:05:43 +0000 (+0100) Subject: cuda : fix synchronization with tensor get/set (#659) X-Git-Tag: upstream/0.0.1642~1171 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=c80e07e9d4724392aaf02cdf32d1a1fb7228bea9;p=pkg%2Fggml%2Fsources%2Fggml cuda : fix synchronization with tensor get/set (#659) --- diff --git a/src/ggml-cuda.cu b/src/ggml-cuda.cu index 2e07bc66..06da7b30 100644 --- a/src/ggml-cuda.cu +++ b/src/ggml-cuda.cu @@ -7057,6 +7057,7 @@ inline void ggml_cuda_op_upscale( (void) src1; (void) dst; + (void) src1_dd; } inline void ggml_cuda_op_pad( @@ -7073,6 +7074,7 @@ inline void ggml_cuda_op_pad( (void) src1; (void) dst; + (void) src1_dd; } inline void ggml_cuda_op_rms_norm( @@ -8953,7 +8955,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) { char * buf; CUDA_CHECK(cudaMalloc(&buf, size)); - char * buf_host = (char*)data + offset_split; + char * buf_host = (char *)data + offset_split; // set padding to 0 to avoid possible NaN values if (size > original_size) { @@ -9426,9 +9428,12 @@ static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, gg GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); - CUDA_CHECK(cudaMemcpy((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice)); + ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context; - UNUSED(buffer); + ggml_cuda_set_device(ctx->device); + CUDA_CHECK(cudaDeviceSynchronize()); + + CUDA_CHECK(cudaMemcpy((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice)); } static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { @@ -9436,9 +9441,12 @@ static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, co GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); - CUDA_CHECK(cudaMemcpy(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost)); + ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context; - UNUSED(buffer); + ggml_cuda_set_device(ctx->device); + CUDA_CHECK(cudaDeviceSynchronize()); + + CUDA_CHECK(cudaMemcpy(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost)); } static struct ggml_backend_buffer_i cuda_backend_buffer_interface = {