From 862e5ba50e48fcb0f4e1370c46b763a4760a30cf Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 12 Apr 2022 17:31:01 -0700 Subject: [PATCH] more sketches to base dalle2 class --- dalle2_pytorch/dalle2_pytorch.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 20d4007..c47ecc7 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -470,13 +470,11 @@ class DALLE2(nn.Module): def __init__( self, *, - clip, prior, decoder, tokenizer = None ): super().__init__() - assert isinstance(clip), CLIP assert isinstance(prior), DiffusionPrior assert isinstance(decoder), Decoder self.tokenizer = tokenizer @@ -487,4 +485,10 @@ class DALLE2(nn.Module): *, text ): - return image + if isinstance(text, str): + assert exists(self.tokenizer), 'tokenizer must be passed in if you were to pass in the text as a string' + text = self.tokenizer.encode(text) + + image_embed = prior.sample(text, num_samples_per_batch = 2) + images = decoder.sample(image_embed) + return images