From 3f1db3a1298bd4ab04e4861e7d970387068a889c Mon Sep 17 00:00:00 2001 From: botbw Date: Fri, 8 Nov 2024 05:14:07 +0000 Subject: [PATCH] [chore] enable notify when submitting tensor write task --- csrc/pthread_backend.cpp | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/csrc/pthread_backend.cpp b/csrc/pthread_backend.cpp index ede6da5..fc43425 100644 --- a/csrc/pthread_backend.cpp +++ b/csrc/pthread_backend.cpp @@ -89,22 +89,31 @@ void PthreadAsyncIO::sync_h2d() { void PthreadAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional pinned) { auto stream = c10::cuda::getCurrentCUDAStream(); + if (!t.is_cuda()) { + this->h2d_in_progress.fetch_sub(1); // already moved to cpu + if (this->h2d_in_progress.load() == 0) { // notify when all h2d are completed and safe to optimizer.step() + std::lock_guard lock(this->mtx); + cv.notify_one(); + } + } auto fut = this->pool.submit_task( [this, fd, t, offset, pinned, stream] { - at::cuda::CUDAStreamGuard guard(stream); // https://pytorch.org/cppdocs/notes/tensor_cuda_stream.html torch::Tensor cpu_tensor; if (t.is_cuda()) { + at::cuda::CUDAStreamGuard guard(stream); // https://pytorch.org/cppdocs/notes/tensor_cuda_stream.html if (pinned.has_value()) { pinned.value().copy_(t, /*non_blocking*/ false); cpu_tensor = pinned.value(); } else { cpu_tensor = t.to(t.options().device(c10::DeviceType::CPU), /*non_blocking*/ false, /*copy*/ false); // modified from torch::Tensor::cpu() } - } - this->h2d_in_progress.fetch_sub(1); - if (this->h2d_in_progress.load() == 0) { // notify when all h2d are completed and safe to optimizer.step() - std::lock_guard lock(this->mtx); - cv.notify_one(); + this->h2d_in_progress.fetch_sub(1); + if (this->h2d_in_progress.load() == 0) { // notify when all h2d are completed and safe to optimizer.step() + std::lock_guard lock(this->mtx); + cv.notify_one(); + } + } else { + cpu_tensor = t; } void *buf = cpu_tensor.data_ptr(); size_t n_bytes = cpu_tensor.numel() * cpu_tensor.element_size();