diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index c2bfa38..1253444 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1852,7 +1852,7 @@ class Decoder(nn.Module): one_unet = one_unet.cast_model_parameters( lowres_cond = not is_first, cond_on_image_embeds = not unconditional and is_first, - cond_on_text_encodings = not unconditional and (is_first or one_unet.cond_on_text_encodings), + cond_on_text_encodings = not unconditional and one_unet.cond_on_text_encodings, channels = unet_channels, channels_out = unet_channels_out ) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 2c7bffb..f8d9095 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.12.0' +__version__ = '0.12.1'