diff --git a/dalle2_pytorch/optimizer.py b/dalle2_pytorch/optimizer.py index 33192a5..e4a0208 100644 --- a/dalle2_pytorch/optimizer.py +++ b/dalle2_pytorch/optimizer.py @@ -1,8 +1,10 @@ from torch.optim import AdamW, Adam def separate_weight_decayable_params(params): - no_wd_params = set([param for param in params if param.ndim < 2]) - wd_params = set(params) - no_wd_params + wd_params, no_wd_params = [], [] + for param in params: + param_list = no_wd_params if param.ndim < 2 else wd_params + param_list.append(param) return wd_params, no_wd_params def get_optimizer( @@ -25,8 +27,8 @@ def get_optimizer( wd_params, no_wd_params = separate_weight_decayable_params(params) params = [ - {'params': list(wd_params)}, - {'params': list(no_wd_params), 'weight_decay': 0}, + {'params': wd_params}, + {'params': no_wd_params, 'weight_decay': 0}, ] return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)