mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
31 lines
858 B
Python
31 lines
858 B
Python
from torch.optim import AdamW, Adam
|
|
|
|
def separate_weight_decayable_params(params):
|
|
no_wd_params = set([param for param in params if param.ndim < 2])
|
|
wd_params = set(params) - no_wd_params
|
|
return wd_params, no_wd_params
|
|
|
|
def get_optimizer(
|
|
params,
|
|
lr = 1e-4,
|
|
wd = 1e-2,
|
|
betas = (0.9, 0.999),
|
|
eps = 1e-8,
|
|
filter_by_requires_grad = False
|
|
):
|
|
if filter_by_requires_grad:
|
|
params = list(filter(lambda t: t.requires_grad, params))
|
|
|
|
if wd == 0:
|
|
return Adam(params, lr = lr, betas = betas, eps = eps)
|
|
|
|
params = set(params)
|
|
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
|
|
|
param_groups = [
|
|
{'params': list(wd_params)},
|
|
{'params': list(no_wd_params), 'weight_decay': 0},
|
|
]
|
|
|
|
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|