mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-15 00:44:25 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
301a97197f |
10
README.md
10
README.md
@@ -1264,14 +1264,4 @@ For detailed information on training the diffusion prior, please refer to the [d
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
```bibtex
|
|
||||||
@article{Qiao2019WeightS,
|
|
||||||
title = {Weight Standardization},
|
|
||||||
author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille},
|
|
||||||
journal = {ArXiv},
|
|
||||||
year = {2019},
|
|
||||||
volume = {abs/1903.10520}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||||
|
|||||||
@@ -38,8 +38,6 @@ from coca_pytorch import CoCa
|
|||||||
|
|
||||||
NAT = 1. / math.log(2.)
|
NAT = 1. / math.log(2.)
|
||||||
|
|
||||||
UnetOutput = namedtuple('UnetOutput', ['pred', 'var_interp_frac_unnormalized'])
|
|
||||||
|
|
||||||
# helper functions
|
# helper functions
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
@@ -1279,12 +1277,9 @@ class DiffusionPrior(nn.Module):
|
|||||||
is_ddim = timesteps < self.noise_scheduler.num_timesteps
|
is_ddim = timesteps < self.noise_scheduler.num_timesteps
|
||||||
|
|
||||||
if not is_ddim:
|
if not is_ddim:
|
||||||
normalized_image_embed = self.p_sample_loop_ddpm(*args, **kwargs)
|
return self.p_sample_loop_ddpm(*args, **kwargs)
|
||||||
else:
|
|
||||||
normalized_image_embed = self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
|
|
||||||
|
|
||||||
image_embed = normalized_image_embed / self.image_embed_scale
|
return self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
|
||||||
return image_embed
|
|
||||||
|
|
||||||
def p_losses(self, image_embed, times, text_cond, noise = None):
|
def p_losses(self, image_embed, times, text_cond, noise = None):
|
||||||
noise = default(noise, lambda: torch.randn_like(image_embed))
|
noise = default(noise, lambda: torch.randn_like(image_embed))
|
||||||
@@ -1292,7 +1287,7 @@ class DiffusionPrior(nn.Module):
|
|||||||
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
|
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
|
||||||
|
|
||||||
self_cond = None
|
self_cond = None
|
||||||
if self.net.self_cond and random.random() < 0.5:
|
if self.net.self_cond and random.random() < 1.5:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self_cond = self.net(image_embed_noisy, times, **text_cond).detach()
|
self_cond = self.net(image_embed_noisy, times, **text_cond).detach()
|
||||||
|
|
||||||
@@ -1353,6 +1348,8 @@ class DiffusionPrior(nn.Module):
|
|||||||
|
|
||||||
# retrieve original unscaled image embed
|
# retrieve original unscaled image embed
|
||||||
|
|
||||||
|
image_embeds /= self.image_embed_scale
|
||||||
|
|
||||||
text_embeds = text_cond['text_embed']
|
text_embeds = text_cond['text_embed']
|
||||||
|
|
||||||
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
||||||
@@ -1451,30 +1448,6 @@ def Downsample(dim, *, dim_out = None):
|
|||||||
dim_out = default(dim_out, dim)
|
dim_out = default(dim_out, dim)
|
||||||
return nn.Conv2d(dim, dim_out, 4, 2, 1)
|
return nn.Conv2d(dim, dim_out, 4, 2, 1)
|
||||||
|
|
||||||
class WeightStandardizedConv2d(nn.Conv2d):
|
|
||||||
"""
|
|
||||||
https://arxiv.org/abs/1903.10520
|
|
||||||
weight standardization purportedly works synergistically with group normalization
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
|
|
||||||
|
|
||||||
weight = self.weight
|
|
||||||
flattened_weights = rearrange(weight, 'o ... -> o (...)')
|
|
||||||
|
|
||||||
mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
|
|
||||||
|
|
||||||
var = torch.var(flattened_weights, dim = -1, unbiased = False)
|
|
||||||
var = rearrange(var, 'o -> o 1 1 1')
|
|
||||||
|
|
||||||
weight = (weight - mean) * (var + eps).rsqrt()
|
|
||||||
|
|
||||||
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
|
||||||
|
|
||||||
class SinusoidalPosEmb(nn.Module):
|
class SinusoidalPosEmb(nn.Module):
|
||||||
def __init__(self, dim):
|
def __init__(self, dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -1493,13 +1466,10 @@ class Block(nn.Module):
|
|||||||
self,
|
self,
|
||||||
dim,
|
dim,
|
||||||
dim_out,
|
dim_out,
|
||||||
groups = 8,
|
groups = 8
|
||||||
weight_standardization = False
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
conv_klass = nn.Conv2d if not weight_standardization else WeightStandardizedConv2d
|
self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
|
||||||
|
|
||||||
self.project = conv_klass(dim, dim_out, 3, padding = 1)
|
|
||||||
self.norm = nn.GroupNorm(groups, dim_out)
|
self.norm = nn.GroupNorm(groups, dim_out)
|
||||||
self.act = nn.SiLU()
|
self.act = nn.SiLU()
|
||||||
|
|
||||||
@@ -1523,7 +1493,6 @@ class ResnetBlock(nn.Module):
|
|||||||
cond_dim = None,
|
cond_dim = None,
|
||||||
time_cond_dim = None,
|
time_cond_dim = None,
|
||||||
groups = 8,
|
groups = 8,
|
||||||
weight_standardization = False,
|
|
||||||
cosine_sim_cross_attn = False
|
cosine_sim_cross_attn = False
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -1549,8 +1518,8 @@ class ResnetBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.block1 = Block(dim, dim_out, groups = groups, weight_standardization = weight_standardization)
|
self.block1 = Block(dim, dim_out, groups = groups)
|
||||||
self.block2 = Block(dim_out, dim_out, groups = groups, weight_standardization = weight_standardization)
|
self.block2 = Block(dim_out, dim_out, groups = groups)
|
||||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, time_emb = None, cond = None):
|
def forward(self, x, time_emb = None, cond = None):
|
||||||
@@ -1775,7 +1744,6 @@ class Unet(nn.Module):
|
|||||||
init_dim = None,
|
init_dim = None,
|
||||||
init_conv_kernel_size = 7,
|
init_conv_kernel_size = 7,
|
||||||
resnet_groups = 8,
|
resnet_groups = 8,
|
||||||
resnet_weight_standardization = False,
|
|
||||||
num_resnet_blocks = 2,
|
num_resnet_blocks = 2,
|
||||||
init_cross_embed = True,
|
init_cross_embed = True,
|
||||||
init_cross_embed_kernel_sizes = (3, 7, 15),
|
init_cross_embed_kernel_sizes = (3, 7, 15),
|
||||||
@@ -1923,7 +1891,7 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
# prepare resnet klass
|
# prepare resnet klass
|
||||||
|
|
||||||
resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn, weight_standardization = resnet_weight_standardization)
|
resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn)
|
||||||
|
|
||||||
# give memory efficient unet an initial resnet block
|
# give memory efficient unet an initial resnet block
|
||||||
|
|
||||||
@@ -2616,14 +2584,6 @@ class Decoder(nn.Module):
|
|||||||
index = unet_number - 1
|
index = unet_number - 1
|
||||||
return self.unets[index]
|
return self.unets[index]
|
||||||
|
|
||||||
def parse_unet_output(self, learned_variance, output):
|
|
||||||
var_interp_frac_unnormalized = None
|
|
||||||
|
|
||||||
if learned_variance:
|
|
||||||
output, var_interp_frac_unnormalized = output.chunk(2, dim = 1)
|
|
||||||
|
|
||||||
return UnetOutput(output, var_interp_frac_unnormalized)
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def one_unet_in_gpu(self, unet_number = None, unet = None):
|
def one_unet_in_gpu(self, unet_number = None, unet = None):
|
||||||
assert exists(unet_number) ^ exists(unet)
|
assert exists(unet_number) ^ exists(unet)
|
||||||
@@ -2665,9 +2625,10 @@ class Decoder(nn.Module):
|
|||||||
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
|
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
|
||||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||||
|
|
||||||
model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
|
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
|
||||||
|
|
||||||
pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output)
|
if learned_variance:
|
||||||
|
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
|
||||||
|
|
||||||
if predict_x_start:
|
if predict_x_start:
|
||||||
x_start = pred
|
x_start = pred
|
||||||
@@ -2850,9 +2811,10 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
self_cond = x_start if unet.self_cond else None
|
self_cond = x_start if unet.self_cond else None
|
||||||
|
|
||||||
unet_output = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
|
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
|
||||||
|
|
||||||
pred, _ = self.parse_unet_output(learned_variance, unet_output)
|
if learned_variance:
|
||||||
|
pred, _ = pred.chunk(2, dim = 1)
|
||||||
|
|
||||||
if predict_x_start:
|
if predict_x_start:
|
||||||
x_start = pred
|
x_start = pred
|
||||||
@@ -2924,13 +2886,16 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
if unet.self_cond and random.random() < 0.5:
|
if unet.self_cond and random.random() < 0.5:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
unet_output = unet(x_noisy, times, **unet_kwargs)
|
self_cond = unet(x_noisy, times, **unet_kwargs)
|
||||||
self_cond, _ = self.parse_unet_output(learned_variance, unet_output)
|
|
||||||
|
if learned_variance:
|
||||||
|
self_cond, _ = self_cond.chunk(2, dim = 1)
|
||||||
|
|
||||||
self_cond = self_cond.detach()
|
self_cond = self_cond.detach()
|
||||||
|
|
||||||
# forward to get model prediction
|
# forward to get model prediction
|
||||||
|
|
||||||
unet_output = unet(
|
model_output = unet(
|
||||||
x_noisy,
|
x_noisy,
|
||||||
times,
|
times,
|
||||||
**unet_kwargs,
|
**unet_kwargs,
|
||||||
@@ -2939,7 +2904,10 @@ class Decoder(nn.Module):
|
|||||||
text_cond_drop_prob = self.text_cond_drop_prob,
|
text_cond_drop_prob = self.text_cond_drop_prob,
|
||||||
)
|
)
|
||||||
|
|
||||||
pred, _ = self.parse_unet_output(learned_variance, unet_output)
|
if learned_variance:
|
||||||
|
pred, _ = model_output.chunk(2, dim = 1)
|
||||||
|
else:
|
||||||
|
pred = model_output
|
||||||
|
|
||||||
target = noise if not predict_x_start else x_start
|
target = noise if not predict_x_start else x_start
|
||||||
|
|
||||||
@@ -2962,7 +2930,7 @@ class Decoder(nn.Module):
|
|||||||
# if learning the variance, also include the extra weight kl loss
|
# if learning the variance, also include the extra weight kl loss
|
||||||
|
|
||||||
true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
|
true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
|
||||||
model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = unet_output)
|
model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
|
||||||
|
|
||||||
# kl loss with detached model predicted mean, for stability reasons as in paper
|
# kl loss with detached model predicted mean, for stability reasons as in paper
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '1.7.0'
|
__version__ = '1.6.2'
|
||||||
|
|||||||
Reference in New Issue
Block a user