allow for division of loss prior to scaling, for gradient accumulation purposes

This commit is contained in:
Phil Wang
2022-04-30 12:56:47 -07:00
parent a2ef69af66
commit 63195cc2cb
2 changed files with 10 additions and 3 deletions

View File

@@ -169,7 +169,14 @@ class DecoderTrainer(nn.Module):
ema_unet = self.ema_unets[index]
ema_unet.update()
def forward(self, x, *, unet_number, **kwargs):
def forward(
self,
x,
*,
unet_number,
divisor = 1,
**kwargs
):
with autocast(enabled = self.amp):
loss = self.decoder(x, unet_number = unet_number, **kwargs)
return self.scale(loss, unet_number = unet_number)
return self.scale(loss / divisor, unet_number = unet_number)