Compare commits

..

43 Commits

Author SHA1 Message Date
Phil Wang
1992d25cad project management 2022-05-04 11:18:54 -07:00
Phil Wang
5b619c2fd5 make sure some hyperparameters for unet block is configurable 2022-05-04 11:18:32 -07:00
Phil Wang
9359ad2e91 0.0.95 2022-05-04 10:53:05 -07:00
Phil Wang
9ff228188b offer old resnet blocks, from the original DDPM paper, just in case convnexts are unsuitable for generative work 2022-05-04 10:52:58 -07:00
Kumar R
2d9963d30e Reporting metrics - Cosine similarity. (#55)
* Update train_diffusion_prior.py

* Delete train_diffusion_prior.py

* Cosine similarity logging.

* Update train_diffusion_prior.py

* Report Cosine metrics every N steps.
2022-05-04 08:04:36 -07:00
Phil Wang
58d9b422f3 0.0.94 2022-05-04 07:42:33 -07:00
Ray Bell
44b319cb57 add missing import (#56) 2022-05-04 07:42:20 -07:00
Phil Wang
c30f380689 final reminder 2022-05-03 08:18:53 -07:00
Phil Wang
e4e884bb8b keep all doors open 2022-05-03 08:17:02 -07:00
Phil Wang
803ad9c17d product management again 2022-05-03 08:15:25 -07:00
Phil Wang
a88dd6a9c0 todo 2022-05-03 08:09:02 -07:00
Kumar R
72c16b496e Update train_diffusion_prior.py (#53) 2022-05-02 22:44:57 -07:00
z
81d83dd7f2 defaults align with paper (#52)
Co-authored-by: nousr <>
2022-05-02 13:52:11 -07:00
Phil Wang
fa66f7e1e9 todo 2022-05-02 12:57:15 -07:00
Phil Wang
aa8d135245 allow laion to experiment with normformer in diffusion prior 2022-05-02 11:35:00 -07:00
Phil Wang
70282de23b add ability to turn on normformer settings, given @borisdayma reported good results and some personal anecdata 2022-05-02 11:33:15 -07:00
Phil Wang
83f761847e todo 2022-05-02 10:52:39 -07:00
Phil Wang
11469dc0c6 makes more sense to keep this as True as default, for stability 2022-05-02 10:50:55 -07:00
Romain Beaumont
2d25c89f35 Fix passing of l2norm_output to DiffusionPriorNetwork (#51) 2022-05-02 10:48:16 -07:00
Phil Wang
3fe96c208a add ability to train diffusion prior with l2norm on output image embed 2022-05-02 09:53:20 -07:00
Phil Wang
0fc6c9cdf3 provide option to l2norm the output of the diffusion prior 2022-05-02 09:41:03 -07:00
Phil Wang
7ee0ecc388 mixed precision for training diffusion prior + save optimizer and scaler states 2022-05-02 09:31:04 -07:00
Phil Wang
1924c7cc3d fix issue with mixed precision and gradient clipping 2022-05-02 09:20:19 -07:00
Phil Wang
f7df3caaf3 address not calculating average eval / test loss when training diffusion prior https://github.com/lucidrains/DALLE2-pytorch/issues/49 2022-05-02 08:51:41 -07:00
Phil Wang
fc954ee788 fix calculation of adaptive weight for vit-vqgan, thanks to @CiaoHe 2022-05-02 07:58:14 -07:00
Phil Wang
c1db2753f5 todo 2022-05-01 18:02:30 -07:00
Phil Wang
ad87bfe28f switch to using linear attention for the sparse attention layers within unet, given success in GAN projects 2022-05-01 17:59:03 -07:00
Phil Wang
76c767b1ce update deps, commit to using webdatasets, per @rom1504 consultation 2022-05-01 12:22:15 -07:00
Phil Wang
d991b8c39c just clip the diffusion prior network parameters 2022-05-01 12:01:08 -07:00
Phil Wang
902693e271 todo 2022-05-01 11:57:08 -07:00
Phil Wang
35cd63982d add gradient clipping, make sure weight decay is configurable, make sure learning rate is actually passed into get_optimizer, make sure model is set to training mode at beginning of each epoch 2022-05-01 11:55:38 -07:00
Kumar R
53ce6dfdf6 All changes implemented, current run happening. Link to wandb run in comments. (#43)
* Train DiffusionPrior with pre-computed embeddings

This is in response to https://github.com/lucidrains/DALLE2-pytorch/issues/29 - more metrics will get added.
2022-05-01 11:46:59 -07:00
Phil Wang
ad8d7a368b product management 2022-05-01 11:26:21 -07:00
Phil Wang
b8cf1e5c20 more attention 2022-05-01 11:00:33 -07:00
Phil Wang
94aaa08d97 product management 2022-05-01 09:43:10 -07:00
Phil Wang
8b9bbec7d1 project management 2022-05-01 09:32:57 -07:00
Phil Wang
1bb9fc9829 add convnext backbone for vqgan-vae, still need to fix groupnorms in resnet encdec 2022-05-01 09:32:24 -07:00
Phil Wang
5e421bd5bb let researchers do the hyperparameter search 2022-05-01 08:46:21 -07:00
Phil Wang
67fcab1122 add MLP based time conditioning to all convnexts, in addition to cross attention. also add an initial convolution, given convnext first depthwise conv 2022-05-01 08:41:02 -07:00
Phil Wang
5bfbccda22 port over vqgan vae trainer 2022-05-01 08:09:15 -07:00
Phil Wang
989275ff59 product management 2022-04-30 16:57:56 -07:00
Phil Wang
56408f4a40 project management 2022-04-30 16:57:02 -07:00
Phil Wang
d1a697ac23 allows one to shortcut sampling at a specific unet number, if one were to be training in stages 2022-04-30 16:05:13 -07:00
8 changed files with 923 additions and 48 deletions

View File

@@ -783,7 +783,7 @@ for unet_number in (1, 2):
# you can sample from the exponentially moving averaged unets as so
mock_image_embed = torch.randn(4, 512).cuda()
images = decoder.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
```
## CLI (wip)
@@ -818,14 +818,24 @@ Once built, images will be saved to the same directory the command is invoked
- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
- [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in
- [x] take care of mixed precision as well as gradient accumulation within decoder trainer
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [x] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer
- [x] bring in tools to train vqgan-vae
- [x] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet)
- [x] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation)
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo)
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer
- [ ] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
- [ ] train on a toy task, offer in colab
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] bring in tools to train vqgan-vae
- [ ] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
- [ ] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
## Citations
@@ -857,12 +867,22 @@ Once built, images will be saved to the same directory the command is invoked
```bibtex
@inproceedings{Liu2022ACF,
title = {A ConvNet for the 2020https://arxiv.org/abs/2112.11435s},
title = {A ConvNet for the 2020s},
author = {Zhuang Liu and Hanzi Mao and Chaozheng Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
year = {2022}
}
```
```bibtex
@article{shen2019efficient,
author = {Zhuoran Shen and Mingyuan Zhang and Haiyu Zhao and Shuai Yi and Hongsheng Li},
title = {Efficient Attention: Attention with Linear Complexities},
journal = {CoRR},
year = {2018},
url = {http://arxiv.org/abs/1812.01243},
}
```
```bibtex
@inproceedings{Tu2022MaxViTMV,
title = {MaxViT: Multi-Axis Vision Transformer},
@@ -881,4 +901,14 @@ Once built, images will be saved to the same directory the command is invoked
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>
```bibtex
@article{Shleifer2021NormFormerIT,
title = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
author = {Sam Shleifer and Jason Weston and Myle Ott},
journal = {ArXiv},
year = {2021},
volume = {abs/2110.09456}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -1,6 +1,7 @@
import click
import torch
import torchvision.transforms as T
from functools import reduce
from pathlib import Path
from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior

View File

@@ -29,6 +29,9 @@ from x_clip import CLIP
def exists(val):
return val is not None
def identity(t, *args, **kwargs):
return t
def default(val, d):
if exists(val):
return val
@@ -496,7 +499,12 @@ class SwiGLU(nn.Module):
x, gate = x.chunk(2, dim = -1)
return x * F.silu(gate)
def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False):
def FeedForward(
dim,
mult = 4,
dropout = 0.,
post_activation_norm = False
):
""" post-activation norm https://arxiv.org/abs/2110.09456 """
inner_dim = int(mult * dim)
@@ -519,7 +527,8 @@ class Attention(nn.Module):
dim_head = 64,
heads = 8,
dropout = 0.,
causal = False
causal = False,
post_norm = False
):
super().__init__()
self.scale = dim_head ** -0.5
@@ -534,7 +543,11 @@ class Attention(nn.Module):
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim) if post_norm else nn.Identity()
)
def forward(self, x, mask = None, attn_bias = None):
b, n, device = *x.shape[:2], x.device
@@ -596,10 +609,11 @@ class CausalTransformer(nn.Module):
dim_head = 64,
heads = 8,
ff_mult = 4,
norm_out = False,
norm_out = True,
attn_dropout = 0.,
ff_dropout = 0.,
final_proj = True
final_proj = True,
normformer = False
):
super().__init__()
self.rel_pos_bias = RelPosBias(heads = heads)
@@ -607,8 +621,8 @@ class CausalTransformer(nn.Module):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
]))
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
@@ -635,12 +649,14 @@ class DiffusionPriorNetwork(nn.Module):
self,
dim,
num_timesteps = None,
l2norm_output = False, # whether to restrict image embedding output with l2norm at the end (may make it easier to learn?)
**kwargs
):
super().__init__()
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
self.learned_query = nn.Parameter(torch.randn(dim))
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
self.l2norm_output = l2norm_output
def forward_with_cond_scale(
self,
@@ -719,7 +735,8 @@ class DiffusionPriorNetwork(nn.Module):
pred_image_embed = tokens[..., -1, :]
return pred_image_embed
output_fn = l2norm if self.l2norm_output else identity
return output_fn(pred_image_embed)
class DiffusionPrior(BaseGaussianDiffusion):
def __init__(
@@ -913,6 +930,72 @@ class SinusoidalPosEmb(nn.Module):
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
return torch.cat((emb.sin(), emb.cos()), dim = -1)
class Block(nn.Module):
def __init__(
self,
dim,
dim_out,
groups = 8
):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(dim, dim_out, 3, padding = 1),
nn.GroupNorm(groups, dim_out),
nn.SiLU()
)
def forward(self, x):
return self.block(x)
class ResnetBlock(nn.Module):
def __init__(
self,
dim,
dim_out,
*,
cond_dim = None,
time_cond_dim = None,
groups = 8
):
super().__init__()
self.time_mlp = None
if exists(time_cond_dim):
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_cond_dim, dim_out)
)
self.cross_attn = None
if exists(cond_dim):
self.cross_attn = EinopsToAndFrom(
'b c h w',
'b (h w) c',
CrossAttention(
dim = dim_out,
context_dim = cond_dim
)
)
self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, cond = None, time_emb = None):
h = self.block1(x)
if exists(self.time_mlp) and exists(time_emb):
time_emb = self.time_mlp(time_emb)
h = rearrange(time_emb, 'b c -> b c 1 1') + h
if exists(self.cross_attn):
assert exists(cond)
h = self.cross_attn(h, context = cond) + h
h = self.block2(h)
return h + self.res_conv(x)
class ConvNextBlock(nn.Module):
""" https://arxiv.org/abs/2201.03545 """
@@ -922,8 +1005,8 @@ class ConvNextBlock(nn.Module):
dim_out,
*,
cond_dim = None,
mult = 2,
norm = True
time_cond_dim = None,
mult = 2
):
super().__init__()
need_projection = dim != dim_out
@@ -940,11 +1023,19 @@ class ConvNextBlock(nn.Module):
)
)
self.time_mlp = None
if exists(time_cond_dim):
self.time_mlp = nn.Sequential(
nn.GELU(),
nn.Linear(time_cond_dim, dim)
)
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
inner_dim = int(dim_out * mult)
self.net = nn.Sequential(
ChanLayerNorm(dim) if norm else nn.Identity(),
ChanLayerNorm(dim),
nn.Conv2d(dim, inner_dim, 3, padding = 1),
nn.GELU(),
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
@@ -952,9 +1043,13 @@ class ConvNextBlock(nn.Module):
self.res_conv = nn.Conv2d(dim, dim_out, 1) if need_projection else nn.Identity()
def forward(self, x, cond = None):
def forward(self, x, cond = None, time = None):
h = self.ds_conv(x)
if exists(time) and exists(self.time_mlp):
t = self.time_mlp(time)
h = rearrange(t, 'b c -> b c 1 1') + h
if exists(self.cross_attn):
assert exists(cond)
h = self.cross_attn(h, context = cond) + h
@@ -1037,6 +1132,46 @@ class GridAttention(nn.Module):
out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz)
return out
class LinearAttention(nn.Module):
def __init__(
self,
dim,
dim_head = 32,
heads = 8
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
self.norm = ChanLayerNorm(dim)
self.nonlin = nn.GELU()
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1, bias = False),
ChanLayerNorm(dim)
)
def forward(self, fmap):
h, x, y = self.heads, *fmap.shape[-2:]
fmap = self.norm(fmap)
q, k, v = self.to_qkv(fmap).chunk(3, dim = 1)
q, k, v = rearrange_many((q, k, v), 'b (h c) x y -> (b h) (x y) c', h = h)
q = q.softmax(dim = -1)
k = k.softmax(dim = -2)
q = q * self.scale
context = einsum('b n d, b n e -> b d e', k, v)
out = einsum('b n d, b d e -> b n e', q, context)
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)
out = self.nonlin(out)
return self.to_out(out)
class Unet(nn.Module):
def __init__(
self,
@@ -1051,14 +1186,19 @@ class Unet(nn.Module):
dim_mults=(1, 2, 4, 8),
channels = 3,
attn_dim_head = 32,
attn_heads = 8,
attn_heads = 16,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
sparse_attn = False,
sparse_attn_window = 8, # window size for sparse attention
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
cond_on_text_encodings = False,
max_text_len = 256,
cond_on_image_embeds = False,
init_dim = None,
init_conv_kernel_size = 7,
block_type = 'resnet',
block_resnet_groups = 8,
block_convnext_mult = 2,
**kwargs
):
super().__init__()
# save locals to take care of some hyperparameters for cascading DDPM
@@ -1076,22 +1216,34 @@ class Unet(nn.Module):
self.channels = channels
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
init_dim = default(init_dim, dim // 2)
dims = [init_channels, *map(lambda m: dim * m, dim_mults)]
assert (init_conv_kernel_size % 2) == 1
self.init_conv = nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
# time, image embeddings, and optional text encoding
cond_dim = default(cond_dim, dim)
time_cond_dim = dim * 4
self.time_mlp = nn.Sequential(
self.to_time_hiddens = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, cond_dim * num_time_tokens),
nn.Linear(dim, time_cond_dim),
nn.GELU()
)
self.to_time_tokens = nn.Sequential(
nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
Rearrange('b (r d) -> b r d', r = num_time_tokens)
)
self.to_time_cond = nn.Sequential(
nn.Linear(time_cond_dim, time_cond_dim)
)
self.image_to_cond = nn.Sequential(
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
Rearrange('b (n d) -> b n d', n = num_image_tokens)
@@ -1121,6 +1273,15 @@ class Unet(nn.Module):
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
# whether to use resnet or the (improved?) convnext blocks
if block_type == 'resnet':
block_klass = partial(ResnetBlock, groups = block_resnet_groups)
elif block_type == 'convnext':
block_klass = partial(ConvNextBlock, mult = block_convnext_mult)
else:
raise ValueError(f'unimplemented block type {block_type}')
# layers
self.downs = nn.ModuleList([])
@@ -1133,32 +1294,32 @@ class Unet(nn.Module):
layer_cond_dim = cond_dim if not is_first else None
self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
block_klass(dim_in, dim_out, time_cond_dim = time_cond_dim),
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
block_klass(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Downsample(dim_out) if not is_last else nn.Identity()
]))
mid_dim = dims[-1]
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
self.mid_block1 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
self.mid_block2 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 2)
layer_cond_dim = cond_dim if not is_last else None
self.ups.append(nn.ModuleList([
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
block_klass(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
block_klass(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Upsample(dim_in)
]))
out_dim = default(out_dim, channels)
self.final_conv = nn.Sequential(
ConvNextBlock(dim, dim),
block_klass(dim, dim),
nn.Conv2d(dim, out_dim, 1)
)
@@ -1214,9 +1375,16 @@ class Unet(nn.Module):
if exists(lowres_cond_img):
x = torch.cat((x, lowres_cond_img), dim = 1)
# initial convolution
x = self.init_conv(x)
# time conditioning
time_tokens = self.time_mlp(time)
time_hiddens = self.to_time_hiddens(time)
time_tokens = self.to_time_tokens(time_hiddens)
t = self.to_time_cond(time_hiddens)
# conditional dropout
@@ -1282,25 +1450,25 @@ class Unet(nn.Module):
hiddens = []
for convnext, sparse_attn, convnext2, downsample in self.downs:
x = convnext(x, c)
for block1, sparse_attn, block2, downsample in self.downs:
x = block1(x, c, t)
x = sparse_attn(x)
x = convnext2(x, c)
x = block2(x, c, t)
hiddens.append(x)
x = downsample(x)
x = self.mid_block1(x, mid_c)
x = self.mid_block1(x, mid_c, t)
if exists(self.mid_attn):
x = self.mid_attn(x)
x = self.mid_block2(x, mid_c)
x = self.mid_block2(x, mid_c, t)
for convnext, sparse_attn, convnext2, upsample in self.ups:
for block1, sparse_attn, block2, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1)
x = convnext(x, c)
x = block1(x, c, t)
x = sparse_attn(x)
x = convnext2(x, c)
x = block2(x, c, t)
x = upsample(x)
return self.final_conv(x)

View File

@@ -159,12 +159,13 @@ class DecoderTrainer(nn.Module):
index = unet_number - 1
unet = self.decoder.unets[index]
if exists(self.max_grad_norm):
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
optimizer = getattr(self, f'optim{index}')
scaler = getattr(self, f'scaler{index}')
if exists(self.max_grad_norm):
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()

View File

@@ -0,0 +1,266 @@
from math import sqrt
import copy
from random import choice
from pathlib import Path
from shutil import rmtree
import torch
from torch import nn
from PIL import Image
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.utils import make_grid, save_image
from einops import rearrange
from dalle2_pytorch.train import EMA
from dalle2_pytorch.vqgan_vae import VQGanVAE
from dalle2_pytorch.optimizer import get_optimizer
# helpers
def exists(val):
return val is not None
def noop(*args, **kwargs):
pass
def cycle(dl):
while True:
for data in dl:
yield data
def cast_tuple(t):
return t if isinstance(t, (tuple, list)) else (t,)
def yes_or_no(question):
answer = input(f'{question} (y/n) ')
return answer.lower() in ('yes', 'y')
def accum_log(log, new_logs):
for key, new_value in new_logs.items():
old_value = log.get(key, 0.)
log[key] = old_value + new_value
return log
# classes
class ImageDataset(Dataset):
def __init__(
self,
folder,
image_size,
exts = ['jpg', 'jpeg', 'png']
):
super().__init__()
self.folder = folder
self.image_size = image_size
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
print(f'{len(self.paths)} training samples found at {folder}')
self.transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize(image_size),
T.RandomHorizontalFlip(),
T.CenterCrop(image_size),
T.ToTensor()
])
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
return self.transform(img)
# main trainer class
class VQGanVAETrainer(nn.Module):
def __init__(
self,
vae,
*,
num_train_steps,
lr,
batch_size,
folder,
grad_accum_every,
wd = 0.,
save_results_every = 100,
save_model_every = 1000,
results_folder = './results',
valid_frac = 0.05,
random_split_seed = 42,
ema_beta = 0.995,
ema_update_after_step = 2000,
ema_update_every = 10,
apply_grad_penalty_every = 4,
):
super().__init__()
assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
image_size = vae.image_size
self.vae = vae
self.ema_vae = EMA(vae, update_after_step = ema_update_after_step, update_every = ema_update_every)
self.register_buffer('steps', torch.Tensor([0]))
self.num_train_steps = num_train_steps
self.batch_size = batch_size
self.grad_accum_every = grad_accum_every
all_parameters = set(vae.parameters())
discr_parameters = set(vae.discr.parameters())
vae_parameters = all_parameters - discr_parameters
self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
# create dataset
self.ds = ImageDataset(folder, image_size = image_size)
# split for validation
if valid_frac > 0:
train_size = int((1 - valid_frac) * len(self.ds))
valid_size = len(self.ds) - train_size
self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
else:
self.valid_ds = self.ds
print(f'training with shared training and valid dataset of {len(self.ds)} samples')
# dataloader
self.dl = cycle(DataLoader(
self.ds,
batch_size = batch_size,
shuffle = True
))
self.valid_dl = cycle(DataLoader(
self.valid_ds,
batch_size = batch_size,
shuffle = True
))
self.save_model_every = save_model_every
self.save_results_every = save_results_every
self.apply_grad_penalty_every = apply_grad_penalty_every
self.results_folder = Path(results_folder)
if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
rmtree(str(self.results_folder))
self.results_folder.mkdir(parents = True, exist_ok = True)
def train_step(self):
device = next(self.vae.parameters()).device
steps = int(self.steps.item())
apply_grad_penalty = not (steps % self.apply_grad_penalty_every)
self.vae.train()
# logs
logs = {}
# update vae (generator)
for _ in range(self.grad_accum_every):
img = next(self.dl)
img = img.to(device)
loss = self.vae(
img,
return_loss = True,
apply_grad_penalty = apply_grad_penalty
)
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
(loss / self.grad_accum_every).backward()
self.optim.step()
self.optim.zero_grad()
# update discriminator
if exists(self.vae.discr):
discr_loss = 0
for _ in range(self.grad_accum_every):
img = next(self.dl)
img = img.to(device)
loss = self.vae(img, return_discr_loss = True)
accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
(loss / self.grad_accum_every).backward()
self.discr_optim.step()
self.discr_optim.zero_grad()
# log
print(f"{steps}: vae loss: {logs['loss']} - discr loss: {logs['discr_loss']}")
# update exponential moving averaged generator
self.ema_vae.update()
# sample results every so often
if not (steps % self.save_results_every):
for model, filename in ((self.ema_vae.ema_model, f'{steps}.ema'), (self.vae, str(steps))):
model.eval()
imgs = next(self.dl)
imgs = imgs.to(device)
recons = model(imgs)
nrows = int(sqrt(self.batch_size))
imgs_and_recons = torch.stack((imgs, recons), dim = 0)
imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...')
imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.)
grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1))
logs['reconstructions'] = grid
save_image(grid, str(self.results_folder / f'{filename}.png'))
print(f'{steps}: saving to {str(self.results_folder)}')
# save model every so often
if not (steps % self.save_model_every):
state_dict = self.vae.state_dict()
model_path = str(self.results_folder / f'vae.{steps}.pt')
torch.save(state_dict, model_path)
ema_state_dict = self.ema_vae.state_dict()
model_path = str(self.results_folder / f'vae.{steps}.ema.pt')
torch.save(ema_state_dict, model_path)
print(f'{steps}: saving model to {str(self.results_folder)}')
self.steps += 1
return logs
def train(self, log_fn = noop):
device = next(self.vae.parameters()).device
while self.steps < self.num_train_steps:
logs = self.train_step()
log_fn(logs)
print('training complete')

View File

@@ -285,6 +285,10 @@ class ResnetEncDec(nn.Module):
def get_encoded_fmap_size(self, image_size):
return image_size // (2 ** self.layers)
@property
def last_dec_layer(self):
return self.decoders[-1].weight
def encode(self, x):
for enc in self.encoders:
x = enc(x)
@@ -327,6 +331,112 @@ class ResBlock(nn.Module):
def forward(self, x):
return self.net(x) + x
# convnext enc dec
class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g
class ConvNext(nn.Module):
def __init__(self, dim, mult = 4, kernel_size = 3, ds_kernel_size = 7):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
nn.Conv2d(dim, dim, ds_kernel_size, padding = ds_kernel_size // 2, groups = dim),
ChanLayerNorm(dim),
nn.Conv2d(dim, inner_dim, kernel_size, padding = kernel_size // 2),
nn.GELU(),
nn.Conv2d(inner_dim, dim, kernel_size, padding = kernel_size // 2)
)
def forward(self, x):
return self.net(x) + x
class ConvNextEncDec(nn.Module):
def __init__(
self,
dim,
*,
channels = 3,
layers = 4,
layer_mults = None,
num_blocks = 1,
first_conv_kernel_size = 5,
use_attn = True,
attn_dim_head = 64,
attn_heads = 8,
attn_dropout = 0.,
):
super().__init__()
self.layers = layers
self.encoders = MList([])
self.decoders = MList([])
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
layer_dims = [dim * mult for mult in layer_mults]
dims = (dim, *layer_dims)
self.encoded_dim = dims[-1]
dim_pairs = zip(dims[:-1], dims[1:])
append = lambda arr, t: arr.append(t)
prepend = lambda arr, t: arr.insert(0, t)
if not isinstance(num_blocks, tuple):
num_blocks = (*((0,) * (layers - 1)), num_blocks)
if not isinstance(use_attn, tuple):
use_attn = (*((False,) * (layers - 1)), use_attn)
assert len(num_blocks) == layers, 'number of blocks config must be equal to number of layers'
assert len(use_attn) == layers
for layer_index, (dim_in, dim_out), layer_num_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_blocks, use_attn):
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))
if layer_use_attn:
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
for _ in range(layer_num_blocks):
append(self.encoders, ConvNext(dim_out))
prepend(self.decoders, ConvNext(dim_out))
if layer_use_attn:
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
append(self.decoders, nn.Conv2d(dim, channels, 1))
def get_encoded_fmap_size(self, image_size):
return image_size // (2 ** self.layers)
@property
def last_dec_layer(self):
return self.decoders[-1].weight
def encode(self, x):
for enc in self.encoders:
x = enc(x)
return x
def decode(self, x):
for dec in self.decoders:
x = dec(x)
return x
# vqgan attention layer
class VQGanAttention(nn.Module):
@@ -504,6 +614,10 @@ class ViTEncDec(nn.Module):
def get_encoded_fmap_size(self, image_size):
return image_size // self.patch_size
@property
def last_dec_layer(self):
return self.decoder[-3][-1].weight
def encode(self, x):
return self.encoder(x)
@@ -568,6 +682,8 @@ class VQGanVAE(nn.Module):
enc_dec_klass = ResnetEncDec
elif vae_type == 'vit':
enc_dec_klass = ViTEncDec
elif vae_type == 'convnext':
enc_dec_klass = ConvNextEncDec
else:
raise ValueError(f'{vae_type} not valid')
@@ -739,7 +855,7 @@ class VQGanVAE(nn.Module):
# calculate adaptive weight
last_dec_layer = self.decoders[-1].weight
last_dec_layer = self.enc_dec.last_dec_layer
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.82',
version = '0.0.96',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
@@ -26,12 +26,14 @@ setup(
'clip-anytorch',
'einops>=0.4',
'einops-exts>=0.0.3',
'embedding-reader',
'kornia>=0.5.4',
'pillow',
'torch>=1.10',
'torchvision',
'tqdm',
'vector-quantize-pytorch',
'webdataset',
'x-clip>=0.5.1',
'youtokentome'
],

291
train_diffusion_prior.py Normal file
View File

@@ -0,0 +1,291 @@
import os
import math
import argparse
import numpy as np
import torch
from torch import nn
from embedding_reader import EmbeddingReader
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
from dalle2_pytorch.optimizer import get_optimizer
from torch.cuda.amp import autocast,GradScaler
import time
from tqdm import tqdm
import wandb
os.environ["WANDB_SILENT"] = "true"
NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training
REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
model.eval()
with torch.no_grad():
total_loss = 0.
total_samples = 0.
for emb_images, emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end),
text_reader(batch_size=batch_size, start=start, end=end)):
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
batches = emb_images_tensor.shape[0]
loss = model(text_embed = emb_text_tensor, image_embed = emb_images_tensor)
total_loss += loss.item() * batches
total_samples += batches
avg_loss = (total_loss / total_samples)
wandb.log({f'{phase} {loss_type}': avg_loss})
def save_model(save_path, state_dict):
# Saving State Dict
print("====================================== Saving checkpoint ======================================")
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,val_set_size,NUM_TEST_EMBEDDINGS,device):
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
tstart = train_set_size+val_set_size
tend = train_set_size+val_set_size+NUM_TEST_EMBEDDINGS
for embt, embi in zip(text_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend),image_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend)):
text_embed = torch.tensor(embt[0]).to(device)
text_embed = text_embed / text_embed.norm(dim=1, keepdim=True)
test_text_cond = dict(text_embed = text_embed)
test_image_embeddings = torch.tensor(embi[0]).to(device)
test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(dim=1, keepdim=True)
predicted_image_embeddings = diffusion_prior.p_sample_loop((NUM_TEST_EMBEDDINGS, 768), text_cond = test_text_cond)
predicted_image_embeddings = predicted_image_embeddings / predicted_image_embeddings.norm(dim=1, keepdim=True)
original_similarity = cos(text_embed,test_image_embeddings).cpu().numpy()
predicted_similarity = cos(text_embed,predicted_image_embeddings).cpu().numpy()
wandb.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)})
wandb.log({"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity)})
return np.mean(predicted_similarity - original_similarity)
def train(image_embed_dim,
image_embed_url,
text_embed_url,
batch_size,
train_percent,
val_percent,
test_percent,
num_epochs,
dp_loss_type,
clip,
dp_condition_on_text_encodings,
dp_timesteps,
dp_l2norm_output,
dp_normformer,
dp_cond_drop_prob,
dpn_depth,
dpn_dim_head,
dpn_heads,
save_interval,
save_path,
device,
learning_rate=0.001,
max_grad_norm=0.5,
weight_decay=0.01,
amp=False):
# DiffusionPriorNetwork
prior_network = DiffusionPriorNetwork(
dim = image_embed_dim,
depth = dpn_depth,
dim_head = dpn_dim_head,
heads = dpn_heads,
normformer = dp_normformer,
l2norm_output = dp_l2norm_output).to(device)
# DiffusionPrior with text embeddings and image embeddings pre-computed
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
image_embed_dim = image_embed_dim,
timesteps = dp_timesteps,
cond_drop_prob = dp_cond_drop_prob,
loss_type = dp_loss_type,
condition_on_text_encodings = dp_condition_on_text_encodings).to(device)
# Get image and text embeddings from the servers
print("==============Downloading embeddings - image and text====================")
image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy")
text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy")
num_data_points = text_reader.count
# Create save_path if it doesn't exist
if not os.path.exists(save_path):
os.makedirs(save_path)
### Training code ###
scaler = GradScaler(enabled=amp)
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
epochs = num_epochs
step = 0
t = time.time()
train_set_size = int(train_percent*num_data_points)
val_set_size = int(val_percent*num_data_points)
for _ in range(epochs):
diffusion_prior.train()
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size),
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
with autocast(enabled=amp):
loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor)
scaler.scale(loss).backward()
# Samples per second
step+=1
samples_per_sec = batch_size*step/(time.time()-t)
# Save checkpoint every save_interval minutes
if(int(time.time()-t) >= 60*save_interval):
t = time.time()
save_model(
save_path,
dict(model=diffusion_prior.state_dict(), optimizer=optimizer.state_dict(), scaler=scaler.state_dict()))
# Log to wandb
wandb.log({"Training loss": loss.item(),
"Steps": step,
"Samples per second": samples_per_sec})
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
# Get embeddings from the most recently saved model
if(step % REPORT_METRICS_EVERY) == 0:
diff_cosine_sim = report_cosine_sims(diffusion_prior,
image_reader,
text_reader,
train_set_size,
val_set_size,
NUM_TEST_EMBEDDINGS,
device)
wandb.log({"Cosine similarity difference": diff_cosine_sim})
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
### Evaluate model(validation run) ###
start = train_set_size
end=start+val_set_size
eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Validation")
### Test run ###
test_set_size = int(test_percent*train_set_size)
start=train_set_size+val_set_size
end=num_data_points
eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Test")
def main():
parser = argparse.ArgumentParser()
# Logging
parser.add_argument("--wandb-entity", type=str, default="laion")
parser.add_argument("--wandb-project", type=str, default="diffusion-prior")
parser.add_argument("--wandb-name", type=str, default="laion-dprior")
parser.add_argument("--wandb-dataset", type=str, default="LAION-5B")
parser.add_argument("--wandb-arch", type=str, default="DiffusionPrior")
# URLs for embeddings
parser.add_argument("--image-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
parser.add_argument("--text-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
# Hyperparameters
parser.add_argument("--learning-rate", type=float, default=1.1e-4)
parser.add_argument("--weight-decay", type=float, default=6.02e-2)
parser.add_argument("--max-grad-norm", type=float, default=0.5)
parser.add_argument("--batch-size", type=int, default=10**4)
parser.add_argument("--num-epochs", type=int, default=5)
# Image embed dimension
parser.add_argument("--image-embed-dim", type=int, default=768)
# Train-test split
parser.add_argument("--train-percent", type=float, default=0.7)
parser.add_argument("--val-percent", type=float, default=0.2)
parser.add_argument("--test-percent", type=float, default=0.1)
# LAION training(pre-computed embeddings)
# DiffusionPriorNetwork(dpn) parameters
parser.add_argument("--dpn-depth", type=int, default=6)
parser.add_argument("--dpn-dim-head", type=int, default=64)
parser.add_argument("--dpn-heads", type=int, default=8)
# DiffusionPrior(dp) parameters
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
parser.add_argument("--dp-timesteps", type=int, default=100)
parser.add_argument("--dp-l2norm-output", type=bool, default=False)
parser.add_argument("--dp-normformer", type=bool, default=False)
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.1)
parser.add_argument("--dp-loss-type", type=str, default="l2")
parser.add_argument("--clip", type=str, default=None)
parser.add_argument("--amp", type=bool, default=False)
# Model checkpointing interval(minutes)
parser.add_argument("--save-interval", type=int, default=30)
parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints")
args = parser.parse_args()
print("Setting up wandb logging... Please wait...")
wandb.init(
entity=args.wandb_entity,
project=args.wandb_project,
config={
"learning_rate": args.learning_rate,
"architecture": args.wandb_arch,
"dataset": args.wandb_dataset,
"epochs": args.num_epochs,
})
print("wandb logging setup done!")
# Obtain the utilized device.
has_cuda = torch.cuda.is_available()
if has_cuda:
device = torch.device("cuda:0")
torch.cuda.set_device(device)
# Training loop
train(args.image_embed_dim,
args.image_embed_url,
args.text_embed_url,
args.batch_size,
args.train_percent,
args.val_percent,
args.test_percent,
args.num_epochs,
args.dp_loss_type,
args.clip,
args.dp_condition_on_text_encodings,
args.dp_timesteps,
args.dp_l2norm_output,
args.dp_normformer,
args.dp_cond_drop_prob,
args.dpn_depth,
args.dpn_dim_head,
args.dpn_heads,
args.save_interval,
args.save_path,
device,
args.learning_rate,
args.max_grad_norm,
args.weight_decay,
args.amp)
if __name__ == "__main__":
main()