Skip to content

Commit

Permalink
Relation extraction (#173)
Browse files Browse the repository at this point in the history
* Added files.

* More additions to rel extraction.

* Rel base.

* Update.

* Updates.

* Dependency parsing.

* Updates.

* Added pre-training steps.

* Added training & model utils.

* Cleanup & fixes.

* Update.

* Evaluation updates for pretraining.

* Removed duplicate relation storage.

* Moved RE model file location.

* Structure revisions.

* Added custom config for RE.

* Implemented custom dataset loader for RE.

* More changes.

* Small fix.

* Latest additions to RelCAT (pipe + predictions)

* Setup.py fix.

* RE utils update.

* rel model update.

* rel dataset + tokenizer improvements.

* RelCAT updates.

* RelCAT saving/loading improvements.

* RelCAT saving/loading improvements.

* RelCAT model fixes.

* Attempted gpu learning fix. Dataset label generation fixes.

* Minor train dataset gen fix.

* Minor train dataset gen fix No.2.

* Config updates.

* Gpu support fixes. Added label stats.

* Evaluation stat fixes.

* Cleaned stat output mode during training.

* Build fix.

* removed unused dependencies and fixed code formatting

* Mypy compliance.

* Fixed linting.

* More Gpu mode train fixes.

* Fixed model saving/loading issues when using other baes models.

* More fixes to stat evaluation. Added proper CAT integration of RelCAT.

* Setup.py typo fix.

* RelCAT loading fix.

* RelCAT Config changes.

* Type fix. Minor additions to RelCAT model.

* Type fixes.

* Type corrections.

* RelCAT update.

* Type fixes.

* Fixed type issue.

* RelCATConfig: added seed param.

* Adaptations to the new codebase + type fixes..

* Doc/type fixes.

* Fixed input size issue for model.

* Fixed issue(s) with model size and config.

* RelCAT: updated configs to new style.

* RelCAT: removed old refs to logging.

* Fixed GPU training + added extra stat print for train set.

* Type fixes.

* Updated dev requirements.

* Linting.

* Fixed pin_memory issue when training on CPU.

* Updated RelCAT dataset get + default config.

* Updated RelDS generator + default config

* Linting.

* Updated RelDatset + config.

* Pushing updates to model

Made changes to:
1) Extracting given number of context tokens left and right of the entities
2) Extracting hidden state from bert for all the tokens of the entities and performing max pooling on them

* Fixing formatting

* Update rel_dataset.py

* Update rel_dataset.py

* Update rel_dataset.py

* RelCAT: added test resource files.

* RelCAT: Fixed model load/checkpointing.

* RelCAT: updated to pipe spacy doc call.

* RelCAT: added tests.

* Fixed lint/type issues & added rel tag to test DS.

* Fixed ann id to token issue.

* RelCAT: updated test dataset + tests.

* RelCAT: updates to requested changes + dataset improvements.

* RelCAT: updated docs/logs according to commends.

* RelCAT: type fix.

* RelCAT: mct export dataset updates.

* RelCAT: test updates + requested changes p2.

* RelCAT: log for MCT export train.

* Updated docs + split train_test & dataset for benchmarks.

* type fixes.

---------

Co-authored-by: Shubham Agarwal <66172189+shubham-s-agarwal@users.noreply.github.com>
Co-authored-by: mart-r <mart.ratas@gmail.com>
3 people authored May 1, 2024
1 parent 1caa187 commit abc97fb
Showing 17 changed files with 6,776 additions and 6 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -18,6 +18,9 @@ venv
db.sqlite3
.ipynb_checkpoints

# vscode
.vscode

