mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
better naming
This commit is contained in:
@@ -483,7 +483,7 @@ class DiffusionPrior(nn.Module):
|
|||||||
timesteps = 1000,
|
timesteps = 1000,
|
||||||
cond_drop_prob = 0.2,
|
cond_drop_prob = 0.2,
|
||||||
loss_type = "l1",
|
loss_type = "l1",
|
||||||
predict_x0 = True,
|
predict_x_start = True,
|
||||||
beta_schedule = "cosine",
|
beta_schedule = "cosine",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -497,7 +497,7 @@ class DiffusionPrior(nn.Module):
|
|||||||
self.image_size = clip.image_size
|
self.image_size = clip.image_size
|
||||||
self.cond_drop_prob = cond_drop_prob
|
self.cond_drop_prob = cond_drop_prob
|
||||||
|
|
||||||
self.predict_x0 = predict_x0
|
self.predict_x_start = predict_x_start
|
||||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||||
|
|
||||||
if beta_schedule == "cosine":
|
if beta_schedule == "cosine":
|
||||||
@@ -586,14 +586,14 @@ class DiffusionPrior(nn.Module):
|
|||||||
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
||||||
pred = self.net(x, t, **text_cond)
|
pred = self.net(x, t, **text_cond)
|
||||||
|
|
||||||
if self.predict_x0:
|
if self.predict_x_start:
|
||||||
x_recon = pred
|
x_recon = pred
|
||||||
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
|
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
|
||||||
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
|
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
|
||||||
else:
|
else:
|
||||||
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
||||||
|
|
||||||
if clip_denoised and not self.predict_x0:
|
if clip_denoised and not self.predict_x_start:
|
||||||
x_recon.clamp_(-1., 1.)
|
x_recon.clamp_(-1., 1.)
|
||||||
|
|
||||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||||
@@ -639,7 +639,7 @@ class DiffusionPrior(nn.Module):
|
|||||||
**text_cond
|
**text_cond
|
||||||
)
|
)
|
||||||
|
|
||||||
to_predict = noise if not self.predict_x0 else image_embed
|
to_predict = noise if not self.predict_x_start else image_embed
|
||||||
|
|
||||||
if self.loss_type == 'l1':
|
if self.loss_type == 'l1':
|
||||||
loss = F.l1_loss(to_predict, x_recon)
|
loss = F.l1_loss(to_predict, x_recon)
|
||||||
@@ -1121,8 +1121,8 @@ class Decoder(nn.Module):
|
|||||||
cond_drop_prob = 0.2,
|
cond_drop_prob = 0.2,
|
||||||
loss_type = 'l1',
|
loss_type = 'l1',
|
||||||
beta_schedule = 'cosine',
|
beta_schedule = 'cosine',
|
||||||
predict_x0 = False,
|
predict_x_start = False,
|
||||||
predict_x0_for_latent_diffusion = False,
|
predict_x_start_for_latent_diffusion = False,
|
||||||
image_sizes = None, # for cascading ddpm, image size at each stage
|
image_sizes = None, # for cascading ddpm, image size at each stage
|
||||||
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
|
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
|
||||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||||
@@ -1173,7 +1173,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
# predict x0 config
|
# predict x0 config
|
||||||
|
|
||||||
self.predict_x0 = cast_tuple(predict_x0, len(unets)) if not predict_x0_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
|
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
|
||||||
|
|
||||||
# cascading ddpm related stuff
|
# cascading ddpm related stuff
|
||||||
|
|
||||||
@@ -1293,31 +1293,31 @@ class Decoder(nn.Module):
|
|||||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||||
|
|
||||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x0 = False, cond_scale = 1.):
|
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
|
||||||
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||||
|
|
||||||
if predict_x0:
|
if predict_x_start:
|
||||||
x_recon = pred
|
x_recon = pred
|
||||||
else:
|
else:
|
||||||
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
||||||
|
|
||||||
if clip_denoised and not predict_x0:
|
if clip_denoised and not predict_x_start:
|
||||||
x_recon.clamp_(-1., 1.)
|
x_recon.clamp_(-1., 1.)
|
||||||
|
|
||||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||||
return model_mean, posterior_variance, posterior_log_variance
|
return model_mean, posterior_variance, posterior_log_variance
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x0 = False, clip_denoised = True, repeat_noise = False):
|
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x0 = predict_x0)
|
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
|
||||||
noise = noise_like(x.shape, device, repeat_noise)
|
noise = noise_like(x.shape, device, repeat_noise)
|
||||||
# no noise when t == 0
|
# no noise when t == 0
|
||||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample_loop(self, unet, shape, image_embed, predict_x0 = False, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
|
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
|
||||||
device = self.betas.device
|
device = self.betas.device
|
||||||
|
|
||||||
b = shape[0]
|
b = shape[0]
|
||||||
@@ -1332,7 +1332,7 @@ class Decoder(nn.Module):
|
|||||||
text_encodings = text_encodings,
|
text_encodings = text_encodings,
|
||||||
cond_scale = cond_scale,
|
cond_scale = cond_scale,
|
||||||
lowres_cond_img = lowres_cond_img,
|
lowres_cond_img = lowres_cond_img,
|
||||||
predict_x0 = predict_x0
|
predict_x_start = predict_x_start
|
||||||
)
|
)
|
||||||
|
|
||||||
return img
|
return img
|
||||||
@@ -1345,7 +1345,7 @@ class Decoder(nn.Module):
|
|||||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||||
)
|
)
|
||||||
|
|
||||||
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x0 = False, noise = None):
|
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
|
||||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
|
|
||||||
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
|
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
|
||||||
@@ -1359,7 +1359,7 @@ class Decoder(nn.Module):
|
|||||||
cond_drop_prob = self.cond_drop_prob
|
cond_drop_prob = self.cond_drop_prob
|
||||||
)
|
)
|
||||||
|
|
||||||
target = noise if not predict_x0 else x_start
|
target = noise if not predict_x_start else x_start
|
||||||
|
|
||||||
if self.loss_type == 'l1':
|
if self.loss_type == 'l1':
|
||||||
loss = F.l1_loss(target, x_recon)
|
loss = F.l1_loss(target, x_recon)
|
||||||
@@ -1381,7 +1381,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
img = None
|
img = None
|
||||||
|
|
||||||
for unet, vae, channel, image_size, predict_x0 in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x0)):
|
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
||||||
with self.one_unet_in_gpu(unet = unet):
|
with self.one_unet_in_gpu(unet = unet):
|
||||||
lowres_cond_img = None
|
lowres_cond_img = None
|
||||||
shape = (batch_size, channel, image_size, image_size)
|
shape = (batch_size, channel, image_size, image_size)
|
||||||
@@ -1401,7 +1401,7 @@ class Decoder(nn.Module):
|
|||||||
image_embed = image_embed,
|
image_embed = image_embed,
|
||||||
text_encodings = text_encodings,
|
text_encodings = text_encodings,
|
||||||
cond_scale = cond_scale,
|
cond_scale = cond_scale,
|
||||||
predict_x0 = predict_x0,
|
predict_x_start = predict_x_start,
|
||||||
lowres_cond_img = lowres_cond_img
|
lowres_cond_img = lowres_cond_img
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1425,7 +1425,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
target_image_size = self.image_sizes[unet_index]
|
target_image_size = self.image_sizes[unet_index]
|
||||||
vae = self.vaes[unet_index]
|
vae = self.vaes[unet_index]
|
||||||
predict_x0 = self.predict_x0[unet_index]
|
predict_x_start = self.predict_x_start[unet_index]
|
||||||
|
|
||||||
b, c, h, w, device, = *image.shape, image.device
|
b, c, h, w, device, = *image.shape, image.device
|
||||||
|
|
||||||
@@ -1449,7 +1449,7 @@ class Decoder(nn.Module):
|
|||||||
if exists(lowres_cond_img):
|
if exists(lowres_cond_img):
|
||||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
lowres_cond_img = vae.encode(lowres_cond_img)
|
||||||
|
|
||||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x0 = predict_x0)
|
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
|
||||||
|
|
||||||
# main class
|
# main class
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user