mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 14:14:21 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
740d644050 |
62
README.md
62
README.md
@@ -786,6 +786,68 @@ mock_image_embed = torch.randn(4, 512).cuda()
|
|||||||
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
|
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Diffusion Prior Training
|
||||||
|
|
||||||
|
Similarly, one can use the `DiffusionPriorTrainer` to automatically instantiate and keep track of an exponential moving averaged prior.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, DiffusionPriorTrainer, Unet, Decoder, CLIP
|
||||||
|
|
||||||
|
clip = CLIP(
|
||||||
|
dim_text = 512,
|
||||||
|
dim_image = 512,
|
||||||
|
dim_latent = 512,
|
||||||
|
num_text_tokens = 49408,
|
||||||
|
text_enc_depth = 6,
|
||||||
|
text_seq_len = 256,
|
||||||
|
text_heads = 8,
|
||||||
|
visual_enc_depth = 6,
|
||||||
|
visual_image_size = 256,
|
||||||
|
visual_patch_size = 32,
|
||||||
|
visual_heads = 8
|
||||||
|
).cuda()
|
||||||
|
|
||||||
|
# mock data
|
||||||
|
|
||||||
|
text = torch.randint(0, 49408, (4, 256)).cuda()
|
||||||
|
images = torch.randn(4, 3, 256, 256).cuda()
|
||||||
|
|
||||||
|
# prior networks (with transformer)
|
||||||
|
|
||||||
|
prior_network = DiffusionPriorNetwork(
|
||||||
|
dim = 512,
|
||||||
|
depth = 6,
|
||||||
|
dim_head = 64,
|
||||||
|
heads = 8
|
||||||
|
).cuda()
|
||||||
|
|
||||||
|
diffusion_prior = DiffusionPrior(
|
||||||
|
net = prior_network,
|
||||||
|
clip = clip,
|
||||||
|
timesteps = 100,
|
||||||
|
cond_drop_prob = 0.2
|
||||||
|
).cuda()
|
||||||
|
|
||||||
|
diffusion_prior_trainer = DiffusionPriorTrainer(
|
||||||
|
diffusion_prior,
|
||||||
|
lr = 3e-4,
|
||||||
|
wd = 1e-2,
|
||||||
|
ema_beta = 0.99,
|
||||||
|
ema_update_after_step = 1000,
|
||||||
|
ema_update_every = 10,
|
||||||
|
)
|
||||||
|
|
||||||
|
loss = diffusion_prior_trainer(text, images)
|
||||||
|
loss.backward()
|
||||||
|
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
|
||||||
|
|
||||||
|
# after much of the above three lines in a loop
|
||||||
|
# you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior
|
||||||
|
|
||||||
|
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings
|
||||||
|
```
|
||||||
|
|
||||||
### Decoder Dataloaders
|
### Decoder Dataloaders
|
||||||
|
|
||||||
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
|
In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
||||||
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
|
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
|
||||||
from dalle2_pytorch.train import DecoderTrainer
|
from dalle2_pytorch.train import DecoderTrainer, DiffusionPriorTrainer
|
||||||
|
|
||||||
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
||||||
from x_clip import CLIP
|
from x_clip import CLIP
|
||||||
|
|||||||
@@ -845,6 +845,18 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
loss = self.loss_fn(pred, target)
|
loss = self.loss_fn(pred, target)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
@eval_decorator
|
||||||
|
def sample_batch_size(self, batch_size, text_cond):
|
||||||
|
device = self.betas.device
|
||||||
|
shape = (batch_size, self.image_embed_dim)
|
||||||
|
|
||||||
|
img = torch.randn(shape, device = device)
|
||||||
|
|
||||||
|
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
||||||
|
img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond)
|
||||||
|
return img
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@eval_decorator
|
@eval_decorator
|
||||||
def sample(self, text, num_samples_per_batch = 2):
|
def sample(self, text, num_samples_per_batch = 2):
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.cuda.amp import autocast, GradScaler
|
from torch.cuda.amp import autocast, GradScaler
|
||||||
|
|
||||||
from dalle2_pytorch.dalle2_pytorch import Decoder
|
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
|
||||||
from dalle2_pytorch.optimizer import get_optimizer
|
from dalle2_pytorch.optimizer import get_optimizer
|
||||||
|
|
||||||
# helper functions
|
# helper functions
|
||||||
@@ -89,7 +89,88 @@ class EMA(nn.Module):
|
|||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.ema_model(*args, **kwargs)
|
return self.ema_model(*args, **kwargs)
|
||||||
|
|
||||||
# trainers
|
# diffusion prior trainer
|
||||||
|
|
||||||
|
class DiffusionPriorTrainer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
diffusion_prior,
|
||||||
|
use_ema = True,
|
||||||
|
lr = 3e-4,
|
||||||
|
wd = 1e-2,
|
||||||
|
max_grad_norm = None,
|
||||||
|
amp = False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert isinstance(diffusion_prior, DiffusionPrior)
|
||||||
|
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
||||||
|
|
||||||
|
self.diffusion_prior = diffusion_prior
|
||||||
|
|
||||||
|
# exponential moving average
|
||||||
|
|
||||||
|
self.use_ema = use_ema
|
||||||
|
|
||||||
|
if use_ema:
|
||||||
|
has_lazy_linear = any([type(module) == nn.LazyLinear for module in diffusion_prior.modules()])
|
||||||
|
assert not has_lazy_linear, 'you must set the text_embed_dim on your u-nets if you plan on doing automatic exponential moving average'
|
||||||
|
|
||||||
|
if self.use_ema:
|
||||||
|
self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs)
|
||||||
|
|
||||||
|
# optimizer and mixed precision stuff
|
||||||
|
|
||||||
|
self.amp = amp
|
||||||
|
|
||||||
|
self.scaler = GradScaler(enabled = amp)
|
||||||
|
|
||||||
|
self.optimizer = get_optimizer(
|
||||||
|
diffusion_prior.parameters(),
|
||||||
|
lr = lr,
|
||||||
|
wd = wd,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# gradient clipping if needed
|
||||||
|
|
||||||
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
if exists(self.max_grad_norm):
|
||||||
|
self.scaler.unscale_(self.optimizer)
|
||||||
|
nn.utils.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
|
||||||
|
|
||||||
|
self.scaler.step(self.optimizer)
|
||||||
|
self.scaler.update()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
if self.use_ema:
|
||||||
|
self.ema_diffusion_prior.update()
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def p_sample_loop(self, *args, **kwargs):
|
||||||
|
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def sample(self, *args, **kwargs):
|
||||||
|
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def sample_batch_size(self, *args, **kwargs):
|
||||||
|
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
divisor = 1,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
with autocast(enabled = self.amp):
|
||||||
|
loss = self.diffusion_prior(*args, **kwargs)
|
||||||
|
return self.scaler.scale(loss / divisor)
|
||||||
|
|
||||||
|
# decoder trainer
|
||||||
|
|
||||||
class DecoderTrainer(nn.Module):
|
class DecoderTrainer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
Reference in New Issue
Block a user