#tmp and similar files
.nfs*
*.log
31 changes: 27 additions & 4 deletions medcat/cat.py
Original file line number Diff line number Diff line change
@@ -33,6 +33,7 @@
from medcat.linking.context_based_linker import Linker
from medcat.preprocessing.cleaners import prepare_name
from medcat.meta_cat import MetaCAT
from medcat.rel_cat import RelCAT
from medcat.utils.meta_cat.data_utils import json_to_fake_spacy
from medcat.config import Config
from medcat.vocab import Vocab
@@ -64,6 +65,8 @@ class CAT(object):
meta_cats (list of medcat.meta_cat.MetaCAT, optional):
A list of models that will be applied sequentially on each
detected annotation.
rel_cats (list of medcat.rel_cat.RelCAT, optional)
List of models applied sequentially on all detected annotations.
Attributes (limited):
cdb (medcat.cdb.CDB):
@@ -89,6 +92,7 @@ def __init__(self,
vocab: Union[Vocab, None] = None,
config: Optional[Config] = None,
meta_cats: List[MetaCAT] = [],
rel_cats: List[RelCAT] = [],
addl_ner: Union[TransformersNER, List[TransformersNER]] = []) -> None:
self.cdb = cdb
self.vocab = vocab
@@ -100,6 +104,7 @@ def __init__(self,
self.config = config
self.cdb.config = config
self._meta_cats = meta_cats
self._rel_cats = rel_cats
self._addl_ner = addl_ner if isinstance(addl_ner, list) else [addl_ner]
self._create_pipeline(self.config)

@@ -133,6 +138,9 @@ def _create_pipeline(self, config: Config):
for meta_cat in self._meta_cats:
self.pipe.add_meta_cat(meta_cat, meta_cat.config.general.category_name)

for rel_cat in self._rel_cats:
self.pipe.add_rel_cat(rel_cat, "_".join(list(rel_cat.config.general["labels2idx"].keys())))

# Set max document length
self.pipe.spacy_nlp.max_length = config.preprocessing.max_document_length

@@ -297,6 +305,10 @@ def create_model_pack(self, save_dir_path: str, model_pack_name: str = DEFAULT_M
name = comp[0]
meta_path = os.path.join(save_dir_path, "meta_" + name)
comp[1].save(meta_path)
if isinstance(comp[1], RelCAT):
name = comp[0]
rel_path = os.path.join(save_dir_path, "rel_" + name)
comp[1].save(rel_path)

# Add a model card also, why not
model_card_path = os.path.join(save_dir_path, "model_card.json")
@@ -341,7 +353,8 @@ def load_model_pack(cls,
meta_cat_config_dict: Optional[Dict] = None,
ner_config_dict: Optional[Dict] = None,
load_meta_models: bool = True,
load_addl_ner: bool = True) -> "CAT":
load_addl_ner: bool = True,
load_rel_models: bool = True) -> "CAT":
"""Load everything within the 'model pack', i.e. the CDB, config, vocab and any MetaCAT models
(if present)
@@ -360,13 +373,16 @@ def load_model_pack(cls,
Whether to load MetaCAT models if present (Default value True).
load_addl_ner (bool):
Whether to load additional NER models if present (Default value True).
load_rel_models (bool):
Whether to load RelCAT models if present (Default value True).
Returns:
CAT: The resulting CAT object.
"""
from medcat.cdb import CDB
from medcat.vocab import Vocab
from medcat.meta_cat import MetaCAT
from medcat.rel_cat import RelCAT

model_pack_path = cls.attempt_unpack(zip_path)

@@ -409,8 +425,15 @@ def load_model_pack(cls,
meta_cats.append(MetaCAT.load(save_dir_path=meta_path,
config_dict=meta_cat_config_dict))

cat = cls(cdb=cdb, config=cdb.config, vocab=vocab, meta_cats=meta_cats, addl_ner=addl_ner)
# Find Rel models in model_pack
rel_paths = [os.path.join(model_pack_path, path) for path in os.listdir(model_pack_path) if path.startswith('rel_')] if load_rel_models else []
rel_cats = []
for rel_path in rel_paths:
rel_cats.append(RelCAT.load(load_path=rel_path))

cat = cls(cdb=cdb, config=cdb.config, vocab=vocab, meta_cats=meta_cats, addl_ner=addl_ner, rel_cats=rel_cats)
logger.info(cat.get_model_card()) # Print the model card

return cat

def __call__(self, text: Optional[str], do_train: bool = False) -> Optional[Doc]:
@@ -1092,8 +1115,8 @@ def get_entities_multi_texts(self,
elif out[i].get('text', '') != text:
out.insert(i, self._doc_to_out(None, only_cui, addl_info)) # type: ignore

cnf_annotation_output = self.config.annotation_output
if not cnf_annotation_output.include_text_in_output:
cnf_annotation_output = getattr(self.config, 'annotation_output', {})
if not (cnf_annotation_output.get('include_text_in_output', False)):
for o in out:
if o is not None:
o.pop('text', None)
98 changes: 98 additions & 0 deletions medcat/config_rel_cat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import logging
from typing import Dict, Any, List
from medcat.config import MixingConfig, BaseModel, Optional, Extra


class General(MixingConfig, BaseModel):
"""The General part of the RelCAT config"""
device: str = "cpu"
relation_type_filter_pairs: List = []
"""Map from category values to ID, if empty it will be autocalculated during training"""
vocab_size: Optional[int] = None
lowercase: bool = True
"""If true all input text will be lowercased"""
cntx_left: int = 15
"""Number of tokens to take from the left of the concept"""
cntx_right: int = 15
"""Number of tokens to take from the right of the concept"""
window_size: int = 300
"""Max acceptable dinstance between entities (in characters), care when using this as it can produce sentences that are over 512 tokens (limit is given by tokenizer)"""

mct_export_max_non_rel_sample_size:int = 200
"""Limit the number of 'Other' samples selected for training/test. This is applied per encountered medcat project, sample_size/num_projects. """
mct_export_create_addl_rels: bool = False
"""When processing relations from a MedCAT export, relations labeled as 'Other' are created from all the annotations pairs available"""

tokenizer_name: str = "bert"
model_name: str = "bert-base-uncased"
log_level: int = logging.INFO
max_seq_length: int = 512
tokenizer_special_tokens: bool = False
annotation_schema_tag_ids: List = []
"""If a foreign non-MCAT trainer dataset is used, you can insert your own Rel entity token delimiters into the tokenizer, \
copy those token IDs here, and also resize your tokenizer embeddings and adjust the hidden_size of the model, this will depend on the number of tokens you introduce"""
labels2idx: Dict = {}
idx2labels: Dict = {}
pin_memory: bool = True
seed: int = 13
task: str = "train"


class Model(MixingConfig, BaseModel):
"""The model part of the RelCAT config"""
input_size: int = 300
hidden_size: int = 768
hidden_layers: int = 3
""" hidden_size * 5, 5 being the number of tokens, default (s1,s2,e1,e2+CLS)"""
model_size: int = 5120
dropout: float = 0.2
num_directions: int = 2
"""2 - bidirectional model, 1 - unidirectional"""

padding_idx: int = -1
emb_grad: bool = True
"""If True the embeddings will also be trained"""
ignore_cpos: bool = False
"""If set to True center positions will be ignored when calculating represenation"""

class Config:
extra = Extra.allow
validate_assignment = True


class Train(MixingConfig, BaseModel):
"""The train part of the RelCAT config"""
nclasses: int = 2
"""Number of classes that this model will output"""
batch_size: int = 25
nepochs: int = 1
lr: float = 1e-4
adam_epsilon: float = 1e-4
test_size: float = 0.2
gradient_acc_steps: int = 1
multistep_milestones: List[int] = [
2, 4, 6, 8, 12, 15, 18, 20, 22, 24, 26, 30]
multistep_lr_gamma: float = 0.8
max_grad_norm: float = 1.0
shuffle_data: bool = True
"""Used only during training, if set the dataset will be shuffled before train/test split"""
class_weights: Optional[Any] = None
score_average: str = "weighted"
"""What to use for averaging F1/P/R across labels"""
auto_save_model: bool = True
"""Should the model be saved during training for best results"""

class Config:
extra = Extra.allow
validate_assignment = True


class ConfigRelCAT(MixingConfig, BaseModel):
"""The RelCAT part of the config"""
general: General = General()
model: Model = Model()
train: Train = Train()

class Config:
extra = Extra.allow
validate_assignment = True
9 changes: 9 additions & 0 deletions medcat/pipe.py
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@
from medcat.linking.context_based_linker import Linker
from medcat.meta_cat import MetaCAT
from medcat.ner.vocab_based_ner import NER
from medcat.rel_cat import RelCAT
from medcat.utils.normalizers import TokenNormalizer, BasicSpellChecker
from medcat.config import Config
from medcat.pipeline.pipe_runner import PipeRunner
@@ -161,6 +162,13 @@ def add_meta_cat(self, meta_cat: MetaCAT, name: Optional[str] = None) -> None:
# Used for sharing pre-processed data/tokens
Doc.set_extension('share_tokens', default=None, force=True)

def add_rel_cat(self, rel_cat: RelCAT, name: Optional[str] = None) -> None:
component_name = spacy.util.get_object_name(rel_cat)
name = name if name is not None else component_name
Language.component(name=component_name, func=rel_cat)
self._nlp.add_pipe(component_name, name=name, last=True)
# dictionary containing relations of the form {}
Doc.set_extension("relations", default=[], force=True)

def add_addl_ner(self, addl_ner: TransformersNER, name: Optional[str] = None) -> None:
component_name = spacy.util.get_object_name(addl_ner)
@@ -169,6 +177,7 @@ def add_addl_ner(self, addl_ner: TransformersNER, name: Optional[str] = None) ->
self._nlp.add_pipe(component_name, name=name, last=True)

Doc.set_extension('ents', default=[], force=True)
Doc.set_extension('relations', default=[], force=True)
Span.set_extension('confidence', default=-1, force=True)
Span.set_extension('id', default=0, force=True)
Span.set_extension('cui', default=-1, force=True)
666 changes: 666 additions & 0 deletions medcat/rel_cat.py

Large diffs are not rendered by default.

Empty file.
219 changes: 219 additions & 0 deletions medcat/utils/relation_extraction/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
import logging
from typing import Any, List, Optional, Tuple
import torch
from torch import nn
from transformers.models.bert.modeling_bert import BertPreTrainingHeads, BertModel
from transformers.models.bert.configuration_bert import BertConfig
from medcat.config_rel_cat import ConfigRelCAT


class BertModel_RelationExtraction(nn.Module):
""" BertModel class for RelCAT
"""

name = "bertmodel_relcat"

log = logging.getLogger(__name__)

def __init__(self, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, model_config: BertConfig):
""" Class to hold the BERT model + model_config
Args:
pretrained_model_name_or_path (str): path to load the model from,
this can be a HF model i.e: "bert-base-uncased", if left empty, it is normally assumed that a model is loaded from 'model.dat'
using the RelCAT.load() method. So if you are initializing/training a model from scratch be sure to base it on some model.
relcat_config (ConfigRelCAT): relcat config.
model_config (BertConfig): HF bert config for model.
"""
super(BertModel_RelationExtraction, self).__init__()

self.relcat_config: ConfigRelCAT = relcat_config
self.model_config: BertConfig = model_config

self.bert_model:BertModel = BertModel(config=model_config)

if pretrained_model_name_or_path != "":
self.bert_model = BertModel.from_pretrained(pretrained_model_name_or_path, config=model_config)

for param in self.bert_model.parameters():
param.requires_grad = False

self.drop_out = nn.Dropout(self.model_config.hidden_dropout_prob)

if self.relcat_config.general.task == "pretrain":
self.activation = nn.Tanh()
self.cls = BertPreTrainingHeads(self.model_config)

self.relu = nn.ReLU()

# dense layers
self.fc1 = nn.Linear(self.relcat_config.model.model_size, self.relcat_config.model.hidden_size)
self.fc2 = nn.Linear(self.relcat_config.model.hidden_size, int(self.relcat_config.model.hidden_size / 2))
self.fc3 = nn.Linear(int(self.relcat_config.model.hidden_size / 2), self.relcat_config.train.nclasses)

self.log.info("RelCAT BertConfig: " + str(self.model_config))

def get_annotation_schema_tag(self, sequence_output: torch.Tensor, input_ids: torch.Tensor, special_tag: List) -> torch.Tensor:
""" Gets to token sequences from the sequence_ouput for the specific token
tag ids in self.relcat_config.general.annotation_schema_tag_ids.
Args:
sequence_output (torch.Tensor): hidden states/embeddings for each token in the input text
input_ids (torch.Tensor): input token ids
special_tag (List): special annotation token id pairs
Returns:
torch.Tensor: new seq_tags
"""

idx_start = torch.where(input_ids == special_tag[0]) # returns: row ids, idx of token[0]/star token in row
idx_end = torch.where(input_ids == special_tag[1]) # returns: row ids, idx of token[1]/end token in row

seen = [] # List to store seen elements and their indices
duplicate_indices = []

for i in range(len(idx_start[0])):
if idx_start[0][i] in seen:
duplicate_indices.append(i)
else:
seen.append(idx_start[0][i])

if len(duplicate_indices) > 0:
self.log.info("Duplicate entities found, removing them...")
for idx_remove in duplicate_indices:
idx_start_0 = torch.cat((idx_start[0][:idx_remove], idx_start[0][idx_remove + 1:]))
idx_start_1 = torch.cat((idx_start[1][:idx_remove], idx_start[1][idx_remove + 1:]))
idx_start = (idx_start_0, idx_start_1) # type: ignore

seen = []
duplicate_indices = []

for i in range(len(idx_end[0])):
if idx_end[0][i] in seen:
duplicate_indices.append(i)
else:
seen.append(idx_end[0][i])

if len(duplicate_indices) > 0:
self.log.info("Duplicate entities found, removing them...")
for idx_remove in duplicate_indices:
idx_end_0 = torch.cat((idx_end[0][:idx_remove], idx_end[0][idx_remove + 1:]))
idx_end_1 = torch.cat((idx_end[1][:idx_remove], idx_end[1][idx_remove + 1:]))
idx_end = (idx_end_0, idx_end_1) # type: ignore

assert len(idx_start[0]) == input_ids.shape[0]
assert len(idx_start[0]) == len(idx_end[0])
sequence_output_entities = []

for i in range(len(idx_start[0])):
to_append = sequence_output[i, idx_start[1][i] + 1:idx_end[1][i], ]

# to_append = torch.sum(to_append, dim=0)
to_append, _ = torch.max(to_append, axis=0) # type: ignore

sequence_output_entities.append(to_append)
sequence_output_entities = torch.stack(sequence_output_entities)

return sequence_output_entities

def output2logits(self, pooled_output: torch.Tensor, sequence_output: torch.Tensor, input_ids: torch.Tensor, e1_e2_start: torch.Tensor) -> torch.Tensor:
"""
Args:
pooled_output (torch.Tensor): embedding of the CLS token
sequence_output (torch.Tensor): hidden states/embeddings for each token in the input text
input_ids (torch.Tensor): input token ids.
e1_e2_start (torch.Tensor): annotation tags token position
Returns:
torch.Tensor: classification probabilities for each token.
"""

new_pooled_output = pooled_output

if self.relcat_config.general.annotation_schema_tag_ids:
annotation_schema_tag_ids_ = [self.relcat_config.general.annotation_schema_tag_ids[i:i + 2] for i in
range(0, len(self.relcat_config.general.annotation_schema_tag_ids), 2)]
seq_tags = []

# for each pair of tags (e1,s1) and (e2,s2)
for each_tags in annotation_schema_tag_ids_:
seq_tags.append(self.get_annotation_schema_tag(
sequence_output, input_ids, each_tags))

seq_tags = torch.stack(seq_tags, dim=0)

new_pooled_output = torch.cat((pooled_output, *seq_tags), dim=1)
else:
e1e2_output = []
temp_e1 = []
temp_e2 = []

for i, seq in enumerate(sequence_output):
# e1e2 token sequences
temp_e1.append(seq[e1_e2_start[i][0]])
temp_e2.append(seq[e1_e2_start[i][1]])

e1e2_output.append(torch.stack(temp_e1, dim=0))
e1e2_output.append(torch.stack(temp_e2, dim=0))

new_pooled_output = torch.cat((pooled_output, *e1e2_output), dim=1)

del e1e2_output
del temp_e2
del temp_e1

x = self.drop_out(new_pooled_output)
x = self.fc1(x)
x = self.drop_out(x)
x = self.fc2(x)
classification_logits = self.fc3(x)
return classification_logits.to(self.relcat_config.general.device)

def forward(self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Any = None,
head_mask: Any = None,
encoder_hidden_states: Any = None,
encoder_attention_mask: Any = None,
Q: Any = None,
e1_e2_start: Any = None,
pooled_output: Any = None) -> Tuple[torch.Tensor, torch.Tensor]:

if input_ids is not None:
input_shape = input_ids.size()
else:
raise ValueError("You have to specify input_ids")

if attention_mask is None:
attention_mask = torch.ones(
input_shape, device=self.relcat_config.general.device)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(
input_shape, device=self.relcat_config.general.device)
if token_type_ids is None:
token_type_ids = torch.zeros(
input_shape, dtype=torch.long, device=self.relcat_config.general.device)

input_ids = input_ids.to(self.relcat_config.general.device)
attention_mask = attention_mask.to(self.relcat_config.general.device)
encoder_attention_mask = encoder_attention_mask.to(
self.relcat_config.general.device)

self.bert_model = self.bert_model.to(self.relcat_config.general.device)

model_output = self.bert_model(input_ids=input_ids, attention_mask=attention_mask,
token_type_ids=token_type_ids,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask)

# (batch_size, sequence_length, hidden_size)
sequence_output = model_output[0]
pooled_output = model_output[1]

classification_logits = self.output2logits(
pooled_output, sequence_output, input_ids, e1_e2_start)

return model_output, classification_logits.to(self.relcat_config.general.device)
53 changes: 53 additions & 0 deletions medcat/utils/relation_extraction/pad_seq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import List, Tuple
import torch
from torch import Tensor, LongTensor
from torch.nn.utils.rnn import pad_sequence


class Pad_Sequence():

def __init__(self, seq_pad_value: int, label_pad_value: int = -1):
""" Used in rel_cat.py in RelCAT to create DataLoaders for train/test datasets.
collate_fn for dataloader to collate sequences of different input_ids, ent1/ent2, and label
lengths into a fixed length batch.
This is applied per batch and not on the whole DataLoader data,
padded x sequence, y sequence, x lengths and y lengths of batch.
Args:
seq_pad_value (int): pad value for input_ids.
label_pad_value (int): pad value for labels. Defaults to -1.
"""
self.seq_pad_value: int = seq_pad_value
self.label_pad_value: int = label_pad_value

def __call__(self, batch: List[torch.Tensor]) -> Tuple[Tensor, List, Tensor, LongTensor, LongTensor]:
""" Pads a batch of input_ids.
Args:
batch (List[torch.Tensor]): gets the batch of Tensors from
RelData.dataset (check __getitem__() method for data returned)
and pads the token sequence + labels as needed
See https://pytorch.org/docs/stable/_modules/torch/nn/utils/rnn.html#pad_sequence
for extra info.
Returns:
Tuple[Tensor, Tensor, Tensor, LongTensor, LongTensor]: padded data
padded input ids, ent1&ent2 start token pos, padded labels, padded input_id_lengths, padded labels length
"""
sorted_batch = sorted(batch, key=lambda x: x[0].shape[0], reverse=True)

# input ids
seqs = [x[0] for x in sorted_batch]
seqs_padded = pad_sequence(
seqs, batch_first=True, padding_value=self.seq_pad_value)
x_lengths = torch.LongTensor([len(x) for x in seqs])

# label_ids
labels = list(map(lambda x: x[2], sorted_batch))
labels_padded = pad_sequence(
labels, batch_first=True, padding_value=self.label_pad_value)
y_lengths = torch.LongTensor([len(x) for x in labels])

ent1_ent2_start_pos = list(map(lambda x: x[1], sorted_batch))

return seqs_padded, ent1_ent2_start_pos, labels_padded, x_lengths, y_lengths
687 changes: 687 additions & 0 deletions medcat/utils/relation_extraction/rel_dataset.py

Large diffs are not rendered by default.

71 changes: 71 additions & 0 deletions medcat/utils/relation_extraction/tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os
from typing import Optional
from transformers.models.bert.tokenization_bert_fast import BertTokenizerFast


class TokenizerWrapperBERT(BertTokenizerFast):
''' Wrapper around a huggingface BERT tokenizer so that it works with the
RelCAT models.
Args:
hf_tokenizers (`transformers.models.bert.tokenization_bert_fast.BertTokenizerFast`):
A huggingface Fast BERT.
'''
name = 'bert-tokenizer'

def __init__(self, hf_tokenizers=None, max_seq_length: Optional[int] = None, add_special_tokens: Optional[bool] = False):
self.hf_tokenizers = hf_tokenizers
self.max_seq_length = max_seq_length
self.add_special_tokens = add_special_tokens

def __call__(self, text, truncation: Optional[bool] = True):
if isinstance(text, str):
result = self.hf_tokenizers.encode_plus(text, return_offsets_mapping=True, return_length=True, return_token_type_ids=True, return_attention_mask=True,
add_special_tokens=self.add_special_tokens, max_length=self.max_seq_length, padding="longest", truncation=truncation)

return {'offset_mapping': result['offset_mapping'],
'input_ids': result['input_ids'],
'tokens': self.hf_tokenizers.convert_ids_to_tokens(result['input_ids']),
'token_type_ids': result['token_type_ids'],
'attention_mask': result['attention_mask'],
'length': result['length']
}
elif isinstance(text, list):
results = self.hf_tokenizers._batch_encode_plus(text, return_offsets_mapping=True, return_length=True, return_token_type_ids=True,
add_special_tokens=self.add_special_tokens, max_length=self.max_seq_length,truncation=truncation)
output = []
for ind in range(len(results['input_ids'])):
output.append({
'offset_mapping': results['offset_mapping'][ind],
'input_ids': results['input_ids'][ind],
'tokens': self.hf_tokenizers.convert_ids_to_tokens(results['input_ids'][ind]),
'token_type_ids': results['token_type_ids'][ind],
'attention_mask': results['attention_mask'][ind],
'length': result['length']
})
return output
else:
raise Exception(
"Unsuported input type, supported: text/list, but got: {}".format(type(text)))

def save(self, dir_path):
path = os.path.join(dir_path, self.name)
self.hf_tokenizers.save_pretrained(path)

@classmethod
def load(cls, dir_path, **kwargs):
tokenizer = cls()
path = os.path.join(dir_path, cls.name)
tokenizer.hf_tokenizers = BertTokenizerFast.from_pretrained(
path, **kwargs)

return tokenizer

def get_size(self):
return len(self.hf_tokenizers.vocab)

def token_to_id(self, token):
return self.hf_tokenizers.convert_tokens_to_ids(token)

def get_pad_id(self):
return self.hf_tokenizers.pad_token_id
277 changes: 277 additions & 0 deletions medcat/utils/relation_extraction/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
import os
import pickle
from typing import Any, Dict, List, Tuple
import numpy as np
import torch
import logging
import random

from pandas.core.series import Series
from medcat.config_rel_cat import ConfigRelCAT

from medcat.preprocessing.tokenizers import TokenizerWrapperBERT
from medcat.utils.relation_extraction.models import BertModel_RelationExtraction


def split_list_train_test_by_class(data: List, test_size: float = 0.2, shuffle: bool = True) -> Tuple[List, List]:
"""
Args:
data (List): "output_relations": relation_instances, <-- see create_base_relations_from_doc/csv
for data columns
test_size (float): Defaults to 0.2.
shuffle (bool): shuffle data randomly. Defaults to True.
Returns:
Tuple[List, List]: train and test datasets
"""

if shuffle:
random.shuffle(data)

train_data = []
test_data = []

row_id_labels = {row_idx: data[row_idx][5] for row_idx in range(len(data))}
count_per_label = {lbl: list(row_id_labels.values()).count(
lbl) for lbl in set(row_id_labels.values())}

for lbl_id, count in count_per_label.items():
_test_records_size = int(count * test_size)
tmp_count = 0
if _test_records_size not in [0, 1]:
for row_idx, _lbl_id in row_id_labels.items():
if _lbl_id == lbl_id:
if tmp_count < _test_records_size:
test_data.append(data[row_idx])
tmp_count += 1
else:
train_data.append(data[row_idx])
else:
for row_idx, _lbl_id in row_id_labels.items():
if _lbl_id == lbl_id:
train_data.append(data[row_idx])
test_data.append(data[row_idx])

return train_data, test_data


def load_bin_file(file_name, path="./") -> Any:
with open(os.path.join(path, file_name), 'rb') as f:
data = pickle.load(f)
return data


def save_bin_file(file_name, data, path="./"):
with open(os.path.join(path, file_name), "wb") as f:
pickle.dump(data, f)


def save_state(model: BertModel_RelationExtraction, optimizer: torch.optim.Adam, scheduler: torch.optim.lr_scheduler.MultiStepLR, epoch:int = 1, best_f1:float = 0.0, path:str = "./", model_name: str = "BERT", task:str = "train", is_checkpoint=False, final_export=False) -> None:
""" Used by RelCAT.save() and RelCAT.train()
Saves the RelCAT model state.
For checkpointing multiple files are created, best_f1, loss etc. score.
If you want to export the model after training set final_export=True and leave is_checkpoint=False.
Args:
model (BertModel_RelationExtraction): model
optimizer (torch.optim.Adam, optional): Defaults to None.
scheduler (torch.optim.lr_scheduler.MultiStepLR, optional): Defaults to None.
epoch (int): Defaults to None.
best_f1 (float): Defaults to None.
path (str):Defaults to "./".
model_name (str): . Defaults to "BERT". This is used to checkpointing only.
task (str): Defaults to "train". This is used to checkpointing only.
is_checkpoint (bool): Defaults to False.
final_export (bool): Defaults to False, if True then is_checkpoint must be False also. Exports model.state_dict(), out into"model.dat".
"""

model_name = model_name.replace("/", "_")
file_name = "%s_checkpoint_%s.dat" % (task, model_name)

if not is_checkpoint:
file_name = "%s_best_%s.dat" % (task, model_name)
if final_export:
file_name = "model.dat"
torch.save(model.state_dict(), os.path.join(path, file_name))

if is_checkpoint:
torch.save({
'epoch': epoch,
'state_dict': model.state_dict(),
'best_f1': best_f1,
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict()
}, os.path.join(path, file_name))


def load_state(model: BertModel_RelationExtraction, optimizer, scheduler, path="./", model_name="BERT", file_prefix="train", load_best=False, device: torch.device =torch.device("cpu"), config: ConfigRelCAT = ConfigRelCAT()) -> Tuple[int, int]:
""" Used by RelCAT.load() and RelCAT.train()
Args:
model (BertModel_RelationExtraction): model, it has to be initialized before calling this method via BertModel_RelationExtraction(...)
optimizer (_type_): optimizer
scheduler (_type_): scheduler
path (str, optional): Defaults to "./".
model_name (str, optional): Defaults to "BERT".
file_prefix (str, optional): Defaults to "train".
load_best (bool, optional): Defaults to False.
device (torch.device, optional): Defaults to torch.device("cpu").
config (ConfigRelCAT): Defaults to ConfigRelCAT().
Returns:
Tuple (int, int): last epoch and f1 score.
"""

model_name = model_name.replace("/", "_")
logging.info("Attempting to load RelCAT model on device: " + str(device))
checkpoint_path = os.path.join(
path, file_prefix + "_checkpoint_%s.dat" % model_name)
best_path = os.path.join(
path, file_prefix + "_best_%s.dat" % model_name)
start_epoch, best_f1, checkpoint = 0, 0, None

if load_best is True and os.path.isfile(best_path):
checkpoint = torch.load(best_path, map_location=device)
logging.info("Loaded best model.")
elif os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=device)
logging.info("Loaded checkpoint model.")

if checkpoint is not None:
start_epoch = checkpoint['epoch']
best_f1 = checkpoint['best_f1']
model.load_state_dict(checkpoint['state_dict'])
model.to(device)

if optimizer is None:
optimizer = torch.optim.Adam(
[{"params": model.module.parameters(), "lr": config.train.lr}])

if scheduler is None:
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
milestones=config.train.multistep_milestones,
gamma=config.train.multistep_lr_gamma)
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])
logging.info("Loaded model and optimizer.")

