backwards pass is not recommended under the autocast context, per pytorch docs

This commit is contained in:
Phil Wang
2022-05-14 18:20:48 -07:00
parent aee92dba4a
commit 9549bd43b7
2 changed files with 7 additions and 7 deletions

View File

@@ -264,10 +264,10 @@ class DiffusionPriorTrainer(nn.Module):
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, *args, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
loss = loss * chunk_size_frac
total_loss += loss.item()
self.scaler.scale(loss).backward()
loss = loss * chunk_size_frac
total_loss += loss.item()
self.scaler.scale(loss).backward()
return total_loss
@@ -388,9 +388,9 @@ class DecoderTrainer(nn.Module):
for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(x, split_size = max_batch_size, **kwargs):
with autocast(enabled = self.amp):
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
loss = loss * chunk_size_frac
total_loss += loss.item()
self.scale(loss, unet_number = unet_number).backward()
loss = loss * chunk_size_frac
total_loss += loss.item()
self.scale(loss, unet_number = unet_number).backward()
return total_loss

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.2.27',
version = '0.2.28',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',