diff --git a/dalle2_pytorch/train.py b/dalle2_pytorch/train.py index ed0c16b..4316204 100644 --- a/dalle2_pytorch/train.py +++ b/dalle2_pytorch/train.py @@ -9,6 +9,12 @@ from dalle2_pytorch.optimizer import get_optimizer # helper functions +def exists(val): + return val is not None + +def cast_tuple(val, length = 1): + return val if isinstance(val, tuple) else ((val,) * length) + def pick_and_pop(keys, d): values = list(map(lambda key: d.pop(key), keys)) return dict(zip(keys, values)) @@ -89,6 +95,9 @@ class DecoderTrainer(nn.Module): self, decoder, use_ema = True, + lr = 3e-4, + wd = 1e-2, + max_grad_norm = None, **kwargs ): super().__init__() @@ -106,16 +115,35 @@ class DecoderTrainer(nn.Module): self.ema_unets = nn.ModuleList([]) - for ind, unet in enumerate(self.decoder.unets): - optimizer = get_optimizer(unet.parameters(), **kwargs) + # be able to finely customize learning rate, weight decay + # per unet + + lr, wd = map(partial(cast_tuple, length = self.num_unets), (lr, wd)) + + for ind, (unet, unet_lr, unet_wd) in enumerate(zip(self.decoder.unets, lr, wd)): + optimizer = get_optimizer( + unet.parameters(), + lr = unet_lr, + wd = unet_wd, + **kwargs + ) + setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers if self.use_ema: self.ema_unets.append(EMA(unet, **ema_kwargs)) + # gradient clipping if needed + + self.max_grad_norm = max_grad_norm + def update(self, unet_number): assert 1 <= unet_number <= self.num_unets index = unet_number - 1 + unet = self.decoder.unets[index] + + if exists(self.max_grad_norm): + nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm) optimizer = getattr(self, f'optim{index}') optimizer.step() diff --git a/setup.py b/setup.py index 4d97c21..7a3bfbe 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.77', + version = '0.0.78', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',