diff --git a/README.md b/README.md index fa9a77a..f205024 100644 --- a/README.md +++ b/README.md @@ -581,7 +581,8 @@ unet1 = Unet( image_embed_dim = 512, cond_dim = 128, channels = 3, - dim_mults=(1, 2, 4, 8) + dim_mults=(1, 2, 4, 8), + cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade) ).cuda() unet2 = Unet( @@ -598,12 +599,11 @@ decoder = Decoder( clip = clip, timesteps = 100, image_cond_drop_prob = 0.1, - text_cond_drop_prob = 0.5, - condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling + text_cond_drop_prob = 0.5 ).cuda() for unet_number in (1, 2): - loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much + loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much loss.backward() # do above for many steps diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 878ab8a..8598fa4 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1930,10 +1930,6 @@ class Decoder(nn.Module): self.unets.append(one_unet) self.vaes.append(one_vae.copy_for_eval()) - # determine from unets whether conditioning on text encoding is needed - - self.condition_on_text_encodings = any([unet.cond_on_text_encodings for unet in self.unets]) - # create noise schedulers per unet if not exists(beta_schedule): @@ -2012,6 +2008,10 @@ class Decoder(nn.Module): def device(self): return self._dummy.device + @property + def condition_on_text_encodings(self): + return any([unet.cond_on_text_encodings for unet in self.unets]) + def get_unet(self, unet_number): assert 0 < unet_number <= len(self.unets) index = unet_number - 1 diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 8e73b46..17136cb 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.16.14' +__version__ = '0.16.15'