diff --git a/README.md b/README.md index 9a16c9c..8235ea0 100644 --- a/README.md +++ b/README.md @@ -1001,8 +1001,8 @@ Once built, images will be saved to the same directory the command is invoked - [x] offer setting in diffusion prior to split time and image embeddings into multiple tokens, configurable, for more surface area during attention - [x] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor) - [x] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training +- [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet -- [ ] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images - [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs - [ ] train on a toy task, offer in colab - [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder @@ -1015,6 +1015,7 @@ Once built, images will be saved to the same directory the command is invoked - [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783 - [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes - [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training +- [ ] decoder needs one day worth of refactor for tech debt ## Citations diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index ffad621..611f0c1 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1382,12 +1382,13 @@ class Unet(nn.Module): *, lowres_cond, channels, - cond_on_image_embeds + cond_on_image_embeds, + cond_on_text_encodings ): - if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds: + if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds and cond_on_text_encodings == self.cond_on_text_encodings: return self - updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds} + updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds, 'cond_on_text_encodings': cond_on_text_encodings} return self.__class__(**{**self._locals, **updated_kwargs}) def forward_with_cond_scale( @@ -1583,7 +1584,8 @@ class Decoder(BaseGaussianDiffusion): condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation clip_denoised = True, clip_x_start = True, - clip_adapter_overrides = dict() + clip_adapter_overrides = dict(), + unconditional = False ): super().__init__( beta_schedule = beta_schedule, @@ -1591,6 +1593,9 @@ class Decoder(BaseGaussianDiffusion): loss_type = loss_type ) + self.unconditional = unconditional + assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present' + assert exists(clip) ^ exists(image_size), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)' self.clip = None @@ -1632,7 +1637,8 @@ class Decoder(BaseGaussianDiffusion): one_unet = one_unet.cast_model_parameters( lowres_cond = not is_first, - cond_on_image_embeds = is_first, + cond_on_image_embeds = is_first and not unconditional, + cond_on_text_encodings = one_unet.cond_on_text_encodings and not unconditional, channels = unet_channels ) @@ -1767,12 +1773,16 @@ class Decoder(BaseGaussianDiffusion): @eval_decorator def sample( self, - image_embed, + image_embed = None, text = None, + batch_size = 1, cond_scale = 1., stop_at_unet_number = None ): - batch_size = image_embed.shape[0] + assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally' + + if not self.unconditional: + batch_size = image_embed.shape[0] text_encodings = text_mask = None if exists(text): @@ -1782,10 +1792,11 @@ class Decoder(BaseGaussianDiffusion): assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented' img = None + is_cuda = next(self.parameters()).is_cuda for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)): - context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context() + context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context() with context: lowres_cond_img = None diff --git a/setup.py b/setup.py index 4da8c61..058a51e 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.2.5', + version = '0.2.6', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',