mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
give time tokens a surface area of 2 tokens as default, make it so researcher can customize which unet actually is conditioned on image embeddings and/or text encodings
This commit is contained in:
@@ -410,9 +410,9 @@ Offer training wrappers
|
|||||||
- [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
|
- [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
|
||||||
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
|
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
|
||||||
- [x] add efficient attention in unet
|
- [x] add efficient attention in unet
|
||||||
- [ ] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
|
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
|
||||||
- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
|
|
||||||
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
|
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
|
||||||
|
- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
|
||||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||||
- [ ] train on a toy task, offer in colab
|
- [ ] train on a toy task, offer in colab
|
||||||
|
|
||||||
|
|||||||
@@ -6,4 +6,4 @@ def main():
|
|||||||
@click.command()
|
@click.command()
|
||||||
@click.argument('text')
|
@click.argument('text')
|
||||||
def dream(text):
|
def dream(text):
|
||||||
return image
|
return 'not ready yet'
|
||||||
|
|||||||
@@ -820,6 +820,7 @@ class Unet(nn.Module):
|
|||||||
image_embed_dim,
|
image_embed_dim,
|
||||||
cond_dim = None,
|
cond_dim = None,
|
||||||
num_image_tokens = 4,
|
num_image_tokens = 4,
|
||||||
|
num_time_tokens = 2,
|
||||||
out_dim = None,
|
out_dim = None,
|
||||||
dim_mults=(1, 2, 4, 8),
|
dim_mults=(1, 2, 4, 8),
|
||||||
channels = 3,
|
channels = 3,
|
||||||
@@ -830,6 +831,8 @@ class Unet(nn.Module):
|
|||||||
sparse_attn = False,
|
sparse_attn = False,
|
||||||
sparse_attn_window = 8, # window size for sparse attention
|
sparse_attn_window = 8, # window size for sparse attention
|
||||||
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
||||||
|
cond_on_text_encodings = False,
|
||||||
|
cond_on_image_embeds = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# save locals to take care of some hyperparameters for cascading DDPM
|
# save locals to take care of some hyperparameters for cascading DDPM
|
||||||
@@ -862,8 +865,8 @@ class Unet(nn.Module):
|
|||||||
SinusoidalPosEmb(dim),
|
SinusoidalPosEmb(dim),
|
||||||
nn.Linear(dim, dim * 4),
|
nn.Linear(dim, dim * 4),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Linear(dim * 4, cond_dim),
|
nn.Linear(dim * 4, cond_dim * num_time_tokens),
|
||||||
Rearrange('b d -> b 1 d')
|
Rearrange('b (r d) -> b r d', r = num_time_tokens)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.image_to_cond = nn.Sequential(
|
self.image_to_cond = nn.Sequential(
|
||||||
@@ -873,6 +876,12 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
self.text_to_cond = nn.LazyLinear(cond_dim)
|
self.text_to_cond = nn.LazyLinear(cond_dim)
|
||||||
|
|
||||||
|
# finer control over whether to condition on image embeddings and text encodings
|
||||||
|
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
|
||||||
|
|
||||||
|
self.cond_on_text_encodings = cond_on_text_encodings
|
||||||
|
self.cond_on_image_embeds = cond_on_image_embeds
|
||||||
|
|
||||||
# for classifier free guidance
|
# for classifier free guidance
|
||||||
|
|
||||||
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||||
@@ -982,6 +991,9 @@ class Unet(nn.Module):
|
|||||||
# mask out image embedding depending on condition dropout
|
# mask out image embedding depending on condition dropout
|
||||||
# for classifier free guidance
|
# for classifier free guidance
|
||||||
|
|
||||||
|
image_tokens = None
|
||||||
|
|
||||||
|
if self.cond_on_image_embeds:
|
||||||
image_tokens = self.image_to_cond(image_embed)
|
image_tokens = self.image_to_cond(image_embed)
|
||||||
|
|
||||||
image_tokens = torch.where(
|
image_tokens = torch.where(
|
||||||
@@ -992,7 +1004,9 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
# take care of text encodings (optional)
|
# take care of text encodings (optional)
|
||||||
|
|
||||||
if exists(text_encodings):
|
text_tokens = None
|
||||||
|
|
||||||
|
if exists(text_encodings) and self.cond_on_text_encodings:
|
||||||
text_tokens = self.text_to_cond(text_encodings)
|
text_tokens = self.text_to_cond(text_encodings)
|
||||||
text_tokens = torch.where(
|
text_tokens = torch.where(
|
||||||
cond_prob_mask,
|
cond_prob_mask,
|
||||||
@@ -1002,12 +1016,15 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
# main conditioning tokens (c)
|
# main conditioning tokens (c)
|
||||||
|
|
||||||
c = torch.cat((time_tokens, image_tokens), dim = -2)
|
c = time_tokens
|
||||||
|
|
||||||
|
if exists(image_tokens):
|
||||||
|
c = torch.cat((c, image_tokens), dim = -2)
|
||||||
|
|
||||||
# text and image conditioning tokens (mid_c)
|
# text and image conditioning tokens (mid_c)
|
||||||
# to save on compute, only do cross attention based conditioning on the inner most layers of the Unet
|
# to save on compute, only do cross attention based conditioning on the inner most layers of the Unet
|
||||||
|
|
||||||
mid_c = c if not exists(text_encodings) else torch.cat((c, text_tokens), dim = -2)
|
mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
|
||||||
|
|
||||||
# go through the layers of the unet, down and up
|
# go through the layers of the unet, down and up
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user