diff --git a/dalle2_pytorch/optimizer.py b/dalle2_pytorch/optimizer.py index 514bed9..ae90ebc 100644 --- a/dalle2_pytorch/optimizer.py +++ b/dalle2_pytorch/optimizer.py @@ -10,13 +10,14 @@ def get_optimizer( lr = 2e-5, wd = 1e-2, betas = (0.9, 0.999), + eps = 1e-8, filter_by_requires_grad = False ): if filter_by_requires_grad: params = list(filter(lambda t: t.requires_grad, params)) if wd == 0: - return Adam(params, lr = lr, betas = betas) + return Adam(params, lr = lr, betas = betas, eps = eps) params = set(params) wd_params, no_wd_params = separate_weight_decayable_params(params) @@ -26,4 +27,4 @@ def get_optimizer( {'params': list(no_wd_params), 'weight_decay': 0}, ] - return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas) + return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps) diff --git a/dalle2_pytorch/train.py b/dalle2_pytorch/train.py index 0fddb8a..8343462 100644 --- a/dalle2_pytorch/train.py +++ b/dalle2_pytorch/train.py @@ -147,6 +147,7 @@ class DiffusionPriorTrainer(nn.Module): use_ema = True, lr = 3e-4, wd = 1e-2, + eps = 1e-6, max_grad_norm = None, amp = False, **kwargs @@ -173,6 +174,7 @@ class DiffusionPriorTrainer(nn.Module): diffusion_prior.parameters(), lr = lr, wd = wd, + eps = eps, **kwargs ) @@ -223,6 +225,7 @@ class DecoderTrainer(nn.Module): use_ema = True, lr = 2e-5, wd = 1e-2, + eps = 1e-8, max_grad_norm = None, amp = False, **kwargs @@ -247,13 +250,14 @@ class DecoderTrainer(nn.Module): # be able to finely customize learning rate, weight decay # per unet - lr, wd = map(partial(cast_tuple, length = self.num_unets), (lr, wd)) + lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps)) - for ind, (unet, unet_lr, unet_wd) in enumerate(zip(self.decoder.unets, lr, wd)): + for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.decoder.unets, lr, wd, eps)): optimizer = get_optimizer( unet.parameters(), lr = unet_lr, wd = unet_wd, + eps = unet_eps, **kwargs ) diff --git a/setup.py b/setup.py index b2327e6..bd10ad3 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.2.20', + version = '0.2.21', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',