complete ddim integration of diffusion prior as well as decoder for each unet, feature complete for https://github.com/lucidrains/DALLE2-pytorch/issues/157

This commit is contained in:
Phil Wang
2022-07-09 17:25:34 -07:00
parent 68de937aac
commit b7e22f7da0
5 changed files with 186 additions and 34 deletions

View File

@@ -583,6 +583,7 @@ unet1 = Unet(
cond_dim = 128, cond_dim = 128,
channels = 3, channels = 3,
dim_mults=(1, 2, 4, 8), dim_mults=(1, 2, 4, 8),
text_embed_dim = 512,
cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade) cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade)
).cuda() ).cuda()
@@ -598,7 +599,8 @@ decoder = Decoder(
unet = (unet1, unet2), unet = (unet1, unet2),
image_sizes = (128, 256), image_sizes = (128, 256),
clip = clip, clip = clip,
timesteps = 100, timesteps = 1000,
sample_timesteps = (250, 27),
image_cond_drop_prob = 0.1, image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5 text_cond_drop_prob = 0.5
).cuda() ).cuda()

View File

@@ -505,6 +505,12 @@ class NoiseScheduler(nn.Module):
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
) )
def predict_noise_from_start(self, x_t, t, x0):
return (
(x0 - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t) / \
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
)
def p2_reweigh_loss(self, loss, times): def p2_reweigh_loss(self, loss, times):
if not self.has_p2_loss_reweighting: if not self.has_p2_loss_reweighting:
return loss return loss
@@ -911,6 +917,7 @@ class DiffusionPrior(nn.Module):
image_size = None, image_size = None,
image_channels = 3, image_channels = 3,
timesteps = 1000, timesteps = 1000,
sample_timesteps = None,
cond_drop_prob = 0., cond_drop_prob = 0.,
loss_type = "l2", loss_type = "l2",
predict_x_start = True, predict_x_start = True,
@@ -924,6 +931,8 @@ class DiffusionPrior(nn.Module):
): ):
super().__init__() super().__init__()
self.sample_timesteps = sample_timesteps
self.noise_scheduler = NoiseScheduler( self.noise_scheduler = NoiseScheduler(
beta_schedule = beta_schedule, beta_schedule = beta_schedule,
timesteps = timesteps, timesteps = timesteps,
@@ -978,8 +987,6 @@ class DiffusionPrior(nn.Module):
if self.predict_x_start: if self.predict_x_start:
x_recon = pred x_recon = pred
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
else: else:
x_recon = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred) x_recon = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
@@ -1002,21 +1009,75 @@ class DiffusionPrior(nn.Module):
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad() @torch.no_grad()
def p_sample_loop(self, shape, text_cond, cond_scale = 1.): def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.):
device = self.device batch, device = shape[0], self.device
image_embed = torch.randn(shape, device = device)
b = shape[0]
image_embed = torch.randn(shape, device=device)
if self.init_image_embed_l2norm: if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale image_embed = l2norm(image_embed) * self.image_embed_scale
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps): for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
times = torch.full((b,), i, device = device, dtype = torch.long) times = torch.full((batch,), i, device = device, dtype = torch.long)
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale) image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
return image_embed return image_embed
@torch.no_grad()
def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scale = 1.):
batch, device, alphas, total_timesteps = shape[0], self.device, self.noise_scheduler.alphas_cumprod_prev, self.noise_scheduler.num_timesteps
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:]))
image_embed = torch.randn(shape, device = device)
if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
alpha = alphas[time]
alpha_next = alphas[time_next]
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
pred = self.net.forward_with_cond_scale(image_embed, time_cond, cond_scale = cond_scale, **text_cond)
if self.predict_x_start:
x_start = pred
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = pred)
else:
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred)
pred_noise = pred
if not self.predict_x_start:
x_start.clamp_(-1., 1.)
if self.predict_x_start and self.sampling_clamp_l2norm:
x_start = l2norm(x_start) * self.image_embed_scale
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
new_noise = torch.randn_like(image_embed)
img = x_start * alpha_next.sqrt() + \
c1 * new_noise + \
c2 * pred_noise
return image_embed
@torch.no_grad()
def p_sample_loop(self, *args, timesteps = None, **kwargs):
timesteps = default(timesteps, self.noise_scheduler.num_timesteps)
assert timesteps <= self.noise_scheduler.num_timesteps
is_ddim = timesteps < self.noise_scheduler.num_timesteps
if not is_ddim:
return self.p_sample_loop_ddpm(*args, **kwargs)
return self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
def p_losses(self, image_embed, times, text_cond, noise = None): def p_losses(self, image_embed, times, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(image_embed)) noise = default(noise, lambda: torch.randn_like(image_embed))
@@ -1051,7 +1112,15 @@ class DiffusionPrior(nn.Module):
@torch.no_grad() @torch.no_grad()
@eval_decorator @eval_decorator
def sample(self, text, num_samples_per_batch = 2, cond_scale = 1.): def sample(
self,
text,
num_samples_per_batch = 2,
cond_scale = 1.,
timesteps = None
):
timesteps = default(timesteps, self.sample_timesteps)
# in the paper, what they did was # in the paper, what they did was
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP # sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch) text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
@@ -1066,7 +1135,7 @@ class DiffusionPrior(nn.Module):
if self.condition_on_text_encodings: if self.condition_on_text_encodings:
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask} text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale) image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale, timesteps = timesteps)
# retrieve original unscaled image embed # retrieve original unscaled image embed
@@ -1853,6 +1922,7 @@ class Decoder(nn.Module):
channels = 3, channels = 3,
vae = tuple(), vae = tuple(),
timesteps = 1000, timesteps = 1000,
sample_timesteps = None,
image_cond_drop_prob = 0.1, image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5, text_cond_drop_prob = 0.5,
loss_type = 'l2', loss_type = 'l2',
@@ -1876,7 +1946,8 @@ class Decoder(nn.Module):
use_dynamic_thres = False, # from the Imagen paper use_dynamic_thres = False, # from the Imagen paper
dynamic_thres_percentile = 0.9, dynamic_thres_percentile = 0.9,
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
p2_loss_weight_k = 1 p2_loss_weight_k = 1,
ddim_sampling_eta = 1. # can be set to 0. for deterministic sampling afaict
): ):
super().__init__() super().__init__()
@@ -1956,6 +2027,11 @@ class Decoder(nn.Module):
self.unets.append(one_unet) self.unets.append(one_unet)
self.vaes.append(one_vae.copy_for_eval()) self.vaes.append(one_vae.copy_for_eval())
# sampling timesteps, defaults to non-ddim with full timesteps sampling
self.sample_timesteps = cast_tuple(sample_timesteps, num_unets)
self.ddim_sampling_eta = ddim_sampling_eta
# create noise schedulers per unet # create noise schedulers per unet
if not exists(beta_schedule): if not exists(beta_schedule):
@@ -1966,7 +2042,9 @@ class Decoder(nn.Module):
self.noise_schedulers = nn.ModuleList([]) self.noise_schedulers = nn.ModuleList([])
for unet_beta_schedule, unet_p2_loss_weight_gamma in zip(beta_schedule, p2_loss_weight_gamma): for ind, (unet_beta_schedule, unet_p2_loss_weight_gamma, sample_timesteps) in enumerate(zip(beta_schedule, p2_loss_weight_gamma, self.sample_timesteps)):
assert sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}'
noise_scheduler = NoiseScheduler( noise_scheduler = NoiseScheduler(
beta_schedule = unet_beta_schedule, beta_schedule = unet_beta_schedule,
timesteps = timesteps, timesteps = timesteps,
@@ -2067,6 +2145,26 @@ class Decoder(nn.Module):
for unet, device in zip(self.unets, devices): for unet, device in zip(self.unets, devices):
unet.to(device) unet.to(device)
def dynamic_threshold(self, x):
""" proposed in https://arxiv.org/abs/2205.11487 as an improved clamping in the setting of classifier free guidance """
# s is the threshold amount
# static thresholding would just be s = 1
s = 1.
if self.use_dynamic_thres:
s = torch.quantile(
rearrange(x, 'b ... -> b (...)').abs(),
self.dynamic_thres_percentile,
dim = -1
)
s.clamp_(min = 1.)
s = s.view(-1, *((1,) * (x.ndim - 1)))
# clip by threshold, depending on whether static or dynamic
x = x.clamp(-s, s) / s
return x
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None): def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)' assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
@@ -2081,21 +2179,7 @@ class Decoder(nn.Module):
x_recon = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred) x_recon = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised: if clip_denoised:
# s is the threshold amount x_recon = self.dynamic_threshold(x_recon)
# static thresholding would just be s = 1
s = 1.
if self.use_dynamic_thres:
s = torch.quantile(
rearrange(x_recon, 'b ... -> b (...)').abs(),
self.dynamic_thres_percentile,
dim = -1
)
s.clamp_(min = 1.)
s = s.view(-1, *((1,) * (x_recon.ndim - 1)))
# clip by threshold, depending on whether static or dynamic
x_recon = x_recon.clamp(-s, s) / s
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t) model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
@@ -2125,7 +2209,7 @@ class Decoder(nn.Module):
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad() @torch.no_grad()
def p_sample_loop(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False): def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
device = self.device device = self.device
b = shape[0] b = shape[0]
@@ -2153,6 +2237,61 @@ class Decoder(nn.Module):
unnormalize_img = self.unnormalize_img(img) unnormalize_img = self.unnormalize_img(img)
return unnormalize_img return unnormalize_img
@torch.no_grad()
def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:]))
img = torch.randn(shape, device = device)
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
alpha = alphas[time]
alpha_next = alphas[time_next]
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
if learned_variance:
pred, _ = pred.chunk(2, dim = 1)
if predict_x_start:
x_start = pred
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
else:
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
pred_noise = pred
if clip_denoised:
x_start = self.dynamic_threshold(x_start)
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
img = x_start * alpha_next.sqrt() + \
c1 * torch.randn_like(img) + \
c2 * pred_noise
img = self.unnormalize_img(img)
return img
@torch.no_grad()
def p_sample_loop(self, *args, noise_scheduler, timesteps = None, **kwargs):
num_timesteps = noise_scheduler.num_timesteps
timesteps = default(timesteps, num_timesteps)
assert timesteps <= num_timesteps
is_ddim = timesteps < num_timesteps
if not is_ddim:
return self.p_sample_loop_ddpm(*args, noise_scheduler = noise_scheduler, **kwargs)
return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False): def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
@@ -2253,7 +2392,7 @@ class Decoder(nn.Module):
img = None img = None
is_cuda = next(self.parameters()).is_cuda is_cuda = next(self.parameters()).is_cuda
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers)): for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, sample_timesteps in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.sample_timesteps)):
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context() context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
@@ -2282,7 +2421,8 @@ class Decoder(nn.Module):
clip_denoised = not is_latent_diffusion, clip_denoised = not is_latent_diffusion,
lowres_cond_img = lowres_cond_img, lowres_cond_img = lowres_cond_img,
is_latent_diffusion = is_latent_diffusion, is_latent_diffusion = is_latent_diffusion,
noise_scheduler = noise_scheduler noise_scheduler = noise_scheduler,
timesteps = sample_timesteps
) )
img = vae.decode(img) img = vae.decode(img)

