diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index 29642f29e1ba..95b7d1e80308 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -36,7 +36,7 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar - attention_mask (torch.Tensor): shape [B, S+N] - action_log_probs (torch.Tensor): shape [B, N] - action_mask (torch.Tensor): shape [B, N] - where N is the number of generated tokens. + where N is the number of generated tokens. And all tensors should be on CUDA. """ def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: