mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-21 18:44:20 +01:00
quick fix for @marunine
This commit is contained in:
@@ -1697,7 +1697,8 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
clip_adapter_overrides = dict(),
|
clip_adapter_overrides = dict(),
|
||||||
learned_variance = True,
|
learned_variance = True,
|
||||||
vb_loss_weight = 0.001,
|
vb_loss_weight = 0.001,
|
||||||
unconditional = False
|
unconditional = False,
|
||||||
|
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
beta_schedule = beta_schedule,
|
beta_schedule = beta_schedule,
|
||||||
@@ -1806,6 +1807,10 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
self.clip_denoised = clip_denoised
|
self.clip_denoised = clip_denoised
|
||||||
self.clip_x_start = clip_x_start
|
self.clip_x_start = clip_x_start
|
||||||
|
|
||||||
|
# normalize and unnormalize image functions
|
||||||
|
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
|
||||||
|
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
|
||||||
|
|
||||||
def get_unet(self, unet_number):
|
def get_unet(self, unet_number):
|
||||||
assert 0 < unet_number <= len(self.unets)
|
assert 0 < unet_number <= len(self.unets)
|
||||||
index = unet_number - 1
|
index = unet_number - 1
|
||||||
@@ -1877,7 +1882,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
img = torch.randn(shape, device = device)
|
img = torch.randn(shape, device = device)
|
||||||
|
|
||||||
if not is_latent_diffusion:
|
if not is_latent_diffusion:
|
||||||
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
|
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
|
||||||
|
|
||||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
||||||
img = self.p_sample(
|
img = self.p_sample(
|
||||||
@@ -1894,7 +1899,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
clip_denoised = clip_denoised
|
clip_denoised = clip_denoised
|
||||||
)
|
)
|
||||||
|
|
||||||
unnormalize_img = unnormalize_zero_to_one(img)
|
unnormalize_img = self.unnormalize_img(img)
|
||||||
return unnormalize_img
|
return unnormalize_img
|
||||||
|
|
||||||
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
|
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
|
||||||
@@ -1903,8 +1908,8 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
# normalize to [-1, 1]
|
# normalize to [-1, 1]
|
||||||
|
|
||||||
if not is_latent_diffusion:
|
if not is_latent_diffusion:
|
||||||
x_start = normalize_neg_one_to_one(x_start)
|
x_start = self.normalize_img(x_start)
|
||||||
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
|
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
|
||||||
|
|
||||||
# get x_t
|
# get x_t
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user