mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-06 16:54:20 +01:00
use pydantic to manage decoder training configs + defaults and refactor training script
This commit is contained in:
@@ -90,11 +90,11 @@ def create_dataloaders(
|
||||
def create_decoder(device, decoder_config, unets_config):
|
||||
"""Creates a sample decoder"""
|
||||
|
||||
unets = [Unet(**config) for config in unets_config]
|
||||
unets = [Unet(**config.dict()) for config in unets_config]
|
||||
|
||||
decoder = Decoder(
|
||||
unet=unets,
|
||||
**decoder_config
|
||||
**decoder_config.dict()
|
||||
)
|
||||
|
||||
decoder.to(device=device)
|
||||
@@ -154,13 +154,13 @@ def generate_grid_samples(trainer, examples, text_prepend=""):
|
||||
grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
|
||||
return grid_images, captions
|
||||
|
||||
def evaluate_trainer(trainer, dataloader, device, n_evalation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
||||
def evaluate_trainer(trainer, dataloader, device, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
|
||||
"""
|
||||
Computes evaluation metrics for the decoder
|
||||
"""
|
||||
metrics = {}
|
||||
# Prepare the data
|
||||
examples = get_example_data(dataloader, device, n_evalation_samples)
|
||||
examples = get_example_data(dataloader, device, n_evaluation_samples)
|
||||
real_images, generated_images, captions = generate_samples(trainer, examples)
|
||||
real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
|
||||
generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
|
||||
@@ -252,8 +252,8 @@ def train(
|
||||
start_epoch = 0
|
||||
validation_losses = []
|
||||
|
||||
if exists(load_config) and exists(load_config["source"]):
|
||||
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config["source"], **load_config)
|
||||
if exists(load_config) and exists(load_config.source):
|
||||
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config.source, **load_config)
|
||||
trainer.to(device=inference_device)
|
||||
|
||||
if not exists(unet_training_mask):
|
||||
@@ -386,21 +386,25 @@ def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
|
||||
"""
|
||||
Creates a tracker of the specified type and initializes special features based on the full config
|
||||
"""
|
||||
tracker_config = config["tracker"]
|
||||
tracker_config = config.tracker
|
||||
init_config = {}
|
||||
init_config["config"] = config.config
|
||||
|
||||
if exists(tracker_config.init_config):
|
||||
init_config["config"] = tracker_config.init_config
|
||||
|
||||
if tracker_type == "console":
|
||||
tracker = ConsoleTracker(**init_config)
|
||||
elif tracker_type == "wandb":
|
||||
# We need to initialize the resume state here
|
||||
load_config = config["load"]
|
||||
if load_config["source"] == "wandb" and load_config["resume"]:
|
||||
load_config = config.load
|
||||
if load_config.source == "wandb" and load_config.resume:
|
||||
# Then we are resuming the run load_config["run_path"]
|
||||
run_id = config["resume"]["wandb_run_path"].split("/")[-1]
|
||||
run_id = load_config.run_path.split("/")[-1]
|
||||
init_config["id"] = run_id
|
||||
init_config["resume"] = "must"
|
||||
init_config["entity"] = tracker_config["wandb_entity"]
|
||||
init_config["project"] = tracker_config["wandb_project"]
|
||||
|
||||
init_config["entity"] = tracker_config.wandb_entity
|
||||
init_config["project"] = tracker_config.wandb_project
|
||||
tracker = WandbTracker(data_path)
|
||||
tracker.init(**init_config)
|
||||
else:
|
||||
@@ -409,35 +413,35 @@ def create_tracker(config, tracker_type=None, data_path=None, **kwargs):
|
||||
|
||||
def initialize_training(config):
|
||||
# Create the save path
|
||||
if "cuda" in config["train"]["device"]:
|
||||
if "cuda" in config.train.device:
|
||||
assert torch.cuda.is_available(), "CUDA is not available"
|
||||
device = torch.device(config["train"]["device"])
|
||||
device = torch.device(config.train.device)
|
||||
torch.cuda.set_device(device)
|
||||
all_shards = list(range(config["data"]["start_shard"], config["data"]["end_shard"] + 1))
|
||||
all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
|
||||
|
||||
dataloaders = create_dataloaders (
|
||||
available_shards=all_shards,
|
||||
img_preproc = config.get_preprocessing(),
|
||||
train_prop = config["data"]["splits"]["train"],
|
||||
val_prop = config["data"]["splits"]["val"],
|
||||
test_prop = config["data"]["splits"]["test"],
|
||||
n_sample_images=config["train"]["n_sample_images"],
|
||||
**config["data"]
|
||||
img_preproc = config.img_preproc,
|
||||
train_prop = config.data["splits"]["train"],
|
||||
val_prop = config.data["splits"]["val"],
|
||||
test_prop = config.data["splits"]["test"],
|
||||
n_sample_images=config.train.n_sample_images,
|
||||
**config.data.dict()
|
||||
)
|
||||
|
||||
decoder = create_decoder(device, config["decoder"], config["unets"])
|
||||
decoder = create_decoder(device, config.decoder, config.unets)
|
||||
num_parameters = sum(p.numel() for p in decoder.parameters())
|
||||
print(print_ribbon("Loaded Config", repeat=40))
|
||||
print(f"Number of parameters: {num_parameters}")
|
||||
|
||||
tracker = create_tracker(config, **config["tracker"])
|
||||
tracker = create_tracker(config, **config.tracker.dict())
|
||||
|
||||
train(dataloaders, decoder,
|
||||
tracker=tracker,
|
||||
inference_device=device,
|
||||
load_config=config["load"],
|
||||
evaluate_config=config["evaluate"],
|
||||
**config["train"],
|
||||
load_config=config.load,
|
||||
evaluate_config=config.evaluate,
|
||||
**config.train.dict(),
|
||||
)
|
||||
|
||||
# Create a simple click command line interface to load the config and start the training
|
||||
@@ -447,7 +451,7 @@ def main(config_file):
|
||||
print("Recalling config from {}".format(config_file))
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
config = TrainDecoderConfig(config)
|
||||
config = TrainDecoderConfig(**config)
|
||||
initialize_training(config)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user