use pydantic to manage decoder training configs + defaults and refactor training script

This commit is contained in:
Phil Wang
2022-05-22 14:27:40 -07:00
parent d49eca62fa
commit a1ef023193
7 changed files with 145 additions and 264 deletions

View File

@@ -90,11 +90,11 @@ def create_dataloaders(
def create_decoder(device, decoder_config, unets_config):
"""Creates a sample decoder"""
unets = [Unet(**config) for config in unets_config]
unets = [Unet(**config.dict()) for config in unets_config]
decoder = Decoder(
unet=unets,
**decoder_config
**decoder_config.dict()
)
decoder.to(device=device)
@@ -154,13 +154,13 @@ def generate_grid_samples(trainer, examples, text_prepend=""):
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
return grid_images, captions
def evaluate_trainer(trainer, dataloader, device, n_evalation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
"""
Computes evaluation metrics for the decoder
"""
metrics = {}
# Prepare the data
examples = get_example_data(dataloader, device, n_evalation_samples)
examples = get_example_data(dataloader, device, n_evaluation_samples)
real_images, generated_images, captions = generate_samples(trainer, examples)
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
@@ -252,8 +252,8 @@ def train(
start_epoch = 0
validation_losses = []
if exists(load_config) and exists(load_config["source"]):
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config["source"], **load_config)
if exists(load_config) and exists(load_config.source):
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config)
trainer.to(device=inference_device)
if not exists(unet_training_mask):
@@ -386,21 +386,25 @@ def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
"""
Creates a tracker of the specified type and initializes special features based on the full config
"""
tracker_config = config["tracker"]
tracker_config = config.tracker
init_config = {}
init_config["config"] = config.config
if exists(tracker_config.init_config):
init_config["config"] = tracker_config.init_config
if tracker_type == "console":
tracker = ConsoleTracker(**init_config)
elif tracker_type == "wandb":
# We need to initialize the resume state here
load_config = config["load"]
if load_config["source"] == "wandb" and load_config["resume"]:
load_config = config.load
if load_config.source == "wandb" and load_config.resume:
# Then we are resuming the run load_config["run_path"]
run_id = config["resume"]["wandb_run_path"].split("/")[-1]
run_id = load_config.run_path.split("/")[-1]
init_config["id"] = run_id
init_config["resume"] = "must"
init_config["entity"] = tracker_config["wandb_entity"]
init_config["project"] = tracker_config["wandb_project"]
init_config["entity"] = tracker_config.wandb_entity
init_config["project"] = tracker_config.wandb_project
tracker = WandbTracker(data_path)
tracker.init(**init_config)
else:
@@ -409,35 +413,35 @@ def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
def initialize_training(config):
# Create the save path
if "cuda" in config["train"]["device"]:
if "cuda" in config.train.device:
assert torch.cuda.is_available(), "CUDA is not available"
device = torch.device(config["train"]["device"])
device = torch.device(config.train.device)
torch.cuda.set_device(device)
all_shards = list(range(config["data"]["start_shard"], config["data"]["end_shard"] + 1))
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
dataloaders = create_dataloaders (
available_shards=all_shards,
img_preproc = config.get_preprocessing(),
train_prop = config["data"]["splits"]["train"],
val_prop = config["data"]["splits"]["val"],
test_prop = config["data"]["splits"]["test"],
n_sample_images=config["train"]["n_sample_images"],
**config["data"]
img_preproc = config.img_preproc,
train_prop = config.data["splits"]["train"],
val_prop = config.data["splits"]["val"],
test_prop = config.data["splits"]["test"],
n_sample_images=config.train.n_sample_images,
**config.data.dict()
)
decoder = create_decoder(device, config["decoder"], config["unets"])
decoder = create_decoder(device, config.decoder, config.unets)
num_parameters = sum(p.numel() for p in decoder.parameters())
print(print_ribbon("Loaded Config", repeat=40))
print(f"Number of parameters: {num_parameters}")
tracker = create_tracker(config, **config["tracker"])
tracker = create_tracker(config, **config.tracker.dict())
train(dataloaders, decoder,
tracker=tracker,
inference_device=device,
load_config=config["load"],
evaluate_config=config["evaluate"],
**config["train"],
load_config=config.load,
evaluate_config=config.evaluate,
**config.train.dict(),
)
# Create a simple click command line interface to load the config and start the training
@@ -447,7 +451,7 @@ def main(config_file):
print("Recalling config from {}".format(config_file))
with open(config_file) as f:
config = json.load(f)
config = TrainDecoderConfig(config)
config = TrainDecoderConfig(**config)
initialize_training(config)