Skip to content

Commit

Permalink
[chore] enable notify when submitting tensor write task
Browse files Browse the repository at this point in the history
  • Loading branch information
botbw committed Nov 8, 2024
1 parent 5da46cd commit 3f1db3a
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions csrc/pthread_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,22 +89,31 @@ void PthreadAsyncIO::sync_h2d() {

void PthreadAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned) {
auto stream = c10::cuda::getCurrentCUDAStream();
if (!t.is_cuda()) {
this->h2d_in_progress.fetch_sub(1); // already moved to cpu
if (this->h2d_in_progress.load() == 0) { // notify when all h2d are completed and safe to optimizer.step()
std::lock_guard<std::mutex> lock(this->mtx);
cv.notify_one();
}
}
auto fut = this->pool.submit_task(
[this, fd, t, offset, pinned, stream] {
at::cuda::CUDAStreamGuard guard(stream); // https://pytorch.org/cppdocs/notes/tensor_cuda_stream.html
torch::Tensor cpu_tensor;
if (t.is_cuda()) {
at::cuda::CUDAStreamGuard guard(stream); // https://pytorch.org/cppdocs/notes/tensor_cuda_stream.html
if (pinned.has_value()) {
pinned.value().copy_(t, /*non_blocking*/ false);
cpu_tensor = pinned.value();
} else {
cpu_tensor = t.to(t.options().device(c10::DeviceType::CPU), /*non_blocking*/ false, /*copy*/ false); // modified from torch::Tensor::cpu()
}
}
this->h2d_in_progress.fetch_sub(1);
if (this->h2d_in_progress.load() == 0) { // notify when all h2d are completed and safe to optimizer.step()
std::lock_guard<std::mutex> lock(this->mtx);
cv.notify_one();
this->h2d_in_progress.fetch_sub(1);
if (this->h2d_in_progress.load() == 0) { // notify when all h2d are completed and safe to optimizer.step()
std::lock_guard<std::mutex> lock(this->mtx);
cv.notify_one();
}
} else {
cpu_tensor = t;
}
void *buf = cpu_tensor.data_ptr();
size_t n_bytes = cpu_tensor.numel() * cpu_tensor.element_size();
Expand Down

0 comments on commit 3f1db3a

Please sign in to comment.