Compare commits

..

10 Commits

5 changed files with 160 additions and 61 deletions

View File

@@ -383,9 +383,11 @@ You can also train the decoder on images of greater than the size (say 512x512)
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training. For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
## Experimental - DALL-E2 with Latent Diffusion ## Experimental
This repository decides to take the next step and offer DALL-E2 combined with latent diffusion, from Rombach et al. ### DALL-E2 with Latent Diffusion
This repository decides to take the next step and offer DALL-E2 combined with <a href="https://huggingface.co/spaces/multimodalart/latentdiffusion">latent diffusion</a>, from Rombach et al.
You can use it as follows. Latent diffusion can be limited to just the first U-Net in the cascade, or to any number you wish. You can use it as follows. Latent diffusion can be limited to just the first U-Net in the cascade, or to any number you wish.
@@ -409,9 +411,10 @@ clip = CLIP(
visual_heads = 8 visual_heads = 8
) )
# 2 unets for the decoder (a la cascading DDPM) # 3 unets for the decoder (a la cascading DDPM)
# 1st unet is doing latent diffusion # first two unets are doing latent diffusion
# vqgan-vae must be trained before hand
vae1 = VQGanVAE( vae1 = VQGanVAE(
dim = 32, dim = 32,
@@ -469,7 +472,7 @@ decoder = Decoder(
# mock images (get a lot of this) # mock images (get a lot of this)
images = torch.randn(1, 3, 512, 512).cuda() images = torch.randn(1, 3, 1024, 1024).cuda()
# feed images into decoder, specifying which unet you want to train # feed images into decoder, specifying which unet you want to train
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme # each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
@@ -482,6 +485,10 @@ with decoder.one_unet_in_gpu(2):
loss = decoder(images, unet_number = 2) loss = decoder(images, unet_number = 2)
loss.backward() loss.backward()
with decoder.one_unet_in_gpu(3):
loss = decoder(images, unet_number = 3)
loss.backward()
# do the above for many steps for both unets # do the above for many steps for both unets
# then it will learn to generate images based on the CLIP image embeddings # then it will learn to generate images based on the CLIP image embeddings
@@ -492,7 +499,11 @@ mock_image_embed = torch.randn(1, 512).cuda()
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024) images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
``` ```
## CLI Usage (work in progress) ## Training wrapper (wip)
Offer training wrappers
## CLI (wip)
```bash ```bash
$ dream 'sharing a sunset at the summit of mount everest with my dog' $ dream 'sharing a sunset at the summit of mount everest with my dog'
@@ -500,9 +511,7 @@ $ dream 'sharing a sunset at the summit of mount everest with my dog'
Once built, images will be saved to the same directory the command is invoked Once built, images will be saved to the same directory the command is invoked
## Training wrapper (wip) <a href="https://github.com/lucidrains/big-sleep">template</a>
Offer training wrappers
## Training CLI (wip) ## Training CLI (wip)
@@ -520,11 +529,14 @@ Offer training wrappers
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning) - [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately) - [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms - [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
- [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
- [ ] spend one day cleaning up tech debt in decoder - [ ] spend one day cleaning up tech debt in decoder
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network - [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] train on a toy task, offer in colab - [ ] train on a toy task, offer in colab
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference - [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] bring in tools to train vqgan-vae
- [ ] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
## Citations ## Citations
@@ -562,17 +574,6 @@ Offer training wrappers
} }
``` ```
```bibtex
@misc{zhang2019root,
title = {Root Mean Square Layer Normalization},
author = {Biao Zhang and Rico Sennrich},
year = {2019},
eprint = {1910.07467},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
```
```bibtex ```bibtex
@inproceedings{Tu2022MaxViTMV, @inproceedings{Tu2022MaxViTMV,
title = {MaxViT: Multi-Axis Vision Transformer}, title = {MaxViT: Multi-Axis Vision Transformer},

View File

@@ -1,9 +1,51 @@
import click import click
import torch
import torchvision.transforms as T
from pathlib import Path
from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior
def safeget(dictionary, keys, default = None):
return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary)
def simple_slugify(text, max_length = 255):
return text.replace("-", "_").replace(",", "").replace(" ", "_").replace("|", "--").strip('-_')[:max_length]
def get_pkg_version():
from pkg_resources import get_distribution
return get_distribution('dalle2_pytorch').version
def main(): def main():
pass pass
@click.command() @click.command()
@click.option('--model', default = './dalle2.pt', help = 'path to trained DALL-E2 model')
@click.option('--cond_scale', default = 2, help = 'conditioning scale (classifier free guidance) in decoder')
@click.argument('text') @click.argument('text')
def dream(text): def dream(
return 'not ready yet' model,
cond_scale,
text
):
model_path = Path(model)
full_model_path = str(model_path.resolve())
assert model_path.exists(), f'model not found at {full_model_path}'
loaded = torch.load(str(model_path))
version = safeget(loaded, 'version')
print(f'loading DALL-E2 from {full_model_path}, saved at version {version} - current package version is {get_pkg_version()}')
prior_init_params = safeget(loaded, 'init_params.prior')
decoder_init_params = safeget(loaded, 'init_params.decoder')
model_params = safeget(loaded, 'model_params')
prior = DiffusionPrior(**prior_init_params)
decoder = Decoder(**decoder_init_params)
dalle2 = DALLE2(prior, decoder)
dalle2.load_state_dict(model_params)
image = dalle2(text, cond_scale = cond_scale)
pil_image = T.ToPILImage()(image)
return pil_image.save(f'./{simple_slugify(text)}.png')

View File

@@ -16,6 +16,7 @@ from einops_exts.torch import EinopsToAndFrom
from kornia.filters import gaussian_blur2d from kornia.filters import gaussian_blur2d
from dalle2_pytorch.tokenizer import tokenizer from dalle2_pytorch.tokenizer import tokenizer
from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
# use x-clip # use x-clip
@@ -48,11 +49,11 @@ def is_list_str(x):
return False return False
return all([type(el) == str for el in x]) return all([type(el) == str for el in x])
def pad_tuple_to_length(t, length): def pad_tuple_to_length(t, length, fillvalue = None):
remain_length = length - len(t) remain_length = length - len(t)
if remain_length <= 0: if remain_length <= 0:
return t return t
return (*t, *((None,) * remain_length)) return (*t, *((fillvalue,) * remain_length))
# for controlling freezing of CLIP # for controlling freezing of CLIP
@@ -583,14 +584,16 @@ class DiffusionPrior(nn.Module):
return posterior_mean, posterior_variance, posterior_log_variance_clipped return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool): def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
pred = self.net(x, t, **text_cond)
if self.predict_x0: if self.predict_x0:
x_recon = self.net(x, t, **text_cond) x_recon = pred
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this # not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken # i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
else: else:
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, **text_cond)) x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised: if clip_denoised and not self.predict_x0:
x_recon.clamp_(-1., 1.) x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
@@ -1118,6 +1121,7 @@ class Decoder(nn.Module):
cond_drop_prob = 0.2, cond_drop_prob = 0.2,
loss_type = 'l1', loss_type = 'l1',
beta_schedule = 'cosine', beta_schedule = 'cosine',
predict_x0 = False,
image_sizes = None, # for cascading ddpm, image size at each stage image_sizes = None, # for cascading ddpm, image size at each stage
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
@@ -1135,12 +1139,15 @@ class Decoder(nn.Module):
# while the rest of the unets are conditioned on the low resolution image produced by previous unet # while the rest of the unets are conditioned on the low resolution image produced by previous unet
unets = cast_tuple(unet) unets = cast_tuple(unet)
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets)) vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
self.unets = nn.ModuleList([]) self.unets = nn.ModuleList([])
self.vaes = nn.ModuleList([]) self.vaes = nn.ModuleList([])
for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)): for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)):
assert isinstance(one_unet, Unet)
assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))
is_first = ind == 0 is_first = ind == 0
latent_dim = one_vae.encoded_dim if exists(one_vae) else None latent_dim = one_vae.encoded_dim if exists(one_vae) else None
@@ -1152,7 +1159,7 @@ class Decoder(nn.Module):
) )
self.unets.append(one_unet) self.unets.append(one_unet)
self.vaes.append(one_vae.copy_for_eval() if exists(one_vae) else None) self.vaes.append(one_vae.copy_for_eval())
# unet image sizes # unet image sizes
@@ -1163,6 +1170,10 @@ class Decoder(nn.Module):
self.image_sizes = image_sizes self.image_sizes = image_sizes
self.sample_channels = cast_tuple(self.channels, len(image_sizes)) self.sample_channels = cast_tuple(self.channels, len(image_sizes))
# predict x0 config
self.predict_x0 = cast_tuple(predict_x0, len(unets))
# cascading ddpm related stuff # cascading ddpm related stuff
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets)) lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
@@ -1281,34 +1292,47 @@ class Decoder(nn.Module):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, cond_scale = 1.): def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x0 = False, cond_scale = 1.):
pred_noise = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img) pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
x_recon = self.predict_start_from_noise(x, t = t, noise = pred_noise)
if clip_denoised: if predict_x0:
x_recon = pred
else:
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised and not predict_x0:
x_recon.clamp_(-1., 1.) x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad() @torch.no_grad()
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, clip_denoised = True, repeat_noise = False): def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x0 = False, clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised) model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x0 = predict_x0)
noise = noise_like(x.shape, device, repeat_noise) noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0 # no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad() @torch.no_grad()
def p_sample_loop(self, unet, shape, image_embed, lowres_cond_img = None, text_encodings = None, cond_scale = 1): def p_sample_loop(self, unet, shape, image_embed, predict_x0 = False, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
device = self.betas.device device = self.betas.device
b = shape[0] b = shape[0]
img = torch.randn(shape, device = device) img = torch.randn(shape, device = device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(unet, img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img) img = self.p_sample(
unet,
img,
torch.full((b,), i, device = device, dtype = torch.long),
image_embed = image_embed,
text_encodings = text_encodings,
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
predict_x0 = predict_x0
)
return img return img
@@ -1320,7 +1344,7 @@ class Decoder(nn.Module):
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
) )
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, noise = None): def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x0 = False, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise) x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
@@ -1334,12 +1358,14 @@ class Decoder(nn.Module):
cond_drop_prob = self.cond_drop_prob cond_drop_prob = self.cond_drop_prob
) )
target = noise if not predict_x0 else x_start
if self.loss_type == 'l1': if self.loss_type == 'l1':
loss = F.l1_loss(noise, x_recon) loss = F.l1_loss(target, x_recon)
elif self.loss_type == 'l2': elif self.loss_type == 'l2':
loss = F.mse_loss(noise, x_recon) loss = F.mse_loss(target, x_recon)
elif self.loss_type == "huber": elif self.loss_type == "huber":
loss = F.smooth_l1_loss(noise, x_recon) loss = F.smooth_l1_loss(target, x_recon)
else: else:
raise NotImplementedError() raise NotImplementedError()
@@ -1354,7 +1380,7 @@ class Decoder(nn.Module):
img = None img = None
for unet, vae, channel, image_size in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes)): for unet, vae, channel, image_size, predict_x0 in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x0)):
with self.one_unet_in_gpu(unet = unet): with self.one_unet_in_gpu(unet = unet):
lowres_cond_img = None lowres_cond_img = None
shape = (batch_size, channel, image_size, image_size) shape = (batch_size, channel, image_size, image_size)
@@ -1362,12 +1388,11 @@ class Decoder(nn.Module):
if unet.lowres_cond: if unet.lowres_cond:
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size) lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
if exists(vae): image_size = vae.get_encoded_fmap_size(image_size)
image_size //= (2 ** vae.layers) shape = (batch_size, vae.encoded_dim, image_size, image_size)
shape = (batch_size, vae.encoded_dim, image_size, image_size)
if exists(lowres_cond_img): if exists(lowres_cond_img):
lowres_cond_img = vae.encode(lowres_cond_img) lowres_cond_img = vae.encode(lowres_cond_img)
img = self.p_sample_loop( img = self.p_sample_loop(
unet, unet,
@@ -1375,11 +1400,11 @@ class Decoder(nn.Module):
image_embed = image_embed, image_embed = image_embed,
text_encodings = text_encodings, text_encodings = text_encodings,
cond_scale = cond_scale, cond_scale = cond_scale,
predict_x0 = predict_x0,
lowres_cond_img = lowres_cond_img lowres_cond_img = lowres_cond_img
) )
if exists(vae): img = vae.decode(img)
img = vae.decode(img)
return img return img
@@ -1399,6 +1424,7 @@ class Decoder(nn.Module):
target_image_size = self.image_sizes[unet_index] target_image_size = self.image_sizes[unet_index]
vae = self.vaes[unet_index] vae = self.vaes[unet_index]
predict_x0 = self.predict_x0[unet_index]
b, c, h, w, device, = *image.shape, image.device b, c, h, w, device, = *image.shape, image.device
@@ -1415,15 +1441,14 @@ class Decoder(nn.Module):
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
image = resize_image_to(image, target_image_size) image = resize_image_to(image, target_image_size)
if exists(vae): vae.eval()
vae.eval() with torch.no_grad():
with torch.no_grad(): image = vae.encode(image)
image = vae.encode(image)
if exists(lowres_cond_img): if exists(lowres_cond_img):
lowres_cond_img = vae.encode(lowres_cond_img) lowres_cond_img = vae.encode(lowres_cond_img)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img) return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x0 = predict_x0)
# main class # main class
@@ -1450,6 +1475,7 @@ class DALLE2(nn.Module):
cond_scale = 1. cond_scale = 1.
): ):
device = next(self.parameters()).device device = next(self.parameters()).device
one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)
if isinstance(text, str) or is_list_str(text): if isinstance(text, str) or is_list_str(text):
text = [text] if not isinstance(text, (list, tuple)) else text text = [text] if not isinstance(text, (list, tuple)) else text
@@ -1457,4 +1483,8 @@ class DALLE2(nn.Module):
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples) image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
images = self.decoder.sample(image_embed, cond_scale = cond_scale) images = self.decoder.sample(image_embed, cond_scale = cond_scale)
if one_text:
return images[0]
return images return images

