From b494ed81d4d414505041cf8e4486bc228be5a3f4 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 14 May 2022 15:49:24 -0700 Subject: [PATCH] take care of backwards within trainer classes for diffusion prior and decoder, readying to take care of gradient accumulation as well (plus, unsure if loss should be backwards within autocast block) --- README.md | 2 -- dalle2_pytorch/train.py | 8 ++++++-- setup.py | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index b1a83c0..aae849e 100644 --- a/README.md +++ b/README.md @@ -775,7 +775,6 @@ decoder_trainer = DecoderTrainer( for unet_number in (1, 2): loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward - loss.backward() decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average @@ -839,7 +838,6 @@ diffusion_prior_trainer = DiffusionPriorTrainer( ) loss = diffusion_prior_trainer(text, images) -loss.backward() diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior # after much of the above three lines in a loop diff --git a/dalle2_pytorch/train.py b/dalle2_pytorch/train.py index b341814..244682c 100644 --- a/dalle2_pytorch/train.py +++ b/dalle2_pytorch/train.py @@ -214,7 +214,9 @@ class DiffusionPriorTrainer(nn.Module): ): with autocast(enabled = self.amp): loss = self.diffusion_prior(*args, **kwargs) - return self.scaler.scale(loss / divisor) + scaled_loss = self.scaler.scale(loss / divisor) + scaled_loss.backward() + return loss.item() # decoder trainer @@ -330,4 +332,6 @@ class DecoderTrainer(nn.Module): ): with autocast(enabled = self.amp): loss = self.decoder(x, unet_number = unet_number, **kwargs) - return self.scale(loss / divisor, unet_number = unet_number) + scaled_loss = self.scale(loss / divisor, unet_number = unet_number) + scaled_loss.backward() + return loss.item() diff --git a/setup.py b/setup.py index 70e44b1..88c8d1f 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.2.23', + version = '0.2.24', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',