Compare commits

...

21 Commits

Author SHA1 Message Date
Phil Wang
ce4b0107c1 0.6.13 2022-06-04 13:26:57 -07:00
zion
64c2f9c4eb implement ema warmup from @crowsonkb (#140) 2022-06-04 13:26:34 -07:00
Phil Wang
22cc613278 ema fix from @nousr 2022-06-03 19:44:36 -07:00
zion
83517849e5 ema module fixes (#139) 2022-06-03 19:43:51 -07:00
Phil Wang
708809ed6c lower beta2 for adam down to 0.99, based on https://openreview.net/forum?id=2LdBqxc1Yv 2022-06-03 10:26:28 -07:00
Phil Wang
9cc475f6e7 fix update_every within EMA 2022-06-03 10:21:05 -07:00
Phil Wang
ffd342e9d0 allow for an option to constrain the variance interpolation fraction coming out from the unet for learned variance, if it is turned on 2022-06-03 09:34:57 -07:00
Phil Wang
f8bfd3493a make destructuring datum length agnostic when validating in training decoder script, for @YUHANG-Ma 2022-06-02 13:54:57 -07:00
Phil Wang
9025345e29 take a stab at fixing generate_grid_samples when real images have a greater image size than generated 2022-06-02 11:33:15 -07:00
Phil Wang
8cc278447e just cast to right types for blur sigma and kernel size augs 2022-06-02 11:21:58 -07:00
Phil Wang
38cd62010c allow for random blur sigma and kernel size augmentations on low res conditioning (need to reread paper to see if the augmentation value needs to be fed into the unet for conditioning as well) 2022-06-02 11:11:25 -07:00
Ryan Russell
1cc288af39 Improve Readability (#133)
Signed-off-by: Ryan Russell <git@ryanrussell.org>
2022-06-01 13:28:02 -07:00
Phil Wang
a851168633 make youtokentome optional package, due to reported installation difficulties 2022-06-01 09:25:35 -07:00
Phil Wang
1ffeecd0ca lower default ema beta value 2022-05-31 11:55:21 -07:00
Phil Wang
3df899f7a4 patch 2022-05-31 09:03:43 -07:00
Aidan Dempster
09534119a1 Fixed non deterministic optimizer creation (#130) 2022-05-31 09:03:20 -07:00
Phil Wang
6f8b90d4d7 add packaging package 2022-05-30 11:45:00 -07:00
Phil Wang
b588286288 fix version 2022-05-30 11:06:34 -07:00
Phil Wang
b693e0be03 default number of resnet blocks per layer in unet to 2 (in imagen it was 3 for base 64x64) 2022-05-30 10:06:48 -07:00
Phil Wang
a0bed30a84 additional conditioning on image embedding by summing to time embeddings (for FiLM like conditioning in subsequent layers), from passage found in paper by @mhh0318 2022-05-30 09:26:51 -07:00
zion
387c5bf774 quick patch for new prior loader (#123) 2022-05-29 16:25:53 -07:00
14 changed files with 177 additions and 71 deletions

View File

@@ -943,7 +943,7 @@ from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embed
# Create a dataloader directly. # Create a dataloader directly.
dataloader = create_image_embedding_dataloader( dataloader = create_image_embedding_dataloader(
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
num_workers=4, num_workers=4,
batch_size=32, batch_size=32,
@@ -1097,7 +1097,7 @@ This library would not have gotten to this working state without the help of
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697 - [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2 - [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783 - [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training - [ ] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
- [ ] decoder needs one day worth of refactor for tech debt - [ ] decoder needs one day worth of refactor for tech debt
- [ ] allow for unet to be able to condition non-cross attention style as well - [ ] allow for unet to be able to condition non-cross attention style as well
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89 - [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89

View File

@@ -83,7 +83,7 @@ Defines which evaluation metrics will be used to test the model.
Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked. Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked.
| Option | Required | Default | Description | | Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- | | ------ | -------- | ------- | ----------- |
| `n_evalation_samples` | No | `1000` | The number of samples to generate to test the model. | | `n_evaluation_samples` | No | `1000` | The number of samples to generate to test the model. |
| `FID` | No | `None` | Setting to an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric. | `FID` | No | `None` | Setting to an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric.
| `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric. | `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric.
| `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. | | `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. |

View File

@@ -1,3 +1,4 @@
from dalle2_pytorch.version import __version__
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer

View File

@@ -1,4 +1,5 @@
import math import math
import random
from tqdm import tqdm from tqdm import tqdm
from inspect import isfunction from inspect import isfunction
from functools import partial, wraps from functools import partial, wraps
@@ -1343,10 +1344,11 @@ class Unet(nn.Module):
cond_on_text_encodings = False, cond_on_text_encodings = False,
max_text_len = 256, max_text_len = 256,
cond_on_image_embeds = False, cond_on_image_embeds = False,
add_image_embeds_to_time = True, # alerted by @mhh0318 to a phrase in the paper - "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and adding CLIP embeddings to the existing timestep embedding"
init_dim = None, init_dim = None,
init_conv_kernel_size = 7, init_conv_kernel_size = 7,
resnet_groups = 8, resnet_groups = 8,
num_resnet_blocks = 1, num_resnet_blocks = 2,
init_cross_embed_kernel_sizes = (3, 7, 15), init_cross_embed_kernel_sizes = (3, 7, 15),
cross_embed_downsample = False, cross_embed_downsample = False,
cross_embed_downsample_kernel_sizes = (2, 4), cross_embed_downsample_kernel_sizes = (2, 4),
@@ -1396,11 +1398,16 @@ class Unet(nn.Module):
nn.Linear(time_cond_dim, time_cond_dim) nn.Linear(time_cond_dim, time_cond_dim)
) )
self.image_to_cond = nn.Sequential( self.image_to_tokens = nn.Sequential(
nn.Linear(image_embed_dim, cond_dim * num_image_tokens), nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
Rearrange('b (n d) -> b n d', n = num_image_tokens) Rearrange('b (n d) -> b n d', n = num_image_tokens)
) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity() ) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity()
self.to_image_hiddens = nn.Sequential(
nn.Linear(image_embed_dim, time_cond_dim),
nn.GELU()
) if cond_on_image_embeds and add_image_embeds_to_time else None
self.norm_cond = nn.LayerNorm(cond_dim) self.norm_cond = nn.LayerNorm(cond_dim)
self.norm_mid_cond = nn.LayerNorm(cond_dim) self.norm_mid_cond = nn.LayerNorm(cond_dim)
@@ -1558,6 +1565,13 @@ class Unet(nn.Module):
time_tokens = self.to_time_tokens(time_hiddens) time_tokens = self.to_time_tokens(time_hiddens)
t = self.to_time_cond(time_hiddens) t = self.to_time_cond(time_hiddens)
# image embedding to be summed to time embedding
# discovered by @mhh0318 in the paper
if exists(image_embed) and exists(self.to_image_hiddens):
image_hiddens = self.to_image_hiddens(image_embed)
t = t + image_hiddens
# conditional dropout # conditional dropout
image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device) image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
@@ -1571,7 +1585,7 @@ class Unet(nn.Module):
image_tokens = None image_tokens = None
if self.cond_on_image_embeds: if self.cond_on_image_embeds:
image_tokens = self.image_to_cond(image_embed) image_tokens = self.image_to_tokens(image_embed)
null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
image_tokens = torch.where( image_tokens = torch.where(
@@ -1663,7 +1677,7 @@ class LowresConditioner(nn.Module):
def __init__( def __init__(
self, self,
downsample_first = True, downsample_first = True,
blur_sigma = 0.1, blur_sigma = (0.1, 0.2),
blur_kernel_size = 3, blur_kernel_size = 3,
): ):
super().__init__() super().__init__()
@@ -1687,6 +1701,18 @@ class LowresConditioner(nn.Module):
# when training, blur the low resolution conditional image # when training, blur the low resolution conditional image
blur_sigma = default(blur_sigma, self.blur_sigma) blur_sigma = default(blur_sigma, self.blur_sigma)
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size) blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
# allow for drawing a random sigma between lo and hi float values
if isinstance(blur_sigma, tuple):
blur_sigma = tuple(map(float, blur_sigma))
blur_sigma = random.uniform(*blur_sigma)
# allow for drawing a random kernel size between lo and hi int values
if isinstance(blur_kernel_size, tuple):
blur_kernel_size = tuple(map(int, blur_kernel_size))
kernel_size_lo, kernel_size_hi = blur_kernel_size
blur_kernel_size = random.randrange(kernel_size_lo, kernel_size_hi + 1)
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2)) cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
cond_fmap = resize_image_to(cond_fmap, target_image_size) cond_fmap = resize_image_to(cond_fmap, target_image_size)
@@ -1712,13 +1738,14 @@ class Decoder(BaseGaussianDiffusion):
image_sizes = None, # for cascading ddpm, image size at each stage image_sizes = None, # for cascading ddpm, image size at each stage
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops) random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
blur_sigma = 0.1, # cascading ddpm - blur sigma blur_sigma = (0.1, 0.2), # cascading ddpm - blur sigma
blur_kernel_size = 3, # cascading ddpm - blur kernel size blur_kernel_size = 3, # cascading ddpm - blur kernel size
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
clip_denoised = True, clip_denoised = True,
clip_x_start = True, clip_x_start = True,
clip_adapter_overrides = dict(), clip_adapter_overrides = dict(),
learned_variance = True, learned_variance = True,
learned_variance_constrain_frac = False,
vb_loss_weight = 0.001, vb_loss_weight = 0.001,
unconditional = False, unconditional = False,
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
@@ -1779,6 +1806,7 @@ class Decoder(BaseGaussianDiffusion):
learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False) learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)
self.learned_variance = learned_variance self.learned_variance = learned_variance
self.learned_variance_constrain_frac = learned_variance_constrain_frac # whether to constrain the output of the network (the interpolation fraction) from 0 to 1
self.vb_loss_weight = vb_loss_weight self.vb_loss_weight = vb_loss_weight
# construct unets and vaes # construct unets and vaes
@@ -1919,6 +1947,9 @@ class Decoder(BaseGaussianDiffusion):
max_log = extract(torch.log(self.betas), t, x.shape) max_log = extract(torch.log(self.betas), t, x.shape)
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized) var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
if self.learned_variance_constrain_frac:
var_interp_frac = var_interp_frac.sigmoid()
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
posterior_variance = posterior_log_variance.exp() posterior_variance = posterior_log_variance.exp()

View File

@@ -15,7 +15,7 @@ from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embed
# Create a dataloader directly. # Create a dataloader directly.
dataloader = create_image_embedding_dataloader( dataloader = create_image_embedding_dataloader(
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
num_workers=4, num_workers=4,
batch_size=32, batch_size=32,

View File

@@ -1,15 +1,17 @@
from torch.optim import AdamW, Adam from torch.optim import AdamW, Adam
def separate_weight_decayable_params(params): def separate_weight_decayable_params(params):
no_wd_params = set([param for param in params if param.ndim < 2]) wd_params, no_wd_params = [], []
wd_params = set(params) - no_wd_params for param in params:
param_list = no_wd_params if param.ndim < 2 else wd_params
param_list.append(param)
return wd_params, no_wd_params return wd_params, no_wd_params
def get_optimizer( def get_optimizer(
params, params,
lr = 1e-4, lr = 1e-4,
wd = 1e-2, wd = 1e-2,
betas = (0.9, 0.999), betas = (0.9, 0.99),
eps = 1e-8, eps = 1e-8,
filter_by_requires_grad = False, filter_by_requires_grad = False,
group_wd_params = True, group_wd_params = True,
@@ -25,8 +27,8 @@ def get_optimizer(
wd_params, no_wd_params = separate_weight_decayable_params(params) wd_params, no_wd_params = separate_weight_decayable_params(params)
params = [ params = [
{'params': list(wd_params)}, {'params': wd_params},
{'params': list(no_wd_params), 'weight_decay': 0}, {'params': no_wd_params, 'weight_decay': 0},
] ]
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps) return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)

View File

@@ -2,7 +2,6 @@
# to give users a quick easy start to training DALL-E without doing BPE # to give users a quick easy start to training DALL-E without doing BPE
import torch import torch
import youtokentome as yttm
import html import html
import os import os
@@ -11,6 +10,8 @@ import regex as re
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from dalle2_pytorch.utils import import_or_print_error
# OpenAI simple tokenizer # OpenAI simple tokenizer
@lru_cache() @lru_cache()
@@ -156,7 +157,9 @@ class YttmTokenizer:
bpe_path = Path(bpe_path) bpe_path = Path(bpe_path)
assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist' assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'
tokenizer = yttm.BPE(model = str(bpe_path)) self.yttm = import_or_print_error('youtokentome', 'you need to install youtokentome by `pip install youtokentome`')
tokenizer = self.yttm.BPE(model = str(bpe_path))
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.vocab_size = tokenizer.vocab_size() self.vocab_size = tokenizer.vocab_size()
@@ -167,7 +170,7 @@ class YttmTokenizer:
return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0})) return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0}))
def encode(self, texts): def encode(self, texts):
encoded = self.tokenizer.encode(texts, output_type = yttm.OutputType.ID) encoded = self.tokenizer.encode(texts, output_type = self.yttm.OutputType.ID)
return list(map(torch.tensor, encoded)) return list(map(torch.tensor, encoded))
def tokenize(self, texts, context_length = 256, truncate_text = False): def tokenize(self, texts, context_length = 256, truncate_text = False):

