fix issue with mixed precision and gradient clipping

This commit is contained in:
Phil Wang
2022-05-02 09:20:19 -07:00
parent f7df3caaf3
commit 1924c7cc3d
2 changed files with 5 additions and 4 deletions

View File

@@ -159,12 +159,13 @@ class DecoderTrainer(nn.Module):
index = unet_number - 1
unet = self.decoder.unets[index]
if exists(self.max_grad_norm):
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
optimizer = getattr(self, f'optim{index}')
scaler = getattr(self, f'scaler{index}')
if exists(self.max_grad_norm):
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()