diff --git a/dalle2_pytorch/vqgan_vae_trainer.py b/dalle2_pytorch/vqgan_vae_trainer.py index cbb6f1f..4047980 100644 --- a/dalle2_pytorch/vqgan_vae_trainer.py +++ b/dalle2_pytorch/vqgan_vae_trainer.py @@ -16,10 +16,11 @@ from torchvision.utils import make_grid, save_image from einops import rearrange -from dalle2_pytorch.train import EMA from dalle2_pytorch.vqgan_vae import VQGanVAE from dalle2_pytorch.optimizer import get_optimizer +from ema_pytorch import EMA + # helpers def exists(val): @@ -97,7 +98,7 @@ class VQGanVAETrainer(nn.Module): valid_frac = 0.05, random_split_seed = 42, ema_beta = 0.995, - ema_update_after_step = 2000, + ema_update_after_step = 500, ema_update_every = 10, apply_grad_penalty_every = 4, amp = False