View File

@@ -6,6 +6,8 @@ from itertools import zip_longest
import torch import torch
from torch import nn from torch import nn
from dalle2_pytorch.utils import import_or_print_error
# constants # constants
DEFAULT_DATA_PATH = './.tracker-data' DEFAULT_DATA_PATH = './.tracker-data'
@@ -15,14 +17,6 @@ DEFAULT_DATA_PATH = './.tracker-data'
def exists(val): def exists(val):
return val is not None return val is not None
def import_or_print_error(pkg_name, err_str = None):
try:
return importlib.import_module(pkg_name)
except ModuleNotFoundError as e:
if exists(err_str):
print(err_str)
exit()
# load state dict functions # load state dict functions
def load_wandb_state_dict(run_path, file_path, **kwargs): def load_wandb_state_dict(run_path, file_path, **kwargs):

View File

@@ -11,6 +11,8 @@ from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.optimizer import get_optimizer from dalle2_pytorch.optimizer import get_optimizer
from dalle2_pytorch.version import __version__
from packaging import version
import numpy as np import numpy as np
@@ -57,8 +59,7 @@ def num_to_groups(num, divisor):
return arr return arr
def get_pkg_version(): def get_pkg_version():
from pkg_resources import get_distribution return __version__
return get_distribution('dalle2_pytorch').version
# decorators # decorators
@@ -174,12 +175,34 @@ def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embe
# exponential moving average wrapper # exponential moving average wrapper
class EMA(nn.Module): class EMA(nn.Module):
"""
Implements exponential moving average shadowing for your model.
Utilizes an inverse decay schedule to manage longer term training runs.
By adjusting the power, you can control how fast EMA will ramp up to your specified beta.
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
good values for models you plan to train for a million or more steps (reaches decay
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 1.
min_value (float): The minimum EMA decay rate. Default: 0.
"""
def __init__( def __init__(
self, self,
model, model,
beta = 0.9999, beta = 0.9999,
update_after_step = 1000, update_after_step = 10000,
update_every = 10, update_every = 10,
inv_gamma = 1.0,
power = 2/3,
min_value = 0.0,
): ):
super().__init__() super().__init__()
self.beta = beta self.beta = beta
@@ -187,7 +210,11 @@ class EMA(nn.Module):
self.ema_model = copy.deepcopy(model) self.ema_model = copy.deepcopy(model)
self.update_every = update_every self.update_every = update_every
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0 self.update_after_step = update_after_step
self.inv_gamma = inv_gamma
self.power = power
self.min_value = min_value
self.register_buffer('initted', torch.Tensor([False])) self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0])) self.register_buffer('step', torch.tensor([0]))
@@ -197,37 +224,44 @@ class EMA(nn.Module):
self.ema_model.to(device) self.ema_model.to(device)
def copy_params_from_model_to_ema(self): def copy_params_from_model_to_ema(self):
self.ema_model.state_dict(self.online_model.state_dict()) for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())):
ma_param.data.copy_(current_param.data)
def get_current_decay(self):
epoch = max(0, self.step.item() - self.update_after_step - 1)
value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
return 0. if epoch < 0 else min(self.beta, max(self.min_value, value))
def update(self): def update(self):
step = self.step.item()
self.step += 1 self.step += 1
if (self.step % self.update_every) != 0: if (step % self.update_every) != 0:
return return
if self.step <= self.update_after_step: if step <= self.update_after_step:
self.copy_params_from_model_to_ema() self.copy_params_from_model_to_ema()
return return
if not self.initted: if not self.initted.item():
self.copy_params_from_model_to_ema() self.copy_params_from_model_to_ema()
self.initted.data.copy_(torch.Tensor([True])) self.initted.data.copy_(torch.Tensor([True]))
self.update_moving_average(self.ema_model, self.online_model) self.update_moving_average(self.ema_model, self.online_model)
@torch.no_grad()
def update_moving_average(self, ma_model, current_model): def update_moving_average(self, ma_model, current_model):
def calculate_ema(beta, old, new): current_decay = self.get_current_decay()
if not exists(old):
return new
return old * beta + (1 - beta) * new
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): for current_params, ma_params in zip(list(current_model.parameters()), list(ma_model.parameters())):
old_weight, up_weight = ma_params.data, current_params.data difference = ma_params.data - current_params.data
ma_params.data = calculate_ema(self.beta, old_weight, up_weight) difference.mul_(1.0 - current_decay)
ma_params.sub_(difference)
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()): for current_buffer, ma_buffer in zip(list(current_model.buffers()), list(ma_model.buffers())):
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer) difference = ma_buffer - current_buffer
ma_buffer.copy_(new_buffer_value) difference.mul_(1.0 - current_decay)
ma_buffer.sub_(difference)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs) return self.ema_model(*args, **kwargs)
@@ -299,7 +333,7 @@ class DiffusionPriorTrainer(nn.Module):
scaler = self.scaler.state_dict(), scaler = self.scaler.state_dict(),
optimizer = self.optimizer.state_dict(), optimizer = self.optimizer.state_dict(),
model = self.diffusion_prior.state_dict(), model = self.diffusion_prior.state_dict(),
version = get_pkg_version(), version = __version__,
step = self.step.item(), step = self.step.item(),
**kwargs **kwargs
) )
@@ -315,8 +349,8 @@ class DiffusionPriorTrainer(nn.Module):
loaded_obj = torch.load(str(path)) loaded_obj = torch.load(str(path))
if get_pkg_version() != loaded_obj['version']: if version.parse(__version__) != loaded_obj['version']:
print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {get_pkg_version()}') print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict) self.diffusion_prior.load_state_dict(loaded_obj['model'], strict = strict)
self.step.copy_(torch.ones_like(self.step) * loaded_obj['step']) self.step.copy_(torch.ones_like(self.step) * loaded_obj['step'])
@@ -463,7 +497,7 @@ class DecoderTrainer(nn.Module):
save_obj = dict( save_obj = dict(
model = self.decoder.state_dict(), model = self.decoder.state_dict(),
version = get_pkg_version(), version = __version__,
step = self.step.item(), step = self.step.item(),
**kwargs **kwargs
) )
@@ -486,7 +520,7 @@ class DecoderTrainer(nn.Module):
loaded_obj = torch.load(str(path)) loaded_obj = torch.load(str(path))
if get_pkg_version() != loaded_obj['version']: if version.parse(__version__) != loaded_obj['version']:
print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}') print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {get_pkg_version()}')
self.decoder.load_state_dict(loaded_obj['model'], strict = strict) self.decoder.load_state_dict(loaded_obj['model'], strict = strict)

