small cleanup of decoder train script

This commit is contained in:
Phil Wang
2022-05-21 10:17:07 -07:00
parent b895f52843
commit 0064661729

View File

@@ -16,6 +16,17 @@ from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
import webdataset as wds
import click
# constants
TRAIN_CALC_LOSS_EVERY_ITERS = 10
VALID_CALC_LOSS_EVERY_ITERS = 10
# helpers functions
def exists(val):
return val is not None
# main functions
def create_dataloaders(
available_shards,
@@ -79,18 +90,15 @@ def create_dataloaders(
def create_decoder(device, decoder_config, unets_config):
"""Creates a sample decoder"""
unets = []
for i in range(0, len(unets_config)):
unets.append(Unet(
**unets_config[i]
))
unets = [Unet(**config) for config in unets_config]
decoder = Decoder(
unet=unets,
**decoder_config
)
decoder.to(device=device)
decoder.to(device=device)
return decoder
def get_dataset_keys(dataloader):
@@ -160,20 +168,20 @@ def evaluate_trainer(trainer, dataloader, device, n_evalation_samples=1000, FID=
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
if FID is not None:
if exists(FID):
fid = FrechetInceptionDistance(**FID)
fid.to(device=device)
fid.update(int_real_images, real=True)
fid.update(int_generated_images, real=False)
metrics["FID"] = fid.compute().item()
if IS is not None:
if exists(IS):
inception = InceptionScore(**IS)
inception.to(device=device)
inception.update(int_real_images)
is_mean, is_std = inception.compute()
metrics["IS_mean"] = is_mean.item()
metrics["IS_std"] = is_std.item()
if KID is not None:
if exists(KID):
kernel_inception = KernelInceptionDistance(**KID)
kernel_inception.to(device=device)
kernel_inception.update(int_real_images, real=True)
@@ -181,7 +189,7 @@ def evaluate_trainer(trainer, dataloader, device, n_evalation_samples=1000, FID=
kid_mean, kid_std = kernel_inception.compute()
metrics["KID_mean"] = kid_mean.item()
metrics["KID_std"] = kid_std.item()
if LPIPS is not None:
if exists(LPIPS):
# Convert from [0, 1] to [-1, 1]
renorm_real_images = real_images.mul(2).sub(1)
renorm_generated_images = generated_images.mul(2).sub(1)
@@ -245,11 +253,11 @@ def train(
start_epoch = 0
validation_losses = []
if load_config is not None and load_config["source"] is not None:
if exists(load_config) and exists(load_config["source"]):
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config["source"], **load_config)
trainer.to(device=inference_device)
if unet_training_mask is None:
if not exists(unet_training_mask):
# Then the unet mask should be true for all unets in the decoder
unet_training_mask = [True] * trainer.num_unets
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
@@ -280,7 +288,9 @@ def train(
for unet in range(1, trainer.num_unets+1):
# Check if this is a unet we are training
if unet_training_mask[unet-1]: # Unet index is the unet number - 1
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
continue
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
trainer.update(unet_number=unet)
losses.append(loss)
@@ -290,7 +300,7 @@ def train(
timer.reset()
last_sample = sample
if i % 10 == 0:
if i % CALC_LOSS_EVERY_ITERS == 0:
average_loss = sum(losses) / len(losses)
log_data = {
"Training loss": average_loss,
@@ -311,13 +321,13 @@ def train(
if save_all:
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth")
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
if n_sample_images is not None and n_sample_images > 0:
if exists(n_sample_images) and n_sample_images > 0:
trainer.eval()
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
trainer.train()
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
if epoch_samples is not None and sample >= epoch_samples:
if exists(epoch_samples) and sample >= epoch_samples:
break
trainer.eval()
@@ -334,12 +344,12 @@ def train(
loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet)
average_loss += loss
if i % 10 == 0:
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
print(f"Epoch {epoch}/{epochs} - {sample / timer.elapsed():.2f} samples/sec")
print(f"Loss: {average_loss / (i+1)}")
print("")
if validation_samples is not None and sample >= validation_samples:
if exists(validation_samples) and sample >= validation_samples:
break
average_loss /= i+1
@@ -350,7 +360,7 @@ def train(
# Compute evaluation metrics
trainer.eval()
if evaluate_config is not None:
if exists(evaluate_config):
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
tracker.log(evaluation, step=step, verbose=True)