Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

koBigBird Encoder와 Bart Decoder를 연결해보자!! #35

Open
j961224 opened this issue Dec 14, 2021 · 0 comments · Fixed by #36
Open

koBigBird Encoder와 Bart Decoder를 연결해보자!! #35

j961224 opened this issue Dec 14, 2021 · 0 comments · Fixed by #36
Assignees
Labels
enhancement New feature or request

Comments

@j961224
Copy link
Contributor

j961224 commented Dec 14, 2021

BigBird encoder와 Bart decoder 연결하기!

1. 개요

doc_type_ids를 추가하여 수정된 BigBirdModel class와 Bartdecoder class를 EncoderDecoderModel class로 encoder-decoder 모델을 만듭니다.

2. Model 흐름

EncoderDecoderModel class/
├── BigBirdModelWithDoctype class
│   ├── BigBirdEmbeddingsWithDoctype class
│   ├── BigBirdEncoder class
├── BartDecoderWithDoctype class
│   ├── BartLearnedPositionalEmbedding class
│   ├── BartDecoderLayer class

3. modeling_kobigbird_bart.py 구조

BigBirdConfigWithDoctype class

기존 BigBirdConfig를 상속받아 doc_type_size를 추가한 class입니다.

class BigBirdConfigWithDoctype(BigBirdConfig):
    def __init__(self, doc_type_size: int=None, **kwargs):
        super().__init__(**kwargs)
        self.doc_type_size = doc_type_size

BartConfigWithDoctype class

기존 BartConfig를 상속받아 doc_type_size를 추가한 class입니다.

class BartConfigWithDoctype(BartConfig):
    def __init__(self, doc_type_size: int=None, **kwargs):
        super().__init__(**kwargs)
        self.doc_type_size = doc_type_size

BigBirdEmbeddingsWithDoctype class

기존 BigBirdEmbeddings class를 상속받아 기존에 word_embedding, token_type_embedding, position embedding을 이용하여 embedding을 구성한 것에 doc_type_embeddings을 추가한 class입니다.

# init 함수에 추가한 코드!!!
if isinstance(config.doc_type_size, int):
    self.doc_type_embeddings = nn.Embedding(config.doc_type_size, config.hidden_size)
# doc_type_embedding을 추가로 계산하는 코드 추가!!
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings

if doc_type_ids is not None :
    doc_type_embeddings = self.doc_type_embeddings(doc_type_ids)
    embeddings += doc_type_embeddings

BigBirdModelWithDoctype class

기존 BigBirdPreTrainedModel class를 상속받아 BigBirdEmbeddings class 말고, BigBirdEmbeddingsWithDoctype class를 이용하기 위해 사용된 class입니다.
그리고 추가로 doc_type_ids를 추가해 사용합니다.

# BigBirdEmbeddingsWithDoctype class를 부른 코드 추가!!
self.embeddings = BigBirdEmbeddingsWithDoctype(config)
self.encoder = BigBirdEncoder(config)

BartDecoderWithDoctype class

기존 BartDecoder class에 doc_type_embedding을 추가하기 위한 class입니다.

# 기존 init함수에 doc_type_tokens라는 embedding 추가!
if embed_tokens is not None:
    self.embed_tokens = embed_tokens
else:
    self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
        
if doc_type_tokens is not None:
    self.doc_type_tokens = doc_type_tokens
else:
    if isinstance(config.doc_type_size, int) :
        self.doc_type_tokens = nn.Embedding(config.doc_type_size, config.d_model, self.padding_idx)
# doc_type_embeddings을 추가로 더하는 과정!!
if doc_type_ids is not None:
    doc_type_embeddings = self.doc_type_tokens(doc_type_ids)

hidden_states = inputs_embeds + positions + doc_type_embeddings if doc_type_ids is not None else inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states)

EncoderDecoderModel class

기존 EncoderDecoderModel class를 응용한 class입니다. 저희 BigBirdModelWithDoctype encoder class와 BartDecoderWithDoctype class를 받기 위해서 수정했습니다.

