mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
take care of backwards within trainer classes for diffusion prior and decoder, readying to take care of gradient accumulation as well (plus, unsure if loss should be backwards within autocast block)
This commit is contained in:
@@ -214,7 +214,9 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.diffusion_prior(*args, **kwargs)
|
||||
return self.scaler.scale(loss / divisor)
|
||||
scaled_loss = self.scaler.scale(loss / divisor)
|
||||
scaled_loss.backward()
|
||||
return loss.item()
|
||||
|
||||
# decoder trainer
|
||||
|
||||
@@ -330,4 +332,6 @@ class DecoderTrainer(nn.Module):
|
||||
):
|
||||
with autocast(enabled = self.amp):
|
||||
loss = self.decoder(x, unet_number = unet_number, **kwargs)
|
||||
return self.scale(loss / divisor, unet_number = unet_number)
|
||||
scaled_loss = self.scale(loss / divisor, unet_number = unet_number)
|
||||
scaled_loss.backward()
|
||||
return loss.item()
|
||||
|
||||
Reference in New Issue
Block a user