force first unet in the cascade to be conditioned on image embeds

This commit is contained in:
Phil Wang
2022-04-28 20:53:15 -07:00
parent cb26187450
commit aa900213e7
2 changed files with 7 additions and 5 deletions

View File

@@ -1066,13 +1066,14 @@ class Unet(nn.Module):
self, self,
*, *,
lowres_cond, lowres_cond,
channels channels,
cond_on_image_embeds
): ):
if lowres_cond == self.lowres_cond and channels == self.channels: if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds:
return self return self
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond, 'channels': channels} updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds}
return self.__class__(**updated_kwargs) return self.__class__(**{**self._locals, **updated_kwargs})
def forward_with_cond_scale( def forward_with_cond_scale(
self, self,
@@ -1279,6 +1280,7 @@ class Decoder(BaseGaussianDiffusion):
one_unet = one_unet.cast_model_parameters( one_unet = one_unet.cast_model_parameters(
lowres_cond = not is_first, lowres_cond = not is_first,
cond_on_image_embeds = is_first,
channels = unet_channels channels = unet_channels
) )

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.63', version = '0.0.64',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',