mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
move config parsing logic to own file, consider whether to find an off-the-shelf solution at future date
This commit is contained in:
@@ -2,12 +2,11 @@ from dalle2_pytorch import Unet, Decoder
|
||||
from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon
|
||||
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
||||
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
||||
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
||||
from dalle2_pytorch.utils import Timer
|
||||
|
||||
from configs.decoder_defaults import default_config, ConfigField
|
||||
import json
|
||||
import torchvision
|
||||
from torchvision import transforms as T
|
||||
import torch
|
||||
from torchmetrics.image.fid import FrechetInceptionDistance
|
||||
from torchmetrics.image.inception import InceptionScore
|
||||
@@ -300,7 +299,7 @@ def train(
|
||||
timer.reset()
|
||||
last_sample = sample
|
||||
|
||||
if i % CALC_LOSS_EVERY_ITERS == 0:
|
||||
if i % TRAIN_CALC_LOSS_EVERY_ITERS == 0:
|
||||
average_loss = sum(losses) / len(losses)
|
||||
log_data = {
|
||||
"Training loss": average_loss,
|
||||
@@ -440,67 +439,6 @@ def initialize_training(config):
|
||||
**config["train"],
|
||||
)
|
||||
|
||||
|
||||
class TrainDecoderConfig:
|
||||
def __init__(self, config):
|
||||
self.config = self.map_config(config, default_config)
|
||||
|
||||
def map_config(self, config, defaults):
|
||||
"""
|
||||
Returns a dictionary containing all config options in the union of config and defaults.
|
||||
If the config value is an array, apply the default value to each element.
|
||||
If the default values dict has a value of ConfigField.REQUIRED for a key, it is required and a runtime error should be thrown if a value is not supplied from config
|
||||
"""
|
||||
def _check_option(option, option_config, option_defaults):
|
||||
for key, value in option_defaults.items():
|
||||
if key not in option_config:
|
||||
if value == ConfigField.REQUIRED:
|
||||
raise RuntimeError("Required config value '{}' of option '{}' not supplied".format(key, option))
|
||||
option_config[key] = value
|
||||
|
||||
for key, value in defaults.items():
|
||||
if key not in config:
|
||||
# Then they did not pass in one of the main configs. If the default is an array or object, then we can fill it in. If is a required object, we must error
|
||||
if value == ConfigField.REQUIRED:
|
||||
raise RuntimeError("Required config value '{}' not supplied".format(key))
|
||||
elif isinstance(value, dict):
|
||||
config[key] = {}
|
||||
elif isinstance(value, list):
|
||||
config[key] = [{}]
|
||||
# Config[key] is now either a dict, list of dicts, or an object that cannot be checked.
|
||||
# If it is a list, then we need to check each element
|
||||
if isinstance(value, list):
|
||||
assert isinstance(config[key], list)
|
||||
for element in config[key]:
|
||||
_check_option(key, element, value[0])
|
||||
elif isinstance(value, dict):
|
||||
_check_option(key, config[key], value)
|
||||
# This object does not support checking
|
||||
return config
|
||||
|
||||
def get_preprocessing(self):
|
||||
"""
|
||||
Takes the preprocessing dictionary and converts it to a composition of torchvision transforms
|
||||
"""
|
||||
def _get_transformation(transformation_name, **kwargs):
|
||||
if transformation_name == "RandomResizedCrop":
|
||||
return T.RandomResizedCrop(**kwargs)
|
||||
elif transformation_name == "RandomHorizontalFlip":
|
||||
return T.RandomHorizontalFlip()
|
||||
elif transformation_name == "ToTensor":
|
||||
return T.ToTensor()
|
||||
|
||||
transformations = []
|
||||
for transformation_name, transformation_kwargs in self.config["data"]["preprocessing"].items():
|
||||
if isinstance(transformation_kwargs, dict):
|
||||
transformations.append(_get_transformation(transformation_name, **transformation_kwargs))
|
||||
else:
|
||||
transformations.append(_get_transformation(transformation_name))
|
||||
return T.Compose(transformations)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.config[key]
|
||||
|
||||
# Create a simple click command line interface to load the config and start the training
|
||||
@click.command()
|
||||
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
|
||||
|
||||
Reference in New Issue
Block a user