mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
allow for returning low resolution conditioning image on forward through decoder with return_lowres_cond_image flag
This commit is contained in:
@@ -2225,7 +2225,8 @@ class Decoder(nn.Module):
|
|||||||
image_embed = None,
|
image_embed = None,
|
||||||
text_encodings = None,
|
text_encodings = None,
|
||||||
text_mask = None,
|
text_mask = None,
|
||||||
unet_number = None
|
unet_number = None,
|
||||||
|
return_lowres_cond_image = False # whether to return the low resolution conditioning images, for debugging upsampler purposes
|
||||||
):
|
):
|
||||||
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
||||||
unet_number = default(unet_number, 1)
|
unet_number = default(unet_number, 1)
|
||||||
@@ -2275,7 +2276,12 @@ class Decoder(nn.Module):
|
|||||||
image = vae.encode(image)
|
image = vae.encode(image)
|
||||||
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
||||||
|
|
||||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
|
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
|
||||||
|
|
||||||
|
if not return_lowres_cond_image:
|
||||||
|
return losses
|
||||||
|
|
||||||
|
return losses, lowres_cond_img
|
||||||
|
|
||||||
# main class
|
# main class
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.15.0'
|
__version__ = '0.15.1'
|
||||||
|
|||||||
Reference in New Issue
Block a user