mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3b520dfa85 | ||
|
|
79198c6ae4 | ||
|
|
77a246b1b9 | ||
|
|
f93a3f6ed8 | ||
|
|
8f2a0c7e00 | ||
|
|
863f4ef243 | ||
|
|
fb8a66a2de | ||
|
|
579d4b42dd | ||
|
|
473808850a | ||
|
|
d5318aef4f | ||
|
|
f82917e1fd |
37
README.md
37
README.md
@@ -12,7 +12,7 @@ This model is SOTA for text-to-image for now.
|
||||
|
||||
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
|
||||
|
||||
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
||||
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
||||
|
||||
## Install
|
||||
|
||||
@@ -246,13 +246,6 @@ loss = decoder(images, unet_number = 2)
|
||||
loss.backward()
|
||||
|
||||
# do the above for many steps for both unets
|
||||
|
||||
# then it will learn to generate images based on the CLIP image embeddings
|
||||
|
||||
# chaining the unets from lowest resolution to highest resolution (thus cascading)
|
||||
|
||||
mock_image_embed = torch.randn(1, 512).cuda()
|
||||
images = decoder.sample(mock_image_embed) # (1, 3, 512, 512)
|
||||
```
|
||||
|
||||
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which wraps `CLIP`, the causal transformer, and unet(s))
|
||||
@@ -529,13 +522,16 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
|
||||
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
|
||||
- [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
|
||||
- [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
|
||||
- [ ] spend one day cleaning up tech debt in decoder
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||
- [ ] bring in tools to train vqgan-vae
|
||||
- [ ] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
|
||||
- [ ] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
|
||||
- [ ] experiment with https://arxiv.org/abs/2112.11435 as upsampler, test in https://github.com/lucidrains/lightweight-gan first
|
||||
|
||||
## Citations
|
||||
|
||||
@@ -567,23 +563,12 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Liu2022ACF,
|
||||
title = {A ConvNet for the 2020s},
|
||||
title = {A ConvNet for the 2020https://arxiv.org/abs/2112.11435s},
|
||||
author = {Zhuang Liu and Hanzi Mao and Chaozheng Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{zhang2019root,
|
||||
title = {Root Mean Square Layer Normalization},
|
||||
author = {Biao Zhang and Rico Sennrich},
|
||||
year = {2019},
|
||||
eprint = {1910.07467},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.LG}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Tu2022MaxViTMV,
|
||||
title = {MaxViT: Multi-Axis Vision Transformer},
|
||||
@@ -592,4 +577,14 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Arar2021LearnedQF,
|
||||
title = {Learned Queries for Efficient Local Attention},
|
||||
author = {Moab Arar and Ariel Shamir and Amit H. Bermano},
|
||||
journal = {ArXiv},
|
||||
year = {2021},
|
||||
volume = {abs/2112.11435}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>
|
||||
|
||||
@@ -1,9 +1,51 @@
|
||||
import click
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from pathlib import Path
|
||||
|
||||
from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior
|
||||
|
||||
def safeget(dictionary, keys, default = None):
|
||||
return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary)
|
||||
|
||||
def simple_slugify(text, max_length = 255):
|
||||
return text.replace("-", "_").replace(",", "").replace(" ", "_").replace("|", "--").strip('-_')[:max_length]
|
||||
|
||||
def get_pkg_version():
|
||||
from pkg_resources import get_distribution
|
||||
return get_distribution('dalle2_pytorch').version
|
||||
|
||||
def main():
|
||||
pass
|
||||
|
||||
@click.command()
|
||||
@click.option('--model', default = './dalle2.pt', help = 'path to trained DALL-E2 model')
|
||||
@click.option('--cond_scale', default = 2, help = 'conditioning scale (classifier free guidance) in decoder')
|
||||
@click.argument('text')
|
||||
def dream(text):
|
||||
return 'not ready yet'
|
||||
def dream(
|
||||
model,
|
||||
cond_scale,
|
||||
text
|
||||
):
|
||||
model_path = Path(model)
|
||||
full_model_path = str(model_path.resolve())
|
||||
assert model_path.exists(), f'model not found at {full_model_path}'
|
||||
loaded = torch.load(str(model_path))
|
||||
|
||||
version = safeget(loaded, 'version')
|
||||
print(f'loading DALL-E2 from {full_model_path}, saved at version {version} - current package version is {get_pkg_version()}')
|
||||
|
||||
prior_init_params = safeget(loaded, 'init_params.prior')
|
||||
decoder_init_params = safeget(loaded, 'init_params.decoder')
|
||||
model_params = safeget(loaded, 'model_params')
|
||||
|
||||
prior = DiffusionPrior(**prior_init_params)
|
||||
decoder = Decoder(**decoder_init_params)
|
||||
|
||||
dalle2 = DALLE2(prior, decoder)
|
||||
dalle2.load_state_dict(model_params)
|
||||
|
||||
image = dalle2(text, cond_scale = cond_scale)
|
||||
|
||||
pil_image = T.ToPILImage()(image)
|
||||
return pil_image.save(f'./{simple_slugify(text)}.png')
|
||||
|
||||
@@ -483,7 +483,7 @@ class DiffusionPrior(nn.Module):
|
||||
timesteps = 1000,
|
||||
cond_drop_prob = 0.2,
|
||||
loss_type = "l1",
|
||||
predict_x0 = True,
|
||||
predict_x_start = True,
|
||||
beta_schedule = "cosine",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -497,7 +497,7 @@ class DiffusionPrior(nn.Module):
|
||||
self.image_size = clip.image_size
|
||||
self.cond_drop_prob = cond_drop_prob
|
||||
|
||||
self.predict_x0 = predict_x0
|
||||
self.predict_x_start = predict_x_start
|
||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||
|
||||
if beta_schedule == "cosine":
|
||||
@@ -584,14 +584,16 @@ class DiffusionPrior(nn.Module):
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
||||
if self.predict_x0:
|
||||
x_recon = self.net(x, t, **text_cond)
|
||||
pred = self.net(x, t, **text_cond)
|
||||
|
||||
if self.predict_x_start:
|
||||
x_recon = pred
|
||||
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
|
||||
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
|
||||
else:
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, **text_cond))
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
||||
|
||||
if clip_denoised:
|
||||
if clip_denoised and not self.predict_x_start:
|
||||
x_recon.clamp_(-1., 1.)
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
@@ -637,7 +639,7 @@ class DiffusionPrior(nn.Module):
|
||||
**text_cond
|
||||
)
|
||||
|
||||
to_predict = noise if not self.predict_x0 else image_embed
|
||||
to_predict = noise if not self.predict_x_start else image_embed
|
||||
|
||||
if self.loss_type == 'l1':
|
||||
loss = F.l1_loss(to_predict, x_recon)
|
||||
@@ -1119,6 +1121,8 @@ class Decoder(nn.Module):
|
||||
cond_drop_prob = 0.2,
|
||||
loss_type = 'l1',
|
||||
beta_schedule = 'cosine',
|
||||
predict_x_start = False,
|
||||
predict_x_start_for_latent_diffusion = False,
|
||||
image_sizes = None, # for cascading ddpm, image size at each stage
|
||||
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
|
||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||
@@ -1167,6 +1171,10 @@ class Decoder(nn.Module):
|
||||
self.image_sizes = image_sizes
|
||||
self.sample_channels = cast_tuple(self.channels, len(image_sizes))
|
||||
|
||||
# predict x0 config
|
||||
|
||||
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
|
||||
|
||||
# cascading ddpm related stuff
|
||||
|
||||
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
|
||||
@@ -1285,34 +1293,47 @@ class Decoder(nn.Module):
|
||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, cond_scale = 1.):
|
||||
pred_noise = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = pred_noise)
|
||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
|
||||
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
|
||||
if clip_denoised:
|
||||
if predict_x_start:
|
||||
x_recon = pred
|
||||
else:
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
||||
|
||||
if clip_denoised and not predict_x_start:
|
||||
x_recon.clamp_(-1., 1.)
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, clip_denoised = True, repeat_noise = False):
|
||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, unet, shape, image_embed, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
|
||||
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
|
||||
device = self.betas.device
|
||||
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
||||
img = self.p_sample(unet, img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
img = self.p_sample(
|
||||
unet,
|
||||
img,
|
||||
torch.full((b,), i, device = device, dtype = torch.long),
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
cond_scale = cond_scale,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
predict_x_start = predict_x_start
|
||||
)
|
||||
|
||||
return img
|
||||
|
||||
@@ -1324,7 +1345,7 @@ class Decoder(nn.Module):
|
||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||
)
|
||||
|
||||
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, noise = None):
|
||||
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
|
||||
@@ -1338,12 +1359,14 @@ class Decoder(nn.Module):
|
||||
cond_drop_prob = self.cond_drop_prob
|
||||
)
|
||||
|
||||
target = noise if not predict_x_start else x_start
|
||||
|
||||
if self.loss_type == 'l1':
|
||||
loss = F.l1_loss(noise, x_recon)
|
||||
loss = F.l1_loss(target, x_recon)
|
||||
elif self.loss_type == 'l2':
|
||||
loss = F.mse_loss(noise, x_recon)
|
||||
loss = F.mse_loss(target, x_recon)
|
||||
elif self.loss_type == "huber":
|
||||
loss = F.smooth_l1_loss(noise, x_recon)
|
||||
loss = F.smooth_l1_loss(target, x_recon)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -1358,7 +1381,7 @@ class Decoder(nn.Module):
|
||||
|
||||
img = None
|
||||
|
||||
for unet, vae, channel, image_size in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes)):
|
||||
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
||||
with self.one_unet_in_gpu(unet = unet):
|
||||
lowres_cond_img = None
|
||||
shape = (batch_size, channel, image_size, image_size)
|
||||
@@ -1378,6 +1401,7 @@ class Decoder(nn.Module):
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
cond_scale = cond_scale,
|
||||
predict_x_start = predict_x_start,
|
||||
lowres_cond_img = lowres_cond_img
|
||||
)
|
||||
|
||||
@@ -1401,6 +1425,7 @@ class Decoder(nn.Module):
|
||||
|
||||
target_image_size = self.image_sizes[unet_index]
|
||||
vae = self.vaes[unet_index]
|
||||
predict_x_start = self.predict_x_start[unet_index]
|
||||
|
||||
b, c, h, w, device, = *image.shape, image.device
|
||||
|
||||
@@ -1424,7 +1449,7 @@ class Decoder(nn.Module):
|
||||
if exists(lowres_cond_img):
|
||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
||||
|
||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
|
||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
|
||||
|
||||
# main class
|
||||
|
||||
@@ -1451,6 +1476,7 @@ class DALLE2(nn.Module):
|
||||
cond_scale = 1.
|
||||
):
|
||||
device = next(self.parameters()).device
|
||||
one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)
|
||||
|
||||
if isinstance(text, str) or is_list_str(text):
|
||||
text = [text] if not isinstance(text, (list, tuple)) else text
|
||||
@@ -1458,4 +1484,8 @@ class DALLE2(nn.Module):
|
||||
|
||||
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
|
||||
images = self.decoder.sample(image_embed, cond_scale = cond_scale)
|
||||
|
||||
if one_text:
|
||||
return images[0]
|
||||
|
||||
return images
|
||||
|
||||
@@ -243,6 +243,112 @@ class ResBlock(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.net(x) + x
|
||||
|
||||
# attention-based upsampling
|
||||
# from https://arxiv.org/abs/2112.11435
|
||||
|
||||
class QueryAndAttend(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
num_queries = 1,
|
||||
dim_head = 32,
|
||||
heads = 8,
|
||||
window_size = 3
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
self.window_size = window_size
|
||||
self.num_queries = num_queries
|
||||
|
||||
self.rel_pos_bias = nn.Parameter(torch.randn(heads, num_queries, window_size * window_size, 1, 1))
|
||||
|
||||
self.queries = nn.Parameter(torch.randn(heads, num_queries, dim_head))
|
||||
self.to_kv = nn.Conv2d(dim, dim_head * 2, 1, bias = False)
|
||||
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
einstein notation
|
||||
b - batch
|
||||
h - heads
|
||||
l - num queries
|
||||
d - head dimension
|
||||
x - height
|
||||
y - width
|
||||
j - source sequence for attending to (kernel size squared in this case)
|
||||
"""
|
||||
|
||||
wsz, heads, dim_head, num_queries = self.window_size, self.heads, self.dim_head, self.num_queries
|
||||
batch, _, height, width = x.shape
|
||||
|
||||
is_one_query = self.num_queries == 1
|
||||
|
||||
# queries, keys, values
|
||||
|
||||
q = self.queries * self.scale
|
||||
k, v = self.to_kv(x).chunk(2, dim = 1)
|
||||
|
||||
# similarities
|
||||
|
||||
sim = einsum('h l d, b d x y -> b h l x y', q, k)
|
||||
sim = rearrange(sim, 'b ... x y -> b (...) x y')
|
||||
|
||||
# unfold the similarity scores, with float(-inf) as padding value
|
||||
|
||||
mask_value = -torch.finfo(sim.dtype).max
|
||||
sim = F.pad(sim, ((wsz // 2,) * 4), value = mask_value)
|
||||
sim = F.unfold(sim, kernel_size = wsz)
|
||||
sim = rearrange(sim, 'b (h l j) (x y) -> b h l j x y', h = heads, l = num_queries, x = height, y = width)
|
||||
|
||||
# rel pos bias
|
||||
|
||||
sim = sim + self.rel_pos_bias
|
||||
|
||||
# numerically stable attention
|
||||
|
||||
sim = sim - sim.amax(dim = -3, keepdim = True).detach()
|
||||
attn = sim.softmax(dim = -3)
|
||||
|
||||
# unfold values
|
||||
|
||||
v = F.pad(v, ((wsz // 2,) * 4), value = 0.)
|
||||
v = F.unfold(v, kernel_size = wsz)
|
||||
v = rearrange(v, 'b (d j) (x y) -> b d j x y', d = dim_head, x = height, y = width)
|
||||
|
||||
# aggregate values
|
||||
|
||||
out = einsum('b h l j x y, b d j x y -> b l h d x y', attn, v)
|
||||
|
||||
# combine heads
|
||||
|
||||
out = rearrange(out, 'b l h d x y -> (b l) (h d) x y')
|
||||
out = self.to_out(out)
|
||||
out = rearrange(out, '(b l) d x y -> b l d x y', b = batch)
|
||||
|
||||
# return original input if one query
|
||||
|
||||
if is_one_query:
|
||||
out = rearrange(out, 'b 1 ... -> b ...')
|
||||
|
||||
return out
|
||||
|
||||
class QueryAttnUpsample(nn.Module):
|
||||
def __init__(self, dim, **kwargs):
|
||||
super().__init__()
|
||||
self.norm = LayerNormChan(dim)
|
||||
self.qna = QueryAndAttend(dim = dim, num_queries = 4, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
out = self.qna(x)
|
||||
out = rearrange(out, 'b (w1 w2) c h w -> b c (h w1) (w w2)', w1 = 2, w2 = 2)
|
||||
return out
|
||||
|
||||
# vqgan attention layer
|
||||
class VQGanAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -375,7 +481,7 @@ class VQGanVAE(nn.Module):
|
||||
|
||||
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
|
||||
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
|
||||
prepend(self.decoders, nn.Sequential(nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu()))
|
||||
prepend(self.decoders, nn.Sequential(QueryAttnUpsample(dim_out), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu()))
|
||||
|
||||
if layer_use_attn:
|
||||
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||
@@ -477,7 +583,8 @@ class VQGanVAE(nn.Module):
|
||||
img,
|
||||
return_loss = False,
|
||||
return_discr_loss = False,
|
||||
return_recons = False
|
||||
return_recons = False,
|
||||
add_gradient_penalty = True
|
||||
):
|
||||
batch, channels, height, width, device = *img.shape, img.device
|
||||
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
|
||||
@@ -502,11 +609,11 @@ class VQGanVAE(nn.Module):
|
||||
|
||||
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
|
||||
|
||||
gp = gradient_penalty(img, img_discr_logits)
|
||||
|
||||
discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
|
||||
|
||||
loss = discr_loss + gp
|
||||
if add_gradient_penalty:
|
||||
gp = gradient_penalty(img, img_discr_logits)
|
||||
loss = discr_loss + gp
|
||||
|
||||
if return_recons:
|
||||
return loss, fmap
|
||||
|
||||
Reference in New Issue
Block a user