Compare commits

...

4 Commits

Author SHA1 Message Date
Phil Wang
579d4b42dd does not seem right to clip for the prior diffusion part 2022-04-24 09:51:18 -07:00
Phil Wang
473808850a some outlines to the eventual CLI endpoint 2022-04-24 09:27:15 -07:00
Phil Wang
d5318aef4f todo 2022-04-23 08:23:08 -07:00
Phil Wang
f82917e1fd prepare for turning off gradient penalty, as shown in GAN literature, GP needs to be only applied 1 out of 4 iterations 2022-04-23 07:52:10 -07:00
5 changed files with 57 additions and 19 deletions

View File

@@ -536,6 +536,7 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] bring in tools to train vqgan-vae
- [ ] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
- [ ] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
## Citations
@@ -573,17 +574,6 @@ Once built, images will be saved to the same directory the command is invoked
}
```
```bibtex
@misc{zhang2019root,
title = {Root Mean Square Layer Normalization},
author = {Biao Zhang and Rico Sennrich},
year = {2019},
eprint = {1910.07467},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
```
```bibtex
@inproceedings{Tu2022MaxViTMV,
title = {MaxViT: Multi-Axis Vision Transformer},

View File

@@ -1,9 +1,51 @@
import click
import torch
import torchvision.transforms as T
from pathlib import Path
from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior
def safeget(dictionary, keys, default = None):
return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary)
def simple_slugify(text, max_length = 255):
return text.replace("-", "_").replace(",", "").replace(" ", "_").replace("|", "--").strip('-_')[:max_length]
def get_pkg_version():
from pkg_resources import get_distribution
return get_distribution('dalle2_pytorch').version
def main():
pass
@click.command()
@click.option('--model', default = './dalle2.pt', help = 'path to trained DALL-E2 model')
@click.option('--cond_scale', default = 2, help = 'conditioning scale (classifier free guidance) in decoder')
@click.argument('text')
def dream(text):
return 'not ready yet'
def dream(
model,
cond_scale,
text
):
model_path = Path(model)
full_model_path = str(model_path.resolve())
assert model_path.exists(), f'model not found at {full_model_path}'
loaded = torch.load(str(model_path))
version = safeget(loaded, 'version')
print(f'loading DALL-E2 from {full_model_path}, saved at version {version} - current package version is {get_pkg_version()}')
prior_init_params = safeget(loaded, 'init_params.prior')
decoder_init_params = safeget(loaded, 'init_params.decoder')
model_params = safeget(loaded, 'model_params')
prior = DiffusionPrior(**prior_init_params)
decoder = Decoder(**decoder_init_params)
dalle2 = DALLE2(prior, decoder)
dalle2.load_state_dict(model_params)
image = dalle2(text, cond_scale = cond_scale)
pil_image = T.ToPILImage()(image)
return pil_image.save(f'./{simple_slugify(text)}.png')

View File

@@ -591,7 +591,7 @@ class DiffusionPrior(nn.Module):
else:
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, **text_cond))
if clip_denoised:
if clip_denoised and not self.predict_x0:
x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
@@ -1451,6 +1451,7 @@ class DALLE2(nn.Module):
cond_scale = 1.
):
device = next(self.parameters()).device
one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)
if isinstance(text, str) or is_list_str(text):
text = [text] if not isinstance(text, (list, tuple)) else text
@@ -1458,4 +1459,8 @@ class DALLE2(nn.Module):
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
images = self.decoder.sample(image_embed, cond_scale = cond_scale)
if one_text:
return images[0]
return images

View File

@@ -477,7 +477,8 @@ class VQGanVAE(nn.Module):
img,
return_loss = False,
return_discr_loss = False,
return_recons = False
return_recons = False,
add_gradient_penalty = True
):
batch, channels, height, width, device = *img.shape, img.device
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
@@ -502,11 +503,11 @@ class VQGanVAE(nn.Module):
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
gp = gradient_penalty(img, img_discr_logits)
discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
loss = discr_loss + gp
if add_gradient_penalty:
gp = gradient_penalty(img, img_discr_logits)
loss = discr_loss + gp
if return_recons:
return loss, fmap

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.38',
version = '0.0.40',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',