From 1924c7cc3da2c2309f57571a45ff0cfdff021490 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 2 May 2022 09:20:19 -0700 Subject: [PATCH] fix issue with mixed precision and gradient clipping --- dalle2_pytorch/train.py | 7 ++++--- setup.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/dalle2_pytorch/train.py b/dalle2_pytorch/train.py index 0868182..ddb0732 100644 --- a/dalle2_pytorch/train.py +++ b/dalle2_pytorch/train.py @@ -159,12 +159,13 @@ class DecoderTrainer(nn.Module): index = unet_number - 1 unet = self.decoder.unets[index] - if exists(self.max_grad_norm): - nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm) - optimizer = getattr(self, f'optim{index}') scaler = getattr(self, f'scaler{index}') + if exists(self.max_grad_norm): + scaler.unscale_(optimizer) + nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm) + scaler.step(optimizer) scaler.update() optimizer.zero_grad() diff --git a/setup.py b/setup.py index 2309d4c..4b09fac 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.89', + version = '0.0.90', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',