return start_epoch, best_f1


def save_results(data, model_name: str = "BERT", path: str = "./", file_prefix: str = "train"):
save_bin_file(file_prefix + "_losses_accuracy_f1_per_epoch_%s.dat" %
model_name, data, path)


def load_results(path, model_name: str = "BERT", file_prefix: str = "train") -> Tuple[List, List, List]:
data_dict_path = os.path.join(
path, file_prefix + "_losses_accuracy_f1_per_epoch_%s.dat" % model_name)

data_dict: Dict = {"losses_per_epoch": [],
"accuracy_per_epoch": [], "f1_per_epoch": []}
if os.path.isfile(data_dict_path):
data_dict = load_bin_file(data_dict_path)

return data_dict["losses_per_epoch"], data_dict["accuracy_per_epoch"], data_dict["f1_per_epoch"]


def put_blanks(relation_data: List, blanking_threshold: float = 0.5) -> List:
"""
Args:
relation_data (List): tuple containing token (sentence_token_span , ent1 , ent2)
Puts blanks randomly in the relation. Used for pre-training.
blanking_threshold (float): % threshold to blank token ids. Defaults to 0.5.
Returns:
List: data
"""

blank_ent1 = np.random.uniform()
blank_ent2 = np.random.uniform()

blanked_relation = relation_data

