Compare commits

..

17 Commits

Author SHA1 Message Date
Phil Wang
ce4b0107c1 0.6.13 2022-06-04 13:26:57 -07:00
zion
64c2f9c4eb implement ema warmup from @crowsonkb (#140) 2022-06-04 13:26:34 -07:00
Phil Wang
22cc613278 ema fix from @nousr 2022-06-03 19:44:36 -07:00
zion
83517849e5 ema module fixes (#139) 2022-06-03 19:43:51 -07:00
Phil Wang
708809ed6c lower beta2 for adam down to 0.99, based on https://openreview.net/forum?id=2LdBqxc1Yv 2022-06-03 10:26:28 -07:00
Phil Wang
9cc475f6e7 fix update_every within EMA 2022-06-03 10:21:05 -07:00
Phil Wang
ffd342e9d0 allow for an option to constrain the variance interpolation fraction coming out from the unet for learned variance, if it is turned on 2022-06-03 09:34:57 -07:00
Phil Wang
f8bfd3493a make destructuring datum length agnostic when validating in training decoder script, for @YUHANG-Ma 2022-06-02 13:54:57 -07:00
Phil Wang
9025345e29 take a stab at fixing generate_grid_samples when real images have a greater image size than generated 2022-06-02 11:33:15 -07:00
Phil Wang
8cc278447e just cast to right types for blur sigma and kernel size augs 2022-06-02 11:21:58 -07:00
Phil Wang
38cd62010c allow for random blur sigma and kernel size augmentations on low res conditioning (need to reread paper to see if the augmentation value needs to be fed into the unet for conditioning as well) 2022-06-02 11:11:25 -07:00
Ryan Russell
1cc288af39 Improve Readability (#133)
Signed-off-by: Ryan Russell <git@ryanrussell.org>
2022-06-01 13:28:02 -07:00
Phil Wang
a851168633 make youtokentome optional package, due to reported installation difficulties 2022-06-01 09:25:35 -07:00
Phil Wang
1ffeecd0ca lower default ema beta value 2022-05-31 11:55:21 -07:00
Phil Wang
3df899f7a4 patch 2022-05-31 09:03:43 -07:00
Aidan Dempster
09534119a1 Fixed non deterministic optimizer creation (#130) 2022-05-31 09:03:20 -07:00
Phil Wang
6f8b90d4d7 add packaging package 2022-05-30 11:45:00 -07:00
12 changed files with 110 additions and 41 deletions

View File

@@ -943,7 +943,7 @@ from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embed
# Create a dataloader directly. # Create a dataloader directly.
dataloader = create_image_embedding_dataloader( dataloader = create_image_embedding_dataloader(
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
num_workers=4, num_workers=4,
batch_size=32, batch_size=32,
@@ -1097,7 +1097,7 @@ This library would not have gotten to this working state without the help of
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697 - [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2 - [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783 - [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
- [ ] bring in skip-layer excitatons (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training - [ ] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training
- [ ] decoder needs one day worth of refactor for tech debt - [ ] decoder needs one day worth of refactor for tech debt
- [ ] allow for unet to be able to condition non-cross attention style as well - [ ] allow for unet to be able to condition non-cross attention style as well
- [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89 - [ ] read the paper, figure it out, and build it https://github.com/lucidrains/DALLE2-pytorch/issues/89

View File

@@ -83,7 +83,7 @@ Defines which evaluation metrics will be used to test the model.
Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked. Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked.
| Option | Required | Default | Description | | Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- | | ------ | -------- | ------- | ----------- |
| `n_evalation_samples` | No | `1000` | The number of samples to generate to test the model. | | `n_evaluation_samples` | No | `1000` | The number of samples to generate to test the model. |
| `FID` | No | `None` | Setting to an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric. | `FID` | No | `None` | Setting to an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric.
| `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric. | `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric.
| `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. | | `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. |

View File

@@ -1,4 +1,5 @@
import math import math
import random
from tqdm import tqdm from tqdm import tqdm
from inspect import isfunction from inspect import isfunction
from functools import partial, wraps from functools import partial, wraps
@@ -1676,7 +1677,7 @@ class LowresConditioner(nn.Module):
def __init__( def __init__(
self, self,
downsample_first = True, downsample_first = True,
blur_sigma = 0.1, blur_sigma = (0.1, 0.2),
blur_kernel_size = 3, blur_kernel_size = 3,
): ):
super().__init__() super().__init__()
@@ -1700,6 +1701,18 @@ class LowresConditioner(nn.Module):
# when training, blur the low resolution conditional image # when training, blur the low resolution conditional image
blur_sigma = default(blur_sigma, self.blur_sigma) blur_sigma = default(blur_sigma, self.blur_sigma)
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size) blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
# allow for drawing a random sigma between lo and hi float values
if isinstance(blur_sigma, tuple):
blur_sigma = tuple(map(float, blur_sigma))
blur_sigma = random.uniform(*blur_sigma)
# allow for drawing a random kernel size between lo and hi int values
if isinstance(blur_kernel_size, tuple):
blur_kernel_size = tuple(map(int, blur_kernel_size))
kernel_size_lo, kernel_size_hi = blur_kernel_size
blur_kernel_size = random.randrange(kernel_size_lo, kernel_size_hi + 1)
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2)) cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
cond_fmap = resize_image_to(cond_fmap, target_image_size) cond_fmap = resize_image_to(cond_fmap, target_image_size)
@@ -1725,13 +1738,14 @@ class Decoder(BaseGaussianDiffusion):
image_sizes = None, # for cascading ddpm, image size at each stage image_sizes = None, # for cascading ddpm, image size at each stage
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops) random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
blur_sigma = 0.1, # cascading ddpm - blur sigma blur_sigma = (0.1, 0.2), # cascading ddpm - blur sigma
blur_kernel_size = 3, # cascading ddpm - blur kernel size blur_kernel_size = 3, # cascading ddpm - blur kernel size
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
clip_denoised = True, clip_denoised = True,
clip_x_start = True, clip_x_start = True,
clip_adapter_overrides = dict(), clip_adapter_overrides = dict(),
learned_variance = True, learned_variance = True,
learned_variance_constrain_frac = False,
vb_loss_weight = 0.001, vb_loss_weight = 0.001,
unconditional = False, unconditional = False,
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
@@ -1792,6 +1806,7 @@ class Decoder(BaseGaussianDiffusion):
learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False) learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)
self.learned_variance = learned_variance self.learned_variance = learned_variance
self.learned_variance_constrain_frac = learned_variance_constrain_frac # whether to constrain the output of the network (the interpolation fraction) from 0 to 1
self.vb_loss_weight = vb_loss_weight self.vb_loss_weight = vb_loss_weight
# construct unets and vaes # construct unets and vaes
@@ -1932,6 +1947,9 @@ class Decoder(BaseGaussianDiffusion):
max_log = extract(torch.log(self.betas), t, x.shape) max_log = extract(torch.log(self.betas), t, x.shape)
var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized) var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
if self.learned_variance_constrain_frac:
var_interp_frac = var_interp_frac.sigmoid()
posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
posterior_variance = posterior_log_variance.exp() posterior_variance = posterior_log_variance.exp()

View File

@@ -15,7 +15,7 @@ from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embed
# Create a dataloader directly. # Create a dataloader directly.
dataloader = create_image_embedding_dataloader( dataloader = create_image_embedding_dataloader(
tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses braket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
num_workers=4, num_workers=4,
batch_size=32, batch_size=32,

View File

@@ -1,15 +1,17 @@
from torch.optim import AdamW, Adam from torch.optim import AdamW, Adam
def separate_weight_decayable_params(params): def separate_weight_decayable_params(params):
no_wd_params = set([param for param in params if param.ndim < 2]) wd_params, no_wd_params = [], []
wd_params = set(params) - no_wd_params for param in params:
param_list = no_wd_params if param.ndim < 2 else wd_params
param_list.append(param)
return wd_params, no_wd_params return wd_params, no_wd_params
def get_optimizer( def get_optimizer(
params, params,
lr = 1e-4, lr = 1e-4,
wd = 1e-2, wd = 1e-2,
betas = (0.9, 0.999), betas = (0.9, 0.99),
eps = 1e-8, eps = 1e-8,
filter_by_requires_grad = False, filter_by_requires_grad = False,
group_wd_params = True, group_wd_params = True,
@@ -25,8 +27,8 @@ def get_optimizer(
wd_params, no_wd_params = separate_weight_decayable_params(params) wd_params, no_wd_params = separate_weight_decayable_params(params)
params = [ params = [
{'params': list(wd_params)}, {'params': wd_params},
{'params': list(no_wd_params), 'weight_decay': 0}, {'params': no_wd_params, 'weight_decay': 0},
] ]
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps) return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)

View File

@@ -2,7 +2,6 @@
# to give users a quick easy start to training DALL-E without doing BPE # to give users a quick easy start to training DALL-E without doing BPE
import torch import torch
import youtokentome as yttm
import html import html
import os import os
@@ -11,6 +10,8 @@ import regex as re
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from dalle2_pytorch.utils import import_or_print_error
# OpenAI simple tokenizer # OpenAI simple tokenizer
@lru_cache() @lru_cache()
@@ -156,7 +157,9 @@ class YttmTokenizer:
bpe_path = Path(bpe_path) bpe_path = Path(bpe_path)
assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist' assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'
tokenizer = yttm.BPE(model = str(bpe_path)) self.yttm = import_or_print_error('youtokentome', 'you need to install youtokentome by `pip install youtokentome`')
tokenizer = self.yttm.BPE(model = str(bpe_path))
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.vocab_size = tokenizer.vocab_size() self.vocab_size = tokenizer.vocab_size()
@@ -167,7 +170,7 @@ class YttmTokenizer:
return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0})) return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0}))
def encode(self, texts): def encode(self, texts):
encoded = self.tokenizer.encode(texts, output_type = yttm.OutputType.ID) encoded = self.tokenizer.encode(texts, output_type = self.yttm.OutputType.ID)
return list(map(torch.tensor, encoded)) return list(map(torch.tensor, encoded))
def tokenize(self, texts, context_length = 256, truncate_text = False): def tokenize(self, texts, context_length = 256, truncate_text = False):

View File

@@ -6,6 +6,8 @@ from itertools import zip_longest
import torch import torch
from torch import nn from torch import nn
from dalle2_pytorch.utils import import_or_print_error
# constants # constants
DEFAULT_DATA_PATH = './.tracker-data' DEFAULT_DATA_PATH = './.tracker-data'
@@ -15,14 +17,6 @@ DEFAULT_DATA_PATH = './.tracker-data'
def exists(val): def exists(val):
return val is not None return val is not None
def import_or_print_error(pkg_name, err_str = None):
try:
return importlib.import_module(pkg_name)
except ModuleNotFoundError as e:
if exists(err_str):
print(err_str)
exit()
# load state dict functions # load state dict functions
def load_wandb_state_dict(run_path, file_path, **kwargs): def load_wandb_state_dict(run_path, file_path, **kwargs):

View File

@@ -175,12 +175,34 @@ def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embe
# exponential moving average wrapper # exponential moving average wrapper
class EMA(nn.Module): class EMA(nn.Module):
"""
Implements exponential moving average shadowing for your model.
Utilizes an inverse decay schedule to manage longer term training runs.
By adjusting the power, you can control how fast EMA will ramp up to your specified beta.
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
good values for models you plan to train for a million or more steps (reaches decay
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 1.
min_value (float): The minimum EMA decay rate. Default: 0.
"""
def __init__( def __init__(
self, self,
model, model,
beta = 0.9999, beta = 0.9999,
update_after_step = 1000, update_after_step = 10000,
update_every = 10, update_every = 10,
inv_gamma = 1.0,
power = 2/3,
min_value = 0.0,
): ):
super().__init__() super().__init__()
self.beta = beta self.beta = beta
@@ -188,7 +210,11 @@ class EMA(nn.Module):
self.ema_model = copy.deepcopy(model) self.ema_model = copy.deepcopy(model)
self.update_every = update_every self.update_every = update_every
self.update_after_step = update_after_step // update_every # only start EMA after this step number, starting at 0 self.update_after_step = update_after_step
self.inv_gamma = inv_gamma
self.power = power
self.min_value = min_value
self.register_buffer('initted', torch.Tensor([False])) self.register_buffer('initted', torch.Tensor([False]))
self.register_buffer('step', torch.tensor([0])) self.register_buffer('step', torch.tensor([0]))
@@ -198,37 +224,44 @@ class EMA(nn.Module):
self.ema_model.to(device) self.ema_model.to(device)
def copy_params_from_model_to_ema(self): def copy_params_from_model_to_ema(self):
self.ema_model.state_dict(self.online_model.state_dict()) for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())):
ma_param.data.copy_(current_param.data)
def get_current_decay(self):
epoch = max(0, self.step.item() - self.update_after_step - 1)
value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
return 0. if epoch < 0 else min(self.beta, max(self.min_value, value))
def update(self): def update(self):
step = self.step.item()
self.step += 1 self.step += 1
if (self.step % self.update_every) != 0: if (step % self.update_every) != 0:
return return
if self.step <= self.update_after_step: if step <= self.update_after_step:
self.copy_params_from_model_to_ema() self.copy_params_from_model_to_ema()
return return
if not self.initted: if not self.initted.item():
self.copy_params_from_model_to_ema() self.copy_params_from_model_to_ema()
self.initted.data.copy_(torch.Tensor([True])) self.initted.data.copy_(torch.Tensor([True]))
self.update_moving_average(self.ema_model, self.online_model) self.update_moving_average(self.ema_model, self.online_model)
@torch.no_grad()
def update_moving_average(self, ma_model, current_model): def update_moving_average(self, ma_model, current_model):
def calculate_ema(beta, old, new): current_decay = self.get_current_decay()
if not exists(old):
return new
return old * beta + (1 - beta) * new
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): for current_params, ma_params in zip(list(current_model.parameters()), list(ma_model.parameters())):
old_weight, up_weight = ma_params.data, current_params.data difference = ma_params.data - current_params.data
ma_params.data = calculate_ema(self.beta, old_weight, up_weight) difference.mul_(1.0 - current_decay)
ma_params.sub_(difference)
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()): for current_buffer, ma_buffer in zip(list(current_model.buffers()), list(ma_model.buffers())):
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer) difference = ma_buffer - current_buffer
ma_buffer.copy_(new_buffer_value) difference.mul_(1.0 - current_decay)
ma_buffer.sub_(difference)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs) return self.ema_model(*args, **kwargs)

