Compare commits

...

20 Commits

Author SHA1 Message Date
Phil Wang
7ea314e2f0 allow for final l2norm clamping of the sampled image embed 2022-07-10 09:44:38 -07:00
Phil Wang
4173e88121 more accurate readme 2022-07-09 20:57:26 -07:00
Phil Wang
3dae43fa0e fix misnamed variable, thanks to @nousr 2022-07-09 19:01:37 -07:00
Phil Wang
a598820012 do not noise for the last step in ddim 2022-07-09 18:38:40 -07:00
Phil Wang
4878762627 fix for small validation bug for sampling steps 2022-07-09 17:31:54 -07:00
Phil Wang
47ae17b36e more informative error for something that tripped me up 2022-07-09 17:28:14 -07:00
Phil Wang
b7e22f7da0 complete ddim integration of diffusion prior as well as decoder for each unet, feature complete for https://github.com/lucidrains/DALLE2-pytorch/issues/157 2022-07-09 17:25:34 -07:00
Romain Beaumont
68de937aac Fix decoder test by fixing the resizing output size (#197) 2022-07-09 07:48:07 -07:00
Phil Wang
097afda606 0.18.0 2022-07-08 18:18:38 -07:00
Aidan Dempster
5c520db825 Added deepspeed support (#195) 2022-07-08 18:18:08 -07:00
Phil Wang
3070610231 just force it so researcher can never pass in an image that is less than the size that is required for CLIP or CoCa 2022-07-08 18:17:29 -07:00
Aidan Dempster
870aeeca62 Fixed issue where evaluation would error when large image was loaded (#194) 2022-07-08 17:11:34 -07:00
Romain Beaumont
f28dc6dc01 setup simple ci (#193) 2022-07-08 16:51:56 -07:00
Phil Wang
081d8d3484 0.17.0 2022-07-08 13:36:26 -07:00
Aidan Dempster
a71f693a26 Add the ability to auto restart the last run when started after a crash (#191)
* Added autoresume after crash functionality to the trackers

* Updated documentation

* Clarified what goes in the autorestart object

* Fixed style issues

Unraveled conditional block

Chnaged to using helper function to get step count
2022-07-08 13:35:40 -07:00
Phil Wang
d7bc5fbedd expose num_steps_taken helper method on trainer to retrieve number of training steps of each unet 2022-07-08 13:00:56 -07:00
Phil Wang
8c823affff allow for control over use of nearest interp method of downsampling low res conditioning, in addition to being able to turn it off 2022-07-08 11:44:43 -07:00
Phil Wang
ec7cab01d9 extra insurance that diffusion prior is on the correct device, when using trainer with accelerator or device was given 2022-07-07 10:08:33 -07:00
Phil Wang
46be8c32d3 fix a potential issue in the low resolution conditioner, when downsampling and then upsampling using resize right, thanks to @marunine 2022-07-07 09:41:49 -07:00
Phil Wang
900f086a6d fix condition_on_text_encodings in dalle2 orchestrator class, fix readme 2022-07-07 07:43:41 -07:00
24 changed files with 577 additions and 88 deletions

33
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,33 @@
name: Continuous integration
on:
push:
branches:
- main
pull_request:
branches:
- main
jobs:
tests:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install
run: |
python3 -m venv .env
source .env/bin/activate
make install
- name: Tests
run: |
source .env/bin/activate
make test

2
.gitignore vendored
View File

@@ -136,3 +136,5 @@ dmypy.json
# Pyre type checker
.pyre/
.tracker_data
*.pth

6
Makefile Normal file
View File

@@ -0,0 +1,6 @@
install:
pip install -U pip
pip install -e .
test:
CUDA_VISIBLE_DEVICES= python train_decoder.py --config_file configs/train_decoder_config.test.json

View File

@@ -44,6 +44,7 @@ This library would not have gotten to this working state without the help of
- <a href="https://github.com/krish240574">Kumar</a> for working on the initial diffusion training script
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
- <a href="https://github.com/marunine">Marunine</a> for identifying issues with resizing of the low resolution conditioner, when training the upsampler, in addition to various other bug fixes
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
- <a href="https://stability.ai/">Stability AI</a> for the generous sponsorship
- <a href="https://huggingface.co">🤗 Huggingface</a> and in particular <a href="https://github.com/sgugger">Sylvain</a> for the <a href="https://github.com/huggingface/accelerate">Accelerate</a> library
@@ -354,7 +355,8 @@ prior_network = DiffusionPriorNetwork(
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 100,
timesteps = 1000,
sample_timesteps = 64,
cond_drop_prob = 0.2
).cuda()
@@ -581,7 +583,9 @@ unet1 = Unet(
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
dim_mults=(1, 2, 4, 8),
text_embed_dim = 512,
cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade)
).cuda()
unet2 = Unet(
@@ -596,14 +600,14 @@ decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (128, 256),
clip = clip,
timesteps = 100,
timesteps = 1000,
sample_timesteps = (250, 27),
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
text_cond_drop_prob = 0.5
).cuda()
for unet_number in (1, 2):
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward()
# do above for many steps

View File

@@ -30,6 +30,7 @@ Defines the configuration options for the decoder model. The unets defined above
| `loss_type` | No | `l2` | The loss function. Options are `l1`, `huber`, or `l2`. |
| `beta_schedule` | No | `cosine` | The noising schedule. Options are `cosine`, `linear`, `quadratic`, `jsd`, or `sigmoid`. |
| `learned_variance` | No | `True` | Whether to learn the variance. |
| `clip` | No | `None` | The clip model to use if embeddings are being generated on the fly. Takes keys `make` and `model` with defaults `openai` and `ViT-L/14`. |
Any parameter from the `Decoder` constructor can also be given here.
@@ -39,7 +40,8 @@ Settings for creation of the dataloaders.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `webdataset_base_url` | Yes | N/A | The url of a shard in the webdataset with the shard replaced with `{}`[^1]. |
| `embeddings_url` | No | N/A | The url of the folder containing embeddings shards. Not required if embeddings are in webdataset. |
| `img_embeddings_url` | No | `None` | The url of the folder containing image embeddings shards. Not required if embeddings are in webdataset or clip is being used. |
| `text_embeddings_url` | No | `None` | The url of the folder containing text embeddings shards. Not required if embeddings are in webdataset or clip is being used. |
| `num_workers` | No | `4` | The number of workers used in the dataloader. |
| `batch_size` | No | `64` | The batch size. |
| `start_shard` | No | `0` | Defines the start of the shard range the dataset will recall. |
@@ -106,6 +108,13 @@ Tracking is split up into three sections:
**Logging:**
All loggers have the following keys:
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `log_type` | Yes | N/A | The type of logger class to use. |
| `resume` | No | `False` | For loggers that have the option to resume an old run, resume it using maually input parameters. |
| `auto_resume` | No | `False` | If true, the logger will attempt to resume an old run using parameters from that previous run. |
If using `console` there is no further configuration than setting `log_type` to `console`.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
@@ -119,10 +128,15 @@ If using `wandb`
| `wandb_project` | Yes | N/A | The wandb project save the run to. |
| `wandb_run_name` | No | `None` | The wandb run name. |
| `wandb_run_id` | No | `None` | The wandb run id. Used if resuming an old run. |
| `wandb_resume` | No | `False` | Whether to resume an old run. |
**Loading:**
All loaders have the following keys:
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `load_from` | Yes | N/A | The type of loader class to use. |
| `only_auto_resume` | No | `False` | If true, the loader will only load the model if the run is being auto resumed. |
If using `local`
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |

View File

@@ -20,7 +20,7 @@
},
"data": {
"webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -",
"embeddings_url": "s3://bucket/embeddings/path/",
"img_embeddings_url": "s3://bucket/img_embeddings/path/",
"num_workers": 4,
"batch_size": 64,
"start_shard": 0,

View File

@@ -0,0 +1,102 @@
{
"decoder": {
"unets": [
{
"dim": 16,
"image_embed_dim": 768,
"cond_dim": 16,
"channels": 3,
"dim_mults": [1, 2, 4, 8],
"attn_dim_head": 16,
"attn_heads": 4,
"self_attn": [false, true, true, true]
}
],
"clip": {
"make": "openai",
"model": "ViT-L/14"
},
"timesteps": 10,
"image_sizes": [64],
"channels": 3,
"loss_type": "l2",
"beta_schedule": ["cosine"],
"learned_variance": true
},
"data": {
"webdataset_base_url": "test_data/{}.tar",
"num_workers": 4,
"batch_size": 4,
"start_shard": 0,
"end_shard": 9,
"shard_width": 1,
"index_width": 1,
"splits": {
"train": 0.75,
"val": 0.15,
"test": 0.1
},
"shuffle_train": false,
"resample_train": true,
"preprocessing": {
"RandomResizedCrop": {
"size": [224, 224],
"scale": [0.75, 1.0],
"ratio": [1.0, 1.0]
},
"ToTensor": true
}
},
"train": {
"epochs": 1,
"lr": 1e-16,
"wd": 0.01,
"max_grad_norm": 0.5,
"save_every_n_samples": 100,
"n_sample_images": 1,
"device": "cpu",
"epoch_samples": 50,
"validation_samples": 5,
"use_ema": true,
"ema_beta": 0.99,
"amp": false,
"save_all": false,
"save_latest": true,
"save_best": true,
"unet_training_mask": [true]
},
"evaluate": {
"n_evaluation_samples": 2,
"FID": {
"feature": 64
},
"IS": {
"feature": 64,
"splits": 10
},
"KID": {
"feature": 64,
"subset_size": 2
},
"LPIPS": {
"net_type": "vgg",
"reduction": "mean"
}
},
"tracker": {
"overwrite_data_path": true,
"log": {
"log_type": "console"
},
"load": {
"load_from": null
},
"save": [{
"save_to": "local"
}]
}
}

View File

@@ -125,14 +125,28 @@ def log(t, eps = 1e-12):
def l2norm(t):
return F.normalize(t, dim = -1)
def resize_image_to(image, target_image_size):
def resize_image_to(
image,
target_image_size,
clamp_range = None,
nearest = False,
**kwargs
):
orig_image_size = image.shape[-1]
if orig_image_size == target_image_size:
return image
scale_factors = target_image_size / orig_image_size
return resize(image, scale_factors = scale_factors)
if not nearest:
scale_factors = target_image_size / orig_image_size
out = resize(image, scale_factors = scale_factors, **kwargs)
else:
out = F.interpolate(image, target_image_size, mode = 'nearest', align_corners = False)
if exists(clamp_range):
out = out.clamp(*clamp_range)
return out
# image normalization functions
# ddpms expect images to be in the range of -1 to 1
@@ -155,6 +169,11 @@ class BaseClipAdapter(nn.Module):
self.clip = clip
self.overrides = kwargs
def validate_and_resize_image(self, image):
image_size = image.shape[-1]
assert image_size >= self.image_size, f'you are passing in an image of size {image_size} but CLIP requires the image size to be at least {self.image_size}'
return resize_image_to(image, self.image_size)
@property
def dim_latent(self):
raise NotImplementedError
@@ -205,7 +224,7 @@ class XClipAdapter(BaseClipAdapter):
@torch.no_grad()
def embed_image(self, image):
image = resize_image_to(image, self.image_size)
image = self.validate_and_resize_image(image)
encoder_output = self.clip.visual_transformer(image)
image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:]
image_embed = self.clip.to_visual_latent(image_cls)
@@ -240,7 +259,7 @@ class CoCaAdapter(BaseClipAdapter):
@torch.no_grad()
def embed_image(self, image):
image = resize_image_to(image, self.image_size)
image = self.validate_and_resize_image(image)
image_embed, image_encodings = self.clip.embed_image(image)
return EmbeddedImage(image_embed, image_encodings)
@@ -301,7 +320,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
@torch.no_grad()
def embed_image(self, image):
assert not self.cleared
image = resize_image_to(image, self.image_size)
image = self.validate_and_resize_image(image)
image = self.clip_normalize(image)
image_embed = self.clip.encode_image(image)
return EmbeddedImage(l2norm(image_embed.float()), None)
@@ -486,6 +505,12 @@ class NoiseScheduler(nn.Module):
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def predict_noise_from_start(self, x_t, t, x0):
return (
(x0 - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t) / \
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
)
def p2_reweigh_loss(self, loss, times):
if not self.has_p2_loss_reweighting:
return loss
@@ -892,19 +917,23 @@ class DiffusionPrior(nn.Module):
image_size = None,
image_channels = 3,
timesteps = 1000,
sample_timesteps = None,
cond_drop_prob = 0.,
loss_type = "l2",
predict_x_start = True,
beta_schedule = "cosine",
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
sampling_clamp_l2norm = False,
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)
sampling_final_clamp_l2norm = False, # whether to l2norm the final image embedding output (this is also done for images in ddpm)
training_clamp_l2norm = False,
init_image_embed_l2norm = False,
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
clip_adapter_overrides = dict()
):
super().__init__()
self.sample_timesteps = sample_timesteps
self.noise_scheduler = NoiseScheduler(
beta_schedule = beta_schedule,
timesteps = timesteps,
@@ -935,23 +964,32 @@ class DiffusionPrior(nn.Module):
self.condition_on_text_encodings = condition_on_text_encodings
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
self.predict_x_start = predict_x_start
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
self.image_embed_scale = default(image_embed_scale, self.image_embed_dim ** 0.5)
# whether to force an l2norm, similar to clipping denoised, when sampling
self.sampling_clamp_l2norm = sampling_clamp_l2norm
self.sampling_final_clamp_l2norm = sampling_final_clamp_l2norm
self.training_clamp_l2norm = training_clamp_l2norm
self.init_image_embed_l2norm = init_image_embed_l2norm
# device tracker
self.register_buffer('_dummy', torch.tensor([True]), persistent = False)
@property
def device(self):
return self._dummy.device
def l2norm_clamp_embed(self, image_embed):
return l2norm(image_embed) * self.image_embed_scale
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
@@ -959,8 +997,6 @@ class DiffusionPrior(nn.Module):
if self.predict_x_start:
x_recon = pred
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
else:
x_recon = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
@@ -983,21 +1019,81 @@ class DiffusionPrior(nn.Module):
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, shape, text_cond, cond_scale = 1.):
device = self.device
b = shape[0]
image_embed = torch.randn(shape, device=device)
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.):
batch, device = shape[0], self.device
image_embed = torch.randn(shape, device = device)
if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
times = torch.full((b,), i, device = device, dtype = torch.long)
times = torch.full((batch,), i, device = device, dtype = torch.long)
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
if self.sampling_final_clamp_l2norm and self.predict_x_start:
image_embed = self.l2norm_clamp_embed(image_embed)
return image_embed
@torch.no_grad()
def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scale = 1.):
batch, device, alphas, total_timesteps = shape[0], self.device, self.noise_scheduler.alphas_cumprod_prev, self.noise_scheduler.num_timesteps
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:]))
image_embed = torch.randn(shape, device = device)
if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
alpha = alphas[time]
alpha_next = alphas[time_next]
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
pred = self.net.forward_with_cond_scale(image_embed, time_cond, cond_scale = cond_scale, **text_cond)
if self.predict_x_start:
x_start = pred
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = pred)
else:
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred)
pred_noise = pred
if not self.predict_x_start:
x_start.clamp_(-1., 1.)
if self.predict_x_start and self.sampling_clamp_l2norm:
x_start = self.l2norm_clamp_embed(x_start)
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
noise = torch.randn_like(image_embed) if time_next > 0 else 0.
image_embed = x_start * alpha_next.sqrt() + \
c1 * noise + \
c2 * pred_noise
if self.predict_x_start and self.sampling_final_clamp_l2norm:
image_embed = self.l2norm_clamp_embed(image_embed)
return image_embed
@torch.no_grad()
def p_sample_loop(self, *args, timesteps = None, **kwargs):
timesteps = default(timesteps, self.noise_scheduler.num_timesteps)
assert timesteps <= self.noise_scheduler.num_timesteps
is_ddim = timesteps < self.noise_scheduler.num_timesteps
if not is_ddim:
return self.p_sample_loop_ddpm(*args, **kwargs)
return self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
def p_losses(self, image_embed, times, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(image_embed))
@@ -1011,7 +1107,7 @@ class DiffusionPrior(nn.Module):
)
if self.predict_x_start and self.training_clamp_l2norm:
pred = l2norm(pred) * self.image_embed_scale
pred = self.l2norm_clamp_embed(pred)
target = noise if not self.predict_x_start else image_embed
@@ -1032,7 +1128,15 @@ class DiffusionPrior(nn.Module):
@torch.no_grad()
@eval_decorator
def sample(self, text, num_samples_per_batch = 2, cond_scale = 1.):
def sample(
self,
text,
num_samples_per_batch = 2,
cond_scale = 1.,
timesteps = None
):
timesteps = default(timesteps, self.sample_timesteps)
# in the paper, what they did was
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
@@ -1047,7 +1151,7 @@ class DiffusionPrior(nn.Module):
if self.condition_on_text_encodings:
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale)
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale, timesteps = timesteps)
# retrieve original unscaled image embed
@@ -1449,10 +1553,12 @@ class Unet(nn.Module):
# text encoding conditioning (optional)
self.text_to_cond = None
self.text_embed_dim = None
if cond_on_text_encodings:
assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text_encodings is True'
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
self.text_embed_dim = text_embed_dim
# finer control over whether to condition on image embeddings and text encodings
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
@@ -1681,6 +1787,8 @@ class Unet(nn.Module):
text_tokens = None
if exists(text_encodings) and self.cond_on_text_encodings:
assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'
text_tokens = self.text_to_cond(text_encodings)
text_tokens = text_tokens[:, :self.max_text_len]
@@ -1776,11 +1884,17 @@ class LowresConditioner(nn.Module):
def __init__(
self,
downsample_first = True,
downsample_mode_nearest = False,
blur_sigma = 0.6,
blur_kernel_size = 3,
input_image_range = None
):
super().__init__()
self.downsample_first = downsample_first
self.downsample_mode_nearest = downsample_mode_nearest
self.input_image_range = input_image_range
self.blur_sigma = blur_sigma
self.blur_kernel_size = blur_kernel_size
@@ -1794,7 +1908,7 @@ class LowresConditioner(nn.Module):
blur_kernel_size = None
):
if self.training and self.downsample_first and exists(downsample_image_size):
cond_fmap = resize_image_to(cond_fmap, downsample_image_size)
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range, nearest = self.downsample_mode_nearest)
if self.training:
# when training, blur the low resolution conditional image
@@ -1814,7 +1928,7 @@ class LowresConditioner(nn.Module):
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
cond_fmap = resize_image_to(cond_fmap, target_image_size)
cond_fmap = resize_image_to(cond_fmap, target_image_size, clamp_range = self.input_image_range)
return cond_fmap
@@ -1828,6 +1942,7 @@ class Decoder(nn.Module):
channels = 3,
vae = tuple(),
timesteps = 1000,
sample_timesteps = None,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
loss_type = 'l2',
@@ -1837,6 +1952,7 @@ class Decoder(nn.Module):
image_sizes = None, # for cascading ddpm, image size at each stage
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
lowres_downsample_mode_nearest = False, # cascading ddpm - whether to use nearest mode downsampling for lower resolution
blur_sigma = 0.6, # cascading ddpm - blur sigma
blur_kernel_size = 3, # cascading ddpm - blur kernel size
clip_denoised = True,
@@ -1850,7 +1966,8 @@ class Decoder(nn.Module):
use_dynamic_thres = False, # from the Imagen paper
dynamic_thres_percentile = 0.9,
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
p2_loss_weight_k = 1
p2_loss_weight_k = 1,
ddim_sampling_eta = 1. # can be set to 0. for deterministic sampling afaict
):
super().__init__()
@@ -1930,9 +2047,10 @@ class Decoder(nn.Module):
self.unets.append(one_unet)
self.vaes.append(one_vae.copy_for_eval())
# determine from unets whether conditioning on text encoding is needed
# sampling timesteps, defaults to non-ddim with full timesteps sampling
self.condition_on_text_encodings = any([unet.cond_on_text_encodings for unet in self.unets])
self.sample_timesteps = cast_tuple(sample_timesteps, num_unets)
self.ddim_sampling_eta = ddim_sampling_eta
# create noise schedulers per unet
@@ -1944,7 +2062,9 @@ class Decoder(nn.Module):
self.noise_schedulers = nn.ModuleList([])
for unet_beta_schedule, unet_p2_loss_weight_gamma in zip(beta_schedule, p2_loss_weight_gamma):
for ind, (unet_beta_schedule, unet_p2_loss_weight_gamma, sample_timesteps) in enumerate(zip(beta_schedule, p2_loss_weight_gamma, self.sample_timesteps)):
assert not exists(sample_timesteps) or sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}'
noise_scheduler = NoiseScheduler(
beta_schedule = unet_beta_schedule,
timesteps = timesteps,
@@ -1972,6 +2092,10 @@ class Decoder(nn.Module):
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
# input image range
self.input_image_range = (-1. if not auto_normalize_img else 0., 1.)
# cascading ddpm related stuff
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
@@ -1979,8 +2103,10 @@ class Decoder(nn.Module):
self.to_lowres_cond = LowresConditioner(
downsample_first = lowres_downsample_first,
downsample_mode_nearest = lowres_downsample_mode_nearest,
blur_sigma = blur_sigma,
blur_kernel_size = blur_kernel_size,
input_image_range = self.input_image_range
)
# classifier free guidance
@@ -2012,6 +2138,10 @@ class Decoder(nn.Module):
def device(self):
return self._dummy.device
@property
def condition_on_text_encodings(self):
return any([unet.cond_on_text_encodings for unet in self.unets])
def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1
@@ -2035,6 +2165,26 @@ class Decoder(nn.Module):
for unet, device in zip(self.unets, devices):
unet.to(device)
def dynamic_threshold(self, x):
""" proposed in https://arxiv.org/abs/2205.11487 as an improved clamping in the setting of classifier free guidance """
# s is the threshold amount
# static thresholding would just be s = 1
s = 1.
if self.use_dynamic_thres:
s = torch.quantile(
rearrange(x, 'b ... -> b (...)').abs(),
self.dynamic_thres_percentile,
dim = -1
)
s.clamp_(min = 1.)
s = s.view(-1, *((1,) * (x.ndim - 1)))
# clip by threshold, depending on whether static or dynamic
x = x.clamp(-s, s) / s
return x
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
@@ -2049,21 +2199,7 @@ class Decoder(nn.Module):
x_recon = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised:
# s is the threshold amount
# static thresholding would just be s = 1
s = 1.
if self.use_dynamic_thres:
s = torch.quantile(
rearrange(x_recon, 'b ... -> b (...)').abs(),
self.dynamic_thres_percentile,
dim = -1
)
s.clamp_(min = 1.)
s = s.view(-1, *((1,) * (x_recon.ndim - 1)))
# clip by threshold, depending on whether static or dynamic
x_recon = x_recon.clamp(-s, s) / s
x_recon = self.dynamic_threshold(x_recon)
model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)
@@ -2093,7 +2229,7 @@ class Decoder(nn.Module):
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
device = self.device
b = shape[0]
@@ -2121,6 +2257,62 @@ class Decoder(nn.Module):
unnormalize_img = self.unnormalize_img(img)
return unnormalize_img
@torch.no_grad()
def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False):
batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta
times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:]))
img = torch.randn(shape, device = device)
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
alpha = alphas[time]
alpha_next = alphas[time_next]
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
if learned_variance:
pred, _ = pred.chunk(2, dim = 1)
if predict_x_start:
x_start = pred
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
else:
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
pred_noise = pred
if clip_denoised:
x_start = self.dynamic_threshold(x_start)
c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
noise = torch.randn_like(img) if time_next > 0 else 0.
img = x_start * alpha_next.sqrt() + \
c1 * noise + \
c2 * pred_noise
img = self.unnormalize_img(img)
return img
@torch.no_grad()
def p_sample_loop(self, *args, noise_scheduler, timesteps = None, **kwargs):
num_timesteps = noise_scheduler.num_timesteps
timesteps = default(timesteps, num_timesteps)
assert timesteps <= num_timesteps
is_ddim = timesteps < num_timesteps
if not is_ddim:
return self.p_sample_loop_ddpm(*args, noise_scheduler = noise_scheduler, **kwargs)
return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False):
noise = default(noise, lambda: torch.randn_like(x_start))
@@ -2221,7 +2413,7 @@ class Decoder(nn.Module):
img = None
is_cuda = next(self.parameters()).is_cuda
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers)):
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, sample_timesteps in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.sample_timesteps)):
context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context()
@@ -2250,7 +2442,8 @@ class Decoder(nn.Module):
clip_denoised = not is_latent_diffusion,
lowres_cond_img = lowres_cond_img,
is_latent_diffusion = is_latent_diffusion,
noise_scheduler = noise_scheduler
noise_scheduler = noise_scheduler,
timesteps = sample_timesteps
)
img = vae.decode(img)

