mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 18:44:22 +01:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bb86ab2404 | ||
|
|
ae056dd67c | ||
|
|
033d6b0ce8 | ||
|
|
c7ea8748db | ||
|
|
13382885d9 | ||
|
|
c3d4a7ffe4 | ||
|
|
164d9be444 | ||
|
|
5562ec6be2 | ||
|
|
89ff04cfe2 | ||
|
|
f4016f6302 | ||
|
|
1212f7058d |
32
README.md
32
README.md
@@ -14,6 +14,16 @@ Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord
|
|||||||
|
|
||||||
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
||||||
|
|
||||||
|
## Status
|
||||||
|
|
||||||
|
- A research group has used the code in this repository to train a functional diffusion prior for their CLIP generations. Will share their work once they release their preprint. This, and <a href="https://github.com/crowsonkb">Katherine's</a> own experiments, validate OpenAI's finding that the extra prior increases variety of generations.
|
||||||
|
|
||||||
|
- Decoder is now verified working for unconditional generation on my experimental setup for Oxford flowers. 2 researchers have also confirmed Decoder is working for them.
|
||||||
|
|
||||||
|
<img src="./samples/oxford.png" width="600px" />
|
||||||
|
|
||||||
|
*ongoing at 21k steps*
|
||||||
|
|
||||||
## Install
|
## Install
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -814,8 +824,8 @@ clip = CLIP(
|
|||||||
|
|
||||||
# mock data
|
# mock data
|
||||||
|
|
||||||
text = torch.randint(0, 49408, (32, 256)).cuda()
|
text = torch.randint(0, 49408, (512, 256)).cuda()
|
||||||
images = torch.randn(32, 3, 256, 256).cuda()
|
images = torch.randn(512, 3, 256, 256).cuda()
|
||||||
|
|
||||||
# prior networks (with transformer)
|
# prior networks (with transformer)
|
||||||
|
|
||||||
@@ -848,7 +858,7 @@ diffusion_prior_trainer.update() # this will update the optimizer as well as th
|
|||||||
# after much of the above three lines in a loop
|
# after much of the above three lines in a loop
|
||||||
# you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior
|
# you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior
|
||||||
|
|
||||||
image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings
|
image_embeds = diffusion_prior_trainer.sample(text, max_batch_size = 4) # (512, 512) - exponential moving averaged image embeddings
|
||||||
```
|
```
|
||||||
|
|
||||||
## Bonus
|
## Bonus
|
||||||
@@ -861,7 +871,7 @@ ex.
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
import torch
|
import torch
|
||||||
from dalle2_pytorch import Unet, Decoder
|
from dalle2_pytorch import Unet, Decoder, DecoderTrainer
|
||||||
|
|
||||||
# unet for the cascading ddpm
|
# unet for the cascading ddpm
|
||||||
|
|
||||||
@@ -884,20 +894,24 @@ decoder = Decoder(
|
|||||||
unconditional = True
|
unconditional = True
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
# mock images (get a lot of this)
|
# decoder trainer
|
||||||
|
|
||||||
|
decoder_trainer = DecoderTrainer(decoder)
|
||||||
|
|
||||||
|
# images (get a lot of this)
|
||||||
|
|
||||||
images = torch.randn(1, 3, 512, 512).cuda()
|
images = torch.randn(1, 3, 512, 512).cuda()
|
||||||
|
|
||||||
# feed images into decoder
|
# feed images into decoder
|
||||||
|
|
||||||
for i in (1, 2):
|
for i in (1, 2):
|
||||||
loss = decoder(images, unet_number = i)
|
loss = decoder_trainer(images, unet_number = i)
|
||||||
loss.backward()
|
decoder_trainer.update(unet_number = i)
|
||||||
|
|
||||||
# do the above for many many many many steps
|
# do the above for many many many many images
|
||||||
# then it will learn to generate images
|
# then it will learn to generate images
|
||||||
|
|
||||||
images = decoder.sample(batch_size = 2) # (2, 3, 512, 512)
|
images = decoder_trainer.sample(batch_size = 36, max_batch_size = 4) # (36, 3, 512, 512)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Dataloaders
|
## Dataloaders
|
||||||
|
|||||||
@@ -1870,13 +1870,14 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
|
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
|
||||||
device = self.betas.device
|
device = self.betas.device
|
||||||
|
|
||||||
b = shape[0]
|
b = shape[0]
|
||||||
img = torch.randn(shape, device = device)
|
img = torch.randn(shape, device = device)
|
||||||
|
|
||||||
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
|
if not is_latent_diffusion:
|
||||||
|
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
|
||||||
|
|
||||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
||||||
img = self.p_sample(
|
img = self.p_sample(
|
||||||
@@ -1896,13 +1897,14 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
unnormalize_img = unnormalize_zero_to_one(img)
|
unnormalize_img = unnormalize_zero_to_one(img)
|
||||||
return unnormalize_img
|
return unnormalize_img
|
||||||
|
|
||||||
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False):
|
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
|
||||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
|
|
||||||
# normalize to [-1, 1]
|
# normalize to [-1, 1]
|
||||||
|
|
||||||
x_start = normalize_neg_one_to_one(x_start)
|
if not is_latent_diffusion:
|
||||||
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
|
x_start = normalize_neg_one_to_one(x_start)
|
||||||
|
lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img)
|
||||||
|
|
||||||
# get x_t
|
# get x_t
|
||||||
|
|
||||||
@@ -1980,7 +1982,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
batch_size = image_embed.shape[0]
|
batch_size = image_embed.shape[0]
|
||||||
|
|
||||||
if exists(text) and not exists(text_encodings) and not self.unconditional:
|
if exists(text) and not exists(text_encodings) and not self.unconditional:
|
||||||
assert exist(self.clip)
|
assert exists(self.clip)
|
||||||
_, text_encodings, text_mask = self.clip.embed_text(text)
|
_, text_encodings, text_mask = self.clip.embed_text(text)
|
||||||
|
|
||||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||||
@@ -2016,7 +2018,8 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
predict_x_start = predict_x_start,
|
predict_x_start = predict_x_start,
|
||||||
learned_variance = learned_variance,
|
learned_variance = learned_variance,
|
||||||
clip_denoised = not is_latent_diffusion,
|
clip_denoised = not is_latent_diffusion,
|
||||||
lowres_cond_img = lowres_cond_img
|
lowres_cond_img = lowres_cond_img,
|
||||||
|
is_latent_diffusion = is_latent_diffusion
|
||||||
)
|
)
|
||||||
|
|
||||||
img = vae.decode(img)
|
img = vae.decode(img)
|
||||||
@@ -2075,12 +2078,14 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
image = aug(image)
|
image = aug(image)
|
||||||
lowres_cond_img = aug(lowres_cond_img, params = aug._params)
|
lowres_cond_img = aug(lowres_cond_img, params = aug._params)
|
||||||
|
|
||||||
|
is_latent_diffusion = not isinstance(vae, NullVQGanVAE)
|
||||||
|
|
||||||
vae.eval()
|
vae.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
image = vae.encode(image)
|
image = vae.encode(image)
|
||||||
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
||||||
|
|
||||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance)
|
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion)
|
||||||
|
|
||||||
# main class
|
# main class
|
||||||
|
|
||||||
|
|||||||
59
dalle2_pytorch/dataloaders/simple_image_only_dataloader.py
Normal file
59
dalle2_pytorch/dataloaders/simple_image_only_dataloader.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils import data
|
||||||
|
from torchvision import transforms, utils
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# helpers functions
|
||||||
|
|
||||||
|
def cycle(dl):
|
||||||
|
while True:
|
||||||
|
for data in dl:
|
||||||
|
yield data
|
||||||
|
|
||||||
|
# dataset and dataloader
|
||||||
|
|
||||||
|
class Dataset(data.Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
folder,
|
||||||
|
image_size,
|
||||||
|
exts = ['jpg', 'jpeg', 'png']
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.folder = folder
|
||||||
|
self.image_size = image_size
|
||||||
|
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
|
||||||
|
|
||||||
|
self.transform = transforms.Compose([
|
||||||
|
transforms.Resize(image_size),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.CenterCrop(image_size),
|
||||||
|
transforms.ToTensor()
|
||||||
|
])
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.paths)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
path = self.paths[index]
|
||||||
|
img = Image.open(path)
|
||||||
|
return self.transform(img)
|
||||||
|
|
||||||
|
def get_images_dataloader(
|
||||||
|
folder,
|
||||||
|
*,
|
||||||
|
batch_size,
|
||||||
|
image_size,
|
||||||
|
shuffle = True,
|
||||||
|
cycle_dl = True,
|
||||||
|
pin_memory = True
|
||||||
|
):
|
||||||
|
ds = Dataset(folder, image_size)
|
||||||
|
dl = data.DataLoader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = pin_memory)
|
||||||
|
|
||||||
|
if cycle_dl:
|
||||||
|
dl = cycle(dl)
|
||||||
|
return dl
|
||||||
@@ -7,7 +7,7 @@ def separate_weight_decayable_params(params):
|
|||||||
|
|
||||||
def get_optimizer(
|
def get_optimizer(
|
||||||
params,
|
params,
|
||||||
lr = 2e-5,
|
lr = 1e-4,
|
||||||
wd = 1e-2,
|
wd = 1e-2,
|
||||||
betas = (0.9, 0.999),
|
betas = (0.9, 0.999),
|
||||||
eps = 1e-8,
|
eps = 1e-8,
|
||||||
|
|||||||
@@ -47,6 +47,14 @@ def groupby_prefix_and_trim(prefix, d):
|
|||||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||||
return kwargs_without_prefix, kwargs
|
return kwargs_without_prefix, kwargs
|
||||||
|
|
||||||
|
def num_to_groups(num, divisor):
|
||||||
|
groups = num // divisor
|
||||||
|
remainder = num % divisor
|
||||||
|
arr = [divisor] * groups
|
||||||
|
if remainder > 0:
|
||||||
|
arr.append(remainder)
|
||||||
|
return arr
|
||||||
|
|
||||||
# decorators
|
# decorators
|
||||||
|
|
||||||
def cast_torch_tensor(fn):
|
def cast_torch_tensor(fn):
|
||||||
@@ -179,8 +187,8 @@ class EMA(nn.Module):
|
|||||||
self.online_model = model
|
self.online_model = model
|
||||||
self.ema_model = copy.deepcopy(model)
|
self.ema_model = copy.deepcopy(model)
|
||||||
|
|
||||||
self.update_after_step = update_after_step # only start EMA after this step number, starting at 0
|
|
||||||
self.update_every = update_every
|
self.update_every = update_every
|
||||||
|
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
|
||||||
|
|
||||||
self.register_buffer('initted', torch.Tensor([False]))
|
self.register_buffer('initted', torch.Tensor([False]))
|
||||||
self.register_buffer('step', torch.tensor([0.]))
|
self.register_buffer('step', torch.tensor([0.]))
|
||||||
@@ -189,14 +197,21 @@ class EMA(nn.Module):
|
|||||||
device = self.initted.device
|
device = self.initted.device
|
||||||
self.ema_model.to(device)
|
self.ema_model.to(device)
|
||||||
|
|
||||||
|
def copy_params_from_model_to_ema(self):
|
||||||
|
self.ema_model.state_dict(self.online_model.state_dict())
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
self.step += 1
|
self.step += 1
|
||||||
|
|
||||||
if self.step <= self.update_after_step or (self.step % self.update_every) != 0:
|
if (self.step % self.update_every) != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.step <= self.update_after_step:
|
||||||
|
self.copy_params_from_model_to_ema()
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self.initted:
|
if not self.initted:
|
||||||
self.ema_model.state_dict(self.online_model.state_dict())
|
self.copy_params_from_model_to_ema()
|
||||||
self.initted.data.copy_(torch.Tensor([True]))
|
self.initted.data.copy_(torch.Tensor([True]))
|
||||||
|
|
||||||
self.update_moving_average(self.ema_model, self.online_model)
|
self.update_moving_average(self.ema_model, self.online_model)
|
||||||
@@ -220,6 +235,16 @@ class EMA(nn.Module):
|
|||||||
|
|
||||||
# diffusion prior trainer
|
# diffusion prior trainer
|
||||||
|
|
||||||
|
def prior_sample_in_chunks(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
def inner(self, *args, max_batch_size = None, **kwargs):
|
||||||
|
if not exists(max_batch_size):
|
||||||
|
return fn(self, *args, **kwargs)
|
||||||
|
|
||||||
|
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
|
||||||
|
return torch.cat(outputs, dim = 0)
|
||||||
|
return inner
|
||||||
|
|
||||||
class DiffusionPriorTrainer(nn.Module):
|
class DiffusionPriorTrainer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -280,11 +305,13 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@cast_torch_tensor
|
@cast_torch_tensor
|
||||||
|
@prior_sample_in_chunks
|
||||||
def p_sample_loop(self, *args, **kwargs):
|
def p_sample_loop(self, *args, **kwargs):
|
||||||
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@cast_torch_tensor
|
@cast_torch_tensor
|
||||||
|
@prior_sample_in_chunks
|
||||||
def sample(self, *args, **kwargs):
|
def sample(self, *args, **kwargs):
|
||||||
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
|
||||||
|
|
||||||
@@ -315,15 +342,31 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
|
|
||||||
# decoder trainer
|
# decoder trainer
|
||||||
|
|
||||||
|
def decoder_sample_in_chunks(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
def inner(self, *args, max_batch_size = None, **kwargs):
|
||||||
|
if not exists(max_batch_size):
|
||||||
|
return fn(self, *args, **kwargs)
|
||||||
|
|
||||||
|
if self.decoder.unconditional:
|
||||||
|
batch_size = kwargs.get('batch_size')
|
||||||
|
batch_sizes = num_to_groups(batch_size, max_batch_size)
|
||||||
|
outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]
|
||||||
|
else:
|
||||||
|
outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
|
||||||
|
|
||||||
|
return torch.cat(outputs, dim = 0)
|
||||||
|
return inner
|
||||||
|
|
||||||
class DecoderTrainer(nn.Module):
|
class DecoderTrainer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
decoder,
|
decoder,
|
||||||
use_ema = True,
|
use_ema = True,
|
||||||
lr = 2e-5,
|
lr = 1e-4,
|
||||||
wd = 1e-2,
|
wd = 1e-2,
|
||||||
eps = 1e-8,
|
eps = 1e-8,
|
||||||
max_grad_norm = None,
|
max_grad_norm = 0.5,
|
||||||
amp = False,
|
amp = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
@@ -404,15 +447,17 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@cast_torch_tensor
|
@cast_torch_tensor
|
||||||
|
@decoder_sample_in_chunks
|
||||||
def sample(self, *args, **kwargs):
|
def sample(self, *args, **kwargs):
|
||||||
if self.use_ema:
|
if kwargs.pop('use_non_ema', False) or not self.use_ema:
|
||||||
trainable_unets = self.decoder.unets
|
return self.decoder.sample(*args, **kwargs)
|
||||||
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
|
|
||||||
|
trainable_unets = self.decoder.unets
|
||||||
|
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
|
||||||
|
|
||||||
output = self.decoder.sample(*args, **kwargs)
|
output = self.decoder.sample(*args, **kwargs)
|
||||||
|
|
||||||
if self.use_ema:
|
self.decoder.unets = trainable_unets # restore original training unets
|
||||||
self.decoder.unets = trainable_unets # restore original training unets
|
|
||||||
|
|
||||||
# cast the ema_model unets back to original device
|
# cast the ema_model unets back to original device
|
||||||
for ema in self.ema_unets:
|
for ema in self.ema_unets:
|
||||||
|
|||||||
BIN
samples/oxford.png
Normal file
BIN
samples/oxford.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 985 KiB |
Reference in New Issue
Block a user