allow for returning low resolution conditioning image on forward through decoder with return_lowres_cond_image flag

This commit is contained in:
Phil Wang
2022-07-01 09:35:39 -07:00
parent a922a539de
commit 7b0edf9e42
2 changed files with 9 additions and 3 deletions

View File

@@ -2225,7 +2225,8 @@ class Decoder(nn.Module):
image_embed = None, image_embed = None,
text_encodings = None, text_encodings = None,
text_mask = None, text_mask = None,
unet_number = None unet_number = None,
return_lowres_cond_image = False # whether to return the low resolution conditioning images, for debugging upsampler purposes
): ):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)' assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1) unet_number = default(unet_number, 1)
@@ -2275,7 +2276,12 @@ class Decoder(nn.Module):
image = vae.encode(image) image = vae.encode(image)
lowres_cond_img = maybe(vae.encode)(lowres_cond_img) lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler) losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
if not return_lowres_cond_image:
return losses
return losses, lowres_cond_img
# main class # main class

View File

@@ -1 +1 @@
__version__ = '0.15.0' __version__ = '0.15.1'