mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-23 19:34:21 +01:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1ffeecd0ca | ||
|
|
3df899f7a4 | ||
|
|
09534119a1 | ||
|
|
6f8b90d4d7 |
@@ -1,8 +1,10 @@
|
|||||||
from torch.optim import AdamW, Adam
|
from torch.optim import AdamW, Adam
|
||||||
|
|
||||||
def separate_weight_decayable_params(params):
|
def separate_weight_decayable_params(params):
|
||||||
no_wd_params = set([param for param in params if param.ndim < 2])
|
wd_params, no_wd_params = [], []
|
||||||
wd_params = set(params) - no_wd_params
|
for param in params:
|
||||||
|
param_list = no_wd_params if param.ndim < 2 else wd_params
|
||||||
|
param_list.append(param)
|
||||||
return wd_params, no_wd_params
|
return wd_params, no_wd_params
|
||||||
|
|
||||||
def get_optimizer(
|
def get_optimizer(
|
||||||
@@ -25,8 +27,8 @@ def get_optimizer(
|
|||||||
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
wd_params, no_wd_params = separate_weight_decayable_params(params)
|
||||||
|
|
||||||
params = [
|
params = [
|
||||||
{'params': list(wd_params)},
|
{'params': wd_params},
|
||||||
{'params': list(no_wd_params), 'weight_decay': 0},
|
{'params': no_wd_params, 'weight_decay': 0},
|
||||||
]
|
]
|
||||||
|
|
||||||
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ class EMA(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
beta = 0.9999,
|
beta = 0.99,
|
||||||
update_after_step = 1000,
|
update_after_step = 1000,
|
||||||
update_every = 10,
|
update_every = 10,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = '0.6.2'
|
__version__ = '0.6.5'
|
||||||
|
|||||||
Reference in New Issue
Block a user