mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
take care of backwards within trainer classes for diffusion prior and decoder, readying to take care of gradient accumulation as well (plus, unsure if loss should be backwards within autocast block)
This commit is contained in:
@@ -775,7 +775,6 @@ decoder_trainer = DecoderTrainer(
|
|||||||
|
|
||||||
for unet_number in (1, 2):
|
for unet_number in (1, 2):
|
||||||
loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward
|
loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
|
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
|
||||||
|
|
||||||
@@ -839,7 +838,6 @@ diffusion_prior_trainer = DiffusionPriorTrainer(
|
|||||||
)
|
)
|
||||||
|
|
||||||
loss = diffusion_prior_trainer(text, images)
|
loss = diffusion_prior_trainer(text, images)
|
||||||
loss.backward()
|
|
||||||
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
|
diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
|
||||||
|
|
||||||
# after much of the above three lines in a loop
|
# after much of the above three lines in a loop
|
||||||
|
|||||||
@@ -214,7 +214,9 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
):
|
):
|
||||||
with autocast(enabled = self.amp):
|
with autocast(enabled = self.amp):
|
||||||
loss = self.diffusion_prior(*args, **kwargs)
|
loss = self.diffusion_prior(*args, **kwargs)
|
||||||
return self.scaler.scale(loss / divisor)
|
scaled_loss = self.scaler.scale(loss / divisor)
|
||||||
|
scaled_loss.backward()
|
||||||
|
return loss.item()
|
||||||
|
|
||||||
# decoder trainer
|
# decoder trainer
|
||||||
|
|
||||||
@@ -330,4 +332,6 @@ class DecoderTrainer(nn.Module):
|
|||||||
):
|
):
|
||||||
with autocast(enabled = self.amp):
|
with autocast(enabled = self.amp):
|
||||||
loss = self.decoder(x, unet_number = unet_number, **kwargs)
|
loss = self.decoder(x, unet_number = unet_number, **kwargs)
|
||||||
return self.scale(loss / divisor, unet_number = unet_number)
|
scaled_loss = self.scale(loss / divisor, unet_number = unet_number)
|
||||||
|
scaled_loss.backward()
|
||||||
|
return loss.item()
|
||||||
|
|||||||
Reference in New Issue
Block a user