mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
EMA for vqgan vae comes from ema_pytorch now
This commit is contained in:
@@ -16,10 +16,11 @@ from torchvision.utils import make_grid, save_image
|
|||||||
|
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from dalle2_pytorch.train import EMA
|
|
||||||
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
||||||
from dalle2_pytorch.optimizer import get_optimizer
|
from dalle2_pytorch.optimizer import get_optimizer
|
||||||
|
|
||||||
|
from ema_pytorch import EMA
|
||||||
|
|
||||||
# helpers
|
# helpers
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
@@ -97,7 +98,7 @@ class VQGanVAETrainer(nn.Module):
|
|||||||
valid_frac = 0.05,
|
valid_frac = 0.05,
|
||||||
random_split_seed = 42,
|
random_split_seed = 42,
|
||||||
ema_beta = 0.995,
|
ema_beta = 0.995,
|
||||||
ema_update_after_step = 2000,
|
ema_update_after_step = 500,
|
||||||
ema_update_every = 10,
|
ema_update_every = 10,
|
||||||
apply_grad_penalty_every = 4,
|
apply_grad_penalty_every = 4,
|
||||||
amp = False
|
amp = False
|
||||||
|
|||||||
Reference in New Issue
Block a user