Compare commits

..

1 Commits
1.8.3 ... 1.6.2

Author SHA1 Message Date
Phil Wang
301a97197f fix self conditioning shape in diffusion prior 2022-08-12 12:29:25 -07:00
4 changed files with 41 additions and 121 deletions

View File

@@ -49,7 +49,6 @@ This library would not have gotten to this working state without the help of
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
- <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library
- <a href="https://github.com/arogozhnikov">Alex</a> for <a href="https://github.com/arogozhnikov/einops">einops</a>, indispensable tool for tensor manipulation
... and many others. Thank you! 🙏
@@ -1265,24 +1264,4 @@ For detailed information on training the diffusion prior, please refer to the [d
}
```
```bibtex
@article{Qiao2019WeightS,
title = {Weight Standardization},
author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille},
journal = {ArXiv},
year = {2019},
volume = {abs/1903.10520}
}
```
```bibtex
@inproceedings{rogozhnikov2022einops,
title = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
author = {Alex Rogozhnikov},
booktitle = {International Conference on Learning Representations},
year = {2022},
url = {https://openreview.net/forum?id=oapKSVM2bcj}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -38,8 +38,6 @@ from coca_pytorch import CoCa
NAT = 1. / math.log(2.)
UnetOutput = namedtuple('UnetOutput', ['pred', 'var_interp_frac_unnormalized'])
# helper functions
def exists(val):
@@ -250,13 +248,9 @@ class XClipAdapter(BaseClipAdapter):
text = text[..., :self.max_text_len]
text_mask = text != 0
encoder_output = self.clip.text_transformer(text)
text_cls, text_encodings = (encoder_output[:, 0], encoder_output[:, 1:]) if encoder_output.ndim == 3 else (encoder_output, None)
text_cls, text_encodings = encoder_output[:, 0], encoder_output[:, 1:]
text_embed = self.clip.to_text_latent(text_cls)
if exists(text_encodings):
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
return EmbeddedText(l2norm(text_embed), text_encodings)
@torch.no_grad()
@@ -1283,12 +1277,9 @@ class DiffusionPrior(nn.Module):
is_ddim = timesteps < self.noise_scheduler.num_timesteps
if not is_ddim:
normalized_image_embed = self.p_sample_loop_ddpm(*args, **kwargs)
else:
normalized_image_embed = self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
return self.p_sample_loop_ddpm(*args, **kwargs)
image_embed = normalized_image_embed / self.image_embed_scale
return image_embed
return self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
def p_losses(self, image_embed, times, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(image_embed))
@@ -1296,7 +1287,7 @@ class DiffusionPrior(nn.Module):
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
self_cond = None
if self.net.self_cond and random.random() < 0.5:
if self.net.self_cond and random.random() < 1.5:
with torch.no_grad():
self_cond = self.net(image_embed_noisy, times, **text_cond).detach()
@@ -1357,6 +1348,8 @@ class DiffusionPrior(nn.Module):
# retrieve original unscaled image embed
image_embeds /= self.image_embed_scale
text_embeds = text_cond['text_embed']
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
@@ -1455,26 +1448,6 @@ def Downsample(dim, *, dim_out = None):
dim_out = default(dim_out, dim)
return nn.Conv2d(dim, dim_out, 4, 2, 1)
class WeightStandardizedConv2d(nn.Conv2d):
"""
https://arxiv.org/abs/1903.10520
weight standardization purportedly works synergistically with group normalization
"""
def forward(self, x):
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
weight = self.weight
flattened_weights = rearrange(weight, 'o ... -> o (...)')
mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
var = torch.var(flattened_weights, dim = -1, unbiased = False)
var = rearrange(var, 'o -> o 1 1 1')
weight = (weight - mean) * (var + eps).rsqrt()
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
@@ -1493,13 +1466,10 @@ class Block(nn.Module):
self,
dim,
dim_out,
groups = 8,
weight_standardization = False
groups = 8
):
super().__init__()
conv_klass = nn.Conv2d if not weight_standardization else WeightStandardizedConv2d
self.project = conv_klass(dim, dim_out, 3, padding = 1)
self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
@@ -1523,7 +1493,6 @@ class ResnetBlock(nn.Module):
cond_dim = None,
time_cond_dim = None,
groups = 8,
weight_standardization = False,
cosine_sim_cross_attn = False
):
super().__init__()
@@ -1549,8 +1518,8 @@ class ResnetBlock(nn.Module):
)
)
self.block1 = Block(dim, dim_out, groups = groups, weight_standardization = weight_standardization)
self.block2 = Block(dim_out, dim_out, groups = groups, weight_standardization = weight_standardization)
self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb = None, cond = None):
@@ -1775,7 +1744,6 @@ class Unet(nn.Module):
init_dim = None,
init_conv_kernel_size = 7,
resnet_groups = 8,
resnet_weight_standardization = False,
num_resnet_blocks = 2,
init_cross_embed = True,
init_cross_embed_kernel_sizes = (3, 7, 15),
@@ -1923,7 +1891,7 @@ class Unet(nn.Module):
# prepare resnet klass
resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn, weight_standardization = resnet_weight_standardization)
resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn)
# give memory efficient unet an initial resnet block
@@ -2616,14 +2584,6 @@ class Decoder(nn.Module):
index = unet_number - 1
return self.unets[index]
def parse_unet_output(self, learned_variance, output):
var_interp_frac_unnormalized = None
if learned_variance:
output, var_interp_frac_unnormalized = output.chunk(2, dim = 1)
return UnetOutput(output, var_interp_frac_unnormalized)
@contextmanager
def one_unet_in_gpu(self, unet_number = None, unet = None):
assert exists(unet_number) ^ exists(unet)
@@ -2665,9 +2625,10 @@ class Decoder(nn.Module):
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output)
if learned_variance:
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
if predict_x_start:
x_start = pred
@@ -2850,9 +2811,10 @@ class Decoder(nn.Module):
self_cond = x_start if unet.self_cond else None
unet_output = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
pred, _ = self.parse_unet_output(learned_variance, unet_output)
if learned_variance:
pred, _ = pred.chunk(2, dim = 1)
if predict_x_start:
x_start = pred
@@ -2924,13 +2886,16 @@ class Decoder(nn.Module):
if unet.self_cond and random.random() < 0.5:
with torch.no_grad():
unet_output = unet(x_noisy, times, **unet_kwargs)
self_cond, _ = self.parse_unet_output(learned_variance, unet_output)
self_cond = unet(x_noisy, times, **unet_kwargs)
if learned_variance:
self_cond, _ = self_cond.chunk(2, dim = 1)
self_cond = self_cond.detach()
# forward to get model prediction
unet_output = unet(
model_output = unet(
x_noisy,
times,
**unet_kwargs,
@@ -2939,7 +2904,10 @@ class Decoder(nn.Module):
text_cond_drop_prob = self.text_cond_drop_prob,
)
pred, _ = self.parse_unet_output(learned_variance, unet_output)
if learned_variance:
pred, _ = model_output.chunk(2, dim = 1)
else:
pred = model_output
target = noise if not predict_x_start else x_start
@@ -2962,7 +2930,7 @@ class Decoder(nn.Module):
# if learning the variance, also include the extra weight kl loss
true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = unet_output)
model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
# kl loss with detached model predicted mean, for stability reasons as in paper

