Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pretrained Model을 불러와서 나만의 모델로 구성해보자!! #27

Open
gistarrr opened this issue Oct 4, 2021 · 3 comments
Open

Comments

@gistarrr
Copy link
Contributor

gistarrr commented Oct 4, 2021

1. 기존 Pretrained Model 문제점 (?) 아쉬운 점(?)

기존의 Pretrained Model을 불러와서 쓰면 항상 고정된 classifier에 label의 개수만 바꿔주는 형식으로 동작.
ex) config.num_labels=30
Transformer에서 AutoModelForSequenceClassification을 살펴보시면 이미 자체적으로 classifier가 구성되어 있습니다.
이를 좀 더 유연하게 label의 개수뿐만 아니라 내부 layer들을 설정해준다면 더 깊은 model을 만들 수 있지 않을까 싶어서 시도했습니다.

2. Model을 어떻게 구성할까?

기존의 nn.Module을 상속받는 형태로 구성했습니다.

class MyModel(nn.Module):
    def __init__(self, model_name, config):
        super().__init__()
        print("this is custom model!!!")
        self.backbone = AutoModel.from_pretrained(model_name, config=config)
        self.num_labels = config.num_labels
        self.config = config

        self.dense = nn.Linear(config.hidden_size*4, config.hidden_size*4)

        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        self.out_proj = nn.Linear(config.hidden_size*4, config.num_labels)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ): # Roberta model의 input 형식과 동일합니다.
        outputs = self.backbone(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        ) # Roberta model의 input 형식과 동일하기 때문에 기존 backbone 모델에 넣으면 원하는 output을 가져올 수 있습니다.

        # output[0] = last hidden state / output[1] = pooler_output / output[2] = all hidden states (if 'output_hidden_states': True) == past_key_value, tuple 형태로 반환    
        all_hidden_states = torch.stack(outputs[2]) # (layers , batch, seq_len, hidden_size)
        concat_pooling_layer = torch.cat(
            (all_hidden_states[-1], all_hidden_states[-2], all_hidden_states[-3], all_hidden_states[-4]), -1)  # (batch, seq_len, hidden_size * 4)
        
        concat_pooling_layer = concat_pooling_layer[:, 0, :] # 모든 layer의 CLS embedding vector / (batch, hidden_size * 4)

        logits = self.dropout(concat_pooling_layer)  # (batch, hidden_size * 4)
        logits = self.dense(logits) # (batch, hidden_size * 4) -> (batch, hidden_size*4)
        logits = torch.tanh(logits)  # (batch, hidden_size*4)
        logits = self.dropout(logits)   # (batch, hidden_size*4)
        logits = self.out_proj(logits)  # (batch, hidden_size*4) -> (batch, num_labels)

        loss = None

        if labels is not None:
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = nn.MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(
                    logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = nn.BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

        if not return_dict:
            output = (logits,) + outputs[3:]
            return ((loss,) + output) if loss is not None else output
        # transformers의 Trainer를 사용하기 위해서는 output을 원하는 모델 형태에 맞게 tuple형태로 반환해야 합니다.
        # AutoModelForSequenceClassification의 경우 loss, logits, outputs.hidden_states, outputs.attentions 입니다.
        # 자세한 내용은 Trainer docs 참고
        return (
            loss,
            logits,
            outputs.hidden_states,
            outputs.attentions,
        )

실제 Trainer docs를 살펴보면

The Trainer class is optimized for 🤗 Transformers models and can have surprising behaviors when you use it on other models. When using it on your own model, make sure:

  1. your model always return tuples or subclasses of ModelOutput.

  2. your model can compute the loss if a labels argument is provided and that loss is returned as the first element of the tuple (if your model returns tuples)

  3. your model can accept multiple label arguments (use the label_names in your TrainingArguments to indicate their name to the Trainer) but none of them should be named "label".

로 표시되어 있습니다. 그리고 이곳에서 SequenceClassifierOutput에 대한 설명을 살펴보면 위와 같은 output parameter로 loss, logits, outputs.hidden_states, outputs.attentions을 가지는 것을 확인하실 수 있습니다.

3. 사소하지만 치명적인 에러 해결

1. config.update({'output_hidden_states': True})를 이용해 output_hidden_states를 출력하도록 설정

model_config = AutoConfig.from_pretrained(args.PLM)
model_config.num_labels = 30
model_config.update({'output_hidden_states': True})

model = MyModel(args.PLM, config=model_config)

2. model.save_pretrained(model_save_pth)

기존의 model은 PretrainedModel을 기본적으로 상속 받기 때문에 save_pretrained method가 존재하지만 저희가 만든 모델은 없습니다. 그렇기 때문에 따로 torch.save()를 이용해줍니다.

torch.save(model.state_dict(), os.path.join(model_save_pth, pytorch_model.pt'))

3. CUDA OOM

if not return_dict:
            output = (logits,) + outputs[3:]
            return ((loss,) + output) if loss is not None else output

이 코드에서 outputs[2:]outputs[3:]으로 바꾸니 해결 ,,, (이유를 모르겠습니다)
바꾼 이유는 기존 config.update({'output_hidden_states': True})를 해줌으로써 output[2](all hidden states)가 추가적으로 생성되었다고 생각해서 바꿔주었습니다. 이 부분에 대해서는 얘기가 필요할 것 같습니다.

3. Inference 과정

기존 inference.py에 존재하는

model = AutoModelForSequenceClassification.from_pretrained(model_dir)
model.parameters
model.to(device)

이 코드를

model_config = AutoConfig.from_pretrained(args.PLM)
model_config.num_labels = 30
model_config.update({'output_hidden_states': True})
model = MyModel(args.PLM, config=model_config)
model.load_state_dict(torch.load(
    os.path.join(model_dir, 'pytorch_model.pt')))

으로 바꿔줌으로써 해결

4. 결과 확인

1
2
wandb를 살펴보면 결과 f1_score나 loss가 좀 더 좋아진 것을 확인할 수 있었습니다.
실제 제출 결과는 65.487이고 k_fold 5한 결과는 70.535가 나왔습니다.
생각해봤을 때는 layer가 더 무거워져서 epoch 3으로도 과적합이 되지 않았나 생각합니다.

5. BONUS

추가적으로 pretrained model의 parameter들을 freeze하고 실험을 해봤습니다.
위 모델에서

def __init__(self, model_name, config):
            super().__init__()
            print("this is custom model!!!")
            self.backbone = AutoModel.from_pretrained(model_name, config=config)
            self.num_labels = config.num_labels
            self.config = config

            self.dense = nn.Linear(config.hidden_size*4, config.hidden_size*4)

            classifier_dropout = (
                config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
            )
            self.dropout = nn.Dropout(classifier_dropout)
            self.out_proj = nn.Linear(config.hidden_size*4, config.num_labels)
############################################
            for param in tqdm(self.backbone.parameters()):
                param.requires_grad = False
############################################

아랫 부분만 추가해서 backbone의 parameter는 freeze하여 실험을 진행하였습니다.

wandb로 확인해본 결과 f1_score 40 이하 생략

@j961224
Copy link
Contributor

j961224 commented Oct 4, 2021

pool output 대신, all hidden states를 사용해주셨군요! 좋은 실험 감사합니다!! 더불어 kaggle 링크 공유해 주신 것도 살펴봐야겠군요!

@presto105
Copy link
Contributor

제가 한거랑 비슷한부분들이 많아 참고할 부분이 많네요!! 많이 배워갑니다

@gistarrr
Copy link
Contributor Author

gistarrr commented Oct 4, 2021

현재 dev에서 cuda OOM 에러 발생 확인

dev와 branch를 merge하고 실행한 결과 OOM 에러가 뜨는 것을 확인하였습니다.
기존 코드에서는 잘 돌아가는데 리펙토링한 dev 코드에서는 돌아가지 않네요 ,,, ㅠ

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants