use the eval decorator

This commit is contained in:
Phil Wang
2022-04-14 10:13:43 -07:00
parent 68e9883f59
commit 23c401a5d5
2 changed files with 4 additions and 3 deletions

View File

@@ -1011,11 +1011,12 @@ class DALLE2(nn.Module):
super().__init__()
assert isinstance(prior, DiffusionPrior)
assert isinstance(decoder, Decoder)
self.prior = prior.eval()
self.decoder = decoder.eval()
self.prior = prior
self.decoder = decoder
self.prior_num_samples = prior_num_samples
@torch.no_grad()
@eval_decorator
def forward(
self,
text,