diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 1a44a74..218d931 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -2328,6 +2328,9 @@ class Decoder(nn.Module): img = torch.randn(shape, device = device) + if not is_latent_diffusion: + lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img) + for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'): alpha = alphas[time] alpha_next = alphas[time_next] diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 7e8f349..4946c77 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.23.3' +__version__ = '0.23.4'