Skip to content

Commit

Permalink
[doc] add docstr
Browse files Browse the repository at this point in the history
  • Loading branch information
ver217 committed Feb 21, 2025
1 parent 17e99e8 commit 4a560d5
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion applications/ColossalChat/coati/distributed/inference_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,20 @@ def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any]
pass

def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
pass
"""Generate new tokens given input_ids and attention_mask.
Args:
input_ids (torch.Tensor): shape [B, S]
attention_mask (torch.Tensor): shape [B, S]
Returns:
Dict[str, torch.Tensor]: containing the
- input_ids (torch.Tensor): shape [B, S+N]
- attention_mask (torch.Tensor): shape [B, S+N]
- action_log_probs (torch.Tensor): shape [B, N]
- action_mask (torch.Tensor): shape [B, N]
where N is the number of generated tokens.
"""

def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
pass
Expand All @@ -49,6 +62,7 @@ def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any]
self.generate_config.update(self.FORCE_GENERATE_CONFIG)
self.tokenizer = tokenizer

@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
input_ids = input_ids.to(get_current_device())
attention_mask = attention_mask.to(get_current_device())
Expand Down Expand Up @@ -99,6 +113,7 @@ def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any]
self.tokenizer = tokenizer
self.config = AutoConfig.from_pretrained(path)

@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
outputs = self.llm.generate(input_ids=input_ids.tolist(), sampling_params=self.generate_config)
out_tokens = []
Expand Down Expand Up @@ -152,6 +167,7 @@ def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any]
self.generate_config = SamplingParams(**generate_config)
self.tokenizer = tokenizer

@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
outputs = self.llm.generate(
prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False
Expand Down

0 comments on commit 4a560d5

Please sign in to comment.