allow for CLIP to be optional in Decoder, and allow DecoderTrainer to work off training pre-encoded image embeddings

This commit is contained in:
Phil Wang
2022-05-05 08:11:01 -07:00
parent 79fabc4341
commit c76a964fd6
2 changed files with 21 additions and 10 deletions

View File

@@ -1457,7 +1457,9 @@ class Decoder(BaseGaussianDiffusion):
self, self,
unet, unet,
*, *,
clip, clip = None,
image_size = None,
channels = 3,
vae = tuple(), vae = tuple(),
timesteps = 1000, timesteps = 1000,
image_cond_drop_prob = 0.1, image_cond_drop_prob = 0.1,
@@ -1481,15 +1483,22 @@ class Decoder(BaseGaussianDiffusion):
loss_type = loss_type loss_type = loss_type
) )
if isinstance(clip, CLIP): assert exists(clip) ^ exists(image_size), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
clip = XClipAdapter(clip)
freeze_model_and_make_eval_(clip) self.clip = None
assert isinstance(clip, BaseClipAdapter) if exists(clip):
if isinstance(clip, CLIP):
clip = XClipAdapter(clip)
self.clip = clip freeze_model_and_make_eval_(clip)
self.clip_image_size = clip.image_size assert isinstance(clip, BaseClipAdapter)
self.channels = clip.image_channels
self.clip = clip
self.clip_image_size = clip.image_size
self.channels = clip.image_channels
else:
self.clip_image_size = image_size
self.channels = channels
self.condition_on_text_encodings = condition_on_text_encodings self.condition_on_text_encodings = condition_on_text_encodings
@@ -1522,7 +1531,7 @@ class Decoder(BaseGaussianDiffusion):
# unet image sizes # unet image sizes
image_sizes = default(image_sizes, (clip.image_size,)) image_sizes = default(image_sizes, (self.clip_image_size,))
image_sizes = tuple(sorted(set(image_sizes))) image_sizes = tuple(sorted(set(image_sizes)))
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}' assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
@@ -1727,10 +1736,12 @@ class Decoder(BaseGaussianDiffusion):
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long) times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
if not exists(image_embed): if not exists(image_embed):
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
image_embed, _ = self.clip.embed_image(image) image_embed, _ = self.clip.embed_image(image)
text_encodings = text_mask = None text_encodings = text_mask = None
if exists(text) and not exists(text_encodings): if exists(text) and not exists(text_encodings):
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
_, text_encodings, text_mask = self.clip.embed_text(text) _, text_encodings, text_mask = self.clip.embed_text(text)
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.105', version = '0.0.106',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',