# 아래의 사진과 같이 만들기 위해서 수정!!
if decoder is None:
    decoder = BartDecoderWithDoctype(config.decoder)

self.encoder = encoder
self.decoder = decoder

# Bart decoder를 쓰므로, Bart에 맞는 linear 정의!!
self.lm_head = nn.Linear(config.decoder.d_model, config.encoder.vocab_size, bias=False)
self.register_buffer("final_logits_bias", torch.zeros((1, config.encoder.vocab_size)))
## Bart decoder에 수행하는 loss 계산법으로 수행!!
lm_logits = self.lm_head(decoder_outputs[0]) + self.final_logits_bias

loss = None
if labels is not None:
    loss_fct = CrossEntropyLoss()
    # decoder의 d_model(vocab_size)가 encoder vocab_size로 대체(이유는 encoder word_embedding이 decoder word_embedding으로 대체)
    loss = loss_fct(lm_logits.view(-1, self.config.encoder.vocab_size), labels.view(-1))

4. processor.py - doc_type embedding 관련 부분

input_ids의 크기와 맞게, doc_type_ids을 넣고 나머지 부분은 pad_id를 추가적으로 붙입니다.

def preprocess_function(examples, tokenizer, data_args):
    ... # 생략
    if data_args.use_doc_type_ids:
        model_inputs = doc_type_marking(model_inputs, doc_type_id, pad_id)

        model_inputs["labels"] = labels["input_ids"]
        return model_inputs
    
# embedding에 필요한 doc_type_ids를 생성, input_ids와 길이 동일
def doc_type_marking(tokenizer_tmp_input, doc_type_id, pad_id):
    doc_type_input_ids=[]
    
    for i, input_ids_per_one_input in enumerate(tokenizer_tmp_input['input_ids']):
        if pad_id not in input_ids_per_one_input:
            marking = [doc_type_id[i]]*len(input_ids_per_one_input)
            doc_type_input_ids.append(marking)
        else:
            marking = [doc_type_id[i]]*len(input_ids_per_one_input[:list(input_ids_per_one_input).index(0)])
            tmp = [0] * len(input_ids_per_one_input[list(input_ids_per_one_input).index(0):])
            doc_type_input_ids.append(marking+tmp)
    tokenizer_tmp_input['doc_type_ids'] = doc_type_input_ids

    return tokenizer_tmp_input

5. data_collator.py - doc_type embedding 관련 부분

그냥 data_collator를 사용할 시, doc_type_ids로 인해 batch마다 padding이 안 됩니다! 그래서 아래와 같은 코드를 추가하여 pad가 되도록 조정했습니다.

doc_type_ids = [feature["doc_type_ids"] for feature in features] if "doc_type_ids" in features[0].keys() else None
if doc_type_ids is not None:
    max_label_length = max(len(l) for l in doc_type_ids)
    padding_side = self.tokenizer.padding_side
    for feature in features:
        remainder = [0] * (max_label_length - len(feature["doc_type_ids"]))
        if isinstance(feature["doc_type_ids"], list):
            feature["doc_type_ids"] = (
                feature["doc_type_ids"] + remainder if padding_side == "right" else remainder + feature["doc_type_ids"]
            )
        elif padding_side == "right":
            feature["doc_type_ids"] = np.concatenate([feature["doc_type_ids"], remainder]).astype(np.int64)
        else:
            feature["doc_type_ids"] = np.concatenate([remainder, feature["doc_type_ids"]]).astype(np.int64)

6. train.py - BigBirdBart model 적용 부분

BigBirdBart Model을 쓰면 if문에 들어가 아래와 같이 각각 config을 불러와 줍니다.

