Skip to content

[T1219_최현진]attention.py

chjin0725 edited this page Nov 29, 2021 · 2 revisions

베이스라인 코드의 attention.py

  • attention.py는 ASTER에서 인코더에 있는 bidirectional LSTM만 제거한 구조이다.

디코더에서 attention을 계산하는 코드에 대한 의문.

image

  • 빨간 박스 부분에서 attention score를 계산하고 있다.
  • src_features는 encoder의 output들이고 prev_hidden_proj는 바로 전 time step의 hidden state이다.
  • 지금까지 봐왔던 attention의 계산방식에 따른 다면 src_features와 prev_hidden_proj를 내적하는 방식으로 계산했을 것이다.
  • 즉, attention_logit = torch.bmm(prev_hidden_proj, src_features.transpose(-1,-2))와 같은 방식으로 계산 하였을 것이다.
  • 근데 여기서는 이 둘을 더하고 tanh를 씌우고 linear layer(코드상에서는 self.score)에 통과시키는 식으로 되어 있다.
  • 논문을 확인해 보았으나 딱히 왜 이렇게 했는지에 대한 언급은 없었다.
  • 다만 이렇게 하더라도 논리적으로 틀린 점은 없는 것 같다. encoder output과 이전 time step의 hidden state를 고려하여 attention weight를 계산 하겠다는 점은 똑같다.