mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
27f19ba7fa | ||
|
|
8f38339c2b | ||
|
|
6b9b4b9e5e | ||
|
|
44e09d5a4d | ||
|
|
34806663e3 |
10
README.md
10
README.md
@@ -1264,4 +1264,14 @@ For detailed information on training the diffusion prior, please refer to the [d
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Qiao2019WeightS,
|
||||
title = {Weight Standardization},
|
||||
author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille},
|
||||
journal = {ArXiv},
|
||||
year = {2019},
|
||||
volume = {abs/1903.10520}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||
|
||||
@@ -1279,9 +1279,12 @@ class DiffusionPrior(nn.Module):
|
||||
is_ddim = timesteps < self.noise_scheduler.num_timesteps
|
||||
|
||||
if not is_ddim:
|
||||
return self.p_sample_loop_ddpm(*args, **kwargs)
|
||||
normalized_image_embed = self.p_sample_loop_ddpm(*args, **kwargs)
|
||||
else:
|
||||
normalized_image_embed = self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
|
||||
|
||||
return self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
|
||||
image_embed = normalized_image_embed / self.image_embed_scale
|
||||
return image_embed
|
||||
|
||||
def p_losses(self, image_embed, times, text_cond, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(image_embed))
|
||||
@@ -1350,8 +1353,6 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
# retrieve original unscaled image embed
|
||||
|
||||
image_embeds /= self.image_embed_scale
|
||||
|
||||
text_embeds = text_cond['text_embed']
|
||||
|
||||
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
|
||||
@@ -1450,6 +1451,26 @@ def Downsample(dim, *, dim_out = None):
|
||||
dim_out = default(dim_out, dim)
|
||||
return nn.Conv2d(dim, dim_out, 4, 2, 1)
|
||||
|
||||
class WeightStandardizedConv2d(nn.Conv2d):
|
||||
"""
|
||||
https://arxiv.org/abs/1903.10520
|
||||
weight standardization purportedly works synergistically with group normalization
|
||||
"""
|
||||
def forward(self, x):
|
||||
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
|
||||
|
||||
weight = self.weight
|
||||
flattened_weights = rearrange(weight, 'o ... -> o (...)')
|
||||
|
||||
mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
|
||||
|
||||
var = torch.var(flattened_weights, dim = -1, unbiased = False)
|
||||
var = rearrange(var, 'o -> o 1 1 1')
|
||||
|
||||
weight = (weight - mean) * (var + eps).rsqrt()
|
||||
|
||||
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
@@ -1468,10 +1489,13 @@ class Block(nn.Module):
|
||||
self,
|
||||
dim,
|
||||
dim_out,
|
||||
groups = 8
|
||||
groups = 8,
|
||||
weight_standardization = False
|
||||
):
|
||||
super().__init__()
|
||||
self.project = nn.Conv2d(dim, dim_out, 3, padding = 1)
|
||||
conv_klass = nn.Conv2d if not weight_standardization else WeightStandardizedConv2d
|
||||
|
||||
self.project = conv_klass(dim, dim_out, 3, padding = 1)
|
||||
self.norm = nn.GroupNorm(groups, dim_out)
|
||||
self.act = nn.SiLU()
|
||||
|
||||
@@ -1495,6 +1519,7 @@ class ResnetBlock(nn.Module):
|
||||
cond_dim = None,
|
||||
time_cond_dim = None,
|
||||
groups = 8,
|
||||
weight_standardization = False,
|
||||
cosine_sim_cross_attn = False
|
||||
):
|
||||
super().__init__()
|
||||
@@ -1520,8 +1545,8 @@ class ResnetBlock(nn.Module):
|
||||
)
|
||||
)
|
||||
|
||||
self.block1 = Block(dim, dim_out, groups = groups)
|
||||
self.block2 = Block(dim_out, dim_out, groups = groups)
|
||||
self.block1 = Block(dim, dim_out, groups = groups, weight_standardization = weight_standardization)
|
||||
self.block2 = Block(dim_out, dim_out, groups = groups, weight_standardization = weight_standardization)
|
||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
||||
|
||||
def forward(self, x, time_emb = None, cond = None):
|
||||
@@ -1746,6 +1771,7 @@ class Unet(nn.Module):
|
||||
init_dim = None,
|
||||
init_conv_kernel_size = 7,
|
||||
resnet_groups = 8,
|
||||
resnet_weight_standardization = False,
|
||||
num_resnet_blocks = 2,
|
||||
init_cross_embed = True,
|
||||
init_cross_embed_kernel_sizes = (3, 7, 15),
|
||||
@@ -1893,7 +1919,7 @@ class Unet(nn.Module):
|
||||
|
||||
# prepare resnet klass
|
||||
|
||||
resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn)
|
||||
resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn, weight_standardization = resnet_weight_standardization)
|
||||
|
||||
# give memory efficient unet an initial resnet block
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from collections.abc import Iterable
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
|
||||
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||
@@ -181,7 +181,8 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
eps = 1e-6,
|
||||
max_grad_norm = None,
|
||||
group_wd_params = True,
|
||||
warmup_steps = 1,
|
||||
warmup_steps = None,
|
||||
cosine_decay_max_steps = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -233,8 +234,11 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
**self.optim_kwargs,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
|
||||
|
||||
if exists(cosine_decay_max_steps):
|
||||
self.scheduler = CosineAnnealingLR(optimizer, T_max = cosine_decay_max_steps)
|
||||
else:
|
||||
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
|
||||
|
||||
self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
|
||||
|
||||
@@ -271,6 +275,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
# FIXME: LambdaLR can't be saved due to pickling issues
|
||||
save_obj = dict(
|
||||
optimizer = self.optimizer.state_dict(),
|
||||
scheduler = self.scheduler.state_dict(),
|
||||
warmup_scheduler = self.warmup_scheduler,
|
||||
model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
|
||||
version = version.parse(__version__),
|
||||
@@ -317,7 +322,9 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
# unwrap the model when loading from checkpoint
|
||||
self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
|
||||
self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
|
||||
|
||||
self.optimizer.load_state_dict(loaded_obj['optimizer'])
|
||||
self.scheduler.load_state_dict(loaded_obj['scheduler'])
|
||||
|
||||
# set warmupstep
|
||||
if exists(self.warmup_scheduler):
|
||||
@@ -350,7 +357,8 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
# accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
|
||||
if not self.accelerator.optimizer_step_was_skipped:
|
||||
with self.warmup_scheduler.dampening():
|
||||
sched_context = self.warmup_scheduler.dampening if exists(self.warmup_scheduler) else nullcontext
|
||||
with sched_context():
|
||||
self.scheduler.step()
|
||||
|
||||
if self.use_ema:
|
||||
@@ -433,6 +441,7 @@ class DecoderTrainer(nn.Module):
|
||||
wd = 1e-2,
|
||||
eps = 1e-8,
|
||||
warmup_steps = None,
|
||||
cosine_decay_max_steps = None,
|
||||
max_grad_norm = 0.5,
|
||||
amp = False,
|
||||
group_wd_params = True,
|
||||
@@ -454,7 +463,7 @@ class DecoderTrainer(nn.Module):
|
||||
# be able to finely customize learning rate, weight decay
|
||||
# per unet
|
||||
|
||||
lr, wd, eps, warmup_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps))
|
||||
lr, wd, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps, cosine_decay_max_steps))
|
||||
|
||||
assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
|
||||
|
||||
@@ -462,7 +471,7 @@ class DecoderTrainer(nn.Module):
|
||||
schedulers = []
|
||||
warmup_schedulers = []
|
||||
|
||||
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps in zip(decoder.unets, lr, wd, eps, warmup_steps):
|
||||
for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps in zip(decoder.unets, lr, wd, eps, warmup_steps, cosine_decay_max_steps):
|
||||
if isinstance(unet, nn.Identity):
|
||||
optimizers.append(None)
|
||||
schedulers.append(None)
|
||||
@@ -478,7 +487,11 @@ class DecoderTrainer(nn.Module):
|
||||
)
|
||||
|
||||
optimizers.append(optimizer)
|
||||
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
||||
|
||||
if exists(unet_cosine_decay_max_steps):
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)
|
||||
else:
|
||||
scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
|
||||
|
||||
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
|
||||
warmup_schedulers.append(warmup_scheduler)
|
||||
@@ -558,9 +571,15 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
for ind in range(0, self.num_unets):
|
||||
optimizer_key = f'optim{ind}'
|
||||
scheduler_key = f'sched{ind}'
|
||||
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
state_dict = optimizer.state_dict() if optimizer is not None else None
|
||||
save_obj = {**save_obj, optimizer_key: state_dict}
|
||||
scheduler = getattr(self, scheduler_key)
|
||||
|
||||
optimizer_state_dict = optimizer.state_dict() if exists(optimizer) else None
|
||||
scheduler_state_dict = scheduler.state_dict() if exists(scheduler) else None
|
||||
|
||||
save_obj = {**save_obj, optimizer_key: optimizer_state_dict, scheduler_key: scheduler_state_dict}
|
||||
|
||||
if self.use_ema:
|
||||
save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
|
||||
@@ -581,10 +600,18 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
optimizer_key = f'optim{ind}'
|
||||
optimizer = getattr(self, optimizer_key)
|
||||
|
||||
scheduler_key = f'sched{ind}'
|
||||
scheduler = getattr(self, scheduler_key)
|
||||
|
||||
warmup_scheduler = self.warmup_schedulers[ind]
|
||||
if optimizer is not None:
|
||||
|
||||
if exists(optimizer):
|
||||
optimizer.load_state_dict(loaded_obj[optimizer_key])
|
||||
|
||||
if exists(scheduler):
|
||||
scheduler.load_state_dict(loaded_obj[scheduler_key])
|
||||
|
||||
if exists(warmup_scheduler):
|
||||
warmup_scheduler.last_step = last_step
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '1.6.4'
|
||||
__version__ = '1.8.2'
|
||||
|
||||
Reference in New Issue
Block a user