You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
기존의 Pretrained Model을 불러와서 쓰면 항상 고정된 classifier에 label의 개수만 바꿔주는 형식으로 동작.
ex) config.num_labels=30
Transformer에서 AutoModelForSequenceClassification을 살펴보시면 이미 자체적으로 classifier가 구성되어 있습니다.
이를 좀 더 유연하게 label의 개수뿐만 아니라 내부 layer들을 설정해준다면 더 깊은 model을 만들 수 있지 않을까 싶어서 시도했습니다.
2. Model을 어떻게 구성할까?
기존의 nn.Module을 상속받는 형태로 구성했습니다.
classMyModel(nn.Module):
def__init__(self, model_name, config):
super().__init__()
print("this is custom model!!!")
self.backbone=AutoModel.from_pretrained(model_name, config=config)
self.num_labels=config.num_labelsself.config=configself.dense=nn.Linear(config.hidden_size*4, config.hidden_size*4)
classifier_dropout= (
config.classifier_dropoutifconfig.classifier_dropoutisnotNoneelseconfig.hidden_dropout_prob
)
self.dropout=nn.Dropout(classifier_dropout)
self.out_proj=nn.Linear(config.hidden_size*4, config.num_labels)
defforward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
): # Roberta model의 input 형식과 동일합니다.outputs=self.backbone(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
) # Roberta model의 input 형식과 동일하기 때문에 기존 backbone 모델에 넣으면 원하는 output을 가져올 수 있습니다.# output[0] = last hidden state / output[1] = pooler_output / output[2] = all hidden states (if 'output_hidden_states': True) == past_key_value, tuple 형태로 반환 all_hidden_states=torch.stack(outputs[2]) # (layers , batch, seq_len, hidden_size)concat_pooling_layer=torch.cat(
(all_hidden_states[-1], all_hidden_states[-2], all_hidden_states[-3], all_hidden_states[-4]), -1) # (batch, seq_len, hidden_size * 4)concat_pooling_layer=concat_pooling_layer[:, 0, :] # 모든 layer의 CLS embedding vector / (batch, hidden_size * 4)logits=self.dropout(concat_pooling_layer) # (batch, hidden_size * 4)logits=self.dense(logits) # (batch, hidden_size * 4) -> (batch, hidden_size*4)logits=torch.tanh(logits) # (batch, hidden_size*4)logits=self.dropout(logits) # (batch, hidden_size*4)logits=self.out_proj(logits) # (batch, hidden_size*4) -> (batch, num_labels)loss=NoneiflabelsisnotNone:
ifself.config.problem_typeisNone:
ifself.num_labels==1:
self.config.problem_type="regression"elifself.num_labels>1and (labels.dtype==torch.longorlabels.dtype==torch.int):
self.config.problem_type="single_label_classification"else:
self.config.problem_type="multi_label_classification"ifself.config.problem_type=="regression":
loss_fct=nn.MSELoss()
ifself.num_labels==1:
loss=loss_fct(logits.squeeze(), labels.squeeze())
else:
loss=loss_fct(logits, labels)
elifself.config.problem_type=="single_label_classification":
loss_fct=nn.CrossEntropyLoss()
loss=loss_fct(
logits.view(-1, self.num_labels), labels.view(-1))
elifself.config.problem_type=="multi_label_classification":
loss_fct=nn.BCEWithLogitsLoss()
loss=loss_fct(logits, labels)
ifnotreturn_dict:
output= (logits,) +outputs[3:]
return ((loss,) +output) iflossisnotNoneelseoutput# transformers의 Trainer를 사용하기 위해서는 output을 원하는 모델 형태에 맞게 tuple형태로 반환해야 합니다.# AutoModelForSequenceClassification의 경우 loss, logits, outputs.hidden_states, outputs.attentions 입니다.# 자세한 내용은 Trainer docs 참고return (
loss,
logits,
outputs.hidden_states,
outputs.attentions,
)
실제 Trainer docs를 살펴보면
The Trainer class is optimized for 🤗 Transformers models and can have surprising behaviors when you use it on other models. When using it on your own model, make sure:
your model always return tuples or subclasses of ModelOutput.
your model can compute the loss if a labels argument is provided and that loss is returned as the first element of the tuple (if your model returns tuples)
your model can accept multiple label arguments (use the label_names in your TrainingArguments to indicate their name to the Trainer) but none of them should be named "label".
로 표시되어 있습니다. 그리고 이곳에서 SequenceClassifierOutput에 대한 설명을 살펴보면 위와 같은 output parameter로 loss, logits, outputs.hidden_states, outputs.attentions을 가지는 것을 확인하실 수 있습니다.
3. 사소하지만 치명적인 에러 해결
1. config.update({'output_hidden_states': True})를 이용해 output_hidden_states를 출력하도록 설정
이 코드에서 outputs[2:]를 outputs[3:]으로 바꾸니 해결 ,,, (이유를 모르겠습니다)
바꾼 이유는 기존 config.update({'output_hidden_states': True})를 해줌으로써 output[2](all hidden states)가 추가적으로 생성되었다고 생각해서 바꿔주었습니다. 이 부분에 대해서는 얘기가 필요할 것 같습니다.
wandb를 살펴보면 결과 f1_score나 loss가 좀 더 좋아진 것을 확인할 수 있었습니다.
실제 제출 결과는 65.487이고 k_fold 5한 결과는 70.535가 나왔습니다.
생각해봤을 때는 layer가 더 무거워져서 epoch 3으로도 과적합이 되지 않았나 생각합니다.
5. BONUS
추가적으로 pretrained model의 parameter들을 freeze하고 실험을 해봤습니다.
위 모델에서
1. 기존 Pretrained Model 문제점 (?) 아쉬운 점(?)
기존의 Pretrained Model을 불러와서 쓰면 항상 고정된 classifier에 label의 개수만 바꿔주는 형식으로 동작.
ex)
config.num_labels=30
Transformer에서 AutoModelForSequenceClassification을 살펴보시면 이미 자체적으로 classifier가 구성되어 있습니다.
이를 좀 더 유연하게 label의 개수뿐만 아니라 내부 layer들을 설정해준다면 더 깊은 model을 만들 수 있지 않을까 싶어서 시도했습니다.
2. Model을 어떻게 구성할까?
기존의 nn.Module을 상속받는 형태로 구성했습니다.
실제 Trainer docs를 살펴보면
The Trainer class is optimized for 🤗 Transformers models and can have surprising behaviors when you use it on other models. When using it on your own model, make sure:
your model always return tuples or subclasses of ModelOutput.
your model can compute the loss if a labels argument is provided and that loss is returned as the first element of the tuple (if your model returns tuples)
your model can accept multiple label arguments (use the label_names in your TrainingArguments to indicate their name to the Trainer) but none of them should be named "label".
로 표시되어 있습니다. 그리고 이곳에서 SequenceClassifierOutput에 대한 설명을 살펴보면 위와 같은 output parameter로 loss, logits, outputs.hidden_states, outputs.attentions을 가지는 것을 확인하실 수 있습니다.
3. 사소하지만 치명적인 에러 해결
1. config.update({'output_hidden_states': True})를 이용해 output_hidden_states를 출력하도록 설정
2. model.save_pretrained(model_save_pth)
기존의 model은 PretrainedModel을 기본적으로 상속 받기 때문에 save_pretrained method가 존재하지만 저희가 만든 모델은 없습니다. 그렇기 때문에 따로
torch.save()
를 이용해줍니다.3. CUDA OOM
이 코드에서
outputs[2:]
를outputs[3:]
으로 바꾸니 해결 ,,, (이유를 모르겠습니다)바꾼 이유는 기존
config.update({'output_hidden_states': True})
를 해줌으로써output[2](all hidden states)
가 추가적으로 생성되었다고 생각해서 바꿔주었습니다. 이 부분에 대해서는 얘기가 필요할 것 같습니다.3. Inference 과정
기존 inference.py에 존재하는
이 코드를
으로 바꿔줌으로써 해결
4. 결과 확인
wandb를 살펴보면 결과 f1_score나 loss가 좀 더 좋아진 것을 확인할 수 있었습니다.
실제 제출 결과는 65.487이고 k_fold 5한 결과는 70.535가 나왔습니다.
생각해봤을 때는 layer가 더 무거워져서 epoch 3으로도 과적합이 되지 않았나 생각합니다.
5. BONUS
추가적으로 pretrained model의 parameter들을 freeze하고 실험을 해봤습니다.
위 모델에서
아랫 부분만 추가해서 backbone의 parameter는 freeze하여 실험을 진행하였습니다.
wandb로 확인해본 결과
f1_score 40 이하 생략The text was updated successfully, but these errors were encountered: