mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-21 00:34:19 +01:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
579d4b42dd | ||
|
|
473808850a | ||
|
|
d5318aef4f | ||
|
|
f82917e1fd |
12
README.md
12
README.md
@@ -536,6 +536,7 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||||
- [ ] bring in tools to train vqgan-vae
|
- [ ] bring in tools to train vqgan-vae
|
||||||
- [ ] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
|
- [ ] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
|
||||||
|
- [ ] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
|
||||||
|
|
||||||
## Citations
|
## Citations
|
||||||
|
|
||||||
@@ -573,17 +574,6 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
```bibtex
|
|
||||||
@misc{zhang2019root,
|
|
||||||
title = {Root Mean Square Layer Normalization},
|
|
||||||
author = {Biao Zhang and Rico Sennrich},
|
|
||||||
year = {2019},
|
|
||||||
eprint = {1910.07467},
|
|
||||||
archivePrefix = {arXiv},
|
|
||||||
primaryClass = {cs.LG}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@inproceedings{Tu2022MaxViTMV,
|
@inproceedings{Tu2022MaxViTMV,
|
||||||
title = {MaxViT: Multi-Axis Vision Transformer},
|
title = {MaxViT: Multi-Axis Vision Transformer},
|
||||||
|
|||||||
@@ -1,9 +1,51 @@
|
|||||||
import click
|
import click
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as T
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior
|
||||||
|
|
||||||
|
def safeget(dictionary, keys, default = None):
|
||||||
|
return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary)
|
||||||
|
|
||||||
|
def simple_slugify(text, max_length = 255):
|
||||||
|
return text.replace("-", "_").replace(",", "").replace(" ", "_").replace("|", "--").strip('-_')[:max_length]
|
||||||
|
|
||||||
|
def get_pkg_version():
|
||||||
|
from pkg_resources import get_distribution
|
||||||
|
return get_distribution('dalle2_pytorch').version
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
|
@click.option('--model', default = './dalle2.pt', help = 'path to trained DALL-E2 model')
|
||||||
|
@click.option('--cond_scale', default = 2, help = 'conditioning scale (classifier free guidance) in decoder')
|
||||||
@click.argument('text')
|
@click.argument('text')
|
||||||
def dream(text):
|
def dream(
|
||||||
return 'not ready yet'
|
model,
|
||||||
|
cond_scale,
|
||||||
|
text
|
||||||
|
):
|
||||||
|
model_path = Path(model)
|
||||||
|
full_model_path = str(model_path.resolve())
|
||||||
|
assert model_path.exists(), f'model not found at {full_model_path}'
|
||||||
|
loaded = torch.load(str(model_path))
|
||||||
|
|
||||||
|
version = safeget(loaded, 'version')
|
||||||
|
print(f'loading DALL-E2 from {full_model_path}, saved at version {version} - current package version is {get_pkg_version()}')
|
||||||
|
|
||||||
|
prior_init_params = safeget(loaded, 'init_params.prior')
|
||||||
|
decoder_init_params = safeget(loaded, 'init_params.decoder')
|
||||||
|
model_params = safeget(loaded, 'model_params')
|
||||||
|
|
||||||
|
prior = DiffusionPrior(**prior_init_params)
|
||||||
|
decoder = Decoder(**decoder_init_params)
|
||||||
|
|
||||||
|
dalle2 = DALLE2(prior, decoder)
|
||||||
|
dalle2.load_state_dict(model_params)
|
||||||
|
|
||||||
|
image = dalle2(text, cond_scale = cond_scale)
|
||||||
|
|
||||||
|
pil_image = T.ToPILImage()(image)
|
||||||
|
return pil_image.save(f'./{simple_slugify(text)}.png')
|
||||||
|
|||||||
@@ -591,7 +591,7 @@ class DiffusionPrior(nn.Module):
|
|||||||
else:
|
else:
|
||||||
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, **text_cond))
|
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, **text_cond))
|
||||||
|
|
||||||
if clip_denoised:
|
if clip_denoised and not self.predict_x0:
|
||||||
x_recon.clamp_(-1., 1.)
|
x_recon.clamp_(-1., 1.)
|
||||||
|
|
||||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||||
@@ -1451,6 +1451,7 @@ class DALLE2(nn.Module):
|
|||||||
cond_scale = 1.
|
cond_scale = 1.
|
||||||
):
|
):
|
||||||
device = next(self.parameters()).device
|
device = next(self.parameters()).device
|
||||||
|
one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)
|
||||||
|
|
||||||
if isinstance(text, str) or is_list_str(text):
|
if isinstance(text, str) or is_list_str(text):
|
||||||
text = [text] if not isinstance(text, (list, tuple)) else text
|
text = [text] if not isinstance(text, (list, tuple)) else text
|
||||||
@@ -1458,4 +1459,8 @@ class DALLE2(nn.Module):
|
|||||||
|
|
||||||
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
|
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
|
||||||
images = self.decoder.sample(image_embed, cond_scale = cond_scale)
|
images = self.decoder.sample(image_embed, cond_scale = cond_scale)
|
||||||
|
|
||||||
|
if one_text:
|
||||||
|
return images[0]
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|||||||
@@ -477,7 +477,8 @@ class VQGanVAE(nn.Module):
|
|||||||
img,
|
img,
|
||||||
return_loss = False,
|
return_loss = False,
|
||||||
return_discr_loss = False,
|
return_discr_loss = False,
|
||||||
return_recons = False
|
return_recons = False,
|
||||||
|
add_gradient_penalty = True
|
||||||
):
|
):
|
||||||
batch, channels, height, width, device = *img.shape, img.device
|
batch, channels, height, width, device = *img.shape, img.device
|
||||||
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
|
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
|
||||||
@@ -502,11 +503,11 @@ class VQGanVAE(nn.Module):
|
|||||||
|
|
||||||
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
|
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
|
||||||
|
|
||||||
gp = gradient_penalty(img, img_discr_logits)
|
|
||||||
|
|
||||||
discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
|
discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
|
||||||
|
|
||||||
loss = discr_loss + gp
|
if add_gradient_penalty:
|
||||||
|
gp = gradient_penalty(img, img_discr_logits)
|
||||||
|
loss = discr_loss + gp
|
||||||
|
|
||||||
if return_recons:
|
if return_recons:
|
||||||
return loss, fmap
|
return loss, fmap
|
||||||
|
|||||||
Reference in New Issue
Block a user