mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
allow for division of loss prior to scaling, for gradient accumulation purposes
This commit is contained in:
@@ -169,7 +169,14 @@ class DecoderTrainer(nn.Module):
|
||||
ema_unet = self.ema_unets[index]
|
||||
ema_unet.update()
|
||||
|
||||
def forward(self, x, *, unet_number, **kwargs):
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
*,
|
||||
unet_number,
|
||||
divisor = 1,
|
||||
**kwargs
|
||||
):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.decoder(x, unet_number = unet_number, **kwargs)
|
||||
return self.scale(loss, unet_number = unet_number)
|
||||
return self.scale(loss / divisor, unet_number = unet_number)
|
||||
|
||||
Reference in New Issue
Block a user