mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
fix issue with mixed precision and gradient clipping
This commit is contained in:
@@ -159,12 +159,13 @@ class DecoderTrainer(nn.Module):
|
|||||||
index = unet_number - 1
|
index = unet_number - 1
|
||||||
unet = self.decoder.unets[index]
|
unet = self.decoder.unets[index]
|
||||||
|
|
||||||
if exists(self.max_grad_norm):
|
|
||||||
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
|
|
||||||
|
|
||||||
optimizer = getattr(self, f'optim{index}')
|
optimizer = getattr(self, f'optim{index}')
|
||||||
scaler = getattr(self, f'scaler{index}')
|
scaler = getattr(self, f'scaler{index}')
|
||||||
|
|
||||||
|
if exists(self.max_grad_norm):
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
|
||||||
|
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|||||||
Reference in New Issue
Block a user