Skip to content

Commit

Permalink
feat: add implementation for parallel model proxy
Browse files Browse the repository at this point in the history
This will allow to execute models in a parallel way and interact with
them from the caller thread.
To see how it works in output mode, you can launch the test with debug
enabled this way:

DEBUG=1 pytest -s tests/test_parallel_proxy.py
  • Loading branch information
tengomucho committed Apr 5, 2024
1 parent 1784fd9 commit 2a98430
Show file tree
Hide file tree
Showing 2 changed files with 210 additions and 0 deletions.
170 changes: 170 additions & 0 deletions optimum/tpu/xla_parallel_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import torch
import os
from enum import Enum

os.environ["PJRT_DEVICE"] = "TPU"

import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch.multiprocessing as mp

from optimum.tpu.modeling import TpuModelForCausalLM
from typing import Dict


DEBUG = False
if os.environ.get("DEBUG", "0") == "1":
DEBUG = True


def debug(*args):
if DEBUG:
print(*args)


class ModelCommand(Enum):
LEAVE = 0
PREFILL = 1
DECODE = 2


class RootMailbox:
def __init__(self, manager: mp.Manager):
self.root_bell = manager.Event()
self.root_command = manager.list()
self.model_ready = manager.Event()
self.output_data = manager.Value(torch.Tensor, torch.tensor([]))

def send(self, command: ModelCommand, data: Dict = None):
# First wait until model is ready to receive commands
debug(f" MM Command {command} waiting for model to be ready")
self.model_ready.wait()
self.model_ready.clear()

self.root_command[:] = [command, data]
self.root_bell.set()
debug(f" MM Command {command} sent")
# wait again until model is ready, meaning command has been processed
self.model_ready.wait()
ret = self.output_data.get()
debug(f" MM Command {command} output shape {ret.shape}")
return ret


class AgentMailbox:
def __init__(self, root_mailbox: RootMailbox):
self.root_bell = root_mailbox.root_bell
self.root_command = root_mailbox.root_command
self.model_ready = root_mailbox.model_ready
self.output_data = root_mailbox.output_data

def receive(self):
self.root_bell.wait()
self.root_bell.clear()
return self.root_command

def send(self, data: torch.Tensor):
debug(f" MM Enqueueing data {data.shape}")
# Data needs to be moved to CPU before setting it
self.output_data.set(data.cpu())
debug(" MM Enqueueing data done")

@property
def command_data(self):
command = self.root_command[0]
data = self.root_command[1]
return command, data


def _mp_fn(rank, model_id, root_mailbox: RootMailbox, sample_fn: callable):
device = xm.xla_device()
world_size = xm.xrt_world_size()
# create agent mailbox out of root's one
mailbox = AgentMailbox(root_mailbox)

debug(
f"Rank {rank} on {device} real device {xm.xla_real_devices([device])} ordinal {xm.get_ordinal()} "
+ f"world size {world_size}"
)

# Model loading and sharding should happen here
model = TpuModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
model = model.eval()
model.to(device)

def get_next_token(inputs):
# move inputs to device in a new dict to avoid conflicts
model_inputs = {}
for key, value in inputs.items():
model_inputs[key] = value.to(device)
outputs = model(**model_inputs, return_dict=False)[0]
xm.mark_step()
# consider adding a rendezvous here
if rank == 0:
debug(f"Rank {rank} getting tokens")
next_token = sample_fn(outputs)
xm.mark_step()
debug(f"Rank {rank} sending next_tokens {next_token.shape}")
mailbox.send(next_token)

while True:
if rank == 0:
mailbox.model_ready.set()
debug(f"Rank {rank} waiting for commands")
mailbox.receive()
# Wait for rank 0 to receive command
xm.rendezvous("start")

debug(f"Rank {rank} waiting for command at rendezvous")
command, inputs = mailbox.command_data
if command == ModelCommand.PREFILL:
debug(f"Rank {rank} PREFILL")
get_next_token(inputs)
elif command == ModelCommand.DECODE:
debug(f"Rank {rank} DECODE")
get_next_token(inputs)
elif command == ModelCommand.LEAVE:
debug(f"Rank {rank} LEAVE")
# Set model to ready
mailbox.model_ready.set()
break


def model_loop_fn(*args):
"""Spawn processes in the TPUs forwarding arguments"""
xmp.spawn(_mp_fn, args=(args), join=True, daemon=False)


class ModelProxy:
def __init__(self, model_id: str, sample_fn: callable):
manager = mp.Manager()
self.mailbox = RootMailbox(manager)

self.model_loop = mp.Process(target=model_loop_fn, args=(model_id, self.mailbox, sample_fn))
self.model_loop.start()

def prefill(self, **model_args):
assert self.mailbox is not None, "ModelProxy is not initialized"
return self.mailbox.send(ModelCommand.PREFILL, model_args)

def decode(self, **model_args):
assert self.mailbox is not None, "ModelProxy is not initialized"
return self.mailbox.send(ModelCommand.PREFILL, model_args)

def leave(self):
if self.mailbox is None:
return
self.mailbox.send(ModelCommand.LEAVE)
debug("Joining...")
self.model_loop.join()
debug("Model loop finished")
self.mailbox = None

def __del__(self):
self.leave()


def sample_greedy(logits):
next_logits = logits[:, -1]
next_token_id = torch.argmax(next_logits, dim=-1)[:, None].int()
return next_token_id
40 changes: 40 additions & 0 deletions tests/test_parallel_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os
from optimum.tpu.xla_parallel_proxy import ModelProxy
from transformers import AutoTokenizer
import torch


def sample_greedy(logits):
next_logits = logits[:, -1]
next_token_id = torch.argmax(next_logits, dim=-1)[:, None].int()
return next_token_id


def test_parallel_proxy_prefill():
# This model will not actually shard gpt2, but it ensures model can be loaded in a parallel way and
# that the proxy can be used to prefill the model.
# NOTE: if environment variable DEBUG=1 is set, the test will be much more verbose.
model_id = "openai-community/gpt2"
# Disable tokenizers parallelism to avoid deadlocks
os.environ["TOKENIZERS_PARALLELISM"] = "false"
tokenizer = AutoTokenizer.from_pretrained(model_id)
text = ["Running something in parallel means"]
inputs = tokenizer(text, return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
pos_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0)
tokens = input_ids.clone()

model = ModelProxy(model_id, sample_greedy)
next_tokens = model.prefill(**inputs, position_ids=pos_ids)
tokens = torch.cat([tokens, next_tokens], dim=-1)

# Data can be decoded even before leaving
decoded_texts = tokenizer.batch_decode(tokens)
print()
print("------------------------------------------")
print("Decoded texts:")
print(decoded_texts[0])
print("------------------------------------------")
expected_text = "Running something in parallel means that"
assert expected_text == decoded_texts[0]

0 comments on commit 2a98430

Please sign in to comment.