mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
fix and cleanup image size determination logic in decoder
This commit is contained in:
@@ -1710,12 +1710,18 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.unconditional = unconditional
|
self.unconditional = unconditional
|
||||||
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
|
||||||
|
|
||||||
assert self.unconditional or (exists(clip) or exists(image_size) or exists(image_sizes)), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
# text conditioning
|
||||||
|
|
||||||
|
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||||
|
self.condition_on_text_encodings = condition_on_text_encodings
|
||||||
|
|
||||||
|
# clip
|
||||||
|
|
||||||
self.clip = None
|
self.clip = None
|
||||||
if exists(clip):
|
if exists(clip):
|
||||||
|
assert not unconditional, 'clip must not be given if doing unconditional image training'
|
||||||
|
|
||||||
if isinstance(clip, CLIP):
|
if isinstance(clip, CLIP):
|
||||||
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
||||||
elif isinstance(clip, CoCa):
|
elif isinstance(clip, CoCa):
|
||||||
@@ -1725,13 +1731,20 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
assert isinstance(clip, BaseClipAdapter)
|
assert isinstance(clip, BaseClipAdapter)
|
||||||
|
|
||||||
self.clip = clip
|
self.clip = clip
|
||||||
self.clip_image_size = clip.image_size
|
|
||||||
self.channels = clip.image_channels
|
|
||||||
else:
|
|
||||||
self.clip_image_size = default(image_size, lambda: image_sizes[-1])
|
|
||||||
self.channels = channels
|
|
||||||
|
|
||||||
self.condition_on_text_encodings = condition_on_text_encodings
|
# determine image size, with image_size and image_sizes taking precedence
|
||||||
|
|
||||||
|
if exists(image_size) or exists(image_sizes):
|
||||||
|
assert exists(image_size) ^ exists(image_sizes), 'only one of image_size or image_sizes must be given'
|
||||||
|
image_size = default(image_size, lambda: image_sizes[-1])
|
||||||
|
elif exists(clip):
|
||||||
|
image_size = clip.image_size
|
||||||
|
else:
|
||||||
|
raise Error('either image_size, image_sizes, or clip must be given to decoder')
|
||||||
|
|
||||||
|
# channels
|
||||||
|
|
||||||
|
self.channels = channels
|
||||||
|
|
||||||
# automatically take care of ensuring that first unet is unconditional
|
# automatically take care of ensuring that first unet is unconditional
|
||||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||||
@@ -1773,7 +1786,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
# unet image sizes
|
# unet image sizes
|
||||||
|
|
||||||
image_sizes = default(image_sizes, (self.clip_image_size,))
|
image_sizes = default(image_sizes, (image_size,))
|
||||||
image_sizes = tuple(sorted(set(image_sizes)))
|
image_sizes = tuple(sorted(set(image_sizes)))
|
||||||
|
|
||||||
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
|
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
|
||||||
@@ -1811,6 +1824,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
self.clip_x_start = clip_x_start
|
self.clip_x_start = clip_x_start
|
||||||
|
|
||||||
# normalize and unnormalize image functions
|
# normalize and unnormalize image functions
|
||||||
|
|
||||||
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
|
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
|
||||||
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
|
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user