mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2f3c02dba8 | ||
|
|
908088cfea | ||
|
|
8dc8a3de0d | ||
|
|
35f89556ba | ||
|
|
2b55f753b9 | ||
|
|
fc8fce38fb | ||
|
|
a1bfb03ba4 |
19
README.md
19
README.md
@@ -1000,14 +1000,15 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [x] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)
|
||||
- [x] offer setting in diffusion prior to split time and image embeddings into multiple tokens, configurable, for more surface area during attention
|
||||
- [x] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
|
||||
- [x] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
|
||||
- [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
|
||||
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
|
||||
- [x] cross embed layers for downsampling, as an option
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
|
||||
- [ ] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
- [ ] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
|
||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||
- [ ] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
|
||||
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
|
||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
|
||||
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
|
||||
@@ -1015,6 +1016,7 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
||||
- [ ] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
|
||||
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
||||
- [ ] decoder needs one day worth of refactor for tech debt
|
||||
|
||||
## Citations
|
||||
|
||||
@@ -1092,4 +1094,15 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{wang2021crossformer,
|
||||
title = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},
|
||||
author = {Wenxiao Wang and Lu Yao and Long Chen and Binbin Lin and Deng Cai and Xiaofei He and Wei Liu},
|
||||
year = {2021},
|
||||
eprint = {2108.00154},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||
|
||||
@@ -303,7 +303,7 @@ def cosine_beta_schedule(timesteps, s = 0.008):
|
||||
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
||||
"""
|
||||
steps = timesteps + 1
|
||||
x = torch.linspace(0, timesteps, steps)
|
||||
x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
|
||||
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
||||
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
||||
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||
@@ -314,21 +314,21 @@ def linear_beta_schedule(timesteps):
|
||||
scale = 1000 / timesteps
|
||||
beta_start = scale * 0.0001
|
||||
beta_end = scale * 0.02
|
||||
return torch.linspace(beta_start, beta_end, timesteps)
|
||||
return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
|
||||
|
||||
|
||||
def quadratic_beta_schedule(timesteps):
|
||||
scale = 1000 / timesteps
|
||||
beta_start = scale * 0.0001
|
||||
beta_end = scale * 0.02
|
||||
return torch.linspace(beta_start**2, beta_end**2, timesteps) ** 2
|
||||
return torch.linspace(beta_start**2, beta_end**2, timesteps, dtype = torch.float64) ** 2
|
||||
|
||||
|
||||
def sigmoid_beta_schedule(timesteps):
|
||||
scale = 1000 / timesteps
|
||||
beta_start = scale * 0.0001
|
||||
beta_end = scale * 0.02
|
||||
betas = torch.linspace(-6, 6, timesteps)
|
||||
betas = torch.linspace(-6, 6, timesteps, dtype = torch.float64)
|
||||
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
|
||||
|
||||
|
||||
@@ -368,17 +368,21 @@ class BaseGaussianDiffusion(nn.Module):
|
||||
self.loss_type = loss_type
|
||||
self.loss_fn = loss_fn
|
||||
|
||||
self.register_buffer('betas', betas)
|
||||
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
||||
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
||||
# register buffer helper function to cast double back to float
|
||||
|
||||
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
|
||||
|
||||
register_buffer('betas', betas)
|
||||
register_buffer('alphas_cumprod', alphas_cumprod)
|
||||
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
|
||||
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
||||
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
||||
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
||||
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
||||
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
||||
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
||||
|
||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||
|
||||
@@ -386,13 +390,13 @@ class BaseGaussianDiffusion(nn.Module):
|
||||
|
||||
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||
|
||||
self.register_buffer('posterior_variance', posterior_variance)
|
||||
register_buffer('posterior_variance', posterior_variance)
|
||||
|
||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||
|
||||
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
||||
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||
register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
||||
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||
|
||||
def q_mean_variance(self, x_start, t):
|
||||
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||
@@ -1228,6 +1232,33 @@ class LinearAttention(nn.Module):
|
||||
out = self.nonlin(out)
|
||||
return self.to_out(out)
|
||||
|
||||
class CrossEmbedLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_in,
|
||||
kernel_sizes,
|
||||
dim_out = None,
|
||||
stride = 2
|
||||
):
|
||||
super().__init__()
|
||||
assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
|
||||
dim_out = default(dim_out, dim_in)
|
||||
|
||||
kernel_sizes = sorted(kernel_sizes)
|
||||
num_scales = len(kernel_sizes)
|
||||
|
||||
# calculate the dimension at each scale
|
||||
dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
|
||||
dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
|
||||
|
||||
self.convs = nn.ModuleList([])
|
||||
for kernel, dim_scale in zip(kernel_sizes, dim_scales):
|
||||
self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
|
||||
|
||||
def forward(self, x):
|
||||
fmaps = tuple(map(lambda conv: conv(x), self.convs))
|
||||
return torch.cat(fmaps, dim = 1)
|
||||
|
||||
class Unet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1252,6 +1283,9 @@ class Unet(nn.Module):
|
||||
init_dim = None,
|
||||
init_conv_kernel_size = 7,
|
||||
resnet_groups = 8,
|
||||
init_cross_embed_kernel_sizes = (3, 7, 15),
|
||||
cross_embed_downsample = False,
|
||||
cross_embed_downsample_kernel_sizes = (2, 4),
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
@@ -1270,10 +1304,9 @@ class Unet(nn.Module):
|
||||
self.channels = channels
|
||||
|
||||
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||
init_dim = default(init_dim, dim // 2)
|
||||
init_dim = default(init_dim, dim // 3 * 2)
|
||||
|
||||
assert (init_conv_kernel_size % 2) == 1
|
||||
self.init_conv = nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
|
||||
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
|
||||
|
||||
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
@@ -1333,6 +1366,12 @@ class Unet(nn.Module):
|
||||
|
||||
assert len(resnet_groups) == len(in_out)
|
||||
|
||||
# downsample klass
|
||||
|
||||
downsample_klass = Downsample
|
||||
if cross_embed_downsample:
|
||||
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
|
||||
|
||||
# layers
|
||||
|
||||
self.downs = nn.ModuleList([])
|
||||
@@ -1348,7 +1387,7 @@ class Unet(nn.Module):
|
||||
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
|
||||
Downsample(dim_out) if not is_last else nn.Identity()
|
||||
downsample_klass(dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
mid_dim = dims[-1]
|
||||
@@ -1382,12 +1421,13 @@ class Unet(nn.Module):
|
||||
*,
|
||||
lowres_cond,
|
||||
channels,
|
||||
cond_on_image_embeds
|
||||
cond_on_image_embeds,
|
||||
cond_on_text_encodings
|
||||
):
|
||||
if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds:
|
||||
if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds and cond_on_text_encodings == self.cond_on_text_encodings:
|
||||
return self
|
||||
|
||||
updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds}
|
||||
updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds, 'cond_on_text_encodings': cond_on_text_encodings}
|
||||
return self.__class__(**{**self._locals, **updated_kwargs})
|
||||
|
||||
def forward_with_cond_scale(
|
||||
@@ -1583,7 +1623,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
||||
clip_denoised = True,
|
||||
clip_x_start = True,
|
||||
clip_adapter_overrides = dict()
|
||||
clip_adapter_overrides = dict(),
|
||||
unconditional = False
|
||||
):
|
||||
super().__init__(
|
||||
beta_schedule = beta_schedule,
|
||||
@@ -1591,6 +1632,9 @@ class Decoder(BaseGaussianDiffusion):
|
||||
loss_type = loss_type
|
||||
)
|
||||
|
||||
self.unconditional = unconditional
|
||||
assert not (condition_on_text_encodings and unconditional), 'unconditional decoder image generation cannot be set to True if conditioning on text is present'
|
||||
|
||||
assert exists(clip) ^ exists(image_size), 'either CLIP is supplied, or you must give the image_size and channels (usually 3 for RGB)'
|
||||
|
||||
self.clip = None
|
||||
@@ -1632,7 +1676,8 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
one_unet = one_unet.cast_model_parameters(
|
||||
lowres_cond = not is_first,
|
||||
cond_on_image_embeds = is_first,
|
||||
cond_on_image_embeds = is_first and not unconditional,
|
||||
cond_on_text_encodings = one_unet.cond_on_text_encodings and not unconditional,
|
||||
channels = unet_channels
|
||||
)
|
||||
|
||||
@@ -1767,12 +1812,16 @@ class Decoder(BaseGaussianDiffusion):
|
||||
@eval_decorator
|
||||
def sample(
|
||||
self,
|
||||
image_embed,
|
||||
image_embed = None,
|
||||
text = None,
|
||||
batch_size = 1,
|
||||
cond_scale = 1.,
|
||||
stop_at_unet_number = None
|
||||
):
|
||||
batch_size = image_embed.shape[0]
|
||||
assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
|
||||
|
||||
if not self.unconditional:
|
||||
batch_size = image_embed.shape[0]
|
||||
|
||||
text_encodings = text_mask = None
|
||||
if exists(text):
|
||||
@@ -1782,10 +1831,11 @@ class Decoder(BaseGaussianDiffusion):
|
||||
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
||||
|
||||
img = None
|
||||
is_cuda = next(self.parameters()).is_cuda
|
||||
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
||||
|
||||
context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context()
|
||||
context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()
|
||||
|
||||
with context:
|
||||
lowres_cond_img = None
|
||||
|
||||
3
setup.py
3
setup.py
@@ -10,11 +10,12 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.2.5',
|
||||
version = '0.2.9',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
author_email = 'lucidrains@gmail.com',
|
||||
long_description_content_type = 'text/markdown',
|
||||
url = 'https://github.com/lucidrains/dalle2-pytorch',
|
||||
keywords = [
|
||||
'artificial intelligence',
|
||||
|
||||
Reference in New Issue
Block a user