Compare commits

...

5 Commits

Author SHA1 Message Date
Phil Wang
4b994601ae just make sure decoder learning rate is reasonable and help out budding researchers 2022-06-23 11:29:28 -07:00
zion
fddf66e91e fix params in decoder (#162) 2022-06-22 14:45:01 -07:00
Phil Wang
c8422ffd5d fix EMA updating buffers with non-float tensors 2022-06-22 07:16:39 -07:00
Conight
2aadc23c7c Fix train decoder config example (#160) 2022-06-21 22:17:06 -07:00
Phil Wang
c098f57e09 EMA for vqgan vae comes from ema_pytorch now 2022-06-20 15:29:08 -07:00
6 changed files with 10 additions and 7 deletions

View File

@@ -15,7 +15,7 @@
"channels": 3,
"timesteps": 1000,
"loss_type": "l2",
"beta_schedule": "cosine",
"beta_schedule": ["cosine"],
"learned_variance": true
},
"data": {

View File

@@ -451,6 +451,8 @@ class DecoderTrainer(nn.Module):
lr, wd, eps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps))
assert all([unet_lr < 1e-3 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
optimizers = []
for unet, unet_lr, unet_wd, unet_eps in zip(decoder.unets, lr, wd, eps):

View File

@@ -1 +1 @@
__version__ = '0.11.3'
__version__ = '0.11.5'

View File

@@ -16,10 +16,11 @@ from torchvision.utils import make_grid, save_image
from einops import rearrange
from dalle2_pytorch.train import EMA
from dalle2_pytorch.vqgan_vae import VQGanVAE
from dalle2_pytorch.optimizer import get_optimizer
from ema_pytorch import EMA
# helpers
def exists(val):
@@ -97,7 +98,7 @@ class VQGanVAETrainer(nn.Module):
valid_frac = 0.05,
random_split_seed = 42,
ema_beta = 0.995,
ema_update_after_step = 2000,
ema_update_after_step = 500,
ema_update_every = 10,
apply_grad_penalty_every = 4,
amp = False

View File

@@ -28,7 +28,7 @@ setup(
'click',
'clip-anytorch',
'coca-pytorch>=0.0.5',
'ema-pytorch>=0.0.3',
'ema-pytorch>=0.0.7',
'einops>=0.4',
'einops-exts>=0.0.3',
'embedding-reader',

View File

@@ -258,8 +258,8 @@ def train(
is_master = accelerator.process_index == 0
trainer = DecoderTrainer(
accelerator,
decoder,
decoder=decoder,
accelerator=accelerator,
**kwargs
)