diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index c94f8d7..23c3a7a 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1710,12 +1710,18 @@ class Decoder(BaseGaussianDiffusion): ) self.unconditional = unconditional - assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present' - assert self.unconditional or (exists(clip) or exists(image_size) or exists(image_sizes)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)' + # text conditioning + + assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present' + self.condition_on_text_encodings = condition_on_text_encodings + + # clip self.clip = None if exists(clip): + assert not unconditional, 'clip must not be given if doing unconditional image training' + if isinstance(clip, CLIP): clip = XClipAdapter(clip, **clip_adapter_overrides) elif isinstance(clip, CoCa): @@ -1725,13 +1731,20 @@ class Decoder(BaseGaussianDiffusion): assert isinstance(clip, BaseClipAdapter) self.clip = clip - self.clip_image_size = clip.image_size - self.channels = clip.image_channels - else: - self.clip_image_size = default(image_size, lambda: image_sizes[-1]) - self.channels = channels - self.condition_on_text_encodings = condition_on_text_encodings + # determine image size, with image_size and image_sizes taking precedence + + if exists(image_size) or exists(image_sizes): + assert exists(image_size) ^ exists(image_sizes), 'only one of image_size or image_sizes must be given' + image_size = default(image_size, lambda: image_sizes[-1]) + elif exists(clip): + image_size = clip.image_size + else: + raise Error('either image_size, image_sizes, or clip must be given to decoder') + + # channels + + self.channels = channels # automatically take care of ensuring that first unet is unconditional # while the rest of the unets are conditioned on the low resolution image produced by previous unet @@ -1773,7 +1786,7 @@ class Decoder(BaseGaussianDiffusion): # unet image sizes - image_sizes = default(image_sizes, (self.clip_image_size,)) + image_sizes = default(image_sizes, (image_size,)) image_sizes = tuple(sorted(set(image_sizes))) assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}' @@ -1811,6 +1824,7 @@ class Decoder(BaseGaussianDiffusion): self.clip_x_start = clip_x_start # normalize and unnormalize image functions + self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity diff --git a/setup.py b/setup.py index 35ded9c..b83e048 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.4.10', + version = '0.4.11', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',