Fixed non deterministic optimizer creation (#130)

This commit is contained in:
Aidan Dempster
2022-05-31 12:03:20 -04:00
committed by GitHub
parent 6f8b90d4d7
commit 09534119a1

View File

@@ -1,8 +1,10 @@
from torch.optim import AdamW, Adam from torch.optim import AdamW, Adam
def separate_weight_decayable_params(params): def separate_weight_decayable_params(params):
no_wd_params = set([param for param in params if param.ndim < 2]) wd_params, no_wd_params = [], []
wd_params = set(params) - no_wd_params for param in params:
param_list = no_wd_params if param.ndim < 2 else wd_params
param_list.append(param)
return wd_params, no_wd_params return wd_params, no_wd_params
def get_optimizer( def get_optimizer(
@@ -25,8 +27,8 @@ def get_optimizer(
wd_params, no_wd_params = separate_weight_decayable_params(params) wd_params, no_wd_params = separate_weight_decayable_params(params)
params = [ params = [
{'params': list(wd_params)}, {'params': wd_params},
{'params': list(no_wd_params), 'weight_decay': 0}, {'params': no_wd_params, 'weight_decay': 0},
] ]
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps) return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)