diff --git a/src/torch_ucc.cpp b/src/torch_ucc.cpp index 941fb06..55ae70d 100644 --- a/src/torch_ucc.cpp +++ b/src/torch_ucc.cpp @@ -802,6 +802,11 @@ c10::intrusive_ptr ProcessGroupUCC::createProcessGroupUCC( void ProcessGroupUCC::initComm(c10::Device dev) { if (!comm) { +#ifdef USE_CUDA + if (dev.is_cuda()) { + c10::cuda::set_device(dev.index()); + } +#endif comm = CommPG::get_comm(comm_id, dev, &oob); comm->ucx_connect_eps(eps, &oob); comm->ucc_create_team(team, &oob);