mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
allow for CLIP to be optional in Decoder, and allow DecoderTrainer to work off training pre-encoded image embeddings
This commit is contained in:
@@ -1457,7 +1457,9 @@ class Decoder(BaseGaussianDiffusion):
|
||||
self,
|
||||
unet,
|
||||
*,
|
||||
clip,
|
||||
clip = None,
|
||||
image_size = None,
|
||||
channels = 3,
|
||||
vae = tuple(),
|
||||
timesteps = 1000,
|
||||
image_cond_drop_prob = 0.1,
|
||||
@@ -1481,15 +1483,22 @@ class Decoder(BaseGaussianDiffusion):
|
||||
loss_type = loss_type
|
||||
)
|
||||
|
||||
if isinstance(clip, CLIP):
|
||||
clip = XClipAdapter(clip)
|
||||
assert exists(clip) ^ exists(image_size), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
||||
|
||||
freeze_model_and_make_eval_(clip)
|
||||
assert isinstance(clip, BaseClipAdapter)
|
||||
self.clip = None
|
||||
if exists(clip):
|
||||
if isinstance(clip, CLIP):
|
||||
clip = XClipAdapter(clip)
|
||||
|
||||
self.clip = clip
|
||||
self.clip_image_size = clip.image_size
|
||||
self.channels = clip.image_channels
|
||||
freeze_model_and_make_eval_(clip)
|
||||
assert isinstance(clip, BaseClipAdapter)
|
||||
|
||||
self.clip = clip
|
||||
self.clip_image_size = clip.image_size
|
||||
self.channels = clip.image_channels
|
||||
else:
|
||||
self.clip_image_size = image_size
|
||||
self.channels = channels
|
||||
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
|
||||
@@ -1522,7 +1531,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
# unet image sizes
|
||||
|
||||
image_sizes = default(image_sizes, (clip.image_size,))
|
||||
image_sizes = default(image_sizes, (self.clip_image_size,))
|
||||
image_sizes = tuple(sorted(set(image_sizes)))
|
||||
|
||||
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
|
||||
@@ -1727,10 +1736,12 @@ class Decoder(BaseGaussianDiffusion):
|
||||
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||
|
||||
if not exists(image_embed):
|
||||
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
|
||||
image_embed, _ = self.clip.embed_image(image)
|
||||
|
||||
text_encodings = text_mask = None
|
||||
if exists(text) and not exists(text_encodings):
|
||||
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
|
||||
_, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
|
||||
Reference in New Issue
Block a user