also consider accepting tokenizer, so dalle2 forward pass can just be invoked as DALLE2(<prompt string>)

This commit is contained in:
Phil Wang
2022-04-12 10:29:29 -07:00
parent de75a8af76
commit 771fe0d0d2

View File

@@ -180,12 +180,14 @@ class DALLE2(nn.Module):
*,
clip,
prior,
decoder
decoder,
tokenizer = None
):
super().__init__()
assert isinstance(clip), CLIP
assert isinstance(prior), DiffusionPrior
assert isinstance(decoder), Decoder
self.tokenizer = tokenizer
@torch.no_grad()
def forward(