View File

@@ -154,6 +154,7 @@ class DiffusionPriorConfig(BaseModel):
image_size: int image_size: int
image_channels: int = 3 image_channels: int = 3
timesteps: int = 1000 timesteps: int = 1000
sample_timesteps: Optional[int] = None
cond_drop_prob: float = 0. cond_drop_prob: float = 0.
loss_type: str = 'l2' loss_type: str = 'l2'
predict_x_start: bool = True predict_x_start: bool = True
@@ -233,6 +234,7 @@ class DecoderConfig(BaseModel):
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
channels: int = 3 channels: int = 3
timesteps: int = 1000 timesteps: int = 1000
sample_timesteps: Optional[SingularOrIterable(int)] = None
loss_type: str = 'l2' loss_type: str = 'l2'
beta_schedule: ListOrTuple(str) = 'cosine' beta_schedule: ListOrTuple(str) = 'cosine'
learned_variance: bool = True learned_variance: bool = True

View File

@@ -536,11 +536,19 @@ class DecoderTrainer(nn.Module):
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip" assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
clip = decoder.clip clip = decoder.clip
clip.to(precision_type) clip.to(precision_type)
decoder, train_loader, val_loader, *optimizers = list(self.accelerator.prepare(decoder, dataloaders["train"], dataloaders["val"], *optimizers))
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
self.decoder = decoder
# prepare dataloaders
train_loader = val_loader = None
if exists(dataloaders):
train_loader, val_loader = self.accelerator.prepare(dataloaders["train"], dataloaders["val"])
self.train_loader = train_loader self.train_loader = train_loader
self.val_loader = val_loader self.val_loader = val_loader
self.decoder = decoder
# store optimizers # store optimizers

View File

@@ -1 +1 @@
__version__ = '0.18.0' __version__ = '0.19.1'