From c098f57e09ab320ba3fb9942d7b7e84826d45de4 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 20 Jun 2022 15:29:08 -0700 Subject: [PATCH] EMA for vqgan vae comes from ema_pytorch now --- dalle2_pytorch/vqgan_vae_trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dalle2_pytorch/vqgan_vae_trainer.py b/dalle2_pytorch/vqgan_vae_trainer.py index cbb6f1f..4047980 100644 --- a/dalle2_pytorch/vqgan_vae_trainer.py +++ b/dalle2_pytorch/vqgan_vae_trainer.py @@ -16,10 +16,11 @@ from torchvision.utils import make_grid, save_image from einops import rearrange -from dalle2_pytorch.train import EMA from dalle2_pytorch.vqgan_vae import VQGanVAE from dalle2_pytorch.optimizer import get_optimizer +from ema_pytorch import EMA + # helpers def exists(val): @@ -97,7 +98,7 @@ class VQGanVAETrainer(nn.Module): valid_frac = 0.05, random_split_seed = 42, ema_beta = 0.995, - ema_update_after_step = 2000, + ema_update_after_step = 500, ema_update_every = 10, apply_grad_penalty_every = 4, amp = False