mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-23 20:04:20 +01:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f23fab7ef7 | ||
|
|
857b9fbf1e | ||
|
|
8864fd0aa7 |
@@ -1195,4 +1195,12 @@ This library would not have gotten to this working state without the help of
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@misc{Saharia2022,
|
||||||
|
title = {Imagen: unprecedented photorealism × deep level of language understanding},
|
||||||
|
author = {Chitwan Saharia*, William Chan*, Saurabh Saxena†, Lala Li†, Jay Whang†, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho†, David Fleet†, Mohammad Norouzi*},
|
||||||
|
year = {2022}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||||
|
|||||||
@@ -1107,13 +1107,20 @@ class Block(nn.Module):
|
|||||||
groups = 8
|
groups = 8
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.block = nn.Sequential(
|
self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
|
||||||
nn.Conv2d(dim, dim_out, 3, padding = 1),
|
self.norm = nn.GroupNorm(groups, dim_out)
|
||||||
nn.GroupNorm(groups, dim_out),
|
self.act = nn.SiLU()
|
||||||
nn.SiLU()
|
|
||||||
)
|
def forward(self, x, scale_shift = None):
|
||||||
def forward(self, x):
|
x = self.project(x)
|
||||||
return self.block(x)
|
x = self.norm(x)
|
||||||
|
|
||||||
|
if exists(scale_shift):
|
||||||
|
scale, shift = scale_shift
|
||||||
|
x = x * (scale + 1) + shift
|
||||||
|
|
||||||
|
x = self.act(x)
|
||||||
|
return x
|
||||||
|
|
||||||
class ResnetBlock(nn.Module):
|
class ResnetBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -1132,7 +1139,7 @@ class ResnetBlock(nn.Module):
|
|||||||
if exists(time_cond_dim):
|
if exists(time_cond_dim):
|
||||||
self.time_mlp = nn.Sequential(
|
self.time_mlp = nn.Sequential(
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Linear(time_cond_dim, dim_out)
|
nn.Linear(time_cond_dim, dim_out * 2)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cross_attn = None
|
self.cross_attn = None
|
||||||
@@ -1152,11 +1159,14 @@ class ResnetBlock(nn.Module):
|
|||||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, cond = None, time_emb = None):
|
def forward(self, x, cond = None, time_emb = None):
|
||||||
h = self.block1(x)
|
|
||||||
|
|
||||||
|
scale_shift = None
|
||||||
if exists(self.time_mlp) and exists(time_emb):
|
if exists(self.time_mlp) and exists(time_emb):
|
||||||
time_emb = self.time_mlp(time_emb)
|
time_emb = self.time_mlp(time_emb)
|
||||||
h = rearrange(time_emb, 'b c -> b c 1 1') + h
|
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
|
||||||
|
scale_shift = time_emb.chunk(2, dim = 1)
|
||||||
|
|
||||||
|
h = self.block1(x, scale_shift = scale_shift)
|
||||||
|
|
||||||
if exists(self.cross_attn):
|
if exists(self.cross_attn):
|
||||||
assert exists(cond)
|
assert exists(cond)
|
||||||
@@ -1704,6 +1714,8 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
vb_loss_weight = 0.001,
|
vb_loss_weight = 0.001,
|
||||||
unconditional = False,
|
unconditional = False,
|
||||||
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
||||||
|
use_dynamic_thres = False, # from the Imagen paper
|
||||||
|
dynamic_thres_percentile = 0.9
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
beta_schedule = beta_schedule,
|
beta_schedule = beta_schedule,
|
||||||
@@ -1826,6 +1838,11 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
self.clip_denoised = clip_denoised
|
self.clip_denoised = clip_denoised
|
||||||
self.clip_x_start = clip_x_start
|
self.clip_x_start = clip_x_start
|
||||||
|
|
||||||
|
# dynamic thresholding settings, if clipping denoised during sampling
|
||||||
|
|
||||||
|
self.use_dynamic_thres = use_dynamic_thres
|
||||||
|
self.dynamic_thres_percentile = dynamic_thres_percentile
|
||||||
|
|
||||||
# normalize and unnormalize image functions
|
# normalize and unnormalize image functions
|
||||||
|
|
||||||
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
|
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
|
||||||
@@ -1868,7 +1885,21 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
||||||
|
|
||||||
if clip_denoised:
|
if clip_denoised:
|
||||||
x_recon.clamp_(-1., 1.)
|
# s is the threshold amount
|
||||||
|
# static thresholding would just be s = 1
|
||||||
|
s = 1.
|
||||||
|
if self.use_dynamic_thres:
|
||||||
|
s = torch.quantile(
|
||||||
|
rearrange(x_recon, 'b ... -> b (...)').abs(),
|
||||||
|
self.dynamic_thres_percentile,
|
||||||
|
dim = -1
|
||||||
|
)
|
||||||
|
|
||||||
|
s.clamp_(min = 1.)
|
||||||
|
s = s.view(-1, *((1,) * (x_recon.ndim - 1)))
|
||||||
|
|
||||||
|
# clip by threshold, depending on whether static or dynamic
|
||||||
|
x_recon = x_recon.clamp(-s, s) / s
|
||||||
|
|
||||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ def get_optimizer(
|
|||||||
betas = (0.9, 0.999),
|
betas = (0.9, 0.999),
|
||||||
eps = 1e-8,
|
eps = 1e-8,
|
||||||
filter_by_requires_grad = False,
|
filter_by_requires_grad = False,
|
||||||
|
group_wd_params = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if filter_by_requires_grad:
|
if filter_by_requires_grad:
|
||||||
@@ -21,11 +22,13 @@ def get_optimizer(
|
|||||||
return Adam(params, lr = lr, betas = betas, eps = eps)
|
return Adam(params, lr = lr, betas = betas, eps = eps)
|
||||||
|
|
||||||
params = set(params)
|
params = set(params)
|
||||||
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
|
||||||
|
|
||||||
param_groups = [
|
if group_wd_params:
|
||||||
{'params': list(wd_params)},
|
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||||
{'params': list(no_wd_params), 'weight_decay': 0},
|
|
||||||
]
|
|
||||||
|
|
||||||
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
params = [
|
||||||
|
{'params': list(wd_params)},
|
||||||
|
{'params': list(no_wd_params), 'weight_decay': 0},
|
||||||
|
]
|
||||||
|
|
||||||
|
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
||||||
|
|||||||
@@ -254,6 +254,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
eps = 1e-6,
|
eps = 1e-6,
|
||||||
max_grad_norm = None,
|
max_grad_norm = None,
|
||||||
amp = False,
|
amp = False,
|
||||||
|
group_wd_params = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -279,6 +280,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
lr = lr,
|
lr = lr,
|
||||||
wd = wd,
|
wd = wd,
|
||||||
eps = eps,
|
eps = eps,
|
||||||
|
group_wd_params = group_wd_params,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -410,6 +412,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
eps = 1e-8,
|
eps = 1e-8,
|
||||||
max_grad_norm = 0.5,
|
max_grad_norm = 0.5,
|
||||||
amp = False,
|
amp = False,
|
||||||
|
group_wd_params = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -435,6 +438,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
lr = unet_lr,
|
lr = unet_lr,
|
||||||
wd = unet_wd,
|
wd = unet_wd,
|
||||||
eps = unet_eps,
|
eps = unet_eps,
|
||||||
|
group_wd_params = group_wd_params,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user