mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 10:14:19 +01:00
move config parsing logic to own file, consider whether to find an off-the-shelf solution at future date
This commit is contained in:
2
setup.py
2
setup.py
@@ -10,7 +10,7 @@ setup(
|
|||||||
'dream = dalle2_pytorch.cli:dream'
|
'dream = dalle2_pytorch.cli:dream'
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
version = '0.3.7',
|
version = '0.3.8',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'DALL-E 2',
|
description = 'DALL-E 2',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
|
|||||||
@@ -2,12 +2,11 @@ from dalle2_pytorch import Unet, Decoder
|
|||||||
from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon
|
from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon
|
||||||
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
||||||
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
||||||
|
from dalle2_pytorch.train_configs import TrainDecoderConfig
|
||||||
from dalle2_pytorch.utils import Timer
|
from dalle2_pytorch.utils import Timer
|
||||||
|
|
||||||
from configs.decoder_defaults import default_config, ConfigField
|
|
||||||
import json
|
import json
|
||||||
import torchvision
|
import torchvision
|
||||||
from torchvision import transforms as T
|
|
||||||
import torch
|
import torch
|
||||||
from torchmetrics.image.fid import FrechetInceptionDistance
|
from torchmetrics.image.fid import FrechetInceptionDistance
|
||||||
from torchmetrics.image.inception import InceptionScore
|
from torchmetrics.image.inception import InceptionScore
|
||||||
@@ -440,67 +439,6 @@ def initialize_training(config):
|
|||||||
**config["train"],
|
**config["train"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TrainDecoderConfig:
|
|
||||||
def __init__(self, config):
|
|
||||||
self.config = self.map_config(config, default_config)
|
|
||||||
|
|
||||||
def map_config(self, config, defaults):
|
|
||||||
"""
|
|
||||||
Returns a dictionary containing all config options in the union of config and defaults.
|
|
||||||
If the config value is an array, apply the default value to each element.
|
|
||||||
If the default values dict has a value of ConfigField.REQUIRED for a key, it is required and a runtime error should be thrown if a value is not supplied from config
|
|
||||||
"""
|
|
||||||
def _check_option(option, option_config, option_defaults):
|
|
||||||
for key, value in option_defaults.items():
|
|
||||||
if key not in option_config:
|
|
||||||
if value == ConfigField.REQUIRED:
|
|
||||||
raise RuntimeError("Required config value '{}' of option '{}' not supplied".format(key, option))
|
|
||||||
option_config[key] = value
|
|
||||||
|
|
||||||
for key, value in defaults.items():
|
|
||||||
if key not in config:
|
|
||||||
# Then they did not pass in one of the main configs. If the default is an array or object, then we can fill it in. If is a required object, we must error
|
|
||||||
if value == ConfigField.REQUIRED:
|
|
||||||
raise RuntimeError("Required config value '{}' not supplied".format(key))
|
|
||||||
elif isinstance(value, dict):
|
|
||||||
config[key] = {}
|
|
||||||
elif isinstance(value, list):
|
|
||||||
config[key] = [{}]
|
|
||||||
# Config[key] is now either a dict, list of dicts, or an object that cannot be checked.
|
|
||||||
# If it is a list, then we need to check each element
|
|
||||||
if isinstance(value, list):
|
|
||||||
assert isinstance(config[key], list)
|
|
||||||
for element in config[key]:
|
|
||||||
_check_option(key, element, value[0])
|
|
||||||
elif isinstance(value, dict):
|
|
||||||
_check_option(key, config[key], value)
|
|
||||||
# This object does not support checking
|
|
||||||
return config
|
|
||||||
|
|
||||||
def get_preprocessing(self):
|
|
||||||
"""
|
|
||||||
Takes the preprocessing dictionary and converts it to a composition of torchvision transforms
|
|
||||||
"""
|
|
||||||
def _get_transformation(transformation_name, **kwargs):
|
|
||||||
if transformation_name == "RandomResizedCrop":
|
|
||||||
return T.RandomResizedCrop(**kwargs)
|
|
||||||
elif transformation_name == "RandomHorizontalFlip":
|
|
||||||
return T.RandomHorizontalFlip()
|
|
||||||
elif transformation_name == "ToTensor":
|
|
||||||
return T.ToTensor()
|
|
||||||
|
|
||||||
transformations = []
|
|
||||||
for transformation_name, transformation_kwargs in self.config["data"]["preprocessing"].items():
|
|
||||||
if isinstance(transformation_kwargs, dict):
|
|
||||||
transformations.append(_get_transformation(transformation_name, **transformation_kwargs))
|
|
||||||
else:
|
|
||||||
transformations.append(_get_transformation(transformation_name))
|
|
||||||
return T.Compose(transformations)
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
return self.config[key]
|
|
||||||
|
|
||||||
# Create a simple click command line interface to load the config and start the training
|
# Create a simple click command line interface to load the config and start the training
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
|
@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
|
||||||
|
|||||||
Reference in New Issue
Block a user