diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index a1c5b39..7fe0a02 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -673,8 +673,14 @@ class DecoderTrainer(nn.Module): def sample(self, *args, **kwargs): distributed = self.accelerator.num_processes > 1 base_decoder = self.accelerator.unwrap_model(self.decoder) + + was_training = base_decoder.training + base_decoder.eval() + if kwargs.pop('use_non_ema', False) or not self.use_ema: - return base_decoder.sample(*args, **kwargs, distributed = distributed) + out = base_decoder.sample(*args, **kwargs, distributed = distributed) + base_decoder.train(was_training) + return out trainable_unets = self.accelerator.unwrap_model(self.decoder).unets base_decoder.unets = self.unets # swap in exponential moving averaged unets for sampling @@ -687,6 +693,7 @@ class DecoderTrainer(nn.Module): for ema in self.ema_unets: ema.restore_ema_model_device() + base_decoder.train(was_training) return output @torch.no_grad() diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 50690f9..caf9513 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.23.6' +__version__ = '0.23.7'