diff --git a/README.md b/README.md index b1a83c0..aae849e 100644 --- a/README.md +++ b/README.md @@ -775,7 +775,6 @@ decoder_trainer = DecoderTrainer( for unet_number in (1, 2): loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward - loss.backward() decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average @@ -839,7 +838,6 @@ diffusion_prior_trainer = DiffusionPriorTrainer( ) loss = diffusion_prior_trainer(text, images) -loss.backward() diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior # after much of the above three lines in a loop diff --git a/dalle2_pytorch/train.py b/dalle2_pytorch/train.py index b341814..244682c 100644 --- a/dalle2_pytorch/train.py +++ b/dalle2_pytorch/train.py @@ -214,7 +214,9 @@ class DiffusionPriorTrainer(nn.Module): ): with autocast(enabled = self.amp): loss = self.diffusion_prior(*args, **kwargs) - return self.scaler.scale(loss / divisor) + scaled_loss = self.scaler.scale(loss / divisor) + scaled_loss.backward() + return loss.item() # decoder trainer @@ -330,4 +332,6 @@ class DecoderTrainer(nn.Module): ): with autocast(enabled = self.amp): loss = self.decoder(x, unet_number = unet_number, **kwargs) - return self.scale(loss / divisor, unet_number = unet_number) + scaled_loss = self.scale(loss / divisor, unet_number = unet_number) + scaled_loss.backward() + return loss.item() diff --git a/setup.py b/setup.py index 70e44b1..88c8d1f 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.2.23', + version = '0.2.24', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',