sentence_token_span, ent1, ent2, label, label_id, ent1_types, ent2_types, ent1_id, ent2_id, ent1_cui, ent2_cui, doc_id = (
*relation_data, )

if blank_ent1 >= blanking_threshold:
blanked_relation = [sentence_token_span, "[BLANK]", ent2, label, label_id,
ent1_types, ent2_types, ent1_id, ent2_id, ent1_cui, ent2_cui, doc_id]

if blank_ent2 >= blanking_threshold:
blanked_relation = [sentence_token_span, ent1, "[BLANK]", label, label_id,
ent1_types, ent2_types, ent1_id, ent2_id, ent1_cui, ent2_cui, doc_id]

return blanked_relation


def create_tokenizer_pretrain(tokenizer: TokenizerWrapperBERT, tokenizer_path: str):
"""
This method simply adds special tokens that we enouncter
Args:
tokenizer (TokenizerWrapperBERT): BERT tokenizer.
tokenizer_path (str): path where tokenizer is to be saved.
"""


tokenizer.hf_tokenizers.add_tokens(
["[BLANK]", "[ENT1]", "[ENT2]", "[/ENT1]", "[/ENT2]"], special_tokens=True)
tokenizer.hf_tokenizers.add_tokens(
["[s1]", "[e1]", "[s2]", "[e2]"], special_tokens=True)
tokenizer.save(tokenizer_path)


