take care of backwards within trainer classes for diffusion prior and decoder, readying to take care of gradient accumulation as well (plus, unsure if loss should be backwards within autocast block)

This commit is contained in:
Phil Wang
2022-05-14 15:49:24 -07:00
parent ff3474f05c
commit b494ed81d4
3 changed files with 7 additions and 5 deletions

View File

@@ -775,7 +775,6 @@ decoder_trainer = DecoderTrainer(
for unet_number in (1, 2):
loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward
loss.backward()
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
@@ -839,7 +838,6 @@ diffusion_prior_trainer = DiffusionPriorTrainer(
)
loss = diffusion_prior_trainer(text, images)
loss.backward()
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
# after much of the above three lines in a loop

View File

@@ -214,7 +214,9 @@ class DiffusionPriorTrainer(nn.Module):
):
with autocast(enabled = self.amp):
loss = self.diffusion_prior(*args, **kwargs)
return self.scaler.scale(loss / divisor)
scaled_loss = self.scaler.scale(loss / divisor)
scaled_loss.backward()
return loss.item()
# decoder trainer
@@ -330,4 +332,6 @@ class DecoderTrainer(nn.Module):
):
with autocast(enabled = self.amp):
loss = self.decoder(x, unet_number = unet_number, **kwargs)
return self.scale(loss / divisor, unet_number = unet_number)
scaled_loss = self.scale(loss / divisor, unet_number = unet_number)
scaled_loss.backward()
return loss.item()

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.2.23',
version = '0.2.24',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',