mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
make sure cascading DDPM can be trained unconditionally, to ready for CLI one command training for the public
This commit is contained in:
@@ -1001,8 +1001,8 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
- [x] offer setting in diffusion prior to split time and image embeddings into multiple tokens, configurable, for more surface area during attention
|
- [x] offer setting in diffusion prior to split time and image embeddings into multiple tokens, configurable, for more surface area during attention
|
||||||
- [x] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
|
- [x] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
|
||||||
- [x] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
|
- [x] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
|
||||||
|
- [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
|
||||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||||
- [ ] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
|
|
||||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||||
- [ ] train on a toy task, offer in colab
|
- [ ] train on a toy task, offer in colab
|
||||||
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
||||||
@@ -1015,6 +1015,7 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
||||||
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||||
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
||||||
|
- [ ] decoder needs one day worth of refactor for tech debt
|
||||||
|
|
||||||
## Citations
|
## Citations
|
||||||
|
|
||||||
|
|||||||
@@ -1382,12 +1382,13 @@ class Unet(nn.Module):
|
|||||||
*,
|
*,
|
||||||
lowres_cond,
|
lowres_cond,
|
||||||
channels,
|
channels,
|
||||||
cond_on_image_embeds
|
cond_on_image_embeds,
|
||||||
|
cond_on_text_encodings
|
||||||
):
|
):
|
||||||
if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds:
|
if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds and cond_on_text_encodings == self.cond_on_text_encodings:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds}
|
updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds, 'cond_on_text_encodings': cond_on_text_encodings}
|
||||||
return self.__class__(**{**self._locals, **updated_kwargs})
|
return self.__class__(**{**self._locals, **updated_kwargs})
|
||||||
|
|
||||||
def forward_with_cond_scale(
|
def forward_with_cond_scale(
|
||||||
@@ -1583,7 +1584,8 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
||||||
clip_denoised = True,
|
clip_denoised = True,
|
||||||
clip_x_start = True,
|
clip_x_start = True,
|
||||||
clip_adapter_overrides = dict()
|
clip_adapter_overrides = dict(),
|
||||||
|
unconditional = False
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
beta_schedule = beta_schedule,
|
beta_schedule = beta_schedule,
|
||||||
@@ -1591,6 +1593,9 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
loss_type = loss_type
|
loss_type = loss_type
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.unconditional = unconditional
|
||||||
|
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||||
|
|
||||||
assert exists(clip) ^ exists(image_size), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
assert exists(clip) ^ exists(image_size), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
||||||
|
|
||||||
self.clip = None
|
self.clip = None
|
||||||
@@ -1632,7 +1637,8 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
one_unet = one_unet.cast_model_parameters(
|
one_unet = one_unet.cast_model_parameters(
|
||||||
lowres_cond = not is_first,
|
lowres_cond = not is_first,
|
||||||
cond_on_image_embeds = is_first,
|
cond_on_image_embeds = is_first and not unconditional,
|
||||||
|
cond_on_text_encodings = one_unet.cond_on_text_encodings and not unconditional,
|
||||||
channels = unet_channels
|
channels = unet_channels
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1767,12 +1773,16 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
@eval_decorator
|
@eval_decorator
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
image_embed,
|
image_embed = None,
|
||||||
text = None,
|
text = None,
|
||||||
|
batch_size = 1,
|
||||||
cond_scale = 1.,
|
cond_scale = 1.,
|
||||||
stop_at_unet_number = None
|
stop_at_unet_number = None
|
||||||
):
|
):
|
||||||
batch_size = image_embed.shape[0]
|
assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
|
||||||
|
|
||||||
|
if not self.unconditional:
|
||||||
|
batch_size = image_embed.shape[0]
|
||||||
|
|
||||||
text_encodings = text_mask = None
|
text_encodings = text_mask = None
|
||||||
if exists(text):
|
if exists(text):
|
||||||
@@ -1782,10 +1792,11 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
||||||
|
|
||||||
img = None
|
img = None
|
||||||
|
is_cuda = next(self.parameters()).is_cuda
|
||||||
|
|
||||||
for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
||||||
|
|
||||||
context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context()
|
context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
|
||||||
|
|
||||||
with context:
|
with context:
|
||||||
lowres_cond_img = None
|
lowres_cond_img = None
|
||||||
|
|||||||
Reference in New Issue
Block a user