mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
36 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
908088cfea | ||
|
|
8dc8a3de0d | ||
|
|
35f89556ba | ||
|
|
2b55f753b9 | ||
|
|
fc8fce38fb | ||
|
|
a1bfb03ba4 | ||
|
|
b1e7b5f6bb | ||
|
|
10b905b445 | ||
|
|
9b322ea634 | ||
|
|
ba64ea45cc | ||
|
|
64f7be1926 | ||
|
|
db805e73e1 | ||
|
|
cb07b37970 | ||
|
|
a774bfefe2 | ||
|
|
2ae57f0cf5 | ||
|
|
e46eaec817 | ||
|
|
8647cb5e76 | ||
|
|
53c189e46a | ||
|
|
dde51fd362 | ||
|
|
2eac7996fa | ||
|
|
4010aec033 | ||
|
|
c87b84a259 | ||
|
|
8b05468653 | ||
|
|
830afd3c15 | ||
|
|
8f93729d19 | ||
|
|
cd5f2c1de4 | ||
|
|
85ed77d512 | ||
|
|
fd53fa17db | ||
|
|
3676ef4d49 | ||
|
|
28e944f328 | ||
|
|
14e63a3f67 | ||
|
|
09e9eaa5a6 | ||
|
|
e6d752cf4a | ||
|
|
ad20a14a4d | ||
|
|
0be1e0d64c | ||
|
|
98df1ba51e |
70
README.md
70
README.md
@@ -902,7 +902,7 @@ Please note that the script internally passes text_embed and image_embed to the
|
||||
### Usage
|
||||
|
||||
```bash
|
||||
$ pyhon train_diffusion_prior.py
|
||||
$ python train_diffusion_prior.py
|
||||
```
|
||||
|
||||
The most significant parameters for the script are as follows:
|
||||
@@ -927,7 +927,39 @@ The most significant parameters for the script are as follows:
|
||||
|
||||
### Sample wandb run log
|
||||
|
||||
Please find a sample wandb run log at : https://wandb.ai/laion/diffusion-prior/runs/aul0rhv5?workspace=
|
||||
Please find a sample wandb run log at : https://wandb.ai/laion/diffusion-prior/runs/1blxu24j
|
||||
|
||||
### Loading and saving the Diffusion Prior model
|
||||
|
||||
Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory.
|
||||
|
||||
## from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model
|
||||
|
||||
load_diffusion_model(dprior_path, device)
|
||||
|
||||
dprior_path : path to saved model(.pth)
|
||||
|
||||
device : the cuda device you're running on
|
||||
|
||||
save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim)
|
||||
|
||||
save_path : path to save at
|
||||
|
||||
model : object of Diffusion_Prior
|
||||
|
||||
optimizer : optimizer object - see train_diffusion_prior.py for how to create one.
|
||||
|
||||
e.g: optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
|
||||
|
||||
scaler : a GradScaler object.
|
||||
|
||||
e.g: scaler = GradScaler(enabled=amp)
|
||||
|
||||
config : config object created in train_diffusion_prior.py - see file for example.
|
||||
|
||||
image_embed_dim - the dimension of the image_embedding
|
||||
|
||||
e.g: 768
|
||||
|
||||
## CLI (wip)
|
||||
|
||||
@@ -966,20 +998,25 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [x] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet)
|
||||
- [x] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation)
|
||||
- [x] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)
|
||||
- [x] offer setting in diffusion prior to split time and image embeddings into multiple tokens, configurable, for more surface area during attention
|
||||
- [x] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
|
||||
- [x] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
|
||||
- [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
|
||||
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
|
||||
- [x] cross embed layers for downsampling, as an option
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
- [ ] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||
- [ ] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
|
||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
|
||||
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
||||
- [ ] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
|
||||
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
||||
- [ ] decoder needs one day worth of refactor for tech debt
|
||||
|
||||
## Citations
|
||||
|
||||
@@ -1047,4 +1084,25 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Yu2022CoCaCC,
|
||||
title = {CoCa: Contrastive Captioners are Image-Text Foundation Models},
|
||||
author = {Jiahui Yu and Zirui Wang and Vijay Vasudevan and Legg Yeung and Mojtaba Seyedhosseini and Yonghui Wu},
|
||||
journal = {ArXiv},
|
||||
year = {2022},
|
||||
volume = {abs/2205.01917}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{wang2021crossformer,
|
||||
title = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},
|
||||
author = {Wenxiao Wang and Lu Yao and Long Chen and Binbin Lin and Deng Cai and Xiaofei He and Wei Liu},
|
||||
year = {2021},
|
||||
eprint = {2108.00154},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||
|
||||
@@ -4,6 +4,7 @@ from inspect import isfunction
|
||||
from functools import partial
|
||||
from contextlib import contextmanager
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -23,9 +24,14 @@ from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
|
||||
|
||||
from resize_right import resize
|
||||
|
||||
# rotary embeddings
|
||||
|
||||
from rotary_embedding_torch import RotaryEmbedding
|
||||
|
||||
# use x-clip
|
||||
|
||||
from x_clip import CLIP
|
||||
from coca_pytorch import CoCa
|
||||
|
||||
# helper functions
|
||||
|
||||
@@ -113,9 +119,10 @@ EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 't
|
||||
EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])
|
||||
|
||||
class BaseClipAdapter(nn.Module):
|
||||
def __init__(self, clip):
|
||||
def __init__(self, clip, **kwargs):
|
||||
super().__init__()
|
||||
self.clip = clip
|
||||
self.overrides = kwargs
|
||||
|
||||
@property
|
||||
def dim_latent(self):
|
||||
@@ -173,6 +180,39 @@ class XClipAdapter(BaseClipAdapter):
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
return EmbeddedImage(l2norm(image_embed), image_encodings)
|
||||
|
||||
class CoCaAdapter(BaseClipAdapter):
|
||||
@property
|
||||
def dim_latent(self):
|
||||
return self.clip.dim
|
||||
|
||||
@property
|
||||
def image_size(self):
|
||||
assert 'image_size' in self.overrides
|
||||
return self.overrides['image_size']
|
||||
|
||||
@property
|
||||
def image_channels(self):
|
||||
assert 'image_channels' in self.overrides
|
||||
return self.overrides['image_channels']
|
||||
|
||||
@property
|
||||
def max_text_len(self):
|
||||
assert 'max_text_len' in self.overrides
|
||||
return self.overrides['max_text_len']
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_text(self, text):
|
||||
text = text[..., :self.max_text_len]
|
||||
text_mask = text != 0
|
||||
text_embed, text_encodings = self.clip.embed_text(text)
|
||||
return EmbeddedText(text_embed, text_encodings, text_mask)
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_image(self, image):
|
||||
image = resize_image_to(image, self.image_size)
|
||||
image_embed, image_encodings = self.clip.embed_image(image)
|
||||
return EmbeddedImage(image_embed, image_encodings)
|
||||
|
||||
class OpenAIClipAdapter(BaseClipAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -225,7 +265,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
||||
text_embed = self.clip.encode_text(text)
|
||||
text_encodings = self.text_encodings
|
||||
del self.text_encodings
|
||||
return EmbeddedText(text_embed.float(), text_encodings.float(), text_mask)
|
||||
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float(), text_mask)
|
||||
|
||||
@torch.no_grad()
|
||||
def embed_image(self, image):
|
||||
@@ -233,7 +273,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
||||
image = resize_image_to(image, self.image_size)
|
||||
image = self.clip_normalize(unnormalize_img(image))
|
||||
image_embed = self.clip.encode_image(image)
|
||||
return EmbeddedImage(image_embed.float(), None)
|
||||
return EmbeddedImage(l2norm(image_embed.float()), None)
|
||||
|
||||
# classifier free guidance functions
|
||||
|
||||
@@ -531,7 +571,8 @@ class Attention(nn.Module):
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
causal = False,
|
||||
post_norm = False
|
||||
post_norm = False,
|
||||
rotary_emb = None
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
@@ -547,6 +588,8 @@ class Attention(nn.Module):
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
|
||||
|
||||
self.rotary_emb = rotary_emb
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim, bias = False),
|
||||
LayerNorm(dim) if post_norm else nn.Identity()
|
||||
@@ -559,6 +602,12 @@ class Attention(nn.Module):
|
||||
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
|
||||
|
||||
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
|
||||
q = q * self.scale
|
||||
|
||||
# rotary embeddings
|
||||
|
||||
if exists(self.rotary_emb):
|
||||
q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k))
|
||||
|
||||
# add null key / value for classifier free guidance in prior net
|
||||
|
||||
@@ -566,7 +615,7 @@ class Attention(nn.Module):
|
||||
k = torch.cat((nk, k), dim = -2)
|
||||
v = torch.cat((nv, v), dim = -2)
|
||||
|
||||
q = q * self.scale
|
||||
# calculate query / key similarities
|
||||
|
||||
sim = einsum('b h i d, b j d -> b h i j', q, k)
|
||||
|
||||
@@ -591,7 +640,7 @@ class Attention(nn.Module):
|
||||
|
||||
# attention
|
||||
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True)
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
||||
attn = sim.softmax(dim = -1)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
@@ -616,15 +665,18 @@ class CausalTransformer(nn.Module):
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.,
|
||||
final_proj = True,
|
||||
normformer = False
|
||||
normformer = False,
|
||||
rotary_emb = True
|
||||
):
|
||||
super().__init__()
|
||||
self.rel_pos_bias = RelPosBias(heads = heads)
|
||||
|
||||
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer),
|
||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb),
|
||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
||||
]))
|
||||
|
||||
@@ -652,10 +704,31 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
self,
|
||||
dim,
|
||||
num_timesteps = None,
|
||||
num_time_embeds = 1,
|
||||
num_image_embeds = 1,
|
||||
num_text_embeds = 1,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||
self.num_time_embeds = num_time_embeds
|
||||
self.num_image_embeds = num_image_embeds
|
||||
self.num_text_embeds = num_text_embeds
|
||||
|
||||
self.to_text_embeds = nn.Sequential(
|
||||
nn.Linear(dim, dim * num_text_embeds) if num_text_embeds > 1 else nn.Identity(),
|
||||
Rearrange('b (n d) -> b n d', n = num_text_embeds)
|
||||
)
|
||||
|
||||
self.to_time_embeds = nn.Sequential(
|
||||
nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||
Rearrange('b (n d) -> b n d', n = num_time_embeds)
|
||||
)
|
||||
|
||||
self.to_image_embeds = nn.Sequential(
|
||||
nn.Linear(dim, dim * num_image_embeds) if num_image_embeds > 1 else nn.Identity(),
|
||||
Rearrange('b (n d) -> b n d', n = num_image_embeds)
|
||||
)
|
||||
|
||||
self.learned_query = nn.Parameter(torch.randn(dim))
|
||||
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
|
||||
|
||||
@@ -685,10 +758,13 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
):
|
||||
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
|
||||
|
||||
num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds
|
||||
|
||||
# in section 2.2, last paragraph
|
||||
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
|
||||
|
||||
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
|
||||
text_embed = self.to_text_embeds(text_embed)
|
||||
image_embed = self.to_image_embeds(image_embed)
|
||||
|
||||
# make text encodings optional
|
||||
# although the paper seems to suggest it is present <--
|
||||
@@ -708,16 +784,17 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
|
||||
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
|
||||
|
||||
keep_mask = repeat(keep_mask, 'b 1 -> b n', n = num_text_embeds)
|
||||
mask = torch.cat((mask, keep_mask), dim = 1)
|
||||
|
||||
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
||||
# but let's just do it right
|
||||
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
|
||||
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
|
||||
time_embed = self.time_embeddings(diffusion_timesteps)
|
||||
time_embed = rearrange(time_embed, 'b d -> b 1 d')
|
||||
time_embed = self.to_time_embeds(diffusion_timesteps)
|
||||
|
||||
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
||||
|
||||
@@ -725,6 +802,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
text_encodings,
|
||||
text_embed,
|
||||
time_embed,
|
||||
image_embed,
|
||||
learned_queries
|
||||
), dim = -2)
|
||||
|
||||
@@ -748,13 +826,16 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
image_size = None,
|
||||
image_channels = 3,
|
||||
timesteps = 1000,
|
||||
cond_drop_prob = 0.2,
|
||||
cond_drop_prob = 0.,
|
||||
loss_type = "l1",
|
||||
predict_x_start = True,
|
||||
beta_schedule = "cosine",
|
||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||
sampling_clamp_l2norm = False,
|
||||
training_clamp_l2norm = False,
|
||||
init_image_embed_l2norm = False,
|
||||
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||
clip_adapter_overrides = dict()
|
||||
):
|
||||
super().__init__(
|
||||
beta_schedule = beta_schedule,
|
||||
@@ -764,7 +845,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
|
||||
if exists(clip):
|
||||
if isinstance(clip, CLIP):
|
||||
clip = XClipAdapter(clip)
|
||||
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
||||
elif isinstance(clip, CoCa):
|
||||
clip = CoCaAdapter(clip, **clip_adapter_overrides)
|
||||
|
||||
assert isinstance(clip, BaseClipAdapter)
|
||||
freeze_model_and_make_eval_(clip)
|
||||
@@ -788,6 +871,8 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
|
||||
# whether to force an l2norm, similar to clipping denoised, when sampling
|
||||
self.sampling_clamp_l2norm = sampling_clamp_l2norm
|
||||
self.training_clamp_l2norm = training_clamp_l2norm
|
||||
self.init_image_embed_l2norm = init_image_embed_l2norm
|
||||
|
||||
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
||||
pred = self.net(x, t, **text_cond)
|
||||
@@ -822,11 +907,16 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
device = self.betas.device
|
||||
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device=device)
|
||||
image_embed = torch.randn(shape, device=device)
|
||||
|
||||
if self.init_image_embed_l2norm:
|
||||
image_embed = l2norm(image_embed) * self.image_embed_scale
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
|
||||
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
|
||||
return img
|
||||
times = torch.full((b,), i, device = device, dtype = torch.long)
|
||||
image_embed = self.p_sample(image_embed, times, text_cond = text_cond)
|
||||
|
||||
return image_embed
|
||||
|
||||
def p_losses(self, image_embed, times, text_cond, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(image_embed))
|
||||
@@ -840,6 +930,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
||||
**text_cond
|
||||
)
|
||||
|
||||
if self.predict_x_start and self.training_clamp_l2norm:
|
||||
pred = l2norm(pred) * self.image_embed_scale
|
||||
|
||||
target = noise if not self.predict_x_start else image_embed
|
||||
|
||||
loss = self.loss_fn(pred, target)
|
||||
@@ -1074,7 +1167,7 @@ class CrossAttention(nn.Module):
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
sim = sim.masked_fill(~mask, max_neg_value)
|
||||
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True)
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
||||
attn = sim.softmax(dim = -1)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
@@ -1135,6 +1228,33 @@ class LinearAttention(nn.Module):
|
||||
out = self.nonlin(out)
|
||||
return self.to_out(out)
|
||||
|
||||
class CrossEmbedLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_in,
|
||||
kernel_sizes,
|
||||
dim_out = None,
|
||||
stride = 2
|
||||
):
|
||||
super().__init__()
|
||||
assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
|
||||
dim_out = default(dim_out, dim_in)
|
||||
|
||||
kernel_sizes = sorted(kernel_sizes)
|
||||
num_scales = len(kernel_sizes)
|
||||
|
||||
# calculate the dimension at each scale
|
||||
dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
|
||||
dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
|
||||
|
||||
self.convs = nn.ModuleList([])
|
||||
for kernel, dim_scale in zip(kernel_sizes, dim_scales):
|
||||
self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
|
||||
|
||||
def forward(self, x):
|
||||
fmaps = tuple(map(lambda conv: conv(x), self.convs))
|
||||
return torch.cat(fmaps, dim = 1)
|
||||
|
||||
class Unet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1158,8 +1278,10 @@ class Unet(nn.Module):
|
||||
cond_on_image_embeds = False,
|
||||
init_dim = None,
|
||||
init_conv_kernel_size = 7,
|
||||
block_type = 'resnet',
|
||||
block_resnet_groups = 8,
|
||||
resnet_groups = 8,
|
||||
init_cross_embed_kernel_sizes = (3, 7, 15),
|
||||
cross_embed_downsample = False,
|
||||
cross_embed_downsample_kernel_sizes = (2, 4),
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -1178,10 +1300,9 @@ class Unet(nn.Module):
|
||||
self.channels = channels
|
||||
|
||||
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||
init_dim = default(init_dim, dim // 2)
|
||||
init_dim = default(init_dim, dim // 3 * 2)
|
||||
|
||||
assert (init_conv_kernel_size % 2) == 1
|
||||
self.init_conv = nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
|
||||
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
|
||||
|
||||
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
@@ -1237,7 +1358,15 @@ class Unet(nn.Module):
|
||||
|
||||
# resnet block klass
|
||||
|
||||
block_klass = partial(ResnetBlock, groups = block_resnet_groups)
|
||||
resnet_groups = cast_tuple(resnet_groups, len(in_out))
|
||||
|
||||
assert len(resnet_groups) == len(in_out)
|
||||
|
||||
# downsample klass
|
||||
|
||||
downsample_klass = Downsample
|
||||
if cross_embed_downsample:
|
||||
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
|
||||
|
||||
# layers
|
||||
|
||||
@@ -1245,38 +1374,39 @@ class Unet(nn.Module):
|
||||
self.ups = nn.ModuleList([])
|
||||
num_resolutions = len(in_out)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
for ind, ((dim_in, dim_out), groups) in enumerate(zip(in_out, resnet_groups)):
|
||||
is_first = ind == 0
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
layer_cond_dim = cond_dim if not is_first else None
|
||||
|
||||
self.downs.append(nn.ModuleList([
|
||||
block_klass(dim_in, dim_out, time_cond_dim = time_cond_dim),
|
||||
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
block_klass(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
Downsample(dim_out) if not is_last else nn.Identity()
|
||||
ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
downsample_klass(dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
mid_dim = dims[-1]
|
||||
|
||||
self.mid_block1 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
|
||||
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
||||
self.mid_block2 = block_klass(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
|
||||
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
for ind, ((dim_in, dim_out), groups) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups))):
|
||||
is_last = ind >= (num_resolutions - 2)
|
||||
layer_cond_dim = cond_dim if not is_last else None
|
||||
|
||||
self.ups.append(nn.ModuleList([
|
||||
block_klass(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
block_klass(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
|
||||
ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Upsample(dim_in)
|
||||
]))
|
||||
|
||||
out_dim = default(out_dim, channels)
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
block_klass(dim, dim),
|
||||
ResnetBlock(dim, dim, groups = resnet_groups[0]),
|
||||
nn.Conv2d(dim, out_dim, 1)
|
||||
)
|
||||
|
||||
@@ -1287,12 +1417,13 @@ class Unet(nn.Module):
|
||||
*,
|
||||
lowres_cond,
|
||||
channels,
|
||||
cond_on_image_embeds
|
||||
cond_on_image_embeds,
|
||||
cond_on_text_encodings
|
||||
):
|
||||
if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds:
|
||||
if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds and cond_on_text_encodings == self.cond_on_text_encodings:
|
||||
return self
|
||||
|
||||
updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds}
|
||||
updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds, 'cond_on_text_encodings': cond_on_text_encodings}
|
||||
return self.__class__(**{**self._locals, **updated_kwargs})
|
||||
|
||||
def forward_with_cond_scale(
|
||||
@@ -1487,7 +1618,9 @@ class Decoder(BaseGaussianDiffusion):
|
||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
||||
clip_denoised = True,
|
||||
clip_x_start = True
|
||||
clip_x_start = True,
|
||||
clip_adapter_overrides = dict(),
|
||||
unconditional = False
|
||||
):
|
||||
super().__init__(
|
||||
beta_schedule = beta_schedule,
|
||||
@@ -1495,12 +1628,17 @@ class Decoder(BaseGaussianDiffusion):
|
||||
loss_type = loss_type
|
||||
)
|
||||
|
||||
self.unconditional = unconditional
|
||||
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||
|
||||
assert exists(clip) ^ exists(image_size), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
||||
|
||||
self.clip = None
|
||||
if exists(clip):
|
||||
if isinstance(clip, CLIP):
|
||||
clip = XClipAdapter(clip)
|
||||
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
||||
elif isinstance(clip, CoCa):
|
||||
clip = CoCaAdapter(clip, **clip_adapter_overrides)
|
||||
|
||||
freeze_model_and_make_eval_(clip)
|
||||
assert isinstance(clip, BaseClipAdapter)
|
||||
@@ -1534,7 +1672,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
one_unet = one_unet.cast_model_parameters(
|
||||
lowres_cond = not is_first,
|
||||
cond_on_image_embeds = is_first,
|
||||
cond_on_image_embeds = is_first and not unconditional,
|
||||
cond_on_text_encodings = one_unet.cond_on_text_encodings and not unconditional,
|
||||
channels = unet_channels
|
||||
)
|
||||
|
||||
@@ -1669,12 +1808,16 @@ class Decoder(BaseGaussianDiffusion):
|
||||
@eval_decorator
|
||||
def sample(
|
||||
self,
|
||||
image_embed,
|
||||
image_embed = None,
|
||||
text = None,
|
||||
batch_size = 1,
|
||||
cond_scale = 1.,
|
||||
stop_at_unet_number = None
|
||||
):
|
||||
batch_size = image_embed.shape[0]
|
||||
assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
|
||||
|
||||
if not self.unconditional:
|
||||
batch_size = image_embed.shape[0]
|
||||
|
||||
text_encodings = text_mask = None
|
||||
if exists(text):
|
||||
@@ -1684,10 +1827,11 @@ class Decoder(BaseGaussianDiffusion):
|
||||
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
||||
|
||||
img = None
|
||||
is_cuda = next(self.parameters()).is_cuda
|
||||
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
||||
|
||||
context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context()
|
||||
context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
|
||||
|
||||
with context:
|
||||
lowres_cond_img = None
|
||||
@@ -1827,3 +1971,4 @@ class DALLE2(nn.Module):
|
||||
return images[0]
|
||||
|
||||
return images
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import time
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
@@ -39,6 +40,50 @@ def groupby_prefix_and_trim(prefix, d):
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
# print helpers
|
||||
|
||||
def print_ribbon(s, symbol = '=', repeat = 40):
|
||||
flank = symbol * repeat
|
||||
return f'{flank} {s} {flank}'
|
||||
|
||||
# saving and loading functions
|
||||
|
||||
# for diffusion prior
|
||||
|
||||
def load_diffusion_model(dprior_path, device):
|
||||
dprior_path = Path(dprior_path)
|
||||
assert dprior_path.exists(), 'Dprior model file does not exist'
|
||||
loaded_obj = torch.load(str(dprior_path), map_location='cpu')
|
||||
|
||||
# Get hyperparameters of loaded model
|
||||
dpn_config = loaded_obj['hparams']['diffusion_prior_network']
|
||||
dp_config = loaded_obj['hparams']['diffusion_prior']
|
||||
image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim']
|
||||
|
||||
# Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters
|
||||
|
||||
# DiffusionPriorNetwork
|
||||
prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device)
|
||||
|
||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
||||
diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device)
|
||||
|
||||
# Load state dict from saved model
|
||||
diffusion_prior.load_state_dict(loaded_obj['model'])
|
||||
|
||||
return diffusion_prior
|
||||
|
||||
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
|
||||
# Saving State Dict
|
||||
print_ribbon('Saving checkpoint')
|
||||
|
||||
state_dict = dict(model=model.state_dict(),
|
||||
optimizer=optimizer.state_dict(),
|
||||
scaler=scaler.state_dict(),
|
||||
hparams = config,
|
||||
image_embed_dim = {"image_embed_dim":image_embed_dim})
|
||||
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
|
||||
|
||||
# exponential moving average wrapper
|
||||
|
||||
class EMA(nn.Module):
|
||||
@@ -111,11 +156,6 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
# exponential moving average
|
||||
|
||||
self.use_ema = use_ema
|
||||
|
||||
if use_ema:
|
||||
has_lazy_linear = any([type(module) == nn.LazyLinear for module in diffusion_prior.modules()])
|
||||
assert not has_lazy_linear, 'you must set the text_embed_dim on your u-nets if you plan on doing automatic exponential moving average'
|
||||
|
||||
if self.use_ema:
|
||||
self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs)
|
||||
|
||||
|
||||
@@ -3,14 +3,15 @@ import copy
|
||||
from random import choice
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from PIL import Image
|
||||
from torchvision.datasets import ImageFolder
|
||||
import torchvision.transforms as T
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
from torch.utils.data import Dataset, DataLoader, random_split
|
||||
|
||||
import torchvision.transforms as T
|
||||
from torchvision.datasets import ImageFolder
|
||||
from torchvision.utils import make_grid, save_image
|
||||
|
||||
from einops import rearrange
|
||||
@@ -99,6 +100,7 @@ class VQGanVAETrainer(nn.Module):
|
||||
ema_update_after_step = 2000,
|
||||
ema_update_every = 10,
|
||||
apply_grad_penalty_every = 4,
|
||||
amp = False
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
|
||||
@@ -120,6 +122,10 @@ class VQGanVAETrainer(nn.Module):
|
||||
self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
|
||||
self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
|
||||
|
||||
self.amp = amp
|
||||
self.scaler = GradScaler(enabled = amp)
|
||||
self.discr_scaler = GradScaler(enabled = amp)
|
||||
|
||||
# create dataset
|
||||
|
||||
self.ds = ImageDataset(folder, image_size = image_size)
|
||||
@@ -178,20 +184,22 @@ class VQGanVAETrainer(nn.Module):
|
||||
img = next(self.dl)
|
||||
img = img.to(device)
|
||||
|
||||
loss = self.vae(
|
||||
img,
|
||||
return_loss = True,
|
||||
apply_grad_penalty = apply_grad_penalty
|
||||
)
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.vae(
|
||||
img,
|
||||
return_loss = True,
|
||||
apply_grad_penalty = apply_grad_penalty
|
||||
)
|
||||
|
||||
|
||||
self.scaler.scale(loss / self.grad_accum_every).backward()
|
||||
|
||||
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
|
||||
|
||||
(loss / self.grad_accum_every).backward()
|
||||
|
||||
self.optim.step()
|
||||
self.scaler.step(self.optim)
|
||||
self.scaler.update()
|
||||
self.optim.zero_grad()
|
||||
|
||||
|
||||
# update discriminator
|
||||
|
||||
if exists(self.vae.discr):
|
||||
@@ -200,12 +208,15 @@ class VQGanVAETrainer(nn.Module):
|
||||
img = next(self.dl)
|
||||
img = img.to(device)
|
||||
|
||||
loss = self.vae(img, return_discr_loss = True)
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.vae(img, return_discr_loss = True)
|
||||
|
||||
self.discr_scaler.scale(loss / self.grad_accum_every).backward()
|
||||
|
||||
accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
|
||||
|
||||
(loss / self.grad_accum_every).backward()
|
||||
|
||||
self.discr_optim.step()
|
||||
self.discr_scaler.step(self.discr_optim)
|
||||
self.discr_scaler.update()
|
||||
self.discr_optim.zero_grad()
|
||||
|
||||
# log
|
||||
|
||||
5
setup.py
5
setup.py
@@ -10,11 +10,12 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.0.108',
|
||||
version = '0.2.8',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
author_email = 'lucidrains@gmail.com',
|
||||
long_description_content_type = 'text/markdown',
|
||||
url = 'https://github.com/lucidrains/dalle2-pytorch',
|
||||
keywords = [
|
||||
'artificial intelligence',
|
||||
@@ -24,12 +25,14 @@ setup(
|
||||
install_requires=[
|
||||
'click',
|
||||
'clip-anytorch',
|
||||
'coca-pytorch>=0.0.5',
|
||||
'einops>=0.4',
|
||||
'einops-exts>=0.0.3',
|
||||
'embedding-reader',
|
||||
'kornia>=0.5.4',
|
||||
'pillow',
|
||||
'resize-right>=0.0.2',
|
||||
'rotary-embedding-torch',
|
||||
'torch>=1.10',
|
||||
'torchvision',
|
||||
'tqdm',
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch
|
||||
from torch import nn
|
||||
from embedding_reader import EmbeddingReader
|
||||
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
|
||||
from dalle2_pytorch.train import load_diffusion_model, save_diffusion_model, print_ribbon
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
from torch.cuda.amp import autocast,GradScaler
|
||||
|
||||
@@ -41,37 +42,56 @@ def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_t
|
||||
avg_loss = (total_loss / total_samples)
|
||||
wandb.log({f'{phase} {loss_type}': avg_loss})
|
||||
|
||||
def save_model(save_path, state_dict):
|
||||
# Saving State Dict
|
||||
print("====================================== Saving checkpoint ======================================")
|
||||
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
|
||||
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,NUM_TEST_EMBEDDINGS,device):
|
||||
diffusion_prior.eval()
|
||||
|
||||
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,val_set_size,NUM_TEST_EMBEDDINGS,device):
|
||||
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
||||
|
||||
tstart = train_set_size+val_set_size
|
||||
tend = train_set_size+val_set_size+NUM_TEST_EMBEDDINGS
|
||||
|
||||
for embt, embi in zip(text_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend),image_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend)):
|
||||
text_embed = torch.tensor(embt[0]).to(device)
|
||||
text_embed = text_embed / text_embed.norm(dim=1, keepdim=True)
|
||||
test_text_cond = dict(text_embed = text_embed)
|
||||
|
||||
test_image_embeddings = torch.tensor(embi[0]).to(device)
|
||||
test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(dim=1, keepdim=True)
|
||||
|
||||
predicted_image_embeddings = diffusion_prior.p_sample_loop((NUM_TEST_EMBEDDINGS, 768), text_cond = test_text_cond)
|
||||
predicted_image_embeddings = predicted_image_embeddings / predicted_image_embeddings.norm(dim=1, keepdim=True)
|
||||
|
||||
original_similarity = cos(text_embed,test_image_embeddings).cpu().numpy()
|
||||
predicted_similarity = cos(text_embed,predicted_image_embeddings).cpu().numpy()
|
||||
|
||||
wandb.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)})
|
||||
wandb.log({"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity)})
|
||||
|
||||
return np.mean(predicted_similarity - original_similarity)
|
||||
|
||||
tstart = train_set_size
|
||||
tend = train_set_size+NUM_TEST_EMBEDDINGS
|
||||
|
||||
for embt, embi in zip(text_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend),
|
||||
image_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend)):
|
||||
# make a copy of the text embeddings for shuffling
|
||||
text_embed = torch.tensor(embt[0]).to(device)
|
||||
text_embed_shuffled = text_embed.clone()
|
||||
# roll the text embeddings to simulate "unrelated" captions
|
||||
rolled_idx = torch.roll(torch.arange(NUM_TEST_EMBEDDINGS), 1)
|
||||
text_embed_shuffled = text_embed_shuffled[rolled_idx]
|
||||
text_embed_shuffled = text_embed_shuffled / \
|
||||
text_embed_shuffled.norm(dim=1, keepdim=True)
|
||||
test_text_shuffled_cond = dict(text_embed=text_embed_shuffled)
|
||||
# prepare the text embedding
|
||||
text_embed = text_embed / text_embed.norm(dim=1, keepdim=True)
|
||||
test_text_cond = dict(text_embed=text_embed)
|
||||
# prepare image embeddings
|
||||
test_image_embeddings = torch.tensor(embi[0]).to(device)
|
||||
test_image_embeddings = test_image_embeddings / \
|
||||
test_image_embeddings.norm(dim=1, keepdim=True)
|
||||
# predict on the unshuffled text embeddings
|
||||
predicted_image_embeddings = diffusion_prior.p_sample_loop(
|
||||
(NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_cond)
|
||||
predicted_image_embeddings = predicted_image_embeddings / \
|
||||
predicted_image_embeddings.norm(dim=1, keepdim=True)
|
||||
# predict on the shuffled embeddings
|
||||
predicted_unrelated_embeddings = diffusion_prior.p_sample_loop(
|
||||
(NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_shuffled_cond)
|
||||
predicted_unrelated_embeddings = predicted_unrelated_embeddings / \
|
||||
predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
|
||||
# calculate similarities
|
||||
original_similarity = cos(
|
||||
text_embed, test_image_embeddings).cpu().numpy()
|
||||
predicted_similarity = cos(
|
||||
text_embed, predicted_image_embeddings).cpu().numpy()
|
||||
unrelated_similarity = cos(
|
||||
text_embed, predicted_unrelated_embeddings).cpu().numpy()
|
||||
predicted_img_similarity = cos(
|
||||
test_image_embeddings, predicted_image_embeddings).cpu().numpy()
|
||||
wandb.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
|
||||
"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity),
|
||||
"CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity),
|
||||
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
|
||||
"Cosine similarity difference":np.mean(predicted_similarity - original_similarity)})
|
||||
|
||||
def train(image_embed_dim,
|
||||
image_embed_url,
|
||||
@@ -93,9 +113,15 @@ def train(image_embed_dim,
|
||||
save_interval,
|
||||
save_path,
|
||||
device,
|
||||
RESUME,
|
||||
DPRIOR_PATH,
|
||||
config,
|
||||
wandb_entity,
|
||||
wandb_project,
|
||||
learning_rate=0.001,
|
||||
max_grad_norm=0.5,
|
||||
weight_decay=0.01,
|
||||
dropout=0.05,
|
||||
amp=False):
|
||||
|
||||
# DiffusionPriorNetwork
|
||||
@@ -104,6 +130,8 @@ def train(image_embed_dim,
|
||||
depth = dpn_depth,
|
||||
dim_head = dpn_dim_head,
|
||||
heads = dpn_heads,
|
||||
attn_dropout = dropout,
|
||||
ff_dropout = dropout,
|
||||
normformer = dp_normformer).to(device)
|
||||
|
||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
||||
@@ -116,16 +144,21 @@ def train(image_embed_dim,
|
||||
loss_type = dp_loss_type,
|
||||
condition_on_text_encodings = dp_condition_on_text_encodings).to(device)
|
||||
|
||||
# Get image and text embeddings from the servers
|
||||
print("==============Downloading embeddings - image and text====================")
|
||||
image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy")
|
||||
text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy")
|
||||
num_data_points = text_reader.count
|
||||
# Load pre-trained model from DPRIOR_PATH
|
||||
if RESUME:
|
||||
diffusion_prior=load_diffusion_model(DPRIOR_PATH,device)
|
||||
wandb.init( entity=wandb_entity, project=wandb_project, config=config)
|
||||
|
||||
# Create save_path if it doesn't exist
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
|
||||
# Get image and text embeddings from the servers
|
||||
print_ribbon("Downloading embeddings - image and text")
|
||||
image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy")
|
||||
text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy")
|
||||
num_data_points = text_reader.count
|
||||
|
||||
### Training code ###
|
||||
scaler = GradScaler(enabled=amp)
|
||||
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
|
||||
@@ -136,12 +169,15 @@ def train(image_embed_dim,
|
||||
|
||||
train_set_size = int(train_percent*num_data_points)
|
||||
val_set_size = int(val_percent*num_data_points)
|
||||
eval_start = train_set_size
|
||||
|
||||
for _ in range(epochs):
|
||||
diffusion_prior.train()
|
||||
|
||||
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size),
|
||||
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
|
||||
|
||||
diffusion_prior.train()
|
||||
|
||||
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
|
||||
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
|
||||
|
||||
@@ -156,9 +192,13 @@ def train(image_embed_dim,
|
||||
if(int(time.time()-t) >= 60*save_interval):
|
||||
t = time.time()
|
||||
|
||||
save_model(
|
||||
save_diffusion_model(
|
||||
save_path,
|
||||
dict(model=diffusion_prior.state_dict(), optimizer=optimizer.state_dict(), scaler=scaler.state_dict()))
|
||||
diffusion_prior,
|
||||
optimizer,
|
||||
scaler,
|
||||
config,
|
||||
image_embed_dim)
|
||||
|
||||
# Log to wandb
|
||||
wandb.log({"Training loss": loss.item(),
|
||||
@@ -168,14 +208,22 @@ def train(image_embed_dim,
|
||||
# Use NUM_TEST_EMBEDDINGS samples from the test set each time
|
||||
# Get embeddings from the most recently saved model
|
||||
if(step % REPORT_METRICS_EVERY) == 0:
|
||||
diff_cosine_sim = report_cosine_sims(diffusion_prior,
|
||||
report_cosine_sims(diffusion_prior,
|
||||
image_reader,
|
||||
text_reader,
|
||||
train_set_size,
|
||||
val_set_size,
|
||||
NUM_TEST_EMBEDDINGS,
|
||||
device)
|
||||
wandb.log({"Cosine similarity difference": diff_cosine_sim})
|
||||
### Evaluate model(validation run) ###
|
||||
eval_model(diffusion_prior,
|
||||
device,
|
||||
image_reader,
|
||||
text_reader,
|
||||
eval_start,
|
||||
eval_start+NUM_TEST_EMBEDDINGS,
|
||||
NUM_TEST_EMBEDDINGS,
|
||||
dp_loss_type,
|
||||
phase="Validation")
|
||||
|
||||
scaler.unscale_(optimizer)
|
||||
nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
|
||||
@@ -184,11 +232,6 @@ def train(image_embed_dim,
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
### Evaluate model(validation run) ###
|
||||
start = train_set_size
|
||||
end=start+val_set_size
|
||||
eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Validation")
|
||||
|
||||
### Test run ###
|
||||
test_set_size = int(test_percent*train_set_size)
|
||||
start=train_set_size+val_set_size
|
||||
@@ -200,7 +243,6 @@ def main():
|
||||
# Logging
|
||||
parser.add_argument("--wandb-entity", type=str, default="laion")
|
||||
parser.add_argument("--wandb-project", type=str, default="diffusion-prior")
|
||||
parser.add_argument("--wandb-name", type=str, default="laion-dprior")
|
||||
parser.add_argument("--wandb-dataset", type=str, default="LAION-5B")
|
||||
parser.add_argument("--wandb-arch", type=str, default="DiffusionPrior")
|
||||
# URLs for embeddings
|
||||
@@ -209,6 +251,7 @@ def main():
|
||||
# Hyperparameters
|
||||
parser.add_argument("--learning-rate", type=float, default=1.1e-4)
|
||||
parser.add_argument("--weight-decay", type=float, default=6.02e-2)
|
||||
parser.add_argument("--dropout", type=float, default=5e-2)
|
||||
parser.add_argument("--max-grad-norm", type=float, default=0.5)
|
||||
parser.add_argument("--batch-size", type=int, default=10**4)
|
||||
parser.add_argument("--num-epochs", type=int, default=5)
|
||||
@@ -226,7 +269,6 @@ def main():
|
||||
# DiffusionPrior(dp) parameters
|
||||
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
|
||||
parser.add_argument("--dp-timesteps", type=int, default=100)
|
||||
parser.add_argument("--dp-l2norm-output", type=bool, default=False)
|
||||
parser.add_argument("--dp-normformer", type=bool, default=False)
|
||||
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.1)
|
||||
parser.add_argument("--dp-loss-type", type=str, default="l2")
|
||||
@@ -235,22 +277,40 @@ def main():
|
||||
# Model checkpointing interval(minutes)
|
||||
parser.add_argument("--save-interval", type=int, default=30)
|
||||
parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints")
|
||||
# Saved model path
|
||||
parser.add_argument("--pretrained-model-path", type=str, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("Setting up wandb logging... Please wait...")
|
||||
config = ({"learning_rate": args.learning_rate,
|
||||
"architecture": args.wandb_arch,
|
||||
"dataset": args.wandb_dataset,
|
||||
"weight_decay":args.weight_decay,
|
||||
"max_gradient_clipping_norm":args.max_grad_norm,
|
||||
"batch_size":args.batch_size,
|
||||
"epochs": args.num_epochs,
|
||||
"diffusion_prior_network":{"depth":args.dpn_depth,
|
||||
"dim_head":args.dpn_dim_head,
|
||||
"heads":args.dpn_heads,
|
||||
"normformer":args.dp_normformer},
|
||||
"diffusion_prior":{"condition_on_text_encodings": args.dp_condition_on_text_encodings,
|
||||
"timesteps": args.dp_timesteps,
|
||||
"cond_drop_prob":args.dp_cond_drop_prob,
|
||||
"loss_type":args.dp_loss_type,
|
||||
"clip":args.clip}
|
||||
})
|
||||
|
||||
wandb.init(
|
||||
entity=args.wandb_entity,
|
||||
project=args.wandb_project,
|
||||
config={
|
||||
"learning_rate": args.learning_rate,
|
||||
"architecture": args.wandb_arch,
|
||||
"dataset": args.wandb_dataset,
|
||||
"epochs": args.num_epochs,
|
||||
})
|
||||
RESUME = False
|
||||
# Check if DPRIOR_PATH exists(saved model path)
|
||||
DPRIOR_PATH = args.pretrained_model_path
|
||||
if(DPRIOR_PATH is not None):
|
||||
RESUME = True
|
||||
else:
|
||||
wandb.init(
|
||||
entity=args.wandb_entity,
|
||||
project=args.wandb_project,
|
||||
config=config)
|
||||
|
||||
print("wandb logging setup done!")
|
||||
# Obtain the utilized device.
|
||||
|
||||
has_cuda = torch.cuda.is_available()
|
||||
@@ -279,9 +339,15 @@ def main():
|
||||
args.save_interval,
|
||||
args.save_path,
|
||||
device,
|
||||
RESUME,
|
||||
DPRIOR_PATH,
|
||||
config,
|
||||
args.wandb_entity,
|
||||
args.wandb_project,
|
||||
args.learning_rate,
|
||||
args.max_grad_norm,
|
||||
args.weight_decay,
|
||||
args.dropout,
|
||||
args.amp)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user