diff --git a/main.py b/main.py index 66b74d1..c916f51 100644 --- a/main.py +++ b/main.py @@ -469,9 +469,8 @@ class ImageLogger(Callback): self.log_img(pl_module, batch, batch_idx, split="train") @rank_zero_only - # def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): def on_validation_batch_end( - self, trainer, pl_module, outputs, batch, batch_idx, **kwargs + self, trainer, pl_module, outputs, batch, batch_idx, *args, **kwargs ): if not self.disabled and pl_module.global_step > 0: self.log_img(pl_module, batch, batch_idx, split="val")