mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
final cleanup to decoder script
This commit is contained in:
@@ -271,7 +271,6 @@ def train(
|
|||||||
|
|
||||||
for epoch in range(start_epoch, epochs):
|
for epoch in range(start_epoch, epochs):
|
||||||
print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
|
print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
|
||||||
trainer.train()
|
|
||||||
|
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
|
|
||||||
@@ -280,11 +279,13 @@ def train(
|
|||||||
last_snapshot = 0
|
last_snapshot = 0
|
||||||
|
|
||||||
losses = []
|
losses = []
|
||||||
|
|
||||||
for i, (img, emb) in enumerate(dataloaders["train"]):
|
for i, (img, emb) in enumerate(dataloaders["train"]):
|
||||||
step += 1
|
step += 1
|
||||||
sample += img.shape[0]
|
sample += img.shape[0]
|
||||||
img, emb = send_to_device((img, emb))
|
img, emb = send_to_device((img, emb))
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
for unet in range(1, trainer.num_unets+1):
|
for unet in range(1, trainer.num_unets+1):
|
||||||
# Check if this is a unet we are training
|
# Check if this is a unet we are training
|
||||||
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
|
||||||
@@ -319,11 +320,12 @@ def train(
|
|||||||
save_paths.append("latest.pth")
|
save_paths.append("latest.pth")
|
||||||
if save_all:
|
if save_all:
|
||||||
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth")
|
save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth")
|
||||||
|
|
||||||
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
|
save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths)
|
||||||
|
|
||||||
if exists(n_sample_images) and n_sample_images > 0:
|
if exists(n_sample_images) and n_sample_images > 0:
|
||||||
trainer.eval()
|
trainer.eval()
|
||||||
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ")
|
||||||
trainer.train()
|
|
||||||
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
|
tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step)
|
||||||
|
|
||||||
if exists(epoch_samples) and sample >= epoch_samples:
|
if exists(epoch_samples) and sample >= epoch_samples:
|
||||||
@@ -358,7 +360,6 @@ def train(
|
|||||||
tracker.log(log_data, step=step, verbose=True)
|
tracker.log(log_data, step=step, verbose=True)
|
||||||
|
|
||||||
# Compute evaluation metrics
|
# Compute evaluation metrics
|
||||||
trainer.eval()
|
|
||||||
if exists(evaluate_config):
|
if exists(evaluate_config):
|
||||||
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
|
||||||
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
|
evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)
|
||||||
|
|||||||
Reference in New Issue
Block a user