mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
allow for one to stop grouping out weight decayable parameters, to debug optimizer state dict problem
This commit is contained in:
@@ -12,6 +12,7 @@ def get_optimizer(
|
|||||||
betas = (0.9, 0.999),
|
betas = (0.9, 0.999),
|
||||||
eps = 1e-8,
|
eps = 1e-8,
|
||||||
filter_by_requires_grad = False,
|
filter_by_requires_grad = False,
|
||||||
|
group_wd_params = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if filter_by_requires_grad:
|
if filter_by_requires_grad:
|
||||||
@@ -21,11 +22,13 @@ def get_optimizer(
|
|||||||
return Adam(params, lr = lr, betas = betas, eps = eps)
|
return Adam(params, lr = lr, betas = betas, eps = eps)
|
||||||
|
|
||||||
params = set(params)
|
params = set(params)
|
||||||
|
|
||||||
|
if group_wd_params:
|
||||||
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||||
|
|
||||||
param_groups = [
|
params = [
|
||||||
{'params': list(wd_params)},
|
{'params': list(wd_params)},
|
||||||
{'params': list(no_wd_params), 'weight_decay': 0},
|
{'params': list(no_wd_params), 'weight_decay': 0},
|
||||||
]
|
]
|
||||||
|
|
||||||
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
||||||
|
|||||||
@@ -254,6 +254,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
eps = 1e-6,
|
eps = 1e-6,
|
||||||
max_grad_norm = None,
|
max_grad_norm = None,
|
||||||
amp = False,
|
amp = False,
|
||||||
|
group_wd_params = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -279,6 +280,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
lr = lr,
|
lr = lr,
|
||||||
wd = wd,
|
wd = wd,
|
||||||
eps = eps,
|
eps = eps,
|
||||||
|
group_wd_params = group_wd_params,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -410,6 +412,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
eps = 1e-8,
|
eps = 1e-8,
|
||||||
max_grad_norm = 0.5,
|
max_grad_norm = 0.5,
|
||||||
amp = False,
|
amp = False,
|
||||||
|
group_wd_params = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -435,6 +438,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
lr = unet_lr,
|
lr = unet_lr,
|
||||||
wd = unet_wd,
|
wd = unet_wd,
|
||||||
eps = unet_eps,
|
eps = unet_eps,
|
||||||
|
group_wd_params = group_wd_params,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user