From 312c75ac555b7379bce3e3698f478142613622ea Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Tue, 4 Jun 2024 06:03:58 -0700 Subject: [PATCH] PR #13265: [GPU][NFC] Outline multihost HLO runner creation logic to enable better tests. Imported from GitHub PR https://github.com/openxla/xla/pull/13265 Copybara import of the project: -- d687c6f117ba11d98653288fb0c1cb8cebd61ed6 by Ilia Sergachev : Give GPU compiler class access to PJRT key-value store. -- 04cb8447c90d46f171e5a0f0261e81f8690d1ab4 by Ilia Sergachev : [GPU][NFC] Outline multihost HLO runner creation logic to enable better tests. Merging this change closes #13265 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/13265 from openxla:refactor_runner 04cb8447c90d46f171e5a0f0261e81f8690d1ab4 PiperOrigin-RevId: 640126763 --- xla/tools/multihost_hlo_runner/BUILD | 5 ++ .../functional_hlo_runner.cc | 50 +++++++++++++++++ .../functional_hlo_runner.h | 8 ++- .../multihost_hlo_runner/hlo_runner_main.cc | 54 ++----------------- 4 files changed, 67 insertions(+), 50 deletions(-) diff --git a/xla/tools/multihost_hlo_runner/BUILD b/xla/tools/multihost_hlo_runner/BUILD index f610579078eff..f52625c1d82c8 100644 --- a/xla/tools/multihost_hlo_runner/BUILD +++ b/xla/tools/multihost_hlo_runner/BUILD @@ -40,6 +40,7 @@ xla_cc_binary( "//xla/pjrt:pjrt_client", "//xla/pjrt/distributed", "//xla/pjrt/distributed:client", + "//xla/pjrt/distributed:key_value_store_interface", "//xla/pjrt/distributed:service", "//xla/service:cpu_plugin", "//xla/tsl/util:command_line_flags", @@ -76,6 +77,7 @@ xla_cc_binary( "//xla/pjrt:pjrt_client", "//xla/pjrt/distributed", "//xla/pjrt/distributed:client", + "//xla/pjrt/distributed:key_value_store_interface", "//xla/pjrt/distributed:service", "//xla/service:cpu_plugin", "//xla/service:gpu_plugin", @@ -115,7 +117,10 @@ cc_library( "//xla/pjrt:pjrt_executable", "//xla/pjrt:pjrt_future", "//xla/pjrt/cpu:cpu_client", + "//xla/pjrt/distributed", "//xla/pjrt/distributed:client", + "//xla/pjrt/distributed:key_value_store_interface", + "//xla/pjrt/distributed:service", "//xla/pjrt/gpu:se_gpu_pjrt_client", "//xla/service:computation_layout", "//xla/service:computation_placer_hdr", diff --git a/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc b/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc index 8402e8d92752a..9ab4e40080f02 100644 --- a/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc +++ b/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc @@ -47,6 +47,9 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/pjrt/cpu/cpu_client.h" #include "xla/pjrt/distributed/client.h" +#include "xla/pjrt/distributed/distributed.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" +#include "xla/pjrt/distributed/service.h" #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #include "xla/pjrt/host_memory_spaces.h" #include "xla/pjrt/pjrt_client.h" @@ -72,6 +75,53 @@ limitations under the License. namespace xla { +absl::StatusOr> GetPjRtClient( + absl::string_view device_type, absl::string_view address, int node_id, + int num_nodes, bool enable_mock_nccl, + std::unique_ptr& service, + std::shared_ptr& kv_store) { + if (device_type == "host") { + CHECK_EQ(num_nodes, 1); + return xla::FunctionalHloRunner::CreateHostClient(); + } + + if (device_type != "gpu") { + return absl::UnimplementedError(device_type); + } + + if (enable_mock_nccl) { + CHECK_GT(num_nodes, 1); + return xla::FunctionalHloRunner::CreateMockGpuClient(num_nodes); + } else { + if (num_nodes == 1) { + return xla::FunctionalHloRunner::CreateGpuClient(); + } else { + CHECK_GT(address.length(), 0); + // Multinode. Start service on task 0. + if (node_id == 0) { + std::string coordinator_bind_address = + "[::]:" + std::string(address).substr(address.rfind(':') + 1); + xla::CoordinationServiceImpl::Options options; + options.num_nodes = num_nodes; + auto status_or = xla::GetDistributedRuntimeService( + coordinator_bind_address, options); + TF_QCHECK_OK(status_or.status()); + service = std::move(status_or.value()); + } + xla::DistributedRuntimeClient::Options options; + options.node_id = node_id; + options.init_timeout = absl::Seconds(300); + auto distributed_client = + GetDistributedRuntimeClient(std::string(address), options); + TF_QCHECK_OK(distributed_client->Connect()); + kv_store = GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"gpu:"); + return xla::FunctionalHloRunner::CreateGpuClient(distributed_client, + node_id, num_nodes); + } + } +} + namespace { // Creates an HloModule from the given proto. absl::StatusOr> HloTextToModule( diff --git a/xla/tools/multihost_hlo_runner/functional_hlo_runner.h b/xla/tools/multihost_hlo_runner/functional_hlo_runner.h index 5711ba1c60398..d651c2bc52c12 100644 --- a/xla/tools/multihost_hlo_runner/functional_hlo_runner.h +++ b/xla/tools/multihost_hlo_runner/functional_hlo_runner.h @@ -28,7 +28,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/literal.h" -#include "xla/pjrt/distributed/client.h" +#include "xla/pjrt/distributed/distributed.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/statusor.h" @@ -38,6 +38,12 @@ limitations under the License. namespace xla { +absl::StatusOr> GetPjRtClient( + absl::string_view device_type, absl::string_view address, int node_id, + int num_nodes, bool enable_mock_nccl, + std::unique_ptr& service, + std::shared_ptr& kv_store); + // Supported input formats for the input HLO module. enum class InputFormat { kText, // Text format returned by HloModule::ToString(). diff --git a/xla/tools/multihost_hlo_runner/hlo_runner_main.cc b/xla/tools/multihost_hlo_runner/hlo_runner_main.cc index dc512badf5067..249dfab4c975d 100644 --- a/xla/tools/multihost_hlo_runner/hlo_runner_main.cc +++ b/xla/tools/multihost_hlo_runner/hlo_runner_main.cc @@ -18,20 +18,17 @@ limitations under the License. #include #include #include -#include #include #include #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" -#include "absl/time/time.h" #include "xla/debug_options_flags.h" #include "xla/pjrt/distributed/client.h" -#include "xla/pjrt/distributed/distributed.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/distributed/service.h" #include "xla/pjrt/pjrt_client.h" -#include "xla/statusor.h" #include "xla/tools/multihost_hlo_runner/functional_hlo_runner.h" #include "xla/tools/multihost_hlo_runner/hlo_runner_flags.h" #include "xla/tsl/util/command_line_flags.h" @@ -71,49 +68,6 @@ Mock GPU usage: Tip: If the input generation takes too long or uses too much host memory, consider using --hlo_argument_mode=uninitialized. )"; - -absl::StatusOr> GetClient( - const std::string& device_type_str, bool enable_mock_nccl, int num_nodes, - const std::string& address_str, int task_id, - std::unique_ptr* service) { - if (device_type_str == "host") { - CHECK_EQ(num_nodes, 1); - return xla::FunctionalHloRunner::CreateHostClient(); - } - - CHECK_EQ(device_type_str, "gpu"); - - if (enable_mock_nccl) { - CHECK_GT(num_nodes, 1); - return xla::FunctionalHloRunner::CreateMockGpuClient(num_nodes); - } else { - if (num_nodes == 1) { - return xla::FunctionalHloRunner::CreateGpuClient(); - } else { - CHECK_GT(address_str.length(), 0); - // Multinode. Start service on task 0. - if (task_id == 0) { - std::string coordinator_bind_address = - "[::]:" + address_str.substr(address_str.rfind(":") + 1); - xla::CoordinationServiceImpl::Options options; - options.num_nodes = num_nodes; - auto status_or = xla::GetDistributedRuntimeService( - coordinator_bind_address, options); - TF_QCHECK_OK(status_or.status()); - *service = std::move(status_or.value()); - } - xla::DistributedRuntimeClient::Options options; - options.node_id = task_id; - options.init_timeout = absl::Seconds(300); - auto distributed_client = - xla::GetDistributedRuntimeClient(address_str, options); - TF_QCHECK_OK(distributed_client->Connect()); - return xla::FunctionalHloRunner::CreateGpuClient(distributed_client, - task_id, num_nodes); - } - } -} - } // namespace int main(int argc, char** argv) { @@ -187,9 +141,11 @@ int main(int argc, char** argv) { << "Can only dump output literal when single input file is specified"; std::unique_ptr service; + std::shared_ptr kv_store; + absl::StatusOr> client = - GetClient(device_type_str, enable_mock_nccl, num_nodes, address_str, - task_id, &service); + xla::GetPjRtClient(device_type_str, address_str, task_id, num_nodes, + enable_mock_nccl, service, kv_store); TF_QCHECK_OK(client.status()); for (int c = 1; c < argc; c++) {