Alternate "prediction" loops #20318
-
I am hoping to implement an additional loop through the Trainer in order to leverage Lightnings automagic handling of dataloaders and GPUs. Specifically, I want to run batches through the attribution methods from Captum. My first attempt was to hijack the def predict_step(self, batch, batch_idx):
if self.calculate_attributes:
return self.attribution_step(batch, batch_idx)
else:
data, target = batch
return self.model(data)
def attribution_step(self, batch, batch_idx):
data, target = batch
batch_size = data.shape[0]
baselines = torch.zeros_like(data)
attribution = self.explainer.attribute(data, baselines, target=target, internal_batch_size=batch_size)
return attribution, target But this has run into issues because gradients are required, and, I believe, the prediction loop disables them. I tried to get around with the Is there a proper way to implement this? Any suggestions would be appreciated. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
I have also tested calling |
Beta Was this translation helpful? Give feedback.
-
Setting |
Beta Was this translation helpful? Give feedback.
Setting
Trainer(inference_mode=False)
was the answer#15925 #15765