diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index c47915a..b9aef83 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -180,12 +180,14 @@ class DALLE2(nn.Module): *, clip, prior, - decoder + decoder, + tokenizer = None ): super().__init__() assert isinstance(clip), CLIP assert isinstance(prior), DiffusionPrior assert isinstance(decoder), Decoder + self.tokenizer = tokenizer @torch.no_grad() def forward(