diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 6e4f48c..130c023 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -61,6 +61,9 @@ def default(val, d): def cast_tuple(val, length = 1): return val if isinstance(val, tuple) else ((val,) * length) +def module_device(module): + return next(module.parameters()).device + @contextmanager def null_context(*args, **kwargs): yield @@ -1817,7 +1820,7 @@ class Decoder(BaseGaussianDiffusion): self.cuda() - devices = [next(unet.parameters()).device for unet in self.unets] + devices = [module_device(unet) for unet in self.unets] self.unets.cpu() unet.cuda() @@ -1965,6 +1968,8 @@ class Decoder(BaseGaussianDiffusion): self, image_embed = None, text = None, + text_mask = None, + text_encodings = None, batch_size = 1, cond_scale = 1., stop_at_unet_number = None @@ -1974,8 +1979,8 @@ class Decoder(BaseGaussianDiffusion): if not self.unconditional: batch_size = image_embed.shape[0] - text_encodings = text_mask = None - if exists(text): + if exists(text) and not exists(text_encodings) and not self.unconditional: + assert exists(self.clip) _, text_encodings, text_mask = self.clip.embed_text(text) assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' @@ -2027,6 +2032,7 @@ class Decoder(BaseGaussianDiffusion): text = None, image_embed = None, text_encodings = None, + text_mask = None, unet_number = None ): assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)' @@ -2051,7 +2057,6 @@ class Decoder(BaseGaussianDiffusion): assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init' image_embed, _ = self.clip.embed_image(image) - text_encodings = text_mask = None if exists(text) and not exists(text_encodings) and not self.unconditional: assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder' _, text_encodings, text_mask = self.clip.embed_text(text) @@ -2107,7 +2112,7 @@ class DALLE2(nn.Module): prior_cond_scale = 1., return_pil_images = False ): - device = next(self.parameters()).device + device = module_device(self) one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1) if isinstance(text, str) or is_list_str(text): diff --git a/setup.py b/setup.py index 805e501..6b7b924 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.2.40', + version = '0.2.42', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',