mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
db0642c4cd | ||
|
|
bb86ab2404 | ||
|
|
ae056dd67c | ||
|
|
033d6b0ce8 | ||
|
|
c7ea8748db | ||
|
|
13382885d9 | ||
|
|
c3d4a7ffe4 | ||
|
|
164d9be444 | ||
|
|
5562ec6be2 |
32
README.md
32
README.md
@@ -14,6 +14,16 @@ Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord
|
||||
|
||||
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
||||
|
||||
## Status
|
||||
|
||||
- A research group has used the code in this repository to train a functional diffusion prior for their CLIP generations. Will share their work once they release their preprint. This, and <a href="https://github.com/crowsonkb">Katherine's</a> own experiments, validate OpenAI's finding that the extra prior increases variety of generations.
|
||||
|
||||
- Decoder is now verified working for unconditional generation on my experimental setup for Oxford flowers. 2 researchers have also confirmed Decoder is working for them.
|
||||
|
||||
<img src="./samples/oxford.png" width="600px" />
|
||||
|
||||
*ongoing at 21k steps*
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
@@ -814,8 +824,8 @@ clip = CLIP(
|
||||
|
||||
# mock data
|
||||
|
||||
text = torch.randint(0, 49408, (32, 256)).cuda()
|
||||
images = torch.randn(32, 3, 256, 256).cuda()
|
||||
text = torch.randint(0, 49408, (512, 256)).cuda()
|
||||
images = torch.randn(512, 3, 256, 256).cuda()
|
||||
|
||||
# prior networks (with transformer)
|
||||
|
||||
@@ -848,7 +858,7 @@ diffusion_prior_trainer.update() # this will update the optimizer as well as th
|
||||
# after much of the above three lines in a loop
|
||||
# you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior
|
||||
|
||||
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings
|
||||
image_embeds = diffusion_prior_trainer.sample(text, max_batch_size = 4) # (512, 512) - exponential moving averaged image embeddings
|
||||
```
|
||||
|
||||
## Bonus
|
||||
@@ -861,7 +871,7 @@ ex.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import Unet, Decoder
|
||||
from dalle2_pytorch import Unet, Decoder, DecoderTrainer
|
||||
|
||||
# unet for the cascading ddpm
|
||||
|
||||
@@ -884,20 +894,24 @@ decoder = Decoder(
|
||||
unconditional = True
|
||||
).cuda()
|
||||
|
||||
# mock images (get a lot of this)
|
||||
# decoder trainer
|
||||
|
||||
decoder_trainer = DecoderTrainer(decoder)
|
||||
|
||||
# images (get a lot of this)
|
||||
|
||||
images = torch.randn(1, 3, 512, 512).cuda()
|
||||
|
||||
# feed images into decoder
|
||||
|
||||
for i in (1, 2):
|
||||
loss = decoder(images, unet_number = i)
|
||||
loss.backward()
|
||||
loss = decoder_trainer(images, unet_number = i)
|
||||
decoder_trainer.update(unet_number = i)
|
||||
|
||||
# do the above for many many many many steps
|
||||
# do the above for many many many many images
|
||||
# then it will learn to generate images
|
||||
|
||||
images = decoder.sample(batch_size = 2) # (2, 3, 512, 512)
|
||||
images = decoder_trainer.sample(batch_size = 36, max_batch_size = 4) # (36, 3, 512, 512)
|
||||
```
|
||||
|
||||
## Dataloaders
|
||||
|
||||
@@ -1697,7 +1697,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
clip_adapter_overrides = dict(),
|
||||
learned_variance = True,
|
||||
vb_loss_weight = 0.001,
|
||||
unconditional = False
|
||||
unconditional = False,
|
||||
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
||||
):
|
||||
super().__init__(
|
||||
beta_schedule = beta_schedule,
|
||||
@@ -1806,6 +1807,10 @@ class Decoder(BaseGaussianDiffusion):
|
||||
self.clip_denoised = clip_denoised
|
||||
self.clip_x_start = clip_x_start
|
||||
|
||||
# normalize and unnormalize image functions
|
||||
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
|
||||
self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
|
||||
|
||||
def get_unet(self, unet_number):
|
||||
assert 0 < unet_number <= len(self.unets)
|
||||
index = unet_number - 1
|
||||
@@ -1877,7 +1882,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
if not is_latent_diffusion:
|
||||
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
|
||||
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
||||
img = self.p_sample(
|
||||
@@ -1894,7 +1899,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
clip_denoised = clip_denoised
|
||||
)
|
||||
|
||||
unnormalize_img = unnormalize_zero_to_one(img)
|
||||
unnormalize_img = self.unnormalize_img(img)
|
||||
return unnormalize_img
|
||||
|
||||
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
|
||||
@@ -1903,8 +1908,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
# normalize to [-1, 1]
|
||||
|
||||
if not is_latent_diffusion:
|
||||
x_start = normalize_neg_one_to_one(x_start)
|
||||
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
|
||||
x_start = self.normalize_img(x_start)
|
||||
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
|
||||
|
||||
# get x_t
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ def separate_weight_decayable_params(params):
|
||||
|
||||
def get_optimizer(
|
||||
params,
|
||||
lr = 2e-5,
|
||||
lr = 1e-4,
|
||||
wd = 1e-2,
|
||||
betas = (0.9, 0.999),
|
||||
eps = 1e-8,
|
||||
|
||||
@@ -47,6 +47,14 @@ def groupby_prefix_and_trim(prefix, d):
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
def num_to_groups(num, divisor):
|
||||
groups = num // divisor
|
||||
remainder = num % divisor
|
||||
arr = [divisor] * groups
|
||||
if remainder > 0:
|
||||
arr.append(remainder)
|
||||
return arr
|
||||
|
||||
# decorators
|
||||
|
||||
def cast_torch_tensor(fn):
|
||||
@@ -227,6 +235,16 @@ class EMA(nn.Module):
|
||||
|
||||
# diffusion prior trainer
|
||||
|
||||
def prior_sample_in_chunks(fn):
|
||||
@wraps(fn)
|
||||
def inner(self, *args, max_batch_size = None, **kwargs):
|
||||
if not exists(max_batch_size):
|
||||
return fn(self, *args, **kwargs)
|
||||
|
||||
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
|
||||
return torch.cat(outputs, dim = 0)
|
||||
return inner
|
||||
|
||||
class DiffusionPriorTrainer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -287,11 +305,13 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def p_sample_loop(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@prior_sample_in_chunks
|
||||
def sample(self, *args, **kwargs):
|
||||
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
||||
|
||||
@@ -322,15 +342,31 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
|
||||
# decoder trainer
|
||||
|
||||
def decoder_sample_in_chunks(fn):
|
||||
@wraps(fn)
|
||||
def inner(self, *args, max_batch_size = None, **kwargs):
|
||||
if not exists(max_batch_size):
|
||||
return fn(self, *args, **kwargs)
|
||||
|
||||
if self.decoder.unconditional:
|
||||
batch_size = kwargs.get('batch_size')
|
||||
batch_sizes = num_to_groups(batch_size, max_batch_size)
|
||||
outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]
|
||||
else:
|
||||
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
|
||||
|
||||
return torch.cat(outputs, dim = 0)
|
||||
return inner
|
||||
|
||||
class DecoderTrainer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
decoder,
|
||||
use_ema = True,
|
||||
lr = 2e-5,
|
||||
lr = 1e-4,
|
||||
wd = 1e-2,
|
||||
eps = 1e-8,
|
||||
max_grad_norm = None,
|
||||
max_grad_norm = 0.5,
|
||||
amp = False,
|
||||
**kwargs
|
||||
):
|
||||
@@ -411,18 +447,17 @@ class DecoderTrainer(nn.Module):
|
||||
|
||||
@torch.no_grad()
|
||||
@cast_torch_tensor
|
||||
@decoder_sample_in_chunks
|
||||
def sample(self, *args, **kwargs):
|
||||
if kwargs.pop('use_non_ema', False):
|
||||
if kwargs.pop('use_non_ema', False) or not self.use_ema:
|
||||
return self.decoder.sample(*args, **kwargs)
|
||||
|
||||
if self.use_ema:
|
||||
trainable_unets = self.decoder.unets
|
||||
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
|
||||
trainable_unets = self.decoder.unets
|
||||
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
|
||||
|
||||
output = self.decoder.sample(*args, **kwargs)
|
||||
|
||||
if self.use_ema:
|
||||
self.decoder.unets = trainable_unets # restore original training unets
|
||||
self.decoder.unets = trainable_unets # restore original training unets
|
||||
|
||||
# cast the ema_model unets back to original device
|
||||
for ema in self.ema_unets:
|
||||
|
||||
BIN
samples/oxford.png
Normal file
BIN
samples/oxford.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 985 KiB |
Reference in New Issue
Block a user