make sure cascading DDPM can be trained unconditionally, to ready for CLI one command training for the public

This commit is contained in:
Phil Wang
2022-05-10 10:48:10 -07:00
parent a1bfb03ba4
commit fc8fce38fb
3 changed files with 22 additions and 10 deletions

View File

@@ -1001,8 +1001,8 @@ Once built, images will be saved to the same directory the command is invoked
- [x] offer setting in diffusion prior to split time and image embeddings into multiple tokens, configurable, for more surface area during attention - [x] offer setting in diffusion prior to split time and image embeddings into multiple tokens, configurable, for more surface area during attention
- [x] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor) - [x] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
- [x] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training - [x] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
- [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs - [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] train on a toy task, offer in colab - [ ] train on a toy task, offer in colab
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder - [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
@@ -1015,6 +1015,7 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783 - [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes - [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training - [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
- [ ] decoder needs one day worth of refactor for tech debt
## Citations ## Citations

View File

@@ -1382,12 +1382,13 @@ class Unet(nn.Module):
*, *,
lowres_cond, lowres_cond,
channels, channels,
cond_on_image_embeds cond_on_image_embeds,
cond_on_text_encodings
): ):
if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds: if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds and cond_on_text_encodings == self.cond_on_text_encodings:
return self return self
updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds} updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds, 'cond_on_text_encodings': cond_on_text_encodings}
return self.__class__(**{**self._locals, **updated_kwargs}) return self.__class__(**{**self._locals, **updated_kwargs})
def forward_with_cond_scale( def forward_with_cond_scale(
@@ -1583,7 +1584,8 @@ class Decoder(BaseGaussianDiffusion):
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
clip_denoised = True, clip_denoised = True,
clip_x_start = True, clip_x_start = True,
clip_adapter_overrides = dict() clip_adapter_overrides = dict(),
unconditional = False
): ):
super().__init__( super().__init__(
beta_schedule = beta_schedule, beta_schedule = beta_schedule,
@@ -1591,6 +1593,9 @@ class Decoder(BaseGaussianDiffusion):
loss_type = loss_type loss_type = loss_type
) )
self.unconditional = unconditional
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
assert exists(clip) ^ exists(image_size), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)' assert exists(clip) ^ exists(image_size), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
self.clip = None self.clip = None
@@ -1632,7 +1637,8 @@ class Decoder(BaseGaussianDiffusion):
one_unet = one_unet.cast_model_parameters( one_unet = one_unet.cast_model_parameters(
lowres_cond = not is_first, lowres_cond = not is_first,
cond_on_image_embeds = is_first, cond_on_image_embeds = is_first and not unconditional,
cond_on_text_encodings = one_unet.cond_on_text_encodings and not unconditional,
channels = unet_channels channels = unet_channels
) )
@@ -1767,12 +1773,16 @@ class Decoder(BaseGaussianDiffusion):
@eval_decorator @eval_decorator
def sample( def sample(
self, self,
image_embed, image_embed = None,
text = None, text = None,
batch_size = 1,
cond_scale = 1., cond_scale = 1.,
stop_at_unet_number = None stop_at_unet_number = None
): ):
batch_size = image_embed.shape[0] assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
if not self.unconditional:
batch_size = image_embed.shape[0]
text_encodings = text_mask = None text_encodings = text_mask = None
if exists(text): if exists(text):
@@ -1782,10 +1792,11 @@ class Decoder(BaseGaussianDiffusion):
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented' assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
img = None img = None
is_cuda = next(self.parameters()).is_cuda
for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)): for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context() context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
with context: with context:
lowres_cond_img = None lowres_cond_img = None

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.2.5', version = '0.2.6',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',