allow text encodings and text mask to be passed in on forward and sampling for Decoder class

This commit is contained in:
Phil Wang
2022-05-16 10:40:32 -07:00
parent dab106d4e5
commit 1212f7058d
2 changed files with 11 additions and 6 deletions

View File

@@ -61,6 +61,9 @@ def default(val, d):
def cast_tuple(val, length = 1): def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length) return val if isinstance(val, tuple) else ((val,) * length)
def module_device(module):
return next(module.parameters()).device
@contextmanager @contextmanager
def null_context(*args, **kwargs): def null_context(*args, **kwargs):
yield yield
@@ -1817,7 +1820,7 @@ class Decoder(BaseGaussianDiffusion):
self.cuda() self.cuda()
devices = [next(unet.parameters()).device for unet in self.unets] devices = [module_device(unet) for unet in self.unets]
self.unets.cpu() self.unets.cpu()
unet.cuda() unet.cuda()
@@ -1965,6 +1968,8 @@ class Decoder(BaseGaussianDiffusion):
self, self,
image_embed = None, image_embed = None,
text = None, text = None,
text_mask = None,
text_encodings = None,
batch_size = 1, batch_size = 1,
cond_scale = 1., cond_scale = 1.,
stop_at_unet_number = None stop_at_unet_number = None
@@ -1974,8 +1979,8 @@ class Decoder(BaseGaussianDiffusion):
if not self.unconditional: if not self.unconditional:
batch_size = image_embed.shape[0] batch_size = image_embed.shape[0]
text_encodings = text_mask = None if exists(text) and not exists(text_encodings) and not self.unconditional:
if exists(text): assert exists(self.clip)
_, text_encodings, text_mask = self.clip.embed_text(text) _, text_encodings, text_mask = self.clip.embed_text(text)
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
@@ -2027,6 +2032,7 @@ class Decoder(BaseGaussianDiffusion):
text = None, text = None,
image_embed = None, image_embed = None,
text_encodings = None, text_encodings = None,
text_mask = None,
unet_number = None unet_number = None
): ):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)' assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
@@ -2051,7 +2057,6 @@ class Decoder(BaseGaussianDiffusion):
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init' assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
image_embed, _ = self.clip.embed_image(image) image_embed, _ = self.clip.embed_image(image)
text_encodings = text_mask = None
if exists(text) and not exists(text_encodings) and not self.unconditional: if exists(text) and not exists(text_encodings) and not self.unconditional:
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder' assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
_, text_encodings, text_mask = self.clip.embed_text(text) _, text_encodings, text_mask = self.clip.embed_text(text)
@@ -2107,7 +2112,7 @@ class DALLE2(nn.Module):
prior_cond_scale = 1., prior_cond_scale = 1.,
return_pil_images = False return_pil_images = False
): ):
device = next(self.parameters()).device device = module_device(self)
one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1) one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)
if isinstance(text, str) or is_list_str(text): if isinstance(text, str) or is_list_str(text):

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.2.40', version = '0.2.42',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',