Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Need a strategy to gather and compute validation loss on the whole validation dataset #20557

Open
JohnHerry opened this issue Jan 22, 2025 · 0 comments
Labels
feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers

Comments

@JohnHerry
Copy link

JohnHerry commented Jan 22, 2025

Description & Motivation

After lightning 2.0, the validation_epoch_end(self, outputs): function is deprecated. instead, we are asked to use a so called on_validation_epoch_end(self) callback, but this call back takes no parameters. so we should save the validation outpus, into a self.validataion_outpus params inside the lightning_module,
but in the DDP traning, the self.validation_outpus can only save part of the validataion mini-batch results on current rank. I want to get ALL validation results some where, but there is no such convinience.

Pitch

example code:

class My(pl.LightningModule):
    def __init__(self, ...):
       ...
       self.all_validation_step_outptus = []

   def validation_step(self, batch, *args, **kwargs):
       with torch.no_grad():
           gt_data, infer_inputs = batch
          gen_feature = self.backbone(infer_inputs)
          gen_data = self.decoder(gen_feature)
          batch_output = {
              "gt_data": gt_data,
              "gen_feature": gen_feature,
              "gen_data": gen_data
          }
          self.all_validation_step_outpus.append(batch_output)

    def on_validation_epoch_end(self):
         outputs = self.all_validation_step_outputs()  # Here I want ALL validataion dataset results, but can ONLY get local rank results!
        gt_data = [x['gt_data'] for x in outputs]
        gen_data = [x['gen_data'] for x in outputs]
        if self.global_rank == 0:
            val_loss = compute_val_loss(gt_data, gen_data).mean()
            self.log("val_loss": val_loss)
           for  idx, item in outputs:
               self.logger.add_image(f"val/feature_{idx}", plot_spectorgram_to_img(item["gen_feature"]))
           

Alternatives

No response

Additional context

No response

cc @lantiga @Borda

@JohnHerry JohnHerry added feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers labels Jan 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers
Projects
None yet
Development

No branches or pull requests

1 participant