mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-15 12:14:28 +01:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
27f19ba7fa | ||
|
|
8f38339c2b | ||
|
|
6b9b4b9e5e | ||
|
|
44e09d5a4d | ||
|
|
34806663e3 | ||
|
|
dc816b1b6e | ||
|
|
05192ffac4 | ||
|
|
9440411954 | ||
|
|
981d407792 |
10
README.md
10
README.md
@@ -1264,4 +1264,14 @@ For detailed information on training the diffusion prior, please refer to the [d
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{Qiao2019WeightS,
|
||||||
|
title = {Weight Standardization},
|
||||||
|
author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille},
|
||||||
|
journal = {ArXiv},
|
||||||
|
year = {2019},
|
||||||
|
volume = {abs/1903.10520}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||||
|
|||||||
@@ -38,6 +38,8 @@ from coca_pytorch import CoCa
|
|||||||
|
|
||||||
NAT = 1. / math.log(2.)
|
NAT = 1. / math.log(2.)
|
||||||
|
|
||||||
|
UnetOutput = namedtuple('UnetOutput', ['pred', 'var_interp_frac_unnormalized'])
|
||||||
|
|
||||||
# helper functions
|
# helper functions
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
@@ -937,9 +939,12 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
num_image_embeds = 1,
|
num_image_embeds = 1,
|
||||||
num_text_embeds = 1,
|
num_text_embeds = 1,
|
||||||
max_text_len = 256,
|
max_text_len = 256,
|
||||||
|
self_cond = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
self.num_time_embeds = num_time_embeds
|
self.num_time_embeds = num_time_embeds
|
||||||
self.num_image_embeds = num_image_embeds
|
self.num_image_embeds = num_image_embeds
|
||||||
self.num_text_embeds = num_text_embeds
|
self.num_text_embeds = num_text_embeds
|
||||||
@@ -967,6 +972,10 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
self.max_text_len = max_text_len
|
self.max_text_len = max_text_len
|
||||||
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
|
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
|
||||||
|
|
||||||
|
# whether to use self conditioning, Hinton's group's new ddpm technique
|
||||||
|
|
||||||
|
self.self_cond = self_cond
|
||||||
|
|
||||||
def forward_with_cond_scale(
|
def forward_with_cond_scale(
|
||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
@@ -988,12 +997,19 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
*,
|
*,
|
||||||
text_embed,
|
text_embed,
|
||||||
text_encodings = None,
|
text_encodings = None,
|
||||||
|
self_cond = None,
|
||||||
cond_drop_prob = 0.
|
cond_drop_prob = 0.
|
||||||
):
|
):
|
||||||
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
|
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
|
||||||
|
|
||||||
num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds
|
num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds
|
||||||
|
|
||||||
|
# setup self conditioning
|
||||||
|
|
||||||
|
if self.self_cond:
|
||||||
|
self_cond = default(self_cond, lambda: torch.zeros(batch, self.dim, device = device, dtype = dtype))
|
||||||
|
self_cond = rearrange(self_cond, 'b d -> b 1 d')
|
||||||
|
|
||||||
# in section 2.2, last paragraph
|
# in section 2.2, last paragraph
|
||||||
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
|
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
|
||||||
|
|
||||||
@@ -1043,13 +1059,16 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
||||||
# but let's just do it right
|
# but let's just do it right
|
||||||
|
|
||||||
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
|
attend_padding = 1 + num_time_embeds + num_image_embeds + int(self.self_cond) # 1 for learned queries + number of image embeds + time embeds
|
||||||
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||||
|
|
||||||
time_embed = self.to_time_embeds(diffusion_timesteps)
|
time_embed = self.to_time_embeds(diffusion_timesteps)
|
||||||
|
|
||||||
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
||||||
|
|
||||||
|
if self.self_cond:
|
||||||
|
learned_queries = torch.cat((image_embed, self_cond), dim = -2)
|
||||||
|
|
||||||
tokens = torch.cat((
|
tokens = torch.cat((
|
||||||
text_encodings,
|
text_encodings,
|
||||||
text_embed,
|
text_embed,
|
||||||
@@ -1151,10 +1170,10 @@ class DiffusionPrior(nn.Module):
|
|||||||
def l2norm_clamp_embed(self, image_embed):
|
def l2norm_clamp_embed(self, image_embed):
|
||||||
return l2norm(image_embed) * self.image_embed_scale
|
return l2norm(image_embed) * self.image_embed_scale
|
||||||
|
|
||||||
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
|
def p_mean_variance(self, x, t, text_cond, self_cond = None, clip_denoised = False, cond_scale = 1.):
|
||||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||||
|
|
||||||
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, **text_cond)
|
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)
|
||||||
|
|
||||||
if self.predict_x_start:
|
if self.predict_x_start:
|
||||||
x_start = pred
|
x_start = pred
|
||||||
@@ -1168,28 +1187,33 @@ class DiffusionPrior(nn.Module):
|
|||||||
x_start = l2norm(x_start) * self.image_embed_scale
|
x_start = l2norm(x_start) * self.image_embed_scale
|
||||||
|
|
||||||
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)
|
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)
|
||||||
return model_mean, posterior_variance, posterior_log_variance
|
return model_mean, posterior_variance, posterior_log_variance, x_start
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample(self, x, t, text_cond = None, clip_denoised = True, cond_scale = 1.):
|
def p_sample(self, x, t, text_cond = None, self_cond = None, clip_denoised = True, cond_scale = 1.):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
|
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = t, text_cond = text_cond, self_cond = self_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
|
||||||
noise = torch.randn_like(x)
|
noise = torch.randn_like(x)
|
||||||
# no noise when t == 0
|
# no noise when t == 0
|
||||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||||
|
return pred, x_start
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.):
|
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.):
|
||||||
batch, device = shape[0], self.device
|
batch, device = shape[0], self.device
|
||||||
|
|
||||||
image_embed = torch.randn(shape, device = device)
|
image_embed = torch.randn(shape, device = device)
|
||||||
|
x_start = None # for self-conditioning
|
||||||
|
|
||||||
if self.init_image_embed_l2norm:
|
if self.init_image_embed_l2norm:
|
||||||
image_embed = l2norm(image_embed) * self.image_embed_scale
|
image_embed = l2norm(image_embed) * self.image_embed_scale
|
||||||
|
|
||||||
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
|
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
|
||||||
times = torch.full((batch,), i, device = device, dtype = torch.long)
|
times = torch.full((batch,), i, device = device, dtype = torch.long)
|
||||||
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
|
|
||||||
|
self_cond = x_start if self.net.self_cond else None
|
||||||
|
image_embed, x_start = self.p_sample(image_embed, times, text_cond = text_cond, self_cond = self_cond, cond_scale = cond_scale)
|
||||||
|
|
||||||
if self.sampling_final_clamp_l2norm and self.predict_x_start:
|
if self.sampling_final_clamp_l2norm and self.predict_x_start:
|
||||||
image_embed = self.l2norm_clamp_embed(image_embed)
|
image_embed = self.l2norm_clamp_embed(image_embed)
|
||||||
@@ -1207,6 +1231,8 @@ class DiffusionPrior(nn.Module):
|
|||||||
|
|
||||||
image_embed = torch.randn(shape, device = device)
|
image_embed = torch.randn(shape, device = device)
|
||||||
|
|
||||||
|
x_start = None # for self-conditioning
|
||||||
|
|
||||||
if self.init_image_embed_l2norm:
|
if self.init_image_embed_l2norm:
|
||||||
image_embed = l2norm(image_embed) * self.image_embed_scale
|
image_embed = l2norm(image_embed) * self.image_embed_scale
|
||||||
|
|
||||||
@@ -1216,7 +1242,9 @@ class DiffusionPrior(nn.Module):
|
|||||||
|
|
||||||
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
||||||
|
|
||||||
pred = self.net.forward_with_cond_scale(image_embed, time_cond, cond_scale = cond_scale, **text_cond)
|
self_cond = x_start if self.net.self_cond else None
|
||||||
|
|
||||||
|
pred = self.net.forward_with_cond_scale(image_embed, time_cond, self_cond = self_cond, cond_scale = cond_scale, **text_cond)
|
||||||
|
|
||||||
if self.predict_x_start:
|
if self.predict_x_start:
|
||||||
x_start = pred
|
x_start = pred
|
||||||
@@ -1251,18 +1279,27 @@ class DiffusionPrior(nn.Module):
|
|||||||
is_ddim = timesteps < self.noise_scheduler.num_timesteps
|
is_ddim = timesteps < self.noise_scheduler.num_timesteps
|
||||||
|
|
||||||
if not is_ddim:
|
if not is_ddim:
|
||||||
return self.p_sample_loop_ddpm(*args, **kwargs)
|
normalized_image_embed = self.p_sample_loop_ddpm(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
normalized_image_embed = self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
|
||||||
|
|
||||||
return self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
|
image_embed = normalized_image_embed / self.image_embed_scale
|
||||||
|
return image_embed
|
||||||
|
|
||||||
def p_losses(self, image_embed, times, text_cond, noise = None):
|
def p_losses(self, image_embed, times, text_cond, noise = None):
|
||||||
noise = default(noise, lambda: torch.randn_like(image_embed))
|
noise = default(noise, lambda: torch.randn_like(image_embed))
|
||||||
|
|
||||||
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
|
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
|
||||||
|
|
||||||
|
self_cond = None
|
||||||
|
if self.net.self_cond and random.random() < 0.5:
|
||||||
|
with torch.no_grad():
|
||||||
|
self_cond = self.net(image_embed_noisy, times, **text_cond).detach()
|
||||||
|
|
||||||
pred = self.net(
|
pred = self.net(
|
||||||
image_embed_noisy,
|
image_embed_noisy,
|
||||||
times,
|
times,
|
||||||
|
self_cond = self_cond,
|
||||||
cond_drop_prob = self.cond_drop_prob,
|
cond_drop_prob = self.cond_drop_prob,
|
||||||
**text_cond
|
**text_cond
|
||||||
)
|
)
|
||||||
@@ -1316,8 +1353,6 @@ class DiffusionPrior(nn.Module):
|
|||||||
|
|
||||||
# retrieve original unscaled image embed
|
# retrieve original unscaled image embed
|
||||||
|
|
||||||
image_embeds /= self.image_embed_scale
|
|
||||||
|
|
||||||
text_embeds = text_cond['text_embed']
|
text_embeds = text_cond['text_embed']
|
||||||
|
|
||||||
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
||||||
@@ -1416,6 +1451,26 @@ def Downsample(dim, *, dim_out = None):
|
|||||||
dim_out = default(dim_out, dim)
|
dim_out = default(dim_out, dim)
|
||||||
return nn.Conv2d(dim, dim_out, 4, 2, 1)
|
return nn.Conv2d(dim, dim_out, 4, 2, 1)
|
||||||
|
|
||||||
|
class WeightStandardizedConv2d(nn.Conv2d):
|
||||||
|
"""
|
||||||
|
https://arxiv.org/abs/1903.10520
|
||||||
|
weight standardization purportedly works synergistically with group normalization
|
||||||
|
"""
|
||||||
|
def forward(self, x):
|
||||||
|
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
|
||||||
|
|
||||||
|
weight = self.weight
|
||||||
|
flattened_weights = rearrange(weight, 'o ... -> o (...)')
|
||||||
|
|
||||||
|
mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
|
||||||
|
|
||||||
|
var = torch.var(flattened_weights, dim = -1, unbiased = False)
|
||||||
|
var = rearrange(var, 'o -> o 1 1 1')
|
||||||
|
|
||||||
|
weight = (weight - mean) * (var + eps).rsqrt()
|
||||||
|
|
||||||
|
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
|
|
||||||
class SinusoidalPosEmb(nn.Module):
|
class SinusoidalPosEmb(nn.Module):
|
||||||
def __init__(self, dim):
|
def __init__(self, dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -1434,10 +1489,13 @@ class Block(nn.Module):
|
|||||||
self,
|
self,
|
||||||
dim,
|
dim,
|
||||||
dim_out,
|
dim_out,
|
||||||
groups = 8
|
groups = 8,
|
||||||
|
weight_standardization = False
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
|
conv_klass = nn.Conv2d if not weight_standardization else WeightStandardizedConv2d
|
||||||
|
|
||||||
|
self.project = conv_klass(dim, dim_out, 3, padding = 1)
|
||||||
self.norm = nn.GroupNorm(groups, dim_out)
|
self.norm = nn.GroupNorm(groups, dim_out)
|
||||||
self.act = nn.SiLU()
|
self.act = nn.SiLU()
|
||||||
|
|
||||||
@@ -1461,6 +1519,7 @@ class ResnetBlock(nn.Module):
|
|||||||
cond_dim = None,
|
cond_dim = None,
|
||||||
time_cond_dim = None,
|
time_cond_dim = None,
|
||||||
groups = 8,
|
groups = 8,
|
||||||
|
weight_standardization = False,
|
||||||
cosine_sim_cross_attn = False
|
cosine_sim_cross_attn = False
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -1486,8 +1545,8 @@ class ResnetBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.block1 = Block(dim, dim_out, groups = groups)
|
self.block1 = Block(dim, dim_out, groups = groups, weight_standardization = weight_standardization)
|
||||||
self.block2 = Block(dim_out, dim_out, groups = groups)
|
self.block2 = Block(dim_out, dim_out, groups = groups, weight_standardization = weight_standardization)
|
||||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, time_emb = None, cond = None):
|
def forward(self, x, time_emb = None, cond = None):
|
||||||
@@ -1700,7 +1759,7 @@ class Unet(nn.Module):
|
|||||||
attn_heads = 16,
|
attn_heads = 16,
|
||||||
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||||
lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen
|
lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen
|
||||||
self_cond = False,
|
self_cond = False, # set this to True to use the self-conditioning technique from - https://arxiv.org/abs/2208.04202
|
||||||
sparse_attn = False,
|
sparse_attn = False,
|
||||||
cosine_sim_cross_attn = False,
|
cosine_sim_cross_attn = False,
|
||||||
cosine_sim_self_attn = False,
|
cosine_sim_self_attn = False,
|
||||||
@@ -1712,6 +1771,7 @@ class Unet(nn.Module):
|
|||||||
init_dim = None,
|
init_dim = None,
|
||||||
init_conv_kernel_size = 7,
|
init_conv_kernel_size = 7,
|
||||||
resnet_groups = 8,
|
resnet_groups = 8,
|
||||||
|
resnet_weight_standardization = False,
|
||||||
num_resnet_blocks = 2,
|
num_resnet_blocks = 2,
|
||||||
init_cross_embed = True,
|
init_cross_embed = True,
|
||||||
init_cross_embed_kernel_sizes = (3, 7, 15),
|
init_cross_embed_kernel_sizes = (3, 7, 15),
|
||||||
@@ -1859,7 +1919,7 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
# prepare resnet klass
|
# prepare resnet klass
|
||||||
|
|
||||||
resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn)
|
resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn, weight_standardization = resnet_weight_standardization)
|
||||||
|
|
||||||
# give memory efficient unet an initial resnet block
|
# give memory efficient unet an initial resnet block
|
||||||
|
|
||||||
@@ -2552,6 +2612,14 @@ class Decoder(nn.Module):
|
|||||||
index = unet_number - 1
|
index = unet_number - 1
|
||||||
return self.unets[index]
|
return self.unets[index]
|
||||||
|
|
||||||
|
def parse_unet_output(self, learned_variance, output):
|
||||||
|
var_interp_frac_unnormalized = None
|
||||||
|
|
||||||
|
if learned_variance:
|
||||||
|
output, var_interp_frac_unnormalized = output.chunk(2, dim = 1)
|
||||||
|
|
||||||
|
return UnetOutput(output, var_interp_frac_unnormalized)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def one_unet_in_gpu(self, unet_number = None, unet = None):
|
def one_unet_in_gpu(self, unet_number = None, unet = None):
|
||||||
assert exists(unet_number) ^ exists(unet)
|
assert exists(unet_number) ^ exists(unet)
|
||||||
@@ -2593,10 +2661,9 @@ class Decoder(nn.Module):
|
|||||||
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
|
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
|
||||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||||
|
|
||||||
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
|
model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
|
||||||
|
|
||||||
if learned_variance:
|
pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output)
|
||||||
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
|
|
||||||
|
|
||||||
if predict_x_start:
|
if predict_x_start:
|
||||||
x_start = pred
|
x_start = pred
|
||||||
@@ -2779,10 +2846,9 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
self_cond = x_start if unet.self_cond else None
|
self_cond = x_start if unet.self_cond else None
|
||||||
|
|
||||||
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
|
unet_output = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
|
||||||
|
|
||||||
if learned_variance:
|
pred, _ = self.parse_unet_output(learned_variance, unet_output)
|
||||||
pred, _ = pred.chunk(2, dim = 1)
|
|
||||||
|
|
||||||
if predict_x_start:
|
if predict_x_start:
|
||||||
x_start = pred
|
x_start = pred
|
||||||
@@ -2854,16 +2920,13 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
if unet.self_cond and random.random() < 0.5:
|
if unet.self_cond and random.random() < 0.5:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self_cond = unet(x_noisy, times, **unet_kwargs)
|
unet_output = unet(x_noisy, times, **unet_kwargs)
|
||||||
|
self_cond, _ = self.parse_unet_output(learned_variance, unet_output)
|
||||||
if learned_variance:
|
|
||||||
self_cond, _ = self_cond.chunk(2, dim = 1)
|
|
||||||
|
|
||||||
self_cond = self_cond.detach()
|
self_cond = self_cond.detach()
|
||||||
|
|
||||||
# forward to get model prediction
|
# forward to get model prediction
|
||||||
|
|
||||||
model_output = unet(
|
unet_output = unet(
|
||||||
x_noisy,
|
x_noisy,
|
||||||
times,
|
times,
|
||||||
**unet_kwargs,
|
**unet_kwargs,
|
||||||
@@ -2872,10 +2935,7 @@ class Decoder(nn.Module):
|
|||||||
text_cond_drop_prob = self.text_cond_drop_prob,
|
text_cond_drop_prob = self.text_cond_drop_prob,
|
||||||
)
|
)
|
||||||
|
|
||||||
if learned_variance:
|
pred, _ = self.parse_unet_output(learned_variance, unet_output)
|
||||||
pred, _ = model_output.chunk(2, dim = 1)
|
|
||||||
else:
|
|
||||||
pred = model_output
|
|
||||||
|
|
||||||
target = noise if not predict_x_start else x_start
|
target = noise if not predict_x_start else x_start
|
||||||
|
|
||||||
@@ -2898,7 +2958,7 @@ class Decoder(nn.Module):
|
|||||||
# if learning the variance, also include the extra weight kl loss
|
# if learning the variance, also include the extra weight kl loss
|
||||||
|
|
||||||
true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
|
true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
|
||||||
model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
|
model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = unet_output)
|
||||||
|
|
||||||
# kl loss with detached model predicted mean, for stability reasons as in paper
|
# kl loss with detached model predicted mean, for stability reasons as in paper
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from collections.abc import Iterable
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim.lr_scheduler import LambdaLR
|
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
|
||||||
from torch.cuda.amp import autocast, GradScaler
|
from torch.cuda.amp import autocast, GradScaler
|
||||||
|
|
||||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||||
@@ -181,7 +181,8 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
eps = 1e-6,
|
eps = 1e-6,
|
||||||
max_grad_norm = None,
|
max_grad_norm = None,
|
||||||
group_wd_params = True,
|
group_wd_params = True,
|
||||||
warmup_steps = 1,
|
warmup_steps = None,
|
||||||
|
cosine_decay_max_steps = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -233,8 +234,11 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
**self.optim_kwargs,
|
**self.optim_kwargs,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
|
if exists(cosine_decay_max_steps):
|
||||||
|
self.scheduler = CosineAnnealingLR(optimizer, T_max = cosine_decay_max_steps)
|
||||||
|
else:
|
||||||
|
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
|
||||||
|
|
||||||
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
|
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
|
||||||
|
|
||||||
@@ -271,6 +275,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
# FIXME: LambdaLR can't be saved due to pickling issues
|
# FIXME: LambdaLR can't be saved due to pickling issues
|
||||||
save_obj = dict(
|
save_obj = dict(
|
||||||
optimizer = self.optimizer.state_dict(),
|
optimizer = self.optimizer.state_dict(),
|
||||||
|
scheduler = self.scheduler.state_dict(),
|
||||||
warmup_scheduler = self.warmup_scheduler,
|
warmup_scheduler = self.warmup_scheduler,
|
||||||
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
|
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
|
||||||
version = version.parse(__version__),
|
version = version.parse(__version__),
|
||||||
@@ -317,7 +322,9 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
# unwrap the model when loading from checkpoint
|
# unwrap the model when loading from checkpoint
|
||||||
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
|
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
|
||||||
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
|
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
|
||||||
|
|
||||||
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||||
|
self.scheduler.load_state_dict(loaded_obj['scheduler'])
|
||||||
|
|
||||||
# set warmupstep
|
# set warmupstep
|
||||||
if exists(self.warmup_scheduler):
|
if exists(self.warmup_scheduler):
|
||||||
@@ -350,7 +357,8 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
|
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
|
||||||
if not self.accelerator.optimizer_step_was_skipped:
|
if not self.accelerator.optimizer_step_was_skipped:
|
||||||
with self.warmup_scheduler.dampening():
|
sched_context = self.warmup_scheduler.dampening if exists(self.warmup_scheduler) else nullcontext
|
||||||
|
with sched_context():
|
||||||
self.scheduler.step()
|
self.scheduler.step()
|
||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
@@ -433,6 +441,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
wd = 1e-2,
|
wd = 1e-2,
|
||||||
eps = 1e-8,
|
eps = 1e-8,
|
||||||
warmup_steps = None,
|
warmup_steps = None,
|
||||||
|
cosine_decay_max_steps = None,
|
||||||
max_grad_norm = 0.5,
|
max_grad_norm = 0.5,
|
||||||
amp = False,
|
amp = False,
|
||||||
group_wd_params = True,
|
group_wd_params = True,
|
||||||
@@ -454,7 +463,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
# be able to finely customize learning rate, weight decay
|
# be able to finely customize learning rate, weight decay
|
||||||
# per unet
|
# per unet
|
||||||
|
|
||||||
lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps))
|
lr, wd, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps, cosine_decay_max_steps))
|
||||||
|
|
||||||
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
||||||
|
|
||||||
@@ -462,7 +471,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
schedulers = []
|
schedulers = []
|
||||||
warmup_schedulers = []
|
warmup_schedulers = []
|
||||||
|
|
||||||
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
|
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps in zip(decoder.unets, lr, wd, eps, warmup_steps, cosine_decay_max_steps):
|
||||||
if isinstance(unet, nn.Identity):
|
if isinstance(unet, nn.Identity):
|
||||||
optimizers.append(None)
|
optimizers.append(None)
|
||||||
schedulers.append(None)
|
schedulers.append(None)
|
||||||
@@ -478,7 +487,11 @@ class DecoderTrainer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
optimizers.append(optimizer)
|
optimizers.append(optimizer)
|
||||||
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
|
||||||
|
if exists(unet_cosine_decay_max_steps):
|
||||||
|
scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)
|
||||||
|
else:
|
||||||
|
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
||||||
|
|
||||||
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
|
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
|
||||||
warmup_schedulers.append(warmup_scheduler)
|
warmup_schedulers.append(warmup_scheduler)
|
||||||
@@ -558,9 +571,15 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
for ind in range(0, self.num_unets):
|
for ind in range(0, self.num_unets):
|
||||||
optimizer_key = f'optim{ind}'
|
optimizer_key = f'optim{ind}'
|
||||||
|
scheduler_key = f'sched{ind}'
|
||||||
|
|
||||||
optimizer = getattr(self, optimizer_key)
|
optimizer = getattr(self, optimizer_key)
|
||||||
state_dict = optimizer.state_dict() if optimizer is not None else None
|
scheduler = getattr(self, scheduler_key)
|
||||||
save_obj = {**save_obj, optimizer_key: state_dict}
|
|
||||||
|
optimizer_state_dict = optimizer.state_dict() if exists(optimizer) else None
|
||||||
|
scheduler_state_dict = scheduler.state_dict() if exists(scheduler) else None
|
||||||
|
|
||||||
|
save_obj = {**save_obj, optimizer_key: optimizer_state_dict, scheduler_key: scheduler_state_dict}
|
||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
||||||
@@ -581,10 +600,18 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
optimizer_key = f'optim{ind}'
|
optimizer_key = f'optim{ind}'
|
||||||
optimizer = getattr(self, optimizer_key)
|
optimizer = getattr(self, optimizer_key)
|
||||||
|
|
||||||
|
scheduler_key = f'sched{ind}'
|
||||||
|
scheduler = getattr(self, scheduler_key)
|
||||||
|
|
||||||
warmup_scheduler = self.warmup_schedulers[ind]
|
warmup_scheduler = self.warmup_schedulers[ind]
|
||||||
if optimizer is not None:
|
|
||||||
|
if exists(optimizer):
|
||||||
optimizer.load_state_dict(loaded_obj[optimizer_key])
|
optimizer.load_state_dict(loaded_obj[optimizer_key])
|
||||||
|
|
||||||
|
if exists(scheduler):
|
||||||
|
scheduler.load_state_dict(loaded_obj[scheduler_key])
|
||||||
|
|
||||||
if exists(warmup_scheduler):
|
if exists(warmup_scheduler):
|
||||||
warmup_scheduler.last_step = last_step
|
warmup_scheduler.last_step = last_step
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '1.6.0'
|
__version__ = '1.8.2'
|
||||||
|
|||||||
Reference in New Issue
Block a user