View File

@@ -1,6 +1,7 @@
import os
import webdataset as wds
import torch
from torch.utils.data import DataLoader
import numpy as np
import fsspec
import shutil
@@ -255,7 +256,7 @@ def create_image_embedding_dataloader(
)
if shuffle_num is not None and shuffle_num > 0:
ds.shuffle(1000)
return wds.WebLoader(
return DataLoader(
ds,
num_workers=num_workers,
batch_size=batch_size,

View File

@@ -1,5 +1,6 @@
import urllib.request
import os
import json
from pathlib import Path
import shutil
from itertools import zip_longest
@@ -37,14 +38,17 @@ class BaseLogger:
data_path (str): A file path for storing temporary data.
verbose (bool): Whether of not to always print logs to the console.
"""
def __init__(self, data_path: str, verbose: bool = False, **kwargs):
def __init__(self, data_path: str, resume: bool = False, auto_resume: bool = False, verbose: bool = False, **kwargs):
self.data_path = Path(data_path)
self.resume = resume
self.auto_resume = auto_resume
self.verbose = verbose
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
"""
Initializes the logger.
Errors if the logger is invalid.
full_config is the config file dict while extra_config is anything else from the script that is not defined the config file.
"""
raise NotImplementedError
@@ -60,6 +64,14 @@ class BaseLogger:
def log_error(self, error_string, **kwargs) -> None:
raise NotImplementedError
def get_resume_data(self, **kwargs) -> dict:
"""
Sets tracker attributes that along with { "resume": True } will be used to resume training.
It is assumed that after init is called this data will be complete.
If the logger does not have any resume functionality, it should return an empty dict.
"""
raise NotImplementedError
class ConsoleLogger(BaseLogger):
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
print("Logging to console")
@@ -76,6 +88,9 @@ class ConsoleLogger(BaseLogger):
def log_error(self, error_string, **kwargs) -> None:
print(error_string)
def get_resume_data(self, **kwargs) -> dict:
return {}
class WandbLogger(BaseLogger):
"""
Logs to a wandb run.
@@ -85,7 +100,6 @@ class WandbLogger(BaseLogger):
wandb_project (str): The wandb project to log to.
wandb_run_id (str): The wandb run id to resume.
wandb_run_name (str): The wandb run name to use.
wandb_resume (bool): Whether to resume a wandb run.
"""
def __init__(self,
data_path: str,
@@ -93,7 +107,6 @@ class WandbLogger(BaseLogger):
wandb_project: str,
wandb_run_id: Optional[str] = None,
wandb_run_name: Optional[str] = None,
wandb_resume: bool = False,
**kwargs
):
super().__init__(data_path, **kwargs)
@@ -101,7 +114,6 @@ class WandbLogger(BaseLogger):
self.project = wandb_project
self.run_id = wandb_run_id
self.run_name = wandb_run_name
self.resume = wandb_resume
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
assert self.entity is not None, "wandb_entity must be specified for wandb logger"
@@ -149,6 +161,14 @@ class WandbLogger(BaseLogger):
print(error_string)
self.wandb.log({"error": error_string, **kwargs}, step=step)
def get_resume_data(self, **kwargs) -> dict:
# In order to resume, we need wandb_entity, wandb_project, and wandb_run_id
return {
"entity": self.entity,
"project": self.project,
"run_id": self.wandb.run.id
}
logger_type_map = {
'console': ConsoleLogger,
'wandb': WandbLogger,
@@ -168,8 +188,9 @@ class BaseLoader:
Parameters:
data_path (str): A file path for storing temporary data.
"""
def __init__(self, data_path: str, **kwargs):
def __init__(self, data_path: str, only_auto_resume: bool = False, **kwargs):
self.data_path = Path(data_path)
self.only_auto_resume = only_auto_resume
def init(self, logger: BaseLogger, **kwargs) -> None:
raise NotImplementedError
@@ -304,6 +325,10 @@ class LocalSaver(BaseSaver):
def save_file(self, local_path: str, save_path: str, **kwargs) -> None:
# Copy the file to save_path
save_path_file_name = Path(save_path).name
# Make sure parent directory exists
save_path_parent = Path(save_path).parent
if not save_path_parent.exists():
save_path_parent.mkdir(parents=True)
print(f"Saving {save_path_file_name} {self.save_type} to local path {save_path}")
shutil.copy(local_path, save_path)
@@ -385,11 +410,7 @@ class Tracker:
def __init__(self, data_path: Optional[str] = DEFAULT_DATA_PATH, overwrite_data_path: bool = False, dummy_mode: bool = False):
self.data_path = Path(data_path)
if not dummy_mode:
if overwrite_data_path:
if self.data_path.exists():
shutil.rmtree(self.data_path)
self.data_path.mkdir(parents=True)
else:
if not overwrite_data_path:
assert not self.data_path.exists(), f'Data path {self.data_path} already exists. Set overwrite_data_path to True to overwrite.'
if not self.data_path.exists():
self.data_path.mkdir(parents=True)
@@ -398,7 +419,46 @@ class Tracker:
self.savers: List[BaseSaver]= []
self.dummy_mode = dummy_mode
def _load_auto_resume(self) -> bool:
# If the file does not exist, we return False. If autoresume is enabled we print a warning so that the user can know that this is the first run.
if not self.auto_resume_path.exists():
if self.logger.auto_resume:
print("Auto_resume is enabled but no auto_resume.json file exists. Assuming this is the first run.")
return False
# Now we know that the autoresume file exists, but if we are not auto resuming we should remove it so that we don't accidentally load it next time
if not self.logger.auto_resume:
print(f'Removing auto_resume.json because auto_resume is not enabled in the config')
self.auto_resume_path.unlink()
return False
# Otherwise we read the json into a dictionary will will override parts of logger.__dict__
with open(self.auto_resume_path, 'r') as f:
auto_resume_dict = json.load(f)
# Check if the logger is of the same type as the autoresume save
if auto_resume_dict["logger_type"] != self.logger.__class__.__name__:
raise Exception(f'The logger type in the auto_resume file is {auto_resume_dict["logger_type"]} but the current logger is {self.logger.__class__.__name__}. Either use the original logger type, set `auto_resume` to `False`, or delete your existing tracker-data folder.')
# Then we are ready to override the logger with the autoresume save
self.logger.__dict__["resume"] = True
print(f"Updating {self.logger.__dict__} with {auto_resume_dict}")
self.logger.__dict__.update(auto_resume_dict)
return True
def _save_auto_resume(self):
# Gets the autoresume dict from the logger and adds "logger_type" to it then saves it to the auto_resume file
auto_resume_dict = self.logger.get_resume_data()
auto_resume_dict['logger_type'] = self.logger.__class__.__name__
with open(self.auto_resume_path, 'w') as f:
json.dump(auto_resume_dict, f)
def init(self, full_config: BaseModel, extra_config: dict):
self.auto_resume_path = self.data_path / 'auto_resume.json'
# Check for resuming the run
self.did_auto_resume = self._load_auto_resume()
if self.did_auto_resume:
print(f'\n\nWARNING: RUN HAS BEEN AUTO-RESUMED WITH THE LOGGER TYPE {self.logger.__class__.__name__}.\nIf this was not your intention, stop this run and set `auto_resume` to `False` in the config.\n\n')
print(f"New logger config: {self.logger.__dict__}")
assert self.logger is not None, '`logger` must be set before `init` is called'
if self.dummy_mode:
# The only thing we need is a loader
@@ -406,12 +466,17 @@ class Tracker:
self.loader.init(self.logger)
return
assert len(self.savers) > 0, '`savers` must be set before `init` is called'
self.logger.init(full_config, extra_config)
if self.loader is not None:
self.loader.init(self.logger)
for saver in self.savers:
saver.init(self.logger)
if self.logger.auto_resume:
# Then we need to save the autoresume file. It is assumed after logger.init is called that the logger is ready to be saved.
self._save_auto_resume()
def add_logger(self, logger: BaseLogger):
self.logger = logger
@@ -503,11 +568,16 @@ class Tracker:
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
print(f'Error saving checkpoint: {e}')
@property
def can_recall(self):
# Defines whether a recall can be performed.
return self.loader is not None and (not self.loader.only_auto_resume or self.did_auto_resume)
def recall(self):
if self.loader is not None:
if self.can_recall:
return self.loader.recall()
else:
raise ValueError('No loader specified')
raise ValueError('Tried to recall, but no loader was set or auto-resume was not performed.')

View File

@@ -47,6 +47,8 @@ class TrainSplitConfig(BaseModel):
class TrackerLogConfig(BaseModel):
log_type: str = 'console'
resume: bool = False # For logs that are saved to unique locations, resume a previous run
auto_resume: bool = False # If the process crashes and restarts, resume from the run that crashed
verbose: bool = False
class Config:
@@ -59,6 +61,7 @@ class TrackerLogConfig(BaseModel):
class TrackerLoadConfig(BaseModel):
load_from: Optional[str] = None
only_auto_resume: bool = False # Only attempt to load if the logger is auto-resuming
class Config:
extra = "allow"
@@ -151,6 +154,7 @@ class DiffusionPriorConfig(BaseModel):
image_size: int
image_channels: int = 3
timesteps: int = 1000
sample_timesteps: Optional[int] = None
cond_drop_prob: float = 0.
loss_type: str = 'l2'
predict_x_start: bool = True
@@ -230,6 +234,7 @@ class DecoderConfig(BaseModel):
clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided
channels: int = 3
timesteps: int = 1000
sample_timesteps: Optional[SingularOrIterable(int)] = None
loss_type: str = 'l2'
beta_schedule: ListOrTuple(str) = 'cosine'
learned_variance: bool = True

View File

@@ -21,7 +21,7 @@ import pytorch_warmup as warmup
from ema_pytorch import EMA
from accelerate import Accelerator
from accelerate import Accelerator, DistributedType
import numpy as np
@@ -76,6 +76,7 @@ def cast_torch_tensor(fn):
def inner(model, *args, **kwargs):
device = kwargs.pop('_device', next(model.parameters()).device)
cast_device = kwargs.pop('_cast_device', True)
cast_deepspeed_precision = kwargs.pop('_cast_deepspeed_precision', True)
kwargs_keys = kwargs.keys()
all_args = (*args, *kwargs.values())
@@ -85,6 +86,21 @@ def cast_torch_tensor(fn):
if cast_device:
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
if cast_deepspeed_precision:
try:
accelerator = model.accelerator
if accelerator is not None and accelerator.distributed_type == DistributedType.DEEPSPEED:
cast_type_map = {
"fp16": torch.half,
"bf16": torch.bfloat16,
"no": torch.float
}
precision_type = cast_type_map[accelerator.mixed_precision]
all_args = tuple(map(lambda t: t.to(precision_type) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
except AttributeError:
# Then this model doesn't have an accelerator
pass
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
@@ -192,6 +208,7 @@ class DiffusionPriorTrainer(nn.Module):
self.device = diffusion_prior_device
else:
self.device = accelerator.device if exists(accelerator) else device
diffusion_prior.to(self.device)
# save model
@@ -445,6 +462,7 @@ class DecoderTrainer(nn.Module):
self,
decoder,
accelerator = None,
dataloaders = None,
use_ema = True,
lr = 1e-4,
wd = 1e-2,
@@ -507,11 +525,31 @@ class DecoderTrainer(nn.Module):
self.register_buffer('steps', torch.tensor([0] * self.num_unets))
if self.accelerator.distributed_type == DistributedType.DEEPSPEED and decoder.clip is not None:
# Then we need to make sure clip is using the correct precision or else deepspeed will error
cast_type_map = {
"fp16": torch.half,
"bf16": torch.bfloat16,
"no": torch.float
}
precision_type = cast_type_map[accelerator.mixed_precision]
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
clip = decoder.clip
clip.to(precision_type)
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
schedulers = list(self.accelerator.prepare(*schedulers))
self.decoder = decoder
# prepare dataloaders
train_loader = val_loader = None
if exists(dataloaders):
train_loader, val_loader = self.accelerator.prepare(dataloaders["train"], dataloaders["val"])
self.train_loader = train_loader
self.val_loader = val_loader
# store optimizers
for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
@@ -526,6 +564,17 @@ class DecoderTrainer(nn.Module):
self.warmup_schedulers = warmup_schedulers
def validate_and_return_unet_number(self, unet_number = None):
if self.num_unets == 1:
unet_number = default(unet_number, 1)
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
return unet_number
def num_steps_taken(self, unet_number = None):
unet_number = self.validate_and_return_unet_number(unet_number)
return self.steps[unet_number - 1].item()
def save(self, path, overwrite = True, **kwargs):
path = Path(path)
assert not (path.exists() and not overwrite)
@@ -594,10 +643,7 @@ class DecoderTrainer(nn.Module):
self.steps += F.one_hot(unet_index_tensor, num_classes = len(self.steps))
def update(self, unet_number = None):
if self.num_unets == 1:
unet_number = default(unet_number, 1)
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
unet_number = self.validate_and_return_unet_number(unet_number)
index = unet_number - 1
optimizer = getattr(self, f'optim{index}')
@@ -663,11 +709,13 @@ class DecoderTrainer(nn.Module):
max_batch_size = None,
**kwargs
):
if self.num_unets == 1:
unet_number = default(unet_number, 1)
unet_number = self.validate_and_return_unet_number(unet_number)
total_loss = 0.
using_amp = self.accelerator.mixed_precision != 'no'
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with self.accelerator.autocast():
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)

View File

@@ -1 +1 @@
__version__ = '0.16.14'
__version__ = '0.19.6'

BIN
test_data/0.tar Normal file

Binary file not shown.

BIN
test_data/1.tar Normal file

Binary file not shown.

BIN
test_data/2.tar Normal file

Binary file not shown.

BIN
test_data/3.tar Normal file

Binary file not shown.

BIN
test_data/4.tar Normal file

Binary file not shown.

BIN
test_data/5.tar Normal file

Binary file not shown.

BIN
test_data/6.tar Normal file

Binary file not shown.

BIN
test_data/7.tar Normal file

Binary file not shown.

BIN
test_data/8.tar Normal file

Binary file not shown.

BIN
test_data/9.tar Normal file

Binary file not shown.

View File

@@ -132,7 +132,7 @@ def get_example_data(dataloader, device, n=5):
break
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend=""):
def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend="", match_image_size=True):
"""
Takes example data and generates images from the embeddings
Returns three lists: real images, generated images, and captions
@@ -160,6 +160,9 @@ def generate_samples(trainer, example_data, condition_on_text_encodings=False, t
samples = trainer.sample(**sample_params)
generated_images = list(samples)
captions = [text_prepend + txt for txt in txts]
if match_image_size:
generated_image_size = generated_images[0].shape[-1]
real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images]
return real_images, generated_images, captions
def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""):
@@ -167,14 +170,6 @@ def generate_grid_samples(trainer, examples, condition_on_text_encodings=False,
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
"""
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend)
real_image_size = real_images[0].shape[-1]
generated_image_size = generated_images[0].shape[-1]
# training images may be larger than the generated one
if real_image_size > generated_image_size:
real_images = [resize_image_to(image, generated_image_size) for image in real_images]
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions
@@ -279,6 +274,7 @@ def train(
trainer = DecoderTrainer(
decoder=decoder,
accelerator=accelerator,
dataloaders=dataloaders,
**kwargs
)
@@ -289,9 +285,8 @@ def train(
sample = 0
samples_seen = 0
val_sample = 0
step = lambda: int(trainer.step.item())
if tracker.loader is not None:
if tracker.can_recall:
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
if next_task == 'train':
sample = recalled_sample
@@ -304,6 +299,8 @@ def train(
if not exists(unet_training_mask):
# Then the unet mask should be true for all unets in the decoder
unet_training_mask = [True] * trainer.num_unets
first_training_unet = min(index for index, mask in enumerate(unet_training_mask) if mask)
step = lambda: int(trainer.num_steps_taken(unet_number=first_training_unet+1))
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
accelerator.print(print_ribbon("Generating Example Data", repeat=40))
@@ -326,7 +323,7 @@ def train(
last_snapshot = sample
if next_task == 'train':
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
for i, (img, emb, txt) in enumerate(trainer.train_loader):
# We want to count the total number of samples across all processes
sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
@@ -419,7 +416,7 @@ def train(
timer = Timer()
accelerator.wait_for_everyone()
i = 0
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
for i, (img, emb, txt) in enumerate(trainer.val_loader): # Use the accelerate prepared loader
val_sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(val_sample_length_tensor)
total_samples = all_samples.sum().item()
@@ -524,6 +521,20 @@ def initialize_training(config: TrainDecoderConfig, config_path):
# Set up accelerator for configurable distributed training
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
if accelerator.num_processes > 1:
# We are using distributed training and want to immediately ensure all can connect
accelerator.print("Waiting for all processes to connect...")
accelerator.wait_for_everyone()
accelerator.print("All processes online and connected")
# If we are in deepspeed fp16 mode, we must ensure learned variance is off
if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance:
raise ValueError("DeepSpeed fp16 mode does not support learned variance")
if accelerator.process_index != accelerator.local_process_index and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED:
# This is an invalid configuration until we figure out how to handle this
raise ValueError("DeepSpeed does not support multi-node distributed training")
# Set up data
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))