from torch.optim import AdamW, Adam def separate_weight_decayable_params(params): no_wd_params = set([param for param in params if param.ndim < 2]) wd_params = set(params) - no_wd_params return wd_params, no_wd_params def get_optimizer( params, lr = 1e-4, wd = 1e-2, betas = (0.9, 0.999), eps = 1e-8, filter_by_requires_grad = False ): if filter_by_requires_grad: params = list(filter(lambda t: t.requires_grad, params)) if wd == 0: return Adam(params, lr = lr, betas = betas, eps = eps) params = set(params) wd_params, no_wd_params = separate_weight_decayable_params(params) param_groups = [ {'params': list(wd_params)}, {'params': list(no_wd_params), 'weight_decay': 0}, ] return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas, eps = eps)