View File

@@ -287,6 +287,28 @@ class VQGanAttention(nn.Module):
return out + residual return out + residual
class NullVQGanVAE(nn.Module):
def __init__(
self,
*,
channels
):
super().__init__()
self.encoded_dim = channels
self.layers = 0
def get_encoded_fmap_size(self, size):
return size
def copy_for_eval(self):
return self
def encode(self, x):
return x
def decode(self, x):
return x
class VQGanVAE(nn.Module): class VQGanVAE(nn.Module):
def __init__( def __init__(
self, self,
@@ -407,6 +429,9 @@ class VQGanVAE(nn.Module):
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
def get_encoded_fmap_size(self, image_size):
return image_size // (2 ** self.layers)
def copy_for_eval(self): def copy_for_eval(self):
device = next(self.parameters()).device device = next(self.parameters()).device
vae_copy = copy.deepcopy(self.cpu()) vae_copy = copy.deepcopy(self.cpu())
@@ -452,7 +477,8 @@ class VQGanVAE(nn.Module):
img, img,
return_loss = False, return_loss = False,
return_discr_loss = False, return_discr_loss = False,
return_recons = False return_recons = False,
add_gradient_penalty = True
): ):
batch, channels, height, width, device = *img.shape, img.device batch, channels, height, width, device = *img.shape, img.device
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}' assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
@@ -477,11 +503,11 @@ class VQGanVAE(nn.Module):
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img)) fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
gp = gradient_penalty(img, img_discr_logits)
discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits) discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
loss = discr_loss + gp if add_gradient_penalty:
gp = gradient_penalty(img, img_discr_logits)
loss = discr_loss + gp
if return_recons: if return_recons:
return loss, fmap return loss, fmap

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.37', version = '0.0.41',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',