diff --git a/CHANGELOG.md b/CHANGELOG.md index d8670d9..5d09e64 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,3 +38,12 @@ about the new command, please refer to the [README](README.md). - 🌟 Abstract classes for new models/dataloaders. - 🌟 Allows Federated Learning with Personalization. - Personalization allows you to leverage each client local data to obtain models that are better adjusted to their own data distribution. You can run the `cv` task in order to try out this feature. + + +## [1.0.1] - 2023-07-29 + +🔋 This release removes the restriction of the minimum number of GPUs available in FLUTE, +allowing users to run experiments using a single-GPU worker by instantiating both: Server +and clients on the same device. For more documentation about how to run an experiments +using a single GPU, please refer to the [README](README.md). + diff --git a/README.md b/README.md index 941be13..0e0967c 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Welcome to FLUTE (Federated Learning Utilities for Testing and Experimentation), FLUTE is a pytorch-based orchestration environment enabling GPU or CPU-based FL simulations. The primary goal of FLUTE is to enable researchers to rapidly prototype and validate their ideas. Features include: - large scale simulation (millions of clients, sampling tens of thousands per round) -- multi-GPU and multi-node orchestration +- single/multi GPU and multi-node orchestration - local or global differential privacy - model quantization - a variety of standard optimizers and aggregation methods @@ -74,11 +74,19 @@ FLUTE uses torch.distributed API as its main communication backbone, supporting After this initial setup, you can use the data created for the integration test for a first local run. Note that this data needs to be download manually inside the `testing` folder, for more instructions please look at [the README file inside `testing`](testing/README.md). +For single-GPU runs: + +``` +python -m torch.distributed.run --nproc_per_node=1 e2e_trainer.py -dataPath ./testing -outputPath scratch -config testing/hello_world_nlg_gru.yaml -task nlg_gru -backend nccl +``` + +For multi-GPU runs (3 GPUs): + ``` python -m torch.distributed.run --nproc_per_node=3 e2e_trainer.py -dataPath ./testing -outputPath scratch -config testing/hello_world_nlg_gru.yaml -task nlg_gru -backend nccl ``` -This config uses 1 node with 3 workers (1 server, 2 clients). The config file `testing/hello_world_nlg_gru.yaml` has some comments explaining the major sections and some important details; essentially, it consists in a very short experiment where a couple of iterations are done for just a few clients. A `scratch` folder will be created containing detailed logs. +The config file `testing/hello_world_nlg_gru.yaml` has some comments explaining the major sections and some important details; essentially, it consists in a very short experiment where a couple of iterations are done for just a few clients. A `scratch` folder will be created containing detailed logs. ## Documentation diff --git a/core/evaluation.py b/core/evaluation.py index 8462f7e..bb4bf9b 100644 --- a/core/evaluation.py +++ b/core/evaluation.py @@ -20,7 +20,7 @@ class Evaluation(): - def __init__(self, config, model_path, process_testvalidate, idx_val_clients, idx_test_clients): + def __init__(self, config, model_path, process_testvalidate, idx_val_clients, idx_test_clients, single_worker): self.config = config self.model_path = model_path @@ -29,6 +29,7 @@ def __init__(self, config, model_path, process_testvalidate, idx_val_clients, id self.idx_val_clients = idx_val_clients self.idx_test_clients = idx_test_clients self.send_dicts = config['server_config'].get('send_dicts', False) + self.single_worker = single_worker super().__init__() def run(self, eval_list, req, metric_logger=None): @@ -155,7 +156,7 @@ def run_distributed_evaluation(self, mode, clients, model): total = 0 self.logits = {'predictions': [], 'probabilities': [], 'labels': []} server_data = (0.0, model, 0) - for result in self.process_testvalidate(clients, server_data, mode): + for result in self.process_testvalidate(clients, server_data, mode, self.single_worker): output, metrics, count = result val_metrics = {key: {'value':0, 'higher_is_better': False} for key in metrics.keys()} if total == 0 else val_metrics @@ -190,7 +191,7 @@ def make_eval_clients(dataset, config): ''' total = sum(dataset.num_samples) - clients = federated.size() - 1 + clients = federated.size() - 1 if federated.size()>1 else federated.size() delta = total / clients + 1 threshold = delta current_users_idxs = list() diff --git a/core/federated.py b/core/federated.py index 03f6d26..d381aee 100644 --- a/core/federated.py +++ b/core/federated.py @@ -4,6 +4,7 @@ import os import cProfile import logging +import threading import torch import torch.distributed as dist @@ -21,6 +22,7 @@ COMMAND_TERMINATE = 10 COMMAND_TESTVAL = 11 COMMAND_SYNC_NODES = 9 +GLOBAL_MESSAGE = None def encode_string(word, string_to_int = True): """ Encodes/Decodes the dictionary keys into an array of integers to be sent @@ -254,7 +256,7 @@ class Server: is actually stored inside of the object. """ @staticmethod - def dispatch_clients(clients, server_data, command, mode=None, do_profiling=False): + def dispatch_clients(clients, server_data, command, mode=None, do_profiling=False, single_worker=None): """Perform the orchestration between Clients and Workers. This function does the following: @@ -285,6 +287,9 @@ def dispatch_clients(clients, server_data, command, mode=None, do_profiling=Fals Returns: Generator of results. """ + # Single GPU flag + single_gpu = True if size()==1 else False + print_rank(f"Single GPU flag Server: {single_gpu}", loglevel=logging.DEBUG) # Some cleanup torch.cuda.empty_cache() @@ -298,60 +303,81 @@ def dispatch_clients(clients, server_data, command, mode=None, do_profiling=Fals # Update lr + model parameters each round for all workers lr, model_params, nround = server_data - for worker_rank in range(1, size()): - _send(COMMAND_UPDATE, worker_rank) - _send(lr,worker_rank) - _send_gradients(model_params, worker_rank) - _send(float(nround),worker_rank) - print_rank(f"Finished sending lr {lr} and n_params {len(model_params)} to worker {worker_rank} - round {nround}", loglevel=logging.DEBUG) - - print_rank(f"Finished sending server_data to workers", loglevel=logging.DEBUG) - - client_queue = clients.copy() - print_rank(f"Clients queue: {client_queue}", loglevel=logging.DEBUG) - free_nodes = list(range(1, size())) - results_list, node_request_map = [], [] - - # Initiate computation for all clients - while client_queue: + if not single_gpu: + for worker_rank in range(1, size()): + _send(COMMAND_UPDATE, worker_rank) + _send(lr,worker_rank) + _send_gradients(model_params, worker_rank) + _send(float(nround),worker_rank) + print_rank(f"Finished sending lr {lr} and n_params {len(model_params)} to worker {worker_rank} - round {nround}", loglevel=logging.DEBUG) + print_rank(f"Finished sending server_data to workers", loglevel=logging.DEBUG) + + client_queue = clients.copy() print_rank(f"Clients queue: {client_queue}", loglevel=logging.DEBUG) - assert len(free_nodes) > 0 - node = free_nodes.pop() - index = len(client_queue)-1 - client_to_process = client_queue.pop(index) - print_rank(f"Sending client {index} to worker {node}", loglevel=logging.DEBUG) - _send(command, node) # The command should indicate the worker which function to run on the client + free_nodes = list(range(1, size())) + results_list, node_request_map = [], [] + + # Initiate computation for all clients + while client_queue: + print_rank(f"Clients queue: {client_queue}", loglevel=logging.DEBUG) + assert len(free_nodes) > 0 + node = free_nodes.pop() + index = len(client_queue)-1 + client_to_process = client_queue.pop(index) + print_rank(f"Sending client {index} to worker {node}", loglevel=logging.DEBUG) + _send(command, node) # The command should indicate the worker which function to run on the client + + if command == COMMAND_TESTVAL: + _send(mode,node) # Only for test/val has a value + _send(index, node) # Worker receives the index of the client to pop + elif command == COMMAND_TRAIN: + _send(client_to_process, node) + print_rank(f"Finished assigning worker {node}, free nodes {free_nodes}", loglevel=logging.DEBUG) + + if dist.get_backend() == "nccl": + append_async_requests(node_request_map, node) + idle_nodes = None + else: + idle_nodes = sync_idle_nodes(client_queue, free_nodes) + + # Waits until receive the output from all ranks + if not free_nodes: + print_rank(f"Waiting for a workers, free nodes {free_nodes}, reqs_lst {node_request_map}", loglevel=logging.DEBUG) + while len(free_nodes) == 0: + node_request_map, results_list, free_nodes = receive_workers_output(node_request_map, results_list, free_nodes, command, idle_nodes) + for output in results_list: + yield output + results_list = [] + + # Wait for all workers to finish + while (len(node_request_map)) != 0: + node_request_map, results_list, free_nodes = receive_workers_output(node_request_map, results_list, free_nodes, command, idle_nodes) + + for output in results_list: + yield output + results_list = [] + else: + # For a single-GPU execution, there is no P2P communication in the same GPU. Using threats to coordinate. + + global GLOBAL_MESSAGE + GLOBAL_MESSAGE = server_data if command == COMMAND_TESTVAL: - _send(mode,node) # Only for test/val has a value - _send(index, node) # Worker receives the index of the client to pop + t1 = threading.Thread(target=single_worker.trigger_evaluate) + t1.start() + t1.join() + yield GLOBAL_MESSAGE elif command == COMMAND_TRAIN: - _send(client_to_process, node) - print_rank(f"Finished assigning worker {node}, free nodes {free_nodes}", loglevel=logging.DEBUG) - - if dist.get_backend() == "nccl": - append_async_requests(node_request_map, node) - idle_nodes = None - else: - idle_nodes = sync_idle_nodes(client_queue, free_nodes) - - # Waits until receive the output from all ranks - if not free_nodes: - print_rank(f"Waiting for a workers, free nodes {free_nodes}, reqs_lst {node_request_map}", loglevel=logging.DEBUG) - while len(free_nodes) == 0: - node_request_map, results_list, free_nodes = receive_workers_output(node_request_map, results_list, free_nodes, command, idle_nodes) - for output in results_list: - yield output - results_list = [] - - # Wait for all workers to finish - while (len(node_request_map)) != 0: - node_request_map, results_list, free_nodes = receive_workers_output(node_request_map, results_list, free_nodes, command, idle_nodes) - - for output in results_list: - yield output - results_list = [] - + total_clients = clients.copy() + + for client_id in total_clients: + GLOBAL_MESSAGE = lr, model_params, nround, client_id + t1 = threading.Thread(target=single_worker.trigger_train) + t1.start() + t1.join() + result = GLOBAL_MESSAGE + yield result + if do_profiling: profiler.disable() print_profiler(profiler) @@ -361,7 +387,7 @@ def dispatch_clients(clients, server_data, command, mode=None, do_profiling=Fals torch.cuda.synchronize() if torch.cuda.is_available() else None @staticmethod - def process_clients(clients, server_data): + def process_clients(clients, server_data, single_worker): """Ask workers to perform training on Clients. Args: @@ -372,10 +398,10 @@ def process_clients(clients, server_data): Returns: Generator of results. """ - return Server.dispatch_clients(clients, server_data, COMMAND_TRAIN) + return Server.dispatch_clients(clients, server_data, COMMAND_TRAIN, single_worker=single_worker) @staticmethod - def process_testvalidate(clients, server_data, mode): + def process_testvalidate(clients, server_data, mode, single_worker): """Ask workers to perform test/val on Clients. Args: @@ -388,7 +414,7 @@ def process_testvalidate(clients, server_data, mode): """ mode = [-2] if mode == "test" else [2] - return Server.dispatch_clients(clients, server_data, COMMAND_TESTVAL, mode) + return Server.dispatch_clients(clients, server_data, COMMAND_TESTVAL, mode, single_worker=single_worker) @staticmethod def terminate_workers(terminate=True): @@ -438,142 +464,190 @@ def run(self): and performs different actions on the Client assigned depending on the command received. """ - - while True: # keeps listening for incoming server calls - - # Initialize tensors -- required by torch.distributed - command, client_idx, mode = 0, 0, 0 # int - lr, nround = torch.zeros(1), torch.zeros(1) # float - - # Read command - command = _recv(command) - print_rank(f"Command received {command} on worker {rank()}", loglevel=logging.DEBUG) - - # Receive server data -- lr, model_params - if command == COMMAND_UPDATE: - print_rank(f"COMMMAND_UPDATE received {rank()}", loglevel=logging.DEBUG) - lr = _recv(lr, 0) - model_params = _recv_gradients(0) - nround = _recv(nround, 0) - server_data = (lr, model_params, int(nround)) - print_rank(f"Received lr: {lr} and n_params: {len(model_params)} - round {nround}", loglevel=logging.DEBUG) - - elif command == COMMAND_TRAIN: - print_rank(f"COMMMAND_TRAIN received {rank()}", loglevel=logging.DEBUG) - - # Init profiler in training worker - profiler = None - if self.do_profiling: - profiler = cProfile.Profile() - profiler.enable() - - # Receive client id from Server - client_idx = _recv(client_idx) - print_rank(f"Cliend idx received from Server: {client_idx}", loglevel=logging.DEBUG) - - # Instantiate client - client_to_process = Client( - [client_idx], - self.config, - self.config['client_config']['type'] == 'optimization') - - # Execute Client.get_data() - client_data = client_to_process.get_client_data() - - # Execute Client.process_round() - output = client_to_process.process_round(client_data, server_data, self.model, self.data_path) - - # Send output back to Server - if dist.get_backend() == "nccl": - # ASYNC mode -- enabled only for nccl backend - ack = to_device(torch.tensor(1)) - dist.isend(tensor=ack, dst=0) - _send_train_output(output) - else: - # SYNC mode -- gloo backend does not have a non-blocking way to check if the operation is completed - gather_objects = [output for i in range(size())] + # Single GPU flag + single_gpu = True if size()==1 else False + print_rank(f"Single GPU flag Client: {single_gpu}", loglevel=logging.DEBUG) + + if not single_gpu: + while True: # keeps listening for incoming server calls + + # Initialize tensors -- required by torch.distributed + command, client_idx, mode = 0, 0, 0 # int + lr, nround = torch.zeros(1), torch.zeros(1) # float + + # Read command + command = _recv(command) + print_rank(f"Command received {command} on worker {rank()}", loglevel=logging.DEBUG) + + # Receive server data -- lr, model_params + if command == COMMAND_UPDATE: + print_rank(f"COMMMAND_UPDATE received {rank()}", loglevel=logging.DEBUG) + lr = _recv(lr, 0) + model_params = _recv_gradients(0) + nround = _recv(nround, 0) + server_data = (lr, model_params, int(nround)) + print_rank(f"Received lr: {lr} and n_params: {len(model_params)} - round {nround}", loglevel=logging.DEBUG) + + elif command == COMMAND_TRAIN: + print_rank(f"COMMMAND_TRAIN received {rank()}", loglevel=logging.DEBUG) + + # Init profiler in training worker + profiler = None + if self.do_profiling: + profiler = cProfile.Profile() + profiler.enable() + + # Receive client id from Server + client_idx = _recv(client_idx) + print_rank(f"Cliend idx received from Server: {client_idx}", loglevel=logging.DEBUG) + + # Instantiate client + client_to_process = Client( + [client_idx], + self.config, + self.config['client_config']['type'] == 'optimization') + + # Execute Client.get_data() + client_data = client_to_process.get_client_data() + + # Execute Client.process_round() + output = client_to_process.process_round(client_data, server_data, self.model, self.data_path) + + # Send output back to Server + if dist.get_backend() == "nccl": + # ASYNC mode -- enabled only for nccl backend + ack = to_device(torch.tensor(1)) + dist.isend(tensor=ack, dst=0) + _send_train_output(output) + else: + # SYNC mode -- gloo backend does not have a non-blocking way to check if the operation is completed + gather_objects = [output for i in range(size())] + output = [None for _ in gather_objects] + dist.all_gather_object(output, gather_objects[rank()]) + + # Some cleanup + torch.cuda.empty_cache() + torch.cuda.synchronize() if torch.cuda.is_available() else None + + if self.do_profiling: + profiler.disable() + print_profiler(profiler) + + elif command == COMMAND_TESTVAL: + print_rank(f"COMMMAND_TESTVAL received {rank()}", loglevel=logging.DEBUG) + + # Init profiler in validation worker + profiler = None + if self.do_profiling: + profiler = cProfile.Profile() + profiler.enable() + + # Receive mode and client id from Server + mode = _recv(mode) + mode = "test" if mode == -2 else "val" + client_idx = _recv(client_idx) + print_rank(f"Client idx received from Server: {client_idx}, {mode}", loglevel=logging.DEBUG) + + # Get client and dataset + clients = self.val_clients if mode == "val" else self.test_clients + dataset = self.val_dataset if mode == "val" else self.test_dataset + clients_queue = clients.copy() + assert 0 <= client_idx < len(clients_queue) + client_to_process = clients_queue.pop(client_idx) + + # Execute Client.get_data() + client_data = client_to_process.get_client_data(dataset) + + # Execute Client.run_testvalidate() + output = client_to_process.run_testvalidate(client_data, server_data, mode, self.model) + + # Send output back to Server + if dist.get_backend() == "nccl": + # ASYNC mode -- enabled only for nccl backend + _, metrics, num_instances = output + metrics['num']= {'value': float(num_instances), 'higher_is_better': False} + output = metrics + print_rank(f"Worker {rank()} output {output}", loglevel=logging.DEBUG) + ack = to_device(torch.tensor(1)) + dist.isend(tensor=ack, dst=0) + _send_metrics(output) + else: + # SYNC mode -- gloo backend does not have a non-blocking way to check if the operation is completed + gather_objects = [output for i in range(size())] + output = [None for _ in gather_objects] + dist.all_gather_object(output, gather_objects[rank()]) + print_rank(f"Worker {rank()} sent output back", loglevel=logging.DEBUG) + + # Some cleanup + torch.cuda.empty_cache() + torch.cuda.synchronize() if torch.cuda.is_available() else None + + if self.do_profiling: + profiler.disable() + print_profiler(profiler) + + elif command == COMMAND_TERMINATE: + print_rank(f"COMMMAND_TERMINATE received {rank()}", loglevel=logging.DEBUG) + + # Some cleanup + torch.cuda.empty_cache() + torch.cuda.synchronize() if torch.cuda.is_available() else None + return + + elif command == COMMAND_SYNC_NODES: # Only for sync calls + print_rank(f"COMMMAND_SYNC_NODES received {rank()}", loglevel=logging.DEBUG) + + gather_objects = [None for i in range(size())] output = [None for _ in gather_objects] dist.all_gather_object(output, gather_objects[rank()]) + print_rank(f"Worker IDLE {rank()} sent dummy output back", loglevel=logging.DEBUG) - # Some cleanup - torch.cuda.empty_cache() - torch.cuda.synchronize() if torch.cuda.is_available() else None - - if self.do_profiling: - profiler.disable() - print_profiler(profiler) + # Some cleanup + torch.cuda.empty_cache() + torch.cuda.synchronize() if torch.cuda.is_available() else None + else: + assert False, "unknown command" - elif command == COMMAND_TESTVAL: - print_rank(f"COMMMAND_TESTVAL received {rank()}", loglevel=logging.DEBUG) + def trigger_evaluate(self): + global GLOBAL_MESSAGE - # Init profiler in validation worker - profiler = None - if self.do_profiling: - profiler = cProfile.Profile() - profiler.enable() - - # Receive mode and client id from Server - mode = _recv(mode) - mode = "test" if mode == -2 else "val" - client_idx = _recv(client_idx) - print_rank(f"Client idx received from Server: {client_idx}, {mode}", loglevel=logging.DEBUG) - - # Get client and dataset - clients = self.val_clients if mode == "val" else self.test_clients - dataset = self.val_dataset if mode == "val" else self.test_dataset - clients_queue = clients.copy() - assert 0 <= client_idx < len(clients_queue) - client_to_process = clients_queue.pop(client_idx) - - # Execute Client.get_data() - client_data = client_to_process.get_client_data(dataset) - - # Execute Client.run_testvalidate() - output = client_to_process.run_testvalidate(client_data, server_data, mode, self.model) - - # Send output back to Server - if dist.get_backend() == "nccl": - # ASYNC mode -- enabled only for nccl backend - _, metrics, num_instances = output - metrics['num']= {'value': float(num_instances), 'higher_is_better': False} - output = metrics - print_rank(f"Worker {rank()} output {output}", loglevel=logging.DEBUG) - ack = to_device(torch.tensor(1)) - dist.isend(tensor=ack, dst=0) - _send_metrics(output) - else: - # SYNC mode -- gloo backend does not have a non-blocking way to check if the operation is completed - gather_objects = [output for i in range(size())] - output = [None for _ in gather_objects] - dist.all_gather_object(output, gather_objects[rank()]) - print_rank(f"Worker {rank()} sent output back", loglevel=logging.DEBUG) + lr, model_params, nround = GLOBAL_MESSAGE + server_data = (lr, model_params, int(nround)) + mode = "val" - # Some cleanup - torch.cuda.empty_cache() - torch.cuda.synchronize() if torch.cuda.is_available() else None + # Get client and dataset + clients = self.val_clients if mode == "val" else self.test_clients + dataset = self.val_dataset if mode == "val" else self.test_dataset + clients_queue = clients.copy() + client_to_process = clients_queue.pop() - if self.do_profiling: - profiler.disable() - print_profiler(profiler) + # Execute Client.get_data() + client_data = client_to_process.get_client_data(dataset) - elif command == COMMAND_TERMINATE: - print_rank(f"COMMMAND_TERMINATE received {rank()}", loglevel=logging.DEBUG) + # Execute Client.run_testvalidate() + output = client_to_process.run_testvalidate(client_data, server_data, mode, self.model) + _, metrics, num_instances = output + metrics['num']= {'value': float(num_instances), 'higher_is_better': False} + GLOBAL_MESSAGE = (_, metrics, num_instances) - # Some cleanup - torch.cuda.empty_cache() - torch.cuda.synchronize() if torch.cuda.is_available() else None - return + # Some cleanup + torch.cuda.empty_cache() + torch.cuda.synchronize() if torch.cuda.is_available() else None + + def trigger_train(self): + global GLOBAL_MESSAGE + lr, model_params, nround, client_idx = GLOBAL_MESSAGE + server_data = (lr, model_params, int(nround)) - elif command == COMMAND_SYNC_NODES: # Only for sync calls - print_rank(f"COMMMAND_SYNC_NODES received {rank()}", loglevel=logging.DEBUG) + # Instantiate client + client_to_process = Client([client_idx], self.config, self.config['client_config']['type'] == 'optimization') + + # Execute Client.get_data() + client_data = client_to_process.get_client_data() - gather_objects = [None for i in range(size())] - output = [None for _ in gather_objects] - dist.all_gather_object(output, gather_objects[rank()]) - print_rank(f"Worker IDLE {rank()} sent dummy output back", loglevel=logging.DEBUG) + # Execute Client.process_round() + GLOBAL_MESSAGE = client_to_process.process_round(client_data, server_data, self.model, self.data_path) - # Some cleanup - torch.cuda.empty_cache() - torch.cuda.synchronize() if torch.cuda.is_available() else None - else: - assert False, "unknown command" + # Some cleanup + torch.cuda.empty_cache() + torch.cuda.synchronize() if torch.cuda.is_available() else None \ No newline at end of file diff --git a/core/server.py b/core/server.py index a48e9b9..32a65b3 100644 --- a/core/server.py +++ b/core/server.py @@ -46,7 +46,7 @@ class OptimizationServer(federated.Server): def __init__(self, num_clients, model, optimizer, ss_scheduler, data_path, model_path, server_train_dataloader, - config, idx_val_clients, idx_test_clients): + config, idx_val_clients, idx_test_clients, single_worker): '''Implement Server's orchestration and aggregation. This is the main Server class, that actually implements orchestration @@ -88,7 +88,7 @@ def __init__(self, num_clients, model, optimizer, ss_scheduler, data_path, model self.val_freq = server_config['val_freq'] self.req_freq = server_config['rec_freq'] - self.evaluation = Evaluation(config, model_path, self.process_testvalidate, idx_val_clients, idx_test_clients) + self.evaluation = Evaluation(config, model_path, self.process_testvalidate, idx_val_clients, idx_test_clients, single_worker) # TODO: does this need to be adjusted for custom metrics? self.metrics = dict() @@ -106,6 +106,7 @@ def __init__(self, num_clients, model, optimizer, ss_scheduler, data_path, model self.list_of_train_data = config['client_config']['data_config']['train']['list_of_train_data'] self.data_path = data_path + self.single_worker = single_worker # Get max grad norm from data config if 'train' in server_config['data_config']: @@ -333,7 +334,7 @@ def log_metric(k, v): self.worker_trainer.model.zero_grad() print_rank(f"Clients sampled from server {sampled_idx_clients}", loglevel=logging.DEBUG) - for client_output in self.process_clients(sampled_idx_clients, server_data): + for client_output in self.process_clients(sampled_idx_clients, server_data, self.single_worker): # Process client output client_timestamp = client_output['ts'] client_stats = client_output['cs'] diff --git a/e2e_trainer.py b/e2e_trainer.py index 1df5f04..ef57f79 100644 --- a/e2e_trainer.py +++ b/e2e_trainer.py @@ -86,6 +86,7 @@ def run_worker(model_path, config, task, data_path, local_rank, backend): """ model_config = config["model_config"] server_config = config["server_config"] + client_config = config["client_config"] # Backend initialization WORLD_RANK = federated.rank() @@ -116,6 +117,20 @@ def run_worker(model_path, config, task, data_path, local_rank, backend): # Instantiate the Server object on the first thread if WORLD_RANK == 0: + + single_worker = None + if federated.size() == 1: + # For a single-GPU/CPU execution using NCCL, Server and Worker are instantiated in the same GPU. + single_worker = federated.Worker(model=model, + data_path=data_path, + do_profiling=client_config.get("do_profiling", False), + val_clients=val_clients, + test_clients=test_clients, + val_dataset = val_dataset, + test_dataset = test_dataset, + config= config) + single_worker.run() + try: print_rank('Server data preparation') @@ -152,6 +167,7 @@ def run_worker(model_path, config, task, data_path, local_rank, backend): config=config, idx_val_clients=idx_val_clients, idx_test_clients=idx_test_clients, + single_worker=single_worker, ) log_run_properties(config) @@ -166,7 +182,6 @@ def run_worker(model_path, config, task, data_path, local_rank, backend): else: # Instantiate client-processing Worker on remaining threads print_rank("Worker on node {}: process started".format(WORLD_RANK)) - client_config = config["client_config"] worker = federated.Worker( model=model, data_path=data_path,