diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 9ab9566..2c17908 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1432,7 +1432,7 @@ class DiffusionPrior(nn.Module): **kwargs ): assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied' - assert exists(image) ^ exists(image_embed), 'either text or text embedding must be supplied' + assert exists(image) ^ exists(image_embed), 'either image or image embedding must be supplied' assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization' if exists(image):