From 83517849e5117360d094c9cff3d0ab880feb715d Mon Sep 17 00:00:00 2001 From: zion <51308183+nousr@users.noreply.github.com> Date: Fri, 3 Jun 2022 19:43:51 -0700 Subject: [PATCH] ema module fixes (#139) --- dalle2_pytorch/trainer.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index 528d122..fa0422d 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -178,7 +178,7 @@ class EMA(nn.Module): def __init__( self, model, - beta = 0.99, + beta = 0.9999, update_after_step = 1000, update_every = 10, ): @@ -198,7 +198,8 @@ class EMA(nn.Module): self.ema_model.to(device) def copy_params_from_model_to_ema(self): - self.ema_model.state_dict(self.online_model.state_dict()) + for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())): + ma_param.data.copy_(current_param.data) def update(self): step = self.step.item() @@ -217,19 +218,17 @@ class EMA(nn.Module): self.update_moving_average(self.ema_model, self.online_model) + @torch.no_grad() def update_moving_average(self, ma_model, current_model): - def calculate_ema(beta, old, new): - if not exists(old): - return new - return old * beta + (1 - beta) * new + for current_params, ma_params in zip(list(current_model.parameters()), list(ma_model.parameters())): + difference = ma_params.data - current_params.data + difference.mul_(1.0 - self.beta) + ma_params.sub_(difference) - for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): - old_weight, up_weight = ma_params.data, current_params.data - ma_params.data.copy_(calculate_ema(self.beta, old_weight, up_weight)) - - for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()): - new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer) - ma_buffer.copy_(new_buffer_value) + for current_buffer, ma_buffer in zip(list(current_model.buffers()), list(ma_model.buffers())): + difference = ma_buffer - current_buffer + difference.mul_(1.0 - self.beta) + ma_buffer.sub_(difference) def __call__(self, *args, **kwargs): return self.ema_model(*args, **kwargs)