From 7b0edf9e42ffe531cf247d59ee1ac69699470e31 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 1 Jul 2022 09:35:39 -0700 Subject: [PATCH] allow for returning low resolution conditioning image on forward through decoder with return_lowres_cond_image flag --- dalle2_pytorch/dalle2_pytorch.py | 10 ++++++++-- dalle2_pytorch/version.py | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 0eb9e36..3a23703 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -2225,7 +2225,8 @@ class Decoder(nn.Module): image_embed = None, text_encodings = None, text_mask = None, - unet_number = None + unet_number = None, + return_lowres_cond_image = False # whether to return the low resolution conditioning images, for debugging upsampler purposes ): assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)' unet_number = default(unet_number, 1) @@ -2275,7 +2276,12 @@ class Decoder(nn.Module): image = vae.encode(image) lowres_cond_img = maybe(vae.encode)(lowres_cond_img) - return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler) + losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler) + + if not return_lowres_cond_image: + return losses + + return losses, lowres_cond_img # main class diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index a842d05..6fccdee 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.15.0' +__version__ = '0.15.1'