Compare commits

...

4 Commits

Author SHA1 Message Date
Phil Wang
faebf4c8b8 from my vision transformer experience, dimension of attention head of 32 is sufficient for image feature maps 2022-04-20 11:40:32 -07:00
Phil Wang
b8e8d3c164 thoughts 2022-04-20 11:34:51 -07:00
Phil Wang
8e2416b49b commit to generalizing latent diffusion to one model 2022-04-20 11:27:42 -07:00
Phil Wang
f37c26e856 cleanup and DRY a little 2022-04-20 10:56:32 -07:00
3 changed files with 33 additions and 23 deletions

View File

@@ -412,7 +412,7 @@ Offer training wrappers
- [x] add efficient attention in unet
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
- [ ] build out latent diffusion architecture, make it completely optional (additional autoencoder + some regularizations [kl and vq regs]) (figure out if latent diffusion + cascading ddpm can be used in conjunction)
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] train on a toy task, offer in colab

View File

@@ -464,11 +464,11 @@ class DiffusionPrior(nn.Module):
net,
*,
clip,
timesteps=1000,
cond_drop_prob=0.2,
loss_type="l1",
predict_x0=True,
beta_schedule="cosine",
timesteps = 1000,
cond_drop_prob = 0.2,
loss_type = "l1",
predict_x0 = True,
beta_schedule = "cosine",
):
super().__init__()
assert isinstance(clip, CLIP)
@@ -825,6 +825,8 @@ class Unet(nn.Module):
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
attn_dim_head = 32,
attn_heads = 8,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_cond_upsample_mode = 'bilinear',
blur_sigma = 0.1,
@@ -888,6 +890,10 @@ class Unet(nn.Module):
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
# attention related params
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
# layers
self.downs = nn.ModuleList([])
@@ -901,7 +907,7 @@ class Unet(nn.Module):
self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
Residual(GridAttention(dim_out, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
Downsample(dim_out) if not is_last else nn.Identity()
]))
@@ -909,7 +915,7 @@ class Unet(nn.Module):
mid_dim = dims[-1]
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim))) if attend_at_middle else None
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
@@ -918,7 +924,7 @@ class Unet(nn.Module):
self.ups.append(nn.ModuleList([
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
Residual(GridAttention(dim_in, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
Upsample(dim_in)
]))
@@ -1142,19 +1148,24 @@ class Decoder(nn.Module):
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
@contextmanager
def one_unet_in_gpu(self, unet_number):
def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1
return self.unets[index]
@contextmanager
def one_unet_in_gpu(self, unet_number = None, unet = None):
assert exists(unet_number) ^ exists(unet)
if exists(unet_number):
unet = self.get_unet(unet_number)
self.cuda()
self.unets.cpu()
unet = self.unets[index]
unet.cuda()
yield
self.unets.cpu()
unet.cpu()
def get_text_encodings(self, text):
text_encodings = self.clip.text_transformer(text)
@@ -1261,8 +1272,8 @@ class Decoder(nn.Module):
img = None
for ind, (unet, image_size) in tqdm(enumerate(zip(self.unets, self.image_sizes))):
with self.one_unet_in_gpu(ind + 1):
for unet, image_size in tqdm(zip(self.unets, self.image_sizes)):
with self.one_unet_in_gpu(unet = unet):
shape = (batch_size, channels, image_size, image_size)
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
@@ -1271,11 +1282,10 @@ class Decoder(nn.Module):
def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1)
assert 1 <= unet_number <= len(self.unets)
index = unet_number - 1
unet = self.unets[index]
target_image_size = self.image_sizes[index]
unet = self.get_unet(unet_number)
target_image_size = self.image_sizes[unet_number - 1]
b, c, h, w, device, = *image.shape, image.device
@@ -1289,7 +1299,7 @@ class Decoder(nn.Module):
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
lowres_cond_img = image if index > 0 else None
lowres_cond_img = image if unet_number > 1 else None
ddpm_image = resize_image_to(image, target_image_size)
return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.28',
version = '0.0.31',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',