Skip to content

Commit

Permalink
PR openxla#13265: [GPU][NFC] Outline multihost HLO runner creation lo…
Browse files Browse the repository at this point in the history
…gic to enable better tests.

Imported from GitHub PR openxla#13265

Copybara import of the project:

--
d687c6f by Ilia Sergachev <[email protected]>:

Give GPU compiler class access to PJRT key-value store.

--
04cb844 by Ilia Sergachev <[email protected]>:

[GPU][NFC] Outline multihost HLO runner creation logic to enable better tests.

Merging this change closes openxla#13265

COPYBARA_INTEGRATE_REVIEW=openxla#13265 from openxla:refactor_runner 04cb844
PiperOrigin-RevId: 640126763
  • Loading branch information
sergachev authored and copybara-github committed Jun 4, 2024
1 parent 9f818aa commit 312c75a
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 50 deletions.
5 changes: 5 additions & 0 deletions xla/tools/multihost_hlo_runner/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ xla_cc_binary(
"//xla/pjrt:pjrt_client",
"//xla/pjrt/distributed",
"//xla/pjrt/distributed:client",
"//xla/pjrt/distributed:key_value_store_interface",
"//xla/pjrt/distributed:service",
"//xla/service:cpu_plugin",
"//xla/tsl/util:command_line_flags",
Expand Down Expand Up @@ -76,6 +77,7 @@ xla_cc_binary(
"//xla/pjrt:pjrt_client",
"//xla/pjrt/distributed",
"//xla/pjrt/distributed:client",
"//xla/pjrt/distributed:key_value_store_interface",
"//xla/pjrt/distributed:service",
"//xla/service:cpu_plugin",
"//xla/service:gpu_plugin",
Expand Down Expand Up @@ -115,7 +117,10 @@ cc_library(
"//xla/pjrt:pjrt_executable",
"//xla/pjrt:pjrt_future",
"//xla/pjrt/cpu:cpu_client",
"//xla/pjrt/distributed",
"//xla/pjrt/distributed:client",
"//xla/pjrt/distributed:key_value_store_interface",
"//xla/pjrt/distributed:service",
"//xla/pjrt/gpu:se_gpu_pjrt_client",
"//xla/service:computation_layout",
"//xla/service:computation_placer_hdr",
Expand Down
50 changes: 50 additions & 0 deletions xla/tools/multihost_hlo_runner/functional_hlo_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ limitations under the License.
#include "xla/literal_util.h"
#include "xla/pjrt/cpu/cpu_client.h"
#include "xla/pjrt/distributed/client.h"
#include "xla/pjrt/distributed/distributed.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/pjrt/distributed/service.h"
#include "xla/pjrt/gpu/se_gpu_pjrt_client.h"
#include "xla/pjrt/host_memory_spaces.h"
#include "xla/pjrt/pjrt_client.h"
Expand All @@ -72,6 +75,53 @@ limitations under the License.

namespace xla {

absl::StatusOr<std::unique_ptr<xla::PjRtClient>> GetPjRtClient(
absl::string_view device_type, absl::string_view address, int node_id,
int num_nodes, bool enable_mock_nccl,
std::unique_ptr<xla::DistributedRuntimeService>& service,
std::shared_ptr<xla::KeyValueStoreInterface>& kv_store) {
if (device_type == "host") {
CHECK_EQ(num_nodes, 1);
return xla::FunctionalHloRunner::CreateHostClient();
}

if (device_type != "gpu") {
return absl::UnimplementedError(device_type);
}

if (enable_mock_nccl) {
CHECK_GT(num_nodes, 1);
return xla::FunctionalHloRunner::CreateMockGpuClient(num_nodes);
} else {
if (num_nodes == 1) {
return xla::FunctionalHloRunner::CreateGpuClient();
} else {
CHECK_GT(address.length(), 0);
// Multinode. Start service on task 0.
if (node_id == 0) {
std::string coordinator_bind_address =
"[::]:" + std::string(address).substr(address.rfind(':') + 1);
xla::CoordinationServiceImpl::Options options;
options.num_nodes = num_nodes;
auto status_or = xla::GetDistributedRuntimeService(
coordinator_bind_address, options);
TF_QCHECK_OK(status_or.status());
service = std::move(status_or.value());
}
xla::DistributedRuntimeClient::Options options;
options.node_id = node_id;
options.init_timeout = absl::Seconds(300);
auto distributed_client =
GetDistributedRuntimeClient(std::string(address), options);
TF_QCHECK_OK(distributed_client->Connect());
kv_store = GetDistributedKeyValueStore(distributed_client,
/*key_prefix=*/"gpu:");
return xla::FunctionalHloRunner::CreateGpuClient(distributed_client,
node_id, num_nodes);
}
}
}

namespace {
// Creates an HloModule from the given proto.
absl::StatusOr<std::unique_ptr<HloModule>> HloTextToModule(
Expand Down
8 changes: 7 additions & 1 deletion xla/tools/multihost_hlo_runner/functional_hlo_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ limitations under the License.
#include "absl/types/span.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/literal.h"
#include "xla/pjrt/distributed/client.h"
#include "xla/pjrt/distributed/distributed.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/statusor.h"
Expand All @@ -38,6 +38,12 @@ limitations under the License.

namespace xla {

absl::StatusOr<std::unique_ptr<xla::PjRtClient>> GetPjRtClient(
absl::string_view device_type, absl::string_view address, int node_id,
int num_nodes, bool enable_mock_nccl,
std::unique_ptr<xla::DistributedRuntimeService>& service,
std::shared_ptr<xla::KeyValueStoreInterface>& kv_store);

// Supported input formats for the input HLO module.
enum class InputFormat {
kText, // Text format returned by HloModule::ToString().
Expand Down
54 changes: 5 additions & 49 deletions xla/tools/multihost_hlo_runner/hlo_runner_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,17 @@ limitations under the License.
#include <iostream>
#include <memory>
#include <string>
#include <string_view>
#include <utility>
#include <vector>

#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/time/time.h"
#include "xla/debug_options_flags.h"
#include "xla/pjrt/distributed/client.h"
#include "xla/pjrt/distributed/distributed.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/pjrt/distributed/service.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/statusor.h"
#include "xla/tools/multihost_hlo_runner/functional_hlo_runner.h"
#include "xla/tools/multihost_hlo_runner/hlo_runner_flags.h"
#include "xla/tsl/util/command_line_flags.h"
Expand Down Expand Up @@ -71,49 +68,6 @@ Mock GPU usage:
Tip: If the input generation takes too long or uses too much host memory,
consider using --hlo_argument_mode=uninitialized.
)";

absl::StatusOr<std::unique_ptr<xla::PjRtClient>> GetClient(
const std::string& device_type_str, bool enable_mock_nccl, int num_nodes,
const std::string& address_str, int task_id,
std::unique_ptr<xla::DistributedRuntimeService>* service) {
if (device_type_str == "host") {
CHECK_EQ(num_nodes, 1);
return xla::FunctionalHloRunner::CreateHostClient();
}

CHECK_EQ(device_type_str, "gpu");

if (enable_mock_nccl) {
CHECK_GT(num_nodes, 1);
return xla::FunctionalHloRunner::CreateMockGpuClient(num_nodes);
} else {
if (num_nodes == 1) {
return xla::FunctionalHloRunner::CreateGpuClient();
} else {
CHECK_GT(address_str.length(), 0);
// Multinode. Start service on task 0.
if (task_id == 0) {
std::string coordinator_bind_address =
"[::]:" + address_str.substr(address_str.rfind(":") + 1);
xla::CoordinationServiceImpl::Options options;
options.num_nodes = num_nodes;
auto status_or = xla::GetDistributedRuntimeService(
coordinator_bind_address, options);
TF_QCHECK_OK(status_or.status());
*service = std::move(status_or.value());
}
xla::DistributedRuntimeClient::Options options;
options.node_id = task_id;
options.init_timeout = absl::Seconds(300);
auto distributed_client =
xla::GetDistributedRuntimeClient(address_str, options);
TF_QCHECK_OK(distributed_client->Connect());
return xla::FunctionalHloRunner::CreateGpuClient(distributed_client,
task_id, num_nodes);
}
}
}

} // namespace

int main(int argc, char** argv) {
Expand Down Expand Up @@ -187,9 +141,11 @@ int main(int argc, char** argv) {
<< "Can only dump output literal when single input file is specified";

std::unique_ptr<xla::DistributedRuntimeService> service;
std::shared_ptr<xla::KeyValueStoreInterface> kv_store;

absl::StatusOr<std::unique_ptr<xla::PjRtClient>> client =
GetClient(device_type_str, enable_mock_nccl, num_nodes, address_str,
task_id, &service);
xla::GetPjRtClient(device_type_str, address_str, task_id, num_nodes,
enable_mock_nccl, service, kv_store);
TF_QCHECK_OK(client.status());

for (int c = 1; c < argc; c++) {
Expand Down

0 comments on commit 312c75a

Please sign in to comment.