Compare commits

..

8 Commits

5 changed files with 68 additions and 19 deletions

View File

@@ -14,6 +14,16 @@ Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place. There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
## Status
- A research group has used the code in this repository to train a functional diffusion prior for their CLIP generations. Will share their work once they release their preprint. This, and <a href="https://github.com/crowsonkb">Katherine's</a> own experiments, validate OpenAI's finding that the extra prior increases variety of generations.
- Decoder is now verified working for unconditional generation on my experimental setup for Oxford flowers. 2 researchers have also confirmed Decoder is working for them.
<img src="./samples/oxford.png" width="600px" />
*ongoing at 21k steps*
## Install ## Install
```bash ```bash
@@ -814,8 +824,8 @@ clip = CLIP(
# mock data # mock data
text = torch.randint(0, 49408, (32, 256)).cuda() text = torch.randint(0, 49408, (512, 256)).cuda()
images = torch.randn(32, 3, 256, 256).cuda() images = torch.randn(512, 3, 256, 256).cuda()
# prior networks (with transformer) # prior networks (with transformer)
@@ -848,7 +858,7 @@ diffusion_prior_trainer.update() # this will update the optimizer as well as th
# after much of the above three lines in a loop # after much of the above three lines in a loop
# you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior # you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings image_embeds = diffusion_prior_trainer.sample(text, max_batch_size = 4) # (512, 512) - exponential moving averaged image embeddings
``` ```
## Bonus ## Bonus
@@ -861,7 +871,7 @@ ex.
```python ```python
import torch import torch
from dalle2_pytorch import Unet, Decoder from dalle2_pytorch import Unet, Decoder, DecoderTrainer
# unet for the cascading ddpm # unet for the cascading ddpm
@@ -884,20 +894,24 @@ decoder = Decoder(
unconditional = True unconditional = True
).cuda() ).cuda()
# mock images (get a lot of this) # decoder trainer
decoder_trainer = DecoderTrainer(decoder)
# images (get a lot of this)
images = torch.randn(1, 3, 512, 512).cuda() images = torch.randn(1, 3, 512, 512).cuda()
# feed images into decoder # feed images into decoder
for i in (1, 2): for i in (1, 2):
loss = decoder(images, unet_number = i) loss = decoder_trainer(images, unet_number = i)
loss.backward() decoder_trainer.update(unet_number = i)
# do the above for many many many many steps # do the above for many many many many images
# then it will learn to generate images # then it will learn to generate images
images = decoder.sample(batch_size = 2) # (2, 3, 512, 512) images = decoder_trainer.sample(batch_size = 36, max_batch_size = 4) # (36, 3, 512, 512)
``` ```
## Dataloaders ## Dataloaders

View File

@@ -7,7 +7,7 @@ def separate_weight_decayable_params(params):
def get_optimizer( def get_optimizer(
params, params,
lr = 2e-5, lr = 1e-4,
wd = 1e-2, wd = 1e-2,
betas = (0.9, 0.999), betas = (0.9, 0.999),
eps = 1e-8, eps = 1e-8,

View File

@@ -47,6 +47,14 @@ def groupby_prefix_and_trim(prefix, d):
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs return kwargs_without_prefix, kwargs
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
# decorators # decorators
def cast_torch_tensor(fn): def cast_torch_tensor(fn):
@@ -227,6 +235,16 @@ class EMA(nn.Module):
# diffusion prior trainer # diffusion prior trainer
def prior_sample_in_chunks(fn):
@wraps(fn)
def inner(self, *args, max_batch_size = None, **kwargs):
if not exists(max_batch_size):
return fn(self, *args, **kwargs)
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
return torch.cat(outputs, dim = 0)
return inner
class DiffusionPriorTrainer(nn.Module): class DiffusionPriorTrainer(nn.Module):
def __init__( def __init__(
self, self,
@@ -287,11 +305,13 @@ class DiffusionPriorTrainer(nn.Module):
@torch.no_grad() @torch.no_grad()
@cast_torch_tensor @cast_torch_tensor
@prior_sample_in_chunks
def p_sample_loop(self, *args, **kwargs): def p_sample_loop(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs) return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
@torch.no_grad() @torch.no_grad()
@cast_torch_tensor @cast_torch_tensor
@prior_sample_in_chunks
def sample(self, *args, **kwargs): def sample(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs) return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
@@ -322,15 +342,31 @@ class DiffusionPriorTrainer(nn.Module):
# decoder trainer # decoder trainer
def decoder_sample_in_chunks(fn):
@wraps(fn)
def inner(self, *args, max_batch_size = None, **kwargs):
if not exists(max_batch_size):
return fn(self, *args, **kwargs)
if self.decoder.unconditional:
batch_size = kwargs.get('batch_size')
batch_sizes = num_to_groups(batch_size, max_batch_size)
outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]
else:
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
return torch.cat(outputs, dim = 0)
return inner
class DecoderTrainer(nn.Module): class DecoderTrainer(nn.Module):
def __init__( def __init__(
self, self,
decoder, decoder,
use_ema = True, use_ema = True,
lr = 2e-5, lr = 1e-4,
wd = 1e-2, wd = 1e-2,
eps = 1e-8, eps = 1e-8,
max_grad_norm = None, max_grad_norm = 0.5,
amp = False, amp = False,
**kwargs **kwargs
): ):
@@ -411,18 +447,17 @@ class DecoderTrainer(nn.Module):
@torch.no_grad() @torch.no_grad()
@cast_torch_tensor @cast_torch_tensor
@decoder_sample_in_chunks
def sample(self, *args, **kwargs): def sample(self, *args, **kwargs):
if kwargs.pop('use_non_ema', False): if kwargs.pop('use_non_ema', False) or not self.use_ema:
return self.decoder.sample(*args, **kwargs) return self.decoder.sample(*args, **kwargs)
if self.use_ema: trainable_unets = self.decoder.unets
trainable_unets = self.decoder.unets self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
output = self.decoder.sample(*args, **kwargs) output = self.decoder.sample(*args, **kwargs)
if self.use_ema: self.decoder.unets = trainable_unets # restore original training unets
self.decoder.unets = trainable_unets # restore original training unets
# cast the ema_model unets back to original device # cast the ema_model unets back to original device
for ema in self.ema_unets: for ema in self.ema_unets:

BIN
samples/oxford.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 985 KiB

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.2.44', version = '0.3.2',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',