From 857b9fbf1ee0f968b2a73a2cedbf22985dae17b9 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 24 May 2022 21:42:32 -0700 Subject: [PATCH] allow for one to stop grouping out weight decayable parameters, to debug optimizer state dict problem --- dalle2_pytorch/optimizer.py | 15 +++++++++------ dalle2_pytorch/trainer.py | 4 ++++ setup.py | 2 +- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/dalle2_pytorch/optimizer.py b/dalle2_pytorch/optimizer.py index ee366d8..9657a6a 100644 --- a/dalle2_pytorch/optimizer.py +++ b/dalle2_pytorch/optimizer.py @@ -12,6 +12,7 @@ def get_optimizer( betas = (0.9, 0.999), eps = 1e-8, filter_by_requires_grad = False, + group_wd_params = True, **kwargs ): if filter_by_requires_grad: @@ -21,11 +22,13 @@ def get_optimizer( return Adam(params, lr = lr, betas = betas, eps = eps) params = set(params) - wd_params, no_wd_params = separate_weight_decayable_params(params) - param_groups = [ - {'params': list(wd_params)}, - {'params': list(no_wd_params), 'weight_decay': 0}, - ] + if group_wd_params: + wd_params, no_wd_params = separate_weight_decayable_params(params) - return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps) + params = [ + {'params': list(wd_params)}, + {'params': list(no_wd_params), 'weight_decay': 0}, + ] + + return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps) diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index acdde4b..4d1ab07 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -254,6 +254,7 @@ class DiffusionPriorTrainer(nn.Module): eps = 1e-6, max_grad_norm = None, amp = False, + group_wd_params = True, **kwargs ): super().__init__() @@ -279,6 +280,7 @@ class DiffusionPriorTrainer(nn.Module): lr = lr, wd = wd, eps = eps, + group_wd_params = group_wd_params, **kwargs ) @@ -410,6 +412,7 @@ class DecoderTrainer(nn.Module): eps = 1e-8, max_grad_norm = 0.5, amp = False, + group_wd_params = True, **kwargs ): super().__init__() @@ -435,6 +438,7 @@ class DecoderTrainer(nn.Module): lr = unet_lr, wd = unet_wd, eps = unet_eps, + group_wd_params = group_wd_params, **kwargs ) diff --git a/setup.py b/setup.py index 86a12c3..e805eda 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.5.0', + version = '0.5.1', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',