mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
457 lines
18 KiB
Python
457 lines
18 KiB
Python
from dalle2_pytorch import Unet, Decoder
|
|
from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon
|
|
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
|
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
|
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
|
from dalle2_pytorch.utils import Timer
|
|
|
|
import torchvision
|
|
import torch
|
|
from torchmetrics.image.fid import FrechetInceptionDistance
|
|
from torchmetrics.image.inception import InceptionScore
|
|
from torchmetrics.image.kid import KernelInceptionDistance
|
|
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
|
import webdataset as wds
|
|
import click
|
|
|
|
# constants
|
|
|
|
TRAIN_CALC_LOSS_EVERY_ITERS = 10
|
|
VALID_CALC_LOSS_EVERY_ITERS = 10
|
|
|
|
# helpers functions
|
|
|
|
def exists(val):
|
|
return val is not None
|
|
|
|
# main functions
|
|
|
|
def create_dataloaders(
|
|
available_shards,
|
|
webdataset_base_url,
|
|
embeddings_url,
|
|
shard_width=6,
|
|
num_workers=4,
|
|
batch_size=32,
|
|
n_sample_images=6,
|
|
shuffle_train=True,
|
|
resample_train=False,
|
|
img_preproc = None,
|
|
index_width=4,
|
|
train_prop = 0.75,
|
|
val_prop = 0.15,
|
|
test_prop = 0.10,
|
|
**kwargs
|
|
):
|
|
"""
|
|
Randomly splits the available shards into train, val, and test sets and returns a dataloader for each
|
|
"""
|
|
assert train_prop + test_prop + val_prop == 1
|
|
num_train = round(train_prop*len(available_shards))
|
|
num_test = round(test_prop*len(available_shards))
|
|
num_val = len(available_shards) - num_train - num_test
|
|
assert num_train + num_test + num_val == len(available_shards), f"{num_train} + {num_test} + {num_val} = {num_train + num_test + num_val} != {len(available_shards)}"
|
|
train_split, test_split, val_split = torch.utils.data.random_split(available_shards, [num_train, num_test, num_val], generator=torch.Generator().manual_seed(0))
|
|
|
|
# The shard number in the webdataset file names has a fixed width. We zero pad the shard numbers so they correspond to a filename.
|
|
train_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in train_split]
|
|
test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split]
|
|
val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split]
|
|
|
|
create_dataloader = lambda tar_urls, shuffle=False, resample=False, with_text=False, for_sampling=False: create_image_embedding_dataloader(
|
|
tar_url=tar_urls,
|
|
num_workers=num_workers,
|
|
batch_size=batch_size if not for_sampling else n_sample_images,
|
|
embeddings_url=embeddings_url,
|
|
index_width=index_width,
|
|
shuffle_num = None,
|
|
extra_keys= ["txt"] if with_text else [],
|
|
shuffle_shards = shuffle,
|
|
resample_shards = resample,
|
|
img_preproc=img_preproc,
|
|
handler=wds.handlers.warn_and_continue
|
|
)
|
|
|
|
train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train)
|
|
train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True)
|
|
val_dataloader = create_dataloader(val_urls, shuffle=False, with_text=True)
|
|
test_dataloader = create_dataloader(test_urls, shuffle=False, with_text=True)
|
|
test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True)
|
|
return {
|
|
"train": train_dataloader,
|
|
"train_sampling": train_sampling_dataloader,
|
|
"val": val_dataloader,
|
|
"test": test_dataloader,
|
|
"test_sampling": test_sampling_dataloader
|
|
}
|
|
|
|
|
|
def create_decoder(device, decoder_config, unets_config):
|
|
"""Creates a sample decoder"""
|
|
|
|
unets = [Unet(**config.dict()) for config in unets_config]
|
|
|
|
decoder = Decoder(
|
|
unet=unets,
|
|
**decoder_config.dict()
|
|
)
|
|
|
|
decoder.to(device=device)
|
|
return decoder
|
|
|
|
def get_dataset_keys(dataloader):
|
|
"""
|
|
It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
|
|
"""
|
|
# If the dataloader is actually a WebLoader, we need to extract the real dataloader
|
|
if isinstance(dataloader, wds.WebLoader):
|
|
dataloader = dataloader.pipeline[0]
|
|
return dataloader.dataset.key_map
|
|
|
|
def get_example_data(dataloader, device, n=5):
|
|
"""
|
|
Samples the dataloader and returns a zipped list of examples
|
|
"""
|
|
images = []
|
|
embeddings = []
|
|
captions = []
|
|
dataset_keys = get_dataset_keys(dataloader)
|
|
has_caption = "txt" in dataset_keys
|
|
for data in dataloader:
|
|
if has_caption:
|
|
img, emb, txt = data
|
|
else:
|
|
img, emb = data
|
|
txt = [""] * emb.shape[0]
|
|
img = img.to(device=device, dtype=torch.float)
|
|
emb = emb.to(device=device, dtype=torch.float)
|
|
images.extend(list(img))
|
|
embeddings.extend(list(emb))
|
|
captions.extend(list(txt))
|
|
if len(images) >= n:
|
|
break
|
|
print("Generated {} examples".format(len(images)))
|
|
return list(zip(images[:n], embeddings[:n], captions[:n]))
|
|
|
|
def generate_samples(trainer, example_data, text_prepend=""):
|
|
"""
|
|
Takes example data and generates images from the embeddings
|
|
Returns three lists: real images, generated images, and captions
|
|
"""
|
|
real_images, embeddings, txts = zip(*example_data)
|
|
embeddings_tensor = torch.stack(embeddings)
|
|
samples = trainer.sample(embeddings_tensor)
|
|
generated_images = list(samples)
|
|
captions = [text_prepend + txt for txt in txts]
|
|
return real_images, generated_images, captions
|
|
|
|
def generate_grid_samples(trainer, examples, text_prepend=""):
|
|
"""
|
|
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
|
|
"""
|
|
real_images, generated_images, captions = generate_samples(trainer, examples, text_prepend)
|
|
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
|
return grid_images, captions
|
|
|
|
def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
|
"""
|
|
Computes evaluation metrics for the decoder
|
|
"""
|
|
metrics = {}
|
|
# Prepare the data
|
|
examples = get_example_data(dataloader, device, n_evaluation_samples)
|
|
real_images, generated_images, captions = generate_samples(trainer, examples)
|
|
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
|
|
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
|
|
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
|
|
int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
|
|
int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
|
|
if exists(FID):
|
|
fid = FrechetInceptionDistance(**FID)
|
|
fid.to(device=device)
|
|
fid.update(int_real_images, real=True)
|
|
fid.update(int_generated_images, real=False)
|
|
metrics["FID"] = fid.compute().item()
|
|
if exists(IS):
|
|
inception = InceptionScore(**IS)
|
|
inception.to(device=device)
|
|
inception.update(int_real_images)
|
|
is_mean, is_std = inception.compute()
|
|
metrics["IS_mean"] = is_mean.item()
|
|
metrics["IS_std"] = is_std.item()
|
|
if exists(KID):
|
|
kernel_inception = KernelInceptionDistance(**KID)
|
|
kernel_inception.to(device=device)
|
|
kernel_inception.update(int_real_images, real=True)
|
|
kernel_inception.update(int_generated_images, real=False)
|
|
kid_mean, kid_std = kernel_inception.compute()
|
|
metrics["KID_mean"] = kid_mean.item()
|
|
metrics["KID_std"] = kid_std.item()
|
|
if exists(LPIPS):
|
|
# Convert from [0, 1] to [-1, 1]
|
|
renorm_real_images = real_images.mul(2).sub(1)
|
|
renorm_generated_images = generated_images.mul(2).sub(1)
|
|
lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS)
|
|
lpips.to(device=device)
|
|
lpips.update(renorm_real_images, renorm_generated_images)
|
|
metrics["LPIPS"] = lpips.compute().item()
|
|
return metrics
|
|
|
|
def save_trainer(tracker, trainer, epoch, step, validation_losses, relative_paths):
|
|
"""
|
|
Logs the model with an appropriate method depending on the tracker
|
|
"""
|
|
if isinstance(relative_paths, str):
|
|
relative_paths = [relative_paths]
|
|
trainer_state_dict = {}
|
|
trainer_state_dict["trainer"] = trainer.state_dict()
|
|
trainer_state_dict['epoch'] = epoch
|
|
trainer_state_dict['step'] = step
|
|
trainer_state_dict['validation_losses'] = validation_losses
|
|
for relative_path in relative_paths:
|
|
tracker.save_state_dict(trainer_state_dict, relative_path)
|
|
|
|
def recall_trainer(tracker, trainer, recall_source=None, **load_config):
|
|
"""
|
|
Loads the model with an appropriate method depending on the tracker
|
|
"""
|
|
print(print_ribbon(f"Loading model from {recall_source}"))
|
|
state_dict = tracker.recall_state_dict(recall_source, **load_config)
|
|
trainer.load_state_dict(state_dict["trainer"])
|
|
print("Model loaded")
|
|
return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"]
|
|
|
|
def train(
|
|
dataloaders,
|
|
decoder,
|
|
tracker,
|
|
inference_device,
|
|
load_config=None,
|
|
evaluate_config=None,
|
|
epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
|
|
validation_samples = None,
|
|
epochs = 20,
|
|
n_sample_images = 5,
|
|
save_every_n_samples = 100000,
|
|
save_all=False,
|
|
save_latest=True,
|
|
save_best=True,
|
|
unet_training_mask=None,
|
|
**kwargs
|
|
):
|
|
"""
|
|
Trains a decoder on a dataset.
|
|
"""
|
|
trainer = DecoderTrainer( # TODO: Change the get_optimizer function so that it can take arbitrary named args so we can just put **kwargs as an argument here
|
|
decoder,
|
|
**kwargs
|
|
)
|
|
# Set up starting model and parameters based on a recalled state dict
|
|
start_step = 0
|
|
start_epoch = 0
|
|
validation_losses = []
|
|
|
|
if exists(load_config) and exists(load_config.source):
|
|
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config)
|
|
trainer.to(device=inference_device)
|
|
|
|
if not exists(unet_training_mask):
|
|
# Then the unet mask should be true for all unets in the decoder
|
|
unet_training_mask = [True] * trainer.num_unets
|
|
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
|
|
|
|
print(print_ribbon("Generating Example Data", repeat=40))
|
|
print("This can take a while to load the shard lists...")
|
|
train_example_data = get_example_data(dataloaders["train_sampling"], inference_device, n_sample_images)
|
|
test_example_data = get_example_data(dataloaders["test_sampling"], inference_device, n_sample_images)
|
|
|
|
send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr]
|
|
step = start_step
|
|
|
|
for epoch in range(start_epoch, epochs):
|
|
print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
|
|
|
|
timer = Timer()
|
|
|
|
sample = 0
|
|
last_sample = 0
|
|
last_snapshot = 0
|
|
|
|
losses = []
|
|
|
|
for i, (img, emb) in enumerate(dataloaders["train"]):
|
|
step += 1
|
|
sample += img.shape[0]
|
|
img, emb = send_to_device((img, emb))
|
|
|
|
trainer.train()
|
|
for unet in range(1, trainer.num_unets+1):
|
|
# Check if this is a unet we are training
|
|
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
|
continue
|
|
|
|
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
|
|
trainer.update(unet_number=unet)
|
|
losses.append(loss)
|
|
|
|
samples_per_sec = (sample - last_sample) / timer.elapsed()
|
|
|
|
timer.reset()
|
|
last_sample = sample
|
|
|
|
if i % TRAIN_CALC_LOSS_EVERY_ITERS == 0:
|
|
average_loss = sum(losses) / len(losses)
|
|
log_data = {
|
|
"Training loss": average_loss,
|
|
"Epoch": epoch,
|
|
"Sample": sample,
|
|
"Step": i,
|
|
"Samples per second": samples_per_sec
|
|
}
|
|
tracker.log(log_data, step=step, verbose=True)
|
|
losses = []
|
|
|
|
if last_snapshot + save_every_n_samples < sample: # This will miss by some amount every time, but it's not a big deal... I hope
|
|
last_snapshot = sample
|
|
# We need to know where the model should be saved
|
|
save_paths = []
|
|
if save_latest:
|
|
save_paths.append("latest.pth")
|
|
if save_all:
|
|
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth")
|
|
|
|
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
|
|
|
|
if exists(n_sample_images) and n_sample_images > 0:
|
|
trainer.eval()
|
|
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
|
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
|
|
|
|
if exists(epoch_samples) and sample >= epoch_samples:
|
|
break
|
|
|
|
trainer.eval()
|
|
print(print_ribbon(f"Starting Validation {epoch}", repeat=40))
|
|
with torch.no_grad():
|
|
sample = 0
|
|
average_loss = 0
|
|
timer = Timer()
|
|
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
|
|
sample += img.shape[0]
|
|
img, emb = send_to_device((img, emb))
|
|
|
|
for unet in range(1, len(decoder.unets)+1):
|
|
loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet)
|
|
average_loss += loss
|
|
|
|
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
|
|
print(f"Epoch {epoch}/{epochs} - {sample / timer.elapsed():.2f} samples/sec")
|
|
print(f"Loss: {average_loss / (i+1)}")
|
|
print("")
|
|
|
|
if exists(validation_samples) and sample >= validation_samples:
|
|
break
|
|
|
|
average_loss /= i+1
|
|
log_data = {
|
|
"Validation loss": average_loss
|
|
}
|
|
tracker.log(log_data, step=step, verbose=True)
|
|
|
|
# Compute evaluation metrics
|
|
if exists(evaluate_config):
|
|
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
|
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
|
|
tracker.log(evaluation, step=step, verbose=True)
|
|
|
|
# Generate sample images
|
|
print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
|
|
test_images, test_captions = generate_grid_samples(trainer, test_example_data, "Test: ")
|
|
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
|
tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step)
|
|
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
|
|
|
|
print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
|
|
# Get the same paths
|
|
save_paths = []
|
|
if save_latest:
|
|
save_paths.append("latest.pth")
|
|
if save_best and (len(validation_losses) == 0 or average_loss < min(validation_losses)):
|
|
save_paths.append("best.pth")
|
|
validation_losses.append(average_loss)
|
|
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
|
|
|
|
def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
|
|
"""
|
|
Creates a tracker of the specified type and initializes special features based on the full config
|
|
"""
|
|
tracker_config = config.tracker
|
|
init_config = {}
|
|
|
|
if exists(tracker_config.init_config):
|
|
init_config["config"] = tracker_config.init_config
|
|
|
|
if tracker_type == "console":
|
|
tracker = ConsoleTracker(**init_config)
|
|
elif tracker_type == "wandb":
|
|
# We need to initialize the resume state here
|
|
load_config = config.load
|
|
if load_config.source == "wandb" and load_config.resume:
|
|
# Then we are resuming the run load_config["run_path"]
|
|
run_id = load_config.run_path.split("/")[-1]
|
|
init_config["id"] = run_id
|
|
init_config["resume"] = "must"
|
|
|
|
init_config["entity"] = tracker_config.wandb_entity
|
|
init_config["project"] = tracker_config.wandb_project
|
|
tracker = WandbTracker(data_path)
|
|
tracker.init(**init_config)
|
|
else:
|
|
raise ValueError(f"Tracker type {tracker_type} not supported by decoder trainer")
|
|
return tracker
|
|
|
|
def initialize_training(config):
|
|
# Create the save path
|
|
if "cuda" in config.train.device:
|
|
assert torch.cuda.is_available(), "CUDA is not available"
|
|
device = torch.device(config.train.device)
|
|
torch.cuda.set_device(device)
|
|
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
|
|
|
|
dataloaders = create_dataloaders (
|
|
available_shards=all_shards,
|
|
img_preproc = config.img_preproc,
|
|
train_prop = config.data.splits.train,
|
|
val_prop = config.data.splits.val,
|
|
test_prop = config.data.splits.test,
|
|
n_sample_images=config.train.n_sample_images,
|
|
**config.data.dict()
|
|
)
|
|
|
|
decoder = create_decoder(device, config.decoder, config.unets)
|
|
num_parameters = sum(p.numel() for p in decoder.parameters())
|
|
print(print_ribbon("Loaded Config", repeat=40))
|
|
print(f"Number of parameters: {num_parameters}")
|
|
|
|
tracker = create_tracker(config, **config.tracker.dict())
|
|
|
|
train(dataloaders, decoder,
|
|
tracker=tracker,
|
|
inference_device=device,
|
|
load_config=config.load,
|
|
evaluate_config=config.evaluate,
|
|
**config.train.dict(),
|
|
)
|
|
|
|
# Create a simple click command line interface to load the config and start the training
|
|
@click.command()
|
|
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
|
|
def main(config_file):
|
|
print("Recalling config from {}".format(config_file))
|
|
config = TrainDecoderConfig.from_json_path(config_file)
|
|
initialize_training(config)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|