From 771fe0d0d27701ad9f4c1d56ae7749ef21272ed4 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 12 Apr 2022 10:29:29 -0700 Subject: [PATCH] also consider accepting tokenizer, so dalle2 forward pass can just be invoked as DALLE2() --- dalle2_pytorch/dalle2_pytorch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index c47915a..b9aef83 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -180,12 +180,14 @@ class DALLE2(nn.Module): *, clip, prior, - decoder + decoder, + tokenizer = None ): super().__init__() assert isinstance(clip), CLIP assert isinstance(prior), DiffusionPrior assert isinstance(decoder), Decoder + self.tokenizer = tokenizer @torch.no_grad() def forward(