Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KL loss specifications #8

Open
JoakimEdin opened this issue Jan 31, 2024 · 1 comment
Open

KL loss specifications #8

JoakimEdin opened this issue Jan 31, 2024 · 1 comment

Comments

@JoakimEdin
Copy link

JoakimEdin commented Jan 31, 2024

Hi again! I am having troubles reproducing your results. I think it is my loss function that is the issue. Below I have written the code I used with questions as comments. Would appreciate it if you could give me some guidance 🙏

kl_div_func = torch.nn.KLDivLoss(reduction='batchmean') # is this correct, or should I choose mean or sum instead?

y_prob, attention = model(input)

evidence_token_ids = torch.softmax(evidence_token_ids, dim=-1) # Is this correct, or should I remove this line?
attention = torch.log(attention)

binary_cross_loss =  torch.nn.functional.binary_cross_entropy_with_logits(y_prob, y)
kl_div = kl_div_func(attention, evidence_token_ids) # are you  providing the function (evidence_token_ids, attention) instead.

loss = binary_cross_loss  + lambda_1 *  kl_div 

Did you calculate the Kl divergence between the attention and the boolean ground truth evidence, or did you use a softmax on the ground truth (I see some people doing this)? Furthermore, which reduction did you use in the KL divergence? torch.nn.KLDivLoss gives you the option to choose between mean, batchmean, and sum (see documentation here: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html). Finally, how did you deal with all examples without annotated evidence?

@ranajafari
Copy link
Contributor

ranajafari commented Feb 1, 2024

Hi Joakim! softmax function changes the zero values in 'evidence_token_ids' to non-zero values. To ensure that the ground truth evidence matches a correct probability distribution you could use the following:
evidence_token_ids = torch.nn.functional.normalize(evidence_token_ids, p=1.0, dim=-1),
where evidence_token_ids is target in KLD loss.

For the reduction in the loss function, the "mean" value was used.

The attention scores for examples without evidence were not passed to the loss.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants