From 63195cc2cb683db11c08ea19f1f76e7c2628127f Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 30 Apr 2022 12:56:47 -0700 Subject: [PATCH] allow for division of loss prior to scaling, for gradient accumulation purposes --- dalle2_pytorch/train.py | 11 +++++++++-- setup.py | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/dalle2_pytorch/train.py b/dalle2_pytorch/train.py index a8c628c..7f3a657 100644 --- a/dalle2_pytorch/train.py +++ b/dalle2_pytorch/train.py @@ -169,7 +169,14 @@ class DecoderTrainer(nn.Module): ema_unet = self.ema_unets[index] ema_unet.update() - def forward(self, x, *, unet_number, **kwargs): + def forward( + self, + x, + *, + unet_number, + divisor = 1, + **kwargs + ): with autocast(enabled = self.amp): loss = self.decoder(x, unet_number = unet_number, **kwargs) - return self.scale(loss, unet_number = unet_number) + return self.scale(loss / divisor, unet_number = unet_number) diff --git a/setup.py b/setup.py index fe6a849..0f56c1f 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.79', + version = '0.0.80', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',