From aa900213e7f59b5e987810453e8f0957a9cdef9e Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 28 Apr 2022 20:53:15 -0700 Subject: [PATCH] force first unet in the cascade to be conditioned on image embeds --- dalle2_pytorch/dalle2_pytorch.py | 10 ++++++---- setup.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index a62171c..58920fa 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1066,13 +1066,14 @@ class Unet(nn.Module): self, *, lowres_cond, - channels + channels, + cond_on_image_embeds ): - if lowres_cond == self.lowres_cond and channels == self.channels: + if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds: return self - updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond, 'channels': channels} - return self.__class__(**updated_kwargs) + updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds} + return self.__class__(**{**self._locals, **updated_kwargs}) def forward_with_cond_scale( self, @@ -1279,6 +1280,7 @@ class Decoder(BaseGaussianDiffusion): one_unet = one_unet.cast_model_parameters( lowres_cond = not is_first, + cond_on_image_embeds = is_first, channels = unet_channels ) diff --git a/setup.py b/setup.py index c171129..a843941 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.63', + version = '0.0.64', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',