diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 1253444..66c8ccc 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1781,13 +1781,6 @@ class Decoder(nn.Module): ): super().__init__() - self.unconditional = unconditional - - # text conditioning - - assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present' - self.condition_on_text_encodings = condition_on_text_encodings - # clip self.clip = None @@ -1819,12 +1812,18 @@ class Decoder(nn.Module): self.channels = channels - # automatically take care of ensuring that first unet is unconditional - # while the rest of the unets are conditioned on the low resolution image produced by previous unet + # verify conditioning method unets = cast_tuple(unet) num_unets = len(unets) + self.unconditional = unconditional + self.condition_on_text_encodings = unets[0].cond_on_text_encodings + assert not (self.condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present' + + # automatically take care of ensuring that first unet is unconditional + # while the rest of the unets are conditioned on the low resolution image produced by previous unet + vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels)) # whether to use learned variance, defaults to True for the first unet in the cascade, as in paper diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index 08f93dc..d6c9dbf 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -158,6 +158,8 @@ class UnetConfig(BaseModel): dim: int dim_mults: ListOrTuple(int) image_embed_dim: int = None + text_embed_dim: int = None + cond_on_text_encodings: bool = None cond_dim: int = None channels: int = 3 attn_dim_head: int = 32 @@ -170,7 +172,6 @@ class DecoderConfig(BaseModel): unets: ListOrTuple(UnetConfig) image_size: int = None image_sizes: ListOrTuple(int) = None - condition_on_text_encodings: bool = False clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided channels: int = 3 timesteps: int = 1000 @@ -286,16 +287,16 @@ class TrainDecoderConfig(BaseModel): if data_config is None or decoder_config is None: # Then something else errored and we should just pass through return values - using_text_embeddings = decoder_config.condition_on_text_encodings + using_text_encodings = decoder_config.unets[0].cond_on_text_encodings # in dalle2 only the first UNet is text conditioned using_clip = exists(decoder_config.clip) img_emb_url = data_config.img_embeddings_url text_emb_url = data_config.text_embeddings_url if using_text_embeddings: # Then we need some way to get the embeddings - assert using_clip or text_emb_url is not None, 'If condition_on_text_encodings is true, either clip or text_embeddings_url must be provided' + assert using_clip or text_emb_url is not None, 'If text conditioning, either clip or text_embeddings_url must be provided' if using_clip: if using_text_embeddings: - assert text_emb_url is None or img_emb_url is None, 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the embeddings' + assert text_emb_url is None or img_emb_url is None, 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the text embeddings' else: assert img_emb_url is None, 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings' if text_emb_url: diff --git a/train_decoder.py b/train_decoder.py index 3fe1289..d608d1e 100644 --- a/train_decoder.py +++ b/train_decoder.py @@ -596,9 +596,10 @@ def initialize_training(config, config_path): has_img_embeddings = config.data.img_embeddings_url is not None has_text_embeddings = config.data.text_embeddings_url is not None - conditioning_on_text = config.decoder.condition_on_text_encodings + conditioning_on_text = config.decoder.unets[0].cond_on_text_encodings has_clip_model = config.decoder.clip is not None data_source_string = "" + if has_img_embeddings: data_source_string += "precomputed image embeddings" elif has_clip_model: @@ -622,7 +623,7 @@ def initialize_training(config, config_path): inference_device=accelerator.device, load_config=config.load, evaluate_config=config.evaluate, - condition_on_text_encodings=config.decoder.condition_on_text_encodings, + condition_on_text_encodings=conditioning_on_text, **config.train.dict(), )