From b90364695da3745f1ef9b566c6b5dfef8a88a74f Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 26 Jun 2022 21:07:42 -0700 Subject: [PATCH] fix remaining issues with deriving cond_on_text_encodings from child unet settings --- README.md | 6 +++--- dalle2_pytorch/dalle2_pytorch.py | 8 +++++--- dalle2_pytorch/train_configs.py | 16 +++++++++++----- dalle2_pytorch/version.py | 2 +- train_decoder.py | 3 ++- 5 files changed, 22 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 0bbbfd1..a9c5e9e 100644 --- a/README.md +++ b/README.md @@ -368,7 +368,8 @@ unet1 = Unet( image_embed_dim = 512, cond_dim = 128, channels = 3, - dim_mults=(1, 2, 4, 8) + dim_mults=(1, 2, 4, 8), + cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings ).cuda() unet2 = Unet( @@ -385,8 +386,7 @@ decoder = Decoder( clip = clip, timesteps = 100, image_cond_drop_prob = 0.1, - text_cond_drop_prob = 0.5, - condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling + text_cond_drop_prob = 0.5 ).cuda() for unet_number in (1, 2): diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 66c8ccc..3ed29c0 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1817,9 +1817,7 @@ class Decoder(nn.Module): unets = cast_tuple(unet) num_unets = len(unets) - self.unconditional = unconditional - self.condition_on_text_encodings = unets[0].cond_on_text_encodings - assert not (self.condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present' + self.unconditional = unconditional # automatically take care of ensuring that first unet is unconditional # while the rest of the unets are conditioned on the low resolution image produced by previous unet @@ -1859,6 +1857,10 @@ class Decoder(nn.Module): self.unets.append(one_unet) self.vaes.append(one_vae.copy_for_eval()) + # determine from unets whether conditioning on text encoding is needed + + self.condition_on_text_encodings = any([unet.cond_on_text_encodings for unet in self.unets]) + # create noise schedulers per unet if not exists(beta_schedule): diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index d6c9dbf..1f9ccec 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -284,21 +284,27 @@ class TrainDecoderConfig(BaseModel): def check_has_embeddings(cls, values): # Makes sure that enough information is provided to get the embeddings specified for training data_config, decoder_config = values.get('data'), values.get('decoder') - if data_config is None or decoder_config is None: + + if not exists(data_config) or not exists(decoder_config): # Then something else errored and we should just pass through return values - using_text_encodings = decoder_config.unets[0].cond_on_text_encodings # in dalle2 only the first UNet is text conditioned + + using_text_encodings = any([unet.cond_on_text_encodings for unet in decoder_config.unets]) using_clip = exists(decoder_config.clip) img_emb_url = data_config.img_embeddings_url text_emb_url = data_config.text_embeddings_url + if using_text_embeddings: # Then we need some way to get the embeddings - assert using_clip or text_emb_url is not None, 'If text conditioning, either clip or text_embeddings_url must be provided' + assert using_clip or exists(text_emb_url), 'If text conditioning, either clip or text_embeddings_url must be provided' + if using_clip: if using_text_embeddings: - assert text_emb_url is None or img_emb_url is None, 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the text embeddings' + assert not exists(text_emb_url) or not exists(img_emb_url), 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the text embeddings' else: - assert img_emb_url is None, 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings' + assert not exists(img_emb_url), 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings' + if text_emb_url: assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason." + return values diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index f8d9095..92a60bd 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.12.1' +__version__ = '0.12.2' diff --git a/train_decoder.py b/train_decoder.py index d608d1e..946aa6d 100644 --- a/train_decoder.py +++ b/train_decoder.py @@ -596,7 +596,8 @@ def initialize_training(config, config_path): has_img_embeddings = config.data.img_embeddings_url is not None has_text_embeddings = config.data.text_embeddings_url is not None - conditioning_on_text = config.decoder.unets[0].cond_on_text_encodings + conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets]) + has_clip_model = config.decoder.clip is not None data_source_string = ""