mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
be able to customize adam eps
This commit is contained in:
@@ -10,13 +10,14 @@ def get_optimizer(
|
|||||||
lr = 2e-5,
|
lr = 2e-5,
|
||||||
wd = 1e-2,
|
wd = 1e-2,
|
||||||
betas = (0.9, 0.999),
|
betas = (0.9, 0.999),
|
||||||
|
eps = 1e-8,
|
||||||
filter_by_requires_grad = False
|
filter_by_requires_grad = False
|
||||||
):
|
):
|
||||||
if filter_by_requires_grad:
|
if filter_by_requires_grad:
|
||||||
params = list(filter(lambda t: t.requires_grad, params))
|
params = list(filter(lambda t: t.requires_grad, params))
|
||||||
|
|
||||||
if wd == 0:
|
if wd == 0:
|
||||||
return Adam(params, lr = lr, betas = betas)
|
return Adam(params, lr = lr, betas = betas, eps = eps)
|
||||||
|
|
||||||
params = set(params)
|
params = set(params)
|
||||||
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||||
@@ -26,4 +27,4 @@ def get_optimizer(
|
|||||||
{'params': list(no_wd_params), 'weight_decay': 0},
|
{'params': list(no_wd_params), 'weight_decay': 0},
|
||||||
]
|
]
|
||||||
|
|
||||||
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas)
|
return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
||||||
|
|||||||
@@ -147,6 +147,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
use_ema = True,
|
use_ema = True,
|
||||||
lr = 3e-4,
|
lr = 3e-4,
|
||||||
wd = 1e-2,
|
wd = 1e-2,
|
||||||
|
eps = 1e-6,
|
||||||
max_grad_norm = None,
|
max_grad_norm = None,
|
||||||
amp = False,
|
amp = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
@@ -173,6 +174,7 @@ class DiffusionPriorTrainer(nn.Module):
|
|||||||
diffusion_prior.parameters(),
|
diffusion_prior.parameters(),
|
||||||
lr = lr,
|
lr = lr,
|
||||||
wd = wd,
|
wd = wd,
|
||||||
|
eps = eps,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -223,6 +225,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
use_ema = True,
|
use_ema = True,
|
||||||
lr = 2e-5,
|
lr = 2e-5,
|
||||||
wd = 1e-2,
|
wd = 1e-2,
|
||||||
|
eps = 1e-8,
|
||||||
max_grad_norm = None,
|
max_grad_norm = None,
|
||||||
amp = False,
|
amp = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
@@ -247,13 +250,14 @@ class DecoderTrainer(nn.Module):
|
|||||||
# be able to finely customize learning rate, weight decay
|
# be able to finely customize learning rate, weight decay
|
||||||
# per unet
|
# per unet
|
||||||
|
|
||||||
lr, wd = map(partial(cast_tuple, length = self.num_unets), (lr, wd))
|
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
|
||||||
|
|
||||||
for ind, (unet, unet_lr, unet_wd) in enumerate(zip(self.decoder.unets, lr, wd)):
|
for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.decoder.unets, lr, wd, eps)):
|
||||||
optimizer = get_optimizer(
|
optimizer = get_optimizer(
|
||||||
unet.parameters(),
|
unet.parameters(),
|
||||||
lr = unet_lr,
|
lr = unet_lr,
|
||||||
wd = unet_wd,
|
wd = unet_wd,
|
||||||
|
eps = unet_eps,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user