Skip to content

Commit

Permalink
CU-8697x7y9x: Fix issue with transformers 4.47+ affecting DeID (#517)
Browse files Browse the repository at this point in the history
* CU-8697x7y9x: Fix issue with transformers 4.47+ affecting DeID

* CU-8697x7y9x: Add type-ignore to module unrelated to current change
  • Loading branch information
mart-r committed Feb 12, 2025
1 parent 710842c commit 3fd620b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
16 changes: 15 additions & 1 deletion medcat/ner/transformers_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,22 @@ def create_eval_pipeline(self):
self.ner_pipe.tokenizer._in_target_context_manager = False
if not hasattr(self.ner_pipe.tokenizer, 'split_special_tokens'):
# NOTE: this will fix the DeID model(s) created with transformers before 4.42
# and allow them to run with later transforemrs
# and allow them to run with later transformers
self.ner_pipe.tokenizer.split_special_tokens = False
if not hasattr(self.ner_pipe.tokenizer, 'pad_token') and hasattr(self.ner_pipe.tokenizer, '_pad_token'):
# NOTE: This will fix the DeID model(s) created with transformers before 4.47
# and allow them to run with later transformmers versions
# In 4.47 the special tokens started to be used differently, yet our saved model
# is not aware of that. So we need to explicitly fix that.
special_tokens_map = self.ner_pipe.tokenizer.__dict__.get('_special_tokens_map', {})
for name in self.ner_pipe.tokenizer.SPECIAL_TOKENS_ATTRIBUTES:
# previously saved in (e.g) _pad_token
prev_val = getattr(self.ner_pipe.tokenizer, f"_{name}")
# now saved in the special tokens map by its name
special_tokens_map[name] = prev_val
# the map is saved in __dict__ explicitly, and it is later used in __getattr__ of the base class.
self.ner_pipe.tokenizer.__dict__['_special_tokens_map'] = special_tokens_map

self.ner_pipe.device = self.model.device
self._consecutive_identical_failures = 0
self._last_exception: Optional[Tuple[str, Type[Exception]]] = None
Expand Down
2 changes: 1 addition & 1 deletion medcat/utils/relation_extraction/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def load_state(model: BertModel_RelationExtraction, optimizer, scheduler, path="

if optimizer is None:
optimizer = torch.optim.Adam(
[{"params": model.module.parameters(), "lr": config.train.lr}])
[{"params": model.module.parameters(), "lr": config.train.lr}]) # type: ignore

if scheduler is None:
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
Expand Down

0 comments on commit 3fd620b

Please sign in to comment.