diff --git a/ovdino/detrex/modeling/language_backbone/bert.py b/ovdino/detrex/modeling/language_backbone/bert.py index 09af772..4330a01 100644 --- a/ovdino/detrex/modeling/language_backbone/bert.py +++ b/ovdino/detrex/modeling/language_backbone/bert.py @@ -121,6 +121,10 @@ def __init__( for param in self.lang_model.parameters(): param.requires_grad = False + @property + def device(self): + return next(self.parameters()).device + def forward(self, x, *args, **kwargs): """Forward function of text_encoder. Args: @@ -131,8 +135,8 @@ def forward(self, x, *args, **kwargs): """ if self.post_tokenize: tokenized_batch = self.tokenizer(x, return_mask=True) - input_ids = tokenized_batch["input_ids"].cuda() - attention_mask = tokenized_batch["attention_mask"].cuda() + input_ids = tokenized_batch["input_ids"].to(self.device) + attention_mask = tokenized_batch["attention_mask"].to(self.device) else: assert self.context_length == x.shape[-1] input_ids = x.reshape(-1, self.context_length)