mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
add sampels-seen and ema decay (#166)
This commit is contained in:
@@ -268,6 +268,7 @@ def train(
|
|||||||
validation_losses = []
|
validation_losses = []
|
||||||
next_task = 'train'
|
next_task = 'train'
|
||||||
sample = 0
|
sample = 0
|
||||||
|
samples_seen = 0
|
||||||
val_sample = 0
|
val_sample = 0
|
||||||
step = lambda: int(trainer.step.item())
|
step = lambda: int(trainer.step.item())
|
||||||
|
|
||||||
@@ -312,6 +313,7 @@ def train(
|
|||||||
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
|
all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
|
||||||
total_samples = all_samples.sum().item()
|
total_samples = all_samples.sum().item()
|
||||||
sample += total_samples
|
sample += total_samples
|
||||||
|
samples_seen += total_samples
|
||||||
img, emb = send_to_device((img, emb))
|
img, emb = send_to_device((img, emb))
|
||||||
|
|
||||||
trainer.train()
|
trainer.train()
|
||||||
@@ -334,14 +336,20 @@ def train(
|
|||||||
mask = unet_all_losses != 0
|
mask = unet_all_losses != 0
|
||||||
unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0)
|
unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0)
|
||||||
loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if loss != 0 }
|
loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if loss != 0 }
|
||||||
|
|
||||||
|
# gather decay rate on each UNet
|
||||||
|
ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets)}
|
||||||
|
|
||||||
log_data = {
|
log_data = {
|
||||||
"Epoch": epoch,
|
"Epoch": epoch,
|
||||||
"Sample": sample,
|
"Sample": sample,
|
||||||
"Step": i,
|
"Step": i,
|
||||||
"Samples per second": samples_per_sec,
|
"Samples per second": samples_per_sec,
|
||||||
|
"Samples Seen": samples_seen,
|
||||||
|
**ema_decay_list,
|
||||||
**loss_map
|
**loss_map
|
||||||
}
|
}
|
||||||
# print(f"I am rank {accelerator.state.process_index}. Example weight: {trainer.decoder.state_dict()['module.unets.0.init_conv.convs.0.weight'][0,0,0,0]}")
|
|
||||||
if is_master:
|
if is_master:
|
||||||
tracker.log(log_data, step=step(), verbose=True)
|
tracker.log(log_data, step=step(), verbose=True)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user