mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-15 02:24:32 +01:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3b520dfa85 | ||
|
|
79198c6ae4 | ||
|
|
77a246b1b9 | ||
|
|
f93a3f6ed8 | ||
|
|
8f2a0c7e00 | ||
|
|
863f4ef243 |
23
README.md
23
README.md
@@ -12,7 +12,7 @@ This model is SOTA for text-to-image for now.
|
|||||||
|
|
||||||
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
|
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
|
||||||
|
|
||||||
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
||||||
|
|
||||||
## Install
|
## Install
|
||||||
|
|
||||||
@@ -246,13 +246,6 @@ loss = decoder(images, unet_number = 2)
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
# do the above for many steps for both unets
|
# do the above for many steps for both unets
|
||||||
|
|
||||||
# then it will learn to generate images based on the CLIP image embeddings
|
|
||||||
|
|
||||||
# chaining the unets from lowest resolution to highest resolution (thus cascading)
|
|
||||||
|
|
||||||
mock_image_embed = torch.randn(1, 512).cuda()
|
|
||||||
images = decoder.sample(mock_image_embed) # (1, 3, 512, 512)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which wraps `CLIP`, the causal transformer, and unet(s))
|
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which wraps `CLIP`, the causal transformer, and unet(s))
|
||||||
@@ -533,10 +526,12 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
- [ ] spend one day cleaning up tech debt in decoder
|
- [ ] spend one day cleaning up tech debt in decoder
|
||||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||||
|
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||||
- [ ] train on a toy task, offer in colab
|
- [ ] train on a toy task, offer in colab
|
||||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||||
- [ ] bring in tools to train vqgan-vae
|
- [ ] bring in tools to train vqgan-vae
|
||||||
- [ ] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
|
- [ ] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
|
||||||
|
- [ ] experiment with https://arxiv.org/abs/2112.11435 as upsampler, test in https://github.com/lucidrains/lightweight-gan first
|
||||||
|
|
||||||
## Citations
|
## Citations
|
||||||
|
|
||||||
@@ -568,7 +563,7 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@inproceedings{Liu2022ACF,
|
@inproceedings{Liu2022ACF,
|
||||||
title = {A ConvNet for the 2020s},
|
title = {A ConvNet for the 2020https://arxiv.org/abs/2112.11435s},
|
||||||
author = {Zhuang Liu and Hanzi Mao and Chaozheng Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
|
author = {Zhuang Liu and Hanzi Mao and Chaozheng Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
|
||||||
year = {2022}
|
year = {2022}
|
||||||
}
|
}
|
||||||
@@ -582,4 +577,14 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{Arar2021LearnedQF,
|
||||||
|
title = {Learned Queries for Efficient Local Attention},
|
||||||
|
author = {Moab Arar and Ariel Shamir and Amit H. Bermano},
|
||||||
|
journal = {ArXiv},
|
||||||
|
year = {2021},
|
||||||
|
volume = {abs/2112.11435}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>
|
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>
|
||||||
|
|||||||
@@ -483,7 +483,7 @@ class DiffusionPrior(nn.Module):
|
|||||||
timesteps = 1000,
|
timesteps = 1000,
|
||||||
cond_drop_prob = 0.2,
|
cond_drop_prob = 0.2,
|
||||||
loss_type = "l1",
|
loss_type = "l1",
|
||||||
predict_x0 = True,
|
predict_x_start = True,
|
||||||
beta_schedule = "cosine",
|
beta_schedule = "cosine",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -497,7 +497,7 @@ class DiffusionPrior(nn.Module):
|
|||||||
self.image_size = clip.image_size
|
self.image_size = clip.image_size
|
||||||
self.cond_drop_prob = cond_drop_prob
|
self.cond_drop_prob = cond_drop_prob
|
||||||
|
|
||||||
self.predict_x0 = predict_x0
|
self.predict_x_start = predict_x_start
|
||||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||||
|
|
||||||
if beta_schedule == "cosine":
|
if beta_schedule == "cosine":
|
||||||
@@ -586,14 +586,14 @@ class DiffusionPrior(nn.Module):
|
|||||||
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
||||||
pred = self.net(x, t, **text_cond)
|
pred = self.net(x, t, **text_cond)
|
||||||
|
|
||||||
if self.predict_x0:
|
if self.predict_x_start:
|
||||||
x_recon = pred
|
x_recon = pred
|
||||||
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
|
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
|
||||||
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
|
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
|
||||||
else:
|
else:
|
||||||
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
||||||
|
|
||||||
if clip_denoised and not self.predict_x0:
|
if clip_denoised and not self.predict_x_start:
|
||||||
x_recon.clamp_(-1., 1.)
|
x_recon.clamp_(-1., 1.)
|
||||||
|
|
||||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||||
@@ -639,7 +639,7 @@ class DiffusionPrior(nn.Module):
|
|||||||
**text_cond
|
**text_cond
|
||||||
)
|
)
|
||||||
|
|
||||||
to_predict = noise if not self.predict_x0 else image_embed
|
to_predict = noise if not self.predict_x_start else image_embed
|
||||||
|
|
||||||
if self.loss_type == 'l1':
|
if self.loss_type == 'l1':
|
||||||
loss = F.l1_loss(to_predict, x_recon)
|
loss = F.l1_loss(to_predict, x_recon)
|
||||||
@@ -1121,7 +1121,8 @@ class Decoder(nn.Module):
|
|||||||
cond_drop_prob = 0.2,
|
cond_drop_prob = 0.2,
|
||||||
loss_type = 'l1',
|
loss_type = 'l1',
|
||||||
beta_schedule = 'cosine',
|
beta_schedule = 'cosine',
|
||||||
predict_x0 = False,
|
predict_x_start = False,
|
||||||
|
predict_x_start_for_latent_diffusion = False,
|
||||||
image_sizes = None, # for cascading ddpm, image size at each stage
|
image_sizes = None, # for cascading ddpm, image size at each stage
|
||||||
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
|
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
|
||||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||||
@@ -1172,7 +1173,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
# predict x0 config
|
# predict x0 config
|
||||||
|
|
||||||
self.predict_x0 = cast_tuple(predict_x0, len(unets))
|
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
|
||||||
|
|
||||||
# cascading ddpm related stuff
|
# cascading ddpm related stuff
|
||||||
|
|
||||||
@@ -1292,31 +1293,31 @@ class Decoder(nn.Module):
|
|||||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||||
|
|
||||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x0 = False, cond_scale = 1.):
|
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
|
||||||
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||||
|
|
||||||
if predict_x0:
|
if predict_x_start:
|
||||||
x_recon = pred
|
x_recon = pred
|
||||||
else:
|
else:
|
||||||
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
||||||
|
|
||||||
if clip_denoised and not predict_x0:
|
if clip_denoised and not predict_x_start:
|
||||||
x_recon.clamp_(-1., 1.)
|
x_recon.clamp_(-1., 1.)
|
||||||
|
|
||||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||||
return model_mean, posterior_variance, posterior_log_variance
|
return model_mean, posterior_variance, posterior_log_variance
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x0 = False, clip_denoised = True, repeat_noise = False):
|
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x0 = predict_x0)
|
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
|
||||||
noise = noise_like(x.shape, device, repeat_noise)
|
noise = noise_like(x.shape, device, repeat_noise)
|
||||||
# no noise when t == 0
|
# no noise when t == 0
|
||||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample_loop(self, unet, shape, image_embed, predict_x0 = False, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
|
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
|
||||||
device = self.betas.device
|
device = self.betas.device
|
||||||
|
|
||||||
b = shape[0]
|
b = shape[0]
|
||||||
@@ -1331,7 +1332,7 @@ class Decoder(nn.Module):
|
|||||||
text_encodings = text_encodings,
|
text_encodings = text_encodings,
|
||||||
cond_scale = cond_scale,
|
cond_scale = cond_scale,
|
||||||
lowres_cond_img = lowres_cond_img,
|
lowres_cond_img = lowres_cond_img,
|
||||||
predict_x0 = predict_x0
|
predict_x_start = predict_x_start
|
||||||
)
|
)
|
||||||
|
|
||||||
return img
|
return img
|
||||||
@@ -1344,7 +1345,7 @@ class Decoder(nn.Module):
|
|||||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||||
)
|
)
|
||||||
|
|
||||||
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x0 = False, noise = None):
|
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
|
||||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
|
|
||||||
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
|
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
|
||||||
@@ -1358,7 +1359,7 @@ class Decoder(nn.Module):
|
|||||||
cond_drop_prob = self.cond_drop_prob
|
cond_drop_prob = self.cond_drop_prob
|
||||||
)
|
)
|
||||||
|
|
||||||
target = noise if not predict_x0 else x_start
|
target = noise if not predict_x_start else x_start
|
||||||
|
|
||||||
if self.loss_type == 'l1':
|
if self.loss_type == 'l1':
|
||||||
loss = F.l1_loss(target, x_recon)
|
loss = F.l1_loss(target, x_recon)
|
||||||
@@ -1380,7 +1381,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
img = None
|
img = None
|
||||||
|
|
||||||
for unet, vae, channel, image_size, predict_x0 in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x0)):
|
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
||||||
with self.one_unet_in_gpu(unet = unet):
|
with self.one_unet_in_gpu(unet = unet):
|
||||||
lowres_cond_img = None
|
lowres_cond_img = None
|
||||||
shape = (batch_size, channel, image_size, image_size)
|
shape = (batch_size, channel, image_size, image_size)
|
||||||
@@ -1400,7 +1401,7 @@ class Decoder(nn.Module):
|
|||||||
image_embed = image_embed,
|
image_embed = image_embed,
|
||||||
text_encodings = text_encodings,
|
text_encodings = text_encodings,
|
||||||
cond_scale = cond_scale,
|
cond_scale = cond_scale,
|
||||||
predict_x0 = predict_x0,
|
predict_x_start = predict_x_start,
|
||||||
lowres_cond_img = lowres_cond_img
|
lowres_cond_img = lowres_cond_img
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1424,7 +1425,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
target_image_size = self.image_sizes[unet_index]
|
target_image_size = self.image_sizes[unet_index]
|
||||||
vae = self.vaes[unet_index]
|
vae = self.vaes[unet_index]
|
||||||
predict_x0 = self.predict_x0[unet_index]
|
predict_x_start = self.predict_x_start[unet_index]
|
||||||
|
|
||||||
b, c, h, w, device, = *image.shape, image.device
|
b, c, h, w, device, = *image.shape, image.device
|
||||||
|
|
||||||
@@ -1448,7 +1449,7 @@ class Decoder(nn.Module):
|
|||||||
if exists(lowres_cond_img):
|
if exists(lowres_cond_img):
|
||||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
lowres_cond_img = vae.encode(lowres_cond_img)
|
||||||
|
|
||||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x0 = predict_x0)
|
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
|
||||||
|
|
||||||
# main class
|
# main class
|
||||||
|
|
||||||
|
|||||||
@@ -243,6 +243,112 @@ class ResBlock(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.net(x) + x
|
return self.net(x) + x
|
||||||
|
|
||||||
|
# attention-based upsampling
|
||||||
|
# from https://arxiv.org/abs/2112.11435
|
||||||
|
|
||||||
|
class QueryAndAttend(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
dim,
|
||||||
|
num_queries = 1,
|
||||||
|
dim_head = 32,
|
||||||
|
heads = 8,
|
||||||
|
window_size = 3
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = dim_head ** -0.5
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
self.heads = heads
|
||||||
|
self.dim_head = dim_head
|
||||||
|
self.window_size = window_size
|
||||||
|
self.num_queries = num_queries
|
||||||
|
|
||||||
|
self.rel_pos_bias = nn.Parameter(torch.randn(heads, num_queries, window_size * window_size, 1, 1))
|
||||||
|
|
||||||
|
self.queries = nn.Parameter(torch.randn(heads, num_queries, dim_head))
|
||||||
|
self.to_kv = nn.Conv2d(dim, dim_head * 2, 1, bias = False)
|
||||||
|
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
einstein notation
|
||||||
|
b - batch
|
||||||
|
h - heads
|
||||||
|
l - num queries
|
||||||
|
d - head dimension
|
||||||
|
x - height
|
||||||
|
y - width
|
||||||
|
j - source sequence for attending to (kernel size squared in this case)
|
||||||
|
"""
|
||||||
|
|
||||||
|
wsz, heads, dim_head, num_queries = self.window_size, self.heads, self.dim_head, self.num_queries
|
||||||
|
batch, _, height, width = x.shape
|
||||||
|
|
||||||
|
is_one_query = self.num_queries == 1
|
||||||
|
|
||||||
|
# queries, keys, values
|
||||||
|
|
||||||
|
q = self.queries * self.scale
|
||||||
|
k, v = self.to_kv(x).chunk(2, dim = 1)
|
||||||
|
|
||||||
|
# similarities
|
||||||
|
|
||||||
|
sim = einsum('h l d, b d x y -> b h l x y', q, k)
|
||||||
|
sim = rearrange(sim, 'b ... x y -> b (...) x y')
|
||||||
|
|
||||||
|
# unfold the similarity scores, with float(-inf) as padding value
|
||||||
|
|
||||||
|
mask_value = -torch.finfo(sim.dtype).max
|
||||||
|
sim = F.pad(sim, ((wsz // 2,) * 4), value = mask_value)
|
||||||
|
sim = F.unfold(sim, kernel_size = wsz)
|
||||||
|
sim = rearrange(sim, 'b (h l j) (x y) -> b h l j x y', h = heads, l = num_queries, x = height, y = width)
|
||||||
|
|
||||||
|
# rel pos bias
|
||||||
|
|
||||||
|
sim = sim + self.rel_pos_bias
|
||||||
|
|
||||||
|
# numerically stable attention
|
||||||
|
|
||||||
|
sim = sim - sim.amax(dim = -3, keepdim = True).detach()
|
||||||
|
attn = sim.softmax(dim = -3)
|
||||||
|
|
||||||
|
# unfold values
|
||||||
|
|
||||||
|
v = F.pad(v, ((wsz // 2,) * 4), value = 0.)
|
||||||
|
v = F.unfold(v, kernel_size = wsz)
|
||||||
|
v = rearrange(v, 'b (d j) (x y) -> b d j x y', d = dim_head, x = height, y = width)
|
||||||
|
|
||||||
|
# aggregate values
|
||||||
|
|
||||||
|
out = einsum('b h l j x y, b d j x y -> b l h d x y', attn, v)
|
||||||
|
|
||||||
|
# combine heads
|
||||||
|
|
||||||
|
out = rearrange(out, 'b l h d x y -> (b l) (h d) x y')
|
||||||
|
out = self.to_out(out)
|
||||||
|
out = rearrange(out, '(b l) d x y -> b l d x y', b = batch)
|
||||||
|
|
||||||
|
# return original input if one query
|
||||||
|
|
||||||
|
if is_one_query:
|
||||||
|
out = rearrange(out, 'b 1 ... -> b ...')
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
class QueryAttnUpsample(nn.Module):
|
||||||
|
def __init__(self, dim, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = LayerNormChan(dim)
|
||||||
|
self.qna = QueryAndAttend(dim = dim, num_queries = 4, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.norm(x)
|
||||||
|
out = self.qna(x)
|
||||||
|
out = rearrange(out, 'b (w1 w2) c h w -> b c (h w1) (w w2)', w1 = 2, w2 = 2)
|
||||||
|
return out
|
||||||
|
|
||||||
|
# vqgan attention layer
|
||||||
class VQGanAttention(nn.Module):
|
class VQGanAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -375,7 +481,7 @@ class VQGanVAE(nn.Module):
|
|||||||
|
|
||||||
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
|
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
|
||||||
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
|
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
|
||||||
prepend(self.decoders, nn.Sequential(nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu()))
|
prepend(self.decoders, nn.Sequential(QueryAttnUpsample(dim_out), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu()))
|
||||||
|
|
||||||
if layer_use_attn:
|
if layer_use_attn:
|
||||||
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||||
|
|||||||
Reference in New Issue
Block a user