better naming

This commit is contained in:
Phil Wang
2022-04-25 07:44:33 -07:00
parent 863f4ef243
commit 8f2a0c7e00
2 changed files with 22 additions and 22 deletions

View File

@@ -483,7 +483,7 @@ class DiffusionPrior(nn.Module):
timesteps = 1000,
cond_drop_prob = 0.2,
loss_type = "l1",
predict_x0 = True,
predict_x_start = True,
beta_schedule = "cosine",
):
super().__init__()
@@ -497,7 +497,7 @@ class DiffusionPrior(nn.Module):
self.image_size = clip.image_size
self.cond_drop_prob = cond_drop_prob
self.predict_x0 = predict_x0
self.predict_x_start = predict_x_start
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
if beta_schedule == "cosine":
@@ -586,14 +586,14 @@ class DiffusionPrior(nn.Module):
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
pred = self.net(x, t, **text_cond)
if self.predict_x0:
if self.predict_x_start:
x_recon = pred
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
else:
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised and not self.predict_x0:
if clip_denoised and not self.predict_x_start:
x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
@@ -639,7 +639,7 @@ class DiffusionPrior(nn.Module):
**text_cond
)
to_predict = noise if not self.predict_x0 else image_embed
to_predict = noise if not self.predict_x_start else image_embed
if self.loss_type == 'l1':
loss = F.l1_loss(to_predict, x_recon)
@@ -1121,8 +1121,8 @@ class Decoder(nn.Module):
cond_drop_prob = 0.2,
loss_type = 'l1',
beta_schedule = 'cosine',
predict_x0 = False,
predict_x0_for_latent_diffusion = False,
predict_x_start = False,
predict_x_start_for_latent_diffusion = False,
image_sizes = None, # for cascading ddpm, image size at each stage
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
@@ -1173,7 +1173,7 @@ class Decoder(nn.Module):
# predict x0 config
self.predict_x0 = cast_tuple(predict_x0, len(unets)) if not predict_x0_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
# cascading ddpm related stuff
@@ -1293,31 +1293,31 @@ class Decoder(nn.Module):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x0 = False, cond_scale = 1.):
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
if predict_x0:
if predict_x_start:
x_recon = pred
else:
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised and not predict_x0:
if clip_denoised and not predict_x_start:
x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x0 = False, clip_denoised = True, repeat_noise = False):
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x0 = predict_x0)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, unet, shape, image_embed, predict_x0 = False, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
device = self.betas.device
b = shape[0]
@@ -1332,7 +1332,7 @@ class Decoder(nn.Module):
text_encodings = text_encodings,
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
predict_x0 = predict_x0
predict_x_start = predict_x_start
)
return img
@@ -1345,7 +1345,7 @@ class Decoder(nn.Module):
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x0 = False, noise = None):
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
@@ -1359,7 +1359,7 @@ class Decoder(nn.Module):
cond_drop_prob = self.cond_drop_prob
)
target = noise if not predict_x0 else x_start
target = noise if not predict_x_start else x_start
if self.loss_type == 'l1':
loss = F.l1_loss(target, x_recon)
@@ -1381,7 +1381,7 @@ class Decoder(nn.Module):
img = None
for unet, vae, channel, image_size, predict_x0 in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x0)):
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
with self.one_unet_in_gpu(unet = unet):
lowres_cond_img = None
shape = (batch_size, channel, image_size, image_size)
@@ -1401,7 +1401,7 @@ class Decoder(nn.Module):
image_embed = image_embed,
text_encodings = text_encodings,
cond_scale = cond_scale,
predict_x0 = predict_x0,
predict_x_start = predict_x_start,
lowres_cond_img = lowres_cond_img
)
@@ -1425,7 +1425,7 @@ class Decoder(nn.Module):
target_image_size = self.image_sizes[unet_index]
vae = self.vaes[unet_index]
predict_x0 = self.predict_x0[unet_index]
predict_x_start = self.predict_x_start[unet_index]
b, c, h, w, device, = *image.shape, image.device
@@ -1449,7 +1449,7 @@ class Decoder(nn.Module):
if exists(lowres_cond_img):
lowres_cond_img = vae.encode(lowres_cond_img)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x0 = predict_x0)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
# main class

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.42',
version = '0.0.43',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',