-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: reduce debug messages in distributed model * fix(style): correct ruff lint entries in pyproject and fix style * chore: bump version to 0.1.0a0 Versions devX with X > 0 are not supported. Also, it would be better to synchronize optimum tpu and TGI versions. * chore(tgi): refactor generator to separate base class * refactor: split distributed model and mailboxes This will allow the reuse of the mailboxes in other parts of the code. * chore(distributed): remove model config from Distributed model It's not necessary and removing it makes the mailboxes more generic. * refactor(multiprocessing): allow flexible parameters - Allow command_data to contain an arbitrary amount of parameters. This will allow more flexibility in upcoming uses. - remove "model" mention from the mailboxes variables, to allow them to keep being as generic as possible. * feat(generator): add sharding support for models All models are now loaded in a sharded way, with each shard being loaded in a separate process that will run in parallel! 🔀 This allows for better memory management, in particular for larger models that would not fit in one TPU. `TpuGenerator` class is now renamed `TpuGeneratorSingleThread`, and a new `TpuGenerator` class is introduced that will handle the processes communication. * feat: reduce log messages in multiprocess generator Only ordinal with rank 0 will now log messages. * test(tgi): add a test for gemma-7b in TGI Note that this test will only run with --runslow parameter set. * chore(MP): Add documentation and return type hints for mailboxes * review(distributed_model): avoid checking len
- Loading branch information
1 parent
f9a6593
commit 377ae1b
Showing
20 changed files
with
506 additions
and
171 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import torch_xla.core.xla_model as xm | ||
from loguru import logger | ||
|
||
|
||
""" | ||
This is just a shallow wrapper to loguru's logger, to only log messages on the master ordinal and avoid repeating | ||
messages on all the other ordinals threads. | ||
""" | ||
|
||
def warning(message: str): | ||
if xm.get_ordinal() == 0: | ||
logger.opt(depth=1).warning(message) | ||
|
||
def info(message: str): | ||
if xm.get_ordinal() == 0: | ||
logger.opt(depth=1).info(message) | ||
|
||
def debug(message: str): | ||
if xm.get_ordinal() == 0: | ||
logger.opt(depth=1).debug(message) | ||
|
||
def error(message: str): | ||
if xm.get_ordinal() == 0: | ||
logger.opt(depth=1).error(message) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from multiprocessing.managers import ListProxy | ||
from typing import List | ||
|
||
import torch.multiprocessing as mp | ||
|
||
|
||
class RootMailbox: | ||
"""A simple multiprocessing mailbox to communicate between the root process and the agents.""" | ||
def __init__(self, manager: mp.Manager): | ||
self.root_bell = manager.Event() | ||
self.root_command = manager.list() | ||
self.agent_ready = manager.Event() | ||
self.output_data = manager.list() | ||
|
||
def send(self, command: int, *args) -> ListProxy: | ||
"""Send a command and arguments to the agents and wait for the response. | ||
Args: | ||
command (int): Command to send to the agents. | ||
*args: Arguments to send to the agents. | ||
Returns: | ||
A list containing the response from the agents. | ||
""" | ||
# First wait until agent is ready to receive commands | ||
self.agent_ready.wait() | ||
self.agent_ready.clear() | ||
|
||
self.root_command[:] = [command, *args] | ||
self.root_bell.set() | ||
# wait again until agent is ready, meaning command has been processed | ||
self.agent_ready.wait() | ||
ret = self.output_data | ||
return ret | ||
|
||
|
||
class AgentMailbox: | ||
"""The agent mailbox to communicate with the root process.""" | ||
def __init__(self, root_mailbox: RootMailbox): | ||
self.root_bell = root_mailbox.root_bell | ||
self.root_command = root_mailbox.root_command | ||
self.agent_ready = root_mailbox.agent_ready | ||
self.output_data = root_mailbox.output_data | ||
|
||
def receive(self) -> ListProxy: | ||
"""Wait for a command from the root process and return it. | ||
Returns: | ||
A list containing the command and arguments from the root process. | ||
""" | ||
self.root_bell.wait() | ||
self.root_bell.clear() | ||
return self.root_command | ||
|
||
def send(self, *data): | ||
"""Send the response to the root process. | ||
Args: | ||
*data: Data to send to the root process. | ||
""" | ||
self.output_data[:] = [*data] | ||
|
||
@property | ||
def command_data(self) -> tuple[int, List]: | ||
"""Property helper to split command and arguments sent by the root process. | ||
Returns: | ||
A tuple containing the command and arguments. | ||
""" | ||
command = self.root_command[0] | ||
data = self.root_command[1:] | ||
return command, data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.