mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
complete ddim integration of diffusion prior as well as decoder for each unet, feature complete for https://github.com/lucidrains/DALLE2-pytorch/issues/157
This commit is contained in:
@@ -505,6 +505,12 @@ class NoiseScheduler(nn.Module):
|
||||
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
||||
)
|
||||
|
||||
def predict_noise_from_start(self, x_t, t, x0):
|
||||
return (
|
||||
(x0 - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t) / \
|
||||
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
||||
)
|
||||
|
||||
def p2_reweigh_loss(self, loss, times):
|
||||
if not self.has_p2_loss_reweighting:
|
||||
return loss
|
||||
@@ -911,6 +917,7 @@ class DiffusionPrior(nn.Module):
|
||||
image_size = None,
|
||||
image_channels = 3,
|
||||
timesteps = 1000,
|
||||
sample_timesteps = None,
|
||||
cond_drop_prob = 0.,
|
||||
loss_type = "l2",
|
||||
predict_x_start = True,
|
||||
@@ -924,6 +931,8 @@ class DiffusionPrior(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_timesteps = sample_timesteps
|
||||
|
||||
self.noise_scheduler = NoiseScheduler(
|
||||
beta_schedule = beta_schedule,
|
||||
timesteps = timesteps,
|
||||
@@ -978,8 +987,6 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
if self.predict_x_start:
|
||||
x_recon = pred
|
||||
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
|
||||
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
|
||||
else:
|
||||
x_recon = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||
|
||||
@@ -1002,21 +1009,75 @@ class DiffusionPrior(nn.Module):
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, shape, text_cond, cond_scale = 1.):
|
||||
device = self.device
|
||||
|
||||
b = shape[0]
|
||||
image_embed = torch.randn(shape, device=device)
|
||||
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.):
|
||||
batch, device = shape[0], self.device
|
||||
image_embed = torch.randn(shape, device = device)
|
||||
|
||||
if self.init_image_embed_l2norm:
|
||||
image_embed = l2norm(image_embed) * self.image_embed_scale
|
||||
|
||||
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
|
||||
times = torch.full((b,), i, device = device, dtype = torch.long)
|
||||
times = torch.full((batch,), i, device = device, dtype = torch.long)
|
||||
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
|
||||
|
||||
return image_embed
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scale = 1.):
|
||||
batch, device, alphas, total_timesteps = shape[0], self.device, self.noise_scheduler.alphas_cumprod_prev, self.noise_scheduler.num_timesteps
|
||||
|
||||
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
|
||||
|
||||
times = list(reversed(times.int().tolist()))
|
||||
time_pairs = list(zip(times[:-1], times[1:]))
|
||||
|
||||
image_embed = torch.randn(shape, device = device)
|
||||
|
||||
if self.init_image_embed_l2norm:
|
||||
image_embed = l2norm(image_embed) * self.image_embed_scale
|
||||
|
||||
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
|
||||
alpha = alphas[time]
|
||||
alpha_next = alphas[time_next]
|
||||
|
||||
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
||||
|
||||
pred = self.net.forward_with_cond_scale(image_embed, time_cond, cond_scale = cond_scale, **text_cond)
|
||||
|
||||
if self.predict_x_start:
|
||||
x_start = pred
|
||||
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = pred)
|
||||
else:
|
||||
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred)
|
||||
pred_noise = pred
|
||||
|
||||
if not self.predict_x_start:
|
||||
x_start.clamp_(-1., 1.)
|
||||
|
||||
if self.predict_x_start and self.sampling_clamp_l2norm:
|
||||
x_start = l2norm(x_start) * self.image_embed_scale
|
||||
|
||||
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
||||
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
||||
new_noise = torch.randn_like(image_embed)
|
||||
|
||||
img = x_start * alpha_next.sqrt() + \
|
||||
c1 * new_noise + \
|
||||
c2 * pred_noise
|
||||
|
||||
return image_embed
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, *args, timesteps = None, **kwargs):
|
||||
timesteps = default(timesteps, self.noise_scheduler.num_timesteps)
|
||||
assert timesteps <= self.noise_scheduler.num_timesteps
|
||||
is_ddim = timesteps < self.noise_scheduler.num_timesteps
|
||||
|
||||
if not is_ddim:
|
||||
return self.p_sample_loop_ddpm(*args, **kwargs)
|
||||
|
||||
return self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
|
||||
|
||||
def p_losses(self, image_embed, times, text_cond, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(image_embed))
|
||||
|
||||
@@ -1051,7 +1112,15 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
def sample(self, text, num_samples_per_batch = 2, cond_scale = 1.):
|
||||
def sample(
|
||||
self,
|
||||
text,
|
||||
num_samples_per_batch = 2,
|
||||
cond_scale = 1.,
|
||||
timesteps = None
|
||||
):
|
||||
timesteps = default(timesteps, self.sample_timesteps)
|
||||
|
||||
# in the paper, what they did was
|
||||
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
|
||||
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
|
||||
@@ -1066,7 +1135,7 @@ class DiffusionPrior(nn.Module):
|
||||
if self.condition_on_text_encodings:
|
||||
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
|
||||
|
||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale)
|
||||
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale, timesteps = timesteps)
|
||||
|
||||
# retrieve original unscaled image embed
|
||||
|
||||
@@ -1853,6 +1922,7 @@ class Decoder(nn.Module):
|
||||
channels = 3,
|
||||
vae = tuple(),
|
||||
timesteps = 1000,
|
||||
sample_timesteps = None,
|
||||
image_cond_drop_prob = 0.1,
|
||||
text_cond_drop_prob = 0.5,
|
||||
loss_type = 'l2',
|
||||
@@ -1876,7 +1946,8 @@ class Decoder(nn.Module):
|
||||
use_dynamic_thres = False, # from the Imagen paper
|
||||
dynamic_thres_percentile = 0.9,
|
||||
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
|
||||
p2_loss_weight_k = 1
|
||||
p2_loss_weight_k = 1,
|
||||
ddim_sampling_eta = 1. # can be set to 0. for deterministic sampling afaict
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -1956,6 +2027,11 @@ class Decoder(nn.Module):
|
||||
self.unets.append(one_unet)
|
||||
self.vaes.append(one_vae.copy_for_eval())
|
||||
|
||||
# sampling timesteps, defaults to non-ddim with full timesteps sampling
|
||||
|
||||
self.sample_timesteps = cast_tuple(sample_timesteps, num_unets)
|
||||
self.ddim_sampling_eta = ddim_sampling_eta
|
||||
|
||||
# create noise schedulers per unet
|
||||
|
||||
if not exists(beta_schedule):
|
||||
@@ -1966,7 +2042,9 @@ class Decoder(nn.Module):
|
||||
|
||||
self.noise_schedulers = nn.ModuleList([])
|
||||
|
||||
for unet_beta_schedule, unet_p2_loss_weight_gamma in zip(beta_schedule, p2_loss_weight_gamma):
|
||||
for ind, (unet_beta_schedule, unet_p2_loss_weight_gamma, sample_timesteps) in enumerate(zip(beta_schedule, p2_loss_weight_gamma, self.sample_timesteps)):
|
||||
assert sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}'
|
||||
|
||||
noise_scheduler = NoiseScheduler(
|
||||
beta_schedule = unet_beta_schedule,
|
||||
timesteps = timesteps,
|
||||
@@ -2067,6 +2145,26 @@ class Decoder(nn.Module):
|
||||
for unet, device in zip(self.unets, devices):
|
||||
unet.to(device)
|
||||
|
||||
def dynamic_threshold(self, x):
|
||||
""" proposed in https://arxiv.org/abs/2205.11487 as an improved clamping in the setting of classifier free guidance """
|
||||
|
||||
# s is the threshold amount
|
||||
# static thresholding would just be s = 1
|
||||
s = 1.
|
||||
if self.use_dynamic_thres:
|
||||
s = torch.quantile(
|
||||
rearrange(x, 'b ... -> b (...)').abs(),
|
||||
self.dynamic_thres_percentile,
|
||||
dim = -1
|
||||
)
|
||||
|
||||
s.clamp_(min = 1.)
|
||||
s = s.view(-1, *((1,) * (x.ndim - 1)))
|
||||
|
||||
# clip by threshold, depending on whether static or dynamic
|
||||
x = x.clamp(-s, s) / s
|
||||
return x
|
||||
|
||||
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
|
||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||
|
||||
@@ -2081,21 +2179,7 @@ class Decoder(nn.Module):
|
||||
x_recon = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||
|
||||
if clip_denoised:
|
||||
# s is the threshold amount
|
||||
# static thresholding would just be s = 1
|
||||
s = 1.
|
||||
if self.use_dynamic_thres:
|
||||
s = torch.quantile(
|
||||
rearrange(x_recon, 'b ... -> b (...)').abs(),
|
||||
self.dynamic_thres_percentile,
|
||||
dim = -1
|
||||
)
|
||||
|
||||
s.clamp_(min = 1.)
|
||||
s = s.view(-1, *((1,) * (x_recon.ndim - 1)))
|
||||
|
||||
# clip by threshold, depending on whether static or dynamic
|
||||
x_recon = x_recon.clamp(-s, s) / s
|
||||
x_recon = self.dynamic_threshold(x_recon)
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
|
||||
@@ -2125,7 +2209,7 @@ class Decoder(nn.Module):
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
|
||||
def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
|
||||
device = self.device
|
||||
|
||||
b = shape[0]
|
||||
@@ -2153,6 +2237,61 @@ class Decoder(nn.Module):
|
||||
unnormalize_img = self.unnormalize_img(img)
|
||||
return unnormalize_img
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
|
||||
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
|
||||
|
||||
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
|
||||
|
||||
times = list(reversed(times.int().tolist()))
|
||||
time_pairs = list(zip(times[:-1], times[1:]))
|
||||
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
|
||||
alpha = alphas[time]
|
||||
alpha_next = alphas[time_next]
|
||||
|
||||
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
||||
|
||||
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
|
||||
if learned_variance:
|
||||
pred, _ = pred.chunk(2, dim = 1)
|
||||
|
||||
if predict_x_start:
|
||||
x_start = pred
|
||||
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
|
||||
else:
|
||||
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
|
||||
pred_noise = pred
|
||||
|
||||
if clip_denoised:
|
||||
x_start = self.dynamic_threshold(x_start)
|
||||
|
||||
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
||||
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
||||
|
||||
img = x_start * alpha_next.sqrt() + \
|
||||
c1 * torch.randn_like(img) + \
|
||||
c2 * pred_noise
|
||||
|
||||
img = self.unnormalize_img(img)
|
||||
return img
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, *args, noise_scheduler, timesteps = None, **kwargs):
|
||||
num_timesteps = noise_scheduler.num_timesteps
|
||||
|
||||
timesteps = default(timesteps, num_timesteps)
|
||||
assert timesteps <= num_timesteps
|
||||
is_ddim = timesteps < num_timesteps
|
||||
|
||||
if not is_ddim:
|
||||
return self.p_sample_loop_ddpm(*args, noise_scheduler = noise_scheduler, **kwargs)
|
||||
|
||||
return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
|
||||
|
||||
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
@@ -2253,7 +2392,7 @@ class Decoder(nn.Module):
|
||||
img = None
|
||||
is_cuda = next(self.parameters()).is_cuda
|
||||
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers)):
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, sample_timesteps in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.sample_timesteps)):
|
||||
|
||||
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
|
||||
|
||||
@@ -2282,7 +2421,8 @@ class Decoder(nn.Module):
|
||||
clip_denoised = not is_latent_diffusion,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
is_latent_diffusion = is_latent_diffusion,
|
||||
noise_scheduler = noise_scheduler
|
||||
noise_scheduler = noise_scheduler,
|
||||
timesteps = sample_timesteps
|
||||
)
|
||||
|
||||
img = vae.decode(img)
|
||||
|
||||
Reference in New Issue
Block a user