Skip to content

Commit

Permalink
[xla:gpu] Do not use ncclSend and ncclRecv directly and use NcclApi p…
Browse files Browse the repository at this point in the history
…art #2

PiperOrigin-RevId: 599037622
  • Loading branch information
ezhulenev authored and copybara-github committed Jan 17, 2024
1 parent f842a3e commit 588171c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 53 deletions.
33 changes: 7 additions & 26 deletions xla/service/gpu/nccl_recv_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,18 @@ limitations under the License.

#include "xla/service/gpu/nccl_recv_thunk.h"

#include <cstdint>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/gpu/nccl_api.h"
#include "xla/stream_executor/stream.h"

#if XLA_ENABLE_XCCL
#include "xla/stream_executor/gpu/gpu_stream.h"
#endif
#include "tsl/platform/errors.h"

namespace xla {
namespace gpu {
Expand Down Expand Up @@ -102,7 +100,6 @@ absl::Status RunRecv(NcclP2PConfig::SourceTargetMapEntry source_target,
DeviceBufferPair& buffer, se::Stream& stream,
ncclComm_t comm, absl::string_view device_string,
int64_t current_id) {
#if XLA_ENABLE_XCCL
// Determine the source IDs for this instance. The source ID is the ID for
// the peer that will copy its data to this instance. If there is no source,
// just memzero() the destination buffer.
Expand All @@ -116,23 +113,12 @@ absl::Status RunRecv(NcclP2PConfig::SourceTargetMapEntry source_target,
VLOG(3) << absl::StreamFormat("%s : id = %d, source_id = %d", device_string,
current_id, source_id.value_or(-1));

TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier,
ToNcclDataTypeAndCountMultiplier(
buffer.element_type, Thunk::kNcclCollectivePermute));
ncclDataType_t dtype = dtype_and_multiplier.first;
int64_t element_count = buffer.element_count * dtype_and_multiplier.second;

se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream);

// Receive data from the source peer to the destination buffer.
if (source_id) {
VLOG(3) << absl::StreamFormat(
"%s : Calling ncclRecv(recvbuff=%p, count=%d, peer=%d comm=%p, "
"stream=%p)",
device_string, dest_addr.opaque(), element_count, *source_id,
static_cast<const void*>(comm), gpu_stream);
XLA_NCCL_RETURN_IF_ERROR(ncclRecv(dest_addr.opaque(), element_count, dtype,
*source_id, comm, gpu_stream));
TF_RETURN_IF_ERROR(NcclApi::Recv(
dest_addr, buffer.element_type, buffer.element_count, *source_id,
reinterpret_cast<NcclApi::NcclCommHandle>(comm), &stream));

} else {
// If there is no source peer, i.e. no sender to this instance, zero out
// the destination buffer.
Expand All @@ -141,11 +127,6 @@ absl::Status RunRecv(NcclP2PConfig::SourceTargetMapEntry source_target,
stream.ThenMemZero(&dest_addr, dest_addr.size());
}
return absl::OkStatus();
#else // XLA_ENABLE_XCCL
return Unimplemented(
"NCCL support is not available: this binary was not built with a CUDA "
"compiler, which is necessary to build the NCCL source library.");
#endif // XLA_ENABLE_XCCL
}

} // namespace gpu
Expand Down
34 changes: 7 additions & 27 deletions xla/service/gpu/nccl_send_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,18 @@ limitations under the License.

#include "xla/service/gpu/nccl_send_thunk.h"

#include <cstdint>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/gpu/nccl_api.h"
#include "xla/stream_executor/stream.h"

#if XLA_ENABLE_XCCL
#include "xla/stream_executor/gpu/gpu_stream.h"
#endif
#include "tsl/platform/errors.h"

namespace xla {
namespace gpu {
Expand Down Expand Up @@ -102,10 +100,8 @@ absl::Status RunSend(NcclP2PConfig::SourceTargetMapEntry source_target,
DeviceBufferPair& buffer, se::Stream& stream,
ncclComm_t comm, absl::string_view device_string,
int64_t current_id) {
#if XLA_ENABLE_XCCL
// Determine the target IDs for this instance. The target ID is the ID
// to which this instance will copy its data.

int device_ordinal = stream.parent()->device_ordinal();
VLOG(3) << "Performing collective permute from device ordinal: "
<< device_ordinal << "current_id " << current_id;
Expand All @@ -116,30 +112,14 @@ absl::Status RunSend(NcclP2PConfig::SourceTargetMapEntry source_target,
VLOG(3) << absl::StreamFormat("%s : id = %d, target_id = %d", device_string,
current_id, target_id.value_or(-1));

TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier,
ToNcclDataTypeAndCountMultiplier(
buffer.element_type, Thunk::kNcclCollectivePermute));
ncclDataType_t dtype = dtype_and_multiplier.first;
int64_t element_count = buffer.element_count * dtype_and_multiplier.second;

se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream);

// Send source buffer to target peer if needed.
if (target_id) {
VLOG(3) << absl::StreamFormat(
"%s : Calling ncclSend(sendbuff=%p, count=%d, peer=%d "
"comm=%p, stream=%p)",
device_string, src_addr.opaque(), element_count, *target_id,
static_cast<const void*>(comm), gpu_stream);
XLA_NCCL_RETURN_IF_ERROR(ncclSend(src_addr.opaque(), element_count, dtype,
*target_id, comm, gpu_stream));
TF_RETURN_IF_ERROR(NcclApi::Send(
src_addr, buffer.element_type, buffer.element_type, *target_id,
reinterpret_cast<NcclApi::NcclCommHandle>(comm), &stream));
}

return absl::OkStatus();
#else // XLA_ENABLE_XCCL
return Unimplemented(
"NCCL support is not available: this binary was not built with a CUDA "
"compiler, which is necessary to build the NCCL source library.");
#endif // XLA_ENABLE_XCCL
}

} // namespace gpu
Expand Down

0 comments on commit 588171c

Please sign in to comment.