Compare commits

...

1 Commits

3 changed files with 7 additions and 5 deletions

View File

@@ -775,7 +775,6 @@ decoder_trainer = DecoderTrainer(
for unet_number in (1, 2): for unet_number in (1, 2):
loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward
loss.backward()
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
@@ -839,7 +838,6 @@ diffusion_prior_trainer = DiffusionPriorTrainer(
) )
loss = diffusion_prior_trainer(text, images) loss = diffusion_prior_trainer(text, images)
loss.backward()
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
# after much of the above three lines in a loop # after much of the above three lines in a loop

View File

@@ -214,7 +214,9 @@ class DiffusionPriorTrainer(nn.Module):
): ):
with autocast(enabled = self.amp): with autocast(enabled = self.amp):
loss = self.diffusion_prior(*args, **kwargs) loss = self.diffusion_prior(*args, **kwargs)
return self.scaler.scale(loss / divisor) scaled_loss = self.scaler.scale(loss / divisor)
scaled_loss.backward()
return loss.item()
# decoder trainer # decoder trainer
@@ -330,4 +332,6 @@ class DecoderTrainer(nn.Module):
): ):
with autocast(enabled = self.amp): with autocast(enabled = self.amp):
loss = self.decoder(x, unet_number = unet_number, **kwargs) loss = self.decoder(x, unet_number = unet_number, **kwargs)
return self.scale(loss / divisor, unet_number = unet_number) scaled_loss = self.scale(loss / divisor, unet_number = unet_number)
scaled_loss.backward()
return loss.item()

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.2.23', version = '0.2.24',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',