backwards pass is not recommended under the autocast context, per pytorch docs

This commit is contained in:
Phil Wang
2022-05-14 18:26:19 -07:00
parent aee92dba4a
commit 4ec6d0ba81
2 changed files with 5 additions and 5 deletions

View File

@@ -266,8 +266,8 @@ class DiffusionPriorTrainer(nn.Module):
loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
loss = loss * chunk_size_frac
total_loss += loss.item()
self.scaler.scale(loss).backward()
total_loss += loss.item()
self.scaler.scale(loss).backward()
return total_loss
@@ -390,7 +390,7 @@ class DecoderTrainer(nn.Module):
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
loss = loss * chunk_size_frac
total_loss += loss.item()
self.scale(loss, unet_number = unet_number).backward()
total_loss += loss.item()
self.scale(loss, unet_number = unet_number).backward()
return total_loss