quick fix for @marunine

This commit is contained in:
Phil Wang
2022-05-18 20:22:52 -07:00
parent bb86ab2404
commit db0642c4cd
2 changed files with 11 additions and 6 deletions

View File

@@ -1697,7 +1697,8 @@ class Decoder(BaseGaussianDiffusion):
clip_adapter_overrides = dict(), clip_adapter_overrides = dict(),
learned_variance = True, learned_variance = True,
vb_loss_weight = 0.001, vb_loss_weight = 0.001,
unconditional = False unconditional = False,
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
): ):
super().__init__( super().__init__(
beta_schedule = beta_schedule, beta_schedule = beta_schedule,
@@ -1806,6 +1807,10 @@ class Decoder(BaseGaussianDiffusion):
self.clip_denoised = clip_denoised self.clip_denoised = clip_denoised
self.clip_x_start = clip_x_start self.clip_x_start = clip_x_start
# normalize and unnormalize image functions
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
def get_unet(self, unet_number): def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets) assert 0 < unet_number <= len(self.unets)
index = unet_number - 1 index = unet_number - 1
@@ -1877,7 +1882,7 @@ class Decoder(BaseGaussianDiffusion):
img = torch.randn(shape, device = device) img = torch.randn(shape, device = device)
if not is_latent_diffusion: if not is_latent_diffusion:
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img) lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample( img = self.p_sample(
@@ -1894,7 +1899,7 @@ class Decoder(BaseGaussianDiffusion):
clip_denoised = clip_denoised clip_denoised = clip_denoised
) )
unnormalize_img = unnormalize_zero_to_one(img) unnormalize_img = self.unnormalize_img(img)
return unnormalize_img return unnormalize_img
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False): def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
@@ -1903,8 +1908,8 @@ class Decoder(BaseGaussianDiffusion):
# normalize to [-1, 1] # normalize to [-1, 1]
if not is_latent_diffusion: if not is_latent_diffusion:
x_start = normalize_neg_one_to_one(x_start) x_start = self.normalize_img(x_start)
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img) lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
# get x_t # get x_t

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.3.2', version = '0.3.3',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',