mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 19:44:26 +01:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b494ed81d4 |
@@ -775,7 +775,6 @@ decoder_trainer = DecoderTrainer(
|
||||
|
||||
for unet_number in (1, 2):
|
||||
loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward
|
||||
loss.backward()
|
||||
|
||||
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
|
||||
|
||||
@@ -839,7 +838,6 @@ diffusion_prior_trainer = DiffusionPriorTrainer(
|
||||
)
|
||||
|
||||
loss = diffusion_prior_trainer(text, images)
|
||||
loss.backward()
|
||||
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
|
||||
|
||||
# after much of the above three lines in a loop
|
||||
|
||||
@@ -214,7 +214,9 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.diffusion_prior(*args, **kwargs)
|
||||
return self.scaler.scale(loss / divisor)
|
||||
scaled_loss = self.scaler.scale(loss / divisor)
|
||||
scaled_loss.backward()
|
||||
return loss.item()
|
||||
|
||||
# decoder trainer
|
||||
|
||||
@@ -330,4 +332,6 @@ class DecoderTrainer(nn.Module):
|
||||
):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.decoder(x, unet_number = unet_number, **kwargs)
|
||||
return self.scale(loss / divisor, unet_number = unet_number)
|
||||
scaled_loss = self.scale(loss / divisor, unet_number = unet_number)
|
||||
scaled_loss.backward()
|
||||
return loss.item()
|
||||
|
||||
Reference in New Issue
Block a user