diff --git a/ltr/trainers/ltr_trainer.py b/ltr/trainers/ltr_trainer.py index 25566029..8c858f7b 100644 --- a/ltr/trainers/ltr_trainer.py +++ b/ltr/trainers/ltr_trainer.py @@ -94,7 +94,7 @@ def _update_stats(self, new_stats: OrderedDict, batch_size, loader): for name, val in new_stats.items(): if name not in self.stats[loader.name].keys(): self.stats[loader.name][name] = AverageMeter() - self.stats[loader.name][name].update(val, batch_size) + self.stats[loader.name][name].update(val.item(), batch_size) def _print_stats(self, i, loader, batch_size): self.num_frames += batch_size @@ -132,4 +132,4 @@ def _write_tensorboard(self): if self.epoch == 1: self.tensorboard_writer.write_info(self.settings.module_name, self.settings.script_name, self.settings.description) - self.tensorboard_writer.write_epoch(self.stats, self.epoch) \ No newline at end of file + self.tensorboard_writer.write_epoch(self.stats, self.epoch)