be able to finely customize learning parameters for each unet, take care of gradient clipping

This commit is contained in:
Phil Wang
2022-04-30 11:56:05 -07:00
parent a9421f49ec
commit 5fff22834e
2 changed files with 31 additions and 3 deletions

View File

@@ -9,6 +9,12 @@ from dalle2_pytorch.optimizer import get_optimizer
# helper functions # helper functions
def exists(val):
return val is not None
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
def pick_and_pop(keys, d): def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys)) values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values)) return dict(zip(keys, values))
@@ -89,6 +95,9 @@ class DecoderTrainer(nn.Module):
self, self,
decoder, decoder,
use_ema = True, use_ema = True,
lr = 3e-4,
wd = 1e-2,
max_grad_norm = None,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
@@ -106,16 +115,35 @@ class DecoderTrainer(nn.Module):
self.ema_unets = nn.ModuleList([]) self.ema_unets = nn.ModuleList([])
for ind, unet in enumerate(self.decoder.unets): # be able to finely customize learning rate, weight decay
optimizer = get_optimizer(unet.parameters(), **kwargs) # per unet
lr, wd = map(partial(cast_tuple, length = self.num_unets), (lr, wd))
for ind, (unet, unet_lr, unet_wd) in enumerate(zip(self.decoder.unets, lr, wd)):
optimizer = get_optimizer(
unet.parameters(),
lr = unet_lr,
wd = unet_wd,
**kwargs
)
setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
if self.use_ema: if self.use_ema:
self.ema_unets.append(EMA(unet, **ema_kwargs)) self.ema_unets.append(EMA(unet, **ema_kwargs))
# gradient clipping if needed
self.max_grad_norm = max_grad_norm
def update(self, unet_number): def update(self, unet_number):
assert 1 <= unet_number <= self.num_unets assert 1 <= unet_number <= self.num_unets
index = unet_number - 1 index = unet_number - 1
unet = self.decoder.unets[index]
if exists(self.max_grad_norm):
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
optimizer = getattr(self, f'optim{index}') optimizer = getattr(self, f'optim{index}')
optimizer.step() optimizer.step()

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.77', version = '0.0.78',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',