EMA for vqgan vae comes from ema_pytorch now

This commit is contained in:
Phil Wang
2022-06-20 15:29:08 -07:00
parent 0021535c26
commit c098f57e09

View File

@@ -16,10 +16,11 @@ from torchvision.utils import make_grid, save_image
from einops import rearrange from einops import rearrange
from dalle2_pytorch.train import EMA
from dalle2_pytorch.vqgan_vae import VQGanVAE from dalle2_pytorch.vqgan_vae import VQGanVAE
from dalle2_pytorch.optimizer import get_optimizer from dalle2_pytorch.optimizer import get_optimizer
from ema_pytorch import EMA
# helpers # helpers
def exists(val): def exists(val):
@@ -97,7 +98,7 @@ class VQGanVAETrainer(nn.Module):
valid_frac = 0.05, valid_frac = 0.05,
random_split_seed = 42, random_split_seed = 42,
ema_beta = 0.995, ema_beta = 0.995,
ema_update_after_step = 2000, ema_update_after_step = 500,
ema_update_every = 10, ema_update_every = 10,
apply_grad_penalty_every = 4, apply_grad_penalty_every = 4,
amp = False amp = False