diff --git a/dalle2_pytorch/train.py b/dalle2_pytorch/train.py index a8c628c..7f3a657 100644 --- a/dalle2_pytorch/train.py +++ b/dalle2_pytorch/train.py @@ -169,7 +169,14 @@ class DecoderTrainer(nn.Module): ema_unet = self.ema_unets[index] ema_unet.update() - def forward(self, x, *, unet_number, **kwargs): + def forward( + self, + x, + *, + unet_number, + divisor = 1, + **kwargs + ): with autocast(enabled = self.amp): loss = self.decoder(x, unet_number = unet_number, **kwargs) - return self.scale(loss, unet_number = unet_number) + return self.scale(loss / divisor, unet_number = unet_number) diff --git a/setup.py b/setup.py index fe6a849..0f56c1f 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.79', + version = '0.0.80', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',