mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 18:24:19 +01:00
more sketches to base dalle2 class
This commit is contained in:
@@ -470,13 +470,11 @@ class DALLE2(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
clip,
|
|
||||||
prior,
|
prior,
|
||||||
decoder,
|
decoder,
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(clip), CLIP
|
|
||||||
assert isinstance(prior), DiffusionPrior
|
assert isinstance(prior), DiffusionPrior
|
||||||
assert isinstance(decoder), Decoder
|
assert isinstance(decoder), Decoder
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
@@ -487,4 +485,10 @@ class DALLE2(nn.Module):
|
|||||||
*,
|
*,
|
||||||
text
|
text
|
||||||
):
|
):
|
||||||
return image
|
if isinstance(text, str):
|
||||||
|
assert exists(self.tokenizer), 'tokenizer must be passed in if you were to pass in the text as a string'
|
||||||
|
text = self.tokenizer.encode(text)
|
||||||
|
|
||||||
|
image_embed = prior.sample(text, num_samples_per_batch = 2)
|
||||||
|
images = decoder.sample(image_embed)
|
||||||
|
return images
|
||||||
|
|||||||
Reference in New Issue
Block a user