diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index b2c4fbd..eb03cd7 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1011,11 +1011,12 @@ class DALLE2(nn.Module): super().__init__() assert isinstance(prior, DiffusionPrior) assert isinstance(decoder, Decoder) - self.prior = prior.eval() - self.decoder = decoder.eval() + self.prior = prior + self.decoder = decoder self.prior_num_samples = prior_num_samples @torch.no_grad() + @eval_decorator def forward( self, text, diff --git a/setup.py b/setup.py index ae2d847..0f15d48 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.11', + version = '0.0.12', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',