diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index dcf305245fbb..29642f29e1ba 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -24,7 +24,20 @@ def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any] pass def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: - pass + """Generate new tokens given input_ids and attention_mask. + + Args: + input_ids (torch.Tensor): shape [B, S] + attention_mask (torch.Tensor): shape [B, S] + + Returns: + Dict[str, torch.Tensor]: containing the + - input_ids (torch.Tensor): shape [B, S+N] + - attention_mask (torch.Tensor): shape [B, S+N] + - action_log_probs (torch.Tensor): shape [B, N] + - action_mask (torch.Tensor): shape [B, N] + where N is the number of generated tokens. + """ def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: pass @@ -49,6 +62,7 @@ def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any] self.generate_config.update(self.FORCE_GENERATE_CONFIG) self.tokenizer = tokenizer + @torch.no_grad() def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: input_ids = input_ids.to(get_current_device()) attention_mask = attention_mask.to(get_current_device()) @@ -99,6 +113,7 @@ def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any] self.tokenizer = tokenizer self.config = AutoConfig.from_pretrained(path) + @torch.no_grad() def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: outputs = self.llm.generate(input_ids=input_ids.tolist(), sampling_params=self.generate_config) out_tokens = [] @@ -152,6 +167,7 @@ def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any] self.generate_config = SamplingParams(**generate_config) self.tokenizer = tokenizer + @torch.no_grad() def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: outputs = self.llm.generate( prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False