From 9549bd43b741bda5e318239fb5e4397d528e8757 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 14 May 2022 18:20:48 -0700 Subject: [PATCH] backwards pass is not recommended under the autocast context, per pytorch docs --- dalle2_pytorch/train.py | 12 ++++++------ setup.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/dalle2_pytorch/train.py b/dalle2_pytorch/train.py index e8f706a..b3899f1 100644 --- a/dalle2_pytorch/train.py +++ b/dalle2_pytorch/train.py @@ -264,10 +264,10 @@ class DiffusionPriorTrainer(nn.Module): for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, *args, split_size = max_batch_size, **kwargs): with autocast(enabled = self.amp): loss = self.diffusion_prior(*chunked_args, **chunked_kwargs) - loss = loss * chunk_size_frac - total_loss += loss.item() - self.scaler.scale(loss).backward() + loss = loss * chunk_size_frac + total_loss += loss.item() + self.scaler.scale(loss).backward() return total_loss @@ -388,9 +388,9 @@ class DecoderTrainer(nn.Module): for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, split_size = max_batch_size, **kwargs): with autocast(enabled = self.amp): loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs) - loss = loss * chunk_size_frac - total_loss += loss.item() - self.scale(loss, unet_number = unet_number).backward() + loss = loss * chunk_size_frac + total_loss += loss.item() + self.scale(loss, unet_number = unet_number).backward() return total_loss diff --git a/setup.py b/setup.py index 8f9d952..dd6ad53 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.2.27', + version = '0.2.28', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',