-
Notifications
You must be signed in to change notification settings - Fork 0
[T1219_최현진]attention.py
chjin0725 edited this page Nov 29, 2021
·
2 revisions
- attention.py는 ASTER에서 인코더에 있는 bidirectional LSTM만 제거한 구조이다.
- 빨간 박스 부분에서 attention score를 계산하고 있다.
- src_features는 encoder의 output들이고 prev_hidden_proj는 바로 전 time step의 hidden state이다.
- 지금까지 봐왔던 attention의 계산방식에 따른 다면 src_features와 prev_hidden_proj를 내적하는 방식으로 계산했을 것이다.
- 즉,
attention_logit = torch.bmm(prev_hidden_proj, src_features.transpose(-1,-2))
와 같은 방식으로 계산 하였을 것이다. - 근데 여기서는 이 둘을 더하고 tanh를 씌우고 linear layer(코드상에서는 self.score)에 통과시키는 식으로 되어 있다.
- 논문을 확인해 보았으나 딱히 왜 이렇게 했는지에 대한 언급은 없었다.
- 다만 이렇게 하더라도 논리적으로 틀린 점은 없는 것 같다. encoder output과 이전 time step의 hidden state를 고려하여 attention weight를 계산 하겠다는 점은 똑같다.