mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
24 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9359ad2e91 | ||
|
|
9ff228188b | ||
|
|
2d9963d30e | ||
|
|
58d9b422f3 | ||
|
|
44b319cb57 | ||
|
|
c30f380689 | ||
|
|
e4e884bb8b | ||
|
|
803ad9c17d | ||
|
|
a88dd6a9c0 | ||
|
|
72c16b496e | ||
|
|
81d83dd7f2 | ||
|
|
fa66f7e1e9 | ||
|
|
aa8d135245 | ||
|
|
70282de23b | ||
|
|
83f761847e | ||
|
|
11469dc0c6 | ||
|
|
2d25c89f35 | ||
|
|
3fe96c208a | ||
|
|
0fc6c9cdf3 | ||
|
|
7ee0ecc388 | ||
|
|
1924c7cc3d | ||
|
|
f7df3caaf3 | ||
|
|
fc954ee788 | ||
|
|
c1db2753f5 |
20
README.md
20
README.md
@@ -821,7 +821,7 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [x] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer
|
||||
- [x] bring in tools to train vqgan-vae
|
||||
- [x] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet)
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo)
|
||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
- [ ] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
|
||||
@@ -830,6 +830,12 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||
- [ ] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
|
||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
|
||||
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
||||
- [ ] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)
|
||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
||||
- [ ] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation)
|
||||
|
||||
## Citations
|
||||
|
||||
@@ -895,4 +901,14 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>
|
||||
```bibtex
|
||||
@article{Shleifer2021NormFormerIT,
|
||||
title = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
|
||||
author = {Sam Shleifer and Jason Weston and Myle Ott},
|
||||
journal = {ArXiv},
|
||||
year = {2021},
|
||||
volume = {abs/2110.09456}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import click
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from functools import reduce
|
||||
from pathlib import Path
|
||||
|
||||
from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior
|
||||
|
||||
@@ -29,6 +29,9 @@ from x_clip import CLIP
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def identity(t, *args, **kwargs):
|
||||
return t
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
@@ -496,7 +499,12 @@ class SwiGLU(nn.Module):
|
||||
x, gate = x.chunk(2, dim = -1)
|
||||
return x * F.silu(gate)
|
||||
|
||||
def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False):
|
||||
def FeedForward(
|
||||
dim,
|
||||
mult = 4,
|
||||
dropout = 0.,
|
||||
post_activation_norm = False
|
||||
):
|
||||
""" post-activation norm https://arxiv.org/abs/2110.09456 """
|
||||
|
||||
inner_dim = int(mult * dim)
|
||||
@@ -519,7 +527,8 @@ class Attention(nn.Module):
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
causal = False
|
||||
causal = False,
|
||||
post_norm = False
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
@@ -534,7 +543,11 @@ class Attention(nn.Module):
|
||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim, bias = False),
|
||||
LayerNorm(dim) if post_norm else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x, mask = None, attn_bias = None):
|
||||
b, n, device = *x.shape[:2], x.device
|
||||
@@ -596,10 +609,11 @@ class CausalTransformer(nn.Module):
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
ff_mult = 4,
|
||||
norm_out = False,
|
||||
norm_out = True,
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.,
|
||||
final_proj = True
|
||||
final_proj = True,
|
||||
normformer = False
|
||||
):
|
||||
super().__init__()
|
||||
self.rel_pos_bias = RelPosBias(heads = heads)
|
||||
@@ -607,8 +621,8 @@ class CausalTransformer(nn.Module):
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout),
|
||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
|
||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer),
|
||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
||||
]))
|
||||
|
||||
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
||||
@@ -635,12 +649,14 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
self,
|
||||
dim,
|
||||
num_timesteps = None,
|
||||
l2norm_output = False, # whether to restrict image embedding output with l2norm at the end (may make it easier to learn?)
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||
self.learned_query = nn.Parameter(torch.randn(dim))
|
||||
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
||||
self.l2norm_output = l2norm_output
|
||||
|
||||
def forward_with_cond_scale(
|
||||
self,
|
||||
@@ -719,7 +735,8 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
|
||||
pred_image_embed = tokens[..., -1, :]
|
||||
|
||||
return pred_image_embed
|
||||
output_fn = l2norm if self.l2norm_output else identity
|
||||
return output_fn(pred_image_embed)
|
||||
|
||||
class DiffusionPrior(BaseGaussianDiffusion):
|
||||
def __init__(
|
||||
@@ -913,6 +930,72 @@ class SinusoidalPosEmb(nn.Module):
|
||||
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
|
||||
return torch.cat((emb.sin(), emb.cos()), dim = -1)
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_out,
|
||||
groups = 8
|
||||
):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv2d(dim, dim_out, 3, padding = 1),
|
||||
nn.GroupNorm(groups, dim_out),
|
||||
nn.SiLU()
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_out,
|
||||
*,
|
||||
cond_dim = None,
|
||||
time_cond_dim = None,
|
||||
groups = 8
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.time_mlp = None
|
||||
|
||||
if exists(time_cond_dim):
|
||||
self.time_mlp = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(time_cond_dim, dim_out)
|
||||
)
|
||||
|
||||
self.cross_attn = None
|
||||
|
||||
if exists(cond_dim):
|
||||
self.cross_attn = EinopsToAndFrom(
|
||||
'b c h w',
|
||||
'b (h w) c',
|
||||
CrossAttention(
|
||||
dim = dim_out,
|
||||
context_dim = cond_dim
|
||||
)
|
||||
)
|
||||
|
||||
self.block1 = Block(dim, dim_out, groups = groups)
|
||||
self.block2 = Block(dim_out, dim_out, groups = groups)
|
||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
||||
|
||||
def forward(self, x, cond = None, time_emb = None):
|
||||
h = self.block1(x)
|
||||
|
||||
if exists(self.time_mlp) and exists(time_emb):
|
||||
time_emb = self.time_mlp(time_emb)
|
||||
h = rearrange(time_emb, 'b c -> b c 1 1') + h
|
||||
|
||||
if exists(self.cross_attn):
|
||||
assert exists(cond)
|
||||
h = self.cross_attn(h, context = cond) + h
|
||||
|
||||
h = self.block2(h)
|
||||
return h + self.res_conv(x)
|
||||
|
||||
class ConvNextBlock(nn.Module):
|
||||
""" https://arxiv.org/abs/2201.03545 """
|
||||
|
||||
@@ -923,8 +1006,7 @@ class ConvNextBlock(nn.Module):
|
||||
*,
|
||||
cond_dim = None,
|
||||
time_cond_dim = None,
|
||||
mult = 2,
|
||||
norm = True
|
||||
mult = 2
|
||||
):
|
||||
super().__init__()
|
||||
need_projection = dim != dim_out
|
||||
@@ -953,7 +1035,7 @@ class ConvNextBlock(nn.Module):
|
||||
|
||||
inner_dim = int(dim_out * mult)
|
||||
self.net = nn.Sequential(
|
||||
ChanLayerNorm(dim) if norm else nn.Identity(),
|
||||
ChanLayerNorm(dim),
|
||||
nn.Conv2d(dim, inner_dim, 3, padding = 1),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
|
||||
@@ -1065,7 +1147,11 @@ class LinearAttention(nn.Module):
|
||||
|
||||
self.nonlin = nn.GELU()
|
||||
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
|
||||
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Conv2d(inner_dim, dim, 1, bias = False),
|
||||
ChanLayerNorm(dim)
|
||||
)
|
||||
|
||||
def forward(self, fmap):
|
||||
h, x, y = self.heads, *fmap.shape[-2:]
|
||||
@@ -1108,7 +1194,9 @@ class Unet(nn.Module):
|
||||
max_text_len = 256,
|
||||
cond_on_image_embeds = False,
|
||||
init_dim = None,
|
||||
init_conv_kernel_size = 7
|
||||
init_conv_kernel_size = 7,
|
||||
block_type = 'resnet',
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
# save locals to take care of some hyperparameters for cascading DDPM
|
||||
@@ -1183,6 +1271,15 @@ class Unet(nn.Module):
|
||||
|
||||
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
|
||||
|
||||
# whether to use resnet or the (improved?) convnext blocks
|
||||
|
||||
if block_type == 'resnet':
|
||||
block_klass = ResnetBlock
|
||||
elif block_type == 'convnext':
|
||||
block_klass = ConvNextBlock
|
||||
else:
|
||||
raise ValueError(f'unimplemented block type {block_type}')
|
||||
|
||||
# layers
|
||||
|
||||
self.downs = nn.ModuleList([])
|
||||
@@ -1195,32 +1292,32 @@ class Unet(nn.Module):
|
||||
layer_cond_dim = cond_dim if not is_first else None
|
||||
|
||||
self.downs.append(nn.ModuleList([
|
||||
ConvNextBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, norm = ind != 0),
|
||||
block_klass(dim_in, dim_out, time_cond_dim = time_cond_dim),
|
||||
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
block_klass(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
Downsample(dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
mid_dim = dims[-1]
|
||||
|
||||
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
|
||||
self.mid_block1 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
||||
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
|
||||
self.mid_block2 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
is_last = ind >= (num_resolutions - 2)
|
||||
layer_cond_dim = cond_dim if not is_last else None
|
||||
|
||||
self.ups.append(nn.ModuleList([
|
||||
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
block_klass(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
block_klass(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
Upsample(dim_in)
|
||||
]))
|
||||
|
||||
out_dim = default(out_dim, channels)
|
||||
self.final_conv = nn.Sequential(
|
||||
ConvNextBlock(dim, dim),
|
||||
block_klass(dim, dim),
|
||||
nn.Conv2d(dim, out_dim, 1)
|
||||
)
|
||||
|
||||
@@ -1351,10 +1448,10 @@ class Unet(nn.Module):
|
||||
|
||||
hiddens = []
|
||||
|
||||
for convnext, sparse_attn, convnext2, downsample in self.downs:
|
||||
x = convnext(x, c, t)
|
||||
for block1, sparse_attn, block2, downsample in self.downs:
|
||||
x = block1(x, c, t)
|
||||
x = sparse_attn(x)
|
||||
x = convnext2(x, c, t)
|
||||
x = block2(x, c, t)
|
||||
hiddens.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
@@ -1365,11 +1462,11 @@ class Unet(nn.Module):
|
||||
|
||||
x = self.mid_block2(x, mid_c, t)
|
||||
|
||||
for convnext, sparse_attn, convnext2, upsample in self.ups:
|
||||
for block1, sparse_attn, block2, upsample in self.ups:
|
||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||
x = convnext(x, c, t)
|
||||
x = block1(x, c, t)
|
||||
x = sparse_attn(x)
|
||||
x = convnext2(x, c, t)
|
||||
x = block2(x, c, t)
|
||||
x = upsample(x)
|
||||
|
||||
return self.final_conv(x)
|
||||
|
||||
@@ -159,12 +159,13 @@ class DecoderTrainer(nn.Module):
|
||||
index = unet_number - 1
|
||||
unet = self.decoder.unets[index]
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
|
||||
|
||||
optimizer = getattr(self, f'optim{index}')
|
||||
scaler = getattr(self, f'scaler{index}')
|
||||
|
||||
if exists(self.max_grad_norm):
|
||||
scaler.unscale_(optimizer)
|
||||
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
@@ -285,6 +285,10 @@ class ResnetEncDec(nn.Module):
|
||||
def get_encoded_fmap_size(self, image_size):
|
||||
return image_size // (2 ** self.layers)
|
||||
|
||||
@property
|
||||
def last_dec_layer(self):
|
||||
return self.decoders[-1].weight
|
||||
|
||||
def encode(self, x):
|
||||
for enc in self.encoders:
|
||||
x = enc(x)
|
||||
@@ -419,6 +423,10 @@ class ConvNextEncDec(nn.Module):
|
||||
def get_encoded_fmap_size(self, image_size):
|
||||
return image_size // (2 ** self.layers)
|
||||
|
||||
@property
|
||||
def last_dec_layer(self):
|
||||
return self.decoders[-1].weight
|
||||
|
||||
def encode(self, x):
|
||||
for enc in self.encoders:
|
||||
x = enc(x)
|
||||
@@ -606,6 +614,10 @@ class ViTEncDec(nn.Module):
|
||||
def get_encoded_fmap_size(self, image_size):
|
||||
return image_size // self.patch_size
|
||||
|
||||
@property
|
||||
def last_dec_layer(self):
|
||||
return self.decoder[-3][-1].weight
|
||||
|
||||
def encode(self, x):
|
||||
return self.encoder(x)
|
||||
|
||||
@@ -843,7 +855,7 @@ class VQGanVAE(nn.Module):
|
||||
|
||||
# calculate adaptive weight
|
||||
|
||||
last_dec_layer = self.decoders[-1].weight
|
||||
last_dec_layer = self.enc_dec.last_dec_layer
|
||||
|
||||
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
|
||||
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)
|
||||
|
||||
2
setup.py
2
setup.py
@@ -10,7 +10,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.0.88',
|
||||
version = '0.0.95',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
|
||||
@@ -1,36 +1,78 @@
|
||||
import os
|
||||
import math
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from embedding_reader import EmbeddingReader
|
||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
from torch.cuda.amp import autocast,GradScaler
|
||||
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
|
||||
import wandb
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training
|
||||
REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training
|
||||
|
||||
|
||||
def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end),
|
||||
total_loss = 0.
|
||||
total_samples = 0.
|
||||
|
||||
for emb_images, emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end),
|
||||
text_reader(batch_size=batch_size, start=start, end=end)):
|
||||
|
||||
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
|
||||
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
|
||||
|
||||
batches = emb_images_tensor.shape[0]
|
||||
|
||||
loss = model(text_embed = emb_text_tensor, image_embed = emb_images_tensor)
|
||||
|
||||
# Log to wandb
|
||||
wandb.log({f'{phase} {loss_type}': loss})
|
||||
total_loss += loss.item() * batches
|
||||
total_samples += batches
|
||||
|
||||
def save_model(save_path,state_dict):
|
||||
avg_loss = (total_loss / total_samples)
|
||||
wandb.log({f'{phase} {loss_type}': avg_loss})
|
||||
|
||||
def save_model(save_path, state_dict):
|
||||
# Saving State Dict
|
||||
print("====================================== Saving checkpoint ======================================")
|
||||
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
|
||||
|
||||
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,val_set_size,NUM_TEST_EMBEDDINGS,device):
|
||||
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||
|
||||
tstart = train_set_size+val_set_size
|
||||
tend = train_set_size+val_set_size+NUM_TEST_EMBEDDINGS
|
||||
|
||||
for embt, embi in zip(text_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend),image_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend)):
|
||||
text_embed = torch.tensor(embt[0]).to(device)
|
||||
text_embed = text_embed / text_embed.norm(dim=1, keepdim=True)
|
||||
test_text_cond = dict(text_embed = text_embed)
|
||||
|
||||
test_image_embeddings = torch.tensor(embi[0]).to(device)
|
||||
test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(dim=1, keepdim=True)
|
||||
|
||||
predicted_image_embeddings = diffusion_prior.p_sample_loop((NUM_TEST_EMBEDDINGS, 768), text_cond = test_text_cond)
|
||||
predicted_image_embeddings = predicted_image_embeddings / predicted_image_embeddings.norm(dim=1, keepdim=True)
|
||||
|
||||
original_similarity = cos(text_embed,test_image_embeddings).cpu().numpy()
|
||||
predicted_similarity = cos(text_embed,predicted_image_embeddings).cpu().numpy()
|
||||
|
||||
wandb.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)})
|
||||
wandb.log({"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity)})
|
||||
|
||||
return np.mean(predicted_similarity - original_similarity)
|
||||
|
||||
|
||||
|
||||
def train(image_embed_dim,
|
||||
image_embed_url,
|
||||
text_embed_url,
|
||||
@@ -43,6 +85,8 @@ def train(image_embed_dim,
|
||||
clip,
|
||||
dp_condition_on_text_encodings,
|
||||
dp_timesteps,
|
||||
dp_l2norm_output,
|
||||
dp_normformer,
|
||||
dp_cond_drop_prob,
|
||||
dpn_depth,
|
||||
dpn_dim_head,
|
||||
@@ -52,14 +96,17 @@ def train(image_embed_dim,
|
||||
device,
|
||||
learning_rate=0.001,
|
||||
max_grad_norm=0.5,
|
||||
weight_decay=0.01):
|
||||
weight_decay=0.01,
|
||||
amp=False):
|
||||
|
||||
# DiffusionPriorNetwork
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = image_embed_dim,
|
||||
depth = dpn_depth,
|
||||
dim_head = dpn_dim_head,
|
||||
heads = dpn_heads).to(device)
|
||||
heads = dpn_heads,
|
||||
normformer = dp_normformer,
|
||||
l2norm_output = dp_l2norm_output).to(device)
|
||||
|
||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
||||
diffusion_prior = DiffusionPrior(
|
||||
@@ -82,6 +129,7 @@ def train(image_embed_dim,
|
||||
os.makedirs(save_path)
|
||||
|
||||
### Training code ###
|
||||
scaler = GradScaler(enabled=amp)
|
||||
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
|
||||
epochs = num_epochs
|
||||
|
||||
@@ -98,23 +146,45 @@ def train(image_embed_dim,
|
||||
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
|
||||
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
|
||||
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
|
||||
optimizer.zero_grad()
|
||||
loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor)
|
||||
loss.backward()
|
||||
|
||||
with autocast(enabled=amp):
|
||||
loss = diffusion_prior(text_embed = emb_text_tensor,image_embed = emb_images_tensor)
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# Samples per second
|
||||
step+=1
|
||||
samples_per_sec = batch_size*step/(time.time()-t)
|
||||
# Save checkpoint every save_interval minutes
|
||||
if(int(time.time()-t) >= 60*save_interval):
|
||||
t = time.time()
|
||||
save_model(save_path,diffusion_prior.state_dict())
|
||||
|
||||
save_model(
|
||||
save_path,
|
||||
dict(model=diffusion_prior.state_dict(), optimizer=optimizer.state_dict(), scaler=scaler.state_dict()))
|
||||
|
||||
# Log to wandb
|
||||
wandb.log({"Training loss": loss.item(),
|
||||
"Steps": step,
|
||||
"Samples per second": samples_per_sec})
|
||||
# Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed)
|
||||
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
|
||||
# Get embeddings from the most recently saved model
|
||||
if(step % REPORT_METRICS_EVERY) == 0:
|
||||
diff_cosine_sim = report_cosine_sims(diffusion_prior,
|
||||
image_reader,
|
||||
text_reader,
|
||||
train_set_size,
|
||||
val_set_size,
|
||||
NUM_TEST_EMBEDDINGS,
|
||||
device)
|
||||
wandb.log({"Cosine similarity difference": diff_cosine_sim})
|
||||
|
||||
nn.init.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
|
||||
optimizer.step()
|
||||
scaler.unscale_(optimizer)
|
||||
nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
### Evaluate model(validation run) ###
|
||||
start = train_set_size
|
||||
@@ -139,8 +209,8 @@ def main():
|
||||
parser.add_argument("--image-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/")
|
||||
parser.add_argument("--text-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/")
|
||||
# Hyperparameters
|
||||
parser.add_argument("--learning-rate", type=float, default=0.001)
|
||||
parser.add_argument("--weight-decay", type=float, default=0.01)
|
||||
parser.add_argument("--learning-rate", type=float, default=1.1e-4)
|
||||
parser.add_argument("--weight-decay", type=float, default=6.02e-2)
|
||||
parser.add_argument("--max-grad-norm", type=float, default=0.5)
|
||||
parser.add_argument("--batch-size", type=int, default=10**4)
|
||||
parser.add_argument("--num-epochs", type=int, default=5)
|
||||
@@ -158,15 +228,20 @@ def main():
|
||||
# DiffusionPrior(dp) parameters
|
||||
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
|
||||
parser.add_argument("--dp-timesteps", type=int, default=100)
|
||||
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2)
|
||||
parser.add_argument("--dp-l2norm-output", type=bool, default=False)
|
||||
parser.add_argument("--dp-normformer", type=bool, default=False)
|
||||
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.1)
|
||||
parser.add_argument("--dp-loss-type", type=str, default="l2")
|
||||
parser.add_argument("--clip", type=str, default=None)
|
||||
parser.add_argument("--amp", type=bool, default=False)
|
||||
# Model checkpointing interval(minutes)
|
||||
parser.add_argument("--save-interval", type=int, default=30)
|
||||
parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("Setting up wandb logging... Please wait...")
|
||||
|
||||
wandb.init(
|
||||
entity=args.wandb_entity,
|
||||
project=args.wandb_project,
|
||||
@@ -176,6 +251,7 @@ def main():
|
||||
"dataset": args.wandb_dataset,
|
||||
"epochs": args.num_epochs,
|
||||
})
|
||||
|
||||
print("wandb logging setup done!")
|
||||
# Obtain the utilized device.
|
||||
|
||||
@@ -197,6 +273,8 @@ def main():
|
||||
args.clip,
|
||||
args.dp_condition_on_text_encodings,
|
||||
args.dp_timesteps,
|
||||
args.dp_l2norm_output,
|
||||
args.dp_normformer,
|
||||
args.dp_cond_drop_prob,
|
||||
args.dpn_depth,
|
||||
args.dpn_dim_head,
|
||||
@@ -206,7 +284,8 @@ def main():
|
||||
device,
|
||||
args.learning_rate,
|
||||
args.max_grad_norm,
|
||||
args.weight_decay)
|
||||
args.weight_decay,
|
||||
args.amp)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user