mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
Fixed non deterministic optimizer creation (#130)
This commit is contained in:
@@ -1,8 +1,10 @@
|
|||||||
from torch.optim import AdamW, Adam
|
from torch.optim import AdamW, Adam
|
||||||
|
|
||||||
def separate_weight_decayable_params(params):
|
def separate_weight_decayable_params(params):
|
||||||
no_wd_params = set([param for param in params if param.ndim < 2])
|
wd_params, no_wd_params = [], []
|
||||||
wd_params = set(params) - no_wd_params
|
for param in params:
|
||||||
|
param_list = no_wd_params if param.ndim < 2 else wd_params
|
||||||
|
param_list.append(param)
|
||||||
return wd_params, no_wd_params
|
return wd_params, no_wd_params
|
||||||
|
|
||||||
def get_optimizer(
|
def get_optimizer(
|
||||||
@@ -25,8 +27,8 @@ def get_optimizer(
|
|||||||
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||||
|
|
||||||
params = [
|
params = [
|
||||||
{'params': list(wd_params)},
|
{'params': wd_params},
|
||||||
{'params': list(no_wd_params), 'weight_decay': 0},
|
{'params': no_wd_params, 'weight_decay': 0},
|
||||||
]
|
]
|
||||||
|
|
||||||
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
||||||
|
|||||||
Reference in New Issue
Block a user