mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-24 00:04:20 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a0e41267f8 |
12
README.md
12
README.md
@@ -12,7 +12,7 @@ This model is SOTA for text-to-image for now.
|
|||||||
|
|
||||||
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication with the <a href="https://laion.ai/">LAION</a> community | <a href="https://www.youtube.com/watch?v=AIOE1l1W0Tw">Yannic Interview</a>
|
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication with the <a href="https://laion.ai/">LAION</a> community | <a href="https://www.youtube.com/watch?v=AIOE1l1W0Tw">Yannic Interview</a>
|
||||||
|
|
||||||
As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lucidrains/imagen-pytorch">here</a>. Jax versions as well as text-to-video project will be shifted towards the Imagen architecture, as it is way simpler.
|
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
||||||
|
|
||||||
## Status
|
## Status
|
||||||
|
|
||||||
@@ -26,7 +26,7 @@ As of 5/23/22, it is no longer SOTA. SOTA will be <a href="https://github.com/lu
|
|||||||
|
|
||||||
## Pre-Trained Models
|
## Pre-Trained Models
|
||||||
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
|
- LAION is training prior models. Checkpoints are available on <a href="https://huggingface.co/zenglishuci/conditioned-prior">🤗huggingface</a> and the training statistics are available on <a href="https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx">🐝WANDB</a>.
|
||||||
- Decoder - <a href="https://wandb.ai/veldrovive/dalle2_train_decoder/runs/jkrtg0so?workspace=user-veldrovive">In-progress test run</a> 🚧
|
- Decoder 🚧
|
||||||
- DALL-E 2 🚧
|
- DALL-E 2 🚧
|
||||||
|
|
||||||
## Install
|
## Install
|
||||||
@@ -1195,12 +1195,4 @@ This library would not have gotten to this working state without the help of
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
```bibtex
|
|
||||||
@misc{Saharia2022,
|
|
||||||
title = {Imagen: unprecedented photorealism × deep level of language understanding},
|
|
||||||
author = {Chitwan Saharia*, William Chan*, Saurabh Saxena†, Lala Li†, Jay Whang†, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho†, David Fleet†, Mohammad Norouzi*},
|
|
||||||
year = {2022}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||||
|
|||||||
@@ -890,7 +890,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if exists(clip):
|
if exists(clip):
|
||||||
assert image_channels == clip.image_channels, f'channels of image ({image_channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
|
assert image_channels == clip.image_channels, f'channels of image ({channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
|
||||||
|
|
||||||
if isinstance(clip, CLIP):
|
if isinstance(clip, CLIP):
|
||||||
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
clip = XClipAdapter(clip, **clip_adapter_overrides)
|
||||||
@@ -1704,8 +1704,6 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
vb_loss_weight = 0.001,
|
vb_loss_weight = 0.001,
|
||||||
unconditional = False,
|
unconditional = False,
|
||||||
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
||||||
use_dynamic_thres = False, # from the Imagen paper
|
|
||||||
dynamic_thres_percentile = 0.9
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
beta_schedule = beta_schedule,
|
beta_schedule = beta_schedule,
|
||||||
@@ -1828,11 +1826,6 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
self.clip_denoised = clip_denoised
|
self.clip_denoised = clip_denoised
|
||||||
self.clip_x_start = clip_x_start
|
self.clip_x_start = clip_x_start
|
||||||
|
|
||||||
# dynamic thresholding settings, if clipping denoised during sampling
|
|
||||||
|
|
||||||
self.use_dynamic_thres = use_dynamic_thres
|
|
||||||
self.dynamic_thres_percentile = dynamic_thres_percentile
|
|
||||||
|
|
||||||
# normalize and unnormalize image functions
|
# normalize and unnormalize image functions
|
||||||
|
|
||||||
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
|
self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
|
||||||
@@ -1875,21 +1868,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
||||||
|
|
||||||
if clip_denoised:
|
if clip_denoised:
|
||||||
# s is the threshold amount
|
x_recon.clamp_(-1., 1.)
|
||||||
# static thresholding would just be s = 1
|
|
||||||
s = 1.
|
|
||||||
if self.use_dynamic_thres:
|
|
||||||
s = torch.quantile(
|
|
||||||
rearrange(x_recon, 'b ... -> b (...)').abs(),
|
|
||||||
self.dynamic_thres_percentile,
|
|
||||||
dim = -1
|
|
||||||
)
|
|
||||||
|
|
||||||
s.clamp_(min = 1.)
|
|
||||||
s = s.view(-1, *((1,) * (x_recon.ndim - 1)))
|
|
||||||
|
|
||||||
# clip by threshold, depending on whether static or dynamic
|
|
||||||
x_recon = x_recon.clamp(-s, s) / s
|
|
||||||
|
|
||||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user