add sampels-seen and ema decay (#166)

This commit is contained in:
zion
2022-06-24 17:12:09 -05:00
committed by GitHub
parent a5b9fd6ca8
commit 98f0c17759

View File

@@ -268,6 +268,7 @@ def train(
validation_losses = []
next_task = 'train'
sample = 0
samples_seen = 0
val_sample = 0
step = lambda: int(trainer.step.item())
@@ -312,6 +313,7 @@ def train(
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
total_samples = all_samples.sum().item()
sample += total_samples
samples_seen += total_samples
img, emb = send_to_device((img, emb))
trainer.train()
@@ -334,14 +336,20 @@ def train(
mask = unet_all_losses != 0
unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0)
loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if loss != 0 }
# gather decay rate on each UNet
ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets)}
log_data = {
"Epoch": epoch,
"Sample": sample,
"Step": i,
"Samples per second": samples_per_sec,
"Samples Seen": samples_seen,
**ema_decay_list,
**loss_map
}
# print(f"I am rank {accelerator.state.process_index}. Example weight: {trainer.decoder.state_dict()['module.unets.0.init_conv.convs.0.weight'][0,0,0,0]}")
if is_master:
tracker.log(log_data, step=step(), verbose=True)