Compare commits

..

1 Commits

3 changed files with 7 additions and 5 deletions

View File

@@ -775,7 +775,6 @@ decoder_trainer = DecoderTrainer(
for unet_number in (1, 2):
loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward
loss.backward()
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
@@ -839,7 +838,6 @@ diffusion_prior_trainer = DiffusionPriorTrainer(
)
loss = diffusion_prior_trainer(text, images)
loss.backward()
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
# after much of the above three lines in a loop

View File

@@ -214,7 +214,9 @@ class DiffusionPriorTrainer(nn.Module):
):
with autocast(enabled = self.amp):
loss = self.diffusion_prior(*args, **kwargs)
return self.scaler.scale(loss / divisor)
scaled_loss = self.scaler.scale(loss / divisor)
scaled_loss.backward()
return loss.item()
# decoder trainer
@@ -330,4 +332,6 @@ class DecoderTrainer(nn.Module):
):
with autocast(enabled = self.amp):
loss = self.decoder(x, unet_number = unet_number, **kwargs)
return self.scale(loss / divisor, unet_number = unet_number)
scaled_loss = self.scale(loss / divisor, unet_number = unet_number)
scaled_loss.backward()
return loss.item()

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.2.23',
version = '0.2.24',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',