Compare commits

...

1 Commits

5 changed files with 159 additions and 4 deletions

View File

@@ -786,6 +786,68 @@ mock_image_embed = torch.randn(4, 512).cuda()
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256) images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
``` ```
### Diffusion Prior Training
Similarly, one can use the `DiffusionPriorTrainer` to automatically instantiate and keep track of an exponential moving averaged prior.
```python
import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, DiffusionPriorTrainer, Unet, Decoder, CLIP
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
).cuda()
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# prior networks (with transformer)
prior_network = DiffusionPriorNetwork(
dim = 512,
depth = 6,
dim_head = 64,
heads = 8
).cuda()
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
diffusion_prior_trainer = DiffusionPriorTrainer(
diffusion_prior,
lr = 3e-4,
wd = 1e-2,
ema_beta = 0.99,
ema_update_after_step = 1000,
ema_update_every = 10,
)
loss = diffusion_prior_trainer(text, images)
loss.backward()
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
# after much of the above three lines in a loop
# you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings
```
### Decoder Dataloaders ### Decoder Dataloaders
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network. In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.

View File

@@ -1,6 +1,6 @@
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
from dalle2_pytorch.train import DecoderTrainer from dalle2_pytorch.train import DecoderTrainer, DiffusionPriorTrainer
from dalle2_pytorch.vqgan_vae import VQGanVAE from dalle2_pytorch.vqgan_vae import VQGanVAE
from x_clip import CLIP from x_clip import CLIP

View File

@@ -845,6 +845,18 @@ class DiffusionPrior(BaseGaussianDiffusion):
loss = self.loss_fn(pred, target) loss = self.loss_fn(pred, target)
return loss return loss
@torch.inference_mode()
@eval_decorator
def sample_batch_size(self, batch_size, text_cond):
device = self.betas.device
shape = (batch_size, self.image_embed_dim)
img = torch.randn(shape, device = device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond)
return img
@torch.inference_mode() @torch.inference_mode()
@eval_decorator @eval_decorator
def sample(self, text, num_samples_per_batch = 2): def sample(self, text, num_samples_per_batch = 2):

View File

@@ -5,7 +5,7 @@ import torch
from torch import nn from torch import nn
from torch.cuda.amp import autocast, GradScaler from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.optimizer import get_optimizer from dalle2_pytorch.optimizer import get_optimizer
# helper functions # helper functions
@@ -89,7 +89,88 @@ class EMA(nn.Module):
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs) return self.ema_model(*args, **kwargs)
# trainers # diffusion prior trainer
class DiffusionPriorTrainer(nn.Module):
def __init__(
self,
diffusion_prior,
use_ema = True,
lr = 3e-4,
wd = 1e-2,
max_grad_norm = None,
amp = False,
**kwargs
):
super().__init__()
assert isinstance(diffusion_prior, DiffusionPrior)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
self.diffusion_prior = diffusion_prior
# exponential moving average
self.use_ema = use_ema
if use_ema:
has_lazy_linear = any([type(module) == nn.LazyLinear for module in diffusion_prior.modules()])
assert not has_lazy_linear, 'you must set the text_embed_dim on your u-nets if you plan on doing automatic exponential moving average'
if self.use_ema:
self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs)
# optimizer and mixed precision stuff
self.amp = amp
self.scaler = GradScaler(enabled = amp)
self.optimizer = get_optimizer(
diffusion_prior.parameters(),
lr = lr,
wd = wd,
**kwargs
)
# gradient clipping if needed
self.max_grad_norm = max_grad_norm
def update(self):
if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
if self.use_ema:
self.ema_diffusion_prior.update()
@torch.inference_mode()
def p_sample_loop(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
@torch.inference_mode()
def sample(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
@torch.inference_mode()
def sample_batch_size(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
def forward(
self,
*args,
divisor = 1,
**kwargs
):
with autocast(enabled = self.amp):
loss = self.diffusion_prior(*args, **kwargs)
return self.scaler.scale(loss / divisor)
# decoder trainer
class DecoderTrainer(nn.Module): class DecoderTrainer(nn.Module):
def __init__( def __init__(

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.107', version = '0.0.108',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',