# Used for creating data sets for pretraining
def tokenize(relations_dataset: Series, tokenizer: TokenizerWrapperBERT, mask_probability: float = 0.5) -> Tuple:
(tokens, span_1_pos, span_2_pos), ent1_text, ent2_text, label, label_id, ent1_types, ent2_types, ent1_id, ent2_id, ent1_cui, ent2_cui, doc_id = relations_dataset

cls_token = tokenizer.hf_tokenizers.cls_token
sep_token = tokenizer.hf_tokenizers.sep_token

tokens = [token.lower() for token in tokens if tokens != '[BLANK]']

forbidden_indices = [i for i in range(
span_1_pos[0], span_1_pos[1])] + [i for i in range(span_2_pos[0], span_2_pos[1])]

pool_indices = [i for i in range(
len(tokens)) if i not in forbidden_indices]

masked_indices = np.random.choice(pool_indices,
size=round(mask_probability *
len(pool_indices)),
replace=False)

masked_for_pred = [token.lower() for idx, token in enumerate(
tokens) if (idx in masked_indices)]

tokens = [token if (idx not in masked_indices)
else tokenizer.hf_tokenizers.mask_token for idx, token in enumerate(tokens)]

if (ent1_text == "[BLANK]") and (ent2_text != "[BLANK]"):
tokens = [cls_token] + tokens[:span_1_pos[0]] + ["[ENT1]", "[BLANK]", "[/ENT1]"] + \
tokens[span_1_pos[1]:span_2_pos[0]] + ["[ENT2]"] + tokens[span_2_pos[0]:span_2_pos[1]] + ["[/ENT2]"] + tokens[span_2_pos[1]:] + [sep_token]

elif (ent1_text == "[BLANK]") and (ent2_text == "[BLANK]"):
tokens = [cls_token] + tokens[:span_1_pos[0]] + ["[ENT1]", "[BLANK]", "[/ENT1]"] + \
tokens[span_1_pos[1]:span_2_pos[0]] + ["[ENT2]", "[BLANK]",
"[/ENT2]"] + tokens[span_2_pos[1]:] + [sep_token]

elif (ent1_text != "[BLANK]") and (ent2_text == "[BLANK]"):
tokens = [cls_token] + tokens[:span_1_pos[0]] + ["[ENT1]"] + tokens[span_1_pos[0]:span_1_pos[1]] + ["[/ENT1]"] + \
tokens[span_1_pos[1]:span_2_pos[0]] + ["[ENT2]", "[BLANK]",
"[/ENT2]"] + tokens[span_2_pos[1]:] + [sep_token]

elif (ent1_text != "[BLANK]") and (ent2_text != "[BLANK]"):
tokens = [cls_token] + tokens[:span_1_pos[0]] + ["[ENT1]"] + tokens[span_1_pos[0]:span_1_pos[1]] + ["[/ENT1]"] + \
tokens[span_1_pos[1]:span_2_pos[0]] + ["[ENT2]"] + tokens[span_2_pos[0]:span_2_pos[1]] + ["[/ENT2]"] + tokens[span_2_pos[1]:] + [sep_token]

ent1_ent2_start = ([i for i, e in enumerate(tokens) if e == "[ENT1]"][0], [
i for i, e in enumerate(tokens) if e == "[ENT2]"][0])

token_ids = tokenizer.hf_tokenizers.convert_tokens_to_ids(tokens)
masked_for_pred = tokenizer.hf_tokenizers.convert_tokens_to_ids(
masked_for_pred)

