From 88f516b5dbb8d9be7a1c869b4578f7e19f31a13f Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 7 Jul 2022 07:42:13 -0700 Subject: [PATCH] fix condition_on_text_encodings in dalle2 orchestrator class, fix readme --- README.md | 6 +++--- dalle2_pytorch/dalle2_pytorch.py | 8 ++++---- dalle2_pytorch/version.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index fa9a77a..fe58391 100644 --- a/README.md +++ b/README.md @@ -581,7 +581,8 @@ unet1 = Unet( image_embed_dim = 512, cond_dim = 128, channels = 3, - dim_mults=(1, 2, 4, 8) + dim_mults=(1, 2, 4, 8), + cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade) ).cuda() unet2 = Unet( @@ -598,8 +599,7 @@ decoder = Decoder( clip = clip, timesteps = 100, image_cond_drop_prob = 0.1, - text_cond_drop_prob = 0.5, - condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling + text_cond_drop_prob = 0.5 ).cuda() for unet_number in (1, 2): diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 878ab8a..8598fa4 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1930,10 +1930,6 @@ class Decoder(nn.Module): self.unets.append(one_unet) self.vaes.append(one_vae.copy_for_eval()) - # determine from unets whether conditioning on text encoding is needed - - self.condition_on_text_encodings = any([unet.cond_on_text_encodings for unet in self.unets]) - # create noise schedulers per unet if not exists(beta_schedule): @@ -2012,6 +2008,10 @@ class Decoder(nn.Module): def device(self): return self._dummy.device + @property + def condition_on_text_encodings(self): + return any([unet.cond_on_text_encodings for unet in self.unets]) + def get_unet(self, unet_number): assert 0 < unet_number <= len(self.unets) index = unet_number - 1 diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 8e73b46..17136cb 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.16.14' +__version__ = '0.16.15'