mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
Added config files for training Changed example image generation to be more efficient Added configuration description to README Removed unused import
500 lines
21 KiB
Python
500 lines
21 KiB
Python
from dalle2_pytorch import Unet, Decoder
|
|
from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon
|
|
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
|
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
|
from configs.decoder_defaults import default_config, ConfigField
|
|
import time
|
|
import json
|
|
import torchvision
|
|
from torchvision import transforms as T
|
|
import torch
|
|
from torchmetrics.image.fid import FrechetInceptionDistance
|
|
from torchmetrics.image.inception import InceptionScore
|
|
from torchmetrics.image.kid import KernelInceptionDistance
|
|
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
|
import webdataset as wds
|
|
import click
|
|
|
|
|
|
def create_dataloaders(
|
|
available_shards,
|
|
webdataset_base_url,
|
|
embeddings_url,
|
|
shard_width=6,
|
|
num_workers=4,
|
|
batch_size=32,
|
|
n_sample_images=6,
|
|
shuffle_train=True,
|
|
resample_train=False,
|
|
img_preproc = None,
|
|
index_width=4,
|
|
train_prop = 0.75,
|
|
val_prop = 0.15,
|
|
test_prop = 0.10,
|
|
**kwargs
|
|
):
|
|
"""
|
|
Randomly splits the available shards into train, val, and test sets and returns a dataloader for each
|
|
"""
|
|
assert train_prop + test_prop + val_prop == 1
|
|
num_train = round(train_prop*len(available_shards))
|
|
num_test = round(test_prop*len(available_shards))
|
|
num_val = len(available_shards) - num_train - num_test
|
|
assert num_train + num_test + num_val == len(available_shards), f"{num_train} + {num_test} + {num_val} = {num_train + num_test + num_val} != {len(available_shards)}"
|
|
train_split, test_split, val_split = torch.utils.data.random_split(available_shards, [num_train, num_test, num_val], generator=torch.Generator().manual_seed(0))
|
|
|
|
# The shard number in the webdataset file names has a fixed width. We zero pad the shard numbers so they correspond to a filename.
|
|
train_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in train_split]
|
|
test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split]
|
|
val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split]
|
|
|
|
create_dataloader = lambda tar_urls, shuffle=False, resample=False, with_text=False, for_sampling=False: create_image_embedding_dataloader(
|
|
tar_url=tar_urls,
|
|
num_workers=num_workers,
|
|
batch_size=batch_size if not for_sampling else n_sample_images,
|
|
embeddings_url=embeddings_url,
|
|
index_width=index_width,
|
|
shuffle_num = None,
|
|
extra_keys= ["txt"] if with_text else [],
|
|
shuffle_shards = shuffle,
|
|
resample_shards = resample,
|
|
img_preproc=img_preproc,
|
|
handler=wds.handlers.warn_and_continue
|
|
)
|
|
|
|
train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train)
|
|
train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True)
|
|
val_dataloader = create_dataloader(val_urls, shuffle=False, with_text=True)
|
|
test_dataloader = create_dataloader(test_urls, shuffle=False, with_text=True)
|
|
test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True)
|
|
return {
|
|
"train": train_dataloader,
|
|
"train_sampling": train_sampling_dataloader,
|
|
"val": val_dataloader,
|
|
"test": test_dataloader,
|
|
"test_sampling": test_sampling_dataloader
|
|
}
|
|
|
|
|
|
def create_decoder(device, decoder_config, unets_config):
|
|
"""Creates a sample decoder"""
|
|
unets = []
|
|
for i in range(0, len(unets_config)):
|
|
unets.append(Unet(
|
|
**unets_config[i]
|
|
))
|
|
|
|
decoder = Decoder(
|
|
unet=tuple(unets), # Must be tuple because of cast_tuple
|
|
**decoder_config
|
|
)
|
|
decoder.to(device=device)
|
|
|
|
return decoder
|
|
|
|
def get_dataset_keys(dataloader):
|
|
"""
|
|
It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
|
|
"""
|
|
# If the dataloader is actually a WebLoader, we need to extract the real dataloader
|
|
if isinstance(dataloader, wds.WebLoader):
|
|
dataloader = dataloader.pipeline[0]
|
|
return dataloader.dataset.key_map
|
|
|
|
def get_example_data(dataloader, device, n=5):
|
|
"""
|
|
Samples the dataloader and returns a zipped list of examples
|
|
"""
|
|
images = []
|
|
embeddings = []
|
|
captions = []
|
|
dataset_keys = get_dataset_keys(dataloader)
|
|
has_caption = "txt" in dataset_keys
|
|
for data in dataloader:
|
|
if has_caption:
|
|
img, emb, txt = data
|
|
else:
|
|
img, emb = data
|
|
txt = [""] * emb.shape[0]
|
|
img = img.to(device=device, dtype=torch.float)
|
|
emb = emb.to(device=device, dtype=torch.float)
|
|
images.extend(list(img))
|
|
embeddings.extend(list(emb))
|
|
captions.extend(list(txt))
|
|
if len(images) >= n:
|
|
break
|
|
print("Generated {} examples".format(len(images)))
|
|
return list(zip(images[:n], embeddings[:n], captions[:n]))
|
|
|
|
def generate_samples(trainer, example_data, text_prepend=""):
|
|
"""
|
|
Takes example data and generates images from the embeddings
|
|
Returns three lists: real images, generated images, and captions
|
|
"""
|
|
real_images, embeddings, txts = zip(*example_data)
|
|
embeddings_tensor = torch.stack(embeddings)
|
|
samples = trainer.sample(embeddings_tensor)
|
|
generated_images = list(samples)
|
|
captions = [text_prepend + txt for txt in txts]
|
|
return real_images, generated_images, captions
|
|
|
|
def generate_grid_samples(trainer, examples, text_prepend=""):
|
|
"""
|
|
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
|
"""
|
|
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
|
|
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
|
return grid_images, captions
|
|
|
|
def evaluate_trainer(trainer, dataloader, device, n_evalation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
|
"""
|
|
Computes evaluation metrics for the decoder
|
|
"""
|
|
metrics = {}
|
|
# Prepare the data
|
|
examples = get_example_data(dataloader, device, n_evalation_samples)
|
|
real_images, generated_images, captions = generate_samples(trainer, examples)
|
|
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
|
|
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
|
|
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
|
|
int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
|
|
int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
|
|
if FID is not None:
|
|
fid = FrechetInceptionDistance(**FID)
|
|
fid.to(device=device)
|
|
fid.update(int_real_images, real=True)
|
|
fid.update(int_generated_images, real=False)
|
|
metrics["FID"] = fid.compute().item()
|
|
if IS is not None:
|
|
inception = InceptionScore(**IS)
|
|
inception.to(device=device)
|
|
inception.update(int_real_images)
|
|
is_mean, is_std = inception.compute()
|
|
metrics["IS_mean"] = is_mean.item()
|
|
metrics["IS_std"] = is_std.item()
|
|
if KID is not None:
|
|
kernel_inception = KernelInceptionDistance(**KID)
|
|
kernel_inception.to(device=device)
|
|
kernel_inception.update(int_real_images, real=True)
|
|
kernel_inception.update(int_generated_images, real=False)
|
|
kid_mean, kid_std = kernel_inception.compute()
|
|
metrics["KID_mean"] = kid_mean.item()
|
|
metrics["KID_std"] = kid_std.item()
|
|
if LPIPS is not None:
|
|
# Convert from [0, 1] to [-1, 1]
|
|
renorm_real_images = real_images.mul(2).sub(1)
|
|
renorm_generated_images = generated_images.mul(2).sub(1)
|
|
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS)
|
|
lpips.to(device=device)
|
|
lpips.update(renorm_real_images, renorm_generated_images)
|
|
metrics["LPIPS"] = lpips.compute().item()
|
|
return metrics
|
|
|
|
def save_trainer(tracker, trainer, epoch, step, validation_losses, relative_paths):
|
|
"""
|
|
Logs the model with an appropriate method depending on the tracker
|
|
"""
|
|
if isinstance(relative_paths, str):
|
|
relative_paths = [relative_paths]
|
|
trainer_state_dict = {}
|
|
trainer_state_dict["trainer"] = trainer.state_dict()
|
|
trainer_state_dict['epoch'] = epoch
|
|
trainer_state_dict['step'] = step
|
|
trainer_state_dict['validation_losses'] = validation_losses
|
|
for relative_path in relative_paths:
|
|
tracker.save_state_dict(trainer_state_dict, relative_path)
|
|
|
|
def recall_trainer(tracker, trainer, recall_source=None, **load_config):
|
|
"""
|
|
Loads the model with an appropriate method depending on the tracker
|
|
"""
|
|
print(print_ribbon(f"Loading model from {recall_source}"))
|
|
state_dict = tracker.recall_state_dict(recall_source, **load_config)
|
|
trainer.load_state_dict(state_dict["trainer"])
|
|
print("Model loaded")
|
|
return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"]
|
|
|
|
def train(
|
|
dataloaders,
|
|
decoder,
|
|
tracker,
|
|
inference_device,
|
|
load_config=None,
|
|
evaluate_config=None,
|
|
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
|
|
validation_samples = None,
|
|
epochs = 20,
|
|
n_sample_images = 5,
|
|
save_every_n_samples = 100000,
|
|
save_all=False,
|
|
save_latest=True,
|
|
save_best=True,
|
|
unet_training_mask=None,
|
|
**kwargs
|
|
):
|
|
"""
|
|
Trains a decoder on a dataset.
|
|
"""
|
|
trainer = DecoderTrainer( # TODO: Change the get_optimizer function so that it can take arbitrary named args so we can just put **kwargs as an argument here
|
|
decoder,
|
|
**kwargs
|
|
)
|
|
# Set up starting model and parameters based on a recalled state dict
|
|
start_step = 0
|
|
start_epoch = 0
|
|
validation_losses = []
|
|
|
|
if load_config is not None and load_config["source"] is not None:
|
|
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config["source"], **load_config)
|
|
trainer.to(device=inference_device)
|
|
|
|
if unet_training_mask is None:
|
|
# Then the unet mask should be true for all unets in the decoder
|
|
unet_training_mask = [True] * trainer.num_unets
|
|
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
|
|
|
|
print(print_ribbon("Generating Example Data", repeat=40))
|
|
print("This can take a while to load the shard lists...")
|
|
train_example_data = get_example_data(dataloaders["train_sampling"], inference_device, n_sample_images)
|
|
test_example_data = get_example_data(dataloaders["test_sampling"], inference_device, n_sample_images)
|
|
|
|
send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr]
|
|
step = start_step
|
|
for epoch in range(start_epoch, epochs):
|
|
print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
|
|
trainer.train()
|
|
|
|
sample = 0
|
|
last_sample = 0
|
|
last_snapshot = 0
|
|
last_time = time.time()
|
|
losses = []
|
|
for i, (img, emb) in enumerate(dataloaders["train"]):
|
|
step += 1
|
|
sample += img.shape[0]
|
|
img, emb = send_to_device((img, emb))
|
|
|
|
for unet in range(1, trainer.num_unets+1):
|
|
# Check if this is a unet we are training
|
|
if unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
|
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
|
|
trainer.update(unet_number=unet)
|
|
losses.append(loss)
|
|
|
|
samples_per_sec = (sample - last_sample) / (time.time() - last_time)
|
|
last_time = time.time()
|
|
last_sample = sample
|
|
|
|
if i % 10 == 0:
|
|
average_loss = sum(losses) / len(losses)
|
|
log_data = {
|
|
"Training loss": average_loss,
|
|
"Epoch": epoch,
|
|
"Sample": sample,
|
|
"Step": i,
|
|
"Samples per second": samples_per_sec
|
|
}
|
|
tracker.log(log_data, step=step, verbose=True)
|
|
losses = []
|
|
|
|
if last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
|
|
last_snapshot = sample
|
|
# We need to know where the model should be saved
|
|
save_paths = []
|
|
if save_latest:
|
|
save_paths.append("latest.pth")
|
|
if save_all:
|
|
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth")
|
|
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
|
|
if n_sample_images is not None and n_sample_images > 0:
|
|
trainer.eval()
|
|
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
|
trainer.train()
|
|
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
|
|
|
|
if epoch_samples is not None and sample >= epoch_samples:
|
|
break
|
|
|
|
trainer.eval()
|
|
print(print_ribbon(f"Starting Validation {epoch}", repeat=40))
|
|
with torch.no_grad():
|
|
sample = 0
|
|
average_loss = 0
|
|
start_time = time.time()
|
|
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
|
|
sample += img.shape[0]
|
|
img, emb = send_to_device((img, emb))
|
|
|
|
for unet in range(1, len(decoder.unets)+1):
|
|
loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet)
|
|
average_loss += loss
|
|
|
|
if i % 10 == 0:
|
|
print(f"Epoch {epoch}/{epochs} - {sample / (time.time() - start_time):.2f} samples/sec")
|
|
print(f"Loss: {average_loss / (i+1)}")
|
|
print("")
|
|
|
|
if validation_samples is not None and sample >= validation_samples:
|
|
break
|
|
average_loss /= i+1
|
|
log_data = {
|
|
"Validation loss": average_loss
|
|
}
|
|
tracker.log(log_data, step=step, verbose=True)
|
|
|
|
# Compute evaluation metrics
|
|
trainer.eval()
|
|
if evaluate_config is not None:
|
|
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
|
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
|
|
tracker.log(evaluation, step=step, verbose=True)
|
|
|
|
# Generate sample images
|
|
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
|
|
test_images, test_captions = generate_grid_samples(trainer, test_example_data, "Test: ")
|
|
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
|
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step)
|
|
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
|
|
|
|
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
|
|
# Get the same paths
|
|
save_paths = []
|
|
if save_latest:
|
|
save_paths.append("latest.pth")
|
|
if save_best and (len(validation_losses) == 0 or average_loss < min(validation_losses)):
|
|
save_paths.append("best.pth")
|
|
validation_losses.append(average_loss)
|
|
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
|
|
|
|
def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
|
|
"""
|
|
Creates a tracker of the specified type and initializes special features based on the full config
|
|
"""
|
|
tracker_config = config["tracker"]
|
|
init_config = {}
|
|
init_config["config"] = config.config
|
|
if tracker_type == "console":
|
|
tracker = ConsoleTracker(**init_config)
|
|
elif tracker_type == "wandb":
|
|
# We need to initialize the resume state here
|
|
load_config = config["load"]
|
|
if load_config["source"] == "wandb" and load_config["resume"]:
|
|
# Then we are resuming the run load_config["run_path"]
|
|
run_id = config["resume"]["wandb_run_path"].split("/")[-1]
|
|
init_config["id"] = run_id
|
|
init_config["resume"] = "must"
|
|
init_config["entity"] = tracker_config["wandb_entity"]
|
|
init_config["project"] = tracker_config["wandb_project"]
|
|
tracker = WandbTracker(data_path)
|
|
tracker.init(**init_config)
|
|
else:
|
|
raise ValueError(f"Tracker type {tracker_type} not supported by decoder trainer")
|
|
return tracker
|
|
|
|
def initialize_training(config):
|
|
# Create the save path
|
|
if "cuda" in config["train"]["device"]:
|
|
assert torch.cuda.is_available(), "CUDA is not available"
|
|
device = torch.device(config["train"]["device"])
|
|
torch.cuda.set_device(device)
|
|
all_shards = list(range(config["data"]["start_shard"], config["data"]["end_shard"] + 1))
|
|
|
|
dataloaders = create_dataloaders (
|
|
available_shards=all_shards,
|
|
img_preproc = config.get_preprocessing(),
|
|
train_prop = config["data"]["splits"]["train"],
|
|
val_prop = config["data"]["splits"]["val"],
|
|
test_prop = config["data"]["splits"]["test"],
|
|
n_sample_images=config["train"]["n_sample_images"],
|
|
**config["data"]
|
|
)
|
|
|
|
decoder = create_decoder(device, config["decoder"], config["unets"])
|
|
num_parameters = sum(p.numel() for p in decoder.parameters())
|
|
print(print_ribbon("Loaded Config", repeat=40))
|
|
print(f"Number of parameters: {num_parameters}")
|
|
|
|
tracker = create_tracker(config, **config["tracker"])
|
|
|
|
train(dataloaders, decoder,
|
|
tracker=tracker,
|
|
inference_device=device,
|
|
load_config=config["load"],
|
|
evaluate_config=config["evaluate"],
|
|
**config["train"],
|
|
)
|
|
|
|
|
|
class TrainDecoderConfig:
|
|
def __init__(self, config):
|
|
self.config = self.map_config(config, default_config)
|
|
|
|
def map_config(self, config, defaults):
|
|
"""
|
|
Returns a dictionary containing all config options in the union of config and defaults.
|
|
If the config value is an array, apply the default value to each element.
|
|
If the default values dict has a value of ConfigField.REQUIRED for a key, it is required and a runtime error should be thrown if a value is not supplied from config
|
|
"""
|
|
def _check_option(option, option_config, option_defaults):
|
|
for key, value in option_defaults.items():
|
|
if key not in option_config:
|
|
if value == ConfigField.REQUIRED:
|
|
raise RuntimeError("Required config value '{}' of option '{}' not supplied".format(key, option))
|
|
option_config[key] = value
|
|
|
|
for key, value in defaults.items():
|
|
if key not in config:
|
|
# Then they did not pass in one of the main configs. If the default is an array or object, then we can fill it in. If is a required object, we must error
|
|
if value == ConfigField.REQUIRED:
|
|
raise RuntimeError("Required config value '{}' not supplied".format(key))
|
|
elif isinstance(value, dict):
|
|
config[key] = {}
|
|
elif isinstance(value, list):
|
|
config[key] = [{}]
|
|
# Config[key] is now either a dict, list of dicts, or an object that cannot be checked.
|
|
# If it is a list, then we need to check each element
|
|
if isinstance(value, list):
|
|
assert isinstance(config[key], list)
|
|
for element in config[key]:
|
|
_check_option(key, element, value[0])
|
|
elif isinstance(value, dict):
|
|
_check_option(key, config[key], value)
|
|
# This object does not support checking
|
|
return config
|
|
|
|
def get_preprocessing(self):
|
|
"""
|
|
Takes the preprocessing dictionary and converts it to a composition of torchvision transforms
|
|
"""
|
|
def _get_transformation(transformation_name, **kwargs):
|
|
if transformation_name == "RandomResizedCrop":
|
|
return T.RandomResizedCrop(**kwargs)
|
|
elif transformation_name == "RandomHorizontalFlip":
|
|
return T.RandomHorizontalFlip()
|
|
elif transformation_name == "ToTensor":
|
|
return T.ToTensor()
|
|
|
|
transformations = []
|
|
for transformation_name, transformation_kwargs in self.config["data"]["preprocessing"].items():
|
|
if isinstance(transformation_kwargs, dict):
|
|
transformations.append(_get_transformation(transformation_name, **transformation_kwargs))
|
|
else:
|
|
transformations.append(_get_transformation(transformation_name))
|
|
return T.Compose(transformations)
|
|
|
|
def __getitem__(self, key):
|
|
return self.config[key]
|
|
|
|
# Create a simple click command line interface to load the config and start the training
|
|
@click.command()
|
|
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
|
|
def main(config_file):
|
|
print("Recalling config from {}".format(config_file))
|
|
with open(config_file) as f:
|
|
config = json.load(f)
|
|
config = TrainDecoderConfig(config)
|
|
initialize_training(config)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |