mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1cce4225eb | ||
|
|
5ab0700bab | ||
|
|
b0f2fbaa95 | ||
|
|
51361c2d15 | ||
|
|
42d6e47387 | ||
|
|
1e939153fb | ||
|
|
1abeb8918e | ||
|
|
b423855483 | ||
|
|
c400d8758c | ||
|
|
bece206699 |
18
README.md
18
README.md
@@ -2,7 +2,9 @@
|
||||
|
||||
## DALL-E 2 - Pytorch (wip)
|
||||
|
||||
Implementation of <a href="https://openai.com/dall-e-2/">DALL-E 2</a>, OpenAI's updated text-to-image synthesis neural network, in Pytorch. <a href="https://youtu.be/RJwPN4qNi_Y?t=555">Yannic Kilcher summary</a>
|
||||
Implementation of <a href="https://openai.com/dall-e-2/">DALL-E 2</a>, OpenAI's updated text-to-image synthesis neural network, in Pytorch.
|
||||
|
||||
<a href="https://youtu.be/RJwPN4qNi_Y?t=555">Yannic Kilcher summary</a> | <a href="https://www.youtube.com/watch?v=F1X4fHzF4mQ">AssemblyAI explainer</a>
|
||||
|
||||
The main novelty seems to be an extra layer of indirection with the prior network (whether it is an autoregressive transformer or a diffusion network), which predicts an image embedding based on the text embedding from CLIP. Specifically, this repository will only build out the diffusion prior network, as it is the best performing variant (but which incidentally involves a causal transformer as the denoising network 😂)
|
||||
|
||||
@@ -12,9 +14,7 @@ It may also explore an extension of using <a href="https://huggingface.co/spaces
|
||||
|
||||
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
|
||||
|
||||
Do let me know if anyone is interested in a Jax version https://github.com/lucidrains/DALLE2-pytorch/discussions/8
|
||||
|
||||
For all of you emailing me (there is a lot), the best way to contribute is through pull requests. Everything is open sourced after all. All my thoughts are public. This is your moment to participate.
|
||||
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>
|
||||
|
||||
## Install
|
||||
|
||||
@@ -320,12 +320,12 @@ Offer training wrappers
|
||||
- [x] add what was proposed in the paper, where DDPM objective for image latent embedding predicts x0 directly (reread vq-diffusion paper and get caught up on that line of work)
|
||||
- [x] make sure it works end to end to produce an output tensor, taking a single gradient step
|
||||
- [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
|
||||
- [ ] look into Jonathan Ho's cascading DDPM for the decoder, as that seems to be what they are using. get caught up on DDPM literature
|
||||
- [ ] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
|
||||
- [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
|
||||
- [ ] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] add attention to unet - apply some personal tricks with efficient attention
|
||||
- [ ] figure out the big idea behind latent diffusion and what can be ported over
|
||||
- [ ] consider U2-net for decoder https://arxiv.org/abs/2005.09007
|
||||
- [ ] add attention to unet - apply some personal tricks with efficient attention - use the sparse attention mechanism from https://github.com/lucidrains/vit-pytorch#maxvit
|
||||
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
|
||||
- [ ] consider U2-net for decoder https://arxiv.org/abs/2005.09007 (also in separate file as experimental) build out https://github.com/lucidrains/x-unet
|
||||
|
||||
## Citations
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from einops.layers.torch import Rearrange
|
||||
from einops_exts import rearrange_many, repeat_many, check_shape
|
||||
from einops_exts.torch import EinopsToAndFrom
|
||||
|
||||
from kornia.filters import filter2d
|
||||
from kornia.filters.gaussian import GaussianBlur2d
|
||||
|
||||
from dalle2_pytorch.tokenizer import tokenizer
|
||||
|
||||
@@ -98,6 +98,29 @@ def cosine_beta_schedule(timesteps, s = 0.008):
|
||||
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||
return torch.clip(betas, 0, 0.999)
|
||||
|
||||
|
||||
def linear_beta_schedule(timesteps):
|
||||
scale = 1000 / timesteps
|
||||
beta_start = scale * 0.0001
|
||||
beta_end = scale * 0.02
|
||||
return torch.linspace(beta_start, beta_end, timesteps)
|
||||
|
||||
|
||||
def quadratic_beta_schedule(timesteps):
|
||||
scale = 1000 / timesteps
|
||||
beta_start = scale * 0.0001
|
||||
beta_end = scale * 0.02
|
||||
return torch.linspace(beta_start**2, beta_end**2, timesteps) ** 2
|
||||
|
||||
|
||||
def sigmoid_beta_schedule(timesteps):
|
||||
scale = 1000 / timesteps
|
||||
beta_start = scale * 0.0001
|
||||
beta_end = scale * 0.02
|
||||
betas = torch.linspace(-6, 6, timesteps)
|
||||
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
|
||||
|
||||
|
||||
# diffusion prior
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
@@ -427,10 +450,11 @@ class DiffusionPrior(nn.Module):
|
||||
net,
|
||||
*,
|
||||
clip,
|
||||
timesteps = 1000,
|
||||
cond_drop_prob = 0.2,
|
||||
loss_type = 'l1',
|
||||
predict_x0 = True
|
||||
timesteps=1000,
|
||||
cond_drop_prob=0.2,
|
||||
loss_type="l1",
|
||||
predict_x0=True,
|
||||
beta_schedule="cosine",
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(clip, CLIP)
|
||||
@@ -446,11 +470,22 @@ class DiffusionPrior(nn.Module):
|
||||
self.predict_x0 = predict_x0
|
||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||
|
||||
betas = cosine_beta_schedule(timesteps)
|
||||
if beta_schedule == "cosine":
|
||||
betas = cosine_beta_schedule(timesteps)
|
||||
elif beta_schedule == "linear":
|
||||
betas = linear_beta_schedule(timesteps)
|
||||
elif beta_schedule == "quadratic":
|
||||
betas = quadratic_beta_schedule(timesteps)
|
||||
elif beta_schedule == "jsd":
|
||||
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
|
||||
elif beta_schedule == "sigmoid":
|
||||
betas = sigmoid_beta_schedule(timesteps)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (0, 1), value = 1.)
|
||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
@@ -601,6 +636,8 @@ class DiffusionPrior(nn.Module):
|
||||
loss = F.l1_loss(to_predict, x_recon)
|
||||
elif self.loss_type == 'l2':
|
||||
loss = F.mse_loss(to_predict, x_recon)
|
||||
elif self.loss_type == "huber":
|
||||
loss = F.smooth_l1_loss(to_predict, x_recon)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -625,17 +662,6 @@ def Upsample(dim):
|
||||
def Downsample(dim):
|
||||
return nn.Conv2d(dim, dim, 4, 2, 1)
|
||||
|
||||
class Blur(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
filt = torch.Tensor([1, 2, 1])
|
||||
self.register_buffer('filt', filt)
|
||||
|
||||
def forward(self, x):
|
||||
filt = self.filt
|
||||
filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1')
|
||||
return filter2d(x, filt, normalized = True)
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
@@ -769,11 +795,25 @@ class Unet(nn.Module):
|
||||
out_dim = None,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
channels = 3,
|
||||
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||
lowres_cond_upsample_mode = 'bilinear',
|
||||
blur_sigma = 0.1
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# for eventual cascading diffusion
|
||||
|
||||
self.lowres_cond = lowres_cond
|
||||
self.lowres_cond_upsample_mode = lowres_cond_upsample_mode
|
||||
self.lowres_cond_blur = GaussianBlur2d((3, 3), (blur_sigma, blur_sigma))
|
||||
|
||||
# determine dimensions
|
||||
|
||||
self.channels = channels
|
||||
|
||||
dims = [channels, *map(lambda m: dim * m, dim_mults)]
|
||||
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||
|
||||
dims = [init_channels, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
|
||||
# time, image embeddings, and optional text encoding
|
||||
@@ -856,12 +896,30 @@ class Unet(nn.Module):
|
||||
time,
|
||||
*,
|
||||
image_embed,
|
||||
lowres_cond_img = None,
|
||||
text_encodings = None,
|
||||
cond_drop_prob = 0.
|
||||
):
|
||||
batch_size, device = x.shape[0], x.device
|
||||
|
||||
# add low resolution conditioning, if present
|
||||
|
||||
assert not self.lowres_cond and not exists(lowres_cond_img), 'low resolution conditioning image must be present'
|
||||
|
||||
if exists(lowres_cond_img):
|
||||
if self.training:
|
||||
# when training, blur the low resolution conditional image
|
||||
lowres_cond_img = self.lowres_cond_blur(lowres_cond_img)
|
||||
|
||||
lowres_cond_img = F.interpolate(lowres_cond_img, size = x.shape[-2:], mode = self.lowres_cond_upsample_mode)
|
||||
x = torch.cat((x, lowres_cond_img), dim = 1)
|
||||
|
||||
# time conditioning
|
||||
|
||||
time_tokens = self.time_mlp(time)
|
||||
|
||||
# conditional dropout
|
||||
|
||||
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
|
||||
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
|
||||
|
||||
@@ -923,9 +981,10 @@ class Decoder(nn.Module):
|
||||
net,
|
||||
*,
|
||||
clip,
|
||||
timesteps = 1000,
|
||||
cond_drop_prob = 0.2,
|
||||
loss_type = 'l1'
|
||||
timesteps=1000,
|
||||
cond_drop_prob=0.2,
|
||||
loss_type="l1",
|
||||
beta_schedule="cosine",
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(clip, CLIP)
|
||||
@@ -937,11 +996,22 @@ class Decoder(nn.Module):
|
||||
self.image_size = clip.image_size
|
||||
self.cond_drop_prob = cond_drop_prob
|
||||
|
||||
betas = cosine_beta_schedule(timesteps)
|
||||
if beta_schedule == "cosine":
|
||||
betas = cosine_beta_schedule(timesteps)
|
||||
elif beta_schedule == "linear":
|
||||
betas = linear_beta_schedule(timesteps)
|
||||
elif beta_schedule == "quadratic":
|
||||
betas = quadratic_beta_schedule(timesteps)
|
||||
elif beta_schedule == "jsd":
|
||||
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
|
||||
elif beta_schedule == "sigmoid":
|
||||
betas = sigmoid_beta_schedule(timesteps)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (0, 1), value = 1.)
|
||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
@@ -1066,6 +1136,8 @@ class Decoder(nn.Module):
|
||||
loss = F.l1_loss(noise, x_recon)
|
||||
elif self.loss_type == 'l2':
|
||||
loss = F.mse_loss(noise, x_recon)
|
||||
elif self.loss_type == "huber":
|
||||
loss = F.smooth_l1_loss(noise, x_recon)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user