Skip to content

Commit

Permalink
enhancements and bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ming committed Aug 19, 2022
1 parent 437d929 commit 3f0e5ef
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 22 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = bert-classifier
version = 0.1.5
version = 0.1.6
author= Ming Gao
author_email = [email protected]
url = https://github.com/minggnim/nlp-classification-model
Expand Down
5 changes: 1 addition & 4 deletions src/bert_classifier/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,4 @@ def bert_encoder(content, tokenizer, max_len):
return_token_type_ids=True,
return_tensors='pt'
)
ids = inputs['input_ids']
mask = inputs['attention_mask']
token_type_ids = inputs['token_type_ids']
return ids, mask, token_type_ids
return inputs
22 changes: 18 additions & 4 deletions src/bert_classifier/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,26 @@ def __len__(self):
def __getitem__(self, index):
content = str(self.content[index])
content = " ".join(content.split())
ids, mask, token_type_ids = bert_encoder(content, self.tokenizer, self.max_len)
encoded_content = bert_encoder(content, self.tokenizer, self.max_len)

return {
'ids': ids,
'mask': mask,
'type_ids': token_type_ids,
'input_ids': encoded_content['input_ids'],
'attention_mask': encoded_content['attention_mask'],
'token_type_ids': encoded_content['token_type_ids'],
'label': torch.tensor(self.label[index], dtype=torch.long),
# 'multi_label': torch.tensor(self.label[index], dtype=torch.float)
}


def create_label_dict(dataframe, label_col):
labels = dataframe.groupby(label_col).size().sort_values(ascending=False).index.tolist()
label_dict = dict([(d, i) for i, d in enumerate(labels)])
return label_dict


def label2id(dataframe, label_col, label_dict, multi_class=False):
if multi_class:
dataframe['label'] = dataframe[label_col].apply(lambda c: [int(c==l) for l in label_dict.keys()])
else:
dataframe['label'] = dataframe[label_col].apply(lambda c: label_dict[c])
return dataframe
4 changes: 2 additions & 2 deletions src/bert_classifier/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def save_label_dict(label_dict, dict_file=LABEL_DICT):
json.dump(label_dict, file)


def load_label_dict(label_dir):
def load_label_dict(label_file=LABEL_DICT):
'''load label dictionary'''
with open(label_dir, 'r') as file:
with open(label_file, 'r') as file:
return json.load(file)
9 changes: 6 additions & 3 deletions src/bert_classifier/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
)


def transform_outputs(outputs, targets):
def transform_outputs(outputs, targets, multi_label=False):
'''
transform outputs to suitable format for calculating performance metrics
'''
preds = np.array(outputs).argmax(axis=-1)
# y = np.array(targets).argmax(axis=1) # for multilabel
return targets, preds
if multi_label:
y = np.array(targets).argmax(axis=1)
else:
y = targets
return y, preds


def accuracy_metrics(outputs, targets):
Expand Down
2 changes: 1 addition & 1 deletion src/bert_classifier/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def predict(self, inp: str):
'''
self.model.eval()
enc = bert_encoder(inp, self.tokenizer, self.max_len)
out = self.model(*enc)[-1].detach().cpu()
out = self.model(**enc)[-1].detach().cpu()
idx = out.argmax().item()
label = self.labels[idx]
proba = out.softmax(-1)[idx].item()
Expand Down
14 changes: 7 additions & 7 deletions src/bert_classifier/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def loss_fn(outputs, targets):
return torch.nn.CrossEntropyLoss()(outputs, targets)


def train(model, optimizer, train_dataloader, test_dataloader, epochs):
def custom_trainer(model, optimizer, train_dataloader, test_dataloader, epochs):
'''
custom training module
'''
Expand All @@ -40,9 +40,9 @@ def train(model, optimizer, train_dataloader, test_dataloader, epochs):
for _, batch in tqdm(enumerate(train_dataloader)):
# import pdb; pdb.set_trace();
model.zero_grad(set_to_none=True)
ids = batch['ids'].squeeze(1).to(DEVICE)
mask = batch['mask'].squeeze(1).to(DEVICE)
type_ids = batch['type_ids'].squeeze(1).to(DEVICE)
ids = batch['input_ids'].squeeze(1).to(DEVICE)
mask = batch['attention_mask'].squeeze(1).to(DEVICE)
type_ids = batch['token_type_ids'].squeeze(1).to(DEVICE)
label = batch['label'].to(DEVICE)
output = model(ids, mask, type_ids)

Expand Down Expand Up @@ -84,9 +84,9 @@ def validate(model, test_dataloader):
val_targets, val_outputs, val_loss = [], [], []
with torch.no_grad():
for _, batch in enumerate(test_dataloader):
ids = batch['ids'].squeeze(1).to(DEVICE)
mask = batch['mask'].squeeze(1).to(DEVICE)
type_ids = batch['type_ids'].squeeze(1).to(DEVICE)
ids = batch['input_ids'].squeeze(1).to(DEVICE)
mask = batch['attention_mask'].squeeze(1).to(DEVICE)
type_ids = batch['token_type_ids'].squeeze(1).to(DEVICE)
label = batch['label'].to(DEVICE)
outputs = model(ids, mask, type_ids)
val_targets.extend(label.detach().cpu().numpy())
Expand Down

0 comments on commit 3f0e5ef

Please sign in to comment.