Compare commits

..

1 Commits

7 changed files with 101 additions and 353 deletions

View File

@@ -902,7 +902,7 @@ Please note that the script internally passes text_embed and image_embed to the
### Usage ### Usage
```bash ```bash
$ python train_diffusion_prior.py $ pyhon train_diffusion_prior.py
``` ```
The most significant parameters for the script are as follows: The most significant parameters for the script are as follows:
@@ -927,39 +927,7 @@ The most significant parameters for the script are as follows:
### Sample wandb run log ### Sample wandb run log
Please find a sample wandb run log at : https://wandb.ai/laion/diffusion-prior/runs/1blxu24j Please find a sample wandb run log at : https://wandb.ai/laion/diffusion-prior/runs/aul0rhv5?workspace=
### Loading and saving the Diffusion Prior model
Two methods are provided, load_diffusion_model and save_diffusion_model, the names being self-explanatory.
## from dalle2_pytorch import load_diffusion_model, save_diffusion_model
load_diffusion_model(dprior_path, device)
dprior_path : path to saved model(.pth)
device : the cuda device you're running on
save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim)
save_path : path to save at
model : object of Diffusion_Prior
optimizer : optimizer object - see train_diffusion_prior.py for how to create one.
e.g: optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
scaler : a GradScaler object.
e.g: scaler = GradScaler(enabled=amp)
config : config object created in train_diffusion_prior.py - see file for example.
image_embed_dim - the dimension of the image_embedding
e.g: 768
## CLI (wip) ## CLI (wip)
@@ -998,9 +966,8 @@ Once built, images will be saved to the same directory the command is invoked
- [x] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet) - [x] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet)
- [x] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation) - [x] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation)
- [x] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias) - [x] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)
- [x] offer setting in diffusion prior to split time and image embeddings into multiple tokens, configurable, for more surface area during attention
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images - [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs - [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training - [ ] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
- [ ] train on a toy task, offer in colab - [ ] train on a toy task, offer in colab
@@ -1013,8 +980,6 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2 - [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783 - [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
- [ ] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor) - [ ] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
## Citations ## Citations
@@ -1082,14 +1047,4 @@ Once built, images will be saved to the same directory the command is invoked
} }
``` ```
```bibtex
@article{Yu2022CoCaCC,
title = {CoCa: Contrastive Captioners are Image-Text Foundation Models},
author = {Jiahui Yu and Zirui Wang and Vijay Vasudevan and Legg Yeung and Mojtaba Seyedhosseini and Yonghui Wu},
journal = {ArXiv},
year = {2022},
volume = {abs/2205.01917}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a> *Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -1,4 +1,4 @@
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder,load_diffusion_model,save_diffusion_model from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
from dalle2_pytorch.train import DecoderTrainer, DiffusionPriorTrainer from dalle2_pytorch.train import DecoderTrainer, DiffusionPriorTrainer

View File

@@ -4,8 +4,6 @@ from inspect import isfunction
from functools import partial from functools import partial
from contextlib import contextmanager from contextlib import contextmanager
from collections import namedtuple from collections import namedtuple
from pathlib import Path
import time
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -25,50 +23,9 @@ from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
from resize_right import resize from resize_right import resize
# rotary embeddings
from rotary_embedding_torch import RotaryEmbedding
# use x-clip # use x-clip
from x_clip import CLIP from x_clip import CLIP
from coca_pytorch import CoCa
# Diffusion Prior model loading and saving functions
def load_diffusion_model(dprior_path, device ):
dprior_path = Path(dprior_path)
assert dprior_path.exists(), 'Dprior model file does not exist'
loaded_obj = torch.load(str(dprior_path), map_location='cpu')
# Get hyperparameters of loaded model
dpn_config = loaded_obj['hparams']['diffusion_prior_network']
dp_config = loaded_obj['hparams']['diffusion_prior']
image_embed_dim = loaded_obj['image_embed_dim']['image_embed_dim']
# Create DiffusionPriorNetwork and DiffusionPrior with loaded hyperparameters
# DiffusionPriorNetwork
prior_network = DiffusionPriorNetwork( dim = image_embed_dim, **dpn_config).to(device)
# DiffusionPrior with text embeddings and image embeddings pre-computed
diffusion_prior = DiffusionPrior(net = prior_network, **dp_config, image_embed_dim = image_embed_dim).to(device)
# Load state dict from saved model
diffusion_prior.load_state_dict(loaded_obj['model'])
return diffusion_prior
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
# Saving State Dict
print("====================================== Saving checkpoint ======================================")
state_dict = dict(model=model.state_dict(),
optimizer=optimizer.state_dict(),
scaler=scaler.state_dict(),
hparams = config,
image_embed_dim = {"image_embed_dim":image_embed_dim})
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
# helper functions # helper functions
@@ -156,10 +113,9 @@ EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 't
EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings']) EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])
class BaseClipAdapter(nn.Module): class BaseClipAdapter(nn.Module):
def __init__(self, clip, **kwargs): def __init__(self, clip):
super().__init__() super().__init__()
self.clip = clip self.clip = clip
self.overrides = kwargs
@property @property
def dim_latent(self): def dim_latent(self):
@@ -217,39 +173,6 @@ class XClipAdapter(BaseClipAdapter):
image_embed = self.clip.to_visual_latent(image_cls) image_embed = self.clip.to_visual_latent(image_cls)
return EmbeddedImage(l2norm(image_embed), image_encodings) return EmbeddedImage(l2norm(image_embed), image_encodings)
class CoCaAdapter(BaseClipAdapter):
@property
def dim_latent(self):
return self.clip.dim
@property
def image_size(self):
assert 'image_size' in self.overrides
return self.overrides['image_size']
@property
def image_channels(self):
assert 'image_channels' in self.overrides
return self.overrides['image_channels']
@property
def max_text_len(self):
assert 'max_text_len' in self.overrides
return self.overrides['max_text_len']
@torch.no_grad()
def embed_text(self, text):
text = text[..., :self.max_text_len]
text_mask = text != 0
text_embed, text_encodings = self.clip.embed_text(text)
return EmbeddedText(text_embed, text_encodings, text_mask)
@torch.no_grad()
def embed_image(self, image):
image = resize_image_to(image, self.image_size)
image_embed, image_encodings = self.clip.embed_image(image)
return EmbeddedImage(image_embed, image_encodings)
class OpenAIClipAdapter(BaseClipAdapter): class OpenAIClipAdapter(BaseClipAdapter):
def __init__( def __init__(
self, self,
@@ -302,7 +225,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
text_embed = self.clip.encode_text(text) text_embed = self.clip.encode_text(text)
text_encodings = self.text_encodings text_encodings = self.text_encodings
del self.text_encodings del self.text_encodings
return EmbeddedText(l2norm(text_embed.float()), text_encodings.float(), text_mask) return EmbeddedText(text_embed.float(), text_encodings.float(), text_mask)
@torch.no_grad() @torch.no_grad()
def embed_image(self, image): def embed_image(self, image):
@@ -310,7 +233,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
image = resize_image_to(image, self.image_size) image = resize_image_to(image, self.image_size)
image = self.clip_normalize(unnormalize_img(image)) image = self.clip_normalize(unnormalize_img(image))
image_embed = self.clip.encode_image(image) image_embed = self.clip.encode_image(image)
return EmbeddedImage(l2norm(image_embed.float()), None) return EmbeddedImage(image_embed.float(), None)
# classifier free guidance functions # classifier free guidance functions
@@ -608,8 +531,7 @@ class Attention(nn.Module):
heads = 8, heads = 8,
dropout = 0., dropout = 0.,
causal = False, causal = False,
post_norm = False, post_norm = False
rotary_emb = None
): ):
super().__init__() super().__init__()
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
@@ -625,8 +547,6 @@ class Attention(nn.Module):
self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False) self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
self.rotary_emb = rotary_emb
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False), nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim) if post_norm else nn.Identity() LayerNorm(dim) if post_norm else nn.Identity()
@@ -639,12 +559,6 @@ class Attention(nn.Module):
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)) q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
q = q * self.scale
# rotary embeddings
if exists(self.rotary_emb):
q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k))
# add null key / value for classifier free guidance in prior net # add null key / value for classifier free guidance in prior net
@@ -652,7 +566,7 @@ class Attention(nn.Module):
k = torch.cat((nk, k), dim = -2) k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2) v = torch.cat((nv, v), dim = -2)
# calculate query / key similarities q = q * self.scale
sim = einsum('b h i d, b j d -> b h i j', q, k) sim = einsum('b h i d, b j d -> b h i j', q, k)
@@ -702,18 +616,15 @@ class CausalTransformer(nn.Module):
attn_dropout = 0., attn_dropout = 0.,
ff_dropout = 0., ff_dropout = 0.,
final_proj = True, final_proj = True,
normformer = False, normformer = False
rotary_emb = True
): ):
super().__init__() super().__init__()
self.rel_pos_bias = RelPosBias(heads = heads) self.rel_pos_bias = RelPosBias(heads = heads)
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
for _ in range(depth): for _ in range(depth):
self.layers.append(nn.ModuleList([ self.layers.append(nn.ModuleList([
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb), Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer) FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
])) ]))
@@ -741,31 +652,10 @@ class DiffusionPriorNetwork(nn.Module):
self, self,
dim, dim,
num_timesteps = None, num_timesteps = None,
num_time_embeds = 1,
num_image_embeds = 1,
num_text_embeds = 1,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
self.num_time_embeds = num_time_embeds self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
self.num_image_embeds = num_image_embeds
self.num_text_embeds = num_text_embeds
self.to_text_embeds = nn.Sequential(
nn.Linear(dim, dim * num_text_embeds) if num_text_embeds > 1 else nn.Identity(),
Rearrange('b (n d) -> b n d', n = num_text_embeds)
)
self.to_time_embeds = nn.Sequential(
nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
Rearrange('b (n d) -> b n d', n = num_time_embeds)
)
self.to_image_embeds = nn.Sequential(
nn.Linear(dim, dim * num_image_embeds) if num_image_embeds > 1 else nn.Identity(),
Rearrange('b (n d) -> b n d', n = num_image_embeds)
)
self.learned_query = nn.Parameter(torch.randn(dim)) self.learned_query = nn.Parameter(torch.randn(dim))
self.causal_transformer = CausalTransformer(dim = dim, **kwargs) self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
@@ -795,13 +685,10 @@ class DiffusionPriorNetwork(nn.Module):
): ):
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds
# in section 2.2, last paragraph # in section 2.2, last paragraph
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction" # "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
text_embed = self.to_text_embeds(text_embed) text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
image_embed = self.to_image_embeds(image_embed)
# make text encodings optional # make text encodings optional
# although the paper seems to suggest it is present <-- # although the paper seems to suggest it is present <--
@@ -821,17 +708,16 @@ class DiffusionPriorNetwork(nn.Module):
# whether text embedding is masked or not depends on the classifier free guidance conditional masking # whether text embedding is masked or not depends on the classifier free guidance conditional masking
keep_mask = repeat(keep_mask, 'b 1 -> b n', n = num_text_embeds)
mask = torch.cat((mask, keep_mask), dim = 1) mask = torch.cat((mask, keep_mask), dim = 1)
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different) # whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right # but let's just do it right
if exists(mask): if exists(mask):
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
time_embed = self.to_time_embeds(diffusion_timesteps) time_embed = self.time_embeddings(diffusion_timesteps)
time_embed = rearrange(time_embed, 'b d -> b 1 d')
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch) learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
@@ -839,7 +725,6 @@ class DiffusionPriorNetwork(nn.Module):
text_encodings, text_encodings,
text_embed, text_embed,
time_embed, time_embed,
image_embed,
learned_queries learned_queries
), dim = -2) ), dim = -2)
@@ -863,16 +748,13 @@ class DiffusionPrior(BaseGaussianDiffusion):
image_size = None, image_size = None,
image_channels = 3, image_channels = 3,
timesteps = 1000, timesteps = 1000,
cond_drop_prob = 0., cond_drop_prob = 0.2,
loss_type = "l1", loss_type = "l1",
predict_x_start = True, predict_x_start = True,
beta_schedule = "cosine", beta_schedule = "cosine",
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
sampling_clamp_l2norm = False, sampling_clamp_l2norm = False,
training_clamp_l2norm = False,
init_image_embed_l2norm = False,
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132 image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
clip_adapter_overrides = dict()
): ):
super().__init__( super().__init__(
beta_schedule = beta_schedule, beta_schedule = beta_schedule,
@@ -882,9 +764,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
if exists(clip): if exists(clip):
if isinstance(clip, CLIP): if isinstance(clip, CLIP):
clip = XClipAdapter(clip, **clip_adapter_overrides) clip = XClipAdapter(clip)
elif isinstance(clip, CoCa):
clip = CoCaAdapter(clip, **clip_adapter_overrides)
assert isinstance(clip, BaseClipAdapter) assert isinstance(clip, BaseClipAdapter)
freeze_model_and_make_eval_(clip) freeze_model_and_make_eval_(clip)
@@ -908,8 +788,6 @@ class DiffusionPrior(BaseGaussianDiffusion):
# whether to force an l2norm, similar to clipping denoised, when sampling # whether to force an l2norm, similar to clipping denoised, when sampling
self.sampling_clamp_l2norm = sampling_clamp_l2norm self.sampling_clamp_l2norm = sampling_clamp_l2norm
self.training_clamp_l2norm = training_clamp_l2norm
self.init_image_embed_l2norm = init_image_embed_l2norm
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool): def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
pred = self.net(x, t, **text_cond) pred = self.net(x, t, **text_cond)
@@ -944,16 +822,11 @@ class DiffusionPrior(BaseGaussianDiffusion):
device = self.betas.device device = self.betas.device
b = shape[0] b = shape[0]
image_embed = torch.randn(shape, device=device) img = torch.randn(shape, device=device)
if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
times = torch.full((b,), i, device = device, dtype = torch.long) img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
image_embed = self.p_sample(image_embed, times, text_cond = text_cond) return img
return image_embed
def p_losses(self, image_embed, times, text_cond, noise = None): def p_losses(self, image_embed, times, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(image_embed)) noise = default(noise, lambda: torch.randn_like(image_embed))
@@ -967,9 +840,6 @@ class DiffusionPrior(BaseGaussianDiffusion):
**text_cond **text_cond
) )
if self.predict_x_start and self.training_clamp_l2norm:
pred = l2norm(pred) * self.image_embed_scale
target = noise if not self.predict_x_start else image_embed target = noise if not self.predict_x_start else image_embed
loss = self.loss_fn(pred, target) loss = self.loss_fn(pred, target)
@@ -1617,8 +1487,7 @@ class Decoder(BaseGaussianDiffusion):
blur_kernel_size = 3, # cascading ddpm - blur kernel size blur_kernel_size = 3, # cascading ddpm - blur kernel size
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
clip_denoised = True, clip_denoised = True,
clip_x_start = True, clip_x_start = True
clip_adapter_overrides = dict()
): ):
super().__init__( super().__init__(
beta_schedule = beta_schedule, beta_schedule = beta_schedule,
@@ -1631,9 +1500,7 @@ class Decoder(BaseGaussianDiffusion):
self.clip = None self.clip = None
if exists(clip): if exists(clip):
if isinstance(clip, CLIP): if isinstance(clip, CLIP):
clip = XClipAdapter(clip, **clip_adapter_overrides) clip = XClipAdapter(clip)
elif isinstance(clip, CoCa):
clip = CoCaAdapter(clip, **clip_adapter_overrides)
freeze_model_and_make_eval_(clip) freeze_model_and_make_eval_(clip)
assert isinstance(clip, BaseClipAdapter) assert isinstance(clip, BaseClipAdapter)
@@ -1960,4 +1827,3 @@ class DALLE2(nn.Module):
return images[0] return images[0]
return images return images

View File

@@ -111,6 +111,11 @@ class DiffusionPriorTrainer(nn.Module):
# exponential moving average # exponential moving average
self.use_ema = use_ema self.use_ema = use_ema
if use_ema:
has_lazy_linear = any([type(module) == nn.LazyLinear for module in diffusion_prior.modules()])
assert not has_lazy_linear, 'you must set the text_embed_dim on your u-nets if you plan on doing automatic exponential moving average'
if self.use_ema: if self.use_ema:
self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs) self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs)

View File

@@ -3,15 +3,14 @@ import copy
from random import choice from random import choice
from pathlib import Path from pathlib import Path
from shutil import rmtree from shutil import rmtree
from PIL import Image
import torch import torch
from torch import nn from torch import nn
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T from PIL import Image
from torchvision.datasets import ImageFolder from torchvision.datasets import ImageFolder
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.utils import make_grid, save_image from torchvision.utils import make_grid, save_image
from einops import rearrange from einops import rearrange
@@ -100,7 +99,6 @@ class VQGanVAETrainer(nn.Module):
ema_update_after_step = 2000, ema_update_after_step = 2000,
ema_update_every = 10, ema_update_every = 10,
apply_grad_penalty_every = 4, apply_grad_penalty_every = 4,
amp = False
): ):
super().__init__() super().__init__()
assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE' assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
@@ -122,10 +120,6 @@ class VQGanVAETrainer(nn.Module):
self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd) self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd) self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
self.amp = amp
self.scaler = GradScaler(enabled = amp)
self.discr_scaler = GradScaler(enabled = amp)
# create dataset # create dataset
self.ds = ImageDataset(folder, image_size = image_size) self.ds = ImageDataset(folder, image_size = image_size)
@@ -184,22 +178,20 @@ class VQGanVAETrainer(nn.Module):
img = next(self.dl) img = next(self.dl)
img = img.to(device) img = img.to(device)
with autocast(enabled = self.amp): loss = self.vae(
loss = self.vae( img,
img, return_loss = True,
return_loss = True, apply_grad_penalty = apply_grad_penalty
apply_grad_penalty = apply_grad_penalty )
)
self.scaler.scale(loss / self.grad_accum_every).backward()
accum_log(logs, {'loss': loss.item() / self.grad_accum_every}) accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
self.scaler.step(self.optim) (loss / self.grad_accum_every).backward()
self.scaler.update()
self.optim.step()
self.optim.zero_grad() self.optim.zero_grad()
# update discriminator # update discriminator
if exists(self.vae.discr): if exists(self.vae.discr):
@@ -208,15 +200,12 @@ class VQGanVAETrainer(nn.Module):
img = next(self.dl) img = next(self.dl)
img = img.to(device) img = img.to(device)
with autocast(enabled = self.amp): loss = self.vae(img, return_discr_loss = True)
loss = self.vae(img, return_discr_loss = True)
self.discr_scaler.scale(loss / self.grad_accum_every).backward()
accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every}) accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
self.discr_scaler.step(self.discr_optim) (loss / self.grad_accum_every).backward()
self.discr_scaler.update()
self.discr_optim.step()
self.discr_optim.zero_grad() self.discr_optim.zero_grad()
# log # log

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.2.1', version = '0.0.108',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',
@@ -24,14 +24,12 @@ setup(
install_requires=[ install_requires=[
'click', 'click',
'clip-anytorch', 'clip-anytorch',
'coca-pytorch>=0.0.5',
'einops>=0.4', 'einops>=0.4',
'einops-exts>=0.0.3', 'einops-exts>=0.0.3',
'embedding-reader', 'embedding-reader',
'kornia>=0.5.4', 'kornia>=0.5.4',
'pillow', 'pillow',
'resize-right>=0.0.2', 'resize-right>=0.0.2',
'rotary-embedding-torch',
'torch>=1.10', 'torch>=1.10',
'torchvision', 'torchvision',
'tqdm', 'tqdm',

View File

@@ -6,7 +6,7 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from embedding_reader import EmbeddingReader from embedding_reader import EmbeddingReader
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, load_diffusion_model, save_diffusion_model from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
from dalle2_pytorch.optimizer import get_optimizer from dalle2_pytorch.optimizer import get_optimizer
from torch.cuda.amp import autocast,GradScaler from torch.cuda.amp import autocast,GradScaler
@@ -41,56 +41,37 @@ def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_t
avg_loss = (total_loss / total_samples) avg_loss = (total_loss / total_samples)
wandb.log({f'{phase} {loss_type}': avg_loss}) wandb.log({f'{phase} {loss_type}': avg_loss})
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,NUM_TEST_EMBEDDINGS,device): def save_model(save_path, state_dict):
diffusion_prior.eval() # Saving State Dict
print("====================================== Saving checkpoint ======================================")
torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth')
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,val_set_size,NUM_TEST_EMBEDDINGS,device):
cos = nn.CosineSimilarity(dim=1, eps=1e-6) cos = nn.CosineSimilarity(dim=1, eps=1e-6)
tstart = train_set_size tstart = train_set_size+val_set_size
tend = train_set_size+NUM_TEST_EMBEDDINGS tend = train_set_size+val_set_size+NUM_TEST_EMBEDDINGS
for embt, embi in zip(text_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend),image_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend)):
text_embed = torch.tensor(embt[0]).to(device)
text_embed = text_embed / text_embed.norm(dim=1, keepdim=True)
test_text_cond = dict(text_embed = text_embed)
test_image_embeddings = torch.tensor(embi[0]).to(device)
test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(dim=1, keepdim=True)
predicted_image_embeddings = diffusion_prior.p_sample_loop((NUM_TEST_EMBEDDINGS, 768), text_cond = test_text_cond)
predicted_image_embeddings = predicted_image_embeddings / predicted_image_embeddings.norm(dim=1, keepdim=True)
original_similarity = cos(text_embed,test_image_embeddings).cpu().numpy()
predicted_similarity = cos(text_embed,predicted_image_embeddings).cpu().numpy()
wandb.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)})
wandb.log({"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity)})
return np.mean(predicted_similarity - original_similarity)
for embt, embi in zip(text_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend),
image_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend)):
# make a copy of the text embeddings for shuffling
text_embed = torch.tensor(embt[0]).to(device)
text_embed_shuffled = text_embed.clone()
# roll the text embeddings to simulate "unrelated" captions
rolled_idx = torch.roll(torch.arange(NUM_TEST_EMBEDDINGS), 1)
text_embed_shuffled = text_embed_shuffled[rolled_idx]
text_embed_shuffled = text_embed_shuffled / \
text_embed_shuffled.norm(dim=1, keepdim=True)
test_text_shuffled_cond = dict(text_embed=text_embed_shuffled)
# prepare the text embedding
text_embed = text_embed / text_embed.norm(dim=1, keepdim=True)
test_text_cond = dict(text_embed=text_embed)
# prepare image embeddings
test_image_embeddings = torch.tensor(embi[0]).to(device)
test_image_embeddings = test_image_embeddings / \
test_image_embeddings.norm(dim=1, keepdim=True)
# predict on the unshuffled text embeddings
predicted_image_embeddings = diffusion_prior.p_sample_loop(
(NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_cond)
predicted_image_embeddings = predicted_image_embeddings / \
predicted_image_embeddings.norm(dim=1, keepdim=True)
# predict on the shuffled embeddings
predicted_unrelated_embeddings = diffusion_prior.p_sample_loop(
(NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_shuffled_cond)
predicted_unrelated_embeddings = predicted_unrelated_embeddings / \
predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
# calculate similarities
original_similarity = cos(
text_embed, test_image_embeddings).cpu().numpy()
predicted_similarity = cos(
text_embed, predicted_image_embeddings).cpu().numpy()
unrelated_similarity = cos(
text_embed, predicted_unrelated_embeddings).cpu().numpy()
predicted_img_similarity = cos(
test_image_embeddings, predicted_image_embeddings).cpu().numpy()
wandb.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity),
"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity),
"CosineSimilarity(orig_image_embed,predicted_image_embed)":np.mean(predicted_img_similarity),
"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(unrelated_similarity),
"Cosine similarity difference":np.mean(predicted_similarity - original_similarity)})
def train(image_embed_dim, def train(image_embed_dim,
image_embed_url, image_embed_url,
@@ -112,15 +93,9 @@ def train(image_embed_dim,
save_interval, save_interval,
save_path, save_path,
device, device,
RESUME,
DPRIOR_PATH,
config,
wandb_entity,
wandb_project,
learning_rate=0.001, learning_rate=0.001,
max_grad_norm=0.5, max_grad_norm=0.5,
weight_decay=0.01, weight_decay=0.01,
dropout=0.05,
amp=False): amp=False):
# DiffusionPriorNetwork # DiffusionPriorNetwork
@@ -129,8 +104,6 @@ def train(image_embed_dim,
depth = dpn_depth, depth = dpn_depth,
dim_head = dpn_dim_head, dim_head = dpn_dim_head,
heads = dpn_heads, heads = dpn_heads,
attn_dropout = dropout,
ff_dropout = dropout,
normformer = dp_normformer).to(device) normformer = dp_normformer).to(device)
# DiffusionPrior with text embeddings and image embeddings pre-computed # DiffusionPrior with text embeddings and image embeddings pre-computed
@@ -143,21 +116,16 @@ def train(image_embed_dim,
loss_type = dp_loss_type, loss_type = dp_loss_type,
condition_on_text_encodings = dp_condition_on_text_encodings).to(device) condition_on_text_encodings = dp_condition_on_text_encodings).to(device)
# Load pre-trained model from DPRIOR_PATH
if RESUME:
diffusion_prior=load_diffusion_model(DPRIOR_PATH,device)
wandb.init( entity=wandb_entity, project=wandb_project, config=config)
# Create save_path if it doesn't exist
if not os.path.exists(save_path):
os.makedirs(save_path)
# Get image and text embeddings from the servers # Get image and text embeddings from the servers
print("==============Downloading embeddings - image and text====================") print("==============Downloading embeddings - image and text====================")
image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy") image_reader = EmbeddingReader(embeddings_folder=image_embed_url, file_format="npy")
text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy") text_reader = EmbeddingReader(embeddings_folder=text_embed_url, file_format="npy")
num_data_points = text_reader.count num_data_points = text_reader.count
# Create save_path if it doesn't exist
if not os.path.exists(save_path):
os.makedirs(save_path)
### Training code ### ### Training code ###
scaler = GradScaler(enabled=amp) scaler = GradScaler(enabled=amp)
optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate) optimizer = get_optimizer(diffusion_prior.net.parameters(), wd=weight_decay, lr=learning_rate)
@@ -168,15 +136,12 @@ def train(image_embed_dim,
train_set_size = int(train_percent*num_data_points) train_set_size = int(train_percent*num_data_points)
val_set_size = int(val_percent*num_data_points) val_set_size = int(val_percent*num_data_points)
eval_start = train_set_size
for _ in range(epochs): for _ in range(epochs):
diffusion_prior.train()
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size), for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size),
text_reader(batch_size=batch_size, start=0, end=train_set_size)): text_reader(batch_size=batch_size, start=0, end=train_set_size)):
diffusion_prior.train()
emb_images_tensor = torch.tensor(emb_images[0]).to(device) emb_images_tensor = torch.tensor(emb_images[0]).to(device)
emb_text_tensor = torch.tensor(emb_text[0]).to(device) emb_text_tensor = torch.tensor(emb_text[0]).to(device)
@@ -191,13 +156,9 @@ def train(image_embed_dim,
if(int(time.time()-t) >= 60*save_interval): if(int(time.time()-t) >= 60*save_interval):
t = time.time() t = time.time()
save_diffusion_model( save_model(
save_path, save_path,
diffusion_prior, dict(model=diffusion_prior.state_dict(), optimizer=optimizer.state_dict(), scaler=scaler.state_dict()))
optimizer,
scaler,
config,
image_embed_dim)
# Log to wandb # Log to wandb
wandb.log({"Training loss": loss.item(), wandb.log({"Training loss": loss.item(),
@@ -207,22 +168,14 @@ def train(image_embed_dim,
# Use NUM_TEST_EMBEDDINGS samples from the test set each time # Use NUM_TEST_EMBEDDINGS samples from the test set each time
# Get embeddings from the most recently saved model # Get embeddings from the most recently saved model
if(step % REPORT_METRICS_EVERY) == 0: if(step % REPORT_METRICS_EVERY) == 0:
report_cosine_sims(diffusion_prior, diff_cosine_sim = report_cosine_sims(diffusion_prior,
image_reader, image_reader,
text_reader, text_reader,
train_set_size, train_set_size,
val_set_size,
NUM_TEST_EMBEDDINGS, NUM_TEST_EMBEDDINGS,
device) device)
### Evaluate model(validation run) ### wandb.log({"Cosine similarity difference": diff_cosine_sim})
eval_model(diffusion_prior,
device,
image_reader,
text_reader,
eval_start,
eval_start+NUM_TEST_EMBEDDINGS,
NUM_TEST_EMBEDDINGS,
dp_loss_type,
phase="Validation")
scaler.unscale_(optimizer) scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm) nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)
@@ -231,6 +184,11 @@ def train(image_embed_dim,
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
### Evaluate model(validation run) ###
start = train_set_size
end=start+val_set_size
eval_model(diffusion_prior,device,image_reader,text_reader,start,end,batch_size,dp_loss_type,phase="Validation")
### Test run ### ### Test run ###
test_set_size = int(test_percent*train_set_size) test_set_size = int(test_percent*train_set_size)
start=train_set_size+val_set_size start=train_set_size+val_set_size
@@ -242,6 +200,7 @@ def main():
# Logging # Logging
parser.add_argument("--wandb-entity", type=str, default="laion") parser.add_argument("--wandb-entity", type=str, default="laion")
parser.add_argument("--wandb-project", type=str, default="diffusion-prior") parser.add_argument("--wandb-project", type=str, default="diffusion-prior")
parser.add_argument("--wandb-name", type=str, default="laion-dprior")
parser.add_argument("--wandb-dataset", type=str, default="LAION-5B") parser.add_argument("--wandb-dataset", type=str, default="LAION-5B")
parser.add_argument("--wandb-arch", type=str, default="DiffusionPrior") parser.add_argument("--wandb-arch", type=str, default="DiffusionPrior")
# URLs for embeddings # URLs for embeddings
@@ -250,7 +209,6 @@ def main():
# Hyperparameters # Hyperparameters
parser.add_argument("--learning-rate", type=float, default=1.1e-4) parser.add_argument("--learning-rate", type=float, default=1.1e-4)
parser.add_argument("--weight-decay", type=float, default=6.02e-2) parser.add_argument("--weight-decay", type=float, default=6.02e-2)
parser.add_argument("--dropout", type=float, default=5e-2)
parser.add_argument("--max-grad-norm", type=float, default=0.5) parser.add_argument("--max-grad-norm", type=float, default=0.5)
parser.add_argument("--batch-size", type=int, default=10**4) parser.add_argument("--batch-size", type=int, default=10**4)
parser.add_argument("--num-epochs", type=int, default=5) parser.add_argument("--num-epochs", type=int, default=5)
@@ -268,6 +226,7 @@ def main():
# DiffusionPrior(dp) parameters # DiffusionPrior(dp) parameters
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False) parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
parser.add_argument("--dp-timesteps", type=int, default=100) parser.add_argument("--dp-timesteps", type=int, default=100)
parser.add_argument("--dp-l2norm-output", type=bool, default=False)
parser.add_argument("--dp-normformer", type=bool, default=False) parser.add_argument("--dp-normformer", type=bool, default=False)
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.1) parser.add_argument("--dp-cond-drop-prob", type=float, default=0.1)
parser.add_argument("--dp-loss-type", type=str, default="l2") parser.add_argument("--dp-loss-type", type=str, default="l2")
@@ -276,40 +235,22 @@ def main():
# Model checkpointing interval(minutes) # Model checkpointing interval(minutes)
parser.add_argument("--save-interval", type=int, default=30) parser.add_argument("--save-interval", type=int, default=30)
parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints") parser.add_argument("--save-path", type=str, default="./diffusion_prior_checkpoints")
# Saved model path
parser.add_argument("--pretrained-model-path", type=str, default=None)
args = parser.parse_args() args = parser.parse_args()
config = ({"learning_rate": args.learning_rate, print("Setting up wandb logging... Please wait...")
"architecture": args.wandb_arch,
"dataset": args.wandb_dataset,
"weight_decay":args.weight_decay,
"max_gradient_clipping_norm":args.max_grad_norm,
"batch_size":args.batch_size,
"epochs": args.num_epochs,
"diffusion_prior_network":{"depth":args.dpn_depth,
"dim_head":args.dpn_dim_head,
"heads":args.dpn_heads,
"normformer":args.dp_normformer},
"diffusion_prior":{"condition_on_text_encodings": args.dp_condition_on_text_encodings,
"timesteps": args.dp_timesteps,
"cond_drop_prob":args.dp_cond_drop_prob,
"loss_type":args.dp_loss_type,
"clip":args.clip}
})
RESUME = False wandb.init(
# Check if DPRIOR_PATH exists(saved model path) entity=args.wandb_entity,
DPRIOR_PATH = args.pretrained_model_path project=args.wandb_project,
if(DPRIOR_PATH is not None): config={
RESUME = True "learning_rate": args.learning_rate,
else: "architecture": args.wandb_arch,
wandb.init( "dataset": args.wandb_dataset,
entity=args.wandb_entity, "epochs": args.num_epochs,
project=args.wandb_project, })
config=config)
print("wandb logging setup done!")
# Obtain the utilized device. # Obtain the utilized device.
has_cuda = torch.cuda.is_available() has_cuda = torch.cuda.is_available()
@@ -338,15 +279,9 @@ def main():
args.save_interval, args.save_interval,
args.save_path, args.save_path,
device, device,
RESUME,
DPRIOR_PATH,
config,
atgs.wandb_entity,
args.wandb_project,
args.learning_rate, args.learning_rate,
args.max_grad_norm, args.max_grad_norm,
args.weight_decay, args.weight_decay,
args.dropout,
args.amp) args.amp)
if __name__ == "__main__": if __name__ == "__main__":