mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-19 03:34:39 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6520d17215 |
17
README.md
17
README.md
@@ -634,12 +634,10 @@ Alternatively, you can also use <a href="https://github.com/mlfoundations/open_c
|
|||||||
$ pip install open-clip-torch
|
$ pip install open-clip-torch
|
||||||
```
|
```
|
||||||
|
|
||||||
Ex. using the <a href="https://laion.ai/blog/large-openclip/">SOTA Open Clip</a> model trained by <a href="https://github.com/rom1504">Romain</a>
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from dalle2_pytorch import OpenClipAdapter
|
from dalle2_pytorch import OpenClipAdapter
|
||||||
|
|
||||||
clip = OpenClipAdapter('ViT-H/14')
|
clip = OpenClipAdapter()
|
||||||
```
|
```
|
||||||
|
|
||||||
Now you'll just have to worry about training the Prior and the Decoder!
|
Now you'll just have to worry about training the Prior and the Decoder!
|
||||||
@@ -1068,7 +1066,7 @@ dataloader = create_image_embedding_dataloader(
|
|||||||
)
|
)
|
||||||
for img, emb in dataloader:
|
for img, emb in dataloader:
|
||||||
print(img.shape) # torch.Size([32, 3, 256, 256])
|
print(img.shape) # torch.Size([32, 3, 256, 256])
|
||||||
print(emb["img"].shape) # torch.Size([32, 512])
|
print(emb.shape) # torch.Size([32, 512])
|
||||||
# Train decoder only as shown above
|
# Train decoder only as shown above
|
||||||
|
|
||||||
# Or create a dataset without a loader so you can configure it manually
|
# Or create a dataset without a loader so you can configure it manually
|
||||||
@@ -1128,7 +1126,6 @@ For detailed information on training the diffusion prior, please refer to the [d
|
|||||||
- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
|
- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
|
||||||
- [x] add the final combination of upsample feature maps, used in unet squared, seems to have an effect in local experiments
|
- [x] add the final combination of upsample feature maps, used in unet squared, seems to have an effect in local experiments
|
||||||
- [ ] consider elucidated dalle2 https://arxiv.org/abs/2206.00364
|
- [ ] consider elucidated dalle2 https://arxiv.org/abs/2206.00364
|
||||||
- [ ] add simple outpainting, text-guided 2x size the image for starters
|
|
||||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||||
|
|
||||||
## Citations
|
## Citations
|
||||||
@@ -1298,14 +1295,4 @@ For detailed information on training the diffusion prior, please refer to the [d
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
```bibtex
|
|
||||||
@article{Salimans2022ProgressiveDF,
|
|
||||||
title = {Progressive Distillation for Fast Sampling of Diffusion Models},
|
|
||||||
author = {Tim Salimans and Jonathan Ho},
|
|
||||||
journal = {ArXiv},
|
|
||||||
year = {2022},
|
|
||||||
volume = {abs/2202.00512}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from dalle2_pytorch.version import __version__
|
from dalle2_pytorch.version import __version__
|
||||||
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
||||||
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter, OpenClipAdapter
|
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
|
||||||
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
|
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
|
||||||
|
|
||||||
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
||||||
|
|||||||
@@ -100,9 +100,6 @@ def eval_decorator(fn):
|
|||||||
return out
|
return out
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
def is_float_dtype(dtype):
|
|
||||||
return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])
|
|
||||||
|
|
||||||
def is_list_str(x):
|
def is_list_str(x):
|
||||||
if not isinstance(x, (list, tuple)):
|
if not isinstance(x, (list, tuple)):
|
||||||
return False
|
return False
|
||||||
@@ -317,10 +314,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|||||||
self.eos_id = 49407 # for handling 0 being also '!'
|
self.eos_id = 49407 # for handling 0 being also '!'
|
||||||
|
|
||||||
text_attention_final = self.find_layer('ln_final')
|
text_attention_final = self.find_layer('ln_final')
|
||||||
|
|
||||||
self.dim_latent_ = text_attention_final.weight.shape[0]
|
|
||||||
self.handle = text_attention_final.register_forward_hook(self._hook)
|
self.handle = text_attention_final.register_forward_hook(self._hook)
|
||||||
|
|
||||||
self.clip_normalize = preprocess.transforms[-1]
|
self.clip_normalize = preprocess.transforms[-1]
|
||||||
self.cleared = False
|
self.cleared = False
|
||||||
|
|
||||||
@@ -339,7 +333,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def dim_latent(self):
|
def dim_latent(self):
|
||||||
return self.dim_latent_
|
return 512
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def image_size(self):
|
def image_size(self):
|
||||||
@@ -389,8 +383,6 @@ class OpenClipAdapter(BaseClipAdapter):
|
|||||||
self.eos_id = 49407
|
self.eos_id = 49407
|
||||||
|
|
||||||
text_attention_final = self.find_layer('ln_final')
|
text_attention_final = self.find_layer('ln_final')
|
||||||
self._dim_latent = text_attention_final.weight.shape[0]
|
|
||||||
|
|
||||||
self.handle = text_attention_final.register_forward_hook(self._hook)
|
self.handle = text_attention_final.register_forward_hook(self._hook)
|
||||||
self.clip_normalize = preprocess.transforms[-1]
|
self.clip_normalize = preprocess.transforms[-1]
|
||||||
self.cleared = False
|
self.cleared = False
|
||||||
@@ -410,14 +402,11 @@ class OpenClipAdapter(BaseClipAdapter):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def dim_latent(self):
|
def dim_latent(self):
|
||||||
return self._dim_latent
|
return 512
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def image_size(self):
|
def image_size(self):
|
||||||
image_size = self.clip.visual.image_size
|
return self.clip.visual.image_size
|
||||||
if isinstance(image_size, tuple):
|
|
||||||
return max(image_size)
|
|
||||||
return image_size
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def image_channels(self):
|
def image_channels(self):
|
||||||
@@ -619,7 +608,7 @@ class NoiseScheduler(nn.Module):
|
|||||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||||
|
|
||||||
def q_sample(self, x_start, t, noise = None):
|
def q_sample(self, x_start, t, noise=None):
|
||||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@@ -627,12 +616,6 @@ class NoiseScheduler(nn.Module):
|
|||||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||||
)
|
)
|
||||||
|
|
||||||
def calculate_v(self, x_start, t, noise = None):
|
|
||||||
return (
|
|
||||||
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
|
|
||||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
|
|
||||||
)
|
|
||||||
|
|
||||||
def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
|
def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
|
||||||
shape = x_from.shape
|
shape = x_from.shape
|
||||||
noise = default(noise, lambda: torch.randn_like(x_from))
|
noise = default(noise, lambda: torch.randn_like(x_from))
|
||||||
@@ -644,12 +627,6 @@ class NoiseScheduler(nn.Module):
|
|||||||
|
|
||||||
return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha
|
return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha
|
||||||
|
|
||||||
def predict_start_from_v(self, x_t, t, v):
|
|
||||||
return (
|
|
||||||
extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
|
|
||||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
|
|
||||||
)
|
|
||||||
|
|
||||||
def predict_start_from_noise(self, x_t, t, noise):
|
def predict_start_from_noise(self, x_t, t, noise):
|
||||||
return (
|
return (
|
||||||
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
||||||
@@ -985,8 +962,6 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
Rearrange('b (n d) -> b n d', n = num_text_embeds)
|
Rearrange('b (n d) -> b n d', n = num_text_embeds)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.continuous_embedded_time = not exists(num_timesteps)
|
|
||||||
|
|
||||||
self.to_time_embeds = nn.Sequential(
|
self.to_time_embeds = nn.Sequential(
|
||||||
nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
||||||
Rearrange('b (n d) -> b n d', n = num_time_embeds)
|
Rearrange('b (n d) -> b n d', n = num_time_embeds)
|
||||||
@@ -1095,7 +1070,7 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
|
|
||||||
null_text_embeds = self.null_text_embeds.to(text_embed.dtype)
|
null_text_embeds = self.null_text_embeds.to(text_embed.dtype)
|
||||||
|
|
||||||
text_embed = torch.where(
|
text_embeds = torch.where(
|
||||||
text_keep_mask,
|
text_keep_mask,
|
||||||
text_embed,
|
text_embed,
|
||||||
null_text_embeds
|
null_text_embeds
|
||||||
@@ -1114,9 +1089,6 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
||||||
# but let's just do it right
|
# but let's just do it right
|
||||||
|
|
||||||
if self.continuous_embedded_time:
|
|
||||||
diffusion_timesteps = diffusion_timesteps.type(dtype)
|
|
||||||
|
|
||||||
time_embed = self.to_time_embeds(diffusion_timesteps)
|
time_embed = self.to_time_embeds(diffusion_timesteps)
|
||||||
|
|
||||||
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
||||||
@@ -1158,7 +1130,6 @@ class DiffusionPrior(nn.Module):
|
|||||||
image_cond_drop_prob = None,
|
image_cond_drop_prob = None,
|
||||||
loss_type = "l2",
|
loss_type = "l2",
|
||||||
predict_x_start = True,
|
predict_x_start = True,
|
||||||
predict_v = False,
|
|
||||||
beta_schedule = "cosine",
|
beta_schedule = "cosine",
|
||||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||||
sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)
|
sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)
|
||||||
@@ -1210,7 +1181,6 @@ class DiffusionPrior(nn.Module):
|
|||||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||||
|
|
||||||
self.predict_x_start = predict_x_start
|
self.predict_x_start = predict_x_start
|
||||||
self.predict_v = predict_v # takes precedence over predict_x_start
|
|
||||||
|
|
||||||
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||||
|
|
||||||
@@ -1240,9 +1210,7 @@ class DiffusionPrior(nn.Module):
|
|||||||
|
|
||||||
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)
|
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)
|
||||||
|
|
||||||
if self.predict_v:
|
if self.predict_x_start:
|
||||||
x_start = self.noise_scheduler.predict_start_from_v(x, t = t, v = pred)
|
|
||||||
elif self.predict_x_start:
|
|
||||||
x_start = pred
|
x_start = pred
|
||||||
else:
|
else:
|
||||||
x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||||
@@ -1291,7 +1259,7 @@ class DiffusionPrior(nn.Module):
|
|||||||
def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scale = 1.):
|
def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scale = 1.):
|
||||||
batch, device, alphas, total_timesteps = shape[0], self.device, self.noise_scheduler.alphas_cumprod_prev, self.noise_scheduler.num_timesteps
|
batch, device, alphas, total_timesteps = shape[0], self.device, self.noise_scheduler.alphas_cumprod_prev, self.noise_scheduler.num_timesteps
|
||||||
|
|
||||||
times = torch.linspace(-1., total_timesteps, steps = timesteps + 1)[:-1]
|
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
|
||||||
|
|
||||||
times = list(reversed(times.int().tolist()))
|
times = list(reversed(times.int().tolist()))
|
||||||
time_pairs = list(zip(times[:-1], times[1:]))
|
time_pairs = list(zip(times[:-1], times[1:]))
|
||||||
@@ -1313,16 +1281,12 @@ class DiffusionPrior(nn.Module):
|
|||||||
|
|
||||||
pred = self.net.forward_with_cond_scale(image_embed, time_cond, self_cond = self_cond, cond_scale = cond_scale, **text_cond)
|
pred = self.net.forward_with_cond_scale(image_embed, time_cond, self_cond = self_cond, cond_scale = cond_scale, **text_cond)
|
||||||
|
|
||||||
# derive x0
|
if self.predict_x_start:
|
||||||
|
|
||||||
if self.predict_v:
|
|
||||||
x_start = self.noise_scheduler.predict_start_from_v(image_embed, t = time_cond, v = pred)
|
|
||||||
elif self.predict_x_start:
|
|
||||||
x_start = pred
|
x_start = pred
|
||||||
|
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = pred)
|
||||||
else:
|
else:
|
||||||
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred)
|
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred)
|
||||||
|
pred_noise = pred
|
||||||
# clip x0 before maybe predicting noise
|
|
||||||
|
|
||||||
if not self.predict_x_start:
|
if not self.predict_x_start:
|
||||||
x_start.clamp_(-1., 1.)
|
x_start.clamp_(-1., 1.)
|
||||||
@@ -1330,17 +1294,6 @@ class DiffusionPrior(nn.Module):
|
|||||||
if self.predict_x_start and self.sampling_clamp_l2norm:
|
if self.predict_x_start and self.sampling_clamp_l2norm:
|
||||||
x_start = self.l2norm_clamp_embed(x_start)
|
x_start = self.l2norm_clamp_embed(x_start)
|
||||||
|
|
||||||
# predict noise
|
|
||||||
|
|
||||||
if self.predict_x_start or self.predict_v:
|
|
||||||
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)
|
|
||||||
else:
|
|
||||||
pred_noise = pred
|
|
||||||
|
|
||||||
if time_next < 0:
|
|
||||||
image_embed = x_start
|
|
||||||
continue
|
|
||||||
|
|
||||||
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
||||||
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
||||||
noise = torch.randn_like(image_embed) if time_next > 0 else 0.
|
noise = torch.randn_like(image_embed) if time_next > 0 else 0.
|
||||||
@@ -1390,12 +1343,7 @@ class DiffusionPrior(nn.Module):
|
|||||||
if self.predict_x_start and self.training_clamp_l2norm:
|
if self.predict_x_start and self.training_clamp_l2norm:
|
||||||
pred = self.l2norm_clamp_embed(pred)
|
pred = self.l2norm_clamp_embed(pred)
|
||||||
|
|
||||||
if self.predict_v:
|
target = noise if not self.predict_x_start else image_embed
|
||||||
target = self.noise_scheduler.calculate_v(image_embed, times, noise)
|
|
||||||
elif self.predict_x_start:
|
|
||||||
target = image_embed
|
|
||||||
else:
|
|
||||||
target = noise
|
|
||||||
|
|
||||||
loss = self.noise_scheduler.loss_fn(pred, target)
|
loss = self.noise_scheduler.loss_fn(pred, target)
|
||||||
return loss
|
return loss
|
||||||
@@ -1465,7 +1413,7 @@ class DiffusionPrior(nn.Module):
|
|||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied'
|
assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied'
|
||||||
assert exists(image) ^ exists(image_embed), 'either image or image embedding must be supplied'
|
assert exists(image) ^ exists(image_embed), 'either text or text embedding must be supplied'
|
||||||
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
|
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
|
||||||
|
|
||||||
if exists(image):
|
if exists(image):
|
||||||
@@ -1571,8 +1519,6 @@ class SinusoidalPosEmb(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
dtype, device = x.dtype, x.device
|
dtype, device = x.dtype, x.device
|
||||||
assert is_float_dtype(dtype), 'input to sinusoidal pos emb must be a float type'
|
|
||||||
|
|
||||||
half_dim = self.dim // 2
|
half_dim = self.dim // 2
|
||||||
emb = math.log(10000) / (half_dim - 1)
|
emb = math.log(10000) / (half_dim - 1)
|
||||||
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
|
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
|
||||||
@@ -2471,7 +2417,6 @@ class Decoder(nn.Module):
|
|||||||
loss_type = 'l2',
|
loss_type = 'l2',
|
||||||
beta_schedule = None,
|
beta_schedule = None,
|
||||||
predict_x_start = False,
|
predict_x_start = False,
|
||||||
predict_v = False,
|
|
||||||
predict_x_start_for_latent_diffusion = False,
|
predict_x_start_for_latent_diffusion = False,
|
||||||
image_sizes = None, # for cascading ddpm, image size at each stage
|
image_sizes = None, # for cascading ddpm, image size at each stage
|
||||||
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
||||||
@@ -2644,10 +2589,6 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
|
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
|
||||||
|
|
||||||
# predict v
|
|
||||||
|
|
||||||
self.predict_v = cast_tuple(predict_v, len(unets))
|
|
||||||
|
|
||||||
# input image range
|
# input image range
|
||||||
|
|
||||||
self.input_image_range = (-1. if not auto_normalize_img else 0., 1.)
|
self.input_image_range = (-1. if not auto_normalize_img else 0., 1.)
|
||||||
@@ -2759,16 +2700,14 @@ class Decoder(nn.Module):
|
|||||||
x = x.clamp(-s, s) / s
|
x = x.clamp(-s, s) / s
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, predict_v = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
|
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
|
||||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||||
|
|
||||||
model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
|
model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
|
||||||
|
|
||||||
pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output)
|
pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output)
|
||||||
|
|
||||||
if predict_v:
|
if predict_x_start:
|
||||||
x_start = noise_scheduler.predict_start_from_v(x, t = t, v = pred)
|
|
||||||
elif predict_x_start:
|
|
||||||
x_start = pred
|
x_start = pred
|
||||||
else:
|
else:
|
||||||
x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||||
@@ -2795,9 +2734,9 @@ class Decoder(nn.Module):
|
|||||||
return model_mean, posterior_variance, posterior_log_variance, x_start
|
return model_mean, posterior_variance, posterior_log_variance, x_start
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, predict_v = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
|
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, predict_v = predict_v, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
|
model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
|
||||||
noise = torch.randn_like(x)
|
noise = torch.randn_like(x)
|
||||||
# no noise when t == 0
|
# no noise when t == 0
|
||||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||||
@@ -2812,7 +2751,6 @@ class Decoder(nn.Module):
|
|||||||
image_embed,
|
image_embed,
|
||||||
noise_scheduler,
|
noise_scheduler,
|
||||||
predict_x_start = False,
|
predict_x_start = False,
|
||||||
predict_v = False,
|
|
||||||
learned_variance = False,
|
learned_variance = False,
|
||||||
clip_denoised = True,
|
clip_denoised = True,
|
||||||
lowres_cond_img = None,
|
lowres_cond_img = None,
|
||||||
@@ -2871,7 +2809,6 @@ class Decoder(nn.Module):
|
|||||||
lowres_cond_img = lowres_cond_img,
|
lowres_cond_img = lowres_cond_img,
|
||||||
lowres_noise_level = lowres_noise_level,
|
lowres_noise_level = lowres_noise_level,
|
||||||
predict_x_start = predict_x_start,
|
predict_x_start = predict_x_start,
|
||||||
predict_v = predict_v,
|
|
||||||
noise_scheduler = noise_scheduler,
|
noise_scheduler = noise_scheduler,
|
||||||
learned_variance = learned_variance,
|
learned_variance = learned_variance,
|
||||||
clip_denoised = clip_denoised
|
clip_denoised = clip_denoised
|
||||||
@@ -2897,7 +2834,6 @@ class Decoder(nn.Module):
|
|||||||
timesteps,
|
timesteps,
|
||||||
eta = 1.,
|
eta = 1.,
|
||||||
predict_x_start = False,
|
predict_x_start = False,
|
||||||
predict_v = False,
|
|
||||||
learned_variance = False,
|
learned_variance = False,
|
||||||
clip_denoised = True,
|
clip_denoised = True,
|
||||||
lowres_cond_img = None,
|
lowres_cond_img = None,
|
||||||
@@ -2957,27 +2893,16 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
pred, _ = self.parse_unet_output(learned_variance, unet_output)
|
pred, _ = self.parse_unet_output(learned_variance, unet_output)
|
||||||
|
|
||||||
# predict x0
|
if predict_x_start:
|
||||||
|
|
||||||
if predict_v:
|
|
||||||
x_start = noise_scheduler.predict_start_from_v(img, t = time_cond, v = pred)
|
|
||||||
elif predict_x_start:
|
|
||||||
x_start = pred
|
x_start = pred
|
||||||
|
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
|
||||||
else:
|
else:
|
||||||
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
|
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
|
||||||
|
pred_noise = pred
|
||||||
# maybe clip x0
|
|
||||||
|
|
||||||
if clip_denoised:
|
if clip_denoised:
|
||||||
x_start = self.dynamic_threshold(x_start)
|
x_start = self.dynamic_threshold(x_start)
|
||||||
|
|
||||||
# predict noise
|
|
||||||
|
|
||||||
if predict_x_start or predict_v:
|
|
||||||
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = x_start)
|
|
||||||
else:
|
|
||||||
pred_noise = pred
|
|
||||||
|
|
||||||
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
||||||
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
|
||||||
noise = torch.randn_like(img) if not is_last_timestep else 0.
|
noise = torch.randn_like(img) if not is_last_timestep else 0.
|
||||||
@@ -3010,7 +2935,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
|
return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
|
||||||
|
|
||||||
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, predict_v = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None):
|
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None):
|
||||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
|
|
||||||
# normalize to [-1, 1]
|
# normalize to [-1, 1]
|
||||||
@@ -3055,12 +2980,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
pred, _ = self.parse_unet_output(learned_variance, unet_output)
|
pred, _ = self.parse_unet_output(learned_variance, unet_output)
|
||||||
|
|
||||||
if predict_v:
|
target = noise if not predict_x_start else x_start
|
||||||
target = noise_scheduler.calculate_v(x_start, times, noise)
|
|
||||||
elif predict_x_start:
|
|
||||||
target = x_start
|
|
||||||
else:
|
|
||||||
target = noise
|
|
||||||
|
|
||||||
loss = noise_scheduler.loss_fn(pred, target, reduction = 'none')
|
loss = noise_scheduler.loss_fn(pred, target, reduction = 'none')
|
||||||
loss = reduce(loss, 'b ... -> b (...)', 'mean')
|
loss = reduce(loss, 'b ... -> b (...)', 'mean')
|
||||||
@@ -3146,7 +3066,7 @@ class Decoder(nn.Module):
|
|||||||
num_unets = self.num_unets
|
num_unets = self.num_unets
|
||||||
cond_scale = cast_tuple(cond_scale, num_unets)
|
cond_scale = cast_tuple(cond_scale, num_unets)
|
||||||
|
|
||||||
for unet_number, unet, vae, channel, image_size, predict_x_start, predict_v, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.predict_v, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
|
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
|
||||||
if unet_number < start_at_unet_number:
|
if unet_number < start_at_unet_number:
|
||||||
continue # It's the easiest way to do it
|
continue # It's the easiest way to do it
|
||||||
|
|
||||||
@@ -3182,7 +3102,6 @@ class Decoder(nn.Module):
|
|||||||
text_encodings = text_encodings,
|
text_encodings = text_encodings,
|
||||||
cond_scale = unet_cond_scale,
|
cond_scale = unet_cond_scale,
|
||||||
predict_x_start = predict_x_start,
|
predict_x_start = predict_x_start,
|
||||||
predict_v = predict_v,
|
|
||||||
learned_variance = learned_variance,
|
learned_variance = learned_variance,
|
||||||
clip_denoised = not is_latent_diffusion,
|
clip_denoised = not is_latent_diffusion,
|
||||||
lowres_cond_img = lowres_cond_img,
|
lowres_cond_img = lowres_cond_img,
|
||||||
@@ -3222,7 +3141,6 @@ class Decoder(nn.Module):
|
|||||||
lowres_conditioner = self.lowres_conds[unet_index]
|
lowres_conditioner = self.lowres_conds[unet_index]
|
||||||
target_image_size = self.image_sizes[unet_index]
|
target_image_size = self.image_sizes[unet_index]
|
||||||
predict_x_start = self.predict_x_start[unet_index]
|
predict_x_start = self.predict_x_start[unet_index]
|
||||||
predict_v = self.predict_v[unet_index]
|
|
||||||
random_crop_size = self.random_crop_sizes[unet_index]
|
random_crop_size = self.random_crop_sizes[unet_index]
|
||||||
learned_variance = self.learned_variance[unet_index]
|
learned_variance = self.learned_variance[unet_index]
|
||||||
b, c, h, w, device, = *image.shape, image.device
|
b, c, h, w, device, = *image.shape, image.device
|
||||||
@@ -3261,7 +3179,7 @@ class Decoder(nn.Module):
|
|||||||
image = vae.encode(image)
|
image = vae.encode(image)
|
||||||
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
||||||
|
|
||||||
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, predict_v = predict_v, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level)
|
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level)
|
||||||
|
|
||||||
if not return_lowres_cond_image:
|
if not return_lowres_cond_image:
|
||||||
return losses
|
return losses
|
||||||
|
|||||||
@@ -4,13 +4,11 @@ from pydantic import BaseModel, validator, root_validator
|
|||||||
from typing import List, Optional, Union, Tuple, Dict, Any, TypeVar
|
from typing import List, Optional, Union, Tuple, Dict, Any, TypeVar
|
||||||
|
|
||||||
from x_clip import CLIP as XCLIP
|
from x_clip import CLIP as XCLIP
|
||||||
from open_clip import list_pretrained
|
|
||||||
from coca_pytorch import CoCa
|
from coca_pytorch import CoCa
|
||||||
|
|
||||||
from dalle2_pytorch.dalle2_pytorch import (
|
from dalle2_pytorch.dalle2_pytorch import (
|
||||||
CoCaAdapter,
|
CoCaAdapter,
|
||||||
OpenAIClipAdapter,
|
OpenAIClipAdapter,
|
||||||
OpenClipAdapter,
|
|
||||||
Unet,
|
Unet,
|
||||||
Decoder,
|
Decoder,
|
||||||
DiffusionPrior,
|
DiffusionPrior,
|
||||||
@@ -119,10 +117,6 @@ class AdapterConfig(BaseModel):
|
|||||||
def create(self):
|
def create(self):
|
||||||
if self.make == "openai":
|
if self.make == "openai":
|
||||||
return OpenAIClipAdapter(self.model)
|
return OpenAIClipAdapter(self.model)
|
||||||
elif self.make == "open_clip":
|
|
||||||
pretrained = dict(list_pretrained())
|
|
||||||
checkpoint = pretrained[self.model]
|
|
||||||
return OpenClipAdapter(name=self.model, pretrained=checkpoint)
|
|
||||||
elif self.make == "x-clip":
|
elif self.make == "x-clip":
|
||||||
return XClipAdapter(XCLIP(**self.base_model_kwargs))
|
return XClipAdapter(XCLIP(**self.base_model_kwargs))
|
||||||
elif self.make == "coca":
|
elif self.make == "coca":
|
||||||
@@ -313,7 +307,6 @@ class DecoderTrainConfig(BaseModel):
|
|||||||
wd: SingularOrIterable[float] = 0.01
|
wd: SingularOrIterable[float] = 0.01
|
||||||
warmup_steps: Optional[SingularOrIterable[int]] = None
|
warmup_steps: Optional[SingularOrIterable[int]] = None
|
||||||
find_unused_parameters: bool = True
|
find_unused_parameters: bool = True
|
||||||
static_graph: bool = True
|
|
||||||
max_grad_norm: SingularOrIterable[float] = 0.5
|
max_grad_norm: SingularOrIterable[float] = 0.5
|
||||||
save_every_n_samples: int = 100000
|
save_every_n_samples: int = 100000
|
||||||
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
|
||||||
|
|||||||
@@ -236,7 +236,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if exists(cosine_decay_max_steps):
|
if exists(cosine_decay_max_steps):
|
||||||
self.scheduler = CosineAnnealingLR(self.optimizer, T_max = cosine_decay_max_steps)
|
self.scheduler = CosineAnnealingLR(optimizer, T_max = cosine_decay_max_steps)
|
||||||
else:
|
else:
|
||||||
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
|
self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '1.11.2'
|
__version__ = '1.10.2'
|
||||||
|
|||||||
1
setup.py
1
setup.py
@@ -26,7 +26,6 @@ setup(
|
|||||||
install_requires=[
|
install_requires=[
|
||||||
'accelerate',
|
'accelerate',
|
||||||
'click',
|
'click',
|
||||||
'open-clip-torch>=2.0.0,<3.0.0',
|
|
||||||
'clip-anytorch>=2.4.0',
|
'clip-anytorch>=2.4.0',
|
||||||
'coca-pytorch>=0.0.5',
|
'coca-pytorch>=0.0.5',
|
||||||
'ema-pytorch>=0.0.7',
|
'ema-pytorch>=0.0.7',
|
||||||
|
|||||||
@@ -156,7 +156,7 @@ def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=No
|
|||||||
if text_embeddings[0] is None:
|
if text_embeddings[0] is None:
|
||||||
# Generate text embeddings from text
|
# Generate text embeddings from text
|
||||||
assert clip is not None, "clip is None, but text_embeddings is None"
|
assert clip is not None, "clip is None, but text_embeddings is None"
|
||||||
tokenized_texts = tokenize(txts, truncate=True).to(device=device)
|
tokenized_texts = tokenize(txts, truncate=True)
|
||||||
text_embed, text_encodings = clip.embed_text(tokenized_texts)
|
text_embed, text_encodings = clip.embed_text(tokenized_texts)
|
||||||
sample_params["text_encodings"] = text_encodings
|
sample_params["text_encodings"] = text_encodings
|
||||||
else:
|
else:
|
||||||
@@ -229,8 +229,8 @@ def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, clip=Non
|
|||||||
metrics["KID_std"] = kid_std.item()
|
metrics["KID_std"] = kid_std.item()
|
||||||
if exists(LPIPS):
|
if exists(LPIPS):
|
||||||
# Convert from [0, 1] to [-1, 1]
|
# Convert from [0, 1] to [-1, 1]
|
||||||
renorm_real_images = real_images.mul(2).sub(1).clamp(-1,1)
|
renorm_real_images = real_images.mul(2).sub(1)
|
||||||
renorm_generated_images = generated_images.mul(2).sub(1).clamp(-1,1)
|
renorm_generated_images = generated_images.mul(2).sub(1)
|
||||||
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync)
|
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync)
|
||||||
lpips.to(device=device)
|
lpips.to(device=device)
|
||||||
lpips.update(renorm_real_images, renorm_generated_images)
|
lpips.update(renorm_real_images, renorm_generated_images)
|
||||||
@@ -480,7 +480,7 @@ def train(
|
|||||||
else:
|
else:
|
||||||
# Then we need to pass the text instead
|
# Then we need to pass the text instead
|
||||||
assert clip is not None
|
assert clip is not None
|
||||||
tokenized_texts = tokenize(txt, truncate=True).to(device=inference_device)
|
tokenized_texts = tokenize(txt, truncate=True)
|
||||||
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
|
assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
|
||||||
text_embed, text_encodings = clip.embed_text(tokenized_texts)
|
text_embed, text_encodings = clip.embed_text(tokenized_texts)
|
||||||
forward_params['text_encodings'] = text_encodings
|
forward_params['text_encodings'] = text_encodings
|
||||||
@@ -556,7 +556,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
|
|||||||
torch.manual_seed(config.seed)
|
torch.manual_seed(config.seed)
|
||||||
|
|
||||||
# Set up accelerator for configurable distributed training
|
# Set up accelerator for configurable distributed training
|
||||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters, static_graph=config.train.static_graph)
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
|
||||||
init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60))
|
init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60))
|
||||||
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs])
|
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs])
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user