return token_ids, masked_for_pred, ent1_ent2_start
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@
long_description_content_type="text/markdown",
url="https://github.com/CogStack/MedCAT",
packages=['medcat', 'medcat.utils', 'medcat.preprocessing', 'medcat.ner', 'medcat.linking', 'medcat.datasets',
'medcat.tokenizers', 'medcat.utils.meta_cat', 'medcat.pipeline', 'medcat.utils.ner',
'medcat.tokenizers', 'medcat.utils.meta_cat', 'medcat.pipeline', 'medcat.utils.ner', 'medcat.utils.relation_extraction',
'medcat.utils.saving', 'medcat.utils.regression', 'medcat.stats'],
install_requires=[
'numpy>=1.22.0,<1.26.0', # 1.22.0 is first to support python 3.11; post 1.26.0 there's issues with scipy
2 changes: 2 additions & 0 deletions tests/resources/medcat_rel_test.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
relation_token_span_ids ent1_ent2_start ent1 ent2 label label_id ent1_type ent2_type ent1_id ent2_id ent1_cui ent2_cui doc_id text
(59, 69) degeneration discrete disease/disability_procedure 1 disease procedure 956 977 33359002 263869007 0 EXAM:,MRI LEFT KNEE WITHOUT CONTRAST,CLINICAL:,This is a 53-year-old female with left knee pain being evaluated for ACL tear.,FINDINGS:,This examination was performed on 10-14-05.,Normal medial meniscus without intrasubstance [s1] degeneration [e1], surface fraying or [s2] discrete [e2] meniscal tear.,There is a discoid lateral meniscus and although there may be minimal superficial fraying along the inner edge of the body, there is no discrete tear (series #6 images #7-12).,There is a near-complete or complete tear of the femoral attachment of the anterior cruciate ligament. The ligament has a balled-up appearance consistent with at least partial retraction of most of the fibers of the ligament. There may be a few fibers still intact (series #4 images #12-14 series #5 images #12-14). The tibial fibers are normal.,Normal posterior cruciate ligament.,There is a sprain of the medial collateral ligament, with mild separation of the deep and superficial fibers at the femoral attachment (series #7 images #6-12). There is no complete tear or discontinuity and there is no meniscocapsular separation.,There is a sprain of the lateral ligament complex without focal tear or discontinuity of any of the intraarticular components.,Normal iliotibial band.,Normal quadriceps and patellar tendons.,There is contusion within the posterolateral corner of the tibia. There is also contusion within the patella at the midline patellar ridge where there is an area of focal chondral flattening (series #8 images #10-13). The medial and lateral patellar facets are otherwise normal as is the femoral trochlea in the there is no patellar subluxation.,There is a mild strain of the vastus medialis oblique muscle extending into the medial patellofemoral ligament and medial patellar retinaculum but there is no complete tear or discontinuity.,Normal lateral patellar retinaculum. There is a joint effusion and plica.,IMPRESSION:, Discoid lateral meniscus without a tear although there may be minimal superficial fraying along the inner edge of the body. Near-complete if not complete tear of the femoral attachment of the anterior cruciate ligament. Medial capsule sprain with associated strain of the vastus medialis oblique muscle. There is focal contusion within the patella at the midline patella ridge. Joint effusion and plica.
4 changes: 4 additions & 0 deletions tests/resources/medcat_rel_train.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
relation_token_span_ids ent1_ent2_start ent1 ent2 label label_id ent1_type ent2_type ent1_id ent2_id ent1_cui ent2_cui doc_id text
(41, 47) severe pain emergency room disease/disability_procedure 1 disease procedure 1060 1039 76948002 225728007 1 REASON FOR CONSULTATION: , Left hip fracture.,HISTORY OF PRESENT ILLNESS: , The patient is a pleasant 53-year-old female with a known history of sciatica, apparently presented to the [s1] emergency room [e1] due to [s2] severe pain [e2] in the left lower extremity and unable to bear weight. History was obtained from the patient. As per the history, she reported that she has been having back pain with left leg pain since past 4 weeks. She has been using a walker for ambulation due to disabling pain in her left thigh and lower back. She was seen by her primary care physician and was scheduled to go for MRI yesterday. However, she was walking and her right foot got caught on some type of rug leading to place excessive weight on her left lower extremity to prevent her fall. Since then, she was unable to ambulate. The patient called paramedics and was brought to the emergency room. She denied any history of fall. She reported that she stepped the wrong way causing the pain to become worse. She is complaining of severe pain in her lower extremity and back pain. Denies any tingling or numbness. Denies any neurological symptoms. Denies any bowel or bladder incontinence.,X-rays were obtained which were remarkable for left hip fracture. Orthopedic consultation was called for further evaluation and management. On further interview with the patient, it is noted that she has a history of malignant melanoma, which was diagnosed approximately 4 to 5 years ago. She underwent surgery at that time and subsequently, she was noted to have a spread to the lymphatic system and lymph nodes for which she underwent surgery in 3/2008.,PAST MEDICAL HISTORY: , Sciatica and melanoma.,PAST SURGICAL HISTORY: ,As discussed above, surgery for melanoma and hysterectomy.,ALLERGIES: , NONE.,SOCIAL HISTORY: , Denies any tobacco or alcohol use. She is divorced with 2 children. She lives with her son.,PHYSICAL EXAMINATION:,GENERAL: The patient is well developed, well nourished in mild distress secondary to left lower extremity and back pain.,MUSCULOSKELETAL: Examination of the left lower extremity, there is presence of apparent shortening and external rotation deformity. Tenderness to palpation is present. Leg rolling is positive for severe pain in the left proximal hip. Further examination of the spine is incomplete secondary to severe leg pain. She is unable to perform a straight leg raising. EHL/EDL 5/5. 2+ pulses are present distally. Calf is soft and nontender. Homans sign is negative. Sensation to light touch is intact.,IMAGING:, AP view of the hip is reviewed. Only 1 limited view is obtained. This is a poor quality x-ray with a lot of soft tissue shadow. This x-ray is significant for basicervical-type femoral neck fracture. Lesser trochanter is intact. This is a high intertrochanteric fracture/basicervical. There is presence of lytic lesion around the femoral neck, which is not well delineated on this particular x-ray. We need to order repeat x-rays including AP pelvis, femur, and knee.,LABS:, Have been reviewed.,ASSESSMENT: , The patient is a 53-year-old female with probable pathological fracture of the left proximal femur.,DISCUSSION AND PLAN: , Nature and course of the diagnosis has been discussed with the patient. Based on her presentation without any history of obvious fall or trauma and past history of malignant melanoma, this appears to be a pathological fracture of the left proximal hip. At the present time, I would recommend obtaining a bone scan and repeat x-rays, which will include AP pelvis, femur, hip including knee. She denies any pain elsewhere. She does have a past history of back pain and sciatica, but at the present time, this appears to be a metastatic bone lesion with pathological fracture. I have discussed the case with Dr. X and recommended oncology consultation.,With the above fracture and presentation, she needs a left hip hemiarthroplasty versus calcar hemiarthroplasty, cemented type. Indication, risk, and benefits of left hip hemiarthroplasty has been discussed with the patient, which includes, but not limited to bleeding, infection, nerve injury, blood vessel injury, dislocation early and late, persistent pain, leg length discrepancy, myositis ossificans, intraoperative fracture, prosthetic fracture, need for conversion to total hip replacement surgery, revision surgery, DVT, pulmonary embolism, risk of anesthesia, need for blood transfusion, and cardiac arrest. She understands above and is willing to undergo further procedure. The goal and the functional outcome have been explained. Further plan will be discussed with her once we obtain the bone scan and the radiographic studies. We will also await for the oncology feedback and clearance.,Thank you very much for allowing me to participate in the care of this patient. I will continue to follow up.
(991, 1010) pulmonary embolism cardiac arrest non_relation 0 disease procedure 1029 1044 59282003 410429000 1 REASON FOR CONSULTATION: , Left hip fracture.,HISTORY OF PRESENT ILLNESS: , The patient is a pleasant 53-year-old female with a known history of sciatica, apparently presented to the emergency room due to severe pain in the left lower extremity and unable to bear weight. History was obtained from the patient. As per the history, she reported that she has been having back pain with left leg pain since past 4 weeks. She has been using a walker for ambulation due to disabling pain in her left thigh and lower back. She was seen by her primary care physician and was scheduled to go for MRI yesterday. However, she was walking and her right foot got caught on some type of rug leading to place excessive weight on her left lower extremity to prevent her fall. Since then, she was unable to ambulate. The patient called paramedics and was brought to the emergency room. She denied any history of fall. She reported that she stepped the wrong way causing the pain to become worse. She is complaining of severe pain in her lower extremity and back pain. Denies any tingling or numbness. Denies any neurological symptoms. Denies any bowel or bladder incontinence.,X-rays were obtained which were remarkable for left hip fracture. Orthopedic consultation was called for further evaluation and management. On further interview with the patient, it is noted that she has a history of malignant melanoma, which was diagnosed approximately 4 to 5 years ago. She underwent surgery at that time and subsequently, she was noted to have a spread to the lymphatic system and lymph nodes for which she underwent surgery in 3/2008.,PAST MEDICAL HISTORY: , Sciatica and melanoma.,PAST SURGICAL HISTORY: ,As discussed above, surgery for melanoma and hysterectomy.,ALLERGIES: , NONE.,SOCIAL HISTORY: , Denies any tobacco or alcohol use. She is divorced with 2 children. She lives with her son.,PHYSICAL EXAMINATION:,GENERAL: The patient is well developed, well nourished in mild distress secondary to left lower extremity and back pain.,MUSCULOSKELETAL: Examination of the left lower extremity, there is presence of apparent shortening and external rotation deformity. Tenderness to palpation is present. Leg rolling is positive for severe pain in the left proximal hip. Further examination of the spine is incomplete secondary to severe leg pain. She is unable to perform a straight leg raising. EHL/EDL 5/5. 2+ pulses are present distally. Calf is soft and nontender. Homans sign is negative. Sensation to light touch is intact.,IMAGING:, AP view of the hip is reviewed. Only 1 limited view is obtained. This is a poor quality x-ray with a lot of soft tissue shadow. This x-ray is significant for basicervical-type femoral neck fracture. Lesser trochanter is intact. This is a high intertrochanteric fracture/basicervical. There is presence of lytic lesion around the femoral neck, which is not well delineated on this particular x-ray. We need to order repeat x-rays including AP pelvis, femur, and knee.,LABS:, Have been reviewed.,ASSESSMENT: , The patient is a 53-year-old female with probable pathological fracture of the left proximal femur.,DISCUSSION AND PLAN: , Nature and course of the diagnosis has been discussed with the patient. Based on her presentation without any history of obvious fall or trauma and past history of malignant melanoma, this appears to be a pathological fracture of the left proximal hip. At the present time, I would recommend obtaining a bone scan and repeat x-rays, which will include AP pelvis, femur, hip including knee. She denies any pain elsewhere. She does have a past history of back pain and sciatica, but at the present time, this appears to be a metastatic bone lesion with pathological fracture. I have discussed the case with Dr. X and recommended oncology consultation.,With the above fracture and presentation, she needs a left hip hemiarthroplasty versus calcar hemiarthroplasty, cemented type. Indication, risk, and benefits of left hip hemiarthroplasty has been discussed with the patient, which includes, but not limited to bleeding, infection, nerve injury, blood vessel injury, dislocation early and late, persistent pain, leg length discrepancy, myositis ossificans, intraoperative fracture, prosthetic fracture, need for conversion to total hip replacement surgery, revision surgery, DVT, [s1] pulmonary embolism [e1], risk of anesthesia, need for blood transfusion, and [s2]cardiac arrest [e2]. She understands above and is willing to undergo further procedure. The goal and the functional outcome have been explained. Further plan will be discussed with her once we obtain the bone scan and the radiographic studies. We will also await for the oncology feedback and clearance.,Thank you very much for allowing me to participate in the care of this patient. I will continue to follow up.
(41, 53) emergency room left lower extremity disease/disability_procedure 1 disease procedure 1021 1039 225728007 32153003 1 REASON FOR CONSULTATION: , Left hip fracture.,HISTORY OF PRESENT ILLNESS: , The patient is a pleasant 53-year-old female with a known history of sciatica, apparently presented to the [s1] emergency room [e1] due to severe pain in the [s2] left lower extremity [e2] and unable to bear weight. History was obtained from the patient. As per the history, she reported that she has been having back pain with left leg pain since past 4 weeks. She has been using a walker for ambulation due to disabling pain in her left thigh and lower back. She was seen by her primary care physician and was scheduled to go for MRI yesterday. However, she was walking and her right foot got caught on some type of rug leading to place excessive weight on her left lower extremity to prevent her fall. Since then, she was unable to ambulate. The patient called paramedics and was brought to the emergency room. She denied any history of fall. She reported that she stepped the wrong way causing the pain to become worse. She is complaining of severe pain in her lower extremity and back pain. Denies any tingling or numbness. Denies any neurological symptoms. Denies any bowel or bladder incontinence.,X-rays were obtained which were remarkable for left hip fracture. Orthopedic consultation was called for further evaluation and management. On further interview with the patient, it is noted that she has a history of malignant melanoma, which was diagnosed approximately 4 to 5 years ago. She underwent surgery at that time and subsequently, she was noted to have a spread to the lymphatic system and lymph nodes for which she underwent surgery in 3/2008.,PAST MEDICAL HISTORY: , Sciatica and melanoma.,PAST SURGICAL HISTORY: ,As discussed above, surgery for melanoma and hysterectomy.,ALLERGIES: , NONE.,SOCIAL HISTORY: , Denies any tobacco or alcohol use. She is divorced with 2 children. She lives with her son.,PHYSICAL EXAMINATION:,GENERAL: The patient is well developed, well nourished in mild distress secondary to left lower extremity and back pain.,MUSCULOSKELETAL: Examination of the left lower extremity, there is presence of apparent shortening and external rotation deformity. Tenderness to palpation is present. Leg rolling is positive for severe pain in the left proximal hip. Further examination of the spine is incomplete secondary to severe leg pain. She is unable to perform a straight leg raising. EHL/EDL 5/5. 2+ pulses are present distally. Calf is soft and nontender. Homans sign is negative. Sensation to light touch is intact.,IMAGING:, AP view of the hip is reviewed. Only 1 limited view is obtained. This is a poor quality x-ray with a lot of soft tissue shadow. This x-ray is significant for basicervical-type femoral neck fracture. Lesser trochanter is intact. This is a high intertrochanteric fracture/basicervical. There is presence of lytic lesion around the femoral neck, which is not well delineated on this particular x-ray. We need to order repeat x-rays including AP pelvis, femur, and knee.,LABS:, Have been reviewed.,ASSESSMENT: , The patient is a 53-year-old female with probable pathological fracture of the left proximal femur.,DISCUSSION AND PLAN: , Nature and course of the diagnosis has been discussed with the patient. Based on her presentation without any history of obvious fall or trauma and past history of malignant melanoma, this appears to be a pathological fracture of the left proximal hip. At the present time, I would recommend obtaining a bone scan and repeat x-rays, which will include AP pelvis, femur, hip including knee. She denies any pain elsewhere. She does have a past history of back pain and sciatica, but at the present time, this appears to be a metastatic bone lesion with pathological fracture. I have discussed the case with Dr. X and recommended oncology consultation.,With the above fracture and presentation, she needs a left hip hemiarthroplasty versus calcar hemiarthroplasty, cemented type. Indication, risk, and benefits of left hip hemiarthroplasty has been discussed with the patient, which includes, but not limited to bleeding, infection, nerve injury, blood vessel injury, dislocation early and late, persistent pain, leg length discrepancy, myositis ossificans, intraoperative fracture, prosthetic fracture, need for conversion to total hip replacement surgery, revision surgery, DVT, pulmonary embolism, risk of anesthesia, need for blood transfusion, and cardiac arrest. She understands above and is willing to undergo further procedure. The goal and the functional outcome have been explained. Further plan will be discussed with her once we obtain the bone scan and the radiographic studies. We will also await for the oncology feedback and clearance.,Thank you very much for allowing me to participate in the care of this patient. I will continue to follow up.
4,533 changes: 4,533 additions & 0 deletions tests/resources/medcat_trainer_export_relations.json

Large diffs are not rendered by default.

16 changes: 15 additions & 1 deletion tests/test_pipe.py
Original file line number Diff line number Diff line change
@@ -6,12 +6,14 @@
from medcat.config import Config
from medcat.pipe import Pipe
from medcat.meta_cat import MetaCAT
from medcat.rel_cat import RelCAT
from medcat.preprocessing.taggers import tag_skip_and_punct
from medcat.preprocessing.tokenizers import spacy_split_all
from medcat.utils.normalizers import BasicSpellChecker, TokenNormalizer
from medcat.ner.vocab_based_ner import NER
from medcat.linking.context_based_linker import Linker
from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBERT
from medcat.utils.relation_extraction.tokenizer import TokenizerWrapperBERT as RelTokenizerWrapperBERT
from transformers import AutoTokenizer


@@ -41,7 +43,9 @@ def setUpClass(cls) -> None:
cls.linker = Linker(cls.cdb, cls.vocab, cls.config)

_tokenizer = TokenizerWrapperBERT(hf_tokenizers=AutoTokenizer.from_pretrained("bert-base-uncased"))
_tokenizer_rel = RelTokenizerWrapperBERT(hf_tokenizers=AutoTokenizer.from_pretrained("bert-base-uncased"))
cls.meta_cat = MetaCAT(tokenizer=_tokenizer)
cls.rel_cat = RelCAT(cls.cdb, tokenizer=_tokenizer_rel, init_model=True)

cls.text = "stop of CDB - I was running and then Movar Virus attacked and CDb"
cls.undertest = Pipe(tokenizer=spacy_split_all, config=cls.config)
@@ -56,6 +60,7 @@ def setUp(self) -> None:
PipeTests.undertest.force_remove(PipeTests.ner.name)
PipeTests.undertest.force_remove(PipeTests.linker.name)
PipeTests.undertest.force_remove(PipeTests.meta_cat.name)
PipeTests.undertest.force_remove(PipeTests.rel_cat.name)

def test_add_tagger(self):
PipeTests.undertest.add_tagger(tagger=tag_skip_and_punct, name=tag_skip_and_punct.name, additional_fields=["is_punct"])
@@ -82,7 +87,12 @@ def test_add_meta_cat(self):
PipeTests.undertest.add_meta_cat(PipeTests.meta_cat)

self.assertEqual(PipeTests.meta_cat.name, Language.get_factory_meta(PipeTests.meta_cat.name).factory)


def test_add_rel_cat(self):
PipeTests.undertest.add_rel_cat(PipeTests.rel_cat)

self.assertEqual(PipeTests.rel_cat.name, Language.get_factory_meta(PipeTests.rel_cat.name).factory)

def test_stopwords_loading(self):
self.assertEqual(PipeTests.undertest._nlp.Defaults.stop_words, PipeTests.config.preprocessing.stopwords)
doc = PipeTests.undertest(PipeTests.text)
@@ -95,6 +105,7 @@ def test_batch_multi_process(self):
PipeTests.undertest.add_ner(PipeTests.ner)
PipeTests.undertest.add_linker(PipeTests.linker)
PipeTests.undertest.add_meta_cat(PipeTests.meta_cat)
PipeTests.undertest.add_rel_cat(PipeTests.rel_cat)

PipeTests.undertest.set_error_handler(_error_handler)
docs = list(self.undertest.batch_multi_process([PipeTests.text, PipeTests.text, PipeTests.text], n_process=1, batch_size=1))
@@ -114,6 +125,7 @@ def _generate_texts(texts):
PipeTests.undertest.add_ner(PipeTests.ner)
PipeTests.undertest.add_linker(PipeTests.linker)
PipeTests.undertest.add_meta_cat(PipeTests.meta_cat)
PipeTests.undertest.add_rel_cat(PipeTests.rel_cat)

docs = list(self.undertest(_generate_texts([PipeTests.text, None, PipeTests.text])))

@@ -128,6 +140,7 @@ def test_callable_with_single_text(self):
PipeTests.undertest.add_ner(PipeTests.ner)
PipeTests.undertest.add_linker(PipeTests.linker)
PipeTests.undertest.add_meta_cat(PipeTests.meta_cat)
PipeTests.undertest.add_rel_cat(PipeTests.rel_cat)

doc = self.undertest(PipeTests.text)

@@ -139,6 +152,7 @@ def test_callable_with_multi_texts(self):
PipeTests.undertest.add_ner(PipeTests.ner)
PipeTests.undertest.add_linker(PipeTests.linker)
PipeTests.undertest.add_meta_cat(PipeTests.meta_cat)
PipeTests.undertest.add_rel_cat(PipeTests.rel_cat)

docs = list(self.undertest([PipeTests.text, None, PipeTests.text]))

111 changes: 111 additions & 0 deletions tests/test_rel_cat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import os
import shutil
import unittest
import json

from medcat.cdb import CDB
from medcat.config_rel_cat import ConfigRelCAT
from medcat.rel_cat import RelCAT
from medcat.utils.relation_extraction.rel_dataset import RelData
from medcat.utils.relation_extraction.tokenizer import TokenizerWrapperBERT
from medcat.utils.relation_extraction.models import BertModel_RelationExtraction

from transformers.models.auto.tokenization_auto import AutoTokenizer
from transformers.models.bert.configuration_bert import BertConfig

import spacy
from spacy.tokens import Span, Doc

class RelCATTests(unittest.TestCase):

@classmethod
def setUpClass(cls) -> None:
config = ConfigRelCAT()
config.general.device = "cpu"
config.general.model_name = "bert-base-uncased"
config.train.batch_size = 1
config.train.nclasses = 3
config.model.hidden_size= 256
config.model.model_size = 2304

tokenizer = TokenizerWrapperBERT(AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=config.general.model_name,
config=config), add_special_tokens=True)

SPEC_TAGS = ["[s1]", "[e1]", "[s2]", "[e2]"]

tokenizer.hf_tokenizers.add_tokens(SPEC_TAGS, special_tokens=True)
config.general.annotation_schema_tag_ids = tokenizer.hf_tokenizers.convert_tokens_to_ids(SPEC_TAGS)

cls.tmp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "tmp")
os.makedirs(cls.tmp_dir, exist_ok=True)

