mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5340a96c0f | ||
|
|
0064661729 | ||
|
|
b895f52843 | ||
|
|
80497e9839 |
12
README.md
12
README.md
@@ -1034,6 +1034,18 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
|
||||
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
|
||||
|
||||
## Appreciation
|
||||
|
||||
This library would not have gotten to this working state without the help of
|
||||
|
||||
- <a href="https://github.com/nousr">Zion</a> and <a href="https://github.com/krish240574">Kumar</a> for the diffusion training script
|
||||
- <a href="https://github.com/Veldrovive">Aidan</a> for the decoder training script and dataloaders
|
||||
- <a href="https://github.com/rom1504">Romain</a> for the pull request reviews and project management
|
||||
- <a href="https://github.com/Ciaohe">He Cao</a> and <a href="https://github.com/xiankgx">xiankgx</a> for the Q&A and for identifying of critical bugs
|
||||
- <a href="https://github.com/crowsonkb">Katherine</a> for her advice
|
||||
|
||||
... and many others. Thank you! 🙏
|
||||
|
||||
## Todo
|
||||
|
||||
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
|
||||
|
||||
@@ -59,6 +59,9 @@ def default(val, d):
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
def cast_tuple(val, length = 1):
|
||||
if isinstance(val, list):
|
||||
val = tuple(val)
|
||||
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
def module_device(module):
|
||||
|
||||
2
setup.py
2
setup.py
@@ -10,7 +10,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.3.6',
|
||||
version = '0.3.8',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
|
||||
122
train_decoder.py
122
train_decoder.py
@@ -2,12 +2,11 @@ from dalle2_pytorch import Unet, Decoder
|
||||
from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon
|
||||
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
||||
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
||||
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
||||
from dalle2_pytorch.utils import Timer
|
||||
|
||||
from configs.decoder_defaults import default_config, ConfigField
|
||||
import json
|
||||
import torchvision
|
||||
from torchvision import transforms as T
|
||||
import torch
|
||||
from torchmetrics.image.fid import FrechetInceptionDistance
|
||||
from torchmetrics.image.inception import InceptionScore
|
||||
@@ -16,6 +15,17 @@ from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
||||
import webdataset as wds
|
||||
import click
|
||||
|
||||
# constants
|
||||
|
||||
TRAIN_CALC_LOSS_EVERY_ITERS = 10
|
||||
VALID_CALC_LOSS_EVERY_ITERS = 10
|
||||
|
||||
# helpers functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
# main functions
|
||||
|
||||
def create_dataloaders(
|
||||
available_shards,
|
||||
@@ -79,18 +89,15 @@ def create_dataloaders(
|
||||
|
||||
def create_decoder(device, decoder_config, unets_config):
|
||||
"""Creates a sample decoder"""
|
||||
unets = []
|
||||
for i in range(0, len(unets_config)):
|
||||
unets.append(Unet(
|
||||
**unets_config[i]
|
||||
))
|
||||
|
||||
|
||||
unets = [Unet(**config) for config in unets_config]
|
||||
|
||||
decoder = Decoder(
|
||||
unet=tuple(unets), # Must be tuple because of cast_tuple
|
||||
unet=unets,
|
||||
**decoder_config
|
||||
)
|
||||
decoder.to(device=device)
|
||||
|
||||
decoder.to(device=device)
|
||||
return decoder
|
||||
|
||||
def get_dataset_keys(dataloader):
|
||||
@@ -160,20 +167,20 @@ def evaluate_trainer(trainer, dataloader, device, n_evalation_samples=1000, FID=
|
||||
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
|
||||
int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
|
||||
int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
|
||||
if FID is not None:
|
||||
if exists(FID):
|
||||
fid = FrechetInceptionDistance(**FID)
|
||||
fid.to(device=device)
|
||||
fid.update(int_real_images, real=True)
|
||||
fid.update(int_generated_images, real=False)
|
||||
metrics["FID"] = fid.compute().item()
|
||||
if IS is not None:
|
||||
if exists(IS):
|
||||
inception = InceptionScore(**IS)
|
||||
inception.to(device=device)
|
||||
inception.update(int_real_images)
|
||||
is_mean, is_std = inception.compute()
|
||||
metrics["IS_mean"] = is_mean.item()
|
||||
metrics["IS_std"] = is_std.item()
|
||||
if KID is not None:
|
||||
if exists(KID):
|
||||
kernel_inception = KernelInceptionDistance(**KID)
|
||||
kernel_inception.to(device=device)
|
||||
kernel_inception.update(int_real_images, real=True)
|
||||
@@ -181,7 +188,7 @@ def evaluate_trainer(trainer, dataloader, device, n_evalation_samples=1000, FID=
|
||||
kid_mean, kid_std = kernel_inception.compute()
|
||||
metrics["KID_mean"] = kid_mean.item()
|
||||
metrics["KID_std"] = kid_std.item()
|
||||
if LPIPS is not None:
|
||||
if exists(LPIPS):
|
||||
# Convert from [0, 1] to [-1, 1]
|
||||
renorm_real_images = real_images.mul(2).sub(1)
|
||||
renorm_generated_images = generated_images.mul(2).sub(1)
|
||||
@@ -245,11 +252,11 @@ def train(
|
||||
start_epoch = 0
|
||||
validation_losses = []
|
||||
|
||||
if load_config is not None and load_config["source"] is not None:
|
||||
if exists(load_config) and exists(load_config["source"]):
|
||||
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config["source"], **load_config)
|
||||
trainer.to(device=inference_device)
|
||||
|
||||
if unet_training_mask is None:
|
||||
if not exists(unet_training_mask):
|
||||
# Then the unet mask should be true for all unets in the decoder
|
||||
unet_training_mask = [True] * trainer.num_unets
|
||||
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
|
||||
@@ -280,17 +287,19 @@ def train(
|
||||
|
||||
for unet in range(1, trainer.num_unets+1):
|
||||
# Check if this is a unet we are training
|
||||
if unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
||||
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
|
||||
trainer.update(unet_number=unet)
|
||||
losses.append(loss)
|
||||
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
||||
continue
|
||||
|
||||
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
|
||||
trainer.update(unet_number=unet)
|
||||
losses.append(loss)
|
||||
|
||||
samples_per_sec = (sample - last_sample) / timer.elapsed()
|
||||
|
||||
timer.reset()
|
||||
last_sample = sample
|
||||
|
||||
if i % 10 == 0:
|
||||
if i % CALC_LOSS_EVERY_ITERS == 0:
|
||||
average_loss = sum(losses) / len(losses)
|
||||
log_data = {
|
||||
"Training loss": average_loss,
|
||||
@@ -311,13 +320,13 @@ def train(
|
||||
if save_all:
|
||||
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth")
|
||||
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
|
||||
if n_sample_images is not None and n_sample_images > 0:
|
||||
if exists(n_sample_images) and n_sample_images > 0:
|
||||
trainer.eval()
|
||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
||||
trainer.train()
|
||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
|
||||
|
||||
if epoch_samples is not None and sample >= epoch_samples:
|
||||
if exists(epoch_samples) and sample >= epoch_samples:
|
||||
break
|
||||
|
||||
trainer.eval()
|
||||
@@ -334,12 +343,12 @@ def train(
|
||||
loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet)
|
||||
average_loss += loss
|
||||
|
||||
if i % 10 == 0:
|
||||
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
|
||||
print(f"Epoch {epoch}/{epochs} - {sample / timer.elapsed():.2f} samples/sec")
|
||||
print(f"Loss: {average_loss / (i+1)}")
|
||||
print("")
|
||||
|
||||
if validation_samples is not None and sample >= validation_samples:
|
||||
if exists(validation_samples) and sample >= validation_samples:
|
||||
break
|
||||
|
||||
average_loss /= i+1
|
||||
@@ -350,7 +359,7 @@ def train(
|
||||
|
||||
# Compute evaluation metrics
|
||||
trainer.eval()
|
||||
if evaluate_config is not None:
|
||||
if exists(evaluate_config):
|
||||
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
|
||||
tracker.log(evaluation, step=step, verbose=True)
|
||||
@@ -430,67 +439,6 @@ def initialize_training(config):
|
||||
**config["train"],
|
||||
)
|
||||
|
||||
|
||||
class TrainDecoderConfig:
|
||||
def __init__(self, config):
|
||||
self.config = self.map_config(config, default_config)
|
||||
|
||||
def map_config(self, config, defaults):
|
||||
"""
|
||||
Returns a dictionary containing all config options in the union of config and defaults.
|
||||
If the config value is an array, apply the default value to each element.
|
||||
If the default values dict has a value of ConfigField.REQUIRED for a key, it is required and a runtime error should be thrown if a value is not supplied from config
|
||||
"""
|
||||
def _check_option(option, option_config, option_defaults):
|
||||
for key, value in option_defaults.items():
|
||||
if key not in option_config:
|
||||
if value == ConfigField.REQUIRED:
|
||||
raise RuntimeError("Required config value '{}' of option '{}' not supplied".format(key, option))
|
||||
option_config[key] = value
|
||||
|
||||
for key, value in defaults.items():
|
||||
if key not in config:
|
||||
# Then they did not pass in one of the main configs. If the default is an array or object, then we can fill it in. If is a required object, we must error
|
||||
if value == ConfigField.REQUIRED:
|
||||
raise RuntimeError("Required config value '{}' not supplied".format(key))
|
||||
elif isinstance(value, dict):
|
||||
config[key] = {}
|
||||
elif isinstance(value, list):
|
||||
config[key] = [{}]
|
||||
# Config[key] is now either a dict, list of dicts, or an object that cannot be checked.
|
||||
# If it is a list, then we need to check each element
|
||||
if isinstance(value, list):
|
||||
assert isinstance(config[key], list)
|
||||
for element in config[key]:
|
||||
_check_option(key, element, value[0])
|
||||
elif isinstance(value, dict):
|
||||
_check_option(key, config[key], value)
|
||||
# This object does not support checking
|
||||
return config
|
||||
|
||||
def get_preprocessing(self):
|
||||
"""
|
||||
Takes the preprocessing dictionary and converts it to a composition of torchvision transforms
|
||||
"""
|
||||
def _get_transformation(transformation_name, **kwargs):
|
||||
if transformation_name == "RandomResizedCrop":
|
||||
return T.RandomResizedCrop(**kwargs)
|
||||
elif transformation_name == "RandomHorizontalFlip":
|
||||
return T.RandomHorizontalFlip()
|
||||
elif transformation_name == "ToTensor":
|
||||
return T.ToTensor()
|
||||
|
||||
transformations = []
|
||||
for transformation_name, transformation_kwargs in self.config["data"]["preprocessing"].items():
|
||||
if isinstance(transformation_kwargs, dict):
|
||||
transformations.append(_get_transformation(transformation_name, **transformation_kwargs))
|
||||
else:
|
||||
transformations.append(_get_transformation(transformation_name))
|
||||
return T.Compose(transformations)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.config[key]
|
||||
|
||||
# Create a simple click command line interface to load the config and start the training
|
||||
@click.command()
|
||||
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
|
||||
|
||||
Reference in New Issue
Block a user