mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
allow for CLIP to be optional in Decoder, and allow DecoderTrainer to work off training pre-encoded image embeddings
This commit is contained in:
@@ -1457,7 +1457,9 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
self,
|
self,
|
||||||
unet,
|
unet,
|
||||||
*,
|
*,
|
||||||
clip,
|
clip = None,
|
||||||
|
image_size = None,
|
||||||
|
channels = 3,
|
||||||
vae = tuple(),
|
vae = tuple(),
|
||||||
timesteps = 1000,
|
timesteps = 1000,
|
||||||
image_cond_drop_prob = 0.1,
|
image_cond_drop_prob = 0.1,
|
||||||
@@ -1481,15 +1483,22 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
loss_type = loss_type
|
loss_type = loss_type
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(clip, CLIP):
|
assert exists(clip) ^ exists(image_size), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
||||||
clip = XClipAdapter(clip)
|
|
||||||
|
|
||||||
freeze_model_and_make_eval_(clip)
|
self.clip = None
|
||||||
assert isinstance(clip, BaseClipAdapter)
|
if exists(clip):
|
||||||
|
if isinstance(clip, CLIP):
|
||||||
|
clip = XClipAdapter(clip)
|
||||||
|
|
||||||
self.clip = clip
|
freeze_model_and_make_eval_(clip)
|
||||||
self.clip_image_size = clip.image_size
|
assert isinstance(clip, BaseClipAdapter)
|
||||||
self.channels = clip.image_channels
|
|
||||||
|
self.clip = clip
|
||||||
|
self.clip_image_size = clip.image_size
|
||||||
|
self.channels = clip.image_channels
|
||||||
|
else:
|
||||||
|
self.clip_image_size = image_size
|
||||||
|
self.channels = channels
|
||||||
|
|
||||||
self.condition_on_text_encodings = condition_on_text_encodings
|
self.condition_on_text_encodings = condition_on_text_encodings
|
||||||
|
|
||||||
@@ -1522,7 +1531,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
# unet image sizes
|
# unet image sizes
|
||||||
|
|
||||||
image_sizes = default(image_sizes, (clip.image_size,))
|
image_sizes = default(image_sizes, (self.clip_image_size,))
|
||||||
image_sizes = tuple(sorted(set(image_sizes)))
|
image_sizes = tuple(sorted(set(image_sizes)))
|
||||||
|
|
||||||
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
|
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
|
||||||
@@ -1727,10 +1736,12 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||||
|
|
||||||
if not exists(image_embed):
|
if not exists(image_embed):
|
||||||
|
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
|
||||||
image_embed, _ = self.clip.embed_image(image)
|
image_embed, _ = self.clip.embed_image(image)
|
||||||
|
|
||||||
text_encodings = text_mask = None
|
text_encodings = text_mask = None
|
||||||
if exists(text) and not exists(text_encodings):
|
if exists(text) and not exists(text_encodings):
|
||||||
|
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
|
||||||
_, text_encodings, text_mask = self.clip.embed_text(text)
|
_, text_encodings, text_mask = self.clip.embed_text(text)
|
||||||
|
|
||||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||||
|
|||||||
Reference in New Issue
Block a user