Compare commits

...

6 Commits

Author SHA1 Message Date
Romain Beaumont
3a1dea7d97 Fix decoder test by fixing the resizing output size 2022-07-09 11:36:22 +02:00
Phil Wang
097afda606 0.18.0 2022-07-08 18:18:38 -07:00
Aidan Dempster
5c520db825 Added deepspeed support (#195) 2022-07-08 18:18:08 -07:00
Phil Wang
3070610231 just force it so researcher can never pass in an image that is less than the size that is required for CLIP or CoCa 2022-07-08 18:17:29 -07:00
Aidan Dempster
870aeeca62 Fixed issue where evaluation would error when large image was loaded (#194) 2022-07-08 17:11:34 -07:00
Romain Beaumont
f28dc6dc01 setup simple ci (#193) 2022-07-08 16:51:56 -07:00
19 changed files with 212 additions and 19 deletions

33
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,33 @@
name: Continuous integration
on:
push:
branches:
- main
pull_request:
branches:
- main
jobs:
tests:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install
run: |
python3 -m venv .env
source .env/bin/activate
make install
- name: Tests
run: |
source .env/bin/activate
make test

2
.gitignore vendored
View File

@@ -136,3 +136,5 @@ dmypy.json
# Pyre type checker # Pyre type checker
.pyre/ .pyre/
.tracker_data
*.pth

6
Makefile Normal file
View File

@@ -0,0 +1,6 @@
install:
pip install -U pip
pip install -e .
test:
CUDA_VISIBLE_DEVICES= python train_decoder.py --config_file configs/train_decoder_config.test.json

View File

@@ -0,0 +1,102 @@
{
"decoder": {
"unets": [
{
"dim": 16,
"image_embed_dim": 768,
"cond_dim": 16,
"channels": 3,
"dim_mults": [1, 2, 4, 8],
"attn_dim_head": 16,
"attn_heads": 4,
"self_attn": [false, true, true, true]
}
],
"clip": {
"make": "openai",
"model": "ViT-L/14"
},
"timesteps": 10,
"image_sizes": [64],
"channels": 3,
"loss_type": "l2",
"beta_schedule": ["cosine"],
"learned_variance": true
},
"data": {
"webdataset_base_url": "test_data/{}.tar",
"num_workers": 4,
"batch_size": 4,
"start_shard": 0,
"end_shard": 9,
"shard_width": 1,
"index_width": 1,
"splits": {
"train": 0.75,
"val": 0.15,
"test": 0.1
},
"shuffle_train": false,
"resample_train": true,
"preprocessing": {
"RandomResizedCrop": {
"size": [224, 224],
"scale": [0.75, 1.0],
"ratio": [1.0, 1.0]
},
"ToTensor": true
}
},
"train": {
"epochs": 1,
"lr": 1e-16,
"wd": 0.01,
"max_grad_norm": 0.5,
"save_every_n_samples": 100,
"n_sample_images": 1,
"device": "cpu",
"epoch_samples": 50,
"validation_samples": 5,
"use_ema": true,
"ema_beta": 0.99,
"amp": false,
"save_all": false,
"save_latest": true,
"save_best": true,
"unet_training_mask": [true]
},
"evaluate": {
"n_evaluation_samples": 2,
"FID": {
"feature": 64
},
"IS": {
"feature": 64,
"splits": 10
},
"KID": {
"feature": 64,
"subset_size": 2
},
"LPIPS": {
"net_type": "vgg",
"reduction": "mean"
}
},
"tracker": {
"overwrite_data_path": true,
"log": {
"log_type": "console"
},
"load": {
"load_from": null
},
"save": [{
"save_to": "local"
}]
}
}

View File

@@ -169,6 +169,11 @@ class BaseClipAdapter(nn.Module):
self.clip = clip self.clip = clip
self.overrides = kwargs self.overrides = kwargs
def validate_and_resize_image(self, image):
image_size = image.shape[-1]
assert image_size >= self.image_size, f'you are passing in an image of size {image_size} but CLIP requires the image size to be at least {self.image_size}'
return resize_image_to(image, self.image_size)
@property @property
def dim_latent(self): def dim_latent(self):
raise NotImplementedError raise NotImplementedError
@@ -219,7 +224,7 @@ class XClipAdapter(BaseClipAdapter):
@torch.no_grad() @torch.no_grad()
def embed_image(self, image): def embed_image(self, image):
image = resize_image_to(image, self.image_size) image = self.validate_and_resize_image(image)
encoder_output = self.clip.visual_transformer(image) encoder_output = self.clip.visual_transformer(image)
image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:] image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:]
image_embed = self.clip.to_visual_latent(image_cls) image_embed = self.clip.to_visual_latent(image_cls)
@@ -254,7 +259,7 @@ class CoCaAdapter(BaseClipAdapter):
@torch.no_grad() @torch.no_grad()
def embed_image(self, image): def embed_image(self, image):
image = resize_image_to(image, self.image_size) image = self.validate_and_resize_image(image)
image_embed, image_encodings = self.clip.embed_image(image) image_embed, image_encodings = self.clip.embed_image(image)
return EmbeddedImage(image_embed, image_encodings) return EmbeddedImage(image_embed, image_encodings)
@@ -315,7 +320,7 @@ class OpenAIClipAdapter(BaseClipAdapter):
@torch.no_grad() @torch.no_grad()
def embed_image(self, image): def embed_image(self, image):
assert not self.cleared assert not self.cleared
image = resize_image_to(image, self.image_size) image = self.validate_and_resize_image(image)
image = self.clip_normalize(image) image = self.clip_normalize(image)
image_embed = self.clip.encode_image(image) image_embed = self.clip.encode_image(image)
return EmbeddedImage(l2norm(image_embed.float()), None) return EmbeddedImage(l2norm(image_embed.float()), None)

View File

@@ -1,6 +1,7 @@
import os import os
import webdataset as wds import webdataset as wds
import torch import torch
from torch.utils.data import DataLoader
import numpy as np import numpy as np
import fsspec import fsspec
import shutil import shutil
@@ -255,7 +256,7 @@ def create_image_embedding_dataloader(
) )
if shuffle_num is not None and shuffle_num > 0: if shuffle_num is not None and shuffle_num > 0:
ds.shuffle(1000) ds.shuffle(1000)
return wds.WebLoader( return DataLoader(
ds, ds,
num_workers=num_workers, num_workers=num_workers,
batch_size=batch_size, batch_size=batch_size,

View File

@@ -21,7 +21,7 @@ import pytorch_warmup as warmup
from ema_pytorch import EMA from ema_pytorch import EMA
from accelerate import Accelerator from accelerate import Accelerator, DistributedType
import numpy as np import numpy as np
@@ -76,6 +76,7 @@ def cast_torch_tensor(fn):
def inner(model, *args, **kwargs): def inner(model, *args, **kwargs):
device = kwargs.pop('_device', next(model.parameters()).device) device = kwargs.pop('_device', next(model.parameters()).device)
cast_device = kwargs.pop('_cast_device', True) cast_device = kwargs.pop('_cast_device', True)
cast_deepspeed_precision = kwargs.pop('_cast_deepspeed_precision', True)
kwargs_keys = kwargs.keys() kwargs_keys = kwargs.keys()
all_args = (*args, *kwargs.values()) all_args = (*args, *kwargs.values())
@@ -85,6 +86,21 @@ def cast_torch_tensor(fn):
if cast_device: if cast_device:
all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args)) all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
if cast_deepspeed_precision:
try:
accelerator = model.accelerator
if accelerator is not None and accelerator.distributed_type == DistributedType.DEEPSPEED:
cast_type_map = {
"fp16": torch.half,
"bf16": torch.bfloat16,
"no": torch.float
}
precision_type = cast_type_map[accelerator.mixed_precision]
all_args = tuple(map(lambda t: t.to(precision_type) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
except AttributeError:
# Then this model doesn't have an accelerator
pass
args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:] args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
kwargs = dict(tuple(zip(kwargs_keys, kwargs_values))) kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
@@ -446,6 +462,7 @@ class DecoderTrainer(nn.Module):
self, self,
decoder, decoder,
accelerator = None, accelerator = None,
dataloaders = None,
use_ema = True, use_ema = True,
lr = 1e-4, lr = 1e-4,
wd = 1e-2, wd = 1e-2,
@@ -508,8 +525,21 @@ class DecoderTrainer(nn.Module):
self.register_buffer('steps', torch.tensor([0] * self.num_unets)) self.register_buffer('steps', torch.tensor([0] * self.num_unets))
decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers)) if self.accelerator.distributed_type == DistributedType.DEEPSPEED and decoder.clip is not None:
# Then we need to make sure clip is using the correct precision or else deepspeed will error
cast_type_map = {
"fp16": torch.half,
"bf16": torch.bfloat16,
"no": torch.float
}
precision_type = cast_type_map[accelerator.mixed_precision]
assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
clip = decoder.clip
clip.to(precision_type)
decoder, train_loader, val_loader, *optimizers = list(self.accelerator.prepare(decoder, dataloaders["train"], dataloaders["val"], *optimizers))
self.train_loader = train_loader
self.val_loader = val_loader
self.decoder = decoder self.decoder = decoder
# store optimizers # store optimizers
@@ -675,6 +705,9 @@ class DecoderTrainer(nn.Module):
total_loss = 0. total_loss = 0.
using_amp = self.accelerator.mixed_precision != 'no'
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
with self.accelerator.autocast(): with self.accelerator.autocast():
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs) loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)

