take care of mixed precision, and make gradient accumulation do-able externally

This commit is contained in:
Phil Wang
2022-04-30 12:27:24 -07:00
parent 5fff22834e
commit a2ef69af66
3 changed files with 22 additions and 4 deletions

View File

@@ -3,6 +3,7 @@ from functools import partial
import torch
from torch import nn
from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder
from dalle2_pytorch.optimizer import get_optimizer
@@ -98,6 +99,7 @@ class DecoderTrainer(nn.Module):
lr = 3e-4,
wd = 1e-2,
max_grad_norm = None,
amp = False,
**kwargs
):
super().__init__()
@@ -115,6 +117,8 @@ class DecoderTrainer(nn.Module):
self.ema_unets = nn.ModuleList([])
self.amp = amp
# be able to finely customize learning rate, weight decay
# per unet
@@ -133,10 +137,19 @@ class DecoderTrainer(nn.Module):
if self.use_ema:
self.ema_unets.append(EMA(unet, **ema_kwargs))
scaler = GradScaler(enabled = amp)
setattr(self, f'scaler{ind}', scaler)
# gradient clipping if needed
self.max_grad_norm = max_grad_norm
def scale(self, loss, *, unet_number):
assert 1 <= unet_number <= self.num_unets
index = unet_number - 1
scaler = getattr(self, f'scaler{index}')
return scaler.scale(loss)
def update(self, unet_number):
assert 1 <= unet_number <= self.num_unets
index = unet_number - 1
@@ -146,7 +159,10 @@ class DecoderTrainer(nn.Module):
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
optimizer = getattr(self, f'optim{index}')
optimizer.step()
scaler = getattr(self, f'scaler{index}')
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
if self.use_ema:
@@ -154,4 +170,6 @@ class DecoderTrainer(nn.Module):
ema_unet.update()
def forward(self, x, *, unet_number, **kwargs):
return self.decoder(x, unet_number = unet_number, **kwargs)
with autocast(enabled = self.amp):
loss = self.decoder(x, unet_number = unet_number, **kwargs)
return self.scale(loss, unet_number = unet_number)