mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
be able to finely customize learning parameters for each unet, take care of gradient clipping
This commit is contained in:
@@ -9,6 +9,12 @@ from dalle2_pytorch.optimizer import get_optimizer
|
|||||||
|
|
||||||
# helper functions
|
# helper functions
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
def cast_tuple(val, length = 1):
|
||||||
|
return val if isinstance(val, tuple) else ((val,) * length)
|
||||||
|
|
||||||
def pick_and_pop(keys, d):
|
def pick_and_pop(keys, d):
|
||||||
values = list(map(lambda key: d.pop(key), keys))
|
values = list(map(lambda key: d.pop(key), keys))
|
||||||
return dict(zip(keys, values))
|
return dict(zip(keys, values))
|
||||||
@@ -89,6 +95,9 @@ class DecoderTrainer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
decoder,
|
decoder,
|
||||||
use_ema = True,
|
use_ema = True,
|
||||||
|
lr = 3e-4,
|
||||||
|
wd = 1e-2,
|
||||||
|
max_grad_norm = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -106,16 +115,35 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
self.ema_unets = nn.ModuleList([])
|
self.ema_unets = nn.ModuleList([])
|
||||||
|
|
||||||
for ind, unet in enumerate(self.decoder.unets):
|
# be able to finely customize learning rate, weight decay
|
||||||
optimizer = get_optimizer(unet.parameters(), **kwargs)
|
# per unet
|
||||||
|
|
||||||
|
lr, wd = map(partial(cast_tuple, length = self.num_unets), (lr, wd))
|
||||||
|
|
||||||
|
for ind, (unet, unet_lr, unet_wd) in enumerate(zip(self.decoder.unets, lr, wd)):
|
||||||
|
optimizer = get_optimizer(
|
||||||
|
unet.parameters(),
|
||||||
|
lr = unet_lr,
|
||||||
|
wd = unet_wd,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
|
setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
|
||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.ema_unets.append(EMA(unet, **ema_kwargs))
|
self.ema_unets.append(EMA(unet, **ema_kwargs))
|
||||||
|
|
||||||
|
# gradient clipping if needed
|
||||||
|
|
||||||
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
def update(self, unet_number):
|
def update(self, unet_number):
|
||||||
assert 1 <= unet_number <= self.num_unets
|
assert 1 <= unet_number <= self.num_unets
|
||||||
index = unet_number - 1
|
index = unet_number - 1
|
||||||
|
unet = self.decoder.unets[index]
|
||||||
|
|
||||||
|
if exists(self.max_grad_norm):
|
||||||
|
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
|
||||||
|
|
||||||
optimizer = getattr(self, f'optim{index}')
|
optimizer = getattr(self, f'optim{index}')
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|||||||
Reference in New Issue
Block a user