How can I use median as the reduction function for loss aggregation? #19778
Unanswered
dempsey-ryan
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I want to log the median loss every epoch, and Lightning is fighting me. I originally did
self.log(..., reduce_dx=torch.median)
which gave me errorlightning_fabric.utilities.exceptions.MisconfigurationException: Only `self.log(..., reduce_fx={min,max,mean,sum})` are supported. If you need a custom reduction, please log a `torchmetrics.Metric` instance instead. Found: <function median at 0x7f53bbd59fc0>
This is the implementation of my Torch Metric:
in my LightningModule:
Now I get
TypeError: MedianMetric.update() takes 2 positional arguments but 3 were given
. I'm pretty sure it wants me to put preds and inputs as the arguments to update, but then I'm re-calculating the loss once in training_step and then a second time in MedianMetric.compute. Why is this so difficult? All I want is to use median as the reduction function!Beta Was this translation helpful? Give feedback.
All reactions