diff --git a/train_decoder.py b/train_decoder.py index 495b7d3..832391b 100644 --- a/train_decoder.py +++ b/train_decoder.py @@ -268,6 +268,7 @@ def train( validation_losses = [] next_task = 'train' sample = 0 + samples_seen = 0 val_sample = 0 step = lambda: int(trainer.step.item()) @@ -312,6 +313,7 @@ def train( all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this. total_samples = all_samples.sum().item() sample += total_samples + samples_seen += total_samples img, emb = send_to_device((img, emb)) trainer.train() @@ -334,14 +336,20 @@ def train( mask = unet_all_losses != 0 unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0) loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if loss != 0 } + + # gather decay rate on each UNet + ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets)} + log_data = { "Epoch": epoch, "Sample": sample, "Step": i, "Samples per second": samples_per_sec, + "Samples Seen": samples_seen, + **ema_decay_list, **loss_map } - # print(f"I am rank {accelerator.state.process_index}. Example weight: {trainer.decoder.state_dict()['module.unets.0.init_conv.convs.0.weight'][0,0,0,0]}") + if is_master: tracker.log(log_data, step=step(), verbose=True)