From b432df2f7b1c8264e50ee91b9bd38144f571d774 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 21 May 2022 10:42:16 -0700 Subject: [PATCH] final cleanup to decoder script --- train_decoder.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/train_decoder.py b/train_decoder.py index fb993e6..1648e9f 100644 --- a/train_decoder.py +++ b/train_decoder.py @@ -271,7 +271,6 @@ def train( for epoch in range(start_epoch, epochs): print(print_ribbon(f"Starting epoch {epoch}", repeat=40)) - trainer.train() timer = Timer() @@ -280,11 +279,13 @@ def train( last_snapshot = 0 losses = [] + for i, (img, emb) in enumerate(dataloaders["train"]): step += 1 sample += img.shape[0] img, emb = send_to_device((img, emb)) + trainer.train() for unet in range(1, trainer.num_unets+1): # Check if this is a unet we are training if not unet_training_mask[unet-1]: # Unet index is the unet number - 1 @@ -319,11 +320,12 @@ def train( save_paths.append("latest.pth") if save_all: save_paths.append(f"checkpoints/epoch_{epoch}_step_{step}.pth") + save_trainer(tracker, trainer, epoch, step, validation_losses, save_paths) + if exists(n_sample_images) and n_sample_images > 0: trainer.eval() train_images, train_captions = generate_grid_samples(trainer, train_example_data, "Train: ") - trainer.train() tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step) if exists(epoch_samples) and sample >= epoch_samples: @@ -358,7 +360,6 @@ def train( tracker.log(log_data, step=step, verbose=True) # Compute evaluation metrics - trainer.eval() if exists(evaluate_config): print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40)) evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, **evaluate_config)