just make sure decoder learning rate is reasonable and help out budding researchers

This commit is contained in:
Phil Wang
2022-06-23 11:29:21 -07:00
parent fddf66e91e
commit 4b994601ae
2 changed files with 3 additions and 1 deletions

View File

@@ -451,6 +451,8 @@ class DecoderTrainer(nn.Module):
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
assert all([unet_lr < 1e-3 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
optimizers = []
for unet, unet_lr, unet_wd, unet_eps in zip(decoder.unets, lr, wd, eps):