cls.save_model_path = os.path.join(cls.tmp_dir, "test_model")
os.makedirs(cls.save_model_path, exist_ok=True)

cdb = CDB.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "cdb.dat"))

cls.medcat_export_with_rels_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "resources", "medcat_trainer_export_relations.json")
cls.medcat_rels_csv_path_train = os.path.join(os.path.dirname(os.path.realpath(__file__)), "resources", "medcat_rel_train.csv")
cls.medcat_rels_csv_path_test = os.path.join(os.path.dirname(os.path.realpath(__file__)), "resources", "medcat_rel_test.csv")

cls.mct_file_test = {}
with open(cls.medcat_export_with_rels_path, "r+") as f:
cls.mct_file_test = json.loads(f.read())["projects"][0]["documents"][1]

cls.config_rel_cat: ConfigRelCAT = config
cls.rel_cat: RelCAT = RelCAT(cdb, tokenizer=tokenizer, config=config, init_model=True)

cls.rel_cat.model.bert_model.resize_token_embeddings(len(tokenizer.hf_tokenizers))

cls.finished = False
cls.tokenizer = tokenizer

def test_train_csv_no_tags(self) -> None:
self.rel_cat.config.train.epochs = 2
self.rel_cat.train(train_csv_path=self.medcat_rels_csv_path_train, test_csv_path=self.medcat_rels_csv_path_test, checkpoint_path=self.tmp_dir)
self.rel_cat.save(self.save_model_path)

