Skip to content

Commit

Permalink
Sharding in tgi (#31)
Browse files Browse the repository at this point in the history
* feat: reduce debug messages in distributed model

* fix(style): correct ruff lint entries in pyproject and fix style

* chore: bump version to 0.1.0a0

Versions devX with X > 0 are not supported. Also, it would be better to
synchronize optimum tpu and TGI versions.

* chore(tgi): refactor generator to separate base class

* refactor: split distributed model and mailboxes

This will allow the reuse of the mailboxes in other parts of the code.

* chore(distributed): remove model config from Distributed model

It's not necessary and removing it makes the mailboxes more generic.

* refactor(multiprocessing): allow flexible parameters

- Allow command_data to contain an arbitrary amount of parameters. This
  will allow more flexibility in upcoming uses.
- remove "model" mention from the mailboxes variables, to allow them to
  keep being as generic as possible.

* feat(generator): add sharding support for models

All models are now loaded in a sharded way, with each shard being loaded
in a separate process that will run in parallel! 🔀

This allows for better memory management, in particular for larger
models that would not fit in one TPU.

`TpuGenerator` class is now renamed `TpuGeneratorSingleThread`, and a
new `TpuGenerator` class is introduced that will handle the processes
communication.

* feat: reduce log messages in multiprocess generator

Only ordinal with rank 0 will now log messages.

* test(tgi): add a test for gemma-7b in TGI

Note that this test will only run with --runslow parameter set.

* chore(MP): Add documentation and return type hints for mailboxes

* review(distributed_model): avoid checking len
  • Loading branch information
tengomucho authored Apr 30, 2024
1 parent f9a6593 commit 377ae1b
Show file tree
Hide file tree
Showing 20 changed files with 506 additions and 171 deletions.
8 changes: 5 additions & 3 deletions examples/text-generation/generation_gemma.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
#!/usr/bin/python

import torch
import time
import datetime
import os
import platform
import time
from typing import List

import torch
import torch_xla.core.xla_model as xm
from optimum.tpu.modeling import AutoModelForCausalLM
from transformers import AutoTokenizer, StaticCache

from optimum.tpu.modeling import AutoModelForCausalLM


os.environ["PJRT_DEVICE"] = "TPU"

Expand Down
4 changes: 2 additions & 2 deletions optimum/tpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .version import __version__, VERSION # noqa: F401
from .modeling import AutoModelForCausalLM # noqa: F401
from .modeling import AutoModelForCausalLM
from .version import VERSION, __version__
86 changes: 13 additions & 73 deletions optimum/tpu/distributed_model.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
# ruff: noqa: E402
import torch
import os
from enum import Enum
from typing import Dict

from loguru import logger


os.environ["PJRT_DEVICE"] = "TPU"

import torch.multiprocessing as mp
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch.multiprocessing as mp

from optimum.tpu.modeling import AutoModelForCausalLM
from transformers import PretrainedConfig

from .xla_mp_comm import AgentMailbox, RootMailbox


class ModelCommand(Enum):
Expand All @@ -21,63 +22,6 @@ class ModelCommand(Enum):
DECODE = 2


class RootMailbox:
def __init__(self, manager: mp.Manager):
self.root_bell = manager.Event()
self.root_command = manager.list()
self.model_ready = manager.Event()
self.output_data = manager.Value(torch.Tensor, torch.tensor([]))
self.model_config = manager.Value(PretrainedConfig, None)

@property
def config(self):
while True:
config = self.model_config.get()
if config is not None:
return config

def send(self, command: ModelCommand, data: Dict = None):
# First wait until model is ready to receive commands
logger.debug(f" MM Command {command} waiting for model to be ready")
self.model_ready.wait()
self.model_ready.clear()

self.root_command[:] = [command, data]
self.root_bell.set()
logger.debug(f" MM Command {command} sent")
# wait again until model is ready, meaning command has been processed
self.model_ready.wait()
ret = self.output_data.get()
logger.debug(f" MM Command {command} output shape {ret.shape}")
return ret


class AgentMailbox:
def __init__(self, root_mailbox: RootMailbox):
self.root_bell = root_mailbox.root_bell
self.root_command = root_mailbox.root_command
self.model_ready = root_mailbox.model_ready
self.output_data = root_mailbox.output_data
self.model_config = root_mailbox.model_config

def receive(self):
self.root_bell.wait()
self.root_bell.clear()
return self.root_command

def send(self, data: torch.Tensor):
logger.debug(f" MM Enqueueing data {data.shape}")
# Data needs to be moved to CPU before setting it
self.output_data.set(data.cpu())
logger.debug(" MM Enqueueing data done")

@property
def command_data(self):
command = self.root_command[0]
data = self.root_command[1]
return command, data


def _mp_fn(rank, model_id, root_mailbox: RootMailbox, sample_fn: callable):
device = xm.xla_device()
world_size = xm.xrt_world_size()
Expand All @@ -93,8 +37,6 @@ def _mp_fn(rank, model_id, root_mailbox: RootMailbox, sample_fn: callable):
model = AutoModelForCausalLM.from_pretrained(model_id)
model = model.eval()
model.to(device)
if rank == 0:
mailbox.model_config.set(model.config)

def get_next_token(inputs):
# move inputs to device in a new dict to avoid conflicts
Expand All @@ -109,18 +51,20 @@ def get_next_token(inputs):
next_token = sample_fn(outputs)
xm.mark_step()
logger.debug(f"Rank {rank} sending next_tokens {next_token.shape}")
mailbox.send(next_token)
# Data needs to be moved to CPU before setting it
mailbox.send(next_token.cpu())

while True:
if rank == 0:
mailbox.model_ready.set()
mailbox.agent_ready.set()
logger.debug(f"Rank {rank} waiting for commands")
mailbox.receive()
# Wait for rank 0 to receive command
xm.rendezvous("start")

logger.debug(f"Rank {rank} waiting for command at rendezvous")
command, inputs = mailbox.command_data
command, data = mailbox.command_data
inputs = data[0] if data else None
if command == ModelCommand.PREFILL:
logger.debug(f"Rank {rank} PREFILL")
get_next_token(inputs)
Expand All @@ -130,7 +74,7 @@ def get_next_token(inputs):
elif command == ModelCommand.LEAVE:
logger.debug(f"Rank {rank} LEAVE")
# Set model to ready
mailbox.model_ready.set()
mailbox.agent_ready.set()
break


Expand All @@ -149,11 +93,11 @@ def __init__(self, model_id: str, sample_fn: callable):

def prefill(self, **model_args):
assert self.mailbox is not None, "DistributedModel is not initialized"
return self.mailbox.send(ModelCommand.PREFILL, model_args)
return self.mailbox.send(ModelCommand.PREFILL, model_args)[0]

def decode(self, **model_args):
assert self.mailbox is not None, "DistributedModel is not initialized"
return self.mailbox.send(ModelCommand.PREFILL, model_args)
return self.mailbox.send(ModelCommand.PREFILL, model_args)[0]

def leave(self):
if self.mailbox is None:
Expand All @@ -164,9 +108,5 @@ def leave(self):
logger.debug("Model loop finished")
self.mailbox = None

@property
def config(self):
return self.mailbox.config

def __del__(self):
self.leave()
3 changes: 2 additions & 1 deletion optimum/tpu/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from typing import Any

from loguru import logger
from transformers import AutoModelForCausalLM as BaseAutoModelForCausalLM, AutoConfig
from transformers import AutoConfig
from transformers import AutoModelForCausalLM as BaseAutoModelForCausalLM

from optimum.tpu.modeling_gemma import TpuGemmaForCausalLM

Expand Down
8 changes: 4 additions & 4 deletions optimum/tpu/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,15 @@
""" PyTorch Gemma model."""

import math
import re
import warnings
from typing import List, Optional, Tuple, Union
import re

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss

from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_attn_mask_utils import (
Expand All @@ -34,6 +33,7 @@
)
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.models.gemma.configuration_gemma import GemmaConfig
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
from transformers.utils import (
add_start_docstrings,
Expand All @@ -44,15 +44,15 @@
replace_return_docstrings,
)
from transformers.utils.import_utils import is_torch_fx_available
from transformers.models.gemma.configuration_gemma import GemmaConfig

from optimum.tpu.xla_model_parallel import (
RowParallelLinear,
ColumnParallelLinear,
RowParallelLinear,
get_model_parallel_rank,
get_model_parallel_world_size,
)


if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
Expand Down
3 changes: 2 additions & 1 deletion optimum/tpu/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@

from pkg_resources import parse_version

__version__ = "0.1.0.dev2"

__version__ = "0.1.0a0"
VERSION = parse_version(__version__)
24 changes: 24 additions & 0 deletions optimum/tpu/xla_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch_xla.core.xla_model as xm
from loguru import logger


"""
This is just a shallow wrapper to loguru's logger, to only log messages on the master ordinal and avoid repeating
messages on all the other ordinals threads.
"""

def warning(message: str):
if xm.get_ordinal() == 0:
logger.opt(depth=1).warning(message)

def info(message: str):
if xm.get_ordinal() == 0:
logger.opt(depth=1).info(message)

def debug(message: str):
if xm.get_ordinal() == 0:
logger.opt(depth=1).debug(message)

def error(message: str):
if xm.get_ordinal() == 0:
logger.opt(depth=1).error(message)
3 changes: 2 additions & 1 deletion optimum/tpu/xla_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from copy import deepcopy
from dataclasses import dataclass
import os
from typing import Callable, List, Optional, Tuple

import torch
Expand All @@ -26,6 +26,7 @@
import torch.nn.init as init
from torch.nn.parameter import Parameter


EPS = torch.finfo(torch.float32).eps

USE_CUDA = os.environ.get("USE_CUDA", False)
Expand Down
72 changes: 72 additions & 0 deletions optimum/tpu/xla_mp_comm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from multiprocessing.managers import ListProxy
from typing import List

import torch.multiprocessing as mp


class RootMailbox:
"""A simple multiprocessing mailbox to communicate between the root process and the agents."""
def __init__(self, manager: mp.Manager):
self.root_bell = manager.Event()
self.root_command = manager.list()
self.agent_ready = manager.Event()
self.output_data = manager.list()

def send(self, command: int, *args) -> ListProxy:
"""Send a command and arguments to the agents and wait for the response.
Args:
command (int): Command to send to the agents.
*args: Arguments to send to the agents.
Returns:
A list containing the response from the agents.
"""
# First wait until agent is ready to receive commands
self.agent_ready.wait()
self.agent_ready.clear()

self.root_command[:] = [command, *args]
self.root_bell.set()
# wait again until agent is ready, meaning command has been processed
self.agent_ready.wait()
ret = self.output_data
return ret


class AgentMailbox:
"""The agent mailbox to communicate with the root process."""
def __init__(self, root_mailbox: RootMailbox):
self.root_bell = root_mailbox.root_bell
self.root_command = root_mailbox.root_command
self.agent_ready = root_mailbox.agent_ready
self.output_data = root_mailbox.output_data

def receive(self) -> ListProxy:
"""Wait for a command from the root process and return it.
Returns:
A list containing the command and arguments from the root process.
"""
self.root_bell.wait()
self.root_bell.clear()
return self.root_command

def send(self, *data):
"""Send the response to the root process.
Args:
*data: Data to send to the root process.
"""
self.output_data[:] = [*data]

@property
def command_data(self) -> tuple[int, List]:
"""Property helper to split command and arguments sent by the root process.
Returns:
A tuple containing the command and arguments.
"""
command = self.root_command[0]
data = self.root_command[1:]
return command, data
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,17 @@ line-length = 119
target-version = ['py38']
extend-exclude = '.ipynb'

[lint.ruff]
[tool.ruff]
# Never enforce `E501` (line length violations).
ignore = ["C901", "E501", "E741", "W605"]
select = ["C", "E", "F", "I", "W"]
lint.ignore = ["C901", "E501", "E741", "W605"]
lint.select = ["C", "E", "F", "I", "W"]
line-length = 119

# Ignore import violations in all `__init__.py` files.
[lint.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["E402", "F401", "F403", "F811"]

[lint.ruff.isort]
[tool.ruff.lint.isort]
lines-after-imports = 2
known-first-party = ["optimum.tpu"]

Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest


# See https://stackoverflow.com/a/61193490/217945 for run_slow
def pytest_addoption(parser):
parser.addoption(
Expand All @@ -18,4 +19,4 @@ def pytest_collection_modifyitems(config, items):
skip_slow = pytest.mark.skip(reason="need --runslow option to run")
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)
item.add_marker(skip_slow)
Loading

0 comments on commit 377ae1b

Please sign in to comment.