allow for one to stop grouping out weight decayable parameters, to debug optimizer state dict problem

This commit is contained in:
Phil Wang
2022-05-24 21:42:32 -07:00
parent 8864fd0aa7
commit 857b9fbf1e
3 changed files with 14 additions and 7 deletions

View File

@@ -12,6 +12,7 @@ def get_optimizer(
betas = (0.9, 0.999), betas = (0.9, 0.999),
eps = 1e-8, eps = 1e-8,
filter_by_requires_grad = False, filter_by_requires_grad = False,
group_wd_params = True,
**kwargs **kwargs
): ):
if filter_by_requires_grad: if filter_by_requires_grad:
@@ -21,11 +22,13 @@ def get_optimizer(
return Adam(params, lr = lr, betas = betas, eps = eps) return Adam(params, lr = lr, betas = betas, eps = eps)
params = set(params) params = set(params)
if group_wd_params:
wd_params, no_wd_params = separate_weight_decayable_params(params) wd_params, no_wd_params = separate_weight_decayable_params(params)
param_groups = [ params = [
{'params': list(wd_params)}, {'params': list(wd_params)},
{'params': list(no_wd_params), 'weight_decay': 0}, {'params': list(no_wd_params), 'weight_decay': 0},
] ]
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps) return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)

View File

@@ -254,6 +254,7 @@ class DiffusionPriorTrainer(nn.Module):
eps = 1e-6, eps = 1e-6,
max_grad_norm = None, max_grad_norm = None,
amp = False, amp = False,
group_wd_params = True,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
@@ -279,6 +280,7 @@ class DiffusionPriorTrainer(nn.Module):
lr = lr, lr = lr,
wd = wd, wd = wd,
eps = eps, eps = eps,
group_wd_params = group_wd_params,
**kwargs **kwargs
) )
@@ -410,6 +412,7 @@ class DecoderTrainer(nn.Module):
eps = 1e-8, eps = 1e-8,
max_grad_norm = 0.5, max_grad_norm = 0.5,
amp = False, amp = False,
group_wd_params = True,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
@@ -435,6 +438,7 @@ class DecoderTrainer(nn.Module):
lr = unet_lr, lr = unet_lr,
wd = unet_wd, wd = unet_wd,
eps = unet_eps, eps = unet_eps,
group_wd_params = group_wd_params,
**kwargs **kwargs
) )

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.5.0', version = '0.5.1',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',