use the eval decorator

This commit is contained in:
Phil Wang
2022-04-14 10:13:43 -07:00
parent 68e9883f59
commit 23c401a5d5
2 changed files with 4 additions and 3 deletions

View File

@@ -1011,11 +1011,12 @@ class DALLE2(nn.Module):
super().__init__() super().__init__()
assert isinstance(prior, DiffusionPrior) assert isinstance(prior, DiffusionPrior)
assert isinstance(decoder, Decoder) assert isinstance(decoder, Decoder)
self.prior = prior.eval() self.prior = prior
self.decoder = decoder.eval() self.decoder = decoder
self.prior_num_samples = prior_num_samples self.prior_num_samples = prior_num_samples
@torch.no_grad() @torch.no_grad()
@eval_decorator
def forward( def forward(
self, self,
text, text,

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.11', version = '0.0.12',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',