if model_args.use_kobigbird_bart:
    config_e = BigBirdConfigWithDoctype.from_pretrained("monologg/kobigbird-bert-base")
    config_d = BartConfigWithDoctype.from_pretrained("gogamza/kobart-base-v1")

    # encoder layer 6개 설정
    config_e.encoder_layers = 6

    # decoder config 설정
    config_d.vocab_size = config_e.vocab_size
    config_d.pad_token_id = config_e.pad_token_id
    config_d.max_position_embeddings = data_args.max_target_length

    # doc_type_embedding 사용할 경우
    if data_args.use_doc_type_ids :
        config_e.doc_type_size = 3
        config_d.doc_type_size = 3
else :
    config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir
    )

BigBirdBart Model을 쓰면, 아래와 같이 config를 이용해 weight가 있는 모델을 각각 불러옵니다.

 def model_init():
    if model_args.use_kobigbird_bart:
        encoder = BigBirdModelWithDoctype.from_pretrained("monologg/kobigbird-bert-base",config=config_e)
        decoder = BartDecoderWithDoctype.from_pretrained("gogamza/kobart-base-v1", config=config_d)

        for i in range(1,6):
            encoder.encoder.layer[i] = encoder.encoder.layer[2*i]

        encoder.encoder.layer = encoder.encoder.layer[:config_e.encoder_layers]
        # decoder shared imbedding 설정
        decoder.embed_tokens = encoder.embeddings.word_embeddings

        total_model = EncoderDecoderModel(encoder = encoder, decoder = decoder)
        return total_model
    else :
        return AutoModelForSeq2SeqLM.from_pretrained(
            model_args.model_name_or_path,
            from_tf=bool(".ckpt" in model_args.model_name_or_path),
            config=config
        )

7. predict.py - BigBirdBart model 적용 부분

train.py를 통해 학습한 BigBirdBart model을 불러옵니다.

if model_args.use_kobigbird_bart :
    model = EncoderDecoderModel.from_pretrained(model_args.model_name_or_path)
    model.encoder.encoder.layer = model.encoder.encoder.layer[:model.config.encoder.encoder_layers]
else :
    model = AutoModelForSeq2SeqLM.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
    )

8. Argument 추가 부분

ModelArguments

kobigbird encoder와 bart decoder를 이용할 껀지에 대한 여부를 결정하는 args입니다.

use_kobigbird_bart: bool = field(
        default=False,
        metadata={"help": "use kobigbird encoder and bart decoder"},
    )

DataTrainingArguments

tokenizer에 input이 들어갈 시, input_ids, attention_mask, token_type_ids만 나옵니다. 거기에 doc_type_ids를 쓰기 위해서, 따로 만들어 줄껀지에 대한 여부를 알려주는 args입니다.

use_doc_type_ids: bool = field(
        default=False,
        metadata={  
            "help": "Calculate the evaluation step relative to the size of the data set."
        },
    )

9. Shell script example

python train.py \
--model_name_or_path monologg/kobigbird-bert-base \
--use_kobigbird_bart True \
--use_doc_type_ids True \ # doc_type embedding을 사용할 겨우우
--do_train \
--output_dir model/kobigbirdbart \
--overwrite_output_dir \
--dataset_name paper,news,magazine \
--num_train_epochs 3 \
--learning_rate 3e-05 \
--max_source_length 4096 \
--max_target_length 128 \
--metric_for_best_model rougeLsum \
--relative_eval_steps 10 \
--es_patience 3 \
--load_best_model_at_end True \
--relative_sample_ratio 0.5 \
--project_name kobigbirdbart \
--wandb_unique_tag kobigbirdbart_base_epoch_3 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 2 \
--is_part True \
python predict.py \
--model_name_or_path 저장된 모델 경로
--tokenizer_name monologg/kobigbird-bert-base \
--num_beams 3 \
--use_kobigbird_bart True

Written by 기성, 유석

@j961224 j961224 added the report Sharing information or results of analysis label Dec 14, 2021
@j961224 j961224 self-assigned this Dec 14, 2021
@gistarrr gistarrr added enhancement New feature or request and removed report Sharing information or results of analysis labels Dec 14, 2021
@gistarrr gistarrr linked a pull request Dec 14, 2021 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants