Compare commits

..

9 Commits
0.3.8 ... 0.4.2

7 changed files with 177 additions and 121 deletions

View File

@@ -1076,6 +1076,7 @@ This library would not have gotten to this working state without the help of
- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14 - [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
- [x] cross embed layers for downsampling, as an option - [x] cross embed layers for downsampling, as an option
- [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a> - [x] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
- [x] use pydantic for config drive training
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs - [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] train on a toy task, offer in colab - [ ] train on a toy task, offer in colab

View File

@@ -4,7 +4,7 @@ For more complex configuration, we provide the option of using a configuration f
### Decoder Trainer ### Decoder Trainer
The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.json.example). The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json).
**<ins>Unets</ins>:** **<ins>Unets</ins>:**

View File

@@ -1,82 +0,0 @@
"""
Defines the default values for the decoder config
"""
from enum import Enum
class ConfigField(Enum):
REQUIRED = 0 # This had more options. It's a bit unnecessary now, but I can't think of a better way to do it.
default_config = {
"unets": ConfigField.REQUIRED,
"decoder": {
"image_sizes": ConfigField.REQUIRED, # The side lengths of the upsampled image at the end of each unet
"image_size": ConfigField.REQUIRED, # Usually the same as image_sizes[-1] I think
"channels": 3,
"timesteps": 1000,
"loss_type": "l2",
"beta_schedule": "cosine",
"learned_variance": True
},
"data": {
"webdataset_base_url": ConfigField.REQUIRED, # Path to a webdataset with jpg images
"embeddings_url": ConfigField.REQUIRED, # Path to .npy files with embeddings
"num_workers": 4,
"batch_size": 64,
"start_shard": 0,
"end_shard": 9999999,
"shard_width": 6,
"index_width": 4,
"splits": {
"train": 0.75,
"val": 0.15,
"test": 0.1
},
"shuffle_train": True,
"resample_train": False,
"preprocessing": {
"ToTensor": True
}
},
"train": {
"epochs": 20,
"lr": 1e-4,
"wd": 0.01,
"max_grad_norm": 0.5,
"save_every_n_samples": 100000,
"n_sample_images": 6, # The number of example images to produce when sampling the train and test dataset
"device": "cuda:0",
"epoch_samples": None, # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
"validation_samples": None, # Same as above but for validation.
"use_ema": True,
"ema_beta": 0.99,
"amp": False,
"save_all": False, # Whether to preserve all checkpoints
"save_latest": True, # Whether to always save the latest checkpoint
"save_best": True, # Whether to save the best checkpoint
"unet_training_mask": None # If None, use all unets
},
"evaluate": {
"n_evalation_samples": 1000,
"FID": None,
"IS": None,
"KID": None,
"LPIPS": None
},
"tracker": {
"tracker_type": "console", # Decoder currently supports console and wandb
"data_path": "./models", # The path where files will be saved locally
"wandb_entity": "", # Only needs to be set if tracker_type is wandb
"wandb_project": "",
"verbose": False # Whether to print console logging for non-console trackers
},
"load": {
"source": None, # Supports file and wandb
"run_path": "", # Used only if source is wandb
"file_path": "", # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
"resume": False # If using wandb, whether to resume the run
}
}

View File