View File

@@ -17,3 +17,13 @@ class Timer:
def print_ribbon(s, symbol = '=', repeat = 40): def print_ribbon(s, symbol = '=', repeat = 40):
flank = symbol * repeat flank = symbol * repeat
return f'{flank} {s} {flank}' return f'{flank} {s} {flank}'
# import helpers
def import_or_print_error(pkg_name, err_str = None):
try:
return importlib.import_module(pkg_name)
except ModuleNotFoundError as e:
if exists(err_str):
print(err_str)
exit()

View File

@@ -1 +1 @@
__version__ = '0.6.2' __version__ = '0.6.13'

View File

@@ -32,6 +32,7 @@ setup(
'embedding-reader', 'embedding-reader',
'kornia>=0.5.4', 'kornia>=0.5.4',
'numpy', 'numpy',
'packaging',
'pillow', 'pillow',
'pydantic', 'pydantic',
'resize-right>=0.0.2', 'resize-right>=0.0.2',
@@ -41,7 +42,6 @@ setup(
'tqdm', 'tqdm',
'vector-quantize-pytorch', 'vector-quantize-pytorch',
'x-clip>=0.4.4', 'x-clip>=0.4.4',
'youtokentome',
'webdataset>=0.2.5', 'webdataset>=0.2.5',
'fsspec>=2022.1.0', 'fsspec>=2022.1.0',
'torchmetrics[image]>=0.8.0' 'torchmetrics[image]>=0.8.0'

View File

@@ -4,6 +4,7 @@ from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
from dalle2_pytorch.train_configs import TrainDecoderConfig from dalle2_pytorch.train_configs import TrainDecoderConfig
from dalle2_pytorch.utils import Timer, print_ribbon from dalle2_pytorch.utils import Timer, print_ribbon
from dalle2_pytorch.dalle2_pytorch import resize_image_to
import torchvision import torchvision
import torch import torch
@@ -136,6 +137,14 @@ def generate_grid_samples(trainer, examples, text_prepend=""):
Generates samples and uses torchvision to put them in a side by side grid for easy viewing Generates samples and uses torchvision to put them in a side by side grid for easy viewing
""" """
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend) real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
real_image_size = real_images[0].shape[-1]
generated_image_size = generated_images[0].shape[-1]
# training images may be larger than the generated one
if real_image_size > generated_image_size:
real_images = [resize_image_to(image, generated_image_size) for image in real_images]
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)] grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions return grid_images, captions
@@ -322,7 +331,7 @@ def train(
sample = 0 sample = 0
average_loss = 0 average_loss = 0
timer = Timer() timer = Timer()
for i, (img, emb, txt) in enumerate(dataloaders["val"]): for i, (img, emb, *_) in enumerate(dataloaders["val"]):
sample += img.shape[0] sample += img.shape[0]
img, emb = send_to_device((img, emb)) img, emb = send_to_device((img, emb))