Compare commits

...

3 Commits

5 changed files with 231 additions and 12 deletions

View File

@@ -786,6 +786,68 @@ mock_image_embed = torch.randn(4, 512).cuda()
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
```
### Diffusion Prior Training
Similarly, one can use the `DiffusionPriorTrainer` to automatically instantiate and keep track of an exponential moving averaged prior.
```python
import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, DiffusionPriorTrainer, Unet, Decoder, CLIP
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
).cuda()
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# prior networks (with transformer)
prior_network = DiffusionPriorNetwork(
dim = 512,
depth = 6,
dim_head = 64,
heads = 8
).cuda()
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
diffusion_prior_trainer = DiffusionPriorTrainer(
diffusion_prior,
lr = 3e-4,
wd = 1e-2,
ema_beta = 0.99,
ema_update_after_step = 1000,
ema_update_every = 10,
)
loss = diffusion_prior_trainer(text, images)
loss.backward()
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
# after much of the above three lines in a loop
# you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings
```
### Decoder Dataloaders
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
@@ -985,4 +1047,14 @@ Once built, images will be saved to the same directory the command is invoked
}
```
```bibtex
@article{Yu2022CoCaCC,
title = {CoCa: Contrastive Captioners are Image-Text Foundation Models},
author = {Jiahui Yu and Zirui Wang and Vijay Vasudevan and Legg Yeung and Mojtaba Seyedhosseini and Yonghui Wu},
journal = {ArXiv},
year = {2022},
volume = {abs/2205.01917}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

View File

@@ -1,6 +1,6 @@
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
from dalle2_pytorch.train import DecoderTrainer
from dalle2_pytorch.train import DecoderTrainer, DiffusionPriorTrainer
from dalle2_pytorch.vqgan_vae import VQGanVAE
from x_clip import CLIP

View File

@@ -23,9 +23,14 @@ from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
from resize_right import resize
# rotary embeddings
from rotary_embedding_torch import RotaryEmbedding
# use x-clip
from x_clip import CLIP
from coca_pytorch import CoCa
# helper functions
@@ -113,9 +118,10 @@ EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings', 't
EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])
class BaseClipAdapter(nn.Module):
def __init__(self, clip):
def __init__(self, clip, **kwargs):
super().__init__()
self.clip = clip
self.overrides = kwargs
@property
def dim_latent(self):
@@ -173,6 +179,39 @@ class XClipAdapter(BaseClipAdapter):
image_embed = self.clip.to_visual_latent(image_cls)
return EmbeddedImage(l2norm(image_embed), image_encodings)
class CoCaAdapter(BaseClipAdapter):
@property
def dim_latent(self):
return self.clip.dim
@property
def image_size(self):
assert 'image_size' in self.overrides
return self.overrides['image_size']
@property
def image_channels(self):
assert 'image_channels' in self.overrides
return self.overrides['image_channels']
@property
def max_text_len(self):
assert 'max_text_len' in self.overrides
return self.overrides['max_text_len']
@torch.no_grad()
def embed_text(self, text):
text = text[..., :self.max_text_len]
text_mask = text != 0
text_embed, text_encodings = self.clip.embed_text(text)
return EmbeddedText(text_embed, text_encodings, text_mask)
@torch.no_grad()
def embed_image(self, image):
image = resize_image_to(image, self.image_size)
image_embed, image_encodings = self.clip.embed_image(image)
return EmbeddedImage(image_embed, image_encodings)
class OpenAIClipAdapter(BaseClipAdapter):
def __init__(
self,
@@ -531,7 +570,8 @@ class Attention(nn.Module):
heads = 8,
dropout = 0.,
causal = False,
post_norm = False
post_norm = False,
rotary_emb = None
):
super().__init__()
self.scale = dim_head ** -0.5
@@ -547,6 +587,8 @@ class Attention(nn.Module):
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
self.rotary_emb = rotary_emb
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim) if post_norm else nn.Identity()
@@ -559,6 +601,12 @@ class Attention(nn.Module):
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
q = q * self.scale
# rotary embeddings
if exists(self.rotary_emb):
q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k))
# add null key / value for classifier free guidance in prior net
@@ -566,7 +614,7 @@ class Attention(nn.Module):
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
q = q * self.scale
# calculate query / key similarities
sim = einsum('b h i d, b j d -> b h i j', q, k)
@@ -616,15 +664,18 @@ class CausalTransformer(nn.Module):
attn_dropout = 0.,
ff_dropout = 0.,
final_proj = True,
normformer = False
normformer = False,
rotary_emb = True
):
super().__init__()
self.rel_pos_bias = RelPosBias(heads = heads)
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer),
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
]))
@@ -755,6 +806,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
sampling_clamp_l2norm = False,
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
clip_adapter_overrides = dict()
):
super().__init__(
beta_schedule = beta_schedule,
@@ -764,7 +816,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
if exists(clip):
if isinstance(clip, CLIP):
clip = XClipAdapter(clip)
clip = XClipAdapter(clip, **clip_adapter_overrides)
elif isinstance(clip, CoCa):
clip = CoCaAdapter(clip, **clip_adapter_overrides)
assert isinstance(clip, BaseClipAdapter)
freeze_model_and_make_eval_(clip)
@@ -845,6 +899,18 @@ class DiffusionPrior(BaseGaussianDiffusion):
loss = self.loss_fn(pred, target)
return loss
@torch.inference_mode()
@eval_decorator
def sample_batch_size(self, batch_size, text_cond):
device = self.betas.device
shape = (batch_size, self.image_embed_dim)
img = torch.randn(shape, device = device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond)
return img
@torch.inference_mode()
@eval_decorator
def sample(self, text, num_samples_per_batch = 2):
@@ -1475,7 +1541,8 @@ class Decoder(BaseGaussianDiffusion):
blur_kernel_size = 3, # cascading ddpm - blur kernel size
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
clip_denoised = True,
clip_x_start = True
clip_x_start = True,
clip_adapter_overrides = dict()
):
super().__init__(
beta_schedule = beta_schedule,
@@ -1488,7 +1555,9 @@ class Decoder(BaseGaussianDiffusion):
self.clip = None
if exists(clip):
if isinstance(clip, CLIP):
clip = XClipAdapter(clip)
clip = XClipAdapter(clip, **clip_adapter_overrides)
elif isinstance(clip, CoCa):
clip = CoCaAdapter(clip, **clip_adapter_overrides)
freeze_model_and_make_eval_(clip)
assert isinstance(clip, BaseClipAdapter)

View File

@@ -5,7 +5,7 @@ import torch
from torch import nn
from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.optimizer import get_optimizer
# helper functions
@@ -89,7 +89,83 @@ class EMA(nn.Module):
def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs)
# trainers
# diffusion prior trainer
class DiffusionPriorTrainer(nn.Module):
def __init__(
self,
diffusion_prior,
use_ema = True,
lr = 3e-4,
wd = 1e-2,
max_grad_norm = None,
amp = False,
**kwargs
):
super().__init__()
assert isinstance(diffusion_prior, DiffusionPrior)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
self.diffusion_prior = diffusion_prior
# exponential moving average
self.use_ema = use_ema
if self.use_ema:
self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs)
# optimizer and mixed precision stuff
self.amp = amp
self.scaler = GradScaler(enabled = amp)
self.optimizer = get_optimizer(
diffusion_prior.parameters(),
lr = lr,
wd = wd,
**kwargs
)
# gradient clipping if needed
self.max_grad_norm = max_grad_norm
def update(self):
if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
if self.use_ema:
self.ema_diffusion_prior.update()
@torch.inference_mode()
def p_sample_loop(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
@torch.inference_mode()
def sample(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
@torch.inference_mode()
def sample_batch_size(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
def forward(
self,
*args,
divisor = 1,
**kwargs
):
with autocast(enabled = self.amp):
loss = self.diffusion_prior(*args, **kwargs)
return self.scaler.scale(loss / divisor)
# decoder trainer
class DecoderTrainer(nn.Module):
def __init__(

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.107',
version = '0.1.1',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
@@ -24,12 +24,14 @@ setup(
install_requires=[
'click',
'clip-anytorch',
'coca-pytorch>=0.0.5',
'einops>=0.4',
'einops-exts>=0.0.3',
'embedding-reader',
'kornia>=0.5.4',
'pillow',
'resize-right>=0.0.2',
'rotary-embedding-torch',
'torch>=1.10',
'torchvision',
'tqdm',