more sketches to base dalle2 class

This commit is contained in:
Phil Wang
2022-04-12 17:31:01 -07:00
parent 25d980ebbf
commit 862e5ba50e

View File

@@ -470,13 +470,11 @@ class DALLE2(nn.Module):
def __init__( def __init__(
self, self,
*, *,
clip,
prior, prior,
decoder, decoder,
tokenizer = None tokenizer = None
): ):
super().__init__() super().__init__()
assert isinstance(clip), CLIP
assert isinstance(prior), DiffusionPrior assert isinstance(prior), DiffusionPrior
assert isinstance(decoder), Decoder assert isinstance(decoder), Decoder
self.tokenizer = tokenizer self.tokenizer = tokenizer
@@ -487,4 +485,10 @@ class DALLE2(nn.Module):
*, *,
text text
): ):
return image if isinstance(text, str):
assert exists(self.tokenizer), 'tokenizer must be passed in if you were to pass in the text as a string'
text = self.tokenizer.encode(text)
image_embed = prior.sample(text, num_samples_per_batch = 2)
images = decoder.sample(image_embed)
return images