mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
be able to customize adam eps
This commit is contained in:
@@ -147,6 +147,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
use_ema = True,
|
||||
lr = 3e-4,
|
||||
wd = 1e-2,
|
||||
eps = 1e-6,
|
||||
max_grad_norm = None,
|
||||
amp = False,
|
||||
**kwargs
|
||||
@@ -173,6 +174,7 @@ class DiffusionPriorTrainer(nn.Module):
|
||||
diffusion_prior.parameters(),
|
||||
lr = lr,
|
||||
wd = wd,
|
||||
eps = eps,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -223,6 +225,7 @@ class DecoderTrainer(nn.Module):
|
||||
use_ema = True,
|
||||
lr = 2e-5,
|
||||
wd = 1e-2,
|
||||
eps = 1e-8,
|
||||
max_grad_norm = None,
|
||||
amp = False,
|
||||
**kwargs
|
||||
@@ -247,13 +250,14 @@ class DecoderTrainer(nn.Module):
|
||||
# be able to finely customize learning rate, weight decay
|
||||
# per unet
|
||||
|
||||
lr, wd = map(partial(cast_tuple, length = self.num_unets), (lr, wd))
|
||||
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
|
||||
|
||||
for ind, (unet, unet_lr, unet_wd) in enumerate(zip(self.decoder.unets, lr, wd)):
|
||||
for ind, (unet, unet_lr, unet_wd, unet_eps) in enumerate(zip(self.decoder.unets, lr, wd, eps)):
|
||||
optimizer = get_optimizer(
|
||||
unet.parameters(),
|
||||
lr = unet_lr,
|
||||
wd = unet_wd,
|
||||
eps = unet_eps,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user