@@ -12,7 +12,6 @@
], ],
"decoder": { "decoder": {
"image_sizes": [64], "image_sizes": [64],
"image_size": [64],
"channels": 3, "channels": 3,
"timesteps": 1000, "timesteps": 1000,
"loss_type": "l2", "loss_type": "l2",
@@ -63,7 +62,7 @@
"unet_training_mask": [true] "unet_training_mask": [true]
}, },
"evaluate": { "evaluate": {
"n_evalation_samples": 1000, "n_evaluation_samples": 1000,
"FID": { "FID": {
"feature": 64 "feature": 64
}, },

View File

@@ -0,0 +1,135 @@
import json
from torchvision import transforms as T
from pydantic import BaseModel, validator, root_validator
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
class UnetConfig(BaseModel):
dim: int
dim_mults: List[int]
image_embed_dim: int = None
cond_dim: int = None
channels: int = 3
attn_dim_head: int = 32
attn_heads: int = 16
class Config:
extra = "allow"
class DecoderConfig(BaseModel):
image_size: int = None
image_sizes: Union[List[int], Tuple[int]] = None
channels: int = 3
timesteps: int = 1000
loss_type: str = 'l2'
beta_schedule: str = 'cosine'
learned_variance: bool = True
@validator('image_sizes')
def check_image_sizes(cls, image_sizes, values):
if exists(values.get('image_size')) ^ exists(image_sizes):
return image_sizes
raise ValueError('either image_size or image_sizes is required, but not both')
class Config:
extra = "allow"
class TrainSplitConfig(BaseModel):
train: float = 0.75
val: float = 0.15
test: float = 0.1
@root_validator
def validate_all(cls, fields):
if sum([*fields.values()]) != 1.:
raise ValueError(f'{fields.keys()} must sum to 1.0')
return fields
class DecoderDataConfig(BaseModel):
webdataset_base_url: str # path to a webdataset with jpg images
embeddings_url: str # path to .npy files with embeddings
num_workers: int = 4
batch_size: int = 64
start_shard: int = 0
end_shard: int = 9999999
shard_width: int = 6
index_width: int = 4
splits: TrainSplitConfig
shuffle_train: bool = True
resample_train: bool = False
preprocessing: Dict[str, Any] = {'ToTensor': True}
class DecoderTrainConfig(BaseModel):
epochs: int = 20
lr: float = 1e-4
wd: float = 0.01
max_grad_norm: float = 0.5
save_every_n_samples: int = 100000
n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
device: str = 'cuda:0'
epoch_samples: int = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
validation_samples: int = None # Same as above but for validation.
use_ema: bool = True
ema_beta: float = 0.99
amp: bool = False
save_all: bool = False # Whether to preserve all checkpoints
save_latest: bool = True # Whether to always save the latest checkpoint
save_best: bool = True # Whether to save the best checkpoint
unet_training_mask: List[bool] = None # If None, use all unets
class DecoderEvaluateConfig(BaseModel):
n_evaluation_samples: int = 1000
FID: Dict[str, Any] = None
IS: Dict[str, Any] = None
KID: Dict[str, Any] = None
LPIPS: Dict[str, Any] = None
class TrackerConfig(BaseModel):
tracker_type: str = 'console' # Decoder currently supports console and wandb
data_path: str = './models' # The path where files will be saved locally
init_config: Dict[str, Any] = None
wandb_entity: str = '' # Only needs to be set if tracker_type is wandb
wandb_project: str = ''
verbose: bool = False # Whether to print console logging for non-console trackers
class DecoderLoadConfig(BaseModel):
source: str = None # Supports file and wandb
run_path: str = '' # Used only if source is wandb
file_path: str = '' # The local filepath if source is file. If source is wandb, the relative path to the model file in wandb.
resume: bool = False # If using wandb, whether to resume the run
class TrainDecoderConfig(BaseModel):
unets: List[UnetConfig]
decoder: DecoderConfig
data: DecoderDataConfig
train: DecoderTrainConfig
evaluate: DecoderEvaluateConfig
tracker: TrackerConfig
load: DecoderLoadConfig
@classmethod
def from_json_path(cls, json_path):
with open(json_path) as f:
config = json.load(f)
return cls(**config)
@property
def img_preproc(self):
def _get_transformation(transformation_name, **kwargs):
if transformation_name == "RandomResizedCrop":
return T.RandomResizedCrop(**kwargs)
elif transformation_name == "RandomHorizontalFlip":
return T.RandomHorizontalFlip()
elif transformation_name == "ToTensor":
return T.ToTensor()
transforms = []
for transform_name, transform_kwargs_or_bool in self.data.preprocessing.items():
transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
transforms.append(_get_transformation(transform_name, **transform_kwargs))
return T.Compose(transforms)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.3.8', version = '0.4.2',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',
@@ -32,6 +32,7 @@ setup(
'kornia>=0.5.4', 'kornia>=0.5.4',
'numpy', 'numpy',
'pillow', 'pillow',
'pydantic',
'resize-right>=0.0.2', 'resize-right>=0.0.2',
'rotary-embedding-torch', 'rotary-embedding-torch',
'torch>=1.10', 'torch>=1.10',

View File

@@ -5,7 +5,6 @@ from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
from dalle2_pytorch.train_configs import TrainDecoderConfig from dalle2_pytorch.train_configs import TrainDecoderConfig
from dalle2_pytorch.utils import Timer from dalle2_pytorch.utils import Timer
import json
import torchvision import torchvision
import torch import torch
from torchmetrics.image.fid import FrechetInceptionDistance from torchmetrics.image.fid import FrechetInceptionDistance
@@ -90,11 +89,11 @@ def create_dataloaders(
def create_decoder(device, decoder_config, unets_config): def create_decoder(device, decoder_config, unets_config):
"""Creates a sample decoder""" """Creates a sample decoder"""
unets = [Unet(**config) for config in unets_config] unets = [Unet(**config.dict()) for config in unets_config]
decoder = Decoder( decoder = Decoder(
unet=unets, unet=unets,
**decoder_config **decoder_config.dict()
) )
decoder.to(device=device) decoder.to(device=device)
@@ -154,13 +153,13 @@ def generate_grid_samples(trainer, examples, text_prepend=""):
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)] grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions return grid_images, captions
def evaluate_trainer(trainer, dataloader, device, n_evalation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None): def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
""" """
Computes evaluation metrics for the decoder Computes evaluation metrics for the decoder
""" """
metrics = {} metrics = {}
# Prepare the data # Prepare the data
examples = get_example_data(dataloader, device, n_evalation_samples) examples = get_example_data(dataloader, device, n_evaluation_samples)
real_images, generated_images, captions = generate_samples(trainer, examples) real_images, generated_images, captions = generate_samples(trainer, examples)
real_images = torch.stack(real_images).to(device=device, dtype=torch.float) real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float) generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
@@ -252,8 +251,8 @@ def train(
start_epoch = 0 start_epoch = 0
validation_losses = [] validation_losses = []
if exists(load_config) and exists(load_config["source"]): if exists(load_config) and exists(load_config.source):
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config["source"], **load_config) start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config)
trainer.to(device=inference_device) trainer.to(device=inference_device)
if not exists(unet_training_mask): if not exists(unet_training_mask):
@@ -271,7 +270,6 @@ def train(
for epoch in range(start_epoch, epochs): for epoch in range(start_epoch, epochs):
print(print_ribbon(f"Starting epoch {epoch}", repeat=40)) print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
trainer.train()
timer = Timer() timer = Timer()
@@ -280,11 +278,13 @@ def train(
last_snapshot = 0 last_snapshot = 0
losses = [] losses = []
for i, (img, emb) in enumerate(dataloaders["train"]): for i, (img, emb) in enumerate(dataloaders["train"]):
step += 1 step += 1
sample += img.shape[0] sample += img.shape[0]
img, emb = send_to_device((img, emb)) img, emb = send_to_device((img, emb))
trainer.train()
for unet in range(1, trainer.num_unets+1): for unet in range(1, trainer.num_unets+1):
# Check if this is a unet we are training # Check if this is a unet we are training
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1 if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
@@ -299,7 +299,7 @@ def train(
timer.reset() timer.reset()
last_sample = sample last_sample = sample
if i % CALC_LOSS_EVERY_ITERS == 0: if i % TRAIN_CALC_LOSS_EVERY_ITERS == 0:
average_loss = sum(losses) / len(losses) average_loss = sum(losses) / len(losses)
log_data = { log_data = {
"Training loss": average_loss, "Training loss": average_loss,
@@ -319,11 +319,12 @@ def train(
save_paths.append("latest.pth") save_paths.append("latest.pth")
if save_all: if save_all:
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth") save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth")
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths) save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
if exists(n_sample_images) and n_sample_images > 0: if exists(n_sample_images) and n_sample_images > 0:
trainer.eval() trainer.eval()
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ") train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
trainer.train()
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step) tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
if exists(epoch_samples) and sample >= epoch_samples: if exists(epoch_samples) and sample >= epoch_samples:
@@ -358,7 +359,6 @@ def train(
tracker.log(log_data, step=step, verbose=True) tracker.log(log_data, step=step, verbose=True)
# Compute evaluation metrics # Compute evaluation metrics
trainer.eval()
if exists(evaluate_config): if exists(evaluate_config):
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40)) print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config) evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
@@ -385,21 +385,25 @@ def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
""" """
Creates a tracker of the specified type and initializes special features based on the full config Creates a tracker of the specified type and initializes special features based on the full config
""" """
tracker_config = config["tracker"] tracker_config = config.tracker
init_config = {} init_config = {}
init_config["config"] = config.config
if exists(tracker_config.init_config):
init_config["config"] = tracker_config.init_config
if tracker_type == "console": if tracker_type == "console":
tracker = ConsoleTracker(**init_config) tracker = ConsoleTracker(**init_config)
elif tracker_type == "wandb": elif tracker_type == "wandb":
# We need to initialize the resume state here # We need to initialize the resume state here
load_config = config["load"] load_config = config.load
if load_config["source"] == "wandb" and load_config["resume"]: if load_config.source == "wandb" and load_config.resume:
# Then we are resuming the run load_config["run_path"] # Then we are resuming the run load_config["run_path"]
run_id = config["resume"]["wandb_run_path"].split("/")[-1] run_id = load_config.run_path.split("/")[-1]
init_config["id"] = run_id init_config["id"] = run_id
init_config["resume"] = "must" init_config["resume"] = "must"
init_config["entity"] = tracker_config["wandb_entity"]
init_config["project"] = tracker_config["wandb_project"] init_config["entity"] = tracker_config.wandb_entity
init_config["project"] = tracker_config.wandb_project
tracker = WandbTracker(data_path) tracker = WandbTracker(data_path)
tracker.init(**init_config) tracker.init(**init_config)
else: else:
@@ -408,35 +412,35 @@ def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
def initialize_training(config): def initialize_training(config):
# Create the save path # Create the save path
if "cuda" in config["train"]["device"]: if "cuda" in config.train.device:
assert torch.cuda.is_available(), "CUDA is not available" assert torch.cuda.is_available(), "CUDA is not available"
device = torch.device(config["train"]["device"]) device = torch.device(config.train.device)
torch.cuda.set_device(device) torch.cuda.set_device(device)
all_shards = list(range(config["data"]["start_shard"], config["data"]["end_shard"] + 1)) all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
dataloaders = create_dataloaders ( dataloaders = create_dataloaders (
available_shards=all_shards, available_shards=all_shards,
img_preproc = config.get_preprocessing(), img_preproc = config.img_preproc,
train_prop = config["data"]["splits"]["train"], train_prop = config.data.splits.train,
val_prop = config["data"]["splits"]["val"], val_prop = config.data.splits.val,
test_prop = config["data"]["splits"]["test"], test_prop = config.data.splits.test,
n_sample_images=config["train"]["n_sample_images"], n_sample_images=config.train.n_sample_images,
**config["data"] **config.data.dict()
) )
decoder = create_decoder(device, config["decoder"], config["unets"]) decoder = create_decoder(device, config.decoder, config.unets)
num_parameters = sum(p.numel() for p in decoder.parameters()) num_parameters = sum(p.numel() for p in decoder.parameters())
print(print_ribbon("Loaded Config", repeat=40)) print(print_ribbon("Loaded Config", repeat=40))
print(f"Number of parameters: {num_parameters}") print(f"Number of parameters: {num_parameters}")
tracker = create_tracker(config, **config["tracker"]) tracker = create_tracker(config, **config.tracker.dict())
train(dataloaders, decoder, train(dataloaders, decoder,
tracker=tracker, tracker=tracker,
inference_device=device, inference_device=device,
load_config=config["load"], load_config=config.load,
evaluate_config=config["evaluate"], evaluate_config=config.evaluate,
**config["train"], **config.train.dict(),
) )
# Create a simple click command line interface to load the config and start the training # Create a simple click command line interface to load the config and start the training
@@ -444,9 +448,7 @@ def initialize_training(config):
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file") @click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
def main(config_file): def main(config_file):
print("Recalling config from {}".format(config_file)) print("Recalling config from {}".format(config_file))
with open(config_file) as f: config = TrainDecoderConfig.from_json_path(config_file)
config = json.load(f)
config = TrainDecoderConfig(config)
initialize_training(config) initialize_training(config)