def test_train_mctrainer(self) -> None:
self.rel_cat = RelCAT.load(self.save_model_path)
self.rel_cat.config.general.mct_export_create_addl_rels = True
self.rel_cat.config.general.mct_export_max_non_rel_sample_size = 10
self.rel_cat.config.train.test_size = 0.1
self.rel_cat.config.train.nclasses = 3
self.rel_cat.model.relcat_config.train.nclasses = 3
self.rel_cat.model.bert_model.resize_token_embeddings(len(self.tokenizer.hf_tokenizers))

self.rel_cat.train(export_data_path=self.medcat_export_with_rels_path, checkpoint_path=self.tmp_dir)

def test_train_predict(self) -> None:
Span.set_extension('id', default=0, force=True)
Span.set_extension('cui', default=None, force=True)
Doc.set_extension('ents', default=[], force=True)
Doc.set_extension('relations', default=[], force=True)
nlp = spacy.blank("en")
doc = nlp(self.mct_file_test["text"])

for ann in self.mct_file_test["annotations"]:
tkn_idx = []
for ind, word in enumerate(doc):
end_char = word.idx + len(word.text)
if end_char <= ann['end'] and end_char > ann['start']:
tkn_idx.append(ind)
entity = Span(doc, min(tkn_idx), max(tkn_idx) + 1, label=ann["value"])
entity._.cui = ann["cui"]
doc._.ents.append(entity)

self.rel_cat.model.bert_model.resize_token_embeddings(len(self.tokenizer.hf_tokenizers))

doc = self.rel_cat(doc)
self.finished = True

assert len(doc._.relations) > 0

def tearDown(self) -> None:
if self.finished:
if os.path.exists(self.tmp_dir):
shutil.rmtree(self.tmp_dir)

if __name__ == '__main__':
unittest.main()

0 comments on commit abc97fb

Please sign in to comment.