From 64c2f9c4eb61d8b798b1ce67cb9d0c2c91de4568 Mon Sep 17 00:00:00 2001 From: zion <51308183+nousr@users.noreply.github.com> Date: Sat, 4 Jun 2022 13:26:34 -0700 Subject: [PATCH] implement ema warmup from @crowsonkb (#140) --- dalle2_pytorch/trainer.py | 39 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index fa0422d..84c0957 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -175,12 +175,34 @@ def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embe # exponential moving average wrapper class EMA(nn.Module): + """ + Implements exponential moving average shadowing for your model. + + Utilizes an inverse decay schedule to manage longer term training runs. + By adjusting the power, you can control how fast EMA will ramp up to your specified beta. + + @crowsonkb's notes on EMA Warmup: + + If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are + good values for models you plan to train for a million or more steps (reaches decay + factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models + you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at + 215.4k steps). + + Args: + inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. + power (float): Exponential factor of EMA warmup. Default: 1. + min_value (float): The minimum EMA decay rate. Default: 0. + """ def __init__( self, model, beta = 0.9999, - update_after_step = 1000, + update_after_step = 10000, update_every = 10, + inv_gamma = 1.0, + power = 2/3, + min_value = 0.0, ): super().__init__() self.beta = beta @@ -190,6 +212,10 @@ class EMA(nn.Module): self.update_every = update_every self.update_after_step = update_after_step + self.inv_gamma = inv_gamma + self.power = power + self.min_value = min_value + self.register_buffer('initted', torch.Tensor([False])) self.register_buffer('step', torch.tensor([0])) @@ -201,6 +227,11 @@ class EMA(nn.Module): for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())): ma_param.data.copy_(current_param.data) + def get_current_decay(self): + epoch = max(0, self.step.item() - self.update_after_step - 1) + value = 1 - (1 + epoch / self.inv_gamma) ** - self.power + return 0. if epoch < 0 else min(self.beta, max(self.min_value, value)) + def update(self): step = self.step.item() self.step += 1 @@ -220,14 +251,16 @@ class EMA(nn.Module): @torch.no_grad() def update_moving_average(self, ma_model, current_model): + current_decay = self.get_current_decay() + for current_params, ma_params in zip(list(current_model.parameters()), list(ma_model.parameters())): difference = ma_params.data - current_params.data - difference.mul_(1.0 - self.beta) + difference.mul_(1.0 - current_decay) ma_params.sub_(difference) for current_buffer, ma_buffer in zip(list(current_model.buffers()), list(ma_model.buffers())): difference = ma_buffer - current_buffer - difference.mul_(1.0 - self.beta) + difference.mul_(1.0 - current_decay) ma_buffer.sub_(difference) def __call__(self, *args, **kwargs):