View File

@@ -1 +1 @@
__version__ = '0.17.0' __version__ = '0.18.0'

BIN
test_data/0.tar Normal file

Binary file not shown.

BIN
test_data/1.tar Normal file

Binary file not shown.

BIN
test_data/2.tar Normal file

Binary file not shown.

BIN
test_data/3.tar Normal file

Binary file not shown.

BIN
test_data/4.tar Normal file

Binary file not shown.

BIN
test_data/5.tar Normal file

Binary file not shown.

BIN
test_data/6.tar Normal file

Binary file not shown.

BIN
test_data/7.tar Normal file

Binary file not shown.

BIN
test_data/8.tar Normal file

Binary file not shown.

BIN
test_data/9.tar Normal file

Binary file not shown.

View File

@@ -132,7 +132,7 @@ def get_example_data(dataloader, device, n=5):
break break
return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n])) return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend=""): def generate_samples(trainer, example_data, condition_on_text_encodings=False, text_prepend="", match_image_size=True):
""" """
Takes example data and generates images from the embeddings Takes example data and generates images from the embeddings
Returns three lists: real images, generated images, and captions Returns three lists: real images, generated images, and captions
@@ -160,6 +160,9 @@ def generate_samples(trainer, example_data, condition_on_text_encodings=False, t
samples = trainer.sample(**sample_params) samples = trainer.sample(**sample_params)
generated_images = list(samples) generated_images = list(samples)
captions = [text_prepend + txt for txt in txts] captions = [text_prepend + txt for txt in txts]
if match_image_size:
generated_image_size = generated_images[0].shape[-1]
real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images]
return real_images, generated_images, captions return real_images, generated_images, captions
def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""): def generate_grid_samples(trainer, examples, condition_on_text_encodings=False, text_prepend=""):
@@ -167,14 +170,6 @@ def generate_grid_samples(trainer, examples, condition_on_text_encodings=False,
Generates samples and uses torchvision to put them in a side by side grid for easy viewing Generates samples and uses torchvision to put them in a side by side grid for easy viewing
""" """
real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend) real_images, generated_images, captions = generate_samples(trainer, examples, condition_on_text_encodings, text_prepend)
real_image_size = real_images[0].shape[-1]
generated_image_size = generated_images[0].shape[-1]
# training images may be larger than the generated one
if real_image_size > generated_image_size:
real_images = [resize_image_to(image, generated_image_size) for image in real_images]
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)] grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions return grid_images, captions
@@ -279,6 +274,7 @@ def train(
trainer = DecoderTrainer( trainer = DecoderTrainer(
decoder=decoder, decoder=decoder,
accelerator=accelerator, accelerator=accelerator,
dataloaders=dataloaders,
**kwargs **kwargs
) )
@@ -289,7 +285,6 @@ def train(
sample = 0 sample = 0
samples_seen = 0 samples_seen = 0
val_sample = 0 val_sample = 0
step = lambda: int(trainer.num_steps_taken(unet_number=1))
if tracker.can_recall: if tracker.can_recall:
start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer) start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
@@ -304,6 +299,8 @@ def train(
if not exists(unet_training_mask): if not exists(unet_training_mask):
# Then the unet mask should be true for all unets in the decoder # Then the unet mask should be true for all unets in the decoder
unet_training_mask = [True] * trainer.num_unets unet_training_mask = [True] * trainer.num_unets
first_training_unet = min(index for index, mask in enumerate(unet_training_mask) if mask)
step = lambda: int(trainer.num_steps_taken(unet_number=first_training_unet+1))
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}" assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
accelerator.print(print_ribbon("Generating Example Data", repeat=40)) accelerator.print(print_ribbon("Generating Example Data", repeat=40))
@@ -326,7 +323,7 @@ def train(
last_snapshot = sample last_snapshot = sample
if next_task == 'train': if next_task == 'train':
for i, (img, emb, txt) in enumerate(dataloaders["train"]): for i, (img, emb, txt) in enumerate(trainer.train_loader):
# We want to count the total number of samples across all processes # We want to count the total number of samples across all processes
sample_length_tensor[0] = len(img) sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this. all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
@@ -419,7 +416,7 @@ def train(
timer = Timer() timer = Timer()
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
i = 0 i = 0
for i, (img, emb, txt) in enumerate(dataloaders["val"]): for i, (img, emb, txt) in enumerate(trainer.val_loader): # Use the accelerate prepared loader
val_sample_length_tensor[0] = len(img) val_sample_length_tensor[0] = len(img)
all_samples = accelerator.gather(val_sample_length_tensor) all_samples = accelerator.gather(val_sample_length_tensor)
total_samples = all_samples.sum().item() total_samples = all_samples.sum().item()
@@ -525,6 +522,20 @@ def initialize_training(config: TrainDecoderConfig, config_path):
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters) ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs]) accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
if accelerator.num_processes > 1:
# We are using distributed training and want to immediately ensure all can connect
accelerator.print("Waiting for all processes to connect...")
accelerator.wait_for_everyone()
accelerator.print("All processes online and connected")
# If we are in deepspeed fp16 mode, we must ensure learned variance is off
if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance:
raise ValueError("DeepSpeed fp16 mode does not support learned variance")
if accelerator.process_index != accelerator.local_process_index and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED:
# This is an invalid configuration until we figure out how to handle this
raise ValueError("DeepSpeed does not support multi-node distributed training")
# Set up data # Set up data
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1)) all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
world_size = accelerator.num_processes world_size = accelerator.num_processes