mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-23 16:55:39 +01:00
Compare commits
28 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6647050c33 | ||
|
|
b4c3e5b854 | ||
|
|
b7f9607258 | ||
|
|
2219348a6e | ||
|
|
9eea9b9862 | ||
|
|
5d958713c0 | ||
|
|
0f31980362 | ||
|
|
bee5bf3815 | ||
|
|
350a3d6045 | ||
|
|
1a81670718 | ||
|
|
934c9728dc | ||
|
|
ce4b0107c1 | ||
|
|
64c2f9c4eb | ||
|
|
22cc613278 | ||
|
|
83517849e5 | ||
|
|
708809ed6c | ||
|
|
9cc475f6e7 | ||
|
|
ffd342e9d0 | ||
|
|
f8bfd3493a | ||
|
|
9025345e29 | ||
|
|
8cc278447e | ||
|
|
38cd62010c | ||
|
|
1cc288af39 | ||
|
|
a851168633 | ||
|
|
1ffeecd0ca | ||
|
|
3df899f7a4 | ||
|
|
09534119a1 | ||
|
|
6f8b90d4d7 |
14
README.md
14
README.md
@@ -943,7 +943,7 @@ from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embed
|
|||||||
|
|
||||||
# Create a dataloader directly.
|
# Create a dataloader directly.
|
||||||
dataloader = create_image_embedding_dataloader(
|
dataloader = create_image_embedding_dataloader(
|
||||||
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
||||||
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
|
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
@@ -1097,7 +1097,7 @@ This library would not have gotten to this working state without the help of
|
|||||||
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697
|
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697
|
||||||
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
|
||||||
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
|
||||||
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
- [ ] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
|
||||||
- [ ] decoder needs one day worth of refactor for tech debt
|
- [ ] decoder needs one day worth of refactor for tech debt
|
||||||
- [ ] allow for unet to be able to condition non-cross attention style as well
|
- [ ] allow for unet to be able to condition non-cross attention style as well
|
||||||
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
|
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89
|
||||||
@@ -1207,4 +1207,14 @@ This library would not have gotten to this working state without the help of
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{Choi2022PerceptionPT,
|
||||||
|
title = {Perception Prioritized Training of Diffusion Models},
|
||||||
|
author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo J. Kim and Sung-Hoon Yoon},
|
||||||
|
journal = {ArXiv},
|
||||||
|
year = {2022},
|
||||||
|
volume = {abs/2204.00227}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ Defines which evaluation metrics will be used to test the model.
|
|||||||
Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked.
|
Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked.
|
||||||
| Option | Required | Default | Description |
|
| Option | Required | Default | Description |
|
||||||
| ------ | -------- | ------- | ----------- |
|
| ------ | -------- | ------- | ----------- |
|
||||||
| `n_evalation_samples` | No | `1000` | The number of samples to generate to test the model. |
|
| `n_evaluation_samples` | No | `1000` | The number of samples to generate to test the model. |
|
||||||
| `FID` | No | `None` | Setting to an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric.
|
| `FID` | No | `None` | Setting to an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric.
|
||||||
| `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric.
|
| `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric.
|
||||||
| `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. |
|
| `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. |
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import math
|
import math
|
||||||
|
import random
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from inspect import isfunction
|
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
@@ -11,7 +11,7 @@ import torch.nn.functional as F
|
|||||||
from torch import nn, einsum
|
from torch import nn, einsum
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
|
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat, reduce
|
||||||
from einops.layers.torch import Rearrange
|
from einops.layers.torch import Rearrange
|
||||||
from einops_exts import rearrange_many, repeat_many, check_shape
|
from einops_exts import rearrange_many, repeat_many, check_shape
|
||||||
from einops_exts.torch import EinopsToAndFrom
|
from einops_exts.torch import EinopsToAndFrom
|
||||||
@@ -56,7 +56,7 @@ def maybe(fn):
|
|||||||
def default(val, d):
|
def default(val, d):
|
||||||
if exists(val):
|
if exists(val):
|
||||||
return val
|
return val
|
||||||
return d() if isfunction(d) else d
|
return d() if callable(d) else d
|
||||||
|
|
||||||
def cast_tuple(val, length = 1):
|
def cast_tuple(val, length = 1):
|
||||||
if isinstance(val, list):
|
if isinstance(val, list):
|
||||||
@@ -313,11 +313,6 @@ def extract(a, t, x_shape):
|
|||||||
out = a.gather(-1, t)
|
out = a.gather(-1, t)
|
||||||
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
||||||
|
|
||||||
def noise_like(shape, device, repeat=False):
|
|
||||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
|
||||||
noise = lambda: torch.randn(shape, device=device)
|
|
||||||
return repeat_noise() if repeat else noise()
|
|
||||||
|
|
||||||
def meanflat(x):
|
def meanflat(x):
|
||||||
return x.mean(dim = tuple(range(1, len(x.shape))))
|
return x.mean(dim = tuple(range(1, len(x.shape))))
|
||||||
|
|
||||||
@@ -372,7 +367,7 @@ def quadratic_beta_schedule(timesteps):
|
|||||||
scale = 1000 / timesteps
|
scale = 1000 / timesteps
|
||||||
beta_start = scale * 0.0001
|
beta_start = scale * 0.0001
|
||||||
beta_end = scale * 0.02
|
beta_end = scale * 0.02
|
||||||
return torch.linspace(beta_start**2, beta_end**2, timesteps, dtype = torch.float64) ** 2
|
return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype = torch.float64) ** 2
|
||||||
|
|
||||||
|
|
||||||
def sigmoid_beta_schedule(timesteps):
|
def sigmoid_beta_schedule(timesteps):
|
||||||
@@ -384,7 +379,7 @@ def sigmoid_beta_schedule(timesteps):
|
|||||||
|
|
||||||
|
|
||||||
class BaseGaussianDiffusion(nn.Module):
|
class BaseGaussianDiffusion(nn.Module):
|
||||||
def __init__(self, *, beta_schedule, timesteps, loss_type):
|
def __init__(self, *, beta_schedule, timesteps, loss_type, p2_loss_weight_gamma = 0., p2_loss_weight_k = 1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if beta_schedule == "cosine":
|
if beta_schedule == "cosine":
|
||||||
@@ -449,6 +444,11 @@ class BaseGaussianDiffusion(nn.Module):
|
|||||||
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||||
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||||
|
|
||||||
|
# p2 loss reweighting
|
||||||
|
|
||||||
|
self.has_p2_loss_reweighting = p2_loss_weight_gamma > 0.
|
||||||
|
register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)
|
||||||
|
|
||||||
def q_posterior(self, x_start, x_t, t):
|
def q_posterior(self, x_start, x_t, t):
|
||||||
posterior_mean = (
|
posterior_mean = (
|
||||||
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
||||||
@@ -945,10 +945,10 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
return model_mean, posterior_variance, posterior_log_variance
|
return model_mean, posterior_variance, posterior_log_variance
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False, cond_scale = 1.):
|
def p_sample(self, x, t, text_cond = None, clip_denoised = True, cond_scale = 1.):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
|
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
|
||||||
noise = noise_like(x.shape, device, repeat_noise)
|
noise = torch.randn_like(x)
|
||||||
# no noise when t == 0
|
# no noise when t == 0
|
||||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||||
@@ -1084,8 +1084,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
def Upsample(dim):
|
def Upsample(dim):
|
||||||
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
||||||
|
|
||||||
def Downsample(dim):
|
def Downsample(dim, *, dim_out = None):
|
||||||
return nn.Conv2d(dim, dim, 4, 2, 1)
|
dim_out = default(dim_out, dim)
|
||||||
|
return nn.Conv2d(dim, dim_out, 4, 2, 1)
|
||||||
|
|
||||||
class SinusoidalPosEmb(nn.Module):
|
class SinusoidalPosEmb(nn.Module):
|
||||||
def __init__(self, dim):
|
def __init__(self, dim):
|
||||||
@@ -1351,6 +1352,7 @@ class Unet(nn.Module):
|
|||||||
init_cross_embed_kernel_sizes = (3, 7, 15),
|
init_cross_embed_kernel_sizes = (3, 7, 15),
|
||||||
cross_embed_downsample = False,
|
cross_embed_downsample = False,
|
||||||
cross_embed_downsample_kernel_sizes = (2, 4),
|
cross_embed_downsample_kernel_sizes = (2, 4),
|
||||||
|
memory_efficient = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -1370,7 +1372,7 @@ class Unet(nn.Module):
|
|||||||
self.channels_out = default(channels_out, channels)
|
self.channels_out = default(channels_out, channels)
|
||||||
|
|
||||||
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
|
||||||
init_dim = default(init_dim, dim // 3 * 2)
|
init_dim = default(init_dim, dim)
|
||||||
|
|
||||||
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
|
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1)
|
||||||
|
|
||||||
@@ -1427,6 +1429,7 @@ class Unet(nn.Module):
|
|||||||
# for classifier free guidance
|
# for classifier free guidance
|
||||||
|
|
||||||
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||||
|
self.null_image_hiddens = nn.Parameter(torch.randn(1, time_cond_dim))
|
||||||
|
|
||||||
self.max_text_len = max_text_len
|
self.max_text_len = max_text_len
|
||||||
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
||||||
@@ -1460,10 +1463,11 @@ class Unet(nn.Module):
|
|||||||
layer_cond_dim = cond_dim if not is_first else None
|
layer_cond_dim = cond_dim if not is_first else None
|
||||||
|
|
||||||
self.downs.append(nn.ModuleList([
|
self.downs.append(nn.ModuleList([
|
||||||
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
|
||||||
|
ResnetBlock(dim_out if memory_efficient else dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
|
||||||
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||||
nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
|
||||||
downsample_klass(dim_out) if not is_last else nn.Identity()
|
downsample_klass(dim_out) if not is_last and not memory_efficient else None
|
||||||
]))
|
]))
|
||||||
|
|
||||||
mid_dim = dims[-1]
|
mid_dim = dims[-1]
|
||||||
@@ -1472,7 +1476,9 @@ class Unet(nn.Module):
|
|||||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
||||||
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
|
||||||
|
|
||||||
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups), reversed(num_resnet_blocks))):
|
up_in_out_slice = slice(1 if not memory_efficient else None, None)
|
||||||
|
|
||||||
|
for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks) in enumerate(zip(reversed(in_out[up_in_out_slice]), reversed(resnet_groups), reversed(num_resnet_blocks))):
|
||||||
is_last = ind >= (num_resolutions - 2)
|
is_last = ind >= (num_resolutions - 2)
|
||||||
layer_cond_dim = cond_dim if not is_last else None
|
layer_cond_dim = cond_dim if not is_last else None
|
||||||
|
|
||||||
@@ -1483,8 +1489,10 @@ class Unet(nn.Module):
|
|||||||
Upsample(dim_in)
|
Upsample(dim_in)
|
||||||
]))
|
]))
|
||||||
|
|
||||||
|
final_dim_in = dim * (1 if memory_efficient else 2)
|
||||||
|
|
||||||
self.final_conv = nn.Sequential(
|
self.final_conv = nn.Sequential(
|
||||||
ResnetBlock(dim, dim, groups = resnet_groups[0]),
|
ResnetBlock(final_dim_in, dim, groups = resnet_groups[0]),
|
||||||
nn.Conv2d(dim, self.channels_out, 1)
|
nn.Conv2d(dim, self.channels_out, 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1564,19 +1572,28 @@ class Unet(nn.Module):
|
|||||||
time_tokens = self.to_time_tokens(time_hiddens)
|
time_tokens = self.to_time_tokens(time_hiddens)
|
||||||
t = self.to_time_cond(time_hiddens)
|
t = self.to_time_cond(time_hiddens)
|
||||||
|
|
||||||
# image embedding to be summed to time embedding
|
|
||||||
# discovered by @mhh0318 in the paper
|
|
||||||
|
|
||||||
if exists(image_embed) and exists(self.to_image_hiddens):
|
|
||||||
image_hiddens = self.to_image_hiddens(image_embed)
|
|
||||||
t = t + image_hiddens
|
|
||||||
|
|
||||||
# conditional dropout
|
# conditional dropout
|
||||||
|
|
||||||
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
|
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
|
||||||
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
|
text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
|
||||||
|
|
||||||
image_keep_mask, text_keep_mask = rearrange_many((image_keep_mask, text_keep_mask), 'b -> b 1 1')
|
text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1')
|
||||||
|
|
||||||
|
# image embedding to be summed to time embedding
|
||||||
|
# discovered by @mhh0318 in the paper
|
||||||
|
|
||||||
|
if exists(image_embed) and exists(self.to_image_hiddens):
|
||||||
|
image_hiddens = self.to_image_hiddens(image_embed)
|
||||||
|
image_keep_mask_hidden = rearrange(image_keep_mask, 'b -> b 1')
|
||||||
|
null_image_hiddens = self.null_image_hiddens.to(image_hiddens.dtype)
|
||||||
|
|
||||||
|
image_hiddens = torch.where(
|
||||||
|
image_keep_mask_hidden,
|
||||||
|
image_hiddens,
|
||||||
|
null_image_hiddens
|
||||||
|
)
|
||||||
|
|
||||||
|
t = t + image_hiddens
|
||||||
|
|
||||||
# mask out image embedding depending on condition dropout
|
# mask out image embedding depending on condition dropout
|
||||||
# for classifier free guidance
|
# for classifier free guidance
|
||||||
@@ -1584,11 +1601,12 @@ class Unet(nn.Module):
|
|||||||
image_tokens = None
|
image_tokens = None
|
||||||
|
|
||||||
if self.cond_on_image_embeds:
|
if self.cond_on_image_embeds:
|
||||||
|
image_keep_mask_embed = rearrange(image_keep_mask, 'b -> b 1 1')
|
||||||
image_tokens = self.image_to_tokens(image_embed)
|
image_tokens = self.image_to_tokens(image_embed)
|
||||||
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
|
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
|
||||||
|
|
||||||
image_tokens = torch.where(
|
image_tokens = torch.where(
|
||||||
image_keep_mask,
|
image_keep_mask_embed,
|
||||||
image_tokens,
|
image_tokens,
|
||||||
null_image_embed
|
null_image_embed
|
||||||
)
|
)
|
||||||
@@ -1643,7 +1661,10 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
hiddens = []
|
hiddens = []
|
||||||
|
|
||||||
for init_block, sparse_attn, resnet_blocks, downsample in self.downs:
|
for pre_downsample, init_block, sparse_attn, resnet_blocks, post_downsample in self.downs:
|
||||||
|
if exists(pre_downsample):
|
||||||
|
x = pre_downsample(x)
|
||||||
|
|
||||||
x = init_block(x, c, t)
|
x = init_block(x, c, t)
|
||||||
x = sparse_attn(x)
|
x = sparse_attn(x)
|
||||||
|
|
||||||
@@ -1651,7 +1672,9 @@ class Unet(nn.Module):
|
|||||||
x = resnet_block(x, c, t)
|
x = resnet_block(x, c, t)
|
||||||
|
|
||||||
hiddens.append(x)
|
hiddens.append(x)
|
||||||
x = downsample(x)
|
|
||||||
|
if exists(post_downsample):
|
||||||
|
x = post_downsample(x)
|
||||||
|
|
||||||
x = self.mid_block1(x, mid_c, t)
|
x = self.mid_block1(x, mid_c, t)
|
||||||
|
|
||||||
@@ -1661,7 +1684,7 @@ class Unet(nn.Module):
|
|||||||
x = self.mid_block2(x, mid_c, t)
|
x = self.mid_block2(x, mid_c, t)
|
||||||
|
|
||||||
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
|
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
|
||||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
x = torch.cat((x, hiddens.pop()), dim = 1)
|
||||||
x = init_block(x, c, t)
|
x = init_block(x, c, t)
|
||||||
x = sparse_attn(x)
|
x = sparse_attn(x)
|
||||||
|
|
||||||
@@ -1670,13 +1693,16 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
x = upsample(x)
|
x = upsample(x)
|
||||||
|
|
||||||
|
if len(hiddens) > 0:
|
||||||
|
x = torch.cat((x, hiddens.pop()), dim = 1)
|
||||||
|
|
||||||
return self.final_conv(x)
|
return self.final_conv(x)
|
||||||
|
|
||||||
class LowresConditioner(nn.Module):
|
class LowresConditioner(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
downsample_first = True,
|
downsample_first = True,
|
||||||
blur_sigma = 0.1,
|
blur_sigma = (0.1, 0.2),
|
||||||
blur_kernel_size = 3,
|
blur_kernel_size = 3,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -1700,6 +1726,18 @@ class LowresConditioner(nn.Module):
|
|||||||
# when training, blur the low resolution conditional image
|
# when training, blur the low resolution conditional image
|
||||||
blur_sigma = default(blur_sigma, self.blur_sigma)
|
blur_sigma = default(blur_sigma, self.blur_sigma)
|
||||||
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
|
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
|
||||||
|
|
||||||
|
# allow for drawing a random sigma between lo and hi float values
|
||||||
|
if isinstance(blur_sigma, tuple):
|
||||||
|
blur_sigma = tuple(map(float, blur_sigma))
|
||||||
|
blur_sigma = random.uniform(*blur_sigma)
|
||||||
|
|
||||||
|
# allow for drawing a random kernel size between lo and hi int values
|
||||||
|
if isinstance(blur_kernel_size, tuple):
|
||||||
|
blur_kernel_size = tuple(map(int, blur_kernel_size))
|
||||||
|
kernel_size_lo, kernel_size_hi = blur_kernel_size
|
||||||
|
blur_kernel_size = random.randrange(kernel_size_lo, kernel_size_hi + 1)
|
||||||
|
|
||||||
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
||||||
|
|
||||||
cond_fmap = resize_image_to(cond_fmap, target_image_size)
|
cond_fmap = resize_image_to(cond_fmap, target_image_size)
|
||||||
@@ -1725,23 +1763,28 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
image_sizes = None, # for cascading ddpm, image size at each stage
|
image_sizes = None, # for cascading ddpm, image size at each stage
|
||||||
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
||||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||||
blur_sigma = 0.1, # cascading ddpm - blur sigma
|
blur_sigma = (0.1, 0.2), # cascading ddpm - blur sigma
|
||||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||||
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
||||||
clip_denoised = True,
|
clip_denoised = True,
|
||||||
clip_x_start = True,
|
clip_x_start = True,
|
||||||
clip_adapter_overrides = dict(),
|
clip_adapter_overrides = dict(),
|
||||||
learned_variance = True,
|
learned_variance = True,
|
||||||
|
learned_variance_constrain_frac = False,
|
||||||
vb_loss_weight = 0.001,
|
vb_loss_weight = 0.001,
|
||||||
unconditional = False,
|
unconditional = False,
|
||||||
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
|
||||||
use_dynamic_thres = False, # from the Imagen paper
|
use_dynamic_thres = False, # from the Imagen paper
|
||||||
dynamic_thres_percentile = 0.9
|
dynamic_thres_percentile = 0.9,
|
||||||
|
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
|
||||||
|
p2_loss_weight_k = 1
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
beta_schedule = beta_schedule,
|
beta_schedule = beta_schedule,
|
||||||
timesteps = timesteps,
|
timesteps = timesteps,
|
||||||
loss_type = loss_type
|
loss_type = loss_type,
|
||||||
|
p2_loss_weight_gamma = p2_loss_weight_gamma,
|
||||||
|
p2_loss_weight_k = p2_loss_weight_k
|
||||||
)
|
)
|
||||||
|
|
||||||
self.unconditional = unconditional
|
self.unconditional = unconditional
|
||||||
@@ -1792,6 +1835,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)
|
learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)
|
||||||
self.learned_variance = learned_variance
|
self.learned_variance = learned_variance
|
||||||
|
self.learned_variance_constrain_frac = learned_variance_constrain_frac # whether to constrain the output of the network (the interpolation fraction) from 0 to 1
|
||||||
self.vb_loss_weight = vb_loss_weight
|
self.vb_loss_weight = vb_loss_weight
|
||||||
|
|
||||||
# construct unets and vaes
|
# construct unets and vaes
|
||||||
@@ -1932,16 +1976,19 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
max_log = extract(torch.log(self.betas), t, x.shape)
|
max_log = extract(torch.log(self.betas), t, x.shape)
|
||||||
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
|
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
|
||||||
|
|
||||||
|
if self.learned_variance_constrain_frac:
|
||||||
|
var_interp_frac = var_interp_frac.sigmoid()
|
||||||
|
|
||||||
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
|
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
|
||||||
posterior_variance = posterior_log_variance.exp()
|
posterior_variance = posterior_log_variance.exp()
|
||||||
|
|
||||||
return model_mean, posterior_variance, posterior_log_variance
|
return model_mean, posterior_variance, posterior_log_variance
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, repeat_noise = False):
|
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
|
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance)
|
||||||
noise = noise_like(x.shape, device, repeat_noise)
|
noise = torch.randn_like(x)
|
||||||
# no noise when t == 0
|
# no noise when t == 0
|
||||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||||
@@ -2005,7 +2052,13 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
target = noise if not predict_x_start else x_start
|
target = noise if not predict_x_start else x_start
|
||||||
|
|
||||||
loss = self.loss_fn(pred, target)
|
loss = self.loss_fn(pred, target, reduction = 'none')
|
||||||
|
loss = reduce(loss, 'b ... -> b (...)', 'mean')
|
||||||
|
|
||||||
|
if self.has_p2_loss_reweighting:
|
||||||
|
loss = loss * extract(self.p2_loss_weight, times, loss.shape)
|
||||||
|
|
||||||
|
loss = loss.mean()
|
||||||
|
|
||||||
if not learned_variance:
|
if not learned_variance:
|
||||||
# return simple loss if not using learned variance
|
# return simple loss if not using learned variance
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embed
|
|||||||
|
|
||||||
# Create a dataloader directly.
|
# Create a dataloader directly.
|
||||||
dataloader = create_image_embedding_dataloader(
|
dataloader = create_image_embedding_dataloader(
|
||||||
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
|
||||||
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
|
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
|
|||||||
@@ -1,15 +1,17 @@
|
|||||||
from torch.optim import AdamW, Adam
|
from torch.optim import AdamW, Adam
|
||||||
|
|
||||||
def separate_weight_decayable_params(params):
|
def separate_weight_decayable_params(params):
|
||||||
no_wd_params = set([param for param in params if param.ndim < 2])
|
wd_params, no_wd_params = [], []
|
||||||
wd_params = set(params) - no_wd_params
|
for param in params:
|
||||||
|
param_list = no_wd_params if param.ndim < 2 else wd_params
|
||||||
|
param_list.append(param)
|
||||||
return wd_params, no_wd_params
|
return wd_params, no_wd_params
|
||||||
|
|
||||||
def get_optimizer(
|
def get_optimizer(
|
||||||
params,
|
params,
|
||||||
lr = 1e-4,
|
lr = 1e-4,
|
||||||
wd = 1e-2,
|
wd = 1e-2,
|
||||||
betas = (0.9, 0.999),
|
betas = (0.9, 0.99),
|
||||||
eps = 1e-8,
|
eps = 1e-8,
|
||||||
filter_by_requires_grad = False,
|
filter_by_requires_grad = False,
|
||||||
group_wd_params = True,
|
group_wd_params = True,
|
||||||
@@ -25,8 +27,8 @@ def get_optimizer(
|
|||||||
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||||
|
|
||||||
params = [
|
params = [
|
||||||
{'params': list(wd_params)},
|
{'params': wd_params},
|
||||||
{'params': list(no_wd_params), 'weight_decay': 0},
|
{'params': no_wd_params, 'weight_decay': 0},
|
||||||
]
|
]
|
||||||
|
|
||||||
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
# to give users a quick easy start to training DALL-E without doing BPE
|
# to give users a quick easy start to training DALL-E without doing BPE
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import youtokentome as yttm
|
|
||||||
|
|
||||||
import html
|
import html
|
||||||
import os
|
import os
|
||||||
@@ -11,6 +10,8 @@ import regex as re
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from dalle2_pytorch.utils import import_or_print_error
|
||||||
|
|
||||||
# OpenAI simple tokenizer
|
# OpenAI simple tokenizer
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
@@ -156,7 +157,9 @@ class YttmTokenizer:
|
|||||||
bpe_path = Path(bpe_path)
|
bpe_path = Path(bpe_path)
|
||||||
assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'
|
assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'
|
||||||
|
|
||||||
tokenizer = yttm.BPE(model = str(bpe_path))
|
self.yttm = import_or_print_error('youtokentome', 'you need to install youtokentome by `pip install youtokentome`')
|
||||||
|
|
||||||
|
tokenizer = self.yttm.BPE(model = str(bpe_path))
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.vocab_size = tokenizer.vocab_size()
|
self.vocab_size = tokenizer.vocab_size()
|
||||||
|
|
||||||
@@ -167,7 +170,7 @@ class YttmTokenizer:
|
|||||||
return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0}))
|
return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0}))
|
||||||
|
|
||||||
def encode(self, texts):
|
def encode(self, texts):
|
||||||
encoded = self.tokenizer.encode(texts, output_type = yttm.OutputType.ID)
|
encoded = self.tokenizer.encode(texts, output_type = self.yttm.OutputType.ID)
|
||||||
return list(map(torch.tensor, encoded))
|
return list(map(torch.tensor, encoded))
|
||||||
|
|
||||||
def tokenize(self, texts, context_length = 256, truncate_text = False):
|
def tokenize(self, texts, context_length = 256, truncate_text = False):
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ from itertools import zip_longest
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from dalle2_pytorch.utils import import_or_print_error
|
||||||
|
|
||||||
# constants
|
# constants
|
||||||
|
|
||||||
DEFAULT_DATA_PATH = './.tracker-data'
|
DEFAULT_DATA_PATH = './.tracker-data'
|
||||||
@@ -15,14 +17,6 @@ DEFAULT_DATA_PATH = './.tracker-data'
|
|||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
|
||||||
def import_or_print_error(pkg_name, err_str = None):
|
|
||||||
try:
|
|
||||||
return importlib.import_module(pkg_name)
|
|
||||||
except ModuleNotFoundError as e:
|
|
||||||
if exists(err_str):
|
|
||||||
print(err_str)
|
|
||||||
exit()
|
|
||||||
|
|
||||||
# load state dict functions
|
# load state dict functions
|
||||||
|
|
||||||
def load_wandb_state_dict(run_path, file_path, **kwargs):
|
def load_wandb_state_dict(run_path, file_path, **kwargs):
|
||||||
|
|||||||
@@ -58,8 +58,15 @@ def num_to_groups(num, divisor):
|
|||||||
arr.append(remainder)
|
arr.append(remainder)
|
||||||
return arr
|
return arr
|
||||||
|
|
||||||
def get_pkg_version():
|
def clamp(value, min_value = None, max_value = None):
|
||||||
return __version__
|
assert exists(min_value) or exists(max_value)
|
||||||
|
if exists(min_value):
|
||||||
|
value = max(value, min_value)
|
||||||
|
|
||||||
|
if exists(max_value):
|
||||||
|
value = min(value, max_value)
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
# decorators
|
# decorators
|
||||||
|
|
||||||
@@ -175,12 +182,34 @@ def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embe
|
|||||||
# exponential moving average wrapper
|
# exponential moving average wrapper
|
||||||
|
|
||||||
class EMA(nn.Module):
|
class EMA(nn.Module):
|
||||||
|
"""
|
||||||
|
Implements exponential moving average shadowing for your model.
|
||||||
|
|
||||||
|
Utilizes an inverse decay schedule to manage longer term training runs.
|
||||||
|
By adjusting the power, you can control how fast EMA will ramp up to your specified beta.
|
||||||
|
|
||||||
|
@crowsonkb's notes on EMA Warmup:
|
||||||
|
|
||||||
|
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
|
||||||
|
good values for models you plan to train for a million or more steps (reaches decay
|
||||||
|
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
|
||||||
|
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
|
||||||
|
215.4k steps).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
||||||
|
power (float): Exponential factor of EMA warmup. Default: 1.
|
||||||
|
min_value (float): The minimum EMA decay rate. Default: 0.
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
beta = 0.9999,
|
beta = 0.9999,
|
||||||
update_after_step = 1000,
|
update_after_step = 10000,
|
||||||
update_every = 10,
|
update_every = 10,
|
||||||
|
inv_gamma = 1.0,
|
||||||
|
power = 2/3,
|
||||||
|
min_value = 0.0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
@@ -188,7 +217,11 @@ class EMA(nn.Module):
|
|||||||
self.ema_model = copy.deepcopy(model)
|
self.ema_model = copy.deepcopy(model)
|
||||||
|
|
||||||
self.update_every = update_every
|
self.update_every = update_every
|
||||||
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0
|
self.update_after_step = update_after_step
|
||||||
|
|
||||||
|
self.inv_gamma = inv_gamma
|
||||||
|
self.power = power
|
||||||
|
self.min_value = min_value
|
||||||
|
|
||||||
self.register_buffer('initted', torch.Tensor([False]))
|
self.register_buffer('initted', torch.Tensor([False]))
|
||||||
self.register_buffer('step', torch.tensor([0]))
|
self.register_buffer('step', torch.tensor([0]))
|
||||||
@@ -198,37 +231,51 @@ class EMA(nn.Module):
|
|||||||
self.ema_model.to(device)
|
self.ema_model.to(device)
|
||||||
|
|
||||||
def copy_params_from_model_to_ema(self):
|
def copy_params_from_model_to_ema(self):
|
||||||
self.ema_model.state_dict(self.online_model.state_dict())
|
for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())):
|
||||||
|
ma_param.data.copy_(current_param.data)
|
||||||
|
|
||||||
|
for ma_buffer, current_buffer in zip(list(self.ema_model.buffers()), list(self.online_model.buffers())):
|
||||||
|
ma_buffer.data.copy_(current_buffer.data)
|
||||||
|
|
||||||
|
def get_current_decay(self):
|
||||||
|
epoch = clamp(self.step.item() - self.update_after_step - 1, min_value = 0)
|
||||||
|
value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
|
||||||
|
|
||||||
|
if epoch <= 0:
|
||||||
|
return 0.
|
||||||
|
|
||||||
|
return clamp(value, min_value = self.min_value, max_value = self.beta)
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
|
step = self.step.item()
|
||||||
self.step += 1
|
self.step += 1
|
||||||
|
|
||||||
if (self.step % self.update_every) != 0:
|
if (step % self.update_every) != 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.step <= self.update_after_step:
|
if step <= self.update_after_step:
|
||||||
self.copy_params_from_model_to_ema()
|
self.copy_params_from_model_to_ema()
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self.initted:
|
if not self.initted.item():
|
||||||
self.copy_params_from_model_to_ema()
|
self.copy_params_from_model_to_ema()
|
||||||
self.initted.data.copy_(torch.Tensor([True]))
|
self.initted.data.copy_(torch.Tensor([True]))
|
||||||
|
|
||||||
self.update_moving_average(self.ema_model, self.online_model)
|
self.update_moving_average(self.ema_model, self.online_model)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def update_moving_average(self, ma_model, current_model):
|
def update_moving_average(self, ma_model, current_model):
|
||||||
def calculate_ema(beta, old, new):
|
current_decay = self.get_current_decay()
|
||||||
if not exists(old):
|
|
||||||
return new
|
|
||||||
return old * beta + (1 - beta) * new
|
|
||||||
|
|
||||||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
for current_params, ma_params in zip(list(current_model.parameters()), list(ma_model.parameters())):
|
||||||
old_weight, up_weight = ma_params.data, current_params.data
|
difference = ma_params.data - current_params.data
|
||||||
ma_params.data = calculate_ema(self.beta, old_weight, up_weight)
|
difference.mul_(1.0 - current_decay)
|
||||||
|
ma_params.sub_(difference)
|
||||||
|
|
||||||
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
|
for current_buffer, ma_buffer in zip(list(current_model.buffers()), list(ma_model.buffers())):
|
||||||
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
|
difference = ma_buffer - current_buffer
|
||||||
ma_buffer.copy_(new_buffer_value)
|
difference.mul_(1.0 - current_decay)
|
||||||
|
ma_buffer.sub_(difference)
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.ema_model(*args, **kwargs)
|
return self.ema_model(*args, **kwargs)
|
||||||
@@ -488,7 +535,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
loaded_obj = torch.load(str(path))
|
loaded_obj = torch.load(str(path))
|
||||||
|
|
||||||
if version.parse(__version__) != loaded_obj['version']:
|
if version.parse(__version__) != loaded_obj['version']:
|
||||||
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')
|
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
|
||||||
|
|
||||||
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
|
self.decoder.load_state_dict(loaded_obj['model'], strict = strict)
|
||||||
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
|
||||||
|
|||||||
@@ -17,3 +17,13 @@ class Timer:
|
|||||||
def print_ribbon(s, symbol = '=', repeat = 40):
|
def print_ribbon(s, symbol = '=', repeat = 40):
|
||||||
flank = symbol * repeat
|
flank = symbol * repeat
|
||||||
return f'{flank} {s} {flank}'
|
return f'{flank} {s} {flank}'
|
||||||
|
|
||||||
|
# import helpers
|
||||||
|
|
||||||
|
def import_or_print_error(pkg_name, err_str = None):
|
||||||
|
try:
|
||||||
|
return importlib.import_module(pkg_name)
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
if exists(err_str):
|
||||||
|
print(err_str)
|
||||||
|
exit()
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.6.2'
|
__version__ = '0.9.0'
|
||||||
|
|||||||
@@ -68,8 +68,8 @@ def group_dict_by_key(cond, d):
|
|||||||
return_val[ind][key] = d[key]
|
return_val[ind][key] = d[key]
|
||||||
return (*return_val,)
|
return (*return_val,)
|
||||||
|
|
||||||
def string_begins_with(prefix, str):
|
def string_begins_with(prefix, string_input):
|
||||||
return str.startswith(prefix)
|
return string_input.startswith(prefix)
|
||||||
|
|
||||||
def group_by_key_prefix(prefix, d):
|
def group_by_key_prefix(prefix, d):
|
||||||
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -32,6 +32,7 @@ setup(
|
|||||||
'embedding-reader',
|
'embedding-reader',
|
||||||
'kornia>=0.5.4',
|
'kornia>=0.5.4',
|
||||||
'numpy',
|
'numpy',
|
||||||
|
'packaging',
|
||||||
'pillow',
|
'pillow',
|
||||||
'pydantic',
|
'pydantic',
|
||||||
'resize-right>=0.0.2',
|
'resize-right>=0.0.2',
|
||||||
@@ -41,7 +42,6 @@ setup(
|
|||||||
'tqdm',
|
'tqdm',
|
||||||
'vector-quantize-pytorch',
|
'vector-quantize-pytorch',
|
||||||
'x-clip>=0.4.4',
|
'x-clip>=0.4.4',
|
||||||
'youtokentome',
|
|
||||||
'webdataset>=0.2.5',
|
'webdataset>=0.2.5',
|
||||||
'fsspec>=2022.1.0',
|
'fsspec>=2022.1.0',
|
||||||
'torchmetrics[image]>=0.8.0'
|
'torchmetrics[image]>=0.8.0'
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
|||||||
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
||||||
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
||||||
from dalle2_pytorch.utils import Timer, print_ribbon
|
from dalle2_pytorch.utils import Timer, print_ribbon
|
||||||
|
from dalle2_pytorch.dalle2_pytorch import resize_image_to
|
||||||
|
|
||||||
import torchvision
|
import torchvision
|
||||||
import torch
|
import torch
|
||||||
@@ -136,6 +137,14 @@ def generate_grid_samples(trainer, examples, text_prepend=""):
|
|||||||
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
||||||
"""
|
"""
|
||||||
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
|
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
|
||||||
|
|
||||||
|
real_image_size = real_images[0].shape[-1]
|
||||||
|
generated_image_size = generated_images[0].shape[-1]
|
||||||
|
|
||||||
|
# training images may be larger than the generated one
|
||||||
|
if real_image_size > generated_image_size:
|
||||||
|
real_images = [resize_image_to(image, generated_image_size) for image in real_images]
|
||||||
|
|
||||||
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
||||||
return grid_images, captions
|
return grid_images, captions
|
||||||
|
|
||||||
@@ -202,7 +211,7 @@ def recall_trainer(tracker, trainer, recall_source=None, **load_config):
|
|||||||
Loads the model with an appropriate method depending on the tracker
|
Loads the model with an appropriate method depending on the tracker
|
||||||
"""
|
"""
|
||||||
print(print_ribbon(f"Loading model from {recall_source}"))
|
print(print_ribbon(f"Loading model from {recall_source}"))
|
||||||
state_dict = tracker.recall_state_dict(recall_source, **load_config)
|
state_dict = tracker.recall_state_dict(recall_source, **load_config.dict())
|
||||||
trainer.load_state_dict(state_dict["trainer"])
|
trainer.load_state_dict(state_dict["trainer"])
|
||||||
print("Model loaded")
|
print("Model loaded")
|
||||||
return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"]
|
return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"]
|
||||||
@@ -322,7 +331,7 @@ def train(
|
|||||||
sample = 0
|
sample = 0
|
||||||
average_loss = 0
|
average_loss = 0
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
|
for i, (img, emb, *_) in enumerate(dataloaders["val"]):
|
||||||
sample += img.shape[0]
|
sample += img.shape[0]
|
||||||
img, emb = send_to_device((img, emb))
|
img, emb = send_to_device((img, emb))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user