mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
be72f8575d | ||
|
|
f2d5b87677 | ||
|
|
461347c171 | ||
|
|
46cef31c86 | ||
|
|
59b1a77d4d | ||
|
|
7f338319fd |
114
README.md
114
README.md
@@ -10,8 +10,6 @@ The main novelty seems to be an extra layer of indirection with the prior networ
|
||||
|
||||
This model is SOTA for text-to-image for now.
|
||||
|
||||
It may also explore an extension of using <a href="https://huggingface.co/spaces/multimodalart/latentdiffusion">latent diffusion</a> in the decoder from Rombach et al.
|
||||
|
||||
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
|
||||
|
||||
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
||||
@@ -385,6 +383,115 @@ You can also train the decoder on images of greater than the size (say 512x512)
|
||||
|
||||
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
|
||||
|
||||
## Experimental - DALL-E2 with Latent Diffusion
|
||||
|
||||
This repository decides to take the next step and offer DALL-E2 combined with latent diffusion, from Rombach et al.
|
||||
|
||||
You can use it as follows. Latent diffusion can be limited to just the first U-Net in the cascade, or to any number you wish.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import Unet, Decoder, CLIP, VQGanVAE
|
||||
|
||||
# trained clip from step 1
|
||||
|
||||
clip = CLIP(
|
||||
dim_text = 512,
|
||||
dim_image = 512,
|
||||
dim_latent = 512,
|
||||
num_text_tokens = 49408,
|
||||
text_enc_depth = 1,
|
||||
text_seq_len = 256,
|
||||
text_heads = 8,
|
||||
visual_enc_depth = 1,
|
||||
visual_image_size = 256,
|
||||
visual_patch_size = 32,
|
||||
visual_heads = 8
|
||||
)
|
||||
|
||||
# 2 unets for the decoder (a la cascading DDPM)
|
||||
|
||||
# 1st unet is doing latent diffusion
|
||||
|
||||
vae1 = VQGanVAE(
|
||||
dim = 32,
|
||||
image_size = 256,
|
||||
layers = 3,
|
||||
layer_mults = (1, 2, 4)
|
||||
)
|
||||
|
||||
vae2 = VQGanVAE(
|
||||
dim = 32,
|
||||
image_size = 512,
|
||||
layers = 3,
|
||||
layer_mults = (1, 2, 4)
|
||||
)
|
||||
|
||||
unet1 = Unet(
|
||||
dim = 32,
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
sparse_attn = True,
|
||||
sparse_attn_window = 2,
|
||||
dim_mults = (1, 2, 4, 8)
|
||||
)
|
||||
|
||||
unet2 = Unet(
|
||||
dim = 32,
|
||||
image_embed_dim = 512,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8, 16),
|
||||
cond_on_image_embeds = True,
|
||||
cond_on_text_encodings = False
|
||||
)
|
||||
|
||||
unet3 = Unet(
|
||||
dim = 32,
|
||||
image_embed_dim = 512,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8, 16),
|
||||
cond_on_image_embeds = True,
|
||||
cond_on_text_encodings = False,
|
||||
attend_at_middle = False
|
||||
)
|
||||
|
||||
# decoder, which contains the unet(s) and clip
|
||||
|
||||
decoder = Decoder(
|
||||
clip = clip,
|
||||
vae = (vae1, vae2), # latent diffusion for unet1 (vae1) and unet2 (vae2), but not for the last unet3
|
||||
unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
|
||||
image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
).cuda()
|
||||
|
||||
# mock images (get a lot of this)
|
||||
|
||||
images = torch.randn(1, 3, 512, 512).cuda()
|
||||
|
||||
# feed images into decoder, specifying which unet you want to train
|
||||
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
|
||||
|
||||
with decoder.one_unet_in_gpu(1):
|
||||
loss = decoder(images, unet_number = 1)
|
||||
loss.backward()
|
||||
|
||||
with decoder.one_unet_in_gpu(2):
|
||||
loss = decoder(images, unet_number = 2)
|
||||
loss.backward()
|
||||
|
||||
# do the above for many steps for both unets
|
||||
|
||||
# then it will learn to generate images based on the CLIP image embeddings
|
||||
|
||||
# chaining the unets from lowest resolution to highest resolution (thus cascading)
|
||||
|
||||
mock_image_embed = torch.randn(1, 512).cuda()
|
||||
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
|
||||
```
|
||||
|
||||
## CLI Usage (work in progress)
|
||||
|
||||
```bash
|
||||
@@ -412,7 +519,8 @@ Offer training wrappers
|
||||
- [x] add efficient attention in unet
|
||||
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
|
||||
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
|
||||
- [ ] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional
|
||||
- [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
|
||||
- [ ] spend one day cleaning up tech debt in decoder
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||
- [ ] train on a toy task, offer in colab
|
||||
|
||||
@@ -48,6 +48,12 @@ def is_list_str(x):
|
||||
return False
|
||||
return all([type(el) == str for el in x])
|
||||
|
||||
def pad_tuple_to_length(t, length):
|
||||
remain_length = length - len(t)
|
||||
if remain_length <= 0:
|
||||
return t
|
||||
return (*t, *((None,) * remain_length))
|
||||
|
||||
# for controlling freezing of CLIP
|
||||
|
||||
def set_module_requires_grad_(module, requires_grad):
|
||||
@@ -137,23 +143,27 @@ def sigmoid_beta_schedule(timesteps):
|
||||
|
||||
# diffusion prior
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.ones(dim))
|
||||
self.register_buffer("beta", torch.zeros(dim))
|
||||
|
||||
def forward(self, x):
|
||||
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
|
||||
|
||||
|
||||
class ChanLayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.scale = dim ** 0.5
|
||||
self.gamma = nn.Parameter(torch.ones(dim))
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
squared_sum = (x ** 2).sum(dim = -1, keepdim = True)
|
||||
inv_norm = torch.rsqrt(squared_sum + self.eps)
|
||||
return x * inv_norm * self.gamma * self.scale
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g
|
||||
|
||||
class ChanRMSNorm(RMSNorm):
|
||||
def forward(self, x):
|
||||
squared_sum = (x ** 2).sum(dim = 1, keepdim = True)
|
||||
inv_norm = torch.rsqrt(squared_sum + self.eps)
|
||||
return x * inv_norm * rearrange(self.gamma, 'c -> 1 c 1 1') * self.scale
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
@@ -249,10 +259,10 @@ def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False):
|
||||
|
||||
inner_dim = int(mult * dim)
|
||||
return nn.Sequential(
|
||||
RMSNorm(dim),
|
||||
LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim * 2, bias = False),
|
||||
SwiGLU(),
|
||||
RMSNorm(inner_dim) if post_activation_norm else nn.Identity(),
|
||||
LayerNorm(inner_dim) if post_activation_norm else nn.Identity(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim, bias = False)
|
||||
)
|
||||
@@ -275,7 +285,8 @@ class Attention(nn.Module):
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.causal = causal
|
||||
self.norm = RMSNorm(dim)
|
||||
self.norm = LayerNorm(dim)
|
||||
self.post_norm = LayerNorm(dim) # sandwich norm from Coqview paper + Normformer
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||
@@ -331,7 +342,8 @@ class Attention(nn.Module):
|
||||
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
out = self.to_out(out)
|
||||
return self.post_norm(out)
|
||||
|
||||
class CausalTransformer(nn.Module):
|
||||
def __init__(
|
||||
@@ -344,7 +356,8 @@ class CausalTransformer(nn.Module):
|
||||
ff_mult = 4,
|
||||
norm_out = False,
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.
|
||||
ff_dropout = 0.,
|
||||
final_proj = True
|
||||
):
|
||||
super().__init__()
|
||||
self.rel_pos_bias = RelPosBias(heads = heads)
|
||||
@@ -356,7 +369,8 @@ class CausalTransformer(nn.Module):
|
||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
|
||||
]))
|
||||
|
||||
self.norm = RMSNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
||||
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
||||
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -371,7 +385,8 @@ class CausalTransformer(nn.Module):
|
||||
x = attn(x, mask = mask, attn_bias = attn_bias) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
out = self.norm(x)
|
||||
return self.project_out(out)
|
||||
|
||||
class DiffusionPriorNetwork(nn.Module):
|
||||
def __init__(
|
||||
@@ -531,12 +546,14 @@ class DiffusionPrior(nn.Module):
|
||||
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||
|
||||
@torch.no_grad()
|
||||
def get_image_embed(self, image):
|
||||
image_encoding = self.clip.visual_transformer(image)
|
||||
image_cls = image_encoding[:, 0]
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
return l2norm(image_embed)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_text_cond(self, text):
|
||||
text_encodings = self.clip.text_transformer(text)
|
||||
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
|
||||
@@ -720,7 +737,7 @@ class ConvNextBlock(nn.Module):
|
||||
|
||||
inner_dim = int(dim_out * mult)
|
||||
self.net = nn.Sequential(
|
||||
ChanRMSNorm(dim) if norm else nn.Identity(),
|
||||
ChanLayerNorm(dim) if norm else nn.Identity(),
|
||||
nn.Conv2d(dim, inner_dim, 3, padding = 1),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
|
||||
@@ -756,8 +773,8 @@ class CrossAttention(nn.Module):
|
||||
|
||||
context_dim = default(context_dim, dim)
|
||||
|
||||
self.norm = RMSNorm(dim)
|
||||
self.norm_context = RMSNorm(context_dim)
|
||||
self.norm = LayerNorm(dim)
|
||||
self.norm_context = LayerNorm(context_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||
@@ -931,11 +948,16 @@ class Unet(nn.Module):
|
||||
|
||||
# if the current settings for the unet are not correct
|
||||
# for cascading DDPM, then reinit the unet with the right settings
|
||||
def force_lowres_cond(self, lowres_cond):
|
||||
if lowres_cond == self.lowres_cond:
|
||||
def cast_model_parameters(
|
||||
self,
|
||||
*,
|
||||
lowres_cond,
|
||||
channels
|
||||
):
|
||||
if lowres_cond == self.lowres_cond and channels == self.channels:
|
||||
return self
|
||||
|
||||
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond}
|
||||
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond, 'channels': channels}
|
||||
return self.__class__(**updated_kwargs)
|
||||
|
||||
def forward_with_cond_scale(
|
||||
@@ -1075,14 +1097,14 @@ class LowresConditioner(nn.Module):
|
||||
if self.training and self.downsample_first and exists(downsample_image_size):
|
||||
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
|
||||
|
||||
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
|
||||
|
||||
if self.training:
|
||||
# when training, blur the low resolution conditional image
|
||||
blur_sigma = default(blur_sigma, self.blur_sigma)
|
||||
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
|
||||
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
||||
|
||||
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
|
||||
|
||||
return cond_fmap
|
||||
|
||||
class Decoder(nn.Module):
|
||||
@@ -1091,6 +1113,7 @@ class Decoder(nn.Module):
|
||||
unet,
|
||||
*,
|
||||
clip,
|
||||
vae = None,
|
||||
timesteps = 1000,
|
||||
cond_drop_prob = 0.2,
|
||||
loss_type = 'l1',
|
||||
@@ -1111,11 +1134,25 @@ class Decoder(nn.Module):
|
||||
# automatically take care of ensuring that first unet is unconditional
|
||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||
|
||||
unets = cast_tuple(unet)
|
||||
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets))
|
||||
|
||||
self.unets = nn.ModuleList([])
|
||||
for ind, one_unet in enumerate(cast_tuple(unet)):
|
||||
self.vaes = nn.ModuleList([])
|
||||
|
||||
for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)):
|
||||
is_first = ind == 0
|
||||
one_unet = one_unet.force_lowres_cond(not is_first)
|
||||
latent_dim = one_vae.encoded_dim if exists(one_vae) else None
|
||||
|
||||
unet_channels = default(latent_dim, self.channels)
|
||||
|
||||
one_unet = one_unet.cast_model_parameters(
|
||||
lowres_cond = not is_first,
|
||||
channels = unet_channels
|
||||
)
|
||||
|
||||
self.unets.append(one_unet)
|
||||
self.vaes.append(one_vae.copy_for_eval() if exists(one_vae) else None)
|
||||
|
||||
# unet image sizes
|
||||
|
||||
@@ -1210,10 +1247,12 @@ class Decoder(nn.Module):
|
||||
yield
|
||||
unet.cpu()
|
||||
|
||||
@torch.no_grad()
|
||||
def get_text_encodings(self, text):
|
||||
text_encodings = self.clip.text_transformer(text)
|
||||
return text_encodings[:, 1:]
|
||||
|
||||
@torch.no_grad()
|
||||
def get_image_embed(self, image):
|
||||
image = resize_image_to(image, self.clip_image_size)
|
||||
image_encoding = self.clip.visual_transformer(image)
|
||||
@@ -1315,25 +1354,43 @@ class Decoder(nn.Module):
|
||||
|
||||
img = None
|
||||
|
||||
for unet, channel, image_size in tqdm(zip(self.unets, self.sample_channels, self.image_sizes)):
|
||||
for unet, vae, channel, image_size in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes)):
|
||||
with self.one_unet_in_gpu(unet = unet):
|
||||
lowres_cond_img = self.to_lowres_cond(
|
||||
img,
|
||||
target_image_size = image_size
|
||||
) if unet.lowres_cond else None
|
||||
lowres_cond_img = None
|
||||
shape = (batch_size, channel, image_size, image_size)
|
||||
|
||||
if unet.lowres_cond:
|
||||
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
|
||||
|
||||
if exists(vae):
|
||||
image_size //= (2 ** vae.layers)
|
||||
shape = (batch_size, vae.encoded_dim, image_size, image_size)
|
||||
|
||||
if exists(lowres_cond_img):
|
||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
||||
|
||||
img = self.p_sample_loop(
|
||||
unet,
|
||||
(batch_size, channel, image_size, image_size),
|
||||
shape,
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
cond_scale = cond_scale,
|
||||
lowres_cond_img = lowres_cond_img
|
||||
)
|
||||
|
||||
if exists(vae):
|
||||
img = vae.decode(img)
|
||||
|
||||
return img
|
||||
|
||||
def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None):
|
||||
def forward(
|
||||
self,
|
||||
image,
|
||||
text = None,
|
||||
image_embed = None,
|
||||
text_encodings = None,
|
||||
unet_number = None
|
||||
):
|
||||
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
||||
unet_number = default(unet_number, 1)
|
||||
unet_index = unet_number - 1
|
||||
@@ -1341,6 +1398,7 @@ class Decoder(nn.Module):
|
||||
unet = self.get_unet(unet_number)
|
||||
|
||||
target_image_size = self.image_sizes[unet_index]
|
||||
vae = self.vaes[unet_index]
|
||||
|
||||
b, c, h, w, device, = *image.shape, image.device
|
||||
|
||||
@@ -1355,8 +1413,17 @@ class Decoder(nn.Module):
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
|
||||
|
||||
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
|
||||
ddpm_image = resize_image_to(image, target_image_size)
|
||||
return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
|
||||
image = resize_image_to(image, target_image_size)
|
||||
|
||||
if exists(vae):
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
image = vae.encode(image)
|
||||
|
||||
if exists(lowres_cond_img):
|
||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
||||
|
||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
|
||||
|
||||
# main class
|
||||
|
||||
|
||||
@@ -294,7 +294,7 @@ class VQGanVAE(nn.Module):
|
||||
dim,
|
||||
image_size,
|
||||
channels = 3,
|
||||
num_layers = 4,
|
||||
layers = 4,
|
||||
layer_mults = None,
|
||||
l2_recon_loss = False,
|
||||
use_hinge_loss = True,
|
||||
@@ -321,35 +321,37 @@ class VQGanVAE(nn.Module):
|
||||
|
||||
self.image_size = image_size
|
||||
self.channels = channels
|
||||
self.num_layers = num_layers
|
||||
self.fmap_size = image_size // (num_layers ** 2)
|
||||
self.layers = layers
|
||||
self.fmap_size = image_size // (layers ** 2)
|
||||
self.codebook_size = vq_codebook_size
|
||||
|
||||
self.encoders = MList([])
|
||||
self.decoders = MList([])
|
||||
|
||||
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(num_layers))))
|
||||
assert len(layer_mults) == num_layers, 'layer multipliers must be equal to designated number of layers'
|
||||
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
|
||||
assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
|
||||
|
||||
layer_dims = [dim * mult for mult in layer_mults]
|
||||
dims = (dim, *layer_dims)
|
||||
codebook_dim = layer_dims[-1]
|
||||
|
||||
self.encoded_dim = dims[-1]
|
||||
|
||||
dim_pairs = zip(dims[:-1], dims[1:])
|
||||
|
||||
append = lambda arr, t: arr.append(t)
|
||||
prepend = lambda arr, t: arr.insert(0, t)
|
||||
|
||||
if not isinstance(num_resnet_blocks, tuple):
|
||||
num_resnet_blocks = (*((0,) * (num_layers - 1)), num_resnet_blocks)
|
||||
num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)
|
||||
|
||||
if not isinstance(use_attn, tuple):
|
||||
use_attn = (*((False,) * (num_layers - 1)), use_attn)
|
||||
use_attn = (*((False,) * (layers - 1)), use_attn)
|
||||
|
||||
assert len(num_resnet_blocks) == num_layers, 'number of resnet blocks config must be equal to number of layers'
|
||||
assert len(use_attn) == num_layers
|
||||
assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers'
|
||||
assert len(use_attn) == layers
|
||||
|
||||
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(num_layers), dim_pairs, num_resnet_blocks, use_attn):
|
||||
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
|
||||
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
|
||||
prepend(self.decoders, nn.Sequential(nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu()))
|
||||
|
||||
@@ -434,13 +436,16 @@ class VQGanVAE(nn.Module):
|
||||
|
||||
return fmap
|
||||
|
||||
def decode(self, fmap):
|
||||
fmap = self.vq(fmap)
|
||||
def decode(self, fmap, return_indices_and_loss = False):
|
||||
fmap, indices, commit_loss = self.vq(fmap)
|
||||
|
||||
for dec in self.decoders:
|
||||
fmap = dec(fmap)
|
||||
|
||||
return fmap
|
||||
if not return_indices_and_loss:
|
||||
return fmap
|
||||
|
||||
return fmap, indices, commit_loss
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -453,9 +458,9 @@ class VQGanVAE(nn.Module):
|
||||
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
|
||||
assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'
|
||||
|
||||
fmap, indices, commit_loss = self.encode(img)
|
||||
fmap = self.encode(img)
|
||||
|
||||
fmap = self.decode(fmap)
|
||||
fmap, indices, commit_loss = self.decode(fmap, return_indices_and_loss = True)
|
||||
|
||||
if not return_loss and not return_discr_loss:
|
||||
return fmap
|
||||
|
||||
Reference in New Issue
Block a user