Compare commits

...

4 Commits

Author SHA1 Message Date
Phil Wang
6f941a219a give time tokens a surface area of 2 tokens as default, make it so researcher can customize which unet actually is conditioned on image embeddings and/or text encodings 2022-04-20 10:04:47 -07:00
Phil Wang
ddde8ca1bf fix cosine bbeta schedule, thanks to @Zhengxinyang 2022-04-19 20:54:28 -07:00
Phil Wang
c26b77ad20 todo 2022-04-19 13:07:32 -07:00
Phil Wang
c5b4aab8e5 intent 2022-04-19 11:00:05 -07:00
4 changed files with 35 additions and 17 deletions

View File

@@ -14,7 +14,7 @@ It may also explore an extension of using <a href="https://huggingface.co/spaces
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
## Install
@@ -410,8 +410,9 @@ Offer training wrappers
- [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
- [x] add efficient attention in unet
- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] train on a toy task, offer in colab

View File

@@ -6,4 +6,4 @@ def main():
@click.command()
@click.argument('text')
def dream(text):
return image
return 'not ready yet'

View File

@@ -105,8 +105,8 @@ def cosine_beta_schedule(timesteps, s = 0.008):
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = torch.linspace(0, steps, steps)
alphas_cumprod = torch.cos(((x / steps) + s) / (1 + s) * torch.pi * 0.5) ** 2
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
@@ -820,6 +820,7 @@ class Unet(nn.Module):
image_embed_dim,
cond_dim = None,
num_image_tokens = 4,
num_time_tokens = 2,
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
@@ -830,6 +831,8 @@ class Unet(nn.Module):
sparse_attn = False,
sparse_attn_window = 8, # window size for sparse attention
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
cond_on_text_encodings = False,
cond_on_image_embeds = False,
):
super().__init__()
# save locals to take care of some hyperparameters for cascading DDPM
@@ -862,8 +865,8 @@ class Unet(nn.Module):
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, cond_dim),
Rearrange('b d -> b 1 d')
nn.Linear(dim * 4, cond_dim * num_time_tokens),
Rearrange('b (r d) -> b r d', r = num_time_tokens)
)
self.image_to_cond = nn.Sequential(
@@ -873,6 +876,12 @@ class Unet(nn.Module):
self.text_to_cond = nn.LazyLinear(cond_dim)
# finer control over whether to condition on image embeddings and text encodings
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
self.cond_on_text_encodings = cond_on_text_encodings
self.cond_on_image_embeds = cond_on_image_embeds
# for classifier free guidance
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
@@ -982,17 +991,22 @@ class Unet(nn.Module):
# mask out image embedding depending on condition dropout
# for classifier free guidance
image_tokens = self.image_to_cond(image_embed)
image_tokens = None
image_tokens = torch.where(
cond_prob_mask,
image_tokens,
self.null_image_embed
)
if self.cond_on_image_embeds:
image_tokens = self.image_to_cond(image_embed)
image_tokens = torch.where(
cond_prob_mask,
image_tokens,
self.null_image_embed
)
# take care of text encodings (optional)
if exists(text_encodings):
text_tokens = None
if exists(text_encodings) and self.cond_on_text_encodings:
text_tokens = self.text_to_cond(text_encodings)
text_tokens = torch.where(
cond_prob_mask,
@@ -1002,12 +1016,15 @@ class Unet(nn.Module):
# main conditioning tokens (c)
c = torch.cat((time_tokens, image_tokens), dim = -2)
c = time_tokens
if exists(image_tokens):
c = torch.cat((c, image_tokens), dim = -2)
# text and image conditioning tokens (mid_c)
# to save on compute, only do cross attention based conditioning on the inner most layers of the Unet
mid_c = c if not exists(text_encodings) else torch.cat((c, text_tokens), dim = -2)
mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
# go through the layers of the unet, down and up

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.25',
version = '0.0.27',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',