mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
allow text encodings and text mask to be passed in on forward and sampling for Decoder class
This commit is contained in:
@@ -61,6 +61,9 @@ def default(val, d):
|
||||
def cast_tuple(val, length = 1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
def module_device(module):
|
||||
return next(module.parameters()).device
|
||||
|
||||
@contextmanager
|
||||
def null_context(*args, **kwargs):
|
||||
yield
|
||||
@@ -1817,7 +1820,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
self.cuda()
|
||||
|
||||
devices = [next(unet.parameters()).device for unet in self.unets]
|
||||
devices = [module_device(unet) for unet in self.unets]
|
||||
self.unets.cpu()
|
||||
unet.cuda()
|
||||
|
||||
@@ -1965,6 +1968,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
self,
|
||||
image_embed = None,
|
||||
text = None,
|
||||
text_mask = None,
|
||||
text_encodings = None,
|
||||
batch_size = 1,
|
||||
cond_scale = 1.,
|
||||
stop_at_unet_number = None
|
||||
@@ -1974,8 +1979,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
if not self.unconditional:
|
||||
batch_size = image_embed.shape[0]
|
||||
|
||||
text_encodings = text_mask = None
|
||||
if exists(text):
|
||||
if exists(text) and not exists(text_encodings) and not self.unconditional:
|
||||
assert exists(self.clip)
|
||||
_, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
@@ -2027,6 +2032,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
text = None,
|
||||
image_embed = None,
|
||||
text_encodings = None,
|
||||
text_mask = None,
|
||||
unet_number = None
|
||||
):
|
||||
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
||||
@@ -2051,7 +2057,6 @@ class Decoder(BaseGaussianDiffusion):
|
||||
assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
|
||||
image_embed, _ = self.clip.embed_image(image)
|
||||
|
||||
text_encodings = text_mask = None
|
||||
if exists(text) and not exists(text_encodings) and not self.unconditional:
|
||||
assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
|
||||
_, text_encodings, text_mask = self.clip.embed_text(text)
|
||||
@@ -2107,7 +2112,7 @@ class DALLE2(nn.Module):
|
||||
prior_cond_scale = 1.,
|
||||
return_pil_images = False
|
||||
):
|
||||
device = next(self.parameters()).device
|
||||
device = module_device(self)
|
||||
one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)
|
||||
|
||||
if isinstance(text, str) or is_list_str(text):
|
||||
|
||||
Reference in New Issue
Block a user