diff --git a/train_decoder.py b/train_decoder.py index f179c01..ef26d38 100644 --- a/train_decoder.py +++ b/train_decoder.py @@ -16,6 +16,17 @@ from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity import webdataset as wds import click +# constants + +TRAIN_CALC_LOSS_EVERY_ITERS = 10 +VALID_CALC_LOSS_EVERY_ITERS = 10 + +# helpers functions + +def exists(val): + return val is not None + +# main functions def create_dataloaders( available_shards, @@ -79,18 +90,15 @@ def create_dataloaders( def create_decoder(device, decoder_config, unets_config): """Creates a sample decoder""" - unets = [] - for i in range(0, len(unets_config)): - unets.append(Unet( - **unets_config[i] - )) - + + unets = [Unet(**config) for config in unets_config] + decoder = Decoder( unet=unets, **decoder_config ) - decoder.to(device=device) + decoder.to(device=device) return decoder def get_dataset_keys(dataloader): @@ -160,20 +168,20 @@ def evaluate_trainer(trainer, dataloader, device, n_evalation_samples=1000, FID= # Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8 int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8) int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8) - if FID is not None: + if exists(FID): fid = FrechetInceptionDistance(**FID) fid.to(device=device) fid.update(int_real_images, real=True) fid.update(int_generated_images, real=False) metrics["FID"] = fid.compute().item() - if IS is not None: + if exists(IS): inception = InceptionScore(**IS) inception.to(device=device) inception.update(int_real_images) is_mean, is_std = inception.compute() metrics["IS_mean"] = is_mean.item() metrics["IS_std"] = is_std.item() - if KID is not None: + if exists(KID): kernel_inception = KernelInceptionDistance(**KID) kernel_inception.to(device=device) kernel_inception.update(int_real_images, real=True) @@ -181,7 +189,7 @@ def evaluate_trainer(trainer, dataloader, device, n_evalation_samples=1000, FID= kid_mean, kid_std = kernel_inception.compute() metrics["KID_mean"] = kid_mean.item() metrics["KID_std"] = kid_std.item() - if LPIPS is not None: + if exists(LPIPS): # Convert from [0, 1] to [-1, 1] renorm_real_images = real_images.mul(2).sub(1) renorm_generated_images = generated_images.mul(2).sub(1) @@ -245,11 +253,11 @@ def train( start_epoch = 0 validation_losses = [] - if load_config is not None and load_config["source"] is not None: + if exists(load_config) and exists(load_config["source"]): start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config["source"], **load_config) trainer.to(device=inference_device) - if unet_training_mask is None: + if not exists(unet_training_mask): # Then the unet mask should be true for all unets in the decoder unet_training_mask = [True] * trainer.num_unets assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}" @@ -280,17 +288,19 @@ def train( for unet in range(1, trainer.num_unets+1): # Check if this is a unet we are training - if unet_training_mask[unet-1]: # Unet index is the unet number - 1 - loss = trainer.forward(img, image_embed=emb, unet_number=unet) - trainer.update(unet_number=unet) - losses.append(loss) + if not unet_training_mask[unet-1]: # Unet index is the unet number - 1 + continue + + loss = trainer.forward(img, image_embed=emb, unet_number=unet) + trainer.update(unet_number=unet) + losses.append(loss) samples_per_sec = (sample - last_sample) / timer.elapsed() timer.reset() last_sample = sample - if i % 10 == 0: + if i % CALC_LOSS_EVERY_ITERS == 0: average_loss = sum(losses) / len(losses) log_data = { "Training loss": average_loss, @@ -311,13 +321,13 @@ def train( if save_all: save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth") save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths) - if n_sample_images is not None and n_sample_images > 0: + if exists(n_sample_images) and n_sample_images > 0: trainer.eval() train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ") trainer.train() tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step) - if epoch_samples is not None and sample >= epoch_samples: + if exists(epoch_samples) and sample >= epoch_samples: break trainer.eval() @@ -334,12 +344,12 @@ def train( loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet) average_loss += loss - if i % 10 == 0: + if i % VALID_CALC_LOSS_EVERY_ITERS == 0: print(f"Epoch {epoch}/{epochs} - {sample / timer.elapsed():.2f} samples/sec") print(f"Loss: {average_loss / (i+1)}") print("") - if validation_samples is not None and sample >= validation_samples: + if exists(validation_samples) and sample >= validation_samples: break average_loss /= i+1 @@ -350,7 +360,7 @@ def train( # Compute evaluation metrics trainer.eval() - if evaluate_config is not None: + if exists(evaluate_config): print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40)) evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config) tracker.log(evaluation, step=step, verbose=True)