View File

@@ -17,3 +17,13 @@ class Timer:
def print_ribbon(s, symbol = '=', repeat = 40): def print_ribbon(s, symbol = '=', repeat = 40):
flank = symbol * repeat flank = symbol * repeat
return f'{flank} {s} {flank}' return f'{flank} {s} {flank}'
# import helpers
def import_or_print_error(pkg_name, err_str = None):
try:
return importlib.import_module(pkg_name)
except ModuleNotFoundError as e:
if exists(err_str):
print(err_str)
exit()

View File

@@ -0,0 +1 @@
__version__ = '0.6.13'

View File

@@ -1,4 +1,5 @@
from setuptools import setup, find_packages from setuptools import setup, find_packages
exec(open('dalle2_pytorch/version.py').read())
setup( setup(
name = 'dalle2-pytorch', name = 'dalle2-pytorch',
@@ -10,7 +11,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.5.7', version = __version__,
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',
@@ -31,6 +32,7 @@ setup(
'embedding-reader', 'embedding-reader',
'kornia>=0.5.4', 'kornia>=0.5.4',
'numpy', 'numpy',
'packaging',
'pillow', 'pillow',
'pydantic', 'pydantic',
'resize-right>=0.0.2', 'resize-right>=0.0.2',
@@ -40,7 +42,6 @@ setup(
'tqdm', 'tqdm',
'vector-quantize-pytorch', 'vector-quantize-pytorch',
'x-clip>=0.4.4', 'x-clip>=0.4.4',
'youtokentome',
'webdataset>=0.2.5', 'webdataset>=0.2.5',
'fsspec>=2022.1.0', 'fsspec>=2022.1.0',
'torchmetrics[image]>=0.8.0' 'torchmetrics[image]>=0.8.0'

View File

@@ -4,6 +4,7 @@ from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
from dalle2_pytorch.train_configs import TrainDecoderConfig from dalle2_pytorch.train_configs import TrainDecoderConfig
from dalle2_pytorch.utils import Timer, print_ribbon from dalle2_pytorch.utils import Timer, print_ribbon
from dalle2_pytorch.dalle2_pytorch import resize_image_to
import torchvision import torchvision
import torch import torch
@@ -136,6 +137,14 @@ def generate_grid_samples(trainer, examples, text_prepend=""):
Generates samples and uses torchvision to put them in a side by side grid for easy viewing Generates samples and uses torchvision to put them in a side by side grid for easy viewing
""" """
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend) real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
real_image_size = real_images[0].shape[-1]
generated_image_size = generated_images[0].shape[-1]
# training images may be larger than the generated one
if real_image_size > generated_image_size:
real_images = [resize_image_to(image, generated_image_size) for image in real_images]
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)] grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions return grid_images, captions
@@ -322,7 +331,7 @@ def train(
sample = 0 sample = 0
average_loss = 0 average_loss = 0
timer = Timer() timer = Timer()
for i, (img, emb, txt) in enumerate(dataloaders["val"]): for i, (img, emb, *_) in enumerate(dataloaders["val"]):
sample += img.shape[0] sample += img.shape[0]
img, emb = send_to_device((img, emb)) img, emb = send_to_device((img, emb))

View File

@@ -7,15 +7,13 @@ import torch
import clip import clip
from torch import nn from torch import nn
from dalle2_pytorch.dataloaders import make_splits from dalle2_pytorch.dataloaders import make_splits, get_reader
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
from dalle2_pytorch.utils import Timer, print_ribbon from dalle2_pytorch.utils import Timer, print_ribbon
from embedding_reader import EmbeddingReader
from tqdm import tqdm from tqdm import tqdm
# constants # constants
@@ -31,7 +29,7 @@ def exists(val):
# functions # functions
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"): def eval_model(model, dataloader, text_conditioned, loss_type, device, phase="Validation",):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
@@ -39,6 +37,8 @@ def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation
total_samples = 0. total_samples = 0.
for image_embeddings, text_data in tqdm(dataloader): for image_embeddings, text_data in tqdm(dataloader):
image_embeddings = image_embeddings.to(device)
text_data = text_data.to(device)
batches = image_embeddings.shape[0] batches = image_embeddings.shape[0]
@@ -57,12 +57,14 @@ def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation
tracker.log({f'{phase} {loss_type}': avg_loss}) tracker.log({f'{phase} {loss_type}': avg_loss})
def report_cosine_sims(diffusion_prior, dataloader, text_conditioned): def report_cosine_sims(diffusion_prior, dataloader, text_conditioned, device):
diffusion_prior.eval() diffusion_prior.eval()
cos = nn.CosineSimilarity(dim=1, eps=1e-6) cos = nn.CosineSimilarity(dim=1, eps=1e-6)
for test_image_embeddings, text_data in tqdm(dataloader): for test_image_embeddings, text_data in tqdm(dataloader):
test_image_embeddings = test_image_embeddings.to(device)
text_data = text_data.to(device)
# we are text conditioned, we produce an embedding from the tokenized text # we are text conditioned, we produce an embedding from the tokenized text
if text_conditioned: if text_conditioned:
@@ -240,7 +242,7 @@ def train(
# Training loop # Training loop
# diffusion prior network # diffusion prior network
prior_network = DiffusionPriorNetwork( prior_network = DiffusionPriorNetwork(
dim = image_embed_dim, dim = image_embed_dim,
depth = dpn_depth, depth = dpn_depth,
dim_head = dpn_dim_head, dim_head = dpn_dim_head,
@@ -249,16 +251,16 @@ def train(
ff_dropout = dropout, ff_dropout = dropout,
normformer = dp_normformer normformer = dp_normformer
) )
# Load clip model if text-conditioning # Load clip model if text-conditioning
if dp_condition_on_text_encodings: if dp_condition_on_text_encodings:
clip_adapter = OpenAIClipAdapter(clip) clip_adapter = OpenAIClipAdapter(clip)
else: else:
clip_adapter = None clip_adapter = None
# diffusion prior with text embeddings and image embeddings pre-computed # diffusion prior with text embeddings and image embeddings pre-computed
diffusion_prior = DiffusionPrior( diffusion_prior = DiffusionPrior(
net = prior_network, net = prior_network,
clip = clip_adapter, clip = clip_adapter,
image_embed_dim = image_embed_dim, image_embed_dim = image_embed_dim,
@@ -296,28 +298,46 @@ def train(
# Utilize wrapper to abstract away loader logic # Utilize wrapper to abstract away loader logic
print_ribbon("Downloading Embeddings") print_ribbon("Downloading Embeddings")
loader_args = dict(text_conditioned=dp_condition_on_text_encodings, batch_size=batch_size, num_data_points=num_data_points, reader_args = dict(text_conditioned=dp_condition_on_text_encodings, img_url=image_embed_url)
train_split=train_percent, eval_split=val_percent, device=device, img_url=image_embed_url)
if dp_condition_on_text_encodings: if dp_condition_on_text_encodings:
loader_args = dict(**loader_args, meta_url=meta_url) reader_args = dict(**reader_args, meta_url=meta_url)
img_reader = get_reader(**reader_args)
train_loader, eval_loader, test_loader = make_splits(
text_conditioned=dp_condition_on_text_encodings,
batch_size=batch_size,
num_data_points=num_data_points,
train_split=train_percent,
eval_split=val_percent,
image_reader=img_reader
)
else: else:
loader_args = dict(**loader_args, txt_url=text_embed_url) reader_args = dict(**reader_args, txt_url=text_embed_url)
img_reader, txt_reader = get_reader(**reader_args)
train_loader, eval_loader, test_loader = make_splits(**loader_args) train_loader, eval_loader, test_loader = make_splits(
text_conditioned=dp_condition_on_text_encodings,
batch_size=batch_size,
num_data_points=num_data_points,
train_split=train_percent,
eval_split=val_percent,
image_reader=img_reader,
text_reader=txt_reader
)
### Training code ### ### Training code ###
step = 1 step = 1
timer = Timer() timer = Timer()
epochs = num_epochs epochs = num_epochs
for _ in range(epochs): for _ in range(epochs):
for image, text in tqdm(train_loader): for image, text in tqdm(train_loader):
diffusion_prior.train() diffusion_prior.train()
image = image.to(device)
text = text.to(device)
input_args = dict(image_embed=image) input_args = dict(image_embed=image)
if dp_condition_on_text_encodings: if dp_condition_on_text_encodings:
input_args = dict(**input_args, text = text) input_args = dict(**input_args, text = text)
@@ -350,9 +370,9 @@ def train(
# Use NUM_TEST_EMBEDDINGS samples from the test set each time # Use NUM_TEST_EMBEDDINGS samples from the test set each time
# Get embeddings from the most recently saved model # Get embeddings from the most recently saved model
if(step % REPORT_METRICS_EVERY) == 0: if(step % REPORT_METRICS_EVERY) == 0:
report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings) report_cosine_sims(diffusion_prior, eval_loader, dp_condition_on_text_encodings, device=device)
### Evaluate model(validation run) ### ### Evaluate model(validation run) ###
eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation") eval_model(diffusion_prior, eval_loader, dp_condition_on_text_encodings, dp_loss_type, phase="Validation", device=device)
step += 1 step += 1
trainer.update() trainer.update()