View File

@@ -9,7 +9,7 @@ from collections.abc import Iterable
import torch
import torch.nn.functional as F
from torch import nn
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
from torch.optim.lr_scheduler import LambdaLR
from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
@@ -181,8 +181,7 @@ class DiffusionPriorTrainer(nn.Module):
eps = 1e-6,
max_grad_norm = None,
group_wd_params = True,
warmup_steps = None,
cosine_decay_max_steps = None,
warmup_steps = 1,
**kwargs
):
super().__init__()
@@ -234,11 +233,8 @@ class DiffusionPriorTrainer(nn.Module):
**self.optim_kwargs,
**kwargs
)
if exists(cosine_decay_max_steps):
self.scheduler = CosineAnnealingLR(optimizer, T_max = cosine_decay_max_steps)
else:
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
@@ -275,7 +271,6 @@ class DiffusionPriorTrainer(nn.Module):
# FIXME: LambdaLR can't be saved due to pickling issues
save_obj = dict(
optimizer = self.optimizer.state_dict(),
scheduler = self.scheduler.state_dict(),
warmup_scheduler = self.warmup_scheduler,
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
version = version.parse(__version__),
@@ -322,9 +317,7 @@ class DiffusionPriorTrainer(nn.Module):
# unwrap the model when loading from checkpoint
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
self.optimizer.load_state_dict(loaded_obj['optimizer'])
self.scheduler.load_state_dict(loaded_obj['scheduler'])
# set warmupstep
if exists(self.warmup_scheduler):
@@ -357,8 +350,7 @@ class DiffusionPriorTrainer(nn.Module):
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
if not self.accelerator.optimizer_step_was_skipped:
sched_context = self.warmup_scheduler.dampening if exists(self.warmup_scheduler) else nullcontext
with sched_context():
with self.warmup_scheduler.dampening():
self.scheduler.step()
if self.use_ema:
@@ -441,7 +433,6 @@ class DecoderTrainer(nn.Module):
wd = 1e-2,
eps = 1e-8,
warmup_steps = None,
cosine_decay_max_steps = None,
max_grad_norm = 0.5,
amp = False,
group_wd_params = True,
@@ -463,7 +454,7 @@ class DecoderTrainer(nn.Module):
# be able to finely customize learning rate, weight decay
# per unet
lr, wd, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps, cosine_decay_max_steps))
lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps))
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
@@ -471,7 +462,7 @@ class DecoderTrainer(nn.Module):
schedulers = []
warmup_schedulers = []
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps in zip(decoder.unets, lr, wd, eps, warmup_steps, cosine_decay_max_steps):
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
if isinstance(unet, nn.Identity):
optimizers.append(None)
schedulers.append(None)
@@ -487,11 +478,7 @@ class DecoderTrainer(nn.Module):
)
optimizers.append(optimizer)
if exists(unet_cosine_decay_max_steps):
scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)
else:
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
warmup_schedulers.append(warmup_scheduler)
@@ -571,15 +558,9 @@ class DecoderTrainer(nn.Module):
for ind in range(0, self.num_unets):
optimizer_key = f'optim{ind}'
scheduler_key = f'sched{ind}'
optimizer = getattr(self, optimizer_key)
scheduler = getattr(self, scheduler_key)
optimizer_state_dict = optimizer.state_dict() if exists(optimizer) else None
scheduler_state_dict = scheduler.state_dict() if exists(scheduler) else None
save_obj = {**save_obj, optimizer_key: optimizer_state_dict, scheduler_key: scheduler_state_dict}
state_dict = optimizer.state_dict() if optimizer is not None else None
save_obj = {**save_obj, optimizer_key: state_dict}
if self.use_ema:
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
@@ -600,18 +581,10 @@ class DecoderTrainer(nn.Module):
optimizer_key = f'optim{ind}'
optimizer = getattr(self, optimizer_key)
scheduler_key = f'sched{ind}'
scheduler = getattr(self, scheduler_key)
warmup_scheduler = self.warmup_schedulers[ind]
if exists(optimizer):
if optimizer is not None:
optimizer.load_state_dict(loaded_obj[optimizer_key])
if exists(scheduler):
scheduler.load_state_dict(loaded_obj[scheduler_key])
if exists(warmup_scheduler):
warmup_scheduler.last_step = last_step

View File

@@ -1 +1 @@
__version__ = '1.8.3'
__version__ = '1.6.2'