mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 18:24:19 +01:00
more sketches to base dalle2 class
This commit is contained in:
@@ -470,13 +470,11 @@ class DALLE2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
clip,
|
||||
prior,
|
||||
decoder,
|
||||
tokenizer = None
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(clip), CLIP
|
||||
assert isinstance(prior), DiffusionPrior
|
||||
assert isinstance(decoder), Decoder
|
||||
self.tokenizer = tokenizer
|
||||
@@ -487,4 +485,10 @@ class DALLE2(nn.Module):
|
||||
*,
|
||||
text
|
||||
):
|
||||
return image
|
||||
if isinstance(text, str):
|
||||
assert exists(self.tokenizer), 'tokenizer must be passed in if you were to pass in the text as a string'
|
||||
text = self.tokenizer.encode(text)
|
||||
|
||||
image_embed = prior.sample(text, num_samples_per_batch = 2)
|
||||
images = decoder.sample(image_embed)
|
||||
return images
|
||||
|
||||
Reference in New Issue
Block a user