Compare commits

...

11 Commits

Author SHA1 Message Date
Romain Beaumont
3a1dea7d97 Fix decoder test by fixing the resizing output size 2022-07-09 11:36:22 +02:00
Phil Wang
097afda606 0.18.0 2022-07-08 18:18:38 -07:00
Aidan Dempster
5c520db825 Added deepspeed support (#195) 2022-07-08 18:18:08 -07:00
Phil Wang
3070610231 just force it so researcher can never pass in an image that is less than the size that is required for CLIP or CoCa 2022-07-08 18:17:29 -07:00
Aidan Dempster
870aeeca62 Fixed issue where evaluation would error when large image was loaded (#194) 2022-07-08 17:11:34 -07:00
Romain Beaumont
f28dc6dc01 setup simple ci (#193) 2022-07-08 16:51:56 -07:00
Phil Wang
081d8d3484 0.17.0 2022-07-08 13:36:26 -07:00
Aidan Dempster
a71f693a26 Add the ability to auto restart the last run when started after a crash (#191)
* Added autoresume after crash functionality to the trackers

* Updated documentation

* Clarified what goes in the autorestart object

* Fixed style issues

Unraveled conditional block

Chnaged to using helper function to get step count
2022-07-08 13:35:40 -07:00
Phil Wang
d7bc5fbedd expose num_steps_taken helper method on trainer to retrieve number of training steps of each unet 2022-07-08 13:00:56 -07:00
Phil Wang
8c823affff allow for control over use of nearest interp method of downsampling low res conditioning, in addition to being able to turn it off 2022-07-08 11:44:43 -07:00
Phil Wang
ec7cab01d9 extra insurance that diffusion prior is on the correct device, when using trainer with accelerator or device was given 2022-07-07 10:08:33 -07:00
23 changed files with 347 additions and 46 deletions

33
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,33 @@
name: Continuous integration
on:
push:
branches:
- main
pull_request:
branches:
- main
jobs:
tests:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install
run: |
python3 -m venv .env
source .env/bin/activate
make install
- name: Tests
run: |
source .env/bin/activate
make test

2
.gitignore vendored
View File

@@ -136,3 +136,5 @@ dmypy.json
# Pyre type checker
.pyre/
.tracker_data
*.pth

6
Makefile Normal file
View File

@@ -0,0 +1,6 @@
install:
pip install -U pip
pip install -e .
test:
CUDA_VISIBLE_DEVICES= python train_decoder.py --config_file configs/train_decoder_config.test.json

View File

@@ -30,6 +30,7 @@ Defines the configuration options for the decoder model. The unets defined above
| `loss_type` | No | `l2` | The loss function. Options are `l1`, `huber`, or `l2`. |
| `beta_schedule` | No | `cosine` | The noising schedule. Options are `cosine`, `linear`, `quadratic`, `jsd`, or `sigmoid`. |
| `learned_variance` | No | `True` | Whether to learn the variance. |
| `clip` | No | `None` | The clip model to use if embeddings are being generated on the fly. Takes keys `make` and `model` with defaults `openai` and `ViT-L/14`. |
Any parameter from the `Decoder` constructor can also be given here.
@@ -39,7 +40,8 @@ Settings for creation of the dataloaders.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `webdataset_base_url` | Yes | N/A | The url of a shard in the webdataset with the shard replaced with `{}`[^1]. |
| `embeddings_url` | No | N/A | The url of the folder containing embeddings shards. Not required if embeddings are in webdataset. |
| `img_embeddings_url` | No | `None` | The url of the folder containing image embeddings shards. Not required if embeddings are in webdataset or clip is being used. |
| `text_embeddings_url` | No | `None` | The url of the folder containing text embeddings shards. Not required if embeddings are in webdataset or clip is being used. |
| `num_workers` | No | `4` | The number of workers used in the dataloader. |
| `batch_size` | No | `64` | The batch size. |
| `start_shard` | No | `0` | Defines the start of the shard range the dataset will recall. |
@@ -106,6 +108,13 @@ Tracking is split up into three sections:
**Logging:**
All loggers have the following keys:
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `log_type` | Yes | N/A | The type of logger class to use. |
| `resume` | No | `False` | For loggers that have the option to resume an old run, resume it using maually input parameters. |
| `auto_resume` | No | `False` | If true, the logger will attempt to resume an old run using parameters from that previous run. |
If using `console` there is no further configuration than setting `log_type` to `console`.
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
@@ -119,10 +128,15 @@ If using `wandb`
| `wandb_project` | Yes | N/A | The wandb project save the run to. |
| `wandb_run_name` | No | `None` | The wandb run name. |
| `wandb_run_id` | No | `None` | The wandb run id. Used if resuming an old run. |
| `wandb_resume` | No | `False` | Whether to resume an old run. |
**Loading:**
All loaders have the following keys:
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |
| `load_from` | Yes | N/A | The type of loader class to use. |
| `only_auto_resume` | No | `False` | If true, the loader will only load the model if the run is being auto resumed. |
If using `local`
| Option | Required | Default | Description |
| ------ | -------- | ------- | ----------- |

View File

@@ -20,7 +20,7 @@
},
"data": {
"webdataset_base_url": "pipe:s3cmd get s3://bucket/path/{}.tar -",
"embeddings_url": "s3://bucket/embeddings/path/",
"img_embeddings_url": "s3://bucket/img_embeddings/path/",
"num_workers": 4,
"batch_size": 64,
"start_shard": 0,

View File

@@ -0,0 +1,102 @@
{
"decoder": {
"unets": [
{
"dim": 16,
"image_embed_dim": 768,
"cond_dim": 16,
"channels": 3,
"dim_mults": [1, 2, 4, 8],
"attn_dim_head": 16,
"attn_heads": 4,
"self_attn": [false, true, true, true]
}
],
"clip": {
"make": "openai",
"model": "ViT-L/14"
},
"timesteps": 10,
"image_sizes": [64],
"channels": 3,
"loss_type": "l2",
"beta_schedule": ["cosine"],
"learned_variance": true
},
"data": {
"webdataset_base_url": "test_data/{}.tar",
"num_workers": 4,
"batch_size": 4,
"start_shard": 0,
"end_shard": 9,
"shard_width": 1,
"index_width": 1,
"splits": {
"train": 0.75,
"val": 0.15,
"test": 0.1
},
"shuffle_train": false,
"resample_train": true,
"preprocessing": {
"RandomResizedCrop": {
"size": [224, 224],
"scale": [0.75, 1.0],
"ratio": [1.0, 1.0]
},
"ToTensor": true
}
},
"train": {
"epochs": 1,
"lr": 1e-16,
"wd": 0.01,
"max_grad_norm": 0.5,
"save_every_n_samples": 100,
"n_sample_images": 1,
"device": "cpu",
"epoch_samples": 50,
"validation_samples": 5,
"use_ema": true,
"ema_beta": 0.99,
"amp": false,
"save_all": false,
"save_latest": true,
"save_best": true,
"unet_training_mask": [true]
},
"evaluate": {
"n_evaluation_samples": 2,
"FID": {
"feature": 64
},
"IS": {
"feature": 64,
"splits": 10
},
"KID": {
"feature": 64,
"subset_size": 2
},
"LPIPS": {
"net_type": "vgg",
"reduction": "mean"
}
},
"tracker": {
"overwrite_data_path": true,
"log": {
"log_type": "console"
},
"load": {
"load_from": null
},
"save": [{
"save_to": "local"
}]
}
}

View File

@@ -125,14 +125,23 @@ def log(t, eps = 1e-12):
def l2norm(t):
return F.normalize(t, dim = -1)
def resize_image_to(image, target_image_size, clamp_range = None):
def resize_image_to(
image,
target_image_size,
clamp_range = None,
nearest = False,
**kwargs
):
orig_image_size = image.shape[-1]
if orig_image_size == target_image_size:
return image
scale_factors = target_image_size / orig_image_size
out = resize(image, scale_factors = scale_factors)
if not nearest:
scale_factors = target_image_size / orig_image_size
out = resize(image, scale_factors = scale_factors, **kwargs)
else:
out = F.interpolate(image, target_image_size, mode = 'nearest', align_corners = False)
if exists(clamp_range):
out = out.clamp(*clamp_range)
@@ -160,6 +169,11 @@ class BaseClipAdapter(nn.Module):
self.clip = clip
self.overrides = kwargs
def validate_and_resize_image(self, image):
image_size = image.shape[-1]
assert image_size >= self.image_size, f'you are passing in an image of size {image_size} but CLIP requires the image size to be at least {self.image_size}'
return resize_image_to(image, self.image_size)
@property
def dim_latent(self):
raise NotImplementedError
@@ -210,7 +224,7 @@ class XClipAdapter(BaseClipAdapter):
@torch.no_grad()
def embed_image(self, image):
image = resize_image_to(image, self.image_size)
image = self.validate_and_resize_image(image)
encoder_output = self.clip.visual_transformer(image)
image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:]
image_embed = self.clip.to_visual_latent(image_cls)
@@ -245,7 +259,7 @@ class CoCaAdapter(BaseClipAdapter):
@torch.no_grad()
def embed_image(self, image):
image = resize_image_to(image, self.image_size)
image = self.validate_and_resize_image(image)
image_embed, image_encodings = self.clip.embed_image(image)
return EmbeddedImage(image_embed, image_encodings)
@@ -306,7 +320,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
@torch.no_grad()
def embed_image(self, image):
assert not self.cleared
image = resize_image_to(image, self.image_size)
image = self.validate_and_resize_image(image)
image = self.clip_normalize(image)
image_embed = self.clip.encode_image(image)
return EmbeddedImage(l2norm(image_embed.float()), None)
@@ -1781,12 +1795,15 @@ class LowresConditioner(nn.Module):
def __init__(
self,
downsample_first = True,
downsample_mode_nearest = False,
blur_sigma = 0.6,
blur_kernel_size = 3,
input_image_range = None
):
super().__init__()
self.downsample_first = downsample_first
self.downsample_mode_nearest = downsample_mode_nearest
self.input_image_range = input_image_range
self.blur_sigma = blur_sigma
@@ -1802,7 +1819,7 @@ class LowresConditioner(nn.Module):
blur_kernel_size = None
):
if self.training and self.downsample_first and exists(downsample_image_size):
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range)
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range, nearest = self.downsample_mode_nearest)
if self.training:
# when training, blur the low resolution conditional image
@@ -1845,6 +1862,7 @@ class Decoder(nn.Module):
image_sizes = None, # for cascading ddpm, image size at each stage
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
lowres_downsample_mode_nearest = False, # cascading ddpm - whether to use nearest mode downsampling for lower resolution
blur_sigma = 0.6, # cascading ddpm - blur sigma
blur_kernel_size = 3, # cascading ddpm - blur kernel size
clip_denoised = True,
@@ -1987,6 +2005,7 @@ class Decoder(nn.Module):
self.to_lowres_cond = LowresConditioner(
downsample_first = lowres_downsample_first,
downsample_mode_nearest = lowres_downsample_mode_nearest,
blur_sigma = blur_sigma,
blur_kernel_size = blur_kernel_size,
input_image_range = self.input_image_range

View File

@@ -1,6 +1,7 @@
import os
import webdataset as wds
import torch
from torch.utils.data import DataLoader
import numpy as np
import fsspec
import shutil
@@ -255,7 +256,7 @@ def create_image_embedding_dataloader(
)
if shuffle_num is not None and shuffle_num > 0:
ds.shuffle(1000)
return wds.WebLoader(
return DataLoader(
ds,
num_workers=num_workers,
batch_size=batch_size,

View File

@@ -1,5 +1,6 @@
import urllib.request
import os
import json
from pathlib import Path
import shutil
from itertools import zip_longest
@@ -37,14 +38,17 @@ class BaseLogger:
data_path (str): A file path for storing temporary data.
verbose (bool): Whether of not to always print logs to the console.
"""
def __init__(self, data_path: str, verbose: bool = False, **kwargs):
def __init__(self, data_path: str, resume: bool = False, auto_resume: bool = False, verbose: bool = False, **kwargs):
self.data_path = Path(data_path)
self.resume = resume
self.auto_resume = auto_resume
self.verbose = verbose
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
"""
Initializes the logger.
Errors if the logger is invalid.
full_config is the config file dict while extra_config is anything else from the script that is not defined the config file.
"""
raise NotImplementedError
@@ -60,6 +64,14 @@ class BaseLogger:
def log_error(self, error_string, **kwargs) -> None:
raise NotImplementedError
def get_resume_data(self, **kwargs) -> dict:
"""
Sets tracker attributes that along with { "resume": True } will be used to resume training.
It is assumed that after init is called this data will be complete.
If the logger does not have any resume functionality, it should return an empty dict.
"""
raise NotImplementedError
class ConsoleLogger(BaseLogger):
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
print("Logging to console")
@@ -76,6 +88,9 @@ class ConsoleLogger(BaseLogger):
def log_error(self, error_string, **kwargs) -> None:
print(error_string)
def get_resume_data(self, **kwargs) -> dict:
return {}
class WandbLogger(BaseLogger):
"""
Logs to a wandb run.
@@ -85,7 +100,6 @@ class WandbLogger(BaseLogger):
wandb_project (str): The wandb project to log to.
wandb_run_id (str): The wandb run id to resume.
wandb_run_name (str): The wandb run name to use.
wandb_resume (bool): Whether to resume a wandb run.
"""
def __init__(self,
data_path: str,
@@ -93,7 +107,6 @@ class WandbLogger(BaseLogger):
wandb_project: str,
wandb_run_id: Optional[str] = None,
wandb_run_name: Optional[str] = None,
wandb_resume: bool = False,
**kwargs
):
super().__init__(data_path, **kwargs)
@@ -101,7 +114,6 @@ class WandbLogger(BaseLogger):
self.project = wandb_project
self.run_id = wandb_run_id
self.run_name = wandb_run_name
self.resume = wandb_resume
def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
assert self.entity is not None, "wandb_entity must be specified for wandb logger"
@@ -149,6 +161,14 @@ class WandbLogger(BaseLogger):
print(error_string)
self.wandb.log({"error": error_string, **kwargs}, step=step)
def get_resume_data(self, **kwargs) -> dict:
# In order to resume, we need wandb_entity, wandb_project, and wandb_run_id
return {
"entity": self.entity,
"project": self.project,
"run_id": self.wandb.run.id
}
logger_type_map = {
'console': ConsoleLogger,
'wandb': WandbLogger,
@@ -168,8 +188,9 @@ class BaseLoader:
Parameters:
data_path (str): A file path for storing temporary data.
"""
def __init__(self, data_path: str, **kwargs):
def __init__(self, data_path: str, only_auto_resume: bool = False, **kwargs):
self.data_path = Path(data_path)
self.only_auto_resume = only_auto_resume
def init(self, logger: BaseLogger, **kwargs) -> None:
raise NotImplementedError
@@ -304,6 +325,10 @@ class LocalSaver(BaseSaver):
def save_file(self, local_path: str, save_path: str, **kwargs) -> None:
# Copy the file to save_path
save_path_file_name = Path(save_path).name
# Make sure parent directory exists
save_path_parent = Path(save_path).parent
if not save_path_parent.exists():
save_path_parent.mkdir(parents=True)
print(f"Saving {save_path_file_name} {self.save_type} to local path {save_path}")
shutil.copy(local_path, save_path)
@@ -385,11 +410,7 @@ class Tracker:
def __init__(self, data_path: Optional[str] = DEFAULT_DATA_PATH, overwrite_data_path: bool = False, dummy_mode: bool = False):
self.data_path = Path(data_path)
if not dummy_mode:
if overwrite_data_path:
if self.data_path.exists():
shutil.rmtree(self.data_path)
self.data_path.mkdir(parents=True)
else:
if not overwrite_data_path:
assert not self.data_path.exists(), f'Data path {self.data_path} already exists. Set overwrite_data_path to True to overwrite.'
if not self.data_path.exists():
self.data_path.mkdir(parents=True)
@@ -398,7 +419,46 @@ class Tracker:
self.savers: List[BaseSaver]= []
self.dummy_mode = dummy_mode
def _load_auto_resume(self) -> bool:
# If the file does not exist, we return False. If autoresume is enabled we print a warning so that the user can know that this is the first run.
if not self.auto_resume_path.exists():
if self.logger.auto_resume:
print("Auto_resume is enabled but no auto_resume.json file exists. Assuming this is the first run.")
return False
# Now we know that the autoresume file exists, but if we are not auto resuming we should remove it so that we don't accidentally load it next time
if not self.logger.auto_resume:
print(f'Removing auto_resume.json because auto_resume is not enabled in the config')
self.auto_resume_path.unlink()
return False
# Otherwise we read the json into a dictionary will will override parts of logger.__dict__
with open(self.auto_resume_path, 'r') as f:
auto_resume_dict = json.load(f)
# Check if the logger is of the same type as the autoresume save
if auto_resume_dict["logger_type"] != self.logger.__class__.__name__:
raise Exception(f'The logger type in the auto_resume file is {auto_resume_dict["logger_type"]} but the current logger is {self.logger.__class__.__name__}. Either use the original logger type, set `auto_resume` to `False`, or delete your existing tracker-data folder.')
# Then we are ready to override the logger with the autoresume save
self.logger.__dict__["resume"] = True
print(f"Updating {self.logger.__dict__} with {auto_resume_dict}")
self.logger.__dict__.update(auto_resume_dict)
return True
def _save_auto_resume(self):
# Gets the autoresume dict from the logger and adds "logger_type" to it then saves it to the auto_resume file
auto_resume_dict = self.logger.get_resume_data()
auto_resume_dict['logger_type'] = self.logger.__class__.__name__
with open(self.auto_resume_path, 'w') as f:
json.dump(auto_resume_dict, f)
def init(self, full_config: BaseModel, extra_config: dict):
self.auto_resume_path = self.data_path / 'auto_resume.json'
# Check for resuming the run
self.did_auto_resume = self._load_auto_resume()
if self.did_auto_resume:
print(f'\n\nWARNING: RUN HAS BEEN AUTO-RESUMED WITH THE LOGGER TYPE {self.logger.__class__.__name__}.\nIf this was not your intention, stop this run and set `auto_resume` to `False` in the config.\n\n')
print(f"New logger config: {self.logger.__dict__}")
assert self.logger is not None, '`logger` must be set before `init` is called'
if self.dummy_mode:
# The only thing we need is a loader
@@ -406,12 +466,17 @@ class Tracker:
self.loader.init(self.logger)
return
assert len(self.savers) > 0, '`savers` must be set before `init` is called'
self.logger.init(full_config, extra_config)
if self.loader is not None:
self.loader.init(self.logger)
for saver in self.savers:
saver.init(self.logger)
if self.logger.auto_resume:
# Then we need to save the autoresume file. It is assumed after logger.init is called that the logger is ready to be saved.
self._save_auto_resume()
def add_logger(self, logger: BaseLogger):
self.logger = logger
@@ -503,11 +568,16 @@ class Tracker:
self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
print(f'Error saving checkpoint: {e}')
@property
def can_recall(self):
# Defines whether a recall can be performed.
return self.loader is not None and (not self.loader.only_auto_resume or self.did_auto_resume)
def recall(self):
if self.loader is not None:
if self.can_recall:
return self.loader.recall()
else:
raise ValueError('No loader specified')
raise ValueError('Tried to recall, but no loader was set or auto-resume was not performed.')

View File

@@ -47,6 +47,8 @@ class TrainSplitConfig(BaseModel):
class TrackerLogConfig(BaseModel):
log_type: str = 'console'
resume: bool = False # For logs that are saved to unique locations, resume a previous run
auto_resume: bool = False # If the process crashes and restarts, resume from the run that crashed
verbose: bool = False
class Config:
@@ -59,6 +61,7 @@ class TrackerLogConfig(BaseModel):
class TrackerLoadConfig(BaseModel):
load_from: Optional[str] = None
only_auto_resume: bool = False # Only attempt to load if the logger is auto-resuming
class Config:
extra = "allow"

View File

@@ -21,7 +21,7 @@ import pytorch_warmup as warmup
from ema_pytorch import EMA
from accelerate import Accelerator
from accelerate import Accelerator, DistributedType
import numpy as np
@@ -76,6 +76,7 @@ def cast_torch_tensor(fn):
def inner(model, *args, **kwargs):
device = kwargs.pop('_device', next(model.parameters()).device)
cast_device = kwargs.pop('_cast_device', True)
cast_deepspeed_precision = kwargs.pop('_cast_deepspeed_precision', True)
kwargs_keys = kwargs.keys()
all_args = (*args, *kwargs.values())
@@ -85,6 +86,21 @@ def cast_torch_tensor(fn):
if cast_device:
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
if cast_deepspeed_precision:
try:
accelerator = model.accelerator
if accelerator is not None and accelerator.distributed_type == DistributedType.DEEPSPEED:
cast_type_map = {
"fp16": torch.half,
"bf16": torch.bfloat16,
"no": torch.float
}
precision_type = cast_type_map[accelerator.mixed_precision]
all_args = tuple(map(lambda t: t.to(precision_type) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
except AttributeError:
# Then this model doesn't have an accelerator
pass
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
@@ -192,6 +208,7 @@ class DiffusionPriorTrainer(nn.Module):
self.device = diffusion_prior_device
else:
self.device = accelerator.device if exists(accelerator) else device
diffusion_prior.to(self.device)
# save model
@@ -445,6 +462,7 @@ class DecoderTrainer(nn.Module):
self,
decoder,
accelerator = None,
dataloaders = None,
use_ema = True,
lr = 1e-4,
wd = 1e-2,
@@ -507,9 +525,21 @@ class DecoderTrainer(nn.Module):
self.register_buffer('steps', torch.tensor([0] * self.num_unets))
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
schedulers = list(self.accelerator.prepare(*schedulers))
if self.accelerator.distributed_type == DistributedType.DEEPSPEED and decoder.clip is not None:
# Then we need to make sure clip is using the correct precision or else deepspeed will error
cast_type_map = {
"fp16": torch.half,
"bf16": torch.bfloat16,
"no": torch.float
}
precision_type = cast_type_map[accelerator.mixed_precision]
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
clip = decoder.clip
clip.to(precision_type)
decoder, train_loader, val_loader, *optimizers = list(self.accelerator.prepare(decoder, dataloaders["train"], dataloaders["val"], *optimizers))
self.train_loader = train_loader
self.val_loader = val_loader
self.decoder = decoder
# store optimizers
@@ -526,6 +556,17 @@ class DecoderTrainer(nn.Module):
self.warmup_schedulers = warmup_schedulers
def validate_and_return_unet_number(self, unet_number = None):
if self.num_unets == 1:
unet_number = default(unet_number, 1)
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
return unet_number
def num_steps_taken(self, unet_number = None):
unet_number = self.validate_and_return_unet_number(unet_number)
return self.steps[unet_number - 1].item()
def save(self, path, overwrite = True, **kwargs):
path = Path(path)
assert not (path.exists() and not overwrite)
@@ -594,10 +635,7 @@ class DecoderTrainer(nn.Module):
self.steps += F.one_hot(unet_index_tensor, num_classes = len(self.steps))
def update(self, unet_number = None):
if self.num_unets == 1:
unet_number = default(unet_number, 1)
assert exists(unet_number) and 1 <= unet_number <= self.num_unets
unet_number = self.validate_and_return_unet_number(unet_number)
index = unet_number - 1
optimizer = getattr(self, f'optim{index}')
@@ -663,11 +701,13 @@ class DecoderTrainer(nn.Module):
max_batch_size = None,
**kwargs
):
if self.num_unets == 1:
unet_number = default(unet_number, 1)
unet_number = self.validate_and_return_unet_number(unet_number)
total_loss = 0.
using_amp = self.accelerator.mixed_precision != 'no'
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with self.accelerator.autocast():
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)

View File

@@ -1 +1 @@
__version__ = '0.16.16'
__version__ = '0.18.0'

BIN
test_data/0.tar Normal file

Binary file not shown.

BIN
test_data/1.tar Normal file

Binary file not shown.

BIN
test_data/2.tar Normal file

Binary file not shown.

BIN
test_data/3.tar Normal file

Binary file not shown.

BIN
test_data/4.tar Normal file

Binary file not shown.

BIN
test_data/5.tar Normal file

Binary file not shown.

BIN
test_data/6.tar Normal file

Binary file not shown.

BIN
test_data/7.tar Normal file

Binary file not shown.

BIN
test_data/8.tar Normal file

Binary file not shown.

BIN
test_data/9.tar Normal file

Binary file not shown.

View File

@@ -132,7 +132,7 @@ def get_example_data(dataloader, device, n=5):
break
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend=""):
def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend="", match_image_size=True):
"""
Takes example data and generates images from the embeddings
Returns three lists: real images, generated images, and captions
@@ -160,6 +160,9 @@ def generate_samples(trainer, example_data, condition_on_text_encodings=False, t
samples = trainer.sample(**sample_params)
generated_images = list(samples)
captions = [text_prepend + txt for txt in txts]
if match_image_size:
generated_image_size = generated_images[0].shape[-1]
real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images]
return real_images, generated_images, captions
def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""):
@@ -167,14 +170,6 @@ def generate_grid_samples(trainer, examples, condition_on_text_encodings=False,
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
"""
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend)
real_image_size = real_images[0].shape[-1]
generated_image_size = generated_images[0].shape[-1]
# training images may be larger than the generated one
if real_image_size > generated_image_size:
real_images = [resize_image_to(image, generated_image_size) for image in real_images]
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions
@@ -279,6 +274,7 @@ def train(
trainer = DecoderTrainer(
decoder=decoder,
accelerator=accelerator,
dataloaders=dataloaders,
**kwargs
)
@@ -289,9 +285,8 @@ def train(
sample = 0
samples_seen = 0
val_sample = 0
step = lambda: int(trainer.step.item())
if tracker.loader is not None:
if tracker.can_recall:
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
if next_task == 'train':
sample = recalled_sample
@@ -304,6 +299,8 @@ def train(
if not exists(unet_training_mask):
# Then the unet mask should be true for all unets in the decoder
unet_training_mask = [True] * trainer.num_unets
first_training_unet = min(index for index, mask in enumerate(unet_training_mask) if mask)
step = lambda: int(trainer.num_steps_taken(unet_number=first_training_unet+1))
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
accelerator.print(print_ribbon("Generating Example Data", repeat=40))
@@ -326,7 +323,7 @@ def train(
last_snapshot = sample
if next_task == 'train':
for i, (img, emb, txt) in enumerate(dataloaders["train"]):
for i, (img, emb, txt) in enumerate(trainer.train_loader):
# We want to count the total number of samples across all processes
sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
@@ -419,7 +416,7 @@ def train(
timer = Timer()
accelerator.wait_for_everyone()
i = 0
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
for i, (img, emb, txt) in enumerate(trainer.val_loader): # Use the accelerate prepared loader
val_sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(val_sample_length_tensor)
total_samples = all_samples.sum().item()
@@ -524,6 +521,20 @@ def initialize_training(config: TrainDecoderConfig, config_path):
# Set up accelerator for configurable distributed training
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
if accelerator.num_processes > 1:
# We are using distributed training and want to immediately ensure all can connect
accelerator.print("Waiting for all processes to connect...")
accelerator.wait_for_everyone()
accelerator.print("All processes online and connected")
# If we are in deepspeed fp16 mode, we must ensure learned variance is off
if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance:
raise ValueError("DeepSpeed fp16 mode does not support learned variance")
if accelerator.process_index != accelerator.local_process_index and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED:
# This is an invalid configuration until we figure out how to handle this
raise ValueError("DeepSpeed does not support multi-node distributed training")
# Set up data
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))