mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
small cleanup of decoder train script
This commit is contained in:
@@ -16,6 +16,17 @@ from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
|||||||
import webdataset as wds
|
import webdataset as wds
|
||||||
import click
|
import click
|
||||||
|
|
||||||
|
# constants
|
||||||
|
|
||||||
|
TRAIN_CALC_LOSS_EVERY_ITERS = 10
|
||||||
|
VALID_CALC_LOSS_EVERY_ITERS = 10
|
||||||
|
|
||||||
|
# helpers functions
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
# main functions
|
||||||
|
|
||||||
def create_dataloaders(
|
def create_dataloaders(
|
||||||
available_shards,
|
available_shards,
|
||||||
@@ -79,18 +90,15 @@ def create_dataloaders(
|
|||||||
|
|
||||||
def create_decoder(device, decoder_config, unets_config):
|
def create_decoder(device, decoder_config, unets_config):
|
||||||
"""Creates a sample decoder"""
|
"""Creates a sample decoder"""
|
||||||
unets = []
|
|
||||||
for i in range(0, len(unets_config)):
|
unets = [Unet(**config) for config in unets_config]
|
||||||
unets.append(Unet(
|
|
||||||
**unets_config[i]
|
|
||||||
))
|
|
||||||
|
|
||||||
decoder = Decoder(
|
decoder = Decoder(
|
||||||
unet=unets,
|
unet=unets,
|
||||||
**decoder_config
|
**decoder_config
|
||||||
)
|
)
|
||||||
decoder.to(device=device)
|
|
||||||
|
|
||||||
|
decoder.to(device=device)
|
||||||
return decoder
|
return decoder
|
||||||
|
|
||||||
def get_dataset_keys(dataloader):
|
def get_dataset_keys(dataloader):
|
||||||
@@ -160,20 +168,20 @@ def evaluate_trainer(trainer, dataloader, device, n_evalation_samples=1000, FID=
|
|||||||
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
|
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
|
||||||
int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
|
int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
|
||||||
int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
|
int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
|
||||||
if FID is not None:
|
if exists(FID):
|
||||||
fid = FrechetInceptionDistance(**FID)
|
fid = FrechetInceptionDistance(**FID)
|
||||||
fid.to(device=device)
|
fid.to(device=device)
|
||||||
fid.update(int_real_images, real=True)
|
fid.update(int_real_images, real=True)
|
||||||
fid.update(int_generated_images, real=False)
|
fid.update(int_generated_images, real=False)
|
||||||
metrics["FID"] = fid.compute().item()
|
metrics["FID"] = fid.compute().item()
|
||||||
if IS is not None:
|
if exists(IS):
|
||||||
inception = InceptionScore(**IS)
|
inception = InceptionScore(**IS)
|
||||||
inception.to(device=device)
|
inception.to(device=device)
|
||||||
inception.update(int_real_images)
|
inception.update(int_real_images)
|
||||||
is_mean, is_std = inception.compute()
|
is_mean, is_std = inception.compute()
|
||||||
metrics["IS_mean"] = is_mean.item()
|
metrics["IS_mean"] = is_mean.item()
|
||||||
metrics["IS_std"] = is_std.item()
|
metrics["IS_std"] = is_std.item()
|
||||||
if KID is not None:
|
if exists(KID):
|
||||||
kernel_inception = KernelInceptionDistance(**KID)
|
kernel_inception = KernelInceptionDistance(**KID)
|
||||||
kernel_inception.to(device=device)
|
kernel_inception.to(device=device)
|
||||||
kernel_inception.update(int_real_images, real=True)
|
kernel_inception.update(int_real_images, real=True)
|
||||||
@@ -181,7 +189,7 @@ def evaluate_trainer(trainer, dataloader, device, n_evalation_samples=1000, FID=
|
|||||||
kid_mean, kid_std = kernel_inception.compute()
|
kid_mean, kid_std = kernel_inception.compute()
|
||||||
metrics["KID_mean"] = kid_mean.item()
|
metrics["KID_mean"] = kid_mean.item()
|
||||||
metrics["KID_std"] = kid_std.item()
|
metrics["KID_std"] = kid_std.item()
|
||||||
if LPIPS is not None:
|
if exists(LPIPS):
|
||||||
# Convert from [0, 1] to [-1, 1]
|
# Convert from [0, 1] to [-1, 1]
|
||||||
renorm_real_images = real_images.mul(2).sub(1)
|
renorm_real_images = real_images.mul(2).sub(1)
|
||||||
renorm_generated_images = generated_images.mul(2).sub(1)
|
renorm_generated_images = generated_images.mul(2).sub(1)
|
||||||
@@ -245,11 +253,11 @@ def train(
|
|||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
validation_losses = []
|
validation_losses = []
|
||||||
|
|
||||||
if load_config is not None and load_config["source"] is not None:
|
if exists(load_config) and exists(load_config["source"]):
|
||||||
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config["source"], **load_config)
|
start_epoch, start_step, validation_losses = recall_trainer(tracker, trainer, recall_source=load_config["source"], **load_config)
|
||||||
trainer.to(device=inference_device)
|
trainer.to(device=inference_device)
|
||||||
|
|
||||||
if unet_training_mask is None:
|
if not exists(unet_training_mask):
|
||||||
# Then the unet mask should be true for all unets in the decoder
|
# Then the unet mask should be true for all unets in the decoder
|
||||||
unet_training_mask = [True] * trainer.num_unets
|
unet_training_mask = [True] * trainer.num_unets
|
||||||
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
|
assert len(unet_training_mask) == trainer.num_unets, f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
|
||||||
@@ -280,17 +288,19 @@ def train(
|
|||||||
|
|
||||||
for unet in range(1, trainer.num_unets+1):
|
for unet in range(1, trainer.num_unets+1):
|
||||||
# Check if this is a unet we are training
|
# Check if this is a unet we are training
|
||||||
if unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
||||||
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
|
continue
|
||||||
trainer.update(unet_number=unet)
|
|
||||||
losses.append(loss)
|
loss = trainer.forward(img, image_embed=emb, unet_number=unet)
|
||||||
|
trainer.update(unet_number=unet)
|
||||||
|
losses.append(loss)
|
||||||
|
|
||||||
samples_per_sec = (sample - last_sample) / timer.elapsed()
|
samples_per_sec = (sample - last_sample) / timer.elapsed()
|
||||||
|
|
||||||
timer.reset()
|
timer.reset()
|
||||||
last_sample = sample
|
last_sample = sample
|
||||||
|
|
||||||
if i % 10 == 0:
|
if i % CALC_LOSS_EVERY_ITERS == 0:
|
||||||
average_loss = sum(losses) / len(losses)
|
average_loss = sum(losses) / len(losses)
|
||||||
log_data = {
|
log_data = {
|
||||||
"Training loss": average_loss,
|
"Training loss": average_loss,
|
||||||
@@ -311,13 +321,13 @@ def train(
|
|||||||
if save_all:
|
if save_all:
|
||||||
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth")
|
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth")
|
||||||
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
|
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
|
||||||
if n_sample_images is not None and n_sample_images > 0:
|
if exists(n_sample_images) and n_sample_images > 0:
|
||||||
trainer.eval()
|
trainer.eval()
|
||||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
||||||
trainer.train()
|
trainer.train()
|
||||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
|
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
|
||||||
|
|
||||||
if epoch_samples is not None and sample >= epoch_samples:
|
if exists(epoch_samples) and sample >= epoch_samples:
|
||||||
break
|
break
|
||||||
|
|
||||||
trainer.eval()
|
trainer.eval()
|
||||||
@@ -334,12 +344,12 @@ def train(
|
|||||||
loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet)
|
loss = trainer.forward(img.float(), image_embed=emb.float(), unet_number=unet)
|
||||||
average_loss += loss
|
average_loss += loss
|
||||||
|
|
||||||
if i % 10 == 0:
|
if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
|
||||||
print(f"Epoch {epoch}/{epochs} - {sample / timer.elapsed():.2f} samples/sec")
|
print(f"Epoch {epoch}/{epochs} - {sample / timer.elapsed():.2f} samples/sec")
|
||||||
print(f"Loss: {average_loss / (i+1)}")
|
print(f"Loss: {average_loss / (i+1)}")
|
||||||
print("")
|
print("")
|
||||||
|
|
||||||
if validation_samples is not None and sample >= validation_samples:
|
if exists(validation_samples) and sample >= validation_samples:
|
||||||
break
|
break
|
||||||
|
|
||||||
average_loss /= i+1
|
average_loss /= i+1
|
||||||
@@ -350,7 +360,7 @@ def train(
|
|||||||
|
|
||||||
# Compute evaluation metrics
|
# Compute evaluation metrics
|
||||||
trainer.eval()
|
trainer.eval()
|
||||||
if evaluate_config is not None:
|
if exists(evaluate_config):
|
||||||
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
|
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
|
||||||
tracker.log(evaluation, step=step, verbose=True)
|
tracker.